function Key =  MinimaxEntropy_Vector_crowd_model(Model, varargin)
% Minimax entropy algorithm in "Learning from the Wisdom of Crowds by Minimax Entropy. Zhou, Platt, Basu, Mao. NIPS 2012"
%
% varargin: name-value pair arguments in matlab style
%       -- maxIter[default=100]: maximum number of iteration
%       -- TOL[default=1e-3]: error threshold for convergence
%       -- lambda_worker[default=0]: regularization coefficient on the workers
%       -- lambda_task[default=0]: regularization coefficience on the tasks (items)
%       -- verbose[default=1]: print nothing (verbose=0); print final (verbose=1); print iterative (verbose=2)
% 
% Output: 
%       -- Key.ans_labels: predicted (deterministic) labels for the items
%       -- Key.error_rate: the prediction error rate (if provided with true_labels)
%       -- Key.soft_labels: the predicted soft labels (i.e., the posterior distribution of the item labels)
%       -- Key.parameter_worker: the worker-wise parameters   
%       -- Key.parameter_task: the task-wise parameters
%
% This function used the LBFGS code 'minFunc.m' by Mark Schmidt (http://www.di.ens.fr/~mschmidt/Software/minFunc.html). 
%
% Qiang Liu @Jan 2013

%% Preparing
[verbose, maxIter, TOL, prior_workers,  prior_tasks, inner_maxIter, damping, issparse] = process_varargin(varargin, ...
    'verbose', 1, 'maxIter', 100, 'TOL', 1e-3, {'prior_worker', 'ell'}, [], ...
    'prior_task', [],   'inner_maxIter', 1, 'damping', 0, 'issparse', false);
[offdiag] = process_varargin(varargin,  'offdiag', false);
[lambda_alpha, lambda_w] = process_varargin(varargin,   {'lambda_worker'}, 0,                 {'lambda_task'}, 0);
[update_alpha, update_w] = process_varargin(varargin,   {'update_parameter_worker'}, true,       {'update_parameter_task'}, true);

if issparse, L = sparse(Model.L); else L = full(Model.L); end;
Ntask = Model.Ntask; Nwork = Model.Nwork; NeibTask = Model.NeibTask; NeibWork = Model.NeibWork;
LabelDomain =Model.LabelDomain; Ndom = length(LabelDomain);

% set default prior parameters
if isempty(prior_tasks), prior_tasks = ones(Ndom, Ntask)/Ndom; end
if isempty(prior_workers), prior_workers = ones(Ndom, Ndom); end

% Set the optimization parameters for LBFGS
[maxIter_Mstep, optTOL_Mstep, progTOL_Mstep] = process_varargin(varargin, 'maxIter_Mstep',25, 'optTOL_Mstep', 1e-3, 'progTOL_Mstep', 1e-3);
options = []; options.bbType = 1;options.Method = 'lbfgs'; options.maxIter = maxIter_Mstep; options.optTol = optTOL_Mstep; options.progTOL = progTOL_Mstep;
if verbose >3, options.Display = 'iter'; elseif verbose >2, options.Display = 'final'; else options.Display = 'off'; end

% initializing mu (the posterior prob) using majority voting counting
mu = zeros(2, Ntask);
for i = 1:Ntask
    neib = NeibTask{i}; labs = full(L(i, neib));
    for k = 1:length(LabelDomain)
        mu(k, i) = nnz(labs==LabelDomain(k)) / length(labs);    
    end   
end

wX = zeros(Ndom, Ntask); 
alpha = zeros(Ndom, Ndom, Nwork);        
logp_task = zeros(Ndom, Ntask);

cal_truth = false;
if isfield(Model, 'true_labels'), 
    cal_truth = true;
    true_labels = Model.true_labels(:)';
    dx_with_ans = find(isfinite(true_labels));
    prob_err_Vec = zeros(1,maxIter); 
end

%% main iteration
err = NaN;
tic
for iter = 1:maxIter    
   %M step 
   for inner_iter = 1:inner_maxIter
       % M step: update alpha (worker-wise confusion matrix)
       if update_alpha
        for j = 1:Nwork
           neib = NeibWork{j}; labs = full(L(neib, j))';
           mu_neibj = mu(:, neib); wX_neibj = wX(:, neib);
           obj_handle = @(unknown)obj_update_alpha_minimaxent_per_worker(unknown, wX_neibj, mu_neibj, labs, LabelDomain, lambda_alpha, offdiag);       
           alpha_tmp = minFunc(obj_handle, reshape(alpha(:,:,j), [],1), options);           
           alpha(:,:,j) = reshape(alpha_tmp, Ndom, Ndom);
        end 
       end
       
        %update wx (task-wise confusion)
        if update_w
        for i = 1:Ntask
            neib = NeibTask{i}; labs = full(L(i, neib));
            mui = mu(:, i);
            alpha_neibi = alpha(:,:, neib);                
            obj_handle = @(wxi)obj_update_wx_per_task(wxi, mui, alpha_neibi, labs, LabelDomain, lambda_w);        
            wX(:, i) = minFunc(obj_handle, wX(:, i), options);           
        end
        end
   end
   
    %E step: update the posterior probability (mu)
    old_mu = mu;
    for i = 1:Ntask
        neib = NeibTask{i}; labs = (L(i, neib));
        logp_task(:,i) = log(prior_tasks(:, i));
        for jdx = 1:length(neib)
            j = neib(jdx);
            tmp = alpha(:,:, j) + wX(:, i*ones(1,Ndom))';
            logz = logsumexp2_stable(tmp);
            logp_task(:, i) = logp_task(:, i) + tmp(:, labs(jdx)) - logz;
        end      
        tmp = logp_task(:,i) + damping * log(old_mu(:,i)+eps);
        mu(:,i) = exp(tmp - max(tmp));
        mu(:,i) = mu(:,i)/sum(mu(:,i));
    end
    
    % check the error    
    err = max(abs(old_mu(:)-mu(:)));
    if verbose >=2, printstr = sprintf('%s: iter=%d, congerr=%f, ', mfilename, iter, err); end
       
    % evaluate ground truth
    if cal_truth
        [~, mxdx] = max(mu.*(1+ rand(size(mu))*eps)); % add some random noise to break ties. 
        ans_labels = LabelDomain(mxdx);
        prob_err_Vec(iter) = mean(ans_labels(dx_with_ans) ~= true_labels(dx_with_ans));
        if verbose >=2, printstr = horzcat(printstr, sprintf('err_rate=%f, ', prob_err_Vec(iter))); end
    end    
    if verbose >=2, printstr = horzcat(printstr, sprintf('time=%f', toc));end
    if verbose >=2, fprintf('%s\n', printstr); end
    
    if err < TOL, break; end         
          
end

%% check the error rate

Key.method = mfilename;
[~, mxdx] = max(mu.*(1+ rand(size(mu))*eps)); %decode the labels of tasks: add random noise to break ties. 
Key.ans_labels = LabelDomain(mxdx);
%if cal_truth, Key.error_rate = double(mean(Key.ans_labels(dx_with_ans) ~= true_labels(dx_with_ans))); end
if isfield(Model, 'true_labels') && ~isempty(Model.true_labels)     
    [Key.error_rate, Key.MoreInfo.error_L1, Key.MoreInfo.error_L2]=cal_error_using_soft_label(mu, Model.true_labels);    
end
Key.soft_labels = mu;
Key.parameter_worker = alpha;
Key.parameter_task = wX;
Key.converge_error = err;


% Print out final information 
if verbose >= 1 
    printstr = sprintf('%s:\n\t-- break at %dth iteration, congerr=%f\n', mfilename, iter, err); 
    if cal_truth, printstr = horzcat(printstr, sprintf('\t-- error rate = %f', Key.error_rate)); end
    fprintf('%s\n',printstr);
end

return;
end

function [obj, Dobj_Dalpha] = obj_update_alpha_minimaxent_per_worker(alpha, wX_neibj, mu_neibj, labs, LabelDomain, lambda, offdiag)

if nargin < 6, lambda = 0; end
if nargin < 7, offdiag=false; end

obj = 0; Ndom = length(LabelDomain); alpha = reshape(alpha, Ndom, Ndom); Dobj_Dalpha = zeros(size(alpha));

for i = 1:size(labs,2)    
    wxi = wX_neibj(:, i); 
    MM = wxi(:,ones(1,Ndom))' + alpha; maxtmp = max(MM, [], 2);  
    expMM = exp( MM - maxtmp(:, ones(1,Ndom)) ); sumexpMM = sum(expMM, 2);
    logz = log(sumexpMM) + maxtmp;            
    obj = obj + (MM(:,labs(i)) - logz)' * mu_neibj(:,i);   
     
    probmatrix = expMM ./ sumexpMM(:, ones(1,Ndom));
    %probmatrix = exp(MM - logz(:, ones(1,Ndom))); %Dobj_Dalpha = Dobj_Dalpha +  mu_neibj(:, i) * (LabelDomain==labs(i)) - diag(mu_neibj(:, i)) * probmatrix';      
    Dobj_Dalpha(:, labs(i)) = Dobj_Dalpha(:, labs(i)) + mu_neibj(:,i);
    Dobj_Dalpha = Dobj_Dalpha - mu_neibj(:,i*ones(1,Ndom)) .* probmatrix;
end

if offdiag
    tmp = diag(diag(alpha));
    obj = -obj + .5 * lambda * (sum(alpha(:).^2) - sum(diag(alpha).^2));
    Dobj_Dalpha = -Dobj_Dalpha(:)  + lambda * (alpha(:) - tmp(:));    
else
    obj = -obj + .5 * lambda * sum(alpha(:).^2);    
    Dobj_Dalpha = -Dobj_Dalpha(:)  + lambda * alpha(:);    
end

return;
end


function [obj, Dobj]= obj_update_wx_per_task(wxi, mui, alpha_neibi, labs, LabelDomain, lambda)
if nargin < 5, lambda = 0; end
Ndom = length(LabelDomain);

obj = 0; Dobj_Dw = 0*wxi;
for j = 1:length(labs)    
    tmp = wxi(:, ones(1,Ndom)) + alpha_neibi(:,  :, j)'; maxtmp = max(tmp, [], 1); % w: Nfeature * Ndom        
    logz = log(sum(exp(tmp - maxtmp(ones(1,Ndom), :)), 1)) + maxtmp;                
    obj = obj + (tmp(labs(j), :) - logz) * mui;               
        
    %%%        
    probmatrix = exp(tmp - logz(ones(1, Ndom), :));        
    Dobj_Dw(:) = Dobj_Dw(:) + (double(LabelDomain==labs(j)) - (probmatrix * mui)')';                  
end

obj = -obj + 0.5 * lambda * sum(wxi(:).^2);
Dobj = -Dobj_Dw(:) + lambda * wxi(:);

return;
end

