%% IHDP DEMONSTRATION FILE
%
% The IHDP data is described in detail in Hill (2011), as well as the
% appendix of the companion paper. The following assumes that function
% "ihdp/process_ihdp.R" has been executed (notice that is a R script), so
% that files "ihdp/ihdp.dat" and "ihdp/ihdp.names" have been generated.
%
% Version 1.0, April 27th 2016


%% PART 0: LOAD DATA

fprintf('[LOAD DATA]\n')

ihdp = load('ihdp.dat'); % Main file
num_vars = size(ihdp, 2);
file_id = fopen('ihdp.names', 'r'); % Variable names, for reference
names_0 = textscan(file_id, '%s');
fclose(file_id);

y = 1;
x = 2;
names = cell(length(names_0{1}), 1);
for i = 1:num_vars
  names{i} = names_0{1}{i}(2:length(names_0{1}{i}) - 1); 
  if strcmp(names{i}, 'momed') % Stratify by mother's education
    strata = i;
  end
end
S_names = {'high school', 'college', 'all'}; % Labels of the two strata, plus the label given to the full dataset
Z_sel = setdiff(1:num_vars, [x y strata]);
Z_sel(Z_sel > strata) = Z_sel(Z_sel > strata) - 1;

%% PART 1: STRATIFY BY MOTHER'S EDUCATION AND NORMALIZE OUTPUT

% Normalization helps with inference, and we can always reverse that when
% displaying results.

S = length(unique(ihdp(:, strata))) + 1;
dat = cell(S, 1);
Z_sels = cell(S, 1);

row_choice = cell(S, 1);
col_choice = cell(S, 1);
for s = 1:S
  row_choice{s} = find(ihdp(:, strata) == s - 1);
  col_choice{s} = setdiff(1:num_vars, strata);  
end
row_choice{S} = 1:size(ihdp, 1);
col_choice{S} = 1:num_vars;

for s = 1:S
  dat{s} = ihdp(row_choice{s}, col_choice{s});  
  mean_sel = mean(dat{s}(:, Z_sel));
  sd_sel = sqrt(var(dat{s}(:, Z_sel)));
  pos_var = find(sd_sel > 0);
  mean_sel = mean_sel(pos_var);
  sd_sel = sd_sel(pos_var);
  Z_sels{s} = Z_sel(pos_var);
  n_sel = size(dat{s}, 1);
  dat{s} = [dat{s}(:, [y x]) (dat{s}(:, Z_sels{s}) - repmat(mean_sel, n_sel, 1)) ./ repmat(sd_sel, n_sel, 1)];
end

y_center = zeros(S, 1);
y_scale = zeros(S, 1);
x_center = mean(ihdp(:, x));
x_scale = sqrt(var(ihdp(:, x)));
for s = 1:S
  y_center(s) = mean(dat{s}(:, y));
  y_scale(s) = sqrt(var(dat{s}(:, y)));
  dat{s}(:, y) = (dat{s}(:, y) - y_center(s)) / y_scale(s);
  dat{s}(:, x) = (dat{s}(:, x) - x_center) / x_scale;
end

X_space = (0:25:450)' / 100;
X_space = (X_space - x_center) / x_scale;
num_X_space = length(X_space);

%% PART 2: OBSERVATIONAL LEARNING

% Here we build a model based on what the population would look like if it
% was truly generated by a dose-response resulting from marginalizing the
% observed covariates.

num_iter = 2000;
fprintf('[OBSERVATIONAL LEARNING (PLEASE WAIT)]\n')
prior_info = cell(S, 1);
for s = 1:S
  fprintf('  stratum %d\n', s)
  prior_info{s} = observational_learning(dat{s}, x, y, 3:size(dat{s}, 2), num_iter, false);
end

fprintf('[BUILD DO PREKERNEL]\n')
do_params_X_space = cell(S, 1);
for s = 1:S
  fprintf('  stratum %d\n', s)
  do_params_X_space{s} = build_do_prekernel(X_space, dat{s}(:, x), dat{s}(:, 3:end), prior_info{s}, true);
end

%% PART 3: GENERATE ARTIFICIAL OBSERVATIONAL/INTERVENTIONAL DATASETS

% 'f_space' is the 'true' dose-response function, built artificially from
% the existing data. We then selected design points 'X_exp' which will
% generate (artificial) outcomes 'Y_exp', 10 datapoints per dose level.

f_space = cell(S, 1);
for s = 1:S
  f_space{s} = do_params_X_space{s}.mu_do;
end

X_space_100 = (X_space * x_scale + x_center) * 100;

num_rep = 10; % Number of data points per dose level
X_exp = cell(S, 1);
Y_exp = cell(S, 1);

for s = 1:S
  X_exp{s} = repmat(X_space, num_rep, 1);
  Y_exp{s} = zeros(length(X_exp{s}), 1);
  for i = 1:num_rep
    Y_exp{s}((i - 1) * num_X_space + (1:num_X_space)) = do_params_X_space{s}.mu_do + randn(num_X_space, 1) .* sqrt(do_params_X_space{s}.var_err_do);
  end
end

X_obs = cell(S, 1); % Observational data, rearranged
Z_obs = cell(S, 1);
for s = 1:S
  X_obs{s} = dat{s}(:, x);
  Z_obs{s} = dat{s}(:, 3:end);
end

%% PART 4: SET PRIORS

% 'sf2' and 'ell' are the amplitude and lengthscale parameters of the
% observational covariance matrix K_obs, as described in the paper.

hyper_a.sf2.mu  = log(1); 
hyper_a.sf2.var = 0.5;

hyper_a.ell.mu = log(1);
hyper_a.ell.var = 0.1;

%% PART 5: RUN SIMPLE DEMONSTRATION

% Fit the observational/interventional data using
% 'gibbs_sample_distortion_model', our main procedure. 'a' is the
% distortion function and 'f' the observational function, so that 'a .* f'
% are samples from the posterior distribution of the dose-response.
% Structure 'f_hat_direct' is the fit to interventional data only, using
% standard GP regression.

fprintf('[SIMPLE DEMONSTRATION]\n')

M = 5000;
burn_in = 500;
verbose = false;
[a, b, f_obs] = gibbs_sample_affine_model(X_exp, Y_exp, X_space, do_params_X_space, prior_info, hyper_a, X_obs, Z_obs, [], M, false, verbose);

num_iter = 2000;
f_hat_direct = cell(S, 1);
for s = 1:S
  f_hat_direct{s} = simple_learn_from_experiments(X_space, X_exp{s}, Y_exp{s}, num_iter);
end

%% PART 6: VISUALIZE IT

% The green curve is the "true" dose-response, dashed-red is the posterior
% expected curve from our method, dashed-blue is the one given by fitting
% interventional data only.

for s = 1:S
  figure
  hold on
  scatter(100 * (X_exp{s} * x_scale + x_center), Y_exp{s} * y_scale(s) + y_center(s))
  plot(X_space_100, f_space{s} * y_scale(s) + y_center(s), '-g', 'LineWidth', 5)
  xlabel('Treatment {\it X}', 'FontSize', 20)
  ylabel('Outcome {\it Y}', 'FontSize', 20)
  title(strcat('Data (', S_names{s}, ')'), 'FontSize', 20)
  hold off
end

for s = 1:S
  figure
  dr = a{s} .* f_obs{s} + b{s};
  dr = dr(:, (burn_in + 1):end);
  plot(X_space_100, dr * y_scale(s) + y_center(s), 'Color', [0.4,0.4,0.4]); hold on
  plot(X_space_100, f_space{s} * y_scale(s) + y_center(s), '-g', 'LineWidth', 5)
  plot(X_space_100, mean(dr, 2) * y_scale(s) + y_center(s), '--r', 'LineWidth', 5)
  plot(X_space_100, f_hat_direct{s}{2} * y_scale(s) + y_center(s), '--b', 'LineWidth', 5)
  xlabel('Treatment {\it X}', 'FontSize', 20)
  ylabel('Outcome {\it Y}', 'FontSize', 20)
  title(strcat('Posterior on dose-response (', S_names{s}, ')'), 'FontSize', 20)
end

%% PART 7: QUANTITATIVE COMPARISON

% Assess errors by running it many times, compare our method against
% interventional data only fitting.

num_trials = 100;
num_iter = 1000;

fprintf('[LARGE SCALE COMPARISON (%d TRIALS)]\n', num_trials)
errors = cell(S, 1);
for s = 1:S
  errors{s} = zeros(num_trials, S);
end

M = 2000;
burn_in = 500;

for n = 1:num_trials
  
  for s = 1:S
    Y_exp{s} = zeros(length(X_exp{s}), 1);
    for i = 1:num_rep
      Y_exp{s}((i - 1) * num_X_space + (1:num_X_space)) = do_params_X_space{s}.mu_do + randn(num_X_space, 1) * sqrt(do_params_X_space{s}.var_err_do);
    end
  end

  fprintf('   Trial %d\n', n)
  [a, b, f_obs] = gibbs_sample_affine_model(X_exp, Y_exp, X_space, do_params_X_space, prior_info, hyper_a, X_obs, Z_obs, [], M, false, verbose);  
  
  for s = 1:S
    f_hat_direct = simple_learn_from_experiments(X_space, X_exp{s}, Y_exp{s}, num_iter);
    dr = a{s} .* f_obs{s} + b{s};
    dr = mean(dr(:, (burn_in + 1):end), 2);
    errors{s}(n, 1) = mean(abs(dr - f_space{s}));
    errors{s}(n, 2) = mean(abs(f_hat_direct{2} - f_space{s}));
  end
  
end

for s = 1:S
  sd_diff = (errors{s}(:, 2) - errors{s}(:, 1)) / (max(f_space{s}) - min(f_space{s}));
  fprintf('Normalized absolute error (S = %d) = %.2f (reduction frequency = %.2f)\n', s, mean(sd_diff), mean(sd_diff > 0))
end

%% PART 8: ACTIVE LEARNING

% 'X_exp_active' and 'Y_exp_active' are the respective treament/outcomes 
% obtained by the simple active learning scheme. 'idx' is the indice in the
% dose space corresponding to each intervention.
%
% 'a_act' and 'f_act' are the corresponding distortion/observational
% curves, while 'a_uniform', 'f_uniform' are obtained using a uniform
% design instead.

fprintf('[ACTIVE LEARNING DEMONSTRATION]\n')

M = 550;             % Number of MCMC iterations per round
burn_in = 50;        % Corresponding burn-in
num_rep_active = 3;  % Number of interventional samples per dose level to be collected

X_exp_init = cell(S - 1, 1);
Y_exp_init = cell(S - 1, 1);
X_exp_uniform = cell(S - 1, 1);
Y_exp_uniform = cell(S - 1, 1);
Z_active_sel = cell(S - 1, 1);
dat_active = cell(S - 1, 1);
models = cell(S - 1, 1);

for s = 1:(S - 1)
    
  X_exp_init{s} = X_space;
  Y_exp_init{s} = do_params_X_space{s}.mu_do + randn(num_X_space, 1) * sqrt(do_params_X_space{s}.var_err_do);
  
  models{s}.poly = false;
  models{s}.X_space = X_space;
  models{s}.f_space = do_params_X_space{s}.mu_do;
  models{s}.var_error = do_params_X_space{s}.var_err_do;
  
  dat_active{s} = dat{s};
  Z_active_sel{s} = 3:size(dat{s}, 2);
  
  X_exp_uniform{s} = repmat(X_space, num_rep_active + 1, 1);
  Y_exp_uniform{s} = zeros(length(X_exp_uniform{s}), 1);
  for i = 1:(num_rep_active + 1)
    Y_exp_uniform{s}((i - 1) * num_X_space + (1:num_X_space)) = do_params_X_space{s}.mu_do + randn(num_X_space, 1) .* sqrt(do_params_X_space{s}.var_err_do);
  end
  
end

[X_exp_active, Y_exp_active, idx] = active_dose_response_learning(num_rep_active * length(X_space), X_space, dat_active, x, y, Z_active_sel, X_exp_init, Y_exp_init, num_iter, prior_info, do_params_X_space, models, burn_in, M);
[a_act, b_act, f_act] = gibbs_sample_affine_model(X_exp_active, Y_exp_active, X_space, do_params_X_space, prior_info, hyper_a, X_obs, Z_obs, [], M, false, false);
[a_uniform, b_uniform, f_uniform] = gibbs_sample_affine_model(X_exp_uniform, Y_exp_uniform, X_space, do_params_X_space, prior_info, hyper_a, X_obs, Z_obs, [], M, false, verbose);

%% PART 9: VISUALIZE ACTIVE LEARNING OUTPUT

for s = 1:(S - 1)
  figure
  hist(idx{s})
  title(strcat('Active selection (', S_names{s}, ')'), 'FontSize', 20)
end

for s = 1:(S - 1)
  figure
  dr = (a_act{s} .* f_act{s} + b_act{s}) * y_scale(s) + y_center(s);
  dr = dr(:, (burn_in + 1):end);
  plot(X_space_100, dr, 'Color', [0.4,0.4,0.4]); hold on
  plot(X_space_100, f_space{s} * y_scale(s) + y_center(s), '-g', 'LineWidth', 5)
  plot(X_space_100, mean(dr, 2), '--r', 'LineWidth', 5)
  dr = (a_uniform{s} .* f_uniform{s} + b_uniform{s}) * y_scale(s) + y_center(s);
  dr = dr(:, (burn_in + 1):end);
  plot(X_space_100, mean(dr, 2), '--b', 'LineWidth', 5)
  xlabel('Treatment {\it X}', 'FontSize', 20)
  ylabel('Outcome {\it Y}', 'FontSize', 20)
  title(strcat('Posterior on dose-response (', S_names{s}, ')'), 'FontSize', 20)
end

for s = 1:(S - 1)
  figure
  dr = (a_uniform{s} .* f_uniform{s} + b_uniform{s}) * y_scale(s) + y_center(s);
  dr = dr(:, (burn_in + 1):end);
  plot(X_space_100, dr, 'Color', [0.4,0.4,0.4]); hold on
  plot(X_space_100, f_space{s} * y_scale(s) + y_center(s), '-g', 'LineWidth', 5)
  plot(X_space_100, mean(dr, 2), '--b', 'LineWidth', 5)
  dr = (a_act{s} .* f_act{s} + b_act{s}) * y_scale(s) + y_center(s);  
  dr = dr(:, (burn_in + 1):end);
  plot(X_space_100, mean(dr, 2), '--r', 'LineWidth', 5)
  xlabel('Treatment {\it X}', 'FontSize', 20)
  ylabel('Outcome {\it Y}', 'FontSize', 20)
  title(strcat('Posterior from uniform (', S_names{s}, ')'), 'FontSize', 20)
end

%% PART 10: COMPARISON AGAINST STANDARD DEEP GAUSSIAN PROCESS

% This requires that Stan and ProcessManager are installed (see
% "setpath.m")

fprintf('[DEEP GAUSSIAN PROCESS COMPARISON (PLEASE WAIT)]\n')

for s = 1:S
  Y_exp{s} = zeros(length(X_exp{s}), 1);
  for i = 1:num_rep
    Y_exp{s}((i - 1) * num_X_space + (1:num_X_space)) = do_params_X_space{s}.mu_do + randn(num_X_space, 1) * sqrt(do_params_X_space{s}.var_err_do);
  end
end

burn_in_hmc = 20; % Burn-in for HMC
M_hmc = 200;      % Number of HMC iterations
tic
[theta_standard_deep, f_standard_deep, f_obs_standard_deep] = hmc_sample_standard_deep_gp(X_exp, Y_exp, X_space, do_params_X_space, prior_info, hyper_a, X_obs, Z_obs, burn_in_hmc, M_hmc); 
time_standard_deep = toc;

burn_in_gibbs = 100;
thinning_gibbs = 10;
M_gibbs = burn_in_gibbs + M_hmc * thinning_gibbs;
M_sel = (burn_in_gibbs + 1):thinning_gibbs:M_gibbs;
tic
[a, b, f] = gibbs_sample_affine_model(X_exp, Y_exp, X_space, do_params_X_space, prior_info, hyper_a, X_obs, Z_obs, [], M_gibbs, false, verbose);
f_affine = cell(S, 1);
for s = 1:S
  f_affine{s} = a{s}(:, M_sel) .* f{s}(:, M_sel) + b{s}(:, M_sel);
end
time_affine = toc;

fprintf('Time taken: %ds for non-factorized, %ds for affine\n', round(time_standard_deep), round(time_affine))

%% PART 11: VISUALIZATION OF COMPARISON AGAINST STANDARD DEEP GAUSSIAN PROCESS

ylims = cell(S, 1);
for s = 1:S
  figure
  plot(X_space_100, f_affine{s} * y_scale(s) + y_center(s), 'Color', [0.4,0.4,0.4]); hold on
  plot(X_space_100, f_space{s} * y_scale(s) + y_center(s), '-g', 'LineWidth', 5)
  plot(X_space_100, mean(f_affine{s}, 2) * y_scale(s) + y_center(s), '--r', 'LineWidth', 5)
  xlabel('Treatment {\it X}', 'FontSize', 20)
  ylabel('Outcome {\it Y}', 'FontSize', 20)
  title(strcat('Factorized model (', S_names{s}, ')'), 'FontSize', 20)
  ylims{s} = ylim;
end

for s = 1:S
  figure
  plot(X_space_100, f_standard_deep{s} * y_scale(s) + y_center(s), 'Color', [0.4,0.4,0.4]); hold on
  plot(X_space_100, f_space{s} * y_scale(s) + y_center(s), '-g', 'LineWidth', 5)
  plot(X_space_100, mean(f_standard_deep{s}) * y_scale(s) + y_center(s), '--r', 'LineWidth', 5)
  xlabel('Treatment {\it X}', 'FontSize', 20)
  ylabel('Outcome {\it Y}', 'FontSize', 20)
  title(strcat('Non-factorized model (', S_names{s}, ')'), 'FontSize', 20)
  ylim(ylims{s})
end

%% Evaluate log-likelihood

for s = 1:S
    
  mean_method = mean(f_affine{s}, 2);
  var_method = var(f_affine{s}, [], 2);
  llik_factorized = -0.5 * (log(sqrt(2 * pi)) + mean(log(var_method)) + mean((f_space{s} - mean_method).^2 ./ var_method));

  mean_method = mean(f_standard_deep{s})';
  var_method = var(f_standard_deep{s})';
  llik_standard_deep = -0.5 * (log(sqrt(2 * pi)) + mean(log(var_method)) + mean((f_space{s} - mean_method).^2 ./ var_method));

  fprintf('Stratum %d: [%1.2f, %1.2f]\n', s, llik_factorized, llik_standard_deep)

end
