#include <cassert>
#include <cstring>
#include <iostream>
#include <random>
#include <vector>
#include "crowd.h"
#include "infer.h"
#include "sim.h"

using namespace std;

/// sim budget ///

sim_budget::sim_budget(crowd* cwd, infer* inf, policy* poly, unsigned int n_labels):
    n_labels(n_labels)
{
    assert(n_labels > 0);
    this->cwd = cwd;
    this->inf = inf;
    this->poly = poly;
}

void sim_budget::run(mt19937& ran_gen)
{
    // run the simulation
    for(unsigned int t = 0; t < n_labels; ++t) {
        cwd->create_data_point(ran_gen, t);
        poly->choose_task(cwd, inf, t);
        inf->infer_update(ran_gen, cwd, t);
    }

    // redo the inference from scratch
    inf->infer_full(ran_gen, cwd, n_labels - 1);
}

sim_budget::~sim_budget()
{
    delete cwd;
    delete inf;
    delete poly;
}

void sim_budget_test()
{
    // simple crowd, inference and policy
    double work_prior = 0.7;
    double p = 0.8;
    int quota = 5;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd* cwd = new crowd_budget(d_work_acc, d_work_quota, n_workers);
    infer* inf = new infer_majority(n_tasks, work_prior);
    policy* poly = new policy_uncertainty();

    unsigned int n_labels = quota * n_workers;
    sim_budget s(cwd, inf, poly, n_labels);
    assert(s.n_labels == n_labels);

    // stochastic test: the error rate should be less than random
    mt19937 ran_gen = well_seeded_mt19937(4095);
    s.run(ran_gen);
    double error_rate = inf->error_rate();
    assert(error_rate < 0.5);
}

/// sim error ///

sim_error::sim_error(crowd* cwd, infer* inf, policy* poly, double target_err):
    target_odds(-math_logit(target_err))
{
    assert(target_err < 0.5);
    this->cwd = cwd;
    this->inf = inf;
    this->poly = poly;
}

void sim_error::run(mt19937& ran_gen)
{
    // run the simulation
    unsigned int t;
    bool halt_flag = false;
    for(t = 0; !halt_flag; ++t) {
        cwd->create_data_point(ran_gen, t);
        poly->choose_task(cwd, inf, t);
        inf->infer_update(ran_gen, cwd, t);

        // halt the simulation if all the tasks are accurate enough
        halt_flag = true;
        for(unsigned int i = 0; i < inf->task_odds.size(); ++i)
            if(inf->task_odds[i] < target_odds && inf->task_odds[i] > - target_odds) {
                halt_flag = false;
                break;
            }
    }

    // redo the inference from scratch
    inf->infer_full(ran_gen, cwd, t - 1);
}

sim_error::~sim_error()
{
    delete cwd;
    delete inf;
    delete poly;
}

void sim_error_test()
{
    // simple crowd, inference and policy
    double work_prior = 0.7;
    double p = 0.8;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 20;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    crowd* cwd = new crowd_workers(d_work_acc, n_workers);
    infer* inf = new infer_majority(n_tasks, work_prior);
    policy* poly = new policy_uncertainty();

    double target_err = 0.1;
    sim_error s(cwd, inf, poly, target_err);
    double target_odds = -math_logit(target_err);
    assert(s.target_odds == target_odds);

    // stochastic test: the error rate should be below our target
    mt19937 ran_gen = well_seeded_mt19937(1111);
    s.run(ran_gen);
    double error_rate = inf->error_rate();
    assert(error_rate <= target_err);

    // ever task must have accuracy above target
    for(unsigned int i = 0; i < n_tasks; ++i)
        assert(inf->task_odds[i] >= target_odds ||
               inf->task_odds[i] <= - target_odds);
}

/// sim parser ///

sim* sim_parse(int *argc, char **argv[])
{
    if(*argc < 1) {
        cerr << "Not enough arguments to parse the sim type" << endl;
        return NULL;
    }

    if(strcmp((*argv)[0], "sim_budget") == 0) {
        if(*argc < 2) {
            cerr << "Not enough arguments to instantiate sim_budget" << endl;
            return NULL;
        }
        unsigned int avg_labels_per_task = strtoul((*argv)[1], NULL, 10);
        *argc -= 2; *argv += 2;
        infer* inf = infer_parse(argc, argv);
        unsigned int n_tasks = inf->task_odds.size();
        unsigned int n_labels = avg_labels_per_task * n_tasks;
        crowd* cwd = crowd_parse(argc, argv, n_labels); // match n_workers for fixed quotas
        policy* poly = policy_parse(argc, argv);
        return (sim*) new sim_budget(cwd, inf, poly, n_labels);

    } else if(strcmp((*argv)[0], "sim_error") == 0) {
        if(*argc < 2) {
            cerr << "Not enough arguments to instantiate sim_error" << endl;
            return NULL;
        }
        double target_err = strtod((*argv)[1], NULL);
        *argc -= 2; *argv += 2;
        infer* inf = infer_parse(argc, argv);
        crowd* cwd = crowd_parse(argc, argv, 0);
        policy* poly = policy_parse(argc, argv);
        return (sim*) new sim_error(cwd, inf, poly, target_err);
    }

    cerr << "Unable to parse the sim type" << endl;
    return NULL;
}

void sim_parse_test()
{
    const char *argv_b[] = {"sim_budget", "3", "infer_weighted", "10", "crowd_budget", "80085", "dirac", "0.666", "kronecker", "4", "policy_uniform", "I am"};
    const char *argv_e[] = {"sim_error", "0.1", "infer_weighted", "10", "crowd_workers", "7", "dirac", "0.666", "policy_uncertainty", "your", "father"};
    char **argv_budget = (char**) argv_b;
    char **argv_error = (char**) argv_e;

    int argc_budget = 12;
    int argc_error = 11;

    sim_budget* sim_b = (sim_budget*) sim_parse(&argc_budget, &argv_budget);
    sim_error* sim_e = (sim_error*) sim_parse(&argc_error, &argv_error);

    assert(argc_budget == 1);
    assert(argc_error == 2);

    assert(strcmp(argv_budget[0], "I am") == 0);
    assert(strcmp(argv_error[0], "your") == 0);
    assert(strcmp(argv_error[1], "father") == 0);

    assert(sim_b != NULL);
    assert(sim_e != NULL);

    mt19937 ran_gen = well_seeded_mt19937(1492);
    sim_b->run(ran_gen);

    assert(sim_b->n_labels == 30);
    assert(sim_b->cwd->work_acc.size() == 8);

    delete sim_b;
    delete sim_e;
}
