data = csvread('drug_data.csv',1,0);
data = data(data~=0);
data = data/100;
data = log(1 - data);
K = 4;
rem = data;
mu = zeros(1,K);
s = 1;
while s<=K
    samp = randsample(rem,1);
    if min(abs(max(mu)-samp))<0.12
        continue
    else
        rem = setdiff(rem,samp);
        mu(s) = samp;
        s = s+1;
    end
end
mu = sort(mu,'descend');
max_iter = 100000;
delta = [0.17*10^-1, 0.125*10^-1, 10^-2, 0.5*10^-2, 0.25*10^-2];
e = 1*10^-1;
avg = 1000;


T_avg_gap = zeros(1,length(delta));    
N_avg_gap = zeros(length(delta),4);
success_gap = zeros(1,length(delta));

T_avg_SE = zeros(1,length(delta));
N_avg_SE = zeros(length(delta),4);
success_SE = zeros(1,length(delta));

T_avg_U = zeros(1,length(delta));
par = 0.5;
N_avg_U = zeros(length(delta),4);
success_U = zeros(1,length(delta));

T_avg_med = zeros(1,length(delta));
N_avg_med = zeros(length(delta),4);
success_med = zeros(1,length(delta));

for delta_iter = 1:length(delta)
%% gap-based     
for iter  = 1:avg
    iter
    t = 1;
    S = [1,2,3,4];
    mu_hat = zeros(1,4);
    Na = zeros(1,4);
    data = zeros(4,max_iter);
    conf = sqrt((2)*log(4*t/(delta(delta_iter))))*ones(1,4);
    eps = inf;
    while eps>0 
            x = randn(1,4) + mu;
            % pick g1
            g1 = rand();
            % pick other adversaries
            g2 = rand(1,3)+g1-0.3;
            g = [g1,g2];
            x = (1-e)*x + e*g;
            if min(Na) <= max(sqrt(t),(1/e^2)*log(1/delta(delta_iter)))
                [~,arm] = min(Na);
                data(arm,Na(arm)+1) = x(arm);
                Na(arm) = Na(arm) + 1;
                temp = sort(data(arm,1:Na(arm)));
                remove = floor(Na(arm)*e);
                temp2 = temp(remove+1:Na(arm)-remove);
                mu_hat(arm) = mean(temp2);
                t = t + 1;
            else
                % arm selection
                [mu_best,best_arm] = max(mu_hat);
                overlap = (mu_hat+conf)-(mu_best-conf(best_arm))*ones(1,4);
                overlap(best_arm) = -inf;
                [gap_amb,amb_arm] = max(overlap);
                eps = gap_amb;
                if conf(best_arm) >= conf(amb_arm)
                    arm = best_arm;
                else
                    arm = amb_arm;
                end

                data(arm,Na(arm)+1) = x(arm);
                Na(arm) = Na(arm) + 1;
                temp = sort(data(arm,1:Na(arm)));
                remove = floor(Na(arm)*e);
                temp2 = temp(remove+1:Na(arm)-remove);
                mu_hat(arm) = mean(temp2);
                t = t+1;
                conf(arm) = sqrt((2/Na(arm))*log(log(t)/(delta(delta_iter))));
            end
    end
    [~,p_arm] = max(mu_hat);
    success_gap(delta_iter) = success_gap(delta_iter) + (p_arm==1);
    T_avg_gap(delta_iter) = T_avg_gap(delta_iter) + t;
    N_avg_gap(delta_iter,:) = N_avg_gap(delta_iter,:) + Na;
end
T_avg_gap(delta_iter) = T_avg_gap(delta_iter)/avg;
N_avg_gap = N_avg_gap/avg;
alloc_gap = N_avg_gap/sum(N_avg_gap);
success_gap(delta_iter) = vpa(success_gap(delta_iter)/avg);

%% SE-trimmed mean


for iter  = 1:avg
    iter
    t = 1;
    S = [1,2,3,4];
    mu_hat = zeros(1,4);
    Na = zeros(1,4);
    alpha = sqrt(2*log(4*pi^2/(6*delta(delta_iter))));
    data = zeros(4,max_iter);
    while length(S)>1
        for i=1:length(S)
            x = randn()+ mu(S(i));
            % pick g1
            g1 = rand();
            % pick other adversaries
            g2 = rand(1,3)+g1-0.3;
            g = [g1,g2];
            x = (1-e)*x + e*g(S(i));
            if t == 1
                mu_hat(S(i)) = x;
                data(S(i),Na(S(i))+1) = x;
                Na(S(i)) = Na(S(i)) + 1;
                t = t+1;
            else
                data(S(i),Na(S(i))+1) = x;
                Na(S(i)) = Na(S(i)) + 1;
                temp = sort(data(S(i),1:Na(S(i))));
                remove = floor(Na(S(i))*e);
                temp2 = temp(remove+1:Na(S(i))-remove);
                mu_hat(S(i)) = mean(temp2);
                t = t+1;
            end
        end
        elem = [];
        for i=1:length(S)
            if mu_hat(i) <= max(mu_hat) - 2*alpha && Na(S(i))>=(1/e^2)*log(1/delta(delta_iter))
                elem = [S(i), elem];
            end
        end
        S = setdiff(S,elem);
        alpha = sqrt((2/t)*log(4*t^2*pi^2/(6*delta(delta_iter))));
    end
    success_SE(delta_iter) = success_SE(delta_iter) + (S(1)==1);
    T_avg_SE(delta_iter) = T_avg_SE(delta_iter) + t;
    N_avg_SE(delta_iter,:) = N_avg_SE(delta_iter,:) + Na;
end
success_SE(delta_iter) = success_SE(delta_iter)/avg;
T_avg_SE(delta_iter) = T_avg_SE(delta_iter)/avg;
N_avg_SE = N_avg_SE/avg;
alloc_SE = N_avg_SE/sum(N_avg_SE);

%% Gap-based uniform sampling


%p = 0.01;
for iter  = 1:avg
    iter
    t = 1;
    mu_hat = zeros(1,4);
    Na = zeros(1,4);
    data = zeros(4,max_iter);
    conf = sqrt((2)*log(4*t/(delta(delta_iter))))*ones(1,4);
    eps = inf;
    while eps>0 
            x = randn(1,4) + mu;
            % pick g1
            g1 = rand();
            % pick other adversaries
            g2 = rand(1,3)+g1-0.3;
            g = [g1,g2];
            x = (1-e)*x + e*g;
            if t==0
                [~,arm] = min(Na);
                data(arm,Na(arm)+1) = x(arm);
                Na(arm) = Na(arm) + 1;
                temp = sort(data(arm,1:Na(arm)));
                remove = floor(Na(arm)*e);
                temp2 = temp(remove+1:Na(arm)-remove);
                mu_hat(arm) = mean(temp2);
                t = t + 1;
            else
                % arm selection
                [mu_best,best_arm] = max(mu_hat);
                overlap = (mu_hat+conf)-(mu_best-conf(best_arm))*ones(1,4);
                overlap(best_arm) = -inf;
                [gap_amb,amb_arm] = max(overlap);
                eps = gap_amb;
                arm = randi(K);
                data(arm,Na(arm)+1) = x(arm);
                Na(arm) = Na(arm) + 1;
                temp = sort(data(arm,1:Na(arm)));
                remove = floor(Na(arm)*e);
                temp2 = temp(remove+1:Na(arm)-remove);
                mu_hat(arm) = mean(temp2);
                t = t+1;
                conf(arm) = sqrt((2/Na(arm))*log(log(t)/(delta(delta_iter))));
            end
    end
    [~,p_arm] = max(mu_hat);
    success_U(delta_iter) = success_U(delta_iter) + (p_arm==1);
    T_avg_U(delta_iter) = T_avg_U(delta_iter) + t;
    N_avg_U(delta_iter,:) = N_avg_U(delta_iter,:) + Na;
end
T_avg_U(delta_iter) = T_avg_U(delta_iter)/avg;
N_avg_U = N_avg_U/avg;
alloc_U = N_avg_U/sum(N_avg_U);
success_U(delta_iter) = vpa(success_U(delta_iter)/avg);

%% median-based (CBAI)

for iter  = 1:avg
    iter
    t = 1;
    S = [1,2,3,4];
    mu_hat = zeros(1,4);
    Na = zeros(1,4);
    alpha = sqrt(2*log(4*pi^2/(6*delta(delta_iter))));
    data = zeros(4,max_iter);
    while length(S)>1
        for i=1:length(S)
            x = randn()+ mu(S(i));
            % pick g1
            g1 = rand();
            % pick other adversaries
            g2 = rand(1,3)+g1-0.3;
            g = [g1,g2];
            x = (1-e)*x + e*g(S(i));
            if t == 1
                mu_hat(S(i)) = x;
                data(S(i),Na(S(i))+1) = x;
                Na(S(i)) = Na(S(i)) + 1;
                t = t+1;
            else
                data(S(i),Na(S(i))+1) = x;
                Na(S(i)) = Na(S(i)) + 1;
                mu_hat(S(i)) = median(data(S(i),1:Na(S(i))));
                t = t+1;
            end
        end
        elem = [];
        for i=1:length(S)
            if mu_hat(i) <= max(mu_hat) - 2*alpha 
                elem = [S(i), elem];
            end
        end
        S = setdiff(S,elem);
        alpha = sqrt((2/t)*log(4*t^2*pi^2/(6*delta(delta_iter))));
    end
    %success_med(delta_iter) = success_med(delta_iter) + (S(1)==1);
    T_avg_med(delta_iter) = T_avg_med(delta_iter) + t;
    N_avg_med(delta_iter,:) = N_avg_med(delta_iter,:) + Na;
end
success_med(delta_iter) = success_med(delta_iter)/avg;
T_avg_med(delta_iter) = T_avg_med(delta_iter)/avg;
N_avg_med = N_avg_med/avg;
alloc_med = N_avg_med/sum(N_avg_med);
end

%% plots

figure()
plot(delta,T_avg_SE,'k*-')
hold on
plot(delta,T_avg_gap,'ro-')
hold on
plot(delta,T_avg_med,'gv-')
grid on
plot(delta,T_avg_U,'b.-')
legend('Algorithm 2','Algorithm 1','Median-based','Random')