% sort_covariates_oracle:
%
% Sort importance of covariates according to the known truth. Importance
% here is defined by the amount of confounding between treatment and
% outcome when adjusting for each variable individually. This is quantified
% by the estimated bias implied by the adjustment.
%
% Input:
%
% - dat: the observational data;
% - x, y, Z: index of treatment, outcome and possible covariates;
% - num_iter: number of optimization steps used to fit Gaussian processes;
% - use_gp: if false, don't use GPs, use a simple fixed-basis regression
%   instead;
% - X_space: dose levels to be used;
% - f_space: corresponding expected responses;
% - verbose: if true, display output.
%
% Output:
%
% - Z_sorted: index of the covariates, sorted from the variable that allows
%   for the highest bias to the one that allows for the least bias;

function Z_sorted = sort_covariates_oracle(dat, x, y, Z, num_iter, use_gp, X_space, f_space, verbose) %#ok<*INUSL>

p = length(Z);
n = size(dat, 1);
Z_scores = zeros(p, 1);
likfunc = @likGauss; %#ok<*NASGU>
covfunc = {@covMaternard, 3}; 
hyp.lik = log(0.1);

num_space = length(X_space);
subsample = 1:min(500, n);

for i = 1:p
    
  if verbose, fprintf('Scoring %d\n', i); end
  
  y_hat = zeros(num_space, 1);
  
  if use_gp % Slow!
      
    % regress y on z and x
    
    hyp.cov = zeros(3, 1); 
    [~, hyp_obs1] = evalc('minimize(hyp, @gp, -num_iter, @infExact, [], covfunc, likfunc, dat(:, [x Z(i)]), dat(:, y))');
    for j = 1:length(X_space)
      dat_j = [ones(n, 1) * X_space(j) dat(:, Z(i))];      
      y_hat(j) = mean(gp(hyp_obs1, @infExact, [], covfunc, likfunc, dat(subsample, [x Z(i)]), dat(subsample, y), dat_j));
    end    
    
  else % Not as flexible
   
    % regress y on x and z
    
    XZ_design = nonlinear_basis2(dat, x, Z(i));
    beta = (XZ_design' * XZ_design) \ (XZ_design' * dat(:, y));
    for j = 1:length(X_space)
      dat_j = [ones(n, 1) * X_space(j) dat(:, Z(i))];
      pred_design = nonlinear_basis2(dat_j, 1, 2);
      y_hat(j) = mean(pred_design * beta);
    end
    
  end
    
  % assess Z scores
  
  Z_scores(i) = mean(abs(y_hat - f_space));
  
end

% Sorted from 'worst' to 'best' variables to adjust for
[~, Z_sorted] = sort(abs(Z_scores), 'descend');
