%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% mcmc_probitsem:
%
% Calculates expected value of p(X | Y) (and generate samples, if required)
% for a probit SEM where Y = probit(L * X + L0), e ~ N(0, 1) and X ~ N(0, S).
% This is done using MCMC.
%
% Input:
%
% - Y: the data point (column vector), encoded as {-1, 1}
% - L: coefficients of the measurement model. The last column is an
%      intercept term
% - invS: inverse latent covariance matrix
% - burn_in: number of burn-in steps
% - num_iter: number of iterations
% - verbose: if true, print iteration numbers
%
% Output:
%
% - m: mean of posterior
% - X_out: samples, if required
%
% Created by: Ricardo Silva, London, 21/04/2011
% University College London
% Current version: 21/04/2011

function [m X_out] = mcmc_probitsem(Y, L, invS, num_iter, burn_in, verbose)

if nargin < 6
  verbose = false;
end

[num_y num_xp1] = size(L); num_x = num_xp1 - 1;
if num_y ~= size(Y, 2)
  error('Incoherent number of observable variables!')
end
N = size(Y, 1);

if nargout == 2
  X_out = zeros(N, num_x, num_iter);
end

% General initialization

m = zeros(num_x, N);

lb_ystar = ones(N, num_y) * -Inf;
up_ystar = ones(N, num_y) * Inf;
for y = 1:num_y
  lb_ystar(Y(:, y) == 1,  y) = 0;
  up_ystar(Y(:, y) == -1, y) = 0;
end

y_star = abs(randn(N, num_y)) .* Y; 
invSigma = invS + L(:, 1:num_x)' * L(:, 1:num_x);
chol_iS = chol(invSigma);
coeff = (invSigma \ L(:, 1:num_x)')';
offset = repmat(L(:, end)', N, 1);
subL = L(:, 1:num_x)';

% Perform MCMC

for iter = 1:num_iter
  if verbose
    fprintf('MCMC Iteration [%d]\n', iter)
  end

  mu = (y_star - offset) * coeff;
  X = mu + (chol_iS \ randn(num_x, N))';
  mean_y_star = X * subL + offset;

  for y = 1:num_y
    y_star(:, y) = FastTruncatedGaussRND(mean_y_star(:, y), lb_ystar(:, y), up_ystar(:, y)); % Double check this with normcdf
    y_star((isinf(y_star(:, y)) .* (Y(:, y) == 1)) == 1, y)  = 7;
    y_star((isinf(y_star(:, y)) .* (Y(:, y) == -1)) == 1, y) = -7;
  end

  if iter > burn_in
    m = m + X';
  end
  
  if nargout == 2
    X_out(:, :, iter) = X;
  end
  
  for y = 1:num_y
    if isinf(mean(log(normcdf(Y(:, y) .* mean_y_star(:, y)))))
      disp('!')
    end
  end
end

m = m / (num_iter - burn_in);
