% build_do_prekernel:
%
% For some given treatment levels (X_exp), generates the corresponding prior
% hypermean and hypervariances. This is done by the observational
% *predictive* mean and variance. So the variance is relatively low, but we
% can re-inflate it later, as done in 'hyper_learning'.
%
% Input:
%
% - X_exp: treatment levels
% - X_obs, Z_obs: observational data used in the construction of the prior.
% - prior_info: outcome of observational_learning.
% - compute_ell: if TRUE compute the covariance parameters
%
% Output:
%
% - do_params: mean and variance of the prior at the treatment levels.

function do_params = build_do_prekernel(X_exp, X_obs, Z_obs, prior_info, ...
                                        compute_ell, exact_way)

n_exp = length(X_exp);
n_obs = size(X_obs, 1);
n_obs_sq = n_obs^2;
XZ_dat = [X_obs Z_obs];

mu_do = zeros(n_exp, 1);
var_do = zeros(n_exp, 1);

W_s = prior_info.chol_K';

var_err_do = zeros(n_exp, 1);
for i = 1:n_exp
  do_dat = [X_exp(i) * ones(n_obs, 1) Z_obs];
  Kss = feval(prior_info.XZ_covfunc{:}, prior_info.XZ_hyp.cov, do_dat);
  Ks = feval(prior_info.XZ_covfunc{:}, prior_info.XZ_hyp.cov, XZ_dat, do_dat);
  y_zhat = Ks' * prior_info.K_w;
  mu_do(i) = mean(y_zhat);
  % y_zhat can be very non-Gaussian. Instead of raw var(y_zhat), get the
  % length of the interval around the median that will cover 95% of its
  % empirical mass as the estimated standard deviation.
  var_err_do(i) = var(y_zhat);
  w_s = W_s \ sum(Ks, 2);
  var_do(i) = (sum(sum(Kss)) - w_s' * w_s) / n_obs_sq;
end
var_err_do = mean(var_err_do + prior_info.X_hyp.lik);

do_params.empty = false;
do_params.mu_do = mu_do;
do_params.var_do = var_do;
do_params.var_err_do = var_err_do;

%% Learn covariances in priors

if ~compute_ell, return; end

if nargin < 6, exact_way = false; end

if exact_way
    
  W = zeros(2, 2 * n_obs);
  W(1, 1:n_obs) = 1 / n_obs;
  W(2, n_obs + (1:n_obs)) = 1 / n_obs;  
  W_p = W';
  out_K = zeros(n_exp);
  
  for i = 1:(n_exp - 1)
    for j = (i + 1):n_exp
      do_dat = [X_exp(i) * ones(n_obs, 1) Z_obs; X_exp(j) * ones(n_obs, 1) Z_obs];
      Kss = feval(prior_info.XZ_covfunc{:}, prior_info.XZ_hyp.cov, do_dat);
      Ks = feval(prior_info.XZ_covfunc{:}, prior_info.XZ_hyp.cov, XZ_dat, do_dat);
      w_s = W_s \ (Ks * W_p);
      pred_Kss = W * Kss * W_p - w_s' * w_s;
      out_K(i, j) = pred_Kss(1, 2);
      out_K(j, i) = out_K(i, j);      
    end
  end    
  
else
    
  cov_pred = zeros(n_exp);
  for i = 1:n_obs
    do_dat = [X_exp Z_obs(ones(n_exp, 1) * i, :)];
    Kss = feval(prior_info.XZ_covfunc{:}, prior_info.XZ_hyp.cov, do_dat);
    Ks = feval(prior_info.XZ_covfunc{:}, prior_info.XZ_hyp.cov, XZ_dat, do_dat);
    w_s = prior_info.chol_K' \ Ks;
    cov_pred = cov_pred + Kss - w_s' * w_s;
  end
  cov_pred = cov_pred / n_obs;
  sd_pred = sqrt(diag(cov_pred));
  out_K = cov_pred ./ (sd_pred * sd_pred');
  out_K = out_K .* (sqrt(var_do) * sqrt(var_do)');
  
end

do_params.K = out_K;

sd_K = sqrt(diag(do_params.K));
corr_K = do_params.K ./ (sd_K * sd_K');
sd_K_adj = sd_K ./ max(sd_K);
do_params.K_scale = (sd_K_adj * sd_K_adj') .* corr_K;
