% synth_run_and_compare:
%
% Compare methods based on the default choice of priors and problems
% generated by 'generate_problems'. Here we assess outputs under different
% the deformation model with default priors.
%
% Input:
%
% - dat: observational data;
% - model: model as generated by "synth,/generate_model.m", required in
%   order to generate data under simulated interventions;
% - num_Z_sel: number of common causes that will be visible to the
%   algorithm;
% - Z_sorted_oracle: a sorting of the covariates that is used to decide
%   which common causes are visible in the observational fitting;
% - M, burn_in: number of MCMC samples and burn-in iterations;                 
% - num_total_exps: vector with a set of sample sizes for the
%   interventional simulations;
% - num_X_space: number of dose levels, to be generated uniformly in the
%   space observed in the observational sample;
% - num_iter: number of optimization steps as used internally to fit
%   hyperparameters of some of the Gaussian process models;
% - verbose: whether to display information or not.
%
% Output:
%
% - X_exp_uniform, Y_exp_uniform: samples generated by simulated
%   interventions. Since num_total_exps can request for more than one sample
%   size, this is a cell structure, one entry per interventional sample
%   size;
% - X_space: dose levels, and implied by the observational data and
%   num_X_space;
% - f_space: the corresponding response, as implied by the model;
% - model1, model2, model3: result of the inference, with and without a distortion
%   function, and the additive model, respectively;
% - eval_error_space*: assessment of the three models;
% - prior_info: default prior generated internally;
% - do_params_X_space: information generated when fitting the observational
%   data;
% - f_hat_confounded: dose-response obtained by naive regression with no
%   adjustment;
% - f_hat_naive_space: dose-response obtained by direct regression on 
%   the interventional data only. Two versions: with Matern and squared
%   exponential kernels;
% - Z_sorted_oracle: same as in input;
% - Z_sel: the subset of covariates selected to be given to the
%   observational model;
      
function [X_exp_uniform, Y_exp_uniform, X_space, f_space, model1, model2, model3, model4, model5, ...
          eval_error_space1, eval_error_space2, eval_error_space3, eval_error_space4, eval_error_space5, ...
          prior_info, do_params_X_space, f_hat_confounded, f_hat_naive_space, ...
          hyper_a, hyper_likv, Z_sorted_oracle, Z_sel] = ...
          synth_run_and_compare(dat, model, num_Z_sel, Z_sorted_oracle, M, burn_in, num_total_exps, num_X_space, num_iter, verbose)

if nargin < 10, verbose = false; end

%% Auxiliary fitting

n = size(dat, 1);
p = size(dat, 2) - 2;
x = p + 1;
y = p + 2;

m_data  = mean(dat);
sd_data = sqrt(var(dat));
dat     = dat - repmat(m_data, n, 1);
dat     = dat ./ repmat(sd_data, n, 1);
m_x     = m_data(x);
sd_x    = sd_data(x);
m_y     = m_data(y);
sd_y    = sd_data(y);

num_exps = length(num_total_exps);
X_space = linspace(min(dat(:, x)), max(dat(:, x)), num_X_space)';
f_space = zeros(num_X_space, 1);
for i = 1:num_X_space
  [~, f_space(i)] = run_intervention(X_space(i)* sd_x + m_x, model);
  f_space(i) = (f_space(i) - m_y) / sd_y;
end

if isempty(Z_sorted_oracle)
  Z_sorted_oracle = 1:p;
end
Z_sel = Z_sorted_oracle(1:num_Z_sel);

if verbose, fprintf('[OBSERVATIONAL LEARNING]\n'); end
prior_info = observational_learning(dat, x, y, Z_sel, num_iter, false);

if verbose, fprintf('[BUILD DO-PREKERNEL]\n'); end
do_params_X_space = build_do_prekernel(X_space, dat(:, x), dat(:, Z_sel), prior_info, true);
if verbose, fprintf('[DIRECT LEARNING, NO ADJUSTMENT]\n'); end
f_hat_confounded = simple_learn_from_experiments(X_space, dat(:, x), dat(:, y), num_iter);

model1 = cell(num_exps, 1); % Affine transformation
model2 = cell(num_exps, 1); % No transformation at all, just fixed observational covariance function
model3 = cell(num_exps, 1); % No distortion function, just additive transform
model4 = cell(num_exps, 1); % Observational covariance function + generic squared exponential kernel
model5 = cell(num_exps, 1); % Just interventional data, generic squared exponential kernel
X_exp_uniform = cell(num_exps, 1);
Y_exp_uniform = cell(num_exps, 1);
f_hat_naive_space = cell(num_exps, 1);

if verbose, fprintf('[DIRECT LEARNING, EXPERIMENTAL]\n'); end

total_so_far = 0;
for v = 1:num_exps    
  rep_uniform = round((num_total_exps(v) - total_so_far) / num_X_space);
  X_exp_uniform{v} = repmat(X_space, rep_uniform, 1);
  Y_exp_uniform{v} = zeros(length(X_exp_uniform{v}), 1);
  for i = 1:length(X_exp_uniform{v})
    Y_exp_uniform{v}(i) = run_intervention(X_exp_uniform{v}(i) * sd_x + m_x, model);
    Y_exp_uniform{v}(i) = (Y_exp_uniform{v}(i) - m_y) / sd_y;
  end
  if v > 1
    X_exp_uniform{v} = [X_exp_uniform{v - 1}; X_exp_uniform{v}];
    Y_exp_uniform{v} = [Y_exp_uniform{v - 1}; Y_exp_uniform{v}];
  end
  f_hat_naive_space{v} = simple_learn_from_experiments(X_space, X_exp_uniform{v}, Y_exp_uniform{v}, 2 * num_iter);
  total_so_far = length(Y_exp_uniform{v});
end

%% Priors

hyper_a.sf2.mu  = log(1);
hyper_a.sf2.var = 0.5;

hyper_a.ell.mu = log(1);
hyper_a.ell.var = 0.1;

hyper_likv.mu = log(prior_info.X_hyp.lik);
hyper_likv.var = 1;

%% Sampling with the distortion model 

sel_M = (burn_in + 1):M;
predict_samples_space = cell(num_exps, 1);
predict_samples_space_locked = cell(num_exps, 1);
predict_samples_space_add = cell(num_exps, 1);
predict_samples_space_add2 = cell(num_exps, 1);
predict_samples_space_plain = cell(num_exps, 1);

for v = 1:num_exps
    
  if verbose, fprintf('[RUNNING FACTORIAL MODEL (N = %d)]\n', num_total_exps(v)); end
  [a, b, f_obs, kernel_mcmc, llik_mcmc] = gibbs_sample_affine_model_nostrata(X_exp_uniform{v}, Y_exp_uniform{v}, X_space, do_params_X_space, prior_info, hyper_a, dat(:, x), dat(:, Z_sel), M, false, false);
  predict_samples_space{v} = (a .* f_obs)' + b';
  predict_samples_space{v} = predict_samples_space{v}(sel_M, :);
  model1{v}.a = a;
  model1{v}.b = b;
  model1{v}.f_obs = f_obs;
  model1{v}.kernel_mcmc = kernel_mcmc;
  model1{v}.llik_mcmc = llik_mcmc;

  if verbose, fprintf('[RUNNING SINGLE LAYER MODEL (N = %d)]\n', num_total_exps(v)); end
  [a_locked, b_locked, f_obs_locked, kernel_mcmc_locked, llik_mcmc_locked] = gibbs_sample_affine_model_nostrata(X_exp_uniform{v}, Y_exp_uniform{v}, X_space, do_params_X_space, prior_info, hyper_a, dat(:, x), dat(:, Z_sel), M, true, false);
  predict_samples_space_locked{v} = f_obs_locked';
  predict_samples_space_locked{v} = predict_samples_space_locked{v}(sel_M, :);
  model2{v}.a = a_locked;
  model2{v}.b = b_locked;
  model2{v}.f_obs = f_obs_locked;
  model2{v}.kernel_mcmc = kernel_mcmc_locked;
  model2{v}.llik_mcmc = llik_mcmc_locked;

  if verbose, fprintf('[RUNNING ADDITIVE MODEL (N = %d)]\n', num_total_exps(v)); end
  X_exp_add = cell(1); X_exp_add{1} = X_exp_uniform{v};
  Y_exp_add = cell(1); Y_exp_add{1} = Y_exp_uniform{v};
  do_params_X_space_add = cell(1); do_params_X_space_add{1} = do_params_X_space;
  prior_info_add = cell(1); prior_info_add{1} = prior_info;
  X_obs_add = cell(1); X_obs_add{1} = dat(:, x);
  Z_obs_add = cell(1); Z_obs_add{1} = dat(:, Z_sel);          
  [f_add, kernel_add_mcmc, llik_add_mcmc] = gibbs_sample_additive_model(X_exp_add, Y_exp_add, X_space, do_params_X_space_add, prior_info_add, hyper_a, X_obs_add, Z_obs_add, M, false);
  model3{v}.f           = f_add{1};
  model3{v}.kernel_mcmc = kernel_add_mcmc{1};
  model3{v}.llik_mcmc   = llik_add_mcmc{1};
  predict_samples_space_add{v} = model3{v}.f(:, sel_M)';

  if verbose, fprintf('[RUNNING SIMPLER ADDITIVE MODEL (N = %d)]\n', num_total_exps(v)); end
  A = multi_exp_mapping(X_space, X_exp_uniform{v}, dat(:, x), dat(:, Z_sel), prior_info);
  [pred_mean, pred_cov, theta_hat] = simple_learn_obs_int(A, Y_exp_uniform{v}, X_space, do_params_X_space);
  model4{v}.f = pred_mean';
  model4{v}.cov_f = pred_cov;
  model4{v}.theta_hat = theta_hat;
  predict_samples_space_add2{v} = repmat(pred_mean', 2, 1);  
  
  if verbose, fprintf('[RUNNING SIMPLER MODEL, NO SPECIAL PRIOR (N = %d)]\n', num_total_exps(v)); end
  A = multi_exp_mapping(X_space, X_exp_uniform{v}, dat(:, x), dat(:, Z_sel), prior_info);
  [pred_mean, pred_cov, theta_hat] = simple_learn_obs_int(A, Y_exp_uniform{v}, X_space);
  model5{v}.f = pred_mean';
  model5{v}.cov_f = pred_cov;
  model5{v}.theta_hat = theta_hat;
  predict_samples_space_plain{v} = repmat(pred_mean', 2, 1);           

end

if verbose, fprintf('\n'); end

%% Sampling with the distortion model EVALUATE

eval_error_space1 = cell(num_exps, 1);
eval_error_space2 = cell(num_exps, 1);
eval_error_space3 = cell(num_exps, 1);
eval_error_space4 = cell(num_exps, 1);
eval_error_space5 = cell(num_exps, 1);

eval_bayes1_abs = zeros(num_exps, 1);
eval_bayes1_llik = zeros(num_exps, 1);
eval_bayes2_abs = zeros(num_exps, 1);
eval_confound1_abs = zeros(num_exps, 1);
eval_confound2_abs = zeros(num_exps, 1);
eval_direct1_abs = zeros(num_exps, 1);
eval_direct2_abs = zeros(num_exps, 1);
eval_bayes2_llik = zeros(num_exps, 1);
eval_prior_abs = zeros(num_exps, 1);

visualize_it = false;

for v = 1:num_exps
  eval_error_space1{v} = assess_methods(predict_samples_space{v}, do_params_X_space.mu_do, f_hat_naive_space{v}, f_hat_confounded, X_space, f_space, dat(:, x), visualize_it, false);
  eval_error_space2{v} = assess_methods(predict_samples_space_locked{v}, do_params_X_space.mu_do, f_hat_naive_space{v}, f_hat_confounded, X_space, f_space, dat(:, x), visualize_it, false);
  eval_error_space3{v} = assess_methods(predict_samples_space_add{v}, do_params_X_space.mu_do, f_hat_naive_space{v}, f_hat_confounded, X_space, f_space, dat(:, x), visualize_it, false);
  eval_error_space4{v} = assess_methods(predict_samples_space_add2{v}, do_params_X_space.mu_do, f_hat_naive_space{v}, f_hat_confounded, X_space, f_space, dat(:, x), false, false);
  eval_error_space5{v} = assess_methods(predict_samples_space_plain{v}, do_params_X_space.mu_do, f_hat_naive_space{v}, f_hat_confounded, X_space, f_space, dat(:, x), false, false);
  eval_bayes1_abs(v) = eval_error_space1{v}.bayes_err2;
  eval_bayes1_llik(v) = eval_error_space1{v}.llik_bayes;
  eval_prior_abs(v) = eval_error_space1{v}.prior_err2;
  eval_confound1_abs(v) = eval_error_space1{v}.confounded_err22{1};
  eval_confound2_abs(v) = eval_error_space1{v}.confounded_err22{2};
  eval_direct1_abs(v) = eval_error_space1{v}.direct_err2{1};
  eval_direct2_abs(v) = eval_error_space1{v}.direct_err2{2};
  eval_bayes2_abs(v) = eval_error_space2{v}.bayes_err2;
  eval_bayes2_llik(v) = eval_error_space2{v}.llik_bayes;
end
