% Learning OT map between multivariate Gaussian distributions
% NeurIPS 2020 submission, please do not distribute
clear
rng('default');
d = 100;
m = 50;
n = 50;

M_test = 200;

myRidge = 1e-8;
flag_squared = true;
sigma_rbf = [0.5];
cost_type = 'euclidean'; %gaussian, dirac, euclidean

fprintf('Data generation: zero mean, trace normalized variance\n');
mu_src = zeros(d,1);
mu_tgt = zeros(d,1);

sigma_src = rand(d,d);
sigma_src = sigma_src*sigma_src' + eye(d)*myRidge;
sigma_src = (sigma_src+sigma_src')/2;
sigma_src = sigma_src/trace(sigma_src);

sigma_tgt = rand(d,d);
sigma_tgt = sigma_tgt*sigma_tgt' + eye(d)*myRidge;
sigma_tgt = (sigma_tgt+sigma_tgt')/2;
sigma_tgt = sigma_tgt/trace(sigma_tgt);

Xste = mvnrnd(mu_src,sigma_src,M_test);

Xstr = mvnrnd(mu_src,sigma_src,m);
Xttr = mvnrnd(mu_tgt,sigma_tgt,n);

cost_Xstr_Xttr = euclidean_distances(Xstr,Xttr,flag_squared);
                    
kernel_Xstr = gaussianKernel(Xstr,Xstr,sigma_rbf);
kernel_Xttr = gaussianKernel(Xttr,Xttr,sigma_rbf);

kernel_Xstr_Xste = gaussianKernel(Xstr,Xste,sigma_rbf);

fprintf('\n\ndimension src = %d, m = %d, n = %d, sigma = %g\n',d,m,n,sigma_rbf);

%% Gaussian optimal
[~, opt_mapper] = gaussian_optimal(mu_src,mu_tgt,sigma_src,sigma_tgt);
fprintf('Gaussian optimal transport map computed\n');
opt_pred = mu_tgt + opt_mapper*( Xstr' - mu_src*ones(1,m) );
opt_pred = opt_pred'; % same orientation as Xstr

%% emd_train
alpha_emd = emd_train(cost_Xstr_Xttr);
emd_pred = barycenterSquaredEuclideanCost(alpha_emd,Xttr);
emd_squared_error = sum((emd_pred-opt_pred).^2,2);
emd_pred_error = mean(emd_squared_error);
fprintf('EMD MSE: %g\n', emd_pred_error);

lambda_1 = 100;
lambda_2 = 100;
[alpha_mat,beta_mat,gamma_mat] = proposed_admm(cost_Xstr_Xttr,kernel_Xstr,kernel_Xttr,lambda_1,lambda_2);
proposed_pred = barycenterSquaredEuclideanCost(alpha_mat,Xttr);
proposed_squared_error = sum((proposed_pred-opt_pred).^2,2);
proposed_pred_error = mean(proposed_squared_error);
fprintf('EMD MSE: %g, Proposed MSE: %g\n', emd_pred_error,proposed_pred_error);

% %For domain adaptation problems, classification and evaluation can be performed by 1-NN using the following code:
% proposed_accuracy = computeAccuracyDomainAdaptationOneNN(proposed_pred,Xtte,Ystr,Ytte);
% %In the above line, proposed_pred is obtained via barycenterSquaredEuclideanCost, Ystr is the classification labels corresponding to Xstr, Xtte is test set, and Ytte contains the labels of the test set (Xtte)

%% out-of-same predictions
opt_pred_test = mu_tgt + opt_mapper*( Xste' - mu_src*ones(1,M_test) );
opt_pred_test = opt_pred_test'; % same orientation as Xste

pseudo_alpha_mat = kernel_Xstr_Xste'*(kernel_Xstr\alpha_mat);
proposed_pred_test = barycenterSquaredEuclideanCost(pseudo_alpha_mat,Xttr);

proposed_squared_error_test = sum((proposed_pred_test-opt_pred_test).^2,2);
proposed_pred_error_test = mean(proposed_squared_error_test);
fprintf('Proposed out-of-sample MSE: %g\n', proposed_pred_error_test);

