% active_dose_response_learning:
%
% An active learning method for generating a sample of size K corresponding
% to manipulations of a X variable with an outcome variable Y across
% different subpopulations. In this implementation, experiments are carried
% one point at a time, for simplicity.
%
% Input:
%
% - num_total_exp: number of experiments over all strata;
% - num_X_space: number of possible placements for experiments;
% - dat: cell structure containing the data matrix for each strata
% - x, y: indices of treatment and outcome;
% - Z: array of indices of adjustment covariates, common within each strata
% - X_exp_init, Y_exp_init: initial samples prior to active learning;
% - num_iter: number of iterations for hyperparameter fitting;
% - prior_info, do_params: structures obtained by "observational_learning"
%   and "build_do_prekernel";
% - models: artificial model that can generate synthetic outcomes as
%   provided by "generate_problem";
% - M: number of Monte Carlo samples to represent predictive distribution.
%
% Output:
%
% - X_exp, Y_exp: points where experiments where performed and the 
%   corresponding output in each strata;
% - idx: corresponding position of the points according to the dose space;
% - a, b, f_obs, kernel_mcmc, llik_mcmc: model learned at the last
%   iteration, see gibbs_sample_affine_model.


function [X_exp, Y_exp, idx, a, b, f_obs, kernel_mcmc, llik_mcmc] = ...
           active_dose_response_learning(num_total_exp, X_space, dat, x, y, Z, ...
                                         X_exp_init, Y_exp_init, ...
                                         num_iter, prior_info, do_params, models, burn_in, M)

%% Basic initialization

num_X_space = length(X_space); % Number of possible treatment levels
S = length(dat); % Number of strata

X_exp = cell(S, 1);
Y_exp = cell(S, 1);
idx   = cell(S, 1);
X_obs = cell(S, 1);
Z_obs = cell(S, 1);
for s = 1:S
  X_exp{s} = X_exp_init{s};                   % Treatment levels
  Y_exp{s} = Y_exp_init{s};                   % Outcomes
  idx{s}   = zeros(length(X_exp_init{s}), 1); % Index of treatment levels
  for j = 1:length(X_exp{s})
    idx{s}(j) = find(X_space == X_exp{s}(j));
  end
  X_obs{s} = dat{s}(:, x);
  Z_obs{s} = dat{s}(:, Z{s});
end

%% Setup prior information: part 1

if isempty(prior_info)
  prior_info = cell(S, 1);
  for s = 1:S
    prior_info{s} = observational_learning(dat{s}, x, y, Z{s}, num_iter);
  end
end
if isempty(do_params)
  do_params = cell(S, 1);
  for s = 1:S
    do_params{s} = build_do_prekernel(X_space, dat{s}(:, x), dat{s}(:, Z{s}), prior_info{s}, true);
  end
end

%% Priors

hyper_a.sf2.mu  = log(1);
hyper_a.sf2.var = 0.5;

hyper_a.ell.mu = log(1);
hyper_a.ell.var = 0.1;

start_mcmc.a = cell(S, 1);
start_mcmc.b = cell(S, 1);
start_mcmc.log_sf2_a = zeros(S, 1);
start_mcmc.log_ell_a = zeros(S, 1);
start_mcmc.log_sf2_b = zeros(S, 1);
start_mcmc.log_ell_b = zeros(S, 1);
start_mcmc.f = cell(S, 1);
for s = 1:S
  start_mcmc.a{s} = ones(num_X_space, 1);
  start_mcmc.b{s} = zeros(num_X_space, 1);
  start_mcmc.log_sf2_a(s) = hyper_a.sf2.mu;
  start_mcmc.log_ell_a(s) = hyper_a.ell.mu;
  start_mcmc.log_sf2_b(s) = hyper_a.sf2.mu;
  start_mcmc.log_ell_b(s) = hyper_a.ell.mu;
  start_mcmc.f{s} = do_params{s}.mu_do;
  start_mcmc.llikv(s) = log(prior_info{s}.X_hyp.lik);
end

%% Initialize sampling

[a, b, f_obs, kernel_mcmc, llik_mcmc] = gibbs_sample_affine_model(X_exp_init, Y_exp_init, X_space, do_params, prior_info, hyper_a, X_obs, Z_obs, start_mcmc, M, false, false);  

%% Core iterations

M_latent = 100;

for i = 1:num_total_exp
 
  %% Decide where to measure new point
  
  [new_idx, new_s] = search_info_gain_strata(a, b, f_obs);
  fprintf('Experiment [%d/%d]: position %d, strata %d\n', i, num_total_exp, new_idx, new_s)
  idx{new_s} = [idx{new_s}; new_idx];
  X_exp{new_s} = [X_exp{new_s}; X_space(new_idx)];
  Y_exp{new_s} = [Y_exp{new_s}; run_intervention(X_space(new_idx), models{new_s})];
  
  %% Update hyperparameter distribution
  
  if mod(length(Y_exp{new_s}), 5) == 0
     
    [a, b, f_obs, kernel_mcmc, llik_mcmc] = gibbs_sample_affine_model(X_exp, Y_exp, X_space, do_params, prior_info, hyper_a, X_obs, Z_obs, start_mcmc, M, false, false);  

  else
      
    start_mcmc_l = start_mcmc;    
    for s = 1:S
      start_mcmc_l.log_sf2_a(s) = mean(kernel_mcmc{s}(1, burn_in + 1:end));
      start_mcmc_l.log_ell_a(s) = mean(kernel_mcmc{s}(2, burn_in + 1:end));
      start_mcmc_l.log_sf2_b(s) = mean(kernel_mcmc{s}(3, burn_in + 1:end));
      start_mcmc_l.log_ell_b(s) = mean(kernel_mcmc{s}(4, burn_in + 1:end));
      start_mcmc_l.likv(s) = mean(exp(llik_mcmc{s}(burn_in + 1:end)));
    end
    start_mcmc_l.locked_hyper = true;
    [a, b, f_obs] = gibbs_sample_affine_model(X_exp, Y_exp, X_space, do_params, prior_info, hyper_a, X_obs, Z_obs, start_mcmc_l, M_latent, false, false);
    
  end
  
end
