import numpy as np
import time
from coinpress_mod.functions import coinpress_mean
from quantile_binary_search.method import random_rotation_mean, clipped_mean
from tests_functions import test_dim
from scipy import stats
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


for cc in range(10):
    p_l = [0.1, 0.2, 0.4, 0.8, 1.6]
    data = np.load('./data/syn_data_{}.npy'.format(cc))
    x = data[:, 1:]/255.

    n = x.shape[0]
    d = x.shape[1]
    mean = np.mean(x, axis=0)
    #eps = 0.2
    #p = eps*eps*0.5
    #eps_l = [0.1, 0.3, 0.5, 0.7, 1.0]
    p_l = [0.1, 0.2, 0.4, 0.8, 1.6]
    r = 50*np.sqrt(d)
    u = 2*r

    # shape (n, d)
    #cov_sqrt = np.random.uniform(0, 1, size=(d, d))
    #cov = np.dot(cov_sqrt, cov_sqrt.transpose())
    #cov = np.diag(np.random.uniform(0, 10, size=d))
    #x = np.random.multivariate_normal(mean, cov, int(n))
    #x = np.random.laplace(mean[0], sigma, (int(n),d))
    #x = np.random.standard_t(df=3, size=(int(n), d)) + mean
    #x = np.random.uniform(0, 5, size=(int(n), d)) + mean

    n_trials = 30

    cp_mean_t1_err_l = []
    cp_mean_t2_err_l = []
    cp_mean_t3_err_l = []
    cp_mean_t4_err_l = []
    cp_mean_t10_err_l = []
    rr_err_l = []

    for p in p_l:
        cp_mean_t1_l = []
        cp_mean_t2_l = []
        cp_mean_t3_l = []
        cp_mean_t4_l = []
        cp_mean_t10_l = []
        rr_l = []

        for _ in range(n_trials):
            c = [0.0]*d

            cp_mean_t1 = coinpress_mean(x,c,r,1,p)
            cp_mean_t2 = coinpress_mean(x,c,r,2,p)
            cp_mean_t3 = coinpress_mean(x,c,r,3,p)
            cp_mean_t4 = coinpress_mean(x,c,r,4,p)
            cp_mean_t10 = coinpress_mean(x,c,r,10,p)

            cp_mean_t1_l.append(np.linalg.norm(cp_mean_t1-mean))
            cp_mean_t2_l.append(np.linalg.norm(cp_mean_t2-mean))
            cp_mean_t3_l.append(np.linalg.norm(cp_mean_t3-mean))
            cp_mean_t4_l.append(np.linalg.norm(cp_mean_t4-mean))
            cp_mean_t10_l.append(np.linalg.norm(cp_mean_t10-mean))

            #mean_clipped = clipped_mean(x,n,d,u,p,threshold=None)
            #print('clipped_mean error:', np.linalg.norm(mean_clipped-mean))

            d_pad = 1024
            x_pad = np.zeros((n, d_pad))
            x_pad[:, :d] = x
            y_hat = random_rotation_mean(x_pad, d_pad, u, p)
            y_hat = y_hat[:d]

            rr_l.append(np.linalg.norm(y_hat-mean))

        cp_mean_t1_err_l.append(stats.trim_mean(cp_mean_t1_l, 0.1))
        cp_mean_t2_err_l.append(stats.trim_mean(cp_mean_t2_l, 0.1))
        cp_mean_t3_err_l.append(stats.trim_mean(cp_mean_t3_l, 0.1))
        cp_mean_t4_err_l.append(stats.trim_mean(cp_mean_t4_l, 0.1))
        cp_mean_t10_err_l.append(stats.trim_mean(cp_mean_t10_l, 0.1))
        rr_err_l.append(stats.trim_mean(rr_l, 0.1))

    np.savetxt('./mnist_{}.txt'.format(cc), np.array([cp_mean_t1_err_l, cp_mean_t2_err_l, cp_mean_t3_err_l, cp_mean_t4_err_l, cp_mean_t10_err_l, rr_err_l]))
    '''
    res = np.loadtxt('./mnist_{}.txt'.format(cc))
    cp_mean_t1_err_l = res[0]
    cp_mean_t2_err_l = res[1]
    cp_mean_t3_err_l = res[2]
    cp_mean_t4_err_l = res[3]
    cp_mean_t10_err_l = res[4]
    rr_err_l = res[5]
    '''

    plt.plot(p_l, cp_mean_t1_err_l, color='darkviolet', marker='+', label=r'COINPRESS $t=1$')
    plt.plot(p_l, cp_mean_t2_err_l, color='red', marker='+', label=r'COINPRESS $t=2$')
    plt.plot(p_l, cp_mean_t3_err_l, color='magenta', marker='+', label=r'COINPRESS $t=3$')
    plt.plot(p_l, cp_mean_t4_err_l, color='gray', marker='+', label=r'COINPRESS $t=4$')
    plt.plot(p_l, cp_mean_t10_err_l, color='olive', marker='+', label=r'COINPRESS $t=10$')
    plt.plot(p_l, rr_err_l, color='green', marker='*', label='Shifted-CM')

    plt.legend(loc='upper right', framealpha=0.4)
    plt.xlim((0.08, 2.2))
    plt.xscale('log', basex=2)
    plt.yscale('log')
    plt.xlabel(r'$\rho$',fontsize=16)
    plt.ylabel(r'$\ell_2$ error',fontsize=16)
    plt.grid()
    plt.savefig('./mnist_{}.pdf'.format(cc),bbox_inches='tight', pad_inches=0)
    plt.clf()



