#include <algorithm>
#include <cassert>
#include <cstring>
#include <iostream>
#include <vector>
#include "crowd.h"
#include "infer.h"
#include "policy.h"

using namespace std;

/// shared functions ///

template<typename T> static unsigned int shared_arg_min(vector<T>& data, vector<bool>& ignore)
{
    assert(data.size() == ignore.size());

    unsigned int i = 0;
    while(ignore[i] && i < ignore.size()) ++i;
    unsigned int m = i;

    for(++i; i < data.size(); ++i)
        if(data[i] < data[m])
            m = i;
    return m;
}

/// uniform allocation ///

policy_uniform::policy_uniform():
    t_next(0)
{}

void policy_uniform::choose_task(crowd* c, infer* inf, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // add new tasks on the fly
    if(inf->task_odds.size() > task_count.size()) {
        unsigned int n_tasks = inf->task_odds.size();
        task_ignore.resize(n_tasks, false);
        task_count.resize(n_tasks, 0);
    }

    // unpack the worker's id
    assert(t < c->data_seq.size());
    unsigned int j = c->data_seq[t].work_id;

    // skip the tasks this worker has already seen
    fill(task_ignore.begin(), task_ignore.end(), false);
    if(j < inf->work_hist.size())
        for(unsigned int k: inf->work_hist[j].time_steps) {
            assert(k < c->data_seq.size());
            task_ignore[c->data_seq[k].task_id] = true;
        }

    // pick the task with the least number of labels
    unsigned int i = shared_arg_min(task_count, task_ignore);
    c->data_seq[t].task_id = i;
    ++task_count[i];
}

void policy_uniform_test()
{
    policy_uniform poly;

    // simple crowd and inference
    double work_prior = 0.7;
    double p = 0.8;
    int quota = 5;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd_budget c(d_work_acc, d_work_quota, n_workers);
    infer_majority maj = infer_majority(n_tasks, work_prior);

    // generate and allocate all the data
    mt19937 ran_gen = well_seeded_mt19937(666);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        poly.choose_task(&c, &maj, t);
        maj.infer_update(ran_gen, &c, t);
    }

    // check that the counts match reality
    vector<unsigned int> task_count_final(n_tasks);
    for(unsigned int j = 0; j < n_workers; ++j)
        for(unsigned int k: maj.work_hist[j].time_steps)
            ++task_count_final[c.data_seq[k].task_id];
    for(unsigned int i = 0; i < n_tasks; ++i)
        assert(poly.task_count[i] == task_count_final[i]);

    // stochastic test: roughly the same number of labels per task
    unsigned int r = n_workers * quota / n_tasks;
    for(unsigned int i = 0; i < n_tasks; ++i)
        assert(poly.task_count[i] >= r - 1);
}

/// uncertainty sampling ///

policy_uncertainty::policy_uncertainty():
    t_next(0)
{}

void policy_uncertainty::choose_task(crowd* c, infer* inf, unsigned int t)
{
    // avoid skipping any steps
    assert(t == t_next);
    ++t_next;

    // add new tasks on the fly
    if(inf->task_odds.size() > task_ignore.size())
        task_ignore.resize(inf->task_odds.size(), false);

    // unpack the worker's id
    assert(t < c->data_seq.size());
    unsigned int j = c->data_seq[t].work_id;

    // skip the tasks this worker has already seen
    fill(task_ignore.begin(), task_ignore.end(), false);
    if(j < inf->work_hist.size())
        for(unsigned int k: inf->work_hist[j].time_steps) {
            assert(k < c->data_seq.size());
            task_ignore[c->data_seq[k].task_id] = true;
        }

    // compute the absolute log odds
    task_abs_odds.resize(0);
    for(double odds: inf->task_odds) {
        double abs_odds = (odds >= 0.0)? odds: -odds;
        task_abs_odds.push_back(abs_odds);
    }

    // pick the task with the smallest absolute log odds
    unsigned int i = shared_arg_min(task_abs_odds, task_ignore);
    c->data_seq[t].task_id = i;
}

void policy_uncertainty_test()
{
    policy_uncertainty poly;

    // simple crowd and inference
    double work_prior = 0.7;
    double p = 0.8;
    int quota = 5;
    unsigned int n_tasks = 10;
    unsigned int n_workers = 10;
    distro* d_work_acc = (distro*) new distro_dirac(p);
    distro* d_work_quota = (distro*) new distro_kronecker(quota);
    crowd_budget c(d_work_acc, d_work_quota, n_workers);
    infer_majority maj = infer_majority(n_tasks, work_prior);

    // generate and allocate all the data
    mt19937 ran_gen = well_seeded_mt19937(777);
    unsigned int t_final = n_workers * quota - 1;
    for(unsigned t = 0; t <= t_final; ++t) {
        c.create_data_point(ran_gen, t);
        poly.choose_task(&c, &maj, t);
        maj.infer_update(ran_gen, &c, t);
    }

    // check the absolute log odds
    for(unsigned int i = 0; i < n_tasks; ++i)
        if(i != c.data_seq[t_final].task_id) {
            double abs_odds = (maj.task_odds[i] >= 0.0)? maj.task_odds[i]: -maj.task_odds[i];
            assert(poly.task_abs_odds[i] == abs_odds);
        }

    // stochastic test: roughly the same absolute uncertainty on all tasks
    double avg_abs_odds = 0.0;
    for(unsigned int i = 0; i < n_tasks; ++i)
        avg_abs_odds += poly.task_abs_odds[i];
    avg_abs_odds /= (double) n_tasks;
    double avg_abs_diff = 0.0;
    for(unsigned int i = 0; i < n_tasks; ++i) {
        double diff = avg_abs_odds - poly.task_abs_odds[i];
        avg_abs_diff += (diff >= 0.0)? diff: -diff;
    }
    avg_abs_diff /= (double) n_tasks;
    assert(avg_abs_diff <= maj.work_weight);
}

/// parser ///

policy* policy_parse(int *argc, char **argv[])
{
    if(*argc < 1) {
        cerr << "Not enough arguments to parse the policy type" << endl;
        return NULL;
    }

    if(strcmp((*argv)[0], "policy_uniform") == 0) {
        *argc -= 1; *argv += 1;
        return (policy*) new policy_uniform();

    } else if(strcmp((*argv)[0], "policy_uncertainty") == 0) {
        *argc -= 1; *argv += 1;
        return (policy*) new policy_uncertainty();
    }

    cerr << "Unable to parse the policy type" << endl;
    return NULL;
}

void policy_parse_test()
{
    const char *argv_uni[] = {"policy_uniform", "no", "big"};
    const char *argv_crt[] = {"policy_uncertainty", "deal"};
    char **argv_uniform = (char**) argv_uni;
    char **argv_uncertainty = (char**) argv_crt;

    int argc_uniform = 3;
    int argc_uncertainty = 2;

    policy_uniform* poly_uni = (policy_uniform*) policy_parse(&argc_uniform, &argv_uniform);
    policy_uncertainty* poly_crt = (policy_uncertainty*) policy_parse(&argc_uncertainty, &argv_uncertainty);

    assert(argc_uniform == 2);
    assert(argc_uncertainty == 1);

    assert(strcmp(argv_uniform[0], "no") == 0);
    assert(strcmp(argv_uniform[1], "big") == 0);
    assert(strcmp(argv_uncertainty[0], "deal") == 0);

    assert(poly_uni != NULL);
    assert(poly_crt != NULL);

    assert(poly_uni->task_count.size() == 0);
    assert(poly_crt->task_abs_odds.size() == 0);

    delete poly_uni;
    delete poly_crt;
}
