#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstring>
#include <iostream>
#include <limits>
#include <cfloat>
#include <vector>
#include "crowd.h"
#include "distro.h"
#include "infer.h"

using namespace std;

/// math functions ///

double math_expit(double x)
{
    return 1.0 / (1.0 + exp(-x));
}

double math_logit(double x)
{
    assert(x > 0.0);
    assert( x < 1.0);

    return log(x / (1.0 - x));
}

double math_logit_safe(double x)
{
    if(x <= 0.0) x = numeric_limits<double>::epsilon();
    if(x >= 1.0) x = 1.0 - numeric_limits<double>::epsilon();

    return log(x / (1.0 - x));
}

void math_normalise_1(vector<double>& vec)
{
    double abs_sum = 0.0;
    for(double x: vec)
        abs_sum += (x >= 0.0)? x: -x;
    for(unsigned int i = 0; i < vec.size(); ++i)
        vec[i] /= abs_sum;
}

vector<double> lgammau_table;

void lgammau_init()
{
    lgammau_table.push_back(DBL_MAX);
    lgammau_table.push_back(0.0);
    lgammau_table.push_back(0.0);

    double lgu_tot = 0.0;
    for(unsigned int n = 3; n < LGU_TBL_MAX; ++n) {
        lgu_tot += log((double) n - 1.0);
        lgammau_table.push_back(lgu_tot);
    }
}

double lgammau(double n)
{
    if(n < LGU_TBL_MAX)
        return lgammau_table[(unsigned int) n];

    double lgu_tot = lgammau_table[LGU_TBL_MAX - 1];
    for(double k = LGU_TBL_MAX; k <= n; k += 1.0)
        lgu_tot += log(k - 1.0);
    return lgu_tot;
}

double math_log_beta(double a, double b)
{
    assert(a >= 0.0);
    assert(b >= 0.0);

    return lgammau(a) + lgammau(b) - lgammau(a + b);
}

/// shared functions ///

static double shared_error_num(vector<double>& task_odds)
{
    double err = 0.0;
    for(double z: task_odds)
        if(z < 0.0)
            err += 1.0;
        else if(z == 0.0) // 50% posterior is worth half an error
            err += 0.5;
    return err;
}

static double shared_error_predict(vector<double>& task_odds)
{
    double pred = 0.0;
    for(double z: task_odds) {
        double s = (z >= 0.0)? z: -z;
        pred += math_expit(s);
    }
    return 1.0 - pred / (double) task_odds.size();
}

/// worker history ///

bool history::operator==(const history& other) const
{
    return (this->time_steps == other.time_steps);
}

/// weighted majority voting ///

infer_weighted::infer_weighted(unsigned int n_tasks):
    t_next(0)
{
    assert(n_tasks > 0);
    task_odds.resize(n_tasks, 0.0);
}

void infer_weighted::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    t_next = 0;
    work_hist.resize(0); // make sure that infer_update computes work_weight again (safer)
    work_estim.resize(0);
    work_weight.resize(0);
    fill(task_odds.begin(), task_odds.end(), 0.0);

    for(unsigned k = 0; k <= t; ++k)
        this->infer_update(ran_gen, c, k);
}

void infer_weighted::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // extra workers are joining!
    if(c->work_acc.size() > work_hist.size()) {
        unsigned int n_workers = c->work_acc.size();
        for(unsigned int j = work_weight.size(); j < n_workers; ++j)
            work_weight.push_back(math_logit(c->work_acc[j])); // cache the workers' weights
        work_hist.resize(n_workers, history());
        work_estim = c->work_acc; // perfect estimates
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    double label = c->data_seq[t].label;
    assert(i < task_odds.size());
    assert(j < work_hist.size());

    // update
    work_hist[j].time_steps.push_back(t);
    task_odds[i] += label * work_weight[j];
}

double infer_weighted::error_num()
{
    return shared_error_num(task_odds);
}

double infer_weighted::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_weighted::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_weighted_test()
{
    unsigned int n_tasks = 10;
    infer_weighted wmv = infer_weighted(n_tasks);

    assert(wmv.task_odds.size() == n_tasks);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(1990);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
    }
    wmv.infer_full(ran_gen, &c, t_final);

    assert(wmv.work_hist.size() == n_workers);
    assert(wmv.work_weight.size() == n_workers);
    assert(wmv.error_rate() < 0.5);

    // check that the weights have been correctly computed
    double epsilon = numeric_limits<double>::epsilon();
    for(unsigned int j = 0; j < n_workers; ++j) {
        double diff = wmv.work_weight[j] - math_logit(c.work_acc[j]);
        assert(diff <= epsilon);
        assert(diff >= -epsilon);
    }
}

/// golden task estimates ///

infer_golden::infer_golden(unsigned int n_tasks, unsigned int n_trials, double alpha, double beta):
    t_next(0),
    n_trials(n_trials),
    alpha(alpha),
    beta(beta)
{
    assert(n_tasks > 0);
    task_odds.resize(n_tasks, 0.0);
}

void infer_golden::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    t_next = 0;
    work_hist.resize(0); // make sure that infer_update computes work_weight again (safer)
    work_estim.resize(0);
    work_weight.resize(0);
    fill(task_odds.begin(), task_odds.end(), 0.0);

    for(unsigned k = 0; k <= t; ++k)
        this->infer_update(ran_gen, c, k);
}

void infer_golden::run_trials(std::mt19937& ran_gen, crowd* c, unsigned int new_work)
{
    distro_uniform d(0.0, 1.0);

    // estimate the accuracy of each new worker
    unsigned int n_workers = c->work_acc.size();
    work_estim.resize(n_workers);
    work_weight.resize(n_workers);
    for(unsigned int j = new_work; j < n_workers; ++j) {

        // count the number of successes in n_trials
        unsigned int n_success = 0;
        vector<double> r = d.extract_vec(ran_gen, n_trials);
        for(unsigned int k = 0; k < n_trials; ++k)
            if(r[k] < c->work_acc[j])
                ++n_success;

        // estimate the accuracy with the posterior mean
        work_estim[j] = ((double) n_success + alpha) / ((double) n_trials + alpha + beta);
        work_weight[j] = math_logit_safe(work_estim[j]);
    }
}

void infer_golden::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // extra workers are joining!
    if(c->work_acc.size() > work_hist.size()) {
        this->run_trials(ran_gen, c, work_hist.size());
        work_hist.resize(c->work_acc.size(), history());
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    double label = c->data_seq[t].label;
    assert(i < task_odds.size());
    assert(j < work_hist.size());

    // update
    work_hist[j].time_steps.push_back(t);
    task_odds[i] += label * work_weight[j];
}

double infer_golden::error_num()
{
    return shared_error_num(task_odds);
}

double infer_golden::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_golden::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_golden_test()
{
    unsigned int n_tasks = 10;
    unsigned int n_trials = 5;
    double alpha = 2.0;
    double beta = 1.0;
    infer_golden gold = infer_golden(n_tasks, n_trials, alpha, beta);

    assert(gold.task_odds.size() == n_tasks);
    assert(gold.n_trials == n_trials);
    assert(gold.alpha == alpha);
    assert(gold.beta == beta);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(1990);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
    }
    gold.infer_full(ran_gen, &c, t_final);

    assert(gold.work_hist.size() == n_workers);
    assert(gold.work_estim.size() == n_workers);
    assert(gold.work_weight.size() == n_workers);
    assert(gold.error_rate() < 0.5);
}

/// majority voting ///

infer_majority::infer_majority(unsigned int n_tasks, double work_prior):
    t_next(0),
    work_prior(work_prior),
    work_weight(math_logit_safe(work_prior))
{
    assert(n_tasks > 0);
    task_odds.resize(n_tasks, 0.0);
}

void infer_majority::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    t_next = 0;
    fill(work_hist.begin(), work_hist.end(), history());
    fill(task_odds.begin(), task_odds.end(), 0.0);
    work_estim.resize(0);

    for(unsigned k = 0; k <= t; ++k)
        this->infer_update(ran_gen, c, k);
}

void infer_majority::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // extra workers are joining!
    if(c->work_acc.size() > work_hist.size()) {
        unsigned int n_workers = c->work_acc.size();
        work_hist.resize(n_workers, history());
        work_estim.resize(n_workers, work_prior);
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    double label = c->data_seq[t].label;
    assert(i < task_odds.size());
    assert(j < work_hist.size());

    // update
    work_hist[j].time_steps.push_back(t);
    task_odds[i] += label * work_weight;
}

double infer_majority::error_num()
{
    return shared_error_num(task_odds);
}

double infer_majority::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_majority::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_majority_test()
{
    unsigned int n_tasks = 10;
    double work_prior = 0.7;
    infer_majority maj(n_tasks, work_prior);

    assert(maj.task_odds.size() == n_tasks);
    double diff = maj.work_weight - math_logit(work_prior);
    double epsilon = numeric_limits<double>::epsilon();
    assert(diff <= epsilon);
    assert(diff >= -epsilon);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(1990);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
    }
    maj.infer_full(ran_gen, &c, t_final);

    assert(maj.work_hist.size() == n_workers);
    assert(maj.error_rate() < 0.5);
}

/// acyclic inference ///

infer_acyclic::infer_acyclic(unsigned int n_tasks, double alpha, double beta):
    t_next(0),
    alpha(alpha),
    beta(beta)
{
    assert(n_tasks > 0);
    task_odds.resize(n_tasks, 0.0);
    task_post.resize(n_tasks, 0.5);
}

void infer_acyclic::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    t_next = 0;
    fill(work_hist.begin(), work_hist.end(), history());
    fill(task_odds.begin(), task_odds.end(), 0.0);
    fill(task_post.begin(), task_post.end(), 0.5);
    work_estim.resize(0);

    for(unsigned k = 0; k <= t; ++k)
        this->infer_update(ran_gen, c, k);
}

void infer_acyclic::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // extra workers are joining!
    if(c->work_acc.size() > work_hist.size()) {
        unsigned int n_workers = c->work_acc.size();
        work_weight.resize(n_workers, 0.0); // work_weight gets overwritten later on anyway
        work_estim.resize(n_workers, 0.5);
        work_hist.resize(n_workers, history());
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    double label = c->data_seq[t].label;
    assert(i < task_odds.size());
    assert(j < work_hist.size());

    // estimate the worker's weight
    double correct = alpha;
    for(unsigned int k: work_hist[j].time_steps) {
        assert(c->data_seq[k].work_id == j);
        assert(c->data_seq[k].task_id < task_post.size());

        unsigned int task_id = c->data_seq[k].task_id;
        if(c->data_seq[k].label > 0.0)
            correct += task_post[task_id];
        else
            correct += 1.0 - task_post[task_id];
    }
    double total = alpha + beta + (double) work_hist[j].time_steps.size();
    work_estim[j] = correct / total;
    work_weight[j] = math_logit(work_estim[j]);

    // update the rest
    work_hist[j].time_steps.push_back(t);
    task_odds[i] += label * work_weight[j];
    task_post[i] = math_expit(task_odds[i]);
}

double infer_acyclic::error_num()
{
    return shared_error_num(task_odds);
}

double infer_acyclic::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_acyclic::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_acyclic_test()
{
    unsigned int n_tasks = 10;
    double alpha = 2.0;
    double beta = 1.0;
    infer_acyclic acy(n_tasks, alpha, beta);

    assert(acy.task_odds.size() == n_tasks);
    assert(acy.task_post.size() == n_tasks);
    assert(acy.alpha == alpha);
    assert(acy.beta == beta);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(12345);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
    }
    acy.infer_full(ran_gen, &c, t_final);

    assert(acy.work_weight.size() == n_workers);
    assert(acy.work_hist.size() == n_workers);
    assert(acy.error_rate() < 0.5);

    // check the value of the workers' prior weight
    acy.infer_full(ran_gen, &c, 0);
    unsigned int j = c.data_seq[0].work_id;
    double diff = acy.work_weight[j] - log(alpha / beta);
    double epsilon = numeric_limits<double>::epsilon();
    assert(diff <= epsilon);
    assert(diff >= -epsilon);
}

/// delayed acyclic inference ///

infer_delayed::infer_delayed(unsigned int n_tasks, double alpha, double beta):
    t_next(0),
    alpha(alpha),
    beta(beta)
{
    assert(n_tasks > 0);
    task_odds.resize(n_tasks, 0.0);
    view_odds.resize(n_tasks, vector<double>(n_tasks));
    view_post.resize(n_tasks, vector<double>(n_tasks));

    for(unsigned int i = 0; i < n_tasks; ++i)
        for(unsigned int j = 0; j < n_tasks; ++j) {
            view_odds[i][j] = 0.0;
            view_post[i][j] = 0.5;
        }
}

void infer_delayed::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    t_next = 0;

    work_estim.resize(0);
    fill(work_hist.begin(), work_hist.end(), history());
    fill(task_odds.begin(), task_odds.end(), 0.0);

    unsigned int n_tasks = task_odds.size();

    fill(view_odds.begin(), view_odds.end(), vector<double>(n_tasks));
    fill(view_post.begin(), view_post.end(), vector<double>(n_tasks));

    for(unsigned int i = 0; i < n_tasks; ++i)
        for(unsigned int j = 0; j < n_tasks; ++j) {
            view_odds[i][j] = 0.0;
            view_post[i][j] = 0.5;
        }

    for(unsigned k = 0; k <= t; ++k)
        this->infer_update(ran_gen, c, k);
}

void infer_delayed::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    unsigned int n_tasks = task_odds.size();
    unsigned int n_workers = c->work_acc.size();

    // extra workers are joining!
    if(n_workers > work_hist.size()) {
        work_hist.resize(n_workers, history());
        work_estim.resize(n_workers, 0.5);
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    double label = c->data_seq[t].label;
    assert(i < task_odds.size());
    assert(j < work_hist.size());

    // the new label changes the view for all tasks but the current one
    for(unsigned int view_i = 0; view_i < n_tasks; ++view_i) {
        if(view_i != i) {

            // estimate the worker's weight (skip
            double correct = alpha;
            double n_labels = 0.0;
            for(unsigned int k: work_hist[j].time_steps) {

                assert(c->data_seq[k].work_id == j);
                assert(c->data_seq[k].task_id < view_post[view_i].size());
                unsigned int task_id = c->data_seq[k].task_id;

                if(task_id != view_i) {
                    if(c->data_seq[k].label > 0.0)
                        correct += view_post[view_i][task_id];
                    else
                        correct += 1.0 - view_post[view_i][task_id];
                    n_labels += 1.0;
                }
            }
            double total = alpha + beta + n_labels;
            double work_weight = math_logit(correct / total);

            // update the odds of this view
            view_odds[view_i][i] += label * work_weight;
            view_post[view_i][i] = math_expit(view_odds[view_i][i]);
        }
        task_odds[view_i] = 0.0; // reset the log-odds on every task
    }

    // update the history
    work_hist[j].time_steps.push_back(t);

    // update the delayed odds
    for(unsigned int u = 0; u <= t; ++u) {
        i = c->data_seq[u].task_id;
        j = c->data_seq[u].work_id;
        label = c->data_seq[u].label;

        // estimate the worker's weight
        double correct = alpha;
        double n_labels = 0.0;
        for(unsigned int k: work_hist[j].time_steps) {
            unsigned int task_id = c->data_seq[k].task_id;

            if(task_id != i) {
                if(c->data_seq[k].label > 0.0)
                    correct += view_post[i][task_id];
                else
                    correct += 1.0 - view_post[i][task_id];
                n_labels += 1.0;
            }
        }
        double total = alpha + beta + n_labels;
        double work_weight = math_logit(correct / total);

        task_odds[i] += label * work_weight;
        work_estim[j] = correct / total;
    }
}

double infer_delayed::error_num()
{
    return shared_error_num(task_odds);
}

double infer_delayed::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_delayed::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_delayed_test()
{
    unsigned int n_tasks = 10;
    double alpha = 2.0;
    double beta = 1.0;
    infer_delayed del(n_tasks, alpha, beta);

    assert(del.task_odds.size() == n_tasks);
    assert(del.view_odds.size() == n_tasks);
    assert(del.view_post.size() == n_tasks);
    assert(del.view_odds[3].size() == n_tasks);
    assert(del.view_post[8].size() == n_tasks);
    assert(del.alpha == alpha);
    assert(del.beta == beta);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(12345);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
        del.infer_update(ran_gen, &c, t);
    }

    assert(del.work_hist.size() == n_workers);
    assert(del.error_rate() < 0.5);
}

/// optimised acyclic Bayesian inference

infer_quick::infer_quick(unsigned int n_tasks, double alpha, double beta):
    t_next(0),
    alpha(alpha),
    beta(beta)
{
    assert(n_tasks > 0);
    task_odds.resize(n_tasks, 0.0);
    sub_odds.resize(n_tasks, vector<double>(n_tasks));
    sub_post.resize(n_tasks, vector<double>(n_tasks));
}

void infer_quick::dataset_pass(crowd* c, int n_bit, unsigned int t_max)
{
    // prepare some clever indexing bitmaps
    unsigned int bit_swap, bit_mask, sub_inc;
    if(n_bit > 0) {
        bit_swap = 1 << (n_bit - 1);
        bit_mask = bit_swap * 0xFFFF; // works up to 65K tasks
        sub_inc = bit_swap;
    } else {
        bit_swap = 0;
        bit_mask = 0xFFFF; // works up to 65K tasks
        sub_inc = 1;
    }

    // copy the parent's subset partial posterior
    for(unsigned int i_sub = 0; i_sub < sub_odds.size(); i_sub += sub_inc) {
        unsigned int i_prev = i_sub & ~bit_swap;
        for(unsigned int i = 0; i < sub_odds.size(); ++i) {
            sub_odds[i_sub][i] = sub_odds[i_prev][i];
            sub_post[i_sub][i] = sub_post[i_prev][i];
        }
    }

    // scan the whole dataset once
    for(unsigned int t = 0; t < t_max; ++t) {

        // unpack the current data point
        assert(t < c->data_seq.size());
        unsigned int i = c->data_seq[t].task_id;
        unsigned int j = c->data_seq[t].work_id;
        double label = c->data_seq[t].label;
        assert(i < task_odds.size());
        assert(j < work_hist.size());

        // find the subset to update
        unsigned int i_sub = (i & bit_mask) ^ bit_swap;
        if(i_sub >= sub_odds.size())
            continue;

        // estimate the worker's weight
        double correct = alpha;
        double total = alpha + beta;
        for(unsigned int k: work_hist[j].time_steps) {
            unsigned int task_id = c->data_seq[k].task_id;
            unsigned int base_id = (task_id & bit_mask) ^ bit_swap;

            // skip tasks that belong to this subset
            // or have yet to be taken into consideration
            if((task_id >= i_sub && task_id < i_sub + bit_swap) ||
               (base_id == i_sub && k >= t))
                continue;

            // increment the pseudo-counts
            if(c->data_seq[k].label > 0.0)
                correct += sub_post[i_sub][task_id];
            else
                correct += 1.0 - sub_post[i_sub][task_id];
            total += 1.0;
        }
        work_estim[j] = correct / total;
        double work_weight = math_logit(work_estim[j]);

        // update the partial posterior
        sub_odds[i_sub][i] += label * work_weight;
        sub_post[i_sub][i] = math_expit(sub_odds[i_sub][i]);
    }
}

void infer_quick::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    unsigned int n_tasks = sub_odds.size();
    for(unsigned int i = 0; i < n_tasks; ++i)
        for(unsigned int j = 0; j < n_tasks; ++j) {
            sub_odds[i][j] = 0.0;
            sub_post[i][j] = 0.5;
        }

    // initialise the full worker history
    fill(work_hist.begin(), work_hist.end(), history());
    for(unsigned int h = 0; h <= t; ++h) {
        unsigned int j = c->data_seq[h].work_id;
        if(j >= work_hist.size())
            work_hist.resize(j + 1, history());
        work_hist[j].time_steps.push_back(h);
    }
    work_estim.resize(work_hist.size(), 0.5);

    // find the minimum power of two greater than the number of tasks
    unsigned int pow_2 = 1;
    int n_bit = 0;
    while(pow_2 < n_tasks) {
        pow_2 *= 2;
        ++n_bit;
    }

    // split the inference in smaller and smaller subsets
    for(; n_bit >= 0; --n_bit)
        this->dataset_pass(c, n_bit, t + 1);

    // copy the final predictions
    for(unsigned int i = 0; i < n_tasks; ++i)
        task_odds[i] = sub_odds[i][i];
}

void infer_quick::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // extra workers are joining!
    if(c->work_acc.size() > work_hist.size()) {
        unsigned int n_workers = c->work_acc.size();
        work_hist.resize(n_workers, history());
        work_estim.resize(n_workers, 0.5);
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int j = c->data_seq[t].work_id;
    assert(j < work_hist.size());

    // update the history
    work_hist[j].time_steps.push_back(t);

    // infer quick only works in off-line mode
    // thus no real inference is done here
}

double infer_quick::error_num()
{
    return shared_error_num(task_odds);
}

double infer_quick::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_quick::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_quick_test()
{
    unsigned int n_tasks = 10;
    double alpha = 2.0;
    double beta = 1.0;
    infer_quick qck(n_tasks, alpha, beta);

    assert(qck.task_odds.size() == n_tasks);
    assert(qck.sub_odds.size() == n_tasks);
    assert(qck.sub_post.size() == n_tasks);
    assert(qck.sub_odds[3].size() == n_tasks);
    assert(qck.sub_post[8].size() == n_tasks);
    assert(qck.alpha == alpha);
    assert(qck.beta == beta);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(12345);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
    }
    qck.infer_full(ran_gen, &c, t_final);

    assert(qck.work_hist.size() == n_workers);
    assert(qck.error_rate() < 0.5);
}

/// variational inference ///

infer_variational::infer_variational(unsigned int n_tasks, double alpha, double beta, unsigned int n_iter_full, unsigned int n_iter_update):
    t_next(0),
    alpha(alpha),
    beta(beta),
    n_iter_full(n_iter_full),
    n_iter_update(n_iter_update)
{
    assert(n_tasks > 0);
    assert(n_iter_full > 0);
    task_odds.resize(n_tasks, 0.0);
    task_post.resize(n_tasks, 0.5);
}

void infer_variational::maximisation(crowd* c)
{
    assert(c->data_seq.size() >= t_next - 1);
    assert(work_hist.size() >= correct.size());
    assert(work_weight.size() >= correct.size());

    // scan all the data observed so far
    fill(correct.begin(), correct.end(), 0.0);
    for(unsigned int t = 0; t < t_next; ++t) {

        // unpack the data point
        unsigned int i = c->data_seq[t].task_id;
        unsigned int j = c->data_seq[t].work_id;
        assert(i < task_post.size());
        assert(j < correct.size());

        // accumulate the posterior
        if(c->data_seq[t].label > 0.0)
            correct[j] += task_post[i];
        else
            correct[j] += 1.0 - task_post[i];
    }

    // compute the workers' weights
    for(unsigned int j = 0; j < correct.size(); ++j) {
        double total = alpha + beta + (double) work_hist[j].time_steps.size();
        work_estim[j] = (correct[j] + alpha) / total;
        work_weight[j] = math_logit(work_estim[j]);
    }
}

void infer_variational::expectation(crowd* c)
{
    assert(c->data_seq.size() >= t_next - 1);
    assert(task_post.size() >= task_odds.size());

    // scan all the data observed so far
    fill(task_odds.begin(), task_odds.end(), 0.0);
    for(unsigned int t = 0; t < t_next; ++t) {

        // unpack the data point
        unsigned int i = c->data_seq[t].task_id;
        unsigned int j = c->data_seq[t].work_id;
        assert(i < task_odds.size());
        assert(j < work_weight.size());

        // weighted majority voting
        task_odds[i] += c->data_seq[t].label * work_weight[j];
    }

    // compute the posterior on the tasks
    for(unsigned int i = 0; i < task_odds.size(); ++i)
        task_post[i] = math_expit(task_odds[i]);
}

void infer_variational::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    assert(c->data_seq.size() >= t);
    t_next = t + 1;

    fill(task_odds.begin(), task_odds.end(), 0.0);
    fill(task_post.begin(), task_post.end(), 0.5);
    work_hist.resize(0);
    work_hist.resize(c->work_acc.size(), history());
    work_weight.resize(c->work_acc.size());
    work_estim.resize(c->work_acc.size());
    correct.resize(c->work_acc.size());

    // prepare the worker history
    for(t = 0; t < t_next; ++t) {
        unsigned int j = c->data_seq[t].work_id;
        assert(j < work_hist.size());
        work_hist[j].time_steps.push_back(t);
    }

    // run the required number of iterations
    for(unsigned h = 0; h < n_iter_full; ++h) {
        this->maximisation(c);
        this->expectation(c);
    }
}

void infer_variational::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // extra workers are joining!
    if(c->work_acc.size() > work_hist.size()) {
        unsigned int n_workers = c->work_acc.size();
        double work_prior = math_logit(alpha / (alpha + beta));
        work_hist.resize(n_workers, history());
        work_weight.resize(n_workers, work_prior); // initialise work_weight
        work_estim.resize(n_workers, alpha / (alpha + beta));
        correct.resize(n_workers);
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    double label = c->data_seq[t].label;
    assert(i < task_odds.size());
    assert(j < work_hist.size());

    // perform a quick expectation step
    work_hist[j].time_steps.push_back(t);
    task_odds[i] += label * work_weight[j];
    task_post[i] = math_expit(task_odds[i]);

    // run the required number of iterations
    for(unsigned int h = 0; h < n_iter_update; ++h) {
        this->maximisation(c);
        this->expectation(c);
    }
}

double infer_variational::error_num()
{
    return shared_error_num(task_odds);
}

double infer_variational::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_variational::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_variational_test()
{
    unsigned int n_tasks = 10;
    double alpha = 2.0;
    double beta = 1.0;
    unsigned int n_iter_full = 50;
    unsigned int n_iter_update = 50;
    infer_variational var(n_tasks, alpha, beta, n_iter_full, n_iter_update);

    assert(var.task_odds.size() == n_tasks);
    assert(var.task_post.size() == n_tasks);
    assert(var.alpha == alpha);
    assert(var.beta == beta);
    assert(var.n_iter_full == n_iter_full);
    assert(var.n_iter_update == n_iter_update);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(12345);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
        var.infer_update(ran_gen, &c, t);
    }

    assert(var.work_weight.size() == n_workers);
    assert(var.work_hist.size() == n_workers);
    assert(var.error_rate() < 0.5);

    // repeat the inference from scratch
    vector<double> task_odds_final = var.task_odds;
    vector<double> task_post_final = var.task_post;
    vector<double> work_weight_final = var.work_weight;
    vector<history> work_hist_final = var.work_hist;

    var.infer_full(ran_gen, &c, t_final);

    double epsilon = numeric_limits<double>::epsilon();
    for(unsigned int i = 0; i < n_tasks; ++i) {
        double diff_odds = var.task_odds[i] - task_odds_final[i];
        double diff_post = var.task_post[i] - task_post_final[i];
        assert(diff_odds <= epsilon);
        assert(diff_odds >= -epsilon);
        assert(diff_post <= epsilon);
        assert(diff_post >= -epsilon);
    }
    for(unsigned j = 0; j < n_workers; ++j) {
        double diff_weight = var.work_weight[j] - work_weight_final[j];
        assert(diff_weight <= epsilon);
        assert(diff_weight >= -epsilon);
    }
    assert(var.work_hist == work_hist_final);
}

/// matrix factorisation

infer_eigen::infer_eigen(unsigned int n_tasks, unsigned int n_iter_full, unsigned int n_iter_update):
    t_next(0),
    n_iter_full(n_iter_full),
    n_iter_update(n_iter_update)
{
    assert(n_tasks > 0);
    assert(n_iter_full > 0);
    task_odds.resize(n_tasks, 0.0);
}

void infer_eigen::power_iteration(crowd *c)
{
    assert(c->data_seq.size() >= t_next - 1);

    // estimate the worker log odds
    fill(work_weight.begin(), work_weight.end(), 0.0);
    for(unsigned int t = 0; t < t_next; ++t) {

        // unpack the data point
        unsigned int i = c->data_seq[t].task_id;
        unsigned int j = c->data_seq[t].work_id;
        assert(i < task_odds.size());
        assert(j < work_weight.size());

        // spammer/hammer model
        work_weight[j] += c->data_seq[t].label * task_odds[i];
    }

    // estimate the task log odds
    fill(task_odds.begin(), task_odds.end(), 0.0);
    for(unsigned int t = 0; t < t_next; ++t) {

        // unpack the data point
        unsigned int i = c->data_seq[t].task_id;
        unsigned int j = c->data_seq[t].work_id;
        assert(i < task_odds.size());
        assert(j < work_weight.size());

        // weighted majority voting
        task_odds[i] += c->data_seq[t].label * work_weight[j];
    }

    // normalise the task odds
    math_normalise_1(task_odds);
}

void infer_eigen::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    assert(c->data_seq.size() >= t);
    t_next = t + 1;

    fill(task_odds.begin(), task_odds.end(), 0.0);
    work_hist.resize(0);
    work_hist.resize(c->work_acc.size(), history());
    work_weight.resize(c->work_acc.size());
    work_estim.resize(c->work_acc.size(), 0.5); // this method provides no estimates

    // scan all the data observed so far
    for(t = 0; t < t_next; ++t) {

        // unpack the data point
        unsigned int i = c->data_seq[t].task_id;
        unsigned int j = c->data_seq[t].work_id;
        assert(i < task_odds.size());
        assert(j < work_hist.size());

        // majority voting initialisation
        task_odds[i] += c->data_seq[t].label;

        // prepare the worker history
        work_hist[j].time_steps.push_back(t);
    }

    // normalise the task odds
    math_normalise_1(task_odds);

    // run the required number of iterations
    for(unsigned h = 0; h < n_iter_full; ++h)
        this->power_iteration(c);
}

void infer_eigen::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // just update the history
    if(n_iter_update == 0) {
        if(c->work_acc.size() > work_hist.size()) {
            work_hist.resize(c->work_acc.size(), history());
            work_estim.resize(c->work_acc.size(), 0.5);
        }
        assert(t < c->data_seq.size());
        unsigned int j = c->data_seq[t].work_id;
        assert(j < work_hist.size());
        work_hist[j].time_steps.push_back(t);

    // there is no easy way of updating the factorisation
    // we use the full algorithm instead
    } else {
        unsigned int n_iter_backup = n_iter_full;
        n_iter_full = n_iter_update;
        this->infer_full(ran_gen, c, t);
        n_iter_full = n_iter_backup;
    }
}

double infer_eigen::error_num()
{
    return shared_error_num(task_odds);
}

double infer_eigen::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_eigen::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_eigen_test()
{
    unsigned int n_tasks = 10;
    unsigned int n_iter_full = 50;
    unsigned int n_iter_update = 50;
    infer_eigen eig(n_tasks, n_iter_full, n_iter_update);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(12345);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
        eig.infer_update(ran_gen, &c, t);
    }

    assert(eig.work_weight.size() == n_workers);
    assert(eig.work_hist.size() == n_workers);
    assert(eig.error_rate() < 0.5);

    // check that the initial parameters stayed the same
    assert(eig.task_odds.size() == n_tasks);
    assert(eig.n_iter_full == n_iter_full);
    assert(eig.n_iter_update == n_iter_update);
}

/// triangular estimation ///

infer_triangle::infer_triangle(unsigned int n_tasks, unsigned int bool_iter_full):
    t_next(0),
    bool_iter_full(bool_iter_full)
{
    assert(n_tasks > 0);

    task_odds.resize(n_tasks, 0.0);
}

void infer_triangle::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(c->data_seq.size() >= t);
    t_next = 0;

    // call the sequential routine
    bool_iter_full = 1;
    for(unsigned int h = 0; h < t; ++h)
        this->infer_update(ran_gen, c, h);

    bool_iter_full = 0;
    this->infer_update(ran_gen, c, t);
}

static double find_max(vector<vector<double>>& matrix,
                       unsigned int *i_max,
                       unsigned int *j_max,
                       unsigned int i_skip,
                       unsigned int j_skip,
                       unsigned int n)
{
    double v_max = -1.0;
    *i_max = 0;
    *j_max = 0;

    for(unsigned int i = 0; i < n; ++i) {
        if(i == i_skip) continue;

        for(unsigned int j = i + 1; j < n; ++j) {
            if(j == j_skip) continue;

            double v_abs = (matrix[i][j] >= 0.0)? matrix[i][j]: -matrix[i][j];
            if(v_abs > v_max) {
                *i_max = i;
                *j_max = j;
                v_max = v_abs;
            }
        }
    }

    return v_max;
}

void infer_triangle::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // new workers
    unsigned int n_workers = c->work_acc.size();
    if(work_weight.size() < n_workers) {
        work_corr.resize(n_workers, vector<double>(n_workers));
        work_norm.resize(n_workers, vector<double>(n_workers));
        work_prod.resize(n_workers, vector<double>(n_workers));
        work_hist.resize(n_workers, history());
        work_theta.resize(n_workers);
        work_weight.resize(n_workers);
        work_estim.resize(n_workers, 0.5);
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    double label = c->data_seq[t].label;
    assert(i < task_odds.size());
    assert(j < work_hist.size());

    // update the history
    work_hist[j].time_steps.push_back(t);

    // for each worker
    for(unsigned int k = 0; k < n_workers; ++k) {

        // is there a label that matches the current task?
        double label_k = 0.0;
        for(unsigned int h: work_hist[k].time_steps)
            if(c->data_seq[h].task_id == i)
                label_k = c->data_seq[h].label;

        // update the scalar products and the absolute norm
        double prod = label * label_k;
        double norm = (prod >= 0.0)? prod: -prod;
        work_prod[j][k] += prod;
        work_norm[j][k] += norm;
        work_prod[k][j] = work_prod[j][k];
        work_norm[k][j] = work_norm[j][k];

        // update the correlation estimate
        double denom = (work_norm[j][k] >= 1.0)? work_norm[j][k]: 1.0;
        work_corr[j][k] = work_prod[j][k] / denom;
        work_corr[k][j] = work_corr[j][k];
    }

    // skip unnecessary computation
    if(bool_iter_full)
        return;

    // find the highest correlation index
    unsigned int r_max, s_max;
    find_max(work_corr, &r_max, &s_max, n_workers, n_workers, n_workers);

    // estimate the absolute workers' accuracy
    for(unsigned int k = 0; k < n_workers; ++k) {
        unsigned int r_alt = r_max;
        unsigned int s_alt = s_max;
        if(k == r_max || k == s_max)
            find_max(work_corr, &r_alt, &s_alt, k, k, n_workers);

        double theta = 0.0;
        if(work_corr[r_alt][s_alt] != 0.0)
            theta = work_corr[r_alt][k] * work_corr[s_alt][k] / work_corr[r_alt][s_alt];
        work_theta[k] = (theta > 0.0)? theta: -theta;
    }

    // find the reference sign
    double v_max, abs_v_max = -1.0;
    unsigned int k_max = 0;
    for(unsigned int k = 0; k < n_workers; ++k) {
        double v = work_theta[k];
        for(unsigned int h = 0; h < n_workers; ++h) {
            if(h == k) continue;
            v += work_corr[h][k];
        }
        double abs_v = (v >= 0.0)? v: -v;
        if(abs_v > abs_v_max) {
            v_max = v;
            abs_v_max = abs_v;
            k_max = k;
        }
    }
    work_theta[k_max] = sqrt(work_theta[k_max]);
    if(v_max < 0.0) work_theta[k_max] = -work_theta[k_max];

    // find the sign of all other workers
    for(unsigned int k = 0; k < n_workers; ++k) {
        if(k == k_max) continue;
        double sign = work_theta[k_max] * work_corr[k][k_max];
        if(sign >= 0.0)
            work_theta[k] = sqrt(work_theta[k]);
        else
            work_theta[k] = -sqrt(work_theta[k]);
    }

    // compute the weights
    for(unsigned int k = 0; k < n_workers; ++k) {
        work_estim[k] = 0.5 * (work_theta[k] + 1.0);
        work_weight[k] = math_logit_safe(work_estim[k]);
    }

    // weighted majority voting
    fill(task_odds.begin(), task_odds.end(), 0.0);
    for(unsigned int h = 0; h <= t; ++h) {
        unsigned int i = c->data_seq[h].task_id;
        unsigned int j = c->data_seq[h].work_id;
        double label = c->data_seq[h].label;
        task_odds[i] += label * work_weight[j];
    }
}

double infer_triangle::error_num()
{
    return shared_error_num(task_odds);
}

double infer_triangle::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_triangle::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_triangle_test()
{
    unsigned int n_tasks = 10;
    unsigned int bool_iter_full = 1;
    infer_triangle tria(n_tasks, bool_iter_full);
    assert(tria.task_odds.size() == n_tasks);

    // define a simple crowd
    double p = 0.8;
    int quota = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(12345);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
        tria.infer_update(ran_gen, &c, t);
    }

    assert(tria.work_corr.size() == n_workers);
    assert(tria.work_norm.size() == n_workers);
    assert(tria.work_prod.size() == n_workers);
    assert(tria.work_hist.size() == n_workers);
    assert(tria.work_theta.size() == n_workers);
    assert(tria.work_weight.size() == n_workers);
    assert(tria.error_rate() <= 0.5);
}

/// mirrored importance sampling ///

particle::particle(const particle& part)
{
    work_right = part.work_right;
    work_wrong = part.work_wrong;
    task_class = part.task_class;
    log_weight = part.log_weight;
}

particle::particle(mt19937& ran_gen,
                   unsigned int n_tasks)
{
    task_class.resize(n_tasks, 0.0);

    uniform_real_distribution<double> d(0.0, 1.0);
    for(unsigned int i = 0; i < n_tasks; ++i)
        task_class[i] = (d(ran_gen) < 0.5)? +1.0: -1.0;

    log_weight = 0.0;
}

void particle::update_weight(crowd *c,
                             unsigned int t,
                             double alpha,
                             double beta)
{
    // resize the worker counter if necessary
    if(c->work_acc.size() > work_right.size()) {
        work_right.resize(c->work_acc.size());
        work_wrong.resize(c->work_acc.size());
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    double label = c->data_seq[t].label;
    assert(i < task_class.size());
    assert(j < work_right.size());

    // update the weight and the counts
    log_weight -= math_log_beta(work_right[j] + alpha, work_wrong[j] + beta);
    if(label == task_class[i])
        work_right[j] += 1.0;
    else
        work_wrong[j] += 1.0;
    log_weight += math_log_beta(work_right[j] + alpha, work_wrong[j] + beta);
}

void particle::extract_class(crowd* c,
                             history& task_hist,
                             unsigned int i,
                             double ran_num,
                             double alpha,
                             double beta)
{
    // for all workers who labelled this task
    double log_odds_swap = 0.0;
    for(unsigned int t: task_hist.time_steps) {

        // unpack the related data point
        assert(t < c->data_seq.size());
        assert(c->data_seq[t].task_id == i);
        unsigned int j = c->data_seq[t].work_id;
        double label = c->data_seq[t].label;
        assert(j < work_right.size());

        // compute the log-likelihood difference after changing task_class[i]
        double change = label * task_class[i];
        log_odds_swap -= math_log_beta(work_right[j] + alpha,
                                       work_wrong[j] + beta);
        log_odds_swap += math_log_beta(work_right[j] + alpha - change,
                                       work_wrong[j] + beta + change);
    }

    // make a Gibbs step, i.e. change task_class[i] stochastically
    double p_swap = math_expit(log_odds_swap);
    if(ran_num < p_swap) {
        task_class[i] = -task_class[i];

        // correct the worker counters
        for(unsigned int t: task_hist.time_steps) {
            unsigned int j = c->data_seq[t].work_id;
            double label = c->data_seq[t].label;

            // check if the worker label aligns with the new class
            if(label == task_class[i]) {
                assert(work_wrong[j] > 0);
                work_right[j] += 1.0;
                work_wrong[j] -= 1.0;
            } else {
                assert(work_right[j] > 0);
                work_right[j] -= 1.0;
                work_wrong[j] += 1.0;
            }
        }
    }
}

void particle::flip_classes(double ran_num,
                            double alpha,
                            double beta)
{
    // compute the log-likelihood difference after flipping all the classes
    double log_odds_swap = 0.0;
    for(unsigned int j = 0; j < work_right.size(); ++j) {
        log_odds_swap -= math_log_beta(work_right[j] + alpha,
                                       work_wrong[j] + beta);
        log_odds_swap += math_log_beta(work_wrong[j] + alpha,
                                       work_right[j] + beta);
    }

    // flip all the classes accordingly
    double p_swap = math_expit(log_odds_swap);
    if(ran_num < p_swap) {
        for(unsigned int i = 0; i < task_class.size(); ++i)
            task_class[i] = -task_class[i];
        work_right.swap(work_wrong);
    }
}

// overwrite the categorical probability with its cumulative
static void compute_cumsum_inplace(vector<double>& weight)
{
    for(unsigned int i = 1; i < weight.size(); ++i)
        weight[i] += weight[i - 1];
}

// binary search: use cumulative weights starting with w[0] and ending with MAX(ran_num)
static unsigned int extract_id(vector<double>& cum_sum,
                               unsigned int a,
                               unsigned int b,
                               double ran_num)
{
    if(b - a == 1)
        return a;

    unsigned int i = (b - a) / 2 + a - 1;

    if(ran_num >= cum_sum[i])
        return extract_id(cum_sum, i + 1, b, ran_num);
    else
        return extract_id(cum_sum, a, i + 1, ran_num);
}

// extract from a categorical distribution with replacement
static void extract_replace(vector<double>& weight,
                            vector<unsigned int>& sample,
                            mt19937& ran_gen)
{
    sample.resize(weight.size());
    fill(sample.begin(), sample.end(), 0);

    compute_cumsum_inplace(weight);
    uniform_real_distribution<double> d(0.0, weight[weight.size() - 1]);

    for(unsigned int n = 0; n < weight.size(); ++n) {
        unsigned int i = extract_id(weight, 0, weight.size(), d(ran_gen));
        ++sample[i];
    }
}

void infer_particle::resample_particles(mt19937& ran_gen)
{
    // compute the weights (no need to normalise them)
    for(unsigned int i = 0; i < swarm.size(); ++i) {
        swarm_weight[i] = exp(swarm[i].log_weight);
        swarm[i].log_weight = 0.0; // we are sampling according to the likelihood
    }

    // extract a subset of particles with replacement (duplicates!)
    extract_replace(swarm_weight, swarm_sample, ran_gen);

    // overwrite the particles we do not want (j) with the duplicates (i)
    unsigned int j = 0;
    for(unsigned int i = 0; i < swarm.size(); ++i)
        for(unsigned int n = 1; n < swarm_sample[i]; ++n) {
            while(swarm_sample[j] > 0 && j < swarm.size()) ++j;
            swarm[j] = swarm[i];
            ++j;
        }
}

infer_particle::infer_particle(unsigned int n_tasks,
                           double alpha,
                           double beta,
                           unsigned int n_particles,
                           unsigned int n_gibbs):
    t_next(0),
    alpha(alpha),
    beta(beta),
    n_particles(n_particles),
    n_gibbs(n_gibbs)
{
    assert(n_tasks > 0);
    assert(n_particles > 0);

    task_odds.resize(n_tasks, 0.0);
    task_post.resize(n_tasks, 0.5);
    task_hist.resize(n_tasks, history());
}

void infer_particle::infer_full(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(c->data_seq.size() >= t);
    t_next = 0;

    // call the sequential routine
    for(unsigned int h = 0; h <= t; ++h)
        this->infer_update(ran_gen, c, h);
}

void infer_particle::infer_update(std::mt19937& ran_gen, crowd* c, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // extra workers are joining!
    if(c->work_acc.size() > work_hist.size()) {
        unsigned int n_workers = c->work_acc.size();
        work_hist.resize(n_workers, history());
        work_estim.resize(n_workers, 0.5);
    }

    // create or extend the set of particles
    if(swarm.size() < n_particles) {
        for(unsigned int k = swarm.size(); k < n_particles; ++k)
            swarm.push_back(particle(ran_gen, task_odds.size()));
        swarm_weight.resize(n_particles);
        swarm_sample.resize(n_particles);
    }

    // unpack the new data point
    assert(t < c->data_seq.size());
    unsigned int i = c->data_seq[t].task_id;
    unsigned int j = c->data_seq[t].work_id;
    assert(i < task_hist.size());
    assert(j < work_hist.size());

    // update the histories
    task_hist[i].time_steps.push_back(t);
    work_hist[j].time_steps.push_back(t);

    // update the particle weights with the new data point
    for(unsigned int k = 0; k < swarm.size(); ++k)
        swarm[k].update_weight(c, t, alpha, beta);

    // resample the particles
    this->resample_particles(ran_gen);

    // move the particles around by the required number of Gibbs steps
    // note: doing so in random order is better than n_tasks steps in order!
    unsigned int n_tasks = task_odds.size();
    uniform_int_distribution<int> u(0, n_tasks - 1);
    uniform_real_distribution<double> d(0.0, 1.0);
    for(unsigned int k = 0; k < swarm.size(); ++k)
        for(unsigned int g = 0; g < n_gibbs; ++g) {
            unsigned int i_change = u(ran_gen);
            swarm[k].extract_class(c, task_hist[i_change], i_change, d(ran_gen), alpha, beta);
        }

    // try and flip the classes of one random particle
    // note: this avoids getting stuck in the wrong mode of the posterior
    unsigned int p = d(ran_gen) * n_particles;
    swarm[p].flip_classes(d(ran_gen), alpha, beta);

    // compute the posterior
    fill(task_post.begin(), task_post.end(), 0.0);
    for(unsigned int k = 0; k < swarm.size(); ++k)
        for(i = 0; i < task_post.size(); ++i)
            task_post[i] += swarm[k].task_class[i];
    for(i = 0; i < task_post.size(); ++i) {
        task_post[i] = (1.0 + task_post[i] / (double) n_particles) / 2.0;
        task_odds[i] = math_logit_safe(task_post[i]);
    }
}

double infer_particle::error_num()
{
    return shared_error_num(task_odds);
}

double infer_particle::error_rate()
{
    return this->error_num() / (double) task_odds.size();
}

double infer_particle::error_predict()
{
    return shared_error_predict(task_odds);
}

void infer_particle_test()
{
    unsigned int n_tasks = 10;
    double alpha = 2.0;
    double beta = 1.0;
    unsigned int n_particles = 10;
    unsigned int n_gibbs = 1;
    infer_particle mir(n_tasks, alpha, beta, n_particles, n_gibbs);

    assert(mir.task_odds.size() == n_tasks);
    assert(mir.alpha == alpha);
    assert(mir.beta == beta);
    assert(mir.n_particles == n_particles);
    assert(mir.n_gibbs == n_gibbs);

    // define a simple crowd
    double p = 0.8;
    int quota = 5;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    unsigned int n_workers = 10;
    crowd_budget c(d_work_acc, d_work_quota, n_workers);

    // generate all the data with round robin task allocation
    mt19937 ran_gen = well_seeded_mt19937(12345);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        c.data_seq[t].task_id = t % n_tasks;
        mir.infer_update(ran_gen, &c, t);
    }

    assert(mir.work_hist.size() == n_workers);
    assert(mir.error_rate() < 0.5);
}

/// parser ///

infer* infer_parse(int *argc, char **argv[])
{
    if(*argc < 1) {
        cerr << "Not enough arguments to parse the infer type" << endl;
        return NULL;
    }

    if(strcmp((*argv)[0], "infer_weighted") == 0) {
        if(*argc < 2) {
            cerr << "Not enough arguments to instantiate infer_weighted" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        *argc -= 2; *argv += 2;
        return (infer*) new infer_weighted(n_tasks);

    } else if(strcmp((*argv)[0], "infer_golden") == 0) {
        if(*argc < 5) {
            cerr << "Not enough arguments to instantiate infer_golden" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        unsigned int n_trials = strtoul((*argv)[2], NULL, 10);
        double alpha = strtod((*argv)[3], NULL);
        double beta = strtod((*argv)[4], NULL);
        *argc -= 5; *argv += 5;
        return (infer*) new infer_golden(n_tasks, n_trials, alpha, beta);

    } else if(strcmp((*argv)[0], "infer_majority") == 0) {
        if(*argc < 3) {
            cerr << "Not enough arguments to instantiate infer_majority" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        double work_prior = strtod((*argv)[2], NULL);
        *argc -= 3; *argv += 3;
        return (infer*) new infer_majority(n_tasks, work_prior);

    } else if(strcmp((*argv)[0], "infer_acyclic") == 0) {
        if(*argc < 4) {
            cerr << "Not enough arguments to instantiate infer_acyclic" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        double alpha = strtod((*argv)[2], NULL);
        double beta = strtod((*argv)[3], NULL);
        *argc -= 4; *argv += 4;
        return (infer*) new infer_acyclic(n_tasks, alpha, beta);

    } else if(strcmp((*argv)[0], "infer_delayed") == 0) {
        if(*argc < 4) {
            cerr << "Not enough arguments to instantiate infer_delayed" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        double alpha = strtod((*argv)[2], NULL);
        double beta = strtod((*argv)[3], NULL);
        *argc -= 4; *argv += 4;
        return (infer*) new infer_delayed(n_tasks, alpha, beta);

    } else if(strcmp((*argv)[0], "infer_quick") == 0) {
        if(*argc < 4) {
            cerr << "Not enough arguments to instantiate infer_quick" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        double alpha = strtod((*argv)[2], NULL);
        double beta = strtod((*argv)[3], NULL);
        *argc -= 4; *argv += 4;
        return (infer*) new infer_quick(n_tasks, alpha, beta);

    } else if(strcmp((*argv)[0], "infer_variational") == 0) {
        if(*argc < 6) {
            cerr << "Not enough arguments to instantiate infer_variational" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        double alpha = strtod((*argv)[2], NULL);
        double beta = strtod((*argv)[3], NULL);
        unsigned int n_iter_full = strtoul((*argv)[4], NULL, 10);
        unsigned int n_iter_update = strtoul((*argv)[5], NULL, 10);
        *argc -= 6; *argv += 6;
        return (infer*) new infer_variational(n_tasks, alpha, beta, n_iter_full, n_iter_update);

    } else if(strcmp((*argv)[0], "infer_eigen") == 0) {
        if(*argc < 4) {
            cerr << "Not enough arguments to instantiate infer_eigen" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        unsigned int n_iter_full = strtoul((*argv)[2], NULL, 10);
        unsigned int n_iter_update = strtoul((*argv)[3], NULL, 10);
        *argc -= 4; *argv += 4;
        return (infer*) new infer_eigen(n_tasks, n_iter_full, n_iter_update);

    } else if(strcmp((*argv)[0], "infer_triangle") == 0) {
        if(*argc < 3) {
            cerr << "Not enough arguments to instantiate infer_triangle" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        unsigned int bool_iter_full = strtoul((*argv)[2], NULL, 10);
        *argc -= 3; *argv += 3;
        return (infer*) new infer_triangle(n_tasks, bool_iter_full);

    } else if(strcmp((*argv)[0], "infer_particle") == 0) {
        if(*argc < 6) {
            cerr << "Not enough arguments to instantiate infer_particle" << endl;
            return NULL;
        }
        unsigned int n_tasks = strtoul((*argv)[1], NULL, 10);
        double alpha = strtod((*argv)[2], NULL);
        double beta = strtod((*argv)[3], NULL);
        unsigned int n_particles = strtoul((*argv)[4], NULL, 10);
        unsigned int n_gibbs = strtoul((*argv)[5], NULL, 10);
        *argc -= 6; *argv += 6;
        return (infer*) new infer_particle(n_tasks, alpha, beta, n_particles, n_gibbs);
    }

    cerr << "Unable to parse the infer type" << endl;
    return NULL;
}

void infer_parse_test()
{
    const char *argv_wmv[] = {"infer_weighted", "10", "Lorem", "ipsum"};
    const char *argv_gold[] = {"infer_golden", "10", "3", "1.5", "1.0", "DODATKI"};
    const char *argv_maj[] = {"infer_majority", "10", "0.8", "dolor"};
    const char *argv_acy[] = {"infer_acyclic", "10", "4.0", "3.0", "sit", "amet"};
    const char *argv_del[] = {"infer_delayed", "10", "4.0", "3.0", "Run, you fools!"};
    const char *argv_qck[] = {"infer_quick", "10", "4.0", "3.0", "Run, you fools!"};
    const char *argv_var[] = {"infer_variational", "10", "2.0", "1.0", "10", "5", "consectetur"};
    const char *argv_eig[] = {"infer_eigen", "10", "10", "5", "adipiscing", "elit"};
    const char *argv_tri[] = {"infer_triangle", "10", "0", "dummy"};
    const char *argv_par[] = {"infer_particle", "10", "2.0", "1.0", "70", "7"};

    char **argv_weighted = (char**) argv_wmv;
    char **argv_golden = (char**) argv_gold;
    char **argv_majority = (char**) argv_maj;
    char **argv_acyclic = (char**) argv_acy;
    char **argv_delayed = (char**) argv_del;
    char **argv_quick = (char**) argv_qck;
    char **argv_variational = (char**) argv_var;
    char **argv_eigen = (char**) argv_eig;
    char **argv_triangle = (char**) argv_tri;
    char **argv_particle = (char**) argv_par;

    int argc_weighted = 4;
    int argc_golden = 6;
    int argc_majority = 4;
    int argc_acyclic = 6;
    int argc_delayed = 5;
    int argc_quick = 5;
    int argc_variational = 7;
    int argc_eigen = 6;
    int argc_triangle = 4;
    int argc_particle = 6;

    infer_weighted* i_weighted = (infer_weighted*) infer_parse(&argc_weighted, &argv_weighted);
    infer_golden* i_golden = (infer_golden*) infer_parse(&argc_golden, &argv_golden);
    infer_majority* i_majority = (infer_majority*) infer_parse(&argc_majority, &argv_majority);
    infer_acyclic* i_acyclic = (infer_acyclic*) infer_parse(&argc_acyclic, &argv_acyclic);
    infer_delayed* i_delayed = (infer_delayed*) infer_parse(&argc_delayed, &argv_delayed);
    infer_quick* i_quick = (infer_quick*) infer_parse(&argc_quick, &argv_quick);
    infer_variational* i_variational = (infer_variational*) infer_parse(&argc_variational, &argv_variational);
    infer_eigen* i_eigen = (infer_eigen*) infer_parse(&argc_eigen, &argv_eigen);
    infer_triangle* i_triangle = (infer_triangle*) infer_parse(&argc_triangle, &argv_triangle);
    infer_particle* i_particle = (infer_particle*) infer_parse(&argc_particle, &argv_particle);

    assert(argc_weighted == 2);
    assert(argc_golden == 1);
    assert(argc_majority == 1);
    assert(argc_acyclic == 2);
    assert(argc_delayed == 1);
    assert(argc_quick == 1);
    assert(argc_variational == 1);
    assert(argc_eigen == 2);
    assert(argc_triangle == 1);
    assert(argc_particle == 0);

    assert(strcmp(argv_weighted[0], "Lorem") == 0);
    assert(strcmp(argv_weighted[1], "ipsum") == 0);
    assert(strcmp(argv_golden[0], "DODATKI") == 0);
    assert(strcmp(argv_majority[0], "dolor") == 0);
    assert(strcmp(argv_acyclic[0], "sit") == 0);
    assert(strcmp(argv_acyclic[1], "amet") == 0);
    assert(strcmp(argv_delayed[0], "Run, you fools!") == 0);
    assert(strcmp(argv_quick[0], "Run, you fools!") == 0);
    assert(strcmp(argv_variational[0], "consectetur") == 0);
    assert(strcmp(argv_eigen[0], "adipiscing") == 0);
    assert(strcmp(argv_eigen[1], "elit") == 0);
    assert(strcmp(argv_triangle[0], "dummy") == 0);

    assert(i_weighted != NULL);
    assert(i_golden != NULL);
    assert(i_majority != NULL);
    assert(i_acyclic != NULL);
    assert(i_delayed != NULL);
    assert(i_quick != NULL);
    assert(i_variational != NULL);
    assert(i_eigen != NULL);
    assert(i_triangle != NULL);
    assert(i_particle != NULL);

    double diff_maj = i_majority->work_weight - math_logit_safe(0.8);
    double epsilon = numeric_limits<double>::epsilon();
    assert(diff_maj <= epsilon);
    assert(diff_maj >= -epsilon);
    assert(i_golden->n_trials == 3);
    assert(i_golden->alpha == 1.5);
    assert(i_golden->beta == 1.0);
    assert(i_acyclic->alpha == 4.0);
    assert(i_acyclic->beta == 3.0);
    assert(i_delayed->alpha == 4.0);
    assert(i_delayed->beta == 3.0);
    assert(i_quick->alpha == 4.0);
    assert(i_quick->beta == 3.0);
    assert(i_variational->alpha == 2.0);
    assert(i_variational->beta == 1.0);
    assert(i_variational->n_iter_full == 10);
    assert(i_variational->n_iter_update == 5);
    assert(i_eigen->n_iter_full == 10);
    assert(i_eigen->n_iter_update == 5);
    assert(i_triangle->bool_iter_full == 0);
    assert(i_particle->alpha == 2.0);
    assert(i_particle->beta == 1.0);
    assert(i_particle->n_particles == 70);
    assert(i_particle->n_gibbs == 7);

    delete i_weighted;
    delete i_golden;
    delete i_majority;
    delete i_acyclic;
    delete i_delayed;
    delete i_quick;
    delete i_variational;
    delete i_eigen;
    delete i_triangle;
    delete i_particle;
}
