#!/usr/bin/env python3
#!/usr/bin/env python3.7 -W ignore::DeprecationWarning

# Script to generate W Distance matrix from a set of embeddings (CWL)
import igraph as ig
import numpy as np

import argparse
import os
import ot
import time

# Remove deprecation warnings
import warnings
warnings.filterwarnings("ignore")

def pad_cost_function(C, val, padding='full'):
    # Two types of padding: 
    #       full:       n1 x n2 -> n1+n2 x n1+n2
    #       partial:    n1 x n2 -> n1 x n1 (if n1>n2)
    n1, n2 = C.shape
    if padding == 'full':
        C_new = np.zeros((n1+n2, n1+n2))
        C_new[:n1,:n2] = C.copy()
        C_new[n1:,:n2] = val
        C_new[:n1,n2:] = val
    elif padding == 'partial':
        if n1 == n2:
            return C
        n = max(n1,n2)
        C_new = np.zeros((n,n))
        C_new[:n1,:n2] = C.copy()
        if n1 < n2:
            C_new[n1:,:] = val
        else:
            C_new[:,n2:] = val
    return C_new

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--embeddings', type=str, help='File containing the embeddings generated by CWL', required=True)
    parser.add_argument('-o', '--output', required=True, help='Indicate the output path name, the script will generate (n) files, one for each iter of WL.')
    parser.add_argument('-n', '--num-iterations', default=3, type=int, help='Max number of Weisfeiler-Lehman iterations')
    parser.add_argument('-s', '--sinkhorn', default=False, action='store_true', help='Use Sinkhorn approximation')
    parser.add_argument('-sl', '--sinkhorn-lambda', default=1e-2, type=float, help='The lambda coefficient for the Sinkhorn approximation')
    parser.add_argument('-p', '--padding', default=None, choices=[None, 'partial', 'full'], help='Set padding on (options: partial, full; default: None)')

    start = time.time()
    args = parser.parse_args()

    if not os.path.exists(args.output):
        os.makedirs(args.output)

    # Load embeddings
    label_sequences = np.load(args.embeddings)

    print('Computing the Wasserstein distance for {}'.format(os.path.split(args.embeddings)[1]))

    # Get the iteration number from the embedding file
    n_it_orig = int(''.join(x for x in os.path.split(args.embeddings)[1].split('h')[-1] if x.isdigit()))

    n = len(label_sequences)
    emb_size = label_sequences[0].shape[1]
    n_feat = int(emb_size/(n_it_orig+1))

    # Iterate over all possible h to generate the Wasserstein matrices
    hs = range(0, args.num_iterations + 1)
    ot_time = 0
    dist_time = 0
    for h in hs:
        M = np.zeros((n,n))
        # Iterate over pairs of graphs
        for graph_index_1, graph_1 in enumerate(label_sequences):
            labels_1 = label_sequences[graph_index_1][:,:n_feat*(h+1)]
            for graph_index_2, graph_2 in enumerate(label_sequences[graph_index_1:]):
                start_dist = time.time()
                labels_2 = label_sequences[graph_index_2 + graph_index_1][:,:n_feat*(h+1)]
                costs = ot.dist(labels_1, labels_2, metric='euclidean')

                if args.padding is not None:
                    costs = pad_cost_function(costs, np.max(costs), padding=args.padding)
                dist_time += time.time()-start_dist

                start_ot = time.time()
                if args.sinkhorn:
                    mat = ot.sinkhorn(np.ones(len(labels_1))/len(labels_1), np.ones(len(labels_2))/len(labels_2), costs, args.sinkhorn_lambda, 
                                numItermax=50)
                    M[graph_index_1, graph_index_2 + graph_index_1] = np.sum(np.multiply(mat, costs))
                    # M[graph_index_1, graph_index_2 + graph_index_1] = ot.sinkhorn2(np.ones(len(labels_1))/len(labels_1), np.ones(len(labels_2))/len(labels_2), costs, lambd)
                else:
                    M[graph_index_1, graph_index_2 + graph_index_1] = \
                        ot.emd2([], [], costs)
                        
                ot_time += time.time()-start_ot
#<<<<<<< HEAD
            if args.sinkhorn:
                print('Dist time: {}'.format(dist_time))
                print('W time: {}'.format(ot_time))
                print(graph_index_1)
#=======
#>>>>>>> 3fbb1f95739594178238fe7bebad12458dea200f
        M = (M + M.T) #/ 2
        # Save output
        filext = 'wasserstein_distance_matrix'
        if args.padding is not None:
            filext += '_{}_padding'.format(args.padding)
        if args.sinkhorn:
            filext += '_sinkhorn_l{}'.format(args.sinkhorn_lambda)
        filext += '_it{}.npy'.format(h)

        np.save(args.output+filext, M)
        print("Iteration for h = {} completed.".format(h))

    print('Total elapsed time: {:2.2f} s'.format(time.time()-start))
    print('Ground distance computation time: {:2.2f} s'.format(dist_time))
    print('Wasserstein computation time: {:2.2f} s'.format(ot_time))
