% gibbs_sample_additive_model:

function [f, kernel_mcmc, llik_mcmc] = gibbs_sample_additive_model...
            (X_exp, Y_exp, X_space, do_params_X_space, prior_info, hyper_a, X_obs, Z_obs, M, verbose)

%% Preliminaries

S = length(X_exp);
num_X_space = length(X_space);
A = cell(S, 1);
f = cell(S, 1);
n = zeros(S, 1);

for s = 1:S
    
  A{s} = multi_exp_mapping(X_space, X_exp{s}, X_obs{s}, Z_obs{s}, prior_info{s});
  if size(A{s}, 2) ~= num_X_space
    error('Currently requires at least one observation per dose level')
  end
  
  n(s) = length(Y_exp{s});
  f{s} = zeros(num_X_space, M);
    
end

noise_matrix = get_noise_matrix(num_X_space);

%% Prepare prior information

SD2 = cell(S, 1);
core = cell(S, 1);
for s = 1:S
  sd = sqrt(diag(do_params_X_space{s}.K)); sd = sd / max(sd);
  SD2{s} = sd * sd';
  X1 = X_space - min(X_space); X1 = X1 / max(X1);
  Y1 = do_params_X_space{s}.mu_do - min(do_params_X_space{s}.mu_do); Y1 = Y1 / max(Y1);
  core{s} = (X1(:, ones(num_X_space, 1)) - X1(:, ones(num_X_space, 1))').^2 + ...
            (Y1(:, ones(num_X_space, 1)) - Y1(:, ones(num_X_space, 1))').^2 ;  
end

%% Main sampling


kernel_mcmc = cell(S, 1);
llik_mcmc = cell(S, 1);
for s = 1:S
  kernel_mcmc{s} = zeros(2, M);
  llik_mcmc{s} = zeros(1, M);
end

for m = 2:M
    
   kernel_mcmc{s}(:, m) = kernel_mcmc{s}(:, m - 1);
   llik_mcmc{s}(m) = llik_mcmc{s}(:, m - 1);
   
   if verbose, fprintf('Sampling iteration %d\n', m); end
   
   for s = 1:S

     % Sample f
   
     K_add = exp(kernel_mcmc{s}(1, m)) * SD2{s} .* exp(-0.5 * core{s} / exp(kernel_mcmc{s}(2, m)));
     K_f = do_params_X_space{s}.K + K_add + noise_matrix;
     prior_meancov_f_obs = K_f \ do_params_X_space{s}.mu_do;
     inv_prior_f_obs = inv(K_f);
     
     inv_cov_f = (A{s}' * A{s}) / exp(llik_mcmc{s}(m)) + inv_prior_f_obs;
     mean_f = inv_cov_f \ (A{s}' * Y_exp{s} / exp(llik_mcmc{s}(m)) + prior_meancov_f_obs);
     f{s}(:, m) = mean_f + chol(inv_cov_f) \ randn(num_X_space, 1);
       
     % Sample kernel parameters
     
     kernel_mcmc{s}(1, m) = slicesample(kernel_mcmc{s}(1, m), 1, 'logpdf', ...
                                @(x)hyper_add_seiso_logpdf(x, 1, kernel_mcmc{s}(:, m), ...
                                do_params_X_space{s}.K, f{s}(:, m), mean_f, ...
                                hyper_a.sf2.mu, hyper_a.sf2.var, SD2{s}, core{s}, noise_matrix));
     kernel_mcmc{s}(2, m) = slicesample(kernel_mcmc{s}(2, m), 1, 'logpdf', ...
                                @(x)hyper_add_seiso_logpdf(x, 2, kernel_mcmc{s}(:, m), ...
                                do_params_X_space{s}.K, f{s}(:, m), mean_f, ...
                                hyper_a.ell.mu, hyper_a.ell.var, SD2{s}, core{s}, noise_matrix));

   end
   
   % Sample likv
   
   for s = 1:S
     exp2 = sum((Y_exp{s} - A{s} * f{s}(:, m)).^2);   
     try
       llik_mcmc{s}(m) = slicesample(llik_mcmc{s}(m), 1, 'logpdf', @(x)likv_logpdf(x, n(s), exp2));
     catch %#ok<CTCH>
       warning('step out procedure failed')
       llik_mcmc{s}(m) = llik_mcmc{s}(m - 1);
     end
   end
               
end

