clc; clear;

addpath(genpath('./FastICA_25/'));

%% Load data

root_fldr = strcat("./MNISTClassifcationExp/",string(datetime('now','TimeZone','local','Format','d-MMM-y HH:mm:ss Z')));

dataset = load("mnist-with-awgn.mat");

%% Generate data by considering sub-images

window_size   = 7;
x_num_windows = 28/window_size;
y_num_windows = 28/window_size;

num_train   = floor(length(dataset.train_x)/10);
noise_power = 10;
k           = 784;
U           = randn(k,k);
Sigma       = noise_power*(1/k)*(U*U');
blocksigma  = zeros(k,k);
Uprime      = randn(5,5);
blocksigma(k/2-2:k/2+2, k/2-2:k/2+2) = (1/5)*(Uprime*Uprime');
Sigma       = Sigma + blocksigma;

U_window = randn(window_size,window_size);
Sigma_window = noise_power*(1/window_size)*(U_window*U_window');
% Sigma_window = zeros(window_size, window_size);
% for i=1:window_size
%     for j=1:window_size
%         Sigma_window(i,j) = ((sin(5*pi*(i+j)/(2*window_size)))^2 + 1)*5;
%     end
% end
% Sigma_window = Sigma_window*Sigma_window';
% Sigma_window

X           = zeros(num_train*x_num_windows*y_num_windows, window_size*window_size);
itr         = 1;
mkdir(strcat(root_fldr,"/TrainingMNIST"));
for i=1:num_train
    image  = dataset.train_x(i,:);
    label  = find(dataset.train_y(i,:) == 1) - 1;
    % e      = mvnrnd(zeros(k,1), Sigma, 1);
    % image  = double(image) + e;
    image  = reshape(double(image), [28,28])';
    for x=1:28
        for y=1:28
            image(x,y) = image(x,y) + randn()*abs(50*(sin(5*pi*(x+y)/(2*28))));
        end
    end
    % figure;imagesc(image);
    % image(10:20,10:20) = image(10:20,10:20) + 10*randn(11,11);
    for x=1:x_num_windows
        for y=1:y_num_windows
            image_window = image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y);
            % noisy_image_window = image_window + mvnrnd(zeros(window_size,1), Sigma_window, 1);
            X(itr, :) = reshape(image_window, [1,window_size*window_size]);
            itr = itr + 1;
        end
    end
    if(mod(i,10) == 0)
        fprintf("%d images processed \n", i);
    end
    subfolder = strcat(root_fldr,"/TrainingMNIST/", string(label));
    if not(isfolder(subfolder))
        mkdir(subfolder);
    end
    filename = strcat(subfolder, "/", string(i), ".png");
    imwrite(uint8(image), filename, "png");
end

%% Perform PCA to reduce dimensionality 

mean_X               = mean(X,1);
X                    = X - mean_X;
cov                  = cov(X);
[U,Lambda]           = eig(cov);
U                    = real(U);
Lambda               = real(Lambda);
[Lambda, ind]        = sort(diag(Lambda));
U                    = U(:, ind);
k_pca                = 25;
k_ica                = 25;
principal_components = U(:,end-k_pca+1:end);
new_X                = X*principal_components;

%% Run algorithms

A_jade                = jade(new_X',k_ica); fprintf("JADE completed\n");
[A_fastica,~]         = fastica(new_X','numOfIC',k_ica); fprintf("FASTICA completed\n");
cgf_fn                = cgf(new_X);
chf_fn                = symmetric_chf(new_X);
kurtosis_fn           = kurtosis(new_X);
C_cgf                 = A_jade*A_jade';
C_chf                 = A_jade*A_jade';
C_kurtosis            = kurtosis_fn.estimate_C(20);
pinv_C_cgf            = pinv(C_cgf);
pinv_C_chf            = pinv(C_chf);
pinv_C_kurtosis       = pinv(C_kurtosis);

A_cgf                 = zeros(k_pca,k_ica);
B1_cgf                = zeros(k_ica,k_pca);
A_chf                 = zeros(k_pca,k_ica);
B1_chf                = zeros(k_ica,k_pca);
A_kurtosis            = zeros(k_pca,k_ica);
B1_kurtosis           = zeros(k_ica,k_pca);
A_meta                = zeros(k_pca,k_ica);
B1_meta               = zeros(k_ica,k_pca);
verbose_flag          = 0;
maxruns               = 20;
num_trials            = 200;
random_projections    = randn(2*num_trials,k_ica);

for i=1:k_ica
    fprintf('Column %d\n',i);
    u_init = randn(k_pca,1);
    
    fprintf('CGF run\n');
    M_cgf = eye(k_pca) - A_cgf*B1_cgf;
    [u_cgf,~,~,~] = ICA_power(new_X, ...
                              maxruns, ...
                              ones(k_pca,k_ica), ...
                              verbose_flag, ...
                              cgf_fn, ...
                              u_init, ...
                              C_cgf, ...
                              M_cgf);

    A_cgf(:,i)=u_cgf';
    u1=pinv_C_cgf*A_cgf(:,i);
    B1_cgf(i,:)=u1/(u1'*A_cgf(:,i));

    fprintf('CHF run\n');
    M_chf = eye(k_pca) - A_chf*B1_chf;
    [u_chf,~,~,~] = ICA_power(new_X, ...
                              maxruns, ...
                              ones(k_pca,k_ica), ...
                              verbose_flag, ...
                              chf_fn, ...
                              u_init, ...
                              C_chf, ...
                              M_chf);

    A_chf(:,i)=u_chf';
    u1=pinv_C_chf*A_chf(:,i);
    B1_chf(i,:)=u1/(u1'*A_chf(:,i));

    fprintf('Kurtosis run\n');
    M_kurtosis = eye(k_pca) - A_kurtosis*B1_kurtosis;
    [u_kurtosis,~,~,~] = ICA_power(new_X, ...
                                   maxruns, ...
                                   ones(k_pca,k_ica), ...
                                   verbose_flag, ...
                                   kurtosis_fn, ...
                                   u_init, ...
                                   C_kurtosis, ...
                                   M_kurtosis);

    A_kurtosis(:,i)=u_kurtosis';
    u1=pinv_C_kurtosis*A_kurtosis(:,i);
    B1_kurtosis(i,:)=u1/(u1'*A_kurtosis(:,i));

    fprintf('=====================\n');
    fprintf('=====================\n');
end

fprintf('Meta run\n');

ind_scores_chf      = Delta_Score_Total(new_X,C_chf,A_chf,random_projections,num_trials);
ind_scores_cgf      = Delta_Score_Total(new_X,C_cgf,A_cgf,random_projections,num_trials);
ind_scores_kurtosis = Delta_Score_Total(new_X,C_kurtosis,A_kurtosis,random_projections,num_trials);
ind_scores_jade     = Delta_Score_Total(new_X,A_jade*A_jade',A_jade,random_projections,num_trials);
ind_scores_fastica  = Delta_Score_Total(new_X,A_fastica*A_fastica',A_fastica,random_projections,num_trials);

fprintf("Independence Score CHF : %.10f\n", ind_scores_chf);
fprintf("Independence Score CGF : %.10f\n", ind_scores_cgf);
fprintf("Independence Score Kurtosis : %.10f\n", ind_scores_kurtosis);
fprintf("Independence Score JADE : %.10f\n", ind_scores_jade);
fprintf("Independence Score FASTICA : %.10f\n", ind_scores_fastica);

A_matrices    = {A_cgf, A_chf, A_kurtosis, A_jade, A_fastica};
scores        = [ind_scores_cgf, ind_scores_chf, ind_scores_kurtosis, ...
                 ind_scores_jade, ind_scores_fastica];
[~,min_index] = min(scores);
A_meta = A_matrices{min_index};
ind_scores_meta = scores(min_index);
fprintf("Independence Score Meta : %.10f\n", ind_scores_meta);

A_cgf      = normr(A_cgf);
A_chf      = normr(A_chf);
A_kurtosis = normr(A_kurtosis);
A_jade     = normr(A_jade);
A_fastica  = normr(A_fastica);
A_meta     = normr(A_meta);

% %% Orthogonalize matrices
% A_cgf      = orth(A_cgf);
% A_chf      = orth(A_chf);
% A_kurtosis = orth(A_kurtosis);
% A_jade     = orth(A_jade);
% A_fastica  = orth(A_fastica);
% A_meta     = orth(A_meta);

%% Denoise test data

write_test_image = 1;

for j=1:6

    if(j == 1)
        img_title = "CGF";
        A     = A_cgf;
    elseif(j == 2)
        img_title = "CHF";
        A     = A_chf;
    elseif(j == 3)
        img_title = "Kurtosis";
        A     = A_kurtosis;
    elseif(j == 4)
        img_title = "JADE";
        A     = A_jade;
    elseif(j == 5)
        img_title = "FastICA";
        A     = A_fastica;
    else
        img_title = "Meta";
        A     = A_meta;
    end

    % mkdir(img_title);
    inv_A = pinv(A);
    inv_A = inv_A*(inv_A'*inv_A);
    A     = pinv(inv_A);

    X_test_images        = dataset.test_x;
    Y_test_images        = dataset.test_y;
    original_test_images = zeros(10000*x_num_windows*y_num_windows, window_size*window_size);
    denoised_test_images = zeros(10000*x_num_windows*y_num_windows, window_size*window_size);
    itr = 1;
    for i=1:100
        image  = X_test_images(i,:);
        % e      = mvnrnd(zeros(k,1), Sigma, 1);
        % image  = double(image) + e;
        % original_test_images(i,:) = image;
        image  = reshape(double(image), [28,28])';
        for x=1:28
            for y=1:28
                image(x,y) = image(x,y) + randn()*abs(50*(sin(5*pi*(x+y)/(2*28))));
            end
        end
        % image(12:16,12:16) = image(12:16,12:16) + 10*randn(5,5);
        % image(10:20,10:20) = image(10:20,10:20) + 10*randn(11,11);
        % figure;
        % imshow(image);
        % title("Noisy image");
        % figure;
        % imshow(reshape(dataset.test_x(i,:), [28,28])');
        % title("Original image");
        % keyboard;
        for x=1:x_num_windows
            for y=1:y_num_windows
                image_window       = image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y);
                % noisy_image_window = image_window + mvnrnd(zeros(window_size,1), Sigma_window, 1);
                x_test             = reshape(image_window, [1,window_size*window_size]);
                x_test             = double(x_test);
                original_test_images(itr,:) = x_test;
                x_test             = x_test - mean_X;
                s_test             = inv_A*(x_test*principal_components)';
                % s_test           = sign(s_test).*max(0, (abs(s_test) - 25)/2 + 0.5*sqrt( (abs(s_test) + 25).^2 - 4*10 ) );
                s_test             = sign(s_test).*max(0, abs(s_test)-50);
                s_test             = real(s_test);
                denoised_x_test    = (A*s_test)'*principal_components' + mean_X;
                denoised_test_images(itr,:) = denoised_x_test;
                itr = itr + 1;
            end
        end
        if(mod(i,1000) == 0)
            fprintf("%d test images denoised \n", i);
        end
    end

    %% See example denoised image

    % indices = 1:length(X_test_images);
    indices = 1:100;
    counter = 1;
    squared_error = 0;
    for i=1:length(indices)
        index           = indices(i);
        denoised_index  = x_num_windows*y_num_windows*(index-1);
        % original_image  = double(original_test_images(index,:));
        label           = find(Y_test_images(index,:) == 1) - 1;
        % original_image  = reshape(original_image, [28,28])';

        itr = 1;
        denoised_image = zeros(28,28);
        original_image = zeros(28,28);
        for x=1:x_num_windows
            for y=1:y_num_windows
                % original_flattened_subimg = reshape(original_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y), [1,window_size*window_size]);
                % flattened_subimg          = uint8(denoised_test_images(denoised_index+itr,:));
                original_flattened_subimg = original_test_images(denoised_index+itr,:);
                flattened_subimg          = denoised_test_images(denoised_index+itr,:);
                itr                       = itr + 1;
                denoised_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y) = reshape(flattened_subimg, [window_size,window_size]);
                original_image(window_size*(x-1)+1:window_size*x, window_size*(y-1)+1:window_size*y) = reshape(original_flattened_subimg, [window_size,window_size]);
            end
        end
        counter = counter + 1;

        % figure;
        % subplot(1,2,1); imshow(uint8(original_image));
        % subplot(1,2,2); imshow(uint8(denoised_image));
        x = uint8(original_image);
        y = uint8(denoised_image);
        squared_error = squared_error + sum((x(:) - y(:)).^2);
        % fprintf("%s squared-error : %.5f\n", img_title, sum((x(:) - y(:)).^2));
        % title(img_title);

        if(write_test_image == 1)
            subfolder = strcat(root_fldr, "/", "Original_Test_Image");
            if not(isfolder(subfolder))
                mkdir(subfolder);
            end
            filename = strcat(subfolder, "/", string(counter),".png");
            imwrite(uint8(original_image), filename, "png");
        end

        subfolder = strcat(root_fldr, "/", img_title);
        if not(isfolder(subfolder))
            mkdir(subfolder);
        end
        filename = strcat(subfolder, "/", string(counter),".png");
        imwrite(uint8(denoised_image), filename, "png");
    end

    fprintf("%s squared-error : %.5f\n", img_title, squared_error/counter);
    write_test_image = 0;
end

% train_path = strcat(root_fldr, "/TrainingMNIST");
% imds = imageDatastore(train_path, ...
%     'IncludeSubfolders',true,'LabelSource','foldernames');
% 
% % figure;
% % perm = randperm(6000,20);
% % for i = 1:20
% %     subplot(4,5,i);
% %     imshow(imds.Files{perm(i)});
% % end
% 
% labelCount = countEachLabel(imds);
% 
% img = readimage(imds,1);
% size(img);
% 
% numTrainFiles = 500;
% [imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
% 
% layers = [
%     imageInputLayer([28 28 1])
% 
%     convolution2dLayer(3,8,'Padding','same')
%     batchNormalizationLayer
%     reluLayer
% 
%     maxPooling2dLayer(2,'Stride',2)
% 
%     convolution2dLayer(3,16,'Padding','same')
%     batchNormalizationLayer
%     reluLayer
% 
%     maxPooling2dLayer(2,'Stride',2)
% 
%     convolution2dLayer(3,32,'Padding','same')
%     batchNormalizationLayer
%     reluLayer
% 
%     fullyConnectedLayer(10)
%     softmaxLayer
%     classificationLayer];
% 
% options = trainingOptions('sgdm', ...
%                         'InitialLearnRate',0.01, ...
%                         'MaxEpochs',4, ...
%                         'Shuffle','every-epoch', ...
%                         'ValidationData',imdsValidation, ...
%                         'ValidationFrequency',30, ...
%                         'Verbose',false, ...
%                         'Plots','training-progress');
% 
% net = trainNetwork(imdsTrain,layers,options);
% 
% YPred = classify(net,imdsTrain);
% YValidation = imdsTrain.Labels;
% 
% accuracy = sum(YPred == YValidation)/numel(YValidation);
% fprintf("Training Accuracy : %.5f\n", accuracy);
% 
% YPred = classify(net,imdsValidation);
% YValidation = imdsValidation.Labels;
% 
% accuracy = sum(YPred == YValidation)/numel(YValidation);
% fprintf("Validation Accuracy : %.5f\n", accuracy);
% 
% test_dataset = strcat(root_fldr,"/CGF");
% test_datastore = imageDatastore(test_dataset, ...
%                     'IncludeSubfolders',true,'LabelSource','foldernames');
% 
% YPred = classify(net,test_datastore);
% YValidation = test_datastore.Labels;
% 
% accuracy = sum(YPred == YValidation)/numel(YValidation);
% fprintf("CGF Accuracy : %.5f\n", accuracy);
% 
% test_dataset = strcat(root_fldr,"/CHF");
% test_datastore = imageDatastore(test_dataset, ...
%                     'IncludeSubfolders',true,'LabelSource','foldernames');
% 
% YPred = classify(net,test_datastore);
% YValidation = test_datastore.Labels;
% 
% accuracy = sum(YPred == YValidation)/numel(YValidation);
% fprintf("CHF Accuracy : %.5f\n", accuracy);
% 
% test_dataset = strcat(root_fldr,"/FastICA");
% test_datastore = imageDatastore(test_dataset, ...
%                     'IncludeSubfolders',true,'LabelSource','foldernames');
% 
% YPred = classify(net,test_datastore);
% YValidation = test_datastore.Labels;
% 
% accuracy = sum(YPred == YValidation)/numel(YValidation);
% fprintf("FastICA Accuracy : %.5f\n", accuracy);
% 
% test_dataset = strcat(root_fldr,"/JADE");
% test_datastore = imageDatastore(test_dataset, ...
%                     'IncludeSubfolders',true,'LabelSource','foldernames');
% 
% YPred = classify(net,test_datastore);
% YValidation = test_datastore.Labels;
% 
% accuracy = sum(YPred == YValidation)/numel(YValidation);
% fprintf("JADE Accuracy : %.5f\n", accuracy);
% 
% test_dataset = strcat(root_fldr,"/Kurtosis");
% test_datastore = imageDatastore(test_dataset, ...
%                     'IncludeSubfolders',true,'LabelSource','foldernames');
% 
% YPred = classify(net,test_datastore);
% YValidation = test_datastore.Labels;
% 
% accuracy = sum(YPred == YValidation)/numel(YValidation);
% fprintf("Kurtosis Accuracy : %.5f\n", accuracy);
% 
% test_dataset = strcat(root_fldr,"/Meta");
% test_datastore = imageDatastore(test_dataset, ...
%                     'IncludeSubfolders',true,'LabelSource','foldernames');
% 
% YPred = classify(net,test_datastore);
% YValidation = test_datastore.Labels;
% 
% accuracy = sum(YPred == YValidation)/numel(YValidation);
% fprintf("Meta Accuracy : %.5f\n", accuracy);
