% generate_problem:
%
% Generates a synthetic model following a bipartite structure. Here, X 
% denotes treatment, Y outcome and Z common causes of X and Y.
%
% A jointly Gaussian top layer of variables Z is generated, with or 
% without some correlation.
%
% Then, a function f_i(Z_i) is generated from a Gaussian process prior for
% every covariate Z_i, rescaled appropriately. Treatment X is generated by
% summing all f_i(Z_i) and adding some noise. The proportion of signal to
% noise is chosen from a predefined interval.
%
% The functions mapping Z to Y is done by first linearly mixing them, then
% feeding the linear combination to a second degree polynomial. This is
% done so that its expected value can be calculated exactly. Y is then
% generated additively by combining some polynomial function in X and the
% second degree polynomial in Z.
%
% Input:
%
% - n: sample size;
% - p: number of covariates (confounders);
% - lik_range: interval in which we uniformly sample error variances for
%              the likelihood function. This will be proportional to the 
%              amount of signal;
% - ell, sf: hyperparameters of a squared exponential covariance function 
%            used in the generation of treatment assignment. That is,
%            functions mapping covariates Z to treatment X;
% - degree: degree of the polynomial corresponding to the effect of X on Y;
% - signal_level: a way of regulating the contribution of X to Y. First,
%                 get the signal_level and  1 - signal_level
%                 quantiles of the distribution of all other causes of Y
%                 that are not X (including error terms). Then rescale the
%                 dose-response function such that the difference between
%                 its maximum and minimum output (within the sampled X)
%                 will be equal to the difference between these quantiles.
%                 The larger signal_level is (between 0 and 0.5), the more
%                 the dose-response function gets washed out by the
%                 contribution of the other causes;
% - correlate: if true, define covariates to be correlated.
%
% Output:
%
% - dat: the dataset. The first p columns are the confounders, the p + 1 
%        column is the treatment, the final column is the outcome;
% - model: the generate model;
% - confound_strength: (estimate of) the rank correlation of X and the
%                      total effect of Z on Y.

function [dat, model, confound_strength] = ...
              generate_problem(n, p, lik_range, ell, sf, degree, signal_level, correlate)

%% Defaults

if nargin < 8
  correlate = false;
end

if signal_level < 0 || signal_level > 0.5
  error('Signal level must between 0 and 0.5')
end

%% Setup parameters

if correlate
  cov_Z = ones(p) * 0.5;
  for i = 1:p, cov_Z(i, i) = 1; end
else
  cov_Z = eye(p);
end

x_noise_prop = rand() * (lik_range(2) - lik_range(1)) + lik_range(1);
y_noise_prop = rand() * (lik_range(2) - lik_range(1)) + lik_range(1);

covfunc = {@covSEiso}; hyp.cov = log([ell; sf]);

nugget = 1.e-10;

%% Generate covariates

Z = randn(n, p) * chol(cov_Z);

%% Generate treatment

f_Zx = zeros(n, p);
for z = 1:p
 K = feval(covfunc{:}, hyp.cov, Z(:, z)) + eye(n) * nugget;
 f_Zx(:, z) = chol(K)' * randn(n, 1);
end
f_Zx = f_Zx / sqrt(p);

% z = 1; [~, idx] = sort(Z(:, z)); plot(Z(idx, z), f_Zx(idx, z))

X_signal = sum(f_Zx, 2);
x_var = x_noise_prop * var(X_signal);
X = X_signal + randn(n, 1) * sqrt(x_var);

%% Generate outcome

coeff_Zy_layer_1 = randn(p + 1, 1) / sqrt(p + 1);
coeff_Zy_layer_2 = generate_coeff(2);
Z_1 = [Z ones(n, 1)] * coeff_Zy_layer_1; 
f_Zy_total = sum(coeff_Zy_layer_2(:, ones(n, 1))' .* [Z_1.^2 Z_1 ones(n, 1)], 2);

m_Zy_total = coeff_Zy_layer_1(end);
var_Zy_total = coeff_Zy_layer_1(1:p)' * cov_Z * coeff_Zy_layer_1(1:p);

confound_strength = corr(X_signal, f_Zy_total, 'type', 'Spearman');

y_var = y_noise_prop * var(f_Zy_total);
other_causes = f_Zy_total + randn(n, 1) * sqrt(y_var);
bottom_idx = round(signal_level * length(other_causes));
top_idx = round((1 -  signal_level) * length(other_causes));
sort_other_causes = sort(other_causes);
effect_size = sort_other_causes(top_idx) - sort_other_causes(bottom_idx);

if degree == 0
  K = feval(covfunc{:}, hyp.cov, X) + eye(n) * nugget;
  f_xy = chol(K)' * randn(n, 1);  
  model.poly = false;
else
  model.poly = true;
  model.center = mean(X);
  model.scale = sqrt(var(X));
  X_sd = (X - model.center) / model.scale;
  model.coeff = generate_coeff(degree);  
  input_f_xy = repmat(X_sd, 1, degree + 1).^repmat(degree:-1:0, n, 1); 
  pre_f_xy = input_f_xy * model.coeff;
  model.coeff = model.coeff / (max(pre_f_xy) - min(pre_f_xy)) * effect_size;
  f_xy = input_f_xy * model.coeff;
end

Y = other_causes + f_xy;

%% Finalize

model.cov_Z = cov_Z;
model.coeff_Zy_layer_1 = coeff_Zy_layer_1;
model.coeff_Zy_layer_2 = coeff_Zy_layer_2;
model.x_var = x_var;
model.y_var = y_var;

m2_Zy_total = m_Zy_total^2 + var_Zy_total;
m3_Zy_total = m_Zy_total^3 + 3 * m_Zy_total * var_Zy_total;
m4_Zy_total = m_Zy_total^4 + 6 * m_Zy_total^2 * var_Zy_total + 3 * var_Zy_total^2;
m2 =   coeff_Zy_layer_2(1) * coeff_Zy_layer_2(1) * m4_Zy_total + ... % a^2 X^4
       coeff_Zy_layer_2(2) * coeff_Zy_layer_2(2) * m2_Zy_total + ... % b^2 X^2    
       coeff_Zy_layer_2(3) * coeff_Zy_layer_2(3)               + ... % c^2
   2 * coeff_Zy_layer_2(1) * coeff_Zy_layer_2(2) * m3_Zy_total + ... % 2ab X^3
   2 * coeff_Zy_layer_2(1) * coeff_Zy_layer_2(3) * m2_Zy_total + ... % 2ac X^2
   2 * coeff_Zy_layer_2(2) * coeff_Zy_layer_2(3) *  m_Zy_total;      % 2bc X
model.intercept = coeff_Zy_layer_2(1) * m2_Zy_total + coeff_Zy_layer_2(2) * m_Zy_total + coeff_Zy_layer_2(3);
model.exo_var = m2 - model.intercept^2;   

dat = zeros(n, p + 2);
dat(:, 1:p) = Z;
dat(:, p + 1) = X;
dat(:, p + 2) = Y;

model.f_xy = f_xy;
