#include <algorithm>
#include <cassert>
#include <cstring>
#include <cmath>
#include <iostream>
#include <vector>
#include "crowd.h"
#include "infer.h"
#include "policy.h"

using namespace std;

/// shared functions ///

template<typename T> static unsigned int shared_arg_min(vector<T>& data, vector<bool>& ignore)
{
    assert(data.size() == ignore.size());

    unsigned int i = 0;
    while(ignore[i] && i < ignore.size()) ++i;
    unsigned int m = i;

    for(++i; i < data.size(); ++i)
        if(data[i] < data[m] && !ignore[i])
            m = i;
    return m;
}

double shared_kl_divergence(vector<double>& old_odds, vector<double>& new_odds)
{
    double tot_kl = 0.0;
    for(unsigned int i = 0; i < old_odds.size(); ++i) {
        double old_p = math_expit(old_odds[i]);
        double new_p = math_expit(new_odds[i]);
        tot_kl += new_p * log(new_p / old_p);
        tot_kl += (1.0 - new_p) * log((1.0 - new_p) / (1.0 - old_p));
    }
    return tot_kl;
}

/// uniform allocation ///

policy_uniform::policy_uniform():
    t_next(0)
{}

void policy_uniform::choose_task(crowd* c, infer* inf, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // add new tasks on the fly
    if(inf->task_odds.size() > task_count.size()) {
        unsigned int n_tasks = inf->task_odds.size();
        task_ignore.resize(n_tasks, false);
        task_count.resize(n_tasks, 0);
    }

    // unpack the worker's id
    assert(t < c->data_seq.size());
    unsigned int j = c->data_seq[t].work_id;

    // skip the tasks this worker has already seen
    fill(task_ignore.begin(), task_ignore.end(), false);
    if(j < inf->work_hist.size())
        for(unsigned int k: inf->work_hist[j].time_steps) {
            assert(k < c->data_seq.size());
            task_ignore[c->data_seq[k].task_id] = true;
        }

    // pick the task with the least number of labels
    unsigned int i = shared_arg_min(task_count, task_ignore);
    c->data_seq[t].task_id = i;
    ++task_count[i];
}

void policy_uniform_test()
{
    policy_uniform poly;

    // simple crowd and inference
    double work_prior = 0.7;
    double p = 0.8;
    int quota = 5;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd_budget c(d_work_acc, d_work_quota, n_workers);
    infer_majority maj = infer_majority(n_tasks, work_prior);

    // generate and allocate all the data
    mt19937 ran_gen = well_seeded_mt19937(666);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        poly.choose_task(&c, &maj, t);
        maj.infer_update(ran_gen, &c, t);
    }

    // check that the counts match reality
    vector<unsigned int> task_count_final(n_tasks);
    for(unsigned int j = 0; j < n_workers; ++j)
        for(unsigned int k: maj.work_hist[j].time_steps)
            ++task_count_final[c.data_seq[k].task_id];
    for(unsigned int i = 0; i < n_tasks; ++i)
        assert(poly.task_count[i] == task_count_final[i]);

    // stochastic test: roughly the same number of labels per task
    unsigned int r = n_workers * quota / n_tasks;
    for(unsigned int i = 0; i < n_tasks; ++i)
        assert(poly.task_count[i] >= r - 1);
}

/// weight balance ///

policy_balance::policy_balance():
    t_next(0)
{}

void policy_balance::choose_task(crowd* c, infer* inf, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // add new tasks on the fly
    if(inf->task_odds.size() > task_bound.size()) {
        unsigned int n_tasks = inf->task_odds.size();
        task_ignore.resize(n_tasks, false);
        task_bound.resize(n_tasks, 0);
    }

    // unpack the worker's id
    assert(t < c->data_seq.size());
    unsigned int j = c->data_seq[t].work_id;

    // skip the tasks this worker has already seen
    fill(task_ignore.begin(), task_ignore.end(), false);
    if(j < inf->work_hist.size())
        for(unsigned int k: inf->work_hist[j].time_steps) {
            assert(k < c->data_seq.size());
            task_ignore[c->data_seq[k].task_id] = true;
        }

    // add new workers on the fly
    if(c->work_acc.size() > work_weight.size())
        work_weight.resize(inf->work_estim.size(), 0.5);

    // recompute the worker weights
    for(unsigned int k = 0; k < inf->work_estim.size(); ++k)
        work_weight[k] = -log(4.0 * inf->work_estim[k] * (1 - inf->work_estim[k]));

    // compute the bound
    fill(task_bound.begin(), task_bound.end(), 0);
    for(unsigned int k = 0; k < t; ++k)
        task_bound[c->data_seq[k].task_id] += work_weight[c->data_seq[k].work_id];

    // pick the task with the smallest bound
    unsigned int i = shared_arg_min(task_bound, task_ignore);
    c->data_seq[t].task_id = i;
}

void policy_balance_test()
{
    policy_balance poly;

    // simple crowd and inference
    double work_prior = 0.7;
    double p = 0.8;
    int quota = 5;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd_budget c(d_work_acc, d_work_quota, n_workers);
    infer_majority maj = infer_majority(n_tasks, work_prior);

    // generate and allocate all the data
    mt19937 ran_gen = well_seeded_mt19937(777);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        poly.choose_task(&c, &maj, t);
        maj.infer_update(ran_gen, &c, t);
    }

    // stochastic test: roughly the same bound on all tasks
    double avg_bound = 0.0;
    for(unsigned int i = 0; i < n_tasks; ++i)
        avg_bound += poly.task_bound[i];
    avg_bound /= (double) n_tasks;
    double avg_abs_diff = 0.0;
    for(unsigned int i = 0; i < n_tasks; ++i) {
        double diff = avg_bound - poly.task_bound[i];
        avg_abs_diff += (diff >= 0.0)? diff: -diff;
    }
    avg_abs_diff /= (double) n_tasks;
    assert(avg_abs_diff <= -log(4.0 * p * (1 - p)));
}

/// uncertainty sampling ///

policy_uncertainty::policy_uncertainty():
    t_next(0)
{}

void policy_uncertainty::choose_task(crowd* c, infer* inf, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // add new tasks on the fly
    if(inf->task_odds.size() > task_ignore.size()) {
        unsigned int n_tasks = inf->task_odds.size();
        task_ignore.resize(n_tasks, false);
        task_abs_odds.resize(n_tasks, 0);
    }

    // unpack the worker's id
    assert(t < c->data_seq.size());
    unsigned int j = c->data_seq[t].work_id;

    // skip the tasks this worker has already seen
    fill(task_ignore.begin(), task_ignore.end(), false);
    if(j < inf->work_hist.size())
        for(unsigned int k: inf->work_hist[j].time_steps) {
            assert(k < c->data_seq.size());
            task_ignore[c->data_seq[k].task_id] = true;
        }

    // compute the absolute log odds
    fill(task_abs_odds.begin(), task_abs_odds.end(), 0);
    for(unsigned int k = 0; k < inf->task_odds.size(); ++k)
        task_abs_odds[k] = (inf->task_odds[k] >= 0.0)? inf->task_odds[k]: -inf->task_odds[k];

    // pick the task with the smallest absolute log odds
    unsigned int i = shared_arg_min(task_abs_odds, task_ignore);
    c->data_seq[t].task_id = i;
}

void policy_uncertainty_test()
{
    policy_uncertainty poly;

    // simple crowd and inference
    double work_prior = 0.7;
    double p = 0.8;
    int quota = 5;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd_budget c(d_work_acc, d_work_quota, n_workers);
    infer_majority maj = infer_majority(n_tasks, work_prior);

    // generate and allocate all the data
    mt19937 ran_gen = well_seeded_mt19937(777);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        poly.choose_task(&c, &maj, t);
        maj.infer_update(ran_gen, &c, t);
    }

    // check the absolute log odds
    for(unsigned int i = 0; i < n_tasks; ++i)
        if(i != c.data_seq[t_final].task_id) {
            double abs_odds = (maj.task_odds[i] >= 0.0)? maj.task_odds[i]: -maj.task_odds[i];
            assert(poly.task_abs_odds[i] == abs_odds);
        }

    // stochastic test: roughly the same absolute uncertainty on all tasks
    double avg_abs_odds = 0.0;
    for(unsigned int i = 0; i < n_tasks; ++i)
        avg_abs_odds += poly.task_abs_odds[i];
    avg_abs_odds /= (double) n_tasks;
    double avg_abs_diff = 0.0;
    for(unsigned int i = 0; i < n_tasks; ++i) {
        double diff = avg_abs_odds - poly.task_abs_odds[i];
        avg_abs_diff += (diff >= 0.0)? diff: -diff;
    }
    avg_abs_diff /= (double) n_tasks;
    assert(avg_abs_diff <= maj.work_weight);
}

/// zero-one loss reduction ///

policy_los::policy_los():
    t_next(0)
{}

void policy_los::choose_task(crowd* c, infer* inf, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // add new tasks on the fly
    if(inf->task_odds.size() > task_los.size()) {
        unsigned int n_tasks = inf->task_odds.size();
        task_ignore.resize(n_tasks, false);
        task_los.resize(n_tasks, 0);
    }

    // unpack the worker's id
    assert(t < c->data_seq.size());
    unsigned int j = c->data_seq[t].work_id;

    // skip the tasks this worker has already seen
    fill(task_ignore.begin(), task_ignore.end(), false);
    if(j < inf->work_hist.size())
        for(unsigned int k: inf->work_hist[j].time_steps) {
            assert(k < c->data_seq.size());
            task_ignore[c->data_seq[k].task_id] = true;
        }

    // compute the negative expected zero-one loss
    double work_p = (inf->work_estim.size() > j)? inf->work_estim[j]: 0.5;
    double work_w = math_logit_safe(work_p);
    for(unsigned int i = 0; i < task_los.size(); ++i) {

        double task_post = math_expit(inf->task_odds[i]);
        double pos_label = work_p * task_post + (1.0 - work_p) * (1.0 - task_post);
        double neg_label = (1.0 - work_p) * task_post + work_p * (1.0 - task_post);

        double cur_odds = inf->task_odds[i];
        double pos_odds = inf->task_odds[i] + work_w;
        double neg_odds = inf->task_odds[i] - work_w;

        double cur_abs = (cur_odds >= 0.0)? cur_odds: -cur_odds;
        double pos_abs = (pos_odds >= 0.0)? pos_odds: -pos_odds;
        double neg_abs = (neg_odds >= 0.0)? neg_odds: -neg_odds;

        double cur_los = math_expit(-cur_abs);
        double pos_los = math_expit(-pos_abs);
        double neg_los = math_expit(-neg_abs);

        task_los[i] = -(cur_los - pos_label * pos_los - neg_label * neg_los);
    }

    // pick the task with the smallest negative loss
    unsigned int i = shared_arg_min(task_los, task_ignore);
    c->data_seq[t].task_id = i;
}

void policy_los_test()
{
    policy_los poly;

    // simple crowd and inference
    double work_prior = 0.7;
    double p = 0.8;
    int quota = 5;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd_budget c(d_work_acc, d_work_quota, n_workers);
    infer_majority maj = infer_majority(n_tasks, work_prior);

    // generate and allocate all the data
    mt19937 ran_gen = well_seeded_mt19937(777);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        poly.choose_task(&c, &maj, t);
        maj.infer_update(ran_gen, &c, t);
    }

    // the loss on all tasks must be zero
    for(unsigned int i = 0; i < n_tasks; ++i)
        assert(poly.task_los[i] <= 0.0000000001);
}

/// expected information gain policy

policy_eig::policy_eig(unsigned int ran_seed):
    t_next(0)
{
    ran_gen = well_seeded_mt19937(ran_seed);
}

void policy_eig::choose_task(crowd* c, infer* inf, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // add new tasks on the fly
    if(inf->task_odds.size() > task_eig.size()) {
        unsigned int n_tasks = inf->task_odds.size();
        task_ignore.resize(n_tasks, false);
        task_eig.resize(n_tasks, 0);
    }

    // unpack the worker's id
    assert(t < c->data_seq.size());
    unsigned int j = c->data_seq[t].work_id;

    // skip the tasks this worker has already seen
    fill(task_ignore.begin(), task_ignore.end(), false);
    if(j < inf->work_hist.size())
        for(unsigned int k: inf->work_hist[j].time_steps) {
            assert(k < c->data_seq.size());
            task_ignore[c->data_seq[k].task_id] = true;
        }

    // initialisation
    data_point tmp = c->data_seq[t];

    // compute the negative expected information gain
    double work_p = (inf->work_estim.size() > j)? inf->work_estim[j]: 0.5;
    for(unsigned int i = 0; i < task_eig.size(); ++i)
        if(task_ignore[i] == false) {
            c->data_seq[t].task_id = i;

            // positive label
            c->data_seq[t].label = +1.0;
            infer* inf_next = inf->clone();
            inf_next->infer_update(ran_gen, c, t);
            double ig_pos = shared_kl_divergence(inf->task_odds, inf_next->task_odds);
            delete inf_next;

            // negative label
            c->data_seq[t].label = -1.0;
            inf_next = inf->clone();
            inf_next->infer_update(ran_gen, c, t);
            double ig_neg = shared_kl_divergence(inf->task_odds, inf_next->task_odds);;
            delete inf_next;

            // expectation
            double task_post = math_expit(inf->task_odds[i]);
            double pos_label = work_p * task_post + (1.0 - work_p) * (1.0 - task_post);
            double neg_label = (1.0 - work_p) * task_post + work_p * (1.0 - task_post);
            task_eig[i] = -(pos_label * ig_pos + neg_label * ig_neg);
        }

    // restore the crowd status
    c->data_seq[t] = tmp;

    // pick the task with the smallest negative expected information gain
    unsigned int i = shared_arg_min(task_eig, task_ignore);
    c->data_seq[t].task_id = i;
}

void policy_eig_test()
{
    policy_eig poly(0);

    // simple crowd and inference
    double p = 0.8;
    int quota = 5;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd_budget c(d_work_acc, d_work_quota, n_workers);
    infer_variational var = infer_variational(n_tasks, 3.0, 2.0, 5, 5);

    // generate and allocate all the data
    mt19937 ran_gen = well_seeded_mt19937(777);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        poly.choose_task(&c, &var, t);
        var.infer_update(ran_gen, &c, t);
    }

    // the loss on all tasks must be zero
    for(unsigned int i = 0; i < n_tasks; ++i)
        assert(poly.task_eig[i] <= 0.0000000001);
}

/// parser ///

policy* policy_parse(int *argc, char **argv[])
{
    if(*argc < 1) {
        cerr << "Not enough arguments to parse the policy type" << endl;
        return NULL;
    }

    if(strcmp((*argv)[0], "policy_uniform") == 0) {
        *argc -= 1; *argv += 1;
        return (policy*) new policy_uniform();

    } else if(strcmp((*argv)[0], "policy_balance") == 0) {
        *argc -= 1; *argv += 1;
        return (policy*) new policy_balance();

    } else if(strcmp((*argv)[0], "policy_uncertainty") == 0) {
        *argc -= 1; *argv += 1;
        return (policy*) new policy_uncertainty();

    } else if(strcmp((*argv)[0], "policy_los") == 0) {
        *argc -= 1; *argv += 1;
        return (policy*) new policy_los();

    } else if(strcmp((*argv)[0], "policy_eig") == 0) {
        *argc -= 1; *argv += 1;
        return (policy*) new policy_eig(0);
    }

    cerr << "Unable to parse the policy type" << endl;
    return NULL;
}

void policy_parse_test()
{
    const char *argv_uni[] = {"policy_uniform", "no", "big"};
    const char *argv_crt[] = {"policy_uncertainty", "deal"};
    const char *argv_bal[] = {"policy_balance", "yo!"};
    const char *argv_l0s[] = {"policy_los", "Hummus?"};
    const char *argv_31g[] = {"policy_eig", "Yes, please!"};
    char **argv_uniform = (char**) argv_uni;
    char **argv_uncertainty = (char**) argv_crt;
    char **argv_balance = (char**) argv_bal;
    char **argv_los = (char**) argv_l0s;
    char **argv_eig = (char**) argv_31g;

    int argc_uniform = 3;
    int argc_uncertainty = 2;
    int argc_balance = 2;
    int argc_los = 2;
    int argc_eig = 2;

    policy_uniform* poly_uni = (policy_uniform*) policy_parse(&argc_uniform, &argv_uniform);
    policy_uncertainty* poly_crt = (policy_uncertainty*) policy_parse(&argc_uncertainty, &argv_uncertainty);
    policy_balance* poly_bal = (policy_balance*) policy_parse(&argc_balance, &argv_balance);
    policy_los* poly_los = (policy_los*) policy_parse(&argc_los, &argv_los);
    policy_eig* poly_eig = (policy_eig*) policy_parse(&argc_eig, &argv_eig);

    assert(argc_uniform == 2);
    assert(argc_uncertainty == 1);
    assert(argc_balance == 1);
    assert(argc_los == 1);
    assert(argc_eig == 1);

    assert(strcmp(argv_uniform[0], "no") == 0);
    assert(strcmp(argv_uniform[1], "big") == 0);
    assert(strcmp(argv_uncertainty[0], "deal") == 0);
    assert(strcmp(argv_balance[0], "yo!") == 0);
    assert(strcmp(argv_los[0], "Hummus?") == 0);
    assert(strcmp(argv_eig[0], "Yes, please!") == 0);

    assert(poly_uni != NULL);
    assert(poly_crt != NULL);
    assert(poly_bal != NULL);
    assert(poly_los != NULL);
    assert(poly_eig != NULL);

    assert(poly_uni->task_count.size() == 0);
    assert(poly_crt->task_abs_odds.size() == 0);
    assert(poly_bal->task_bound.size() == 0);
    assert(poly_los->task_los.size() == 0);
    assert(poly_eig->task_eig.size() == 0);

    delete poly_uni;
    delete poly_crt;
    delete poly_bal;
    delete poly_los;
    delete poly_eig;
}
