% this file is the runner of the active learning averaging over all 12
% annotators using consensus as the ground truth
clear;clc;
load('hemisphere_dataset_summary.mat')
%% define the annotator indices
annotator_lst = [1, 2, 3, 4; % dataset 1
                 1, 3, 4, 5; % dataset 2
                 1, 2, 3, 6];% dataset 3
%% define the hyperparameters
ratio = 0.1;

config.DO_ZSCORING = true;
config.n = 1;
config.repeat = 1;

method_rand = struct("name", "random");
method_cal  = struct("name", "cal");
method_dal  = struct("name", "dal");
methods     = {method_rand, method_cal, method_dal};

weights_lst = [0.3, 0.5, 0.7, 0.9]; %, 0.7
for i=1:length(weights_lst)
    weight = weights_lst(i);
    method_dcal = struct("name", "dcal", "weight", weight);
    methods{end+1} = method_dcal;
end

% [reward_func_lst, reward_name_lst] = get_reward_funcs([1:3]); % return all predefined reward function 5, 7, 10
% gamma_lst = [0.3, 0.9];
% for i=1:length(reward_func_lst)
%     for g=1:length(gamma_lst)
%         gamma       = gamma_lst(g);
%         reward_func = reward_func_lst{i};
%         reward_name = reward_name_lst(i);
%         method_mab  = struct("name", "mab-exp3", "gamma", gamma, ...
%                       "reward_func", reward_func, "reward_name", reward_name);
%         methods{end+1} = method_mab;
%     end
% end

clear g i gamma method_cal method_cal method_dcal method_mab reward_func reward_name weight;
num_methods = length(methods);
labels_saved_all = cell(3, 4, num_methods);
%% ActSort 
for d = 1:3
    for ann = 1:4
        ann_idx = annotator_lst(d, ann);
        choices    = choices_all{ann_idx,d};
        choices_gt = choices_gt_all{ann_idx,d};
        metrics    = metrics_all{d};
        for k=1:num_methods
            global labels_saved;
            labels_saved = struct("labels_ex", [], "labels_ml_prob", [], "q_idxs", []);
            method = methods{k};
            dspname = get_legend_name(method);
            fprintf("%s: Working on dataset %i, annotator %i using method %s\n",datetime("now"), d, ann, dspname);
            [valid, eval_metrics, dataset] = play_active_learning_new(metrics,choices,choices_gt,ratio,method,config);
            labels_saved_all{d, ann, k} = labels_saved;
        end 
    end
end
%% continue
load("exlabels_12datasets_6methods.mat", "labels_saved_all");
labels_saved_all_old = labels_saved_all;
labels_saved_all = cell(3, 4, num_methods);
for d = 1:3
    for ann = 1:4
        ann_idx = annotator_lst(d, ann);
        choices    = choices_all{ann_idx,d};
        choices_gt = choices_gt_all{ann_idx,d};
        metrics    = metrics_all{d};
        for k=1:num_methods
            if k==7                
                global labels_saved;
                labels_saved = struct("labels_ex", [], "labels_ml_prob", [], "q_idxs", []);
                method = methods{k};
                dspname = get_legend_name(method);
                fprintf("%s: Working on dataset %i, annotator %i using method %s\n",datetime("now"), d, ann, dspname);
                [valid, eval_metrics, dataset] = play_active_learning_new(metrics,choices,choices_gt,ratio,method,config);
                labels_saved_all{d, ann, k} = labels_saved;
            else
                method = methods{k};
                dspname = get_legend_name(method);
                fprintf("%s: dataset %i, annotator %i using method %s Exist! Loading ===> \n",datetime("now"), d, ann, dspname);
                labels_saved_all{d, ann, k} = labels_saved_all_old{d, ann, k};
            end
        end 
    end
end