clc; clear;
addpath(genpath('./FastICA_25/'));

%% Prepare data
document_term_matrix = load("document_term_matrix.mat");
vocabulary           = load("vocabulary.mat");
X                    = document_term_matrix.data;
mean_X               = mean(X,1);
X                    = X - mean_X;
cov                  = X*X';
[U,Lambda]           = eig(cov);
U                    = real(U);
Lambda               = real(Lambda);
[Lambda, ind]        = sort(diag(Lambda));
U                    = U(:, ind);
k_pca                = 50;
k_ica                = 5;
principal_components = X'*U(:,end-k_pca+1:end);
new_X                = X*principal_components;

%% Run algorithms

A_jade                = jade(new_X',k_ica); fprintf("JADE completed\n");
[A_fastica,~]         = fastica(new_X','numOfIC',k_ica); fprintf("FASTICA completed\n");
cgf_fn                = cgf(new_X);
chf_fn                = symmetric_chf(new_X);
kurtosis_fn           = kurtosis(new_X);
C_cgf                 = A_jade*A_jade';
C_chf                 = A_jade*A_jade';
C_kurtosis            = kurtosis_fn.estimate_C(20);
pinv_C_cgf            = pinv(C_cgf);
pinv_C_chf            = pinv(C_chf);
pinv_C_kurtosis       = pinv(C_kurtosis);

A_cgf                 = zeros(k_pca,k_ica);
B1_cgf                = zeros(k_ica,k_pca);
A_chf                 = zeros(k_pca,k_ica);
B1_chf                = zeros(k_ica,k_pca);
A_kurtosis            = zeros(k_pca,k_ica);
B1_kurtosis           = zeros(k_ica,k_pca);
A_meta                = zeros(k_pca,k_ica);
B1_meta               = zeros(k_ica,k_pca);
verbose_flag          = 0;
maxruns               = 20;
num_trials            = 200;
random_projections    = randn(2*num_trials,k_ica);

for i=1:k_ica
    fprintf('Column %d\n',i);
    u_init = randn(k_pca,1);
    
    fprintf('CGF run\n');
    M_cgf = eye(k_pca) - A_cgf*B1_cgf;
    [u_cgf,~,~,~] = ICA_power(new_X, ...
                              maxruns, ...
                              ones(k_pca,k_ica), ...
                              verbose_flag, ...
                              cgf_fn, ...
                              u_init, ...
                              C_cgf, ...
                              M_cgf);

    A_cgf(:,i)=u_cgf';
    u1=pinv_C_cgf*A_cgf(:,i);
    B1_cgf(i,:)=u1/(u1'*A_cgf(:,i));

    fprintf('CHF run\n');
    M_chf = eye(k_pca) - A_chf*B1_chf;
    [u_chf,~,~,~] = ICA_power(new_X, ...
                              maxruns, ...
                              ones(k_pca,k_ica), ...
                              verbose_flag, ...
                              chf_fn, ...
                              u_init, ...
                              C_chf, ...
                              M_chf);

    A_chf(:,i)=u_chf';
    u1=pinv_C_chf*A_chf(:,i);
    B1_chf(i,:)=u1/(u1'*A_chf(:,i));

    fprintf('Kurtosis original run\n');
    M_kurtosis = eye(k_pca) - A_kurtosis*B1_kurtosis;
    [u_kurtosis,~,~,~] = ICA_power(new_X, ...
                                   maxruns, ...
                                   ones(k_pca,k_ica), ...
                                   verbose_flag, ...
                                   kurtosis_fn, ...
                                   u_init, ...
                                   C_kurtosis, ...
                                   M_kurtosis);

    A_kurtosis(:,i)=u_kurtosis';
    u1=pinv_C_kurtosis*A_kurtosis(:,i);
    B1_kurtosis(i,:)=u1/(u1'*A_kurtosis(:,i));

    fprintf('=====================\n');
    fprintf('=====================\n');
end

fprintf('Meta run\n');

ind_scores_chf      = Delta_Score_Total(new_X,C_chf,A_chf,random_projections,num_trials);
ind_scores_cgf      = Delta_Score_Total(new_X,C_cgf,A_cgf,random_projections,num_trials);
ind_scores_kurtosis = Delta_Score_Total(new_X,C_kurtosis,A_kurtosis,random_projections,num_trials);
ind_scores_jade     = Delta_Score_Total(new_X,A_jade*A_jade',A_jade,random_projections,num_trials);
ind_scores_fastica  = Delta_Score_Total(new_X,A_fastica*A_fastica',A_fastica,random_projections,num_trials);

fprintf("Independence Score CHF : %.10f\n", ind_scores_chf);
fprintf("Independence Score CGF : %.10f\n", ind_scores_cgf);
fprintf("Independence Score Kurtosis : %.10f\n", ind_scores_kurtosis);
fprintf("Independence Score JADE : %.10f\n", ind_scores_jade);
fprintf("Independence Score FASTICA : %.10f\n", ind_scores_fastica);

A_matrices    = {A_cgf, A_chf, A_kurtosis, A_jade, A_fastica};
scores        = [ind_scores_cgf, ind_scores_chf, ind_scores_kurtosis, ...
                 ind_scores_jade, ind_scores_fastica];
[~,min_index] = min(scores);
A_meta = A_matrices{min_index};

A_cgf      = normr(A_cgf);
A_chf      = normr(A_chf);
A_kurtosis = normr(A_kurtosis);
A_jade     = normr(A_jade);
A_fastica  = normr(A_fastica);
A_meta     = normr(A_meta);

%% Projecting back to original space

signals_cgf      = principal_components*A_cgf + mean_X';
signals_chf      = principal_components*A_chf + mean_X';
signals_kurtosis = principal_components*A_kurtosis + mean_X';
signals_jade     = principal_components*A_jade + mean_X';
signals_fastica  = principal_components*A_fastica + mean_X';
signals_meta     = principal_components*A_fastica + mean_X';


%% Top-10 positive and negative words

all_signals = signals_fastica;
r           = 20;

for signal_id = 1:5
    [sorted_signals_chf_1,indices_sorted_signals_chf_1] = sort(all_signals(:,signal_id));
    
    top10    = indices_sorted_signals_chf_1(end-r+1:end);
    bottom10 = indices_sorted_signals_chf_1(1:r);
    
    for i=1:length(top10)
        index = top10(r-i+1);
        c = vocabulary.data(index);
        fprintf(1,'%s ', c{1});
    end
    fprintf('\n')
    % fprintf("=================\n");
    
    % for i=1:length(bottom10)
    %     index = bottom10(r-i+1);
    %     c = vocabulary.data(index);
    %     fprintf(1,'%s ', c{1});
    %     %vocabulary.data(index)
    % end
end