import os
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from coinpress_mod.algos import L2, multivariate_mean_iterative, multivariate_mean_iterative_n
from coinpress_mod.functions import coinpress_mean
from quantile_binary_search.method import random_rotation_mean, clipped_mean

def generate_gauss_paths(mean,cov,nPaths,nSamples,d,fileOut=''):
    X = []
    for i in range(nPaths):
        X.append(np.random.multivariate_normal(mean, cov, int(nSamples)))
    return X
    
def generate_laplace_paths(mean,cov,nPaths,nSamples,d,fileOut=''):
    X = []
    scale = np.sqrt(cov[0,0]/2)
    for i in range(nPaths):
        X.append(np.random.laplace(mean[0], scale, (int(nSamples),d)))
    return X

def test_dist(mean,cov,p,nPaths,nSamples,d,c,r,u,bGenerate=True,generate_func=generate_gauss_paths,prefix='gauss_',folder=''):
    if bGenerate:
        X_paths = generate_func(mean,cov,nPaths,nSamples,d)
        if not(folder==''):
            filename=prefix+str(nPaths)+'paths_'+str(nSamples)+'samples_dim'+str(d)+'.txt'
            save_paths(X_paths,nPaths,nSamples,d,folder,filename)
    else:
        filename=prefix+str(nPaths)+'paths_'+str(nSamples)+'samples_dim'+str(d)+'.txt'
        X_paths = load_paths(folder+filename,nPaths,nSamples,d)
    non_pr = []
    means_t1 = []
    means_t2 = []
    means_t3 = []
    means_t4 = []
    means_t10 = []
    means_hada = []
    means_clip = []
    means_t1_n = []
    means_t2_n = []
    means_t3_n = []
    means_t4_n = []
    means_t10_n = []

    for i in range(nPaths):
        X = X_paths[i]
        non_pr.append(L2(np.mean(X, axis=0)-mean))
        means_hada.append(L2(random_rotation_mean(X.copy(),d,u,p)-mean))
        means_clip.append(L2(clipped_mean(X.copy(),nSamples,d,u,p)-mean))
        means_t1.append(L2(coinpress_mean(X.copy(), c, r, 1, p, func=multivariate_mean_iterative)-mean))
        means_t1_n.append(L2(coinpress_mean(X.copy(), c, r, 1, p, func=multivariate_mean_iterative_n)-mean))
        means_t2.append(L2(coinpress_mean(X.copy(), c, r, 2, p, func=multivariate_mean_iterative)-mean))
        means_t2_n.append(L2(coinpress_mean(X.copy(), c, r, 2, p, func=multivariate_mean_iterative_n)-mean))
        means_t3.append(L2(coinpress_mean(X.copy(), c, r, 3, p, func=multivariate_mean_iterative)-mean))
        means_t3_n.append(L2(coinpress_mean(X.copy(), c, r, 3, p, func=multivariate_mean_iterative_n)-mean))
        means_t4.append(L2(coinpress_mean(X.copy(), c, r, 4, p, func=multivariate_mean_iterative)-mean))
        means_t4_n.append(L2(coinpress_mean(X.copy(), c, r, 4, p, func=multivariate_mean_iterative_n)-mean))
        means_t10.append(L2(coinpress_mean(X.copy(), c, r, 10, p, func=multivariate_mean_iterative)-mean))
        means_t10_n.append(L2(coinpress_mean(X.copy(), c, r, 10, p, func=multivariate_mean_iterative_n)-mean))
    return non_pr, means_t1, means_t1_n, means_t2, means_t2_n, means_t3, means_t3_n, means_t4, means_t4_n, means_t10, means_t10_n, means_hada, means_clip


def test_dim(mean0,sigma0,dims,p,nPaths,nSamples,bGenerate=True,folder='',generate_func=generate_gauss_paths,prefix=''):
    err_nonpr = []
    err_t1 = []
    err_t1n = []
    err_t2 = []
    err_t2n = []
    err_t3 = []
    err_t3n = []
    err_t4 = []
    err_t4n = []
    err_t10 = []
    err_t10n = []
    err_hada = []
    err_clip = []

    err_nonpr_paths = []
    err_t1_paths = []
    err_t1n_paths = []
    err_t2_paths = []
    err_t2n_paths = []
    err_t3_paths = []
    err_t3n_paths = []
    err_t4_paths = []
    err_t4n_paths = []
    err_t10_paths = []
    err_t10n_paths = []
    err_hada_paths = []
    err_clip_paths = []

    for d in dims:
        non_pr = []
        mean_t1 = []
        mean_t1n = [] 
        mean_t2 = [] 
        mean_t2n = []
        mean_t3 = []
        mean_t3n = []
        mean_t4 = []
        mean_t4n = []
        mean_hada = []
        print(d)
        mean = [mean0]*d
        cov = sigma0*sigma0*np.eye(d)
        c = [0.0]*d
        r = 10*np.sqrt(d)
        u = 2.0*r
        T = 18
        assert(np.linalg.norm(np.array(mean)-np.array(c))<=r)
        assert(nSamples >= 3 * np.sqrt(d)*np.sqrt(T)/np.sqrt(2.0*p))
        non_pr, mean_t1, mean_t1n, mean_t2, mean_t2n, mean_t3, mean_t3n, mean_t4, mean_t4n, mean_t10, mean_t10n, mean_hada, mean_clip = test_dist(mean,cov,p,nPaths,nSamples,d,c,r,u,bGenerate=bGenerate,generate_func=generate_func,folder='')
        err_nonpr.append(np.mean(non_pr))
        err_t1.append(np.mean(mean_t1))
        err_t2.append(np.mean(mean_t2))
        err_t3.append(np.mean(mean_t3))
        err_t4.append(np.mean(mean_t4))
        err_t10.append(np.mean(mean_t10))
        err_t1n.append(np.mean(mean_t1n))
        err_t2n.append(np.mean(mean_t2n))
        err_t3n.append(np.mean(mean_t3n))
        err_t4n.append(np.mean(mean_t4n))
        err_t10n.append(np.mean(mean_t10n))
        err_hada.append(np.mean(mean_hada))
        err_clip.append(np.mean(mean_clip))

        err_nonpr_paths.append(non_pr)
        err_t1_paths.append(mean_t1)
        err_t2_paths.append(mean_t2)
        err_t3_paths.append(mean_t3)
        err_t4_paths.append(mean_t4)
        err_t10_paths.append(mean_t10)
        err_t1n_paths.append(mean_t1n)
        err_t2n_paths.append(mean_t2n)
        err_t3n_paths.append(mean_t3n)
        err_t4n_paths.append(mean_t4n)
        err_t10n_paths.append(mean_t10n)
        err_hada_paths.append(mean_hada)
        err_clip_paths.append(mean_clip)

    strdims = '|'.join([str(di) for di in dims])
    params = [['nPaths','nSamples','d','mean','sigma','p','u','r'],[str(nPaths),str(nSamples),strdims,str(mean0),str(sigma0),str(p),str(u),str(r)]]
    #'nPaths='+str(nPaths)+',nSamples='+str(nSamples)+',d='+str(d)+',sigma='+str(sigma0)+',epsilon='+str(eps)+',u='+str(u)+',r='+str(r)
    subfolder= str(nPaths) + prefix + 'paths_' + str(nSamples) + 'samples_' + str(dims[0]) + '_' +str(len(dims)) + '/'
    strfolder = folder + subfolder
    write_text(params,strfolder,'params.txt')
    write_output(dims,err_nonpr,strfolder,'err_nonpr.txt')
    write_output(dims,err_t1,strfolder,'err_t1.txt')
    write_output(dims,err_t1n,strfolder,'err_t1n.txt')
    write_output(dims,err_t2,strfolder,'err_t2.txt')
    write_output(dims,err_t2n,strfolder,'err_t2n.txt')
    write_output(dims,err_t3,strfolder,'err_t3.txt')
    write_output(dims,err_t3n,strfolder,'err_t3n.txt')
    write_output(dims,err_t4,strfolder,'err_t4.txt')
    write_output(dims,err_t4n,strfolder,'err_t4n.txt')
    write_output(dims,err_t10,strfolder,'err_t10.txt')
    write_output(dims,err_t10n,strfolder,'err_t10n.txt')
    write_output(dims,err_hada,strfolder,'err_hada.txt')
    write_output(dims,err_clip,strfolder,'err_clip.txt')

    write_output(dims,err_nonpr_paths,strfolder,'err_nonpr_paths.txt')
    write_output(dims,err_t1_paths,strfolder,'err_t1_paths.txt')
    write_output(dims,err_t1n_paths,strfolder,'err_t1n_paths.txt')
    write_output(dims,err_t2_paths,strfolder,'err_t2_paths.txt')
    write_output(dims,err_t2n_paths,strfolder,'err_t2n_paths.txt')
    write_output(dims,err_t3_paths,strfolder,'err_t3_paths.txt')
    write_output(dims,err_t3n_paths,strfolder,'err_t3n_paths.txt')
    write_output(dims,err_t4_paths,strfolder,'err_t4_paths.txt')
    write_output(dims,err_t4n_paths,strfolder,'err_t4n_paths.txt')
    write_output(dims,err_t10_paths,strfolder,'err_t10_paths.txt')
    write_output(dims,err_t10n_paths,strfolder,'err_t10n_paths.txt')
    write_output(dims,err_hada_paths,strfolder,'err_hada_paths.txt')
    write_output(dims,err_clip_paths,strfolder,'err_clip_paths.txt')
    return subfolder

def test_rho(mean0,sigma0,d,ps,nPaths,nSamples,bGenerate=True,folder='',generate_func=generate_gauss_paths,prefix=''):
    err_nonpr = []
    err_t1 = []
    err_t1n = []
    err_t2 = []
    err_t2n = []
    err_t3 = []
    err_t3n = []
    err_t4 = []
    err_t4n = []
    err_t10 = []
    err_t10n = []
    err_hada = []
    err_clip = []

    err_nonpr_paths = []
    err_t1_paths = []
    err_t1n_paths = []
    err_t2_paths = []
    err_t2n_paths = []
    err_t3_paths = []
    err_t3n_paths = []
    err_t4_paths = []
    err_t4n_paths = []
    err_t10_paths = []
    err_t10n_paths = []
    err_hada_paths = []
    err_clip_paths = []

    mean = [mean0]*d
    cov = sigma0*sigma0*np.eye(d)
    c = [0.0]*d
    r = 10.0*np.sqrt(d)
    u = 2.0*r
    T = 18
    for p in ps:
        non_pr = []
        mean_t1 = []
        mean_t1n = [] 
        mean_t2 = [] 
        mean_t2n = []
        mean_t3 = []
        mean_t3n = []
        mean_t4 = []
        mean_t4n = []
        mean_hada = []
        print('rho = ',p)
        assert(np.linalg.norm(np.array(mean)-np.array(c))<=r)
        assert(nSamples >= 3 * np.sqrt(d)*np.sqrt(T)/np.sqrt(2.0*p))
        non_pr, mean_t1, mean_t1n, mean_t2, mean_t2n, mean_t3, mean_t3n, mean_t4, mean_t4n, mean_t10, mean_t10n, mean_hada, mean_clip = test_dist(mean,cov,p,nPaths,nSamples,d,c,r,u,bGenerate=bGenerate,generate_func=generate_func,folder='')
        err_nonpr.append(np.mean(non_pr))
        err_t1.append(np.mean(mean_t1))
        err_t2.append(np.mean(mean_t2))
        err_t3.append(np.mean(mean_t3))
        err_t4.append(np.mean(mean_t4))
        err_t10.append(np.mean(mean_t10))
        err_t1n.append(np.mean(mean_t1n))
        err_t2n.append(np.mean(mean_t2n))
        err_t3n.append(np.mean(mean_t3n))
        err_t4n.append(np.mean(mean_t4n))
        err_t10n.append(np.mean(mean_t10n))
        err_hada.append(np.mean(mean_hada))
        err_clip.append(np.mean(mean_clip))

        err_nonpr_paths.append(non_pr)
        err_t1_paths.append(mean_t1)
        err_t2_paths.append(mean_t2)
        err_t3_paths.append(mean_t3)
        err_t4_paths.append(mean_t4)
        err_t10_paths.append(mean_t10)
        err_t1n_paths.append(mean_t1n)
        err_t2n_paths.append(mean_t2n)
        err_t3n_paths.append(mean_t3n)
        err_t4n_paths.append(mean_t4n)
        err_t10n_paths.append(mean_t10n)
        err_hada_paths.append(mean_hada)
        err_clip_paths.append(mean_clip)

    strps = '|'.join([str(pi) for pi in ps])
    params = [['nPaths','nSamples','d','mean','sigma','p','u','r'],[str(nPaths),str(nSamples),str(d),str(mean0),str(sigma0),strps,str(u),str(r)]]
    #'nPaths='+str(nPaths)+',nSamples='+str(nSamples)+',d='+str(d)+',sigma='+str(sigma0)+',epsilon='+str(eps)+',u='+str(u)+',r='+str(r)
    subfolder= str(nPaths) + prefix + 'paths_' + str(nSamples) + 'samples_' +str(len(ps)) + 'rho/'
    strfolder = folder + subfolder
    write_text(params,strfolder,'params.txt')
    write_output(ps,err_nonpr,strfolder,'err_nonpr.txt')
    write_output(ps,err_t1,strfolder,'err_t1.txt')
    write_output(ps,err_t1n,strfolder,'err_t1n.txt')
    write_output(ps,err_t2,strfolder,'err_t2.txt')
    write_output(ps,err_t2n,strfolder,'err_t2n.txt')
    write_output(ps,err_t3,strfolder,'err_t3.txt')
    write_output(ps,err_t3n,strfolder,'err_t3n.txt')
    write_output(ps,err_t4,strfolder,'err_t4.txt')
    write_output(ps,err_t4n,strfolder,'err_t4n.txt')
    write_output(ps,err_t10,strfolder,'err_t10.txt')
    write_output(ps,err_t10n,strfolder,'err_t10n.txt')
    write_output(ps,err_hada,strfolder,'err_hada.txt')
    write_output(ps,err_clip,strfolder,'err_clip.txt')

    write_output(ps,err_nonpr_paths,strfolder,'err_nonpr_paths.txt')
    write_output(ps,err_t1_paths,strfolder,'err_t1_paths.txt')
    write_output(ps,err_t1n_paths,strfolder,'err_t1n_paths.txt')
    write_output(ps,err_t2_paths,strfolder,'err_t2_paths.txt')
    write_output(ps,err_t2n_paths,strfolder,'err_t2n_paths.txt')
    write_output(ps,err_t3_paths,strfolder,'err_t3_paths.txt')
    write_output(ps,err_t3n_paths,strfolder,'err_t3n_paths.txt')
    write_output(ps,err_t4_paths,strfolder,'err_t4_paths.txt')
    write_output(ps,err_t4n_paths,strfolder,'err_t4n_paths.txt')
    write_output(ps,err_t10_paths,strfolder,'err_t10_paths.txt')
    write_output(ps,err_t10n_paths,strfolder,'err_t10n_paths.txt')
    write_output(ps,err_hada_paths,strfolder,'err_hada_paths.txt')
    write_output(ps,err_clip_paths,strfolder,'err_clip_paths.txt')
    return subfolder

def test_n(mean0,sigma0,d,p,nPaths,ns,bGenerate=True,folder='',generate_func=generate_gauss_paths,prefix=''):
    err_nonpr = []
    err_t1 = []
    err_t1n = []
    err_t2 = []
    err_t2n = []
    err_t3 = []
    err_t3n = []
    err_t4 = []
    err_t4n = []
    err_t10 = []
    err_t10n = []
    err_hada = []
    err_clip = []

    err_nonpr_paths = []
    err_t1_paths = []
    err_t1n_paths = []
    err_t2_paths = []
    err_t2n_paths = []
    err_t3_paths = []
    err_t3n_paths = []
    err_t4_paths = []
    err_t4n_paths = []
    err_t10_paths = []
    err_t10n_paths = []
    err_hada_paths = []
    err_clip_paths = []

    mean = [mean0]*d
    cov = sigma0*sigma0*np.eye(d)
    c = [0.0]*d
    r = 10.0*np.sqrt(d)
    u = 2.0*r
    T = 18
    for n in ns:
        non_pr = []
        mean_t1 = []
        mean_t1n = [] 
        mean_t2 = [] 
        mean_t2n = []
        mean_t3 = []
        mean_t3n = []
        mean_t4 = []
        mean_t4n = []
        mean_hada = []
        print('n = ',n)
        assert(np.linalg.norm(np.array(mean)-np.array(c))<=r)
        assert(n >= 10 * 2 * np.sqrt(d)*np.sqrt(T)/np.sqrt(2.0*p))
        non_pr, mean_t1, mean_t1n, mean_t2, mean_t2n, mean_t3, mean_t3n, mean_t4, mean_t4n, mean_t10, mean_t10n, mean_hada, mean_clip = test_dist(mean,cov,p,nPaths,n,d,c,r,u,bGenerate=bGenerate,generate_func=generate_func,folder='')
        err_nonpr.append(np.mean(non_pr))
        err_t1.append(np.mean(mean_t1))
        err_t2.append(np.mean(mean_t2))
        err_t3.append(np.mean(mean_t3))
        err_t4.append(np.mean(mean_t4))
        err_t10.append(np.mean(mean_t10))
        err_t1n.append(np.mean(mean_t1n))
        err_t2n.append(np.mean(mean_t2n))
        err_t3n.append(np.mean(mean_t3n))
        err_t4n.append(np.mean(mean_t4n))
        err_t10n.append(np.mean(mean_t10n))
        err_hada.append(np.mean(mean_hada))
        err_clip.append(np.mean(mean_clip))

        err_nonpr_paths.append(non_pr)
        err_t1_paths.append(mean_t1)
        err_t2_paths.append(mean_t2)
        err_t3_paths.append(mean_t3)
        err_t4_paths.append(mean_t4)
        err_t10_paths.append(mean_t10)
        err_t1n_paths.append(mean_t1n)
        err_t2n_paths.append(mean_t2n)
        err_t3n_paths.append(mean_t3n)
        err_t4n_paths.append(mean_t4n)
        err_t10n_paths.append(mean_t10n)
        err_hada_paths.append(mean_hada)
        err_clip_paths.append(mean_clip)

    strns = '|'.join([str(ni) for ni in ns])
    params = [['nPaths','nSamples','d','mean','sigma','p','u','r'],[str(nPaths),strns,str(d),str(mean0),str(sigma0),str(p),str(u),str(r)]]
    #'nPaths='+str(nPaths)+',nSamples='+str(nSamples)+',d='+str(d)+',sigma='+str(sigma0)+',epsilon='+str(eps)+',u='+str(u)+',r='+str(r)
    subfolder= str(nPaths) + prefix + 'paths_' + str(ns[0]) + '-' + str(ns[-1]) + 'samples_' +str(len(ns)) + '/'
    strfolder = folder + subfolder
    write_text(params,strfolder,'params.txt')
    write_output(ns,err_nonpr,strfolder,'err_nonpr.txt')
    write_output(ns,err_t1,strfolder,'err_t1.txt')
    write_output(ns,err_t1n,strfolder,'err_t1n.txt')
    write_output(ns,err_t2,strfolder,'err_t2.txt')
    write_output(ns,err_t2n,strfolder,'err_t2n.txt')
    write_output(ns,err_t3,strfolder,'err_t3.txt')
    write_output(ns,err_t3n,strfolder,'err_t3n.txt')
    write_output(ns,err_t4,strfolder,'err_t4.txt')
    write_output(ns,err_t4n,strfolder,'err_t4n.txt')
    write_output(ns,err_t10,strfolder,'err_t10.txt')
    write_output(ns,err_t10n,strfolder,'err_t10n.txt')
    write_output(ns,err_hada,strfolder,'err_hada.txt')
    write_output(ns,err_clip,strfolder,'err_clip.txt')

    write_output(ns,err_nonpr_paths,strfolder,'err_nonpr_paths.txt')
    write_output(ns,err_t1_paths,strfolder,'err_t1_paths.txt')
    write_output(ns,err_t1n_paths,strfolder,'err_t1n_paths.txt')
    write_output(ns,err_t2_paths,strfolder,'err_t2_paths.txt')
    write_output(ns,err_t2n_paths,strfolder,'err_t2n_paths.txt')
    write_output(ns,err_t3_paths,strfolder,'err_t3_paths.txt')
    write_output(ns,err_t3n_paths,strfolder,'err_t3n_paths.txt')
    write_output(ns,err_t4_paths,strfolder,'err_t4_paths.txt')
    write_output(ns,err_t4n_paths,strfolder,'err_t4n_paths.txt')
    write_output(ns,err_t10_paths,strfolder,'err_t10_paths.txt')
    write_output(ns,err_t10n_paths,strfolder,'err_t10n_paths.txt')
    write_output(ns,err_hada_paths,strfolder,'err_hada_paths.txt')
    write_output(ns,err_clip_paths,strfolder,'err_clip_paths.txt')
    return subfolder


def write_output(x,y,folder,filename):
    if not os.path.isdir(folder):
        os.makedirs(folder)
    if (len(np.array(y).shape)) == 1:
        out = [x,y]
    else:
        out = []
        for i in range(len(x)):
            out.append([x[i]])
            out[i].extend(y[i])
    np.savetxt(folder+filename,np.transpose(out))

def write_text(x,folder,filename):
    if not os.path.isdir(folder):
        os.makedirs(folder)
    fOut = open(folder+filename,'w')
    for xi in x:
        fOut.write(','.join(xi)+'\n')
    fOut.close()

def save_paths(X,nPaths,nSamples,d,folder,filename):
    if not os.path.isdir(folder):
        os.makedirs(folder)
    Y = np.reshape(X,(nPaths,d*nSamples))
    np.savetxt(folder+filename,np.transpose(Y))

def load_paths(filename,nPaths,nSamples,d):
    data = np.transpose(np.genfromtxt(filename))
    X = np.reshape(data,(nPaths,nSamples,d))
    return X

def make_plot(folder,savename,xlabel,ylabel,xscale='linear',yscale='linear',basex=1,basey=1,loc='upper left',xlim=None,ylim=None):
    dict_data = {}
    dict_data['err_clip.txt'] = ('yellow','o','clipped mean')
    dict_data['err_hada.txt'] = ('green','*','RR mean')
    dict_data['err_t1.txt'] = ('darkviolet','+','coinpress t=1')
    dict_data['err_t2.txt'] = ('red','+','coinpress t=2')
    dict_data['err_t3.txt'] = ('magenta','+','coinpress t=3')
    dict_data['err_t4.txt'] = ('gray','+','coinpress t=4')
    dict_data['err_t10.txt'] = ('olive','+','coinpress t=10')
    dict_data['err_nonpr.txt'] = ('blue','x','Non-Private')
    filelist = ['err_nonpr.txt','err_t1.txt','err_t2.txt','err_t3.txt','err_t4.txt','err_t10.txt','err_clip.txt','err_hada.txt']
    for filename in filelist:
        if not(filename in dict_data.keys()):
            continue
        data = np.transpose(np.genfromtxt(folder+filename))
        style = dict_data[filename]
        plt.plot(data[0],data[1],color=style[0],marker=style[1],label=style[2])
    plt.legend(loc=loc, framealpha=0.4)
    if not(xlim==None):
        plt.xlim(xlim)
    if not(ylim==None):
        plt.ylim(ylim)
    plt.xscale(xscale)
    plt.yscale(yscale)
    plt.xlabel(xlabel,fontsize=16)
    plt.ylabel(ylabel,fontsize=16)
    plt.grid()
    plt.savefig(folder+savename,bbox_inches='tight', pad_inches=0)
    plt.clf()
