clear;
addpath('../aux'); %test;
addpath('../figures');
addpath('../results'); %test;

%%

design='NP'; %NP, HLLT, CC
[f,sim,x_vis,f_vis]=get_design(design);

% sample size
N=1000;

% simulate data, split into stage 1 and 2 samples
rng('default');
[x,y,z]=sim(f,N);

% visualize design
plot(x_vis,f_vis,'LineWidth',5);
hold on;
scatter(x,y,36,[.5 .5 .5]);
xlabel('x','FontSize',20)
ylabel('y','FontSize',20)
legend({'Structural function','Data'},'Location','southeast','FontSize',20);
hold off;
% saveas(gcf,fullfile('../figures',strcat('design_',design)),'epsc');

%% KIV - IV, causal validation varying vx manually

vx=1;
df=get_K(x,y,z,x_vis,vx);

% initialize hyperparameters for tuning
lambda_0=log(0.05); %hyp1=lambda. log(0.05) for NP
xi_0=log(0.05); %hyp2=xi. log(0.05) for NP

% stage 1 tuning
KIV1_obj=@(lambda) KIV1_loss(df,exp(lambda)); %exp to ensure pos; fixed vx,vz
lambda_star=fminunc(KIV1_obj,lambda_0);

% stage 2 tuning
KIV2_obj=@(xi) KIV2_loss(df,[exp(lambda_star),exp(xi)]); %exp to ensure pos
xi_star=fminunc(KIV2_obj,xi_0);

% evaluate on full sample using tuned hyperparameters
y_vis=KIV_pred(df,[exp(lambda_star),exp(xi_star)],3); %exp to ensure pos

% mse
disp('mse:');
disp(mse(y_vis,f_vis));

% visualize estimator
%plot(x_vis,f_vis,'LineWidth',5);
%hold on;
%scatter(x,y,36,[.5 .5 .5]);
%plot(x_vis,y_vis,'--r','LineWidth',5);
%xlabel('x','FontSize',20)
%ylabel('y','FontSize',20)
%legend({'Structural function','Data','KernelIV'},'Location','southeast','FontSize',20);
%hold off;
%saveas(gcf,fullfile('../figures',strcat('KIV_',design,'_',vx)),'epsc');

%% KIV - 40 simulations

clear;
rng('default')
design='NP'; %NP, HLLT, CC
[f,sim,x_vis,f_vis]=get_design(design);

% sample size
N=1000;
n_trials=40;
results=zeros(n_trials,length(f_vis));
results_mse=zeros(n_trials,1);

% robustness
alg='KIV';
vx_list=[.2 .4 .6 .8 1];

for j=1:length(vx_list)
    vx=vx_list(j);
    for i=1:n_trials
        if mod(i,10)==0
            disp(num2str(i));
        end
        results(i,:)= sim_pred(design,N,vx)';
        results_mse(i)=mse(results(i,:)',f_vis);
    end
    % save results
    csvwrite(fullfile('../results',strcat(alg,'_',design,'_',num2str(N),'_',num2str(vx*10),'.csv')),results_mse);
    
end



