# Graph Separation
# This code generates the the plots of Figurs 2 and 3.

import numpy as np
import pandas as pd
import argparse
import os

import torch
import torch.nn as nn

from data_generator import generate_set_pairs
from embedding import Embed
from utils import Log, mem_report, nowstr, Timer


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--d", type=int, default=3, help="Dimension of ambient space of data vectors")
    parser.add_argument("--n", type=int, default=10, help="Number of vectors in each set")

    parser.add_argument("--nSamples", type=int, default=20, help="Number of samples of cloud-pairs to test at each repetition")
    parser.add_argument("--nReps", type=int, default=25000, help="Number of times to repeat the test (to increase accuracy of result)")

    parser.add_argument("--std_X", type=float, default=1, help="Standard deviation of point-coordinates in cloud")
    parser.add_argument("--std_rel_dX_min", type=float, default=0.02, help="Minimal std of an entry in the point-cloud difference, relative to <std_X>")
    parser.add_argument("--std_rel_dX_max", type=float, default=1, help="Maximal std of an entry in the point-cloud difference, relative to <std_X>")

    parser.add_argument("--n_dX_min", type=int, default=1, help="Minimal number of points that differ in each pair of point clouds. Must be >= 1.")
    parser.add_argument("--n_dX_max", type=int, default=-1, help="Maximal number of points that differ in each pair of point clouds. Default: <n>")

    parser.add_argument("--testName", type=str, default='injectivity', help="Test name")

    parser.add_argument('--outDir', type=str, default='', help="Directory to save output .log and .csv files. Default: ./out")
    parser.add_argument("--log", type=str, default='', help="Output log file. Default is <testname>.log")
    parser.add_argument("--csv", type=str, default='', help="Output csv file. Default is <testname>.csv")
    parser.add_argument("--csvtemp", type=str, default='', help="Temporary output csv file. Default is <testname>_temp.csv")

    args = parser.parse_args()

    if args.n_dX_max == -1 or args.n_dX_max > args.n:
        args.n_dX_max = args.n

    if args.log == '':
        args.log = args.testName + '.log'

    if args.csv == '':
        args.csv = args.testName + '.csv'

    if args.csvtemp == '':
        args.csvtemp = args.testName + '_temp' + '.csv'

    if args.outDir == '':
        args.outDir = './out'
        
    os.makedirs(args.outDir, exist_ok=True)

    if args.outDir != '.':
        args.log = os.path.join(args.outDir, args.log)
        args.csv = os.path.join(args.outDir, args.csv)
        args.csvtemp = os.path.join(args.outDir, args.csvtemp)

    return args


def choose_m_vals(args):
    m_sep = 2*args.n*args.d+1

    if m_sep <= 601 or m_sep == 6001:
        m_vals = list(range(1,21)) + list(range(22, 51, 2)) + list(range(55, 101, 5)) + list(range(110, 201, 10)) + list(range(220, 601+21, 20))
    elif m_sep <= 20001:
        m_vals = list(range(1,11)) + list(range(12, 21, 2)) + list(range(25, 101, 5)) + list(range(110, 301, 10)) + list(range(325, 1001, 25)) + list(range(1100, 20001, 100))
    else:
        m_max = 1.5*(2*args.n*args.d + 1)
        num_m_vals = 100
        m_vals = np.unique(np.linspace(1.0, m_max, num=num_m_vals).astype(int))

    return m_vals


def get_activations():
    df = pd.DataFrame( index=['relu'], data=[[nn.ReLU()]], columns=['act'] ) 
    df = pd.concat([df, pd.DataFrame( index=['sigmoid'], data=[[nn.Sigmoid()]], columns=['act'] )] )
    df = pd.concat([df, pd.DataFrame( index=['tanh'], data=[[torch.tanh]], columns=['act'] )] )
    df = pd.concat([df, pd.DataFrame( index=['hardtanh'], data=[[nn.Hardtanh()]], columns=['act'] )] )
    df = pd.concat([df, pd.DataFrame( index=['swish'], data=[[nn.SiLU()]], columns=['act'] )] )
    df = pd.concat([df, pd.DataFrame( index=['mish'], data=[[nn.Mish()]], columns=['act'] )] )
    df = pd.concat([df, pd.DataFrame( index=['sin'], data=[[torch.sin]], columns=['act'] )] )
    df = pd.concat([df, pd.DataFrame( index=['cos'], data=[[torch.cos]], columns=['act'] )] )

    # Get a list of all activation names
    #names = list(df.index.values)

    # Call an activation function by its name
    #f = df.loc['relu']['act']

    return df


def run_test(args, m_vals = None):
    # Set parameters
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    dtype=torch.float64

    t = Timer()

    projOp = 'stochastic'

    log = Log(fname=args.log, screen=True)
    log('Injectivity test')

    if m_vals is None:
        m_vals = choose_m_vals(args)

    #m_vals = np.asarray(m_vals).sort()[::-1]
    m_vals = m_vals[::-1]

    # Report
    log(nowstr())
    log('Test name: %s' % (args.testName))

    log()
    log('Device: %s' % (device))
    log('Arguments: ')
    log(args)

    log()
    log('m-values: (%d total)' % (len(m_vals),))
    log(m_vals)

    log()

    acts = get_activations()
    act_names = list(acts.index.values)

    for iRep in range(1, 1+args.nReps):
        X1, X2, dists = generate_set_pairs(args.nSamples, n=args.n, d=args.d, std_X=args.std_X, std_rel_dX = (args.std_rel_dX_min, args.std_rel_dX_max), n_dX=(args.n_dX_min,args.n_dX_max), calc_distance=True, dtype=dtype, device=device)

        new_rep = True
        results_currRep = pd.DataFrame()

        for m in m_vals:        
            embed = Embed(args.d, m, std_proj=1, assume_std_in = args.std_X, activation=torch.sin, dtype=dtype, device=device)

            if not new_rep:
                (projMat, offsetVec) = projOp_stochastic
                projMat = projMat[:, range(m), :]
                offsetVec = offsetVec[:, range(m), :]
                projOp_stochastic = (projMat, offsetVec)

                #print('Shapes: Mat: %s Vec: %s' % (projMat.shape, offsetVec.shape))

            #log('=== m=%d ===' % (m,))
            for actname in act_names:
                act = acts.loc[actname]['act']

                if new_rep:
                    [X1m, projOp_stochastic] = embed(X1, override_activation=act, projOp = projOp)
                    new_rep = False
                else:
                    X1m = embed(X1, override_activation=act, projOp = projOp_stochastic)

                X2m = embed(X2, override_activation=act, projOp = projOp_stochastic)

                lip_ratios = torch.norm(X1m-X2m, dim=1) / dists

                #log('m:%d  act:%s  cond: %g   max:%g avg: %g min: %g' % (m, actname, torch.min(lip_ratios)/torch.mean(lip_ratios), torch.max(lip_ratios), torch.mean(lip_ratios), torch.min(lip_ratios)))

                result_temp_dict = {'name':actname, 'm':m, 'min': torch.min(lip_ratios).item(), 'max': torch.max(lip_ratios).item(), 'sum': torch.sum(lip_ratios).item(), 'nSamples': torch.numel(lip_ratios) }
                result_temp_df = pd.DataFrame(data=result_temp_dict, index=[0])
                results_currRep = pd.concat((results_currRep, result_temp_df), ignore_index=True)

        if iRep == 1:
            results_tot = results_currRep

        else:
            results_tot['min'] = pd.concat([results_tot['min'], results_currRep['min']], axis=1).min(axis=1)
            results_tot['max'] = pd.concat([results_tot['max'], results_currRep['max']], axis=1).max(axis=1)
            results_tot['sum'] = pd.concat([results_tot['sum'], results_currRep['sum']], axis=1).sum(axis=1)
            results_tot['nSamples'] = pd.concat([results_tot['nSamples'], results_currRep['nSamples']], axis=1).sum(axis=1)

        #log(results_currRep)
        log('Finished repetition %d / %d. tElapsed: %s  ETA: %s' % (iRep, args.nReps, t.str(), t.etastr(iRep/args.nReps)))

        results_tot['mean'] = results_tot['sum'] / results_tot['nSamples']

        results_tot['min_max_ratio'] = results_tot['min'] / results_tot['max']
        results_tot['mean_max_ratio'] = results_tot['mean'] / results_tot['max']
        results_tot['min_mean_ratio'] = results_tot['min'] / results_tot['mean']

        results_temp = process_results(results_tot, log)
        results_temp.to_csv(args.csvtemp)

    log('Finished. Final result (raw):')
    log(results_tot)

    results = process_results(results_tot, log)

    log('Finished. Final result:')
    log(results)

    results.to_csv(args.csv)


def process_results(results, log=None):
    out = pd.DataFrame()
    act_names = list(pd.unique(results['name']))
    
    out['m'] = results[results['name'] == act_names[0]].reset_index()['m']
    out['nSamples'] = results[results['name'] == act_names[0]].reset_index()['nSamples']

    for actname in act_names:
        results_curr_act = results[results['name'] == actname].reset_index()

        out['min_max_ratio' + '_' + actname] = results_curr_act['min_max_ratio']
        out['mean_max_ratio' + '_' + actname] = results_curr_act['mean_max_ratio']
        out['min_mean_ratio' + '_' + actname] = results_curr_act['min_mean_ratio']

    out = out.sort_values(by=['m']).reset_index()

    return out

args = parse_args()
run_test(args)

