import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
from itertools import permutations
import time
from tqdm import tqdm
import seaborn as sns
import scipy

def sigmaSquareHat(X):
    ## X: n * p data matrix, nparray
    n = X.shape[0]
    H = X @ X.T
    S1 = (np.sum(H**2) - np.sum(np.diag(H)**2)) / (n*(n-1))
    H_square = H @ H
    S2 = (np.sum(H_square) - np.sum(np.diag(H_square)) - 
          2 * np.diag(H)@np.sum(H,0) + 2 * np.diag(H)@np.diag(H))
    S2 = S2 / (n*(n-1)*(n-2))
    H_zero_diag = H - np.diag(np.diag(H))
    S3 = np.sum(H_zero_diag) - 4 * np.sum(np.sum(H_zero_diag, 0)**2)
    S3 = S3 / (n*(n-1)*(n-2)*(n-3))
    return S1 - 2 * S2 + S3

def sigma_square_hat(Y):
    n = np.shape(Y)[0]
    return (np.sum(Y**2) - n * np.mean(Y)**2) / (n-1)

def cgz_t(X, Y):
    # deltaX
    X_scale = X - np.tile(np.mean(X, 0),[n,1])
    H_scale = X_scale @ X_scale.T
    H = X @ X.T
    Q = np.tile(np.diag(H), [n, 1])
    DeltaX = H_scale + (Q + Q.T - 2 * H)/(2*n)
    
    #deltaY
    Y = np.matrix(Y).T
    Y_scale = Y - np.mean(Y)
    G_scale = Y_scale @ Y_scale.T
    G = Y @ Y.T
    D = np.tile(np.multiply(Y_scale, Y_scale), [1, n])
    DeltaY = G_scale + (D + D.T - 2 * G)/(2*n)
    # calculate T(n, p)
    DeltaXY = DeltaX * DeltaY
    DeltaXY_zero_diag = DeltaXY - np.diag(np.diag(DeltaXY))
    T = n / ((n-1)*(n-2)**2) * np.sum(DeltaXY_zero_diag)
    return T

def cgz(X, Y):
    # return z score of the test
    return n * cgz_t(X, Y) / (sigma_square_hat(Y) * np.sqrt(2 * sigmaSquareHat(X)))

def zc_t(X, Y):
    H = X @ X.T
    T = 0
    n = X.shape[0]
    for (i1,i2,i3,i4) in permutations(np.arange(n), 4):
        T += (H[i1,i3] + H[i2,i4] - H[i1,i4] - H[i2,i3]) * \
        (Y[i1,0]-Y[i2,0]) * (Y[i3,0]-Y[i4,0]) / 4
    return T / (n*(n-1)*(n-2)*(n-3))

def zc(X, Y):
    # return z score of the test
    return n * zc_t(X, Y) / (sigma_square_hat(Y) * np.sqrt(2 * sigmaSquareHat(X)))

def generate_beta(p):
    return (np.ones(shape=p) + np.random.binomial(n=5, p=0.6, size=p) + \
        # 0.3 * np.random.normal(size=p)) * np.array([i**(5/6) for i in range(1,p+1)])
        0.3 * np.random.normal(size=p))

# signal distribution
def signal_dist_empirical(p, n, k, N):
    Signal = np.zeros(N)
    for i in tqdm(range(N)):
    # for i in range(N):
        Sigma_sqrt_diag = np.array([1/i for i in range(1,p+1)])
        # beta = np.random.binomial(n=3, p=0.3, size=p) + 0.3 * np.random.normal(size=p)
        beta = generate_beta(p)
        beta = beta / np.sqrt(np.sum((Sigma_sqrt_diag * beta) ** 2))
        X   = Sigma_sqrt_diag * np.random.normal(size=[n, p])
        Y   = X @ beta + np.random.normal(size=n)/n
        RR = 4
        for rr in range(RR):
            S   = np.random.normal(size=[p, k])
            XS  = X @ S
            H   = XS @ np.linalg.inv(XS.T @ XS) @ XS.T
            num = Y.T @ H @ Y
            den = np.sum(Y**2) - num
            Signal[i] += (num * (n-k)) / (den * k)
        Signal[i] = Signal[i]/RR
        # print(num/k, den/(n-k))
        # Signal[i] = den / (n-k)
    return Signal

# signal distribution
def signal_dist_empirical_cgz(p, n, k, N):
    Signal = np.zeros(N)
    for i in tqdm(range(N)):
    # for i in range(N):
        Sigma_sqrt_diag = np.array([1/i for i in range(1,p+1)])
        beta = generate_beta(p)
        # beta = np.random.normal(size=p) * np.array([i for i in range(1,p+1)])
        beta = beta / np.sqrt(np.sum((Sigma_sqrt_diag * beta) ** 2))
        X   = Sigma_sqrt_diag * np.random.normal(size=[n, p])
        Y   = X @ beta + np.random.normal(size=n)/n
        Signal[i] = cgz(X, Y)
    return Signal

# signal distribution
def signal_dist_empirical_cgz_null(p, n, k, N):
    Signal = np.zeros(N)
    for i in tqdm(range(N)):
    # for i in range(N):
        Sigma_sqrt_diag = np.array([1/i for i in range(1,p+1)])
        # beta = np.random.binomial(n=3, p=0.3, size=p) + 0.3 * np.random.normal(size=p)
        beta = np.zeros(shape=p)
        X   = Sigma_sqrt_diag * np.random.normal(size=[n, p])
        Y   = X @ beta + np.random.normal(size=n)/n
        Signal[i] = cgz(X, Y)
    return Signal

N = 200
alpha = 0.95

p_lst = [100, 300, 1000, 2000, 5000, 10000]
error_lst = [0] * 6
error_cgz = [0] * 6
for i in range(6):
    p = p_lst[i]
    n = int(10 * np.log(p)**2)
    k = int(n//2)
    print(p, n, k, N)
    signal = signal_dist_empirical_cgz(p, n, k, N)
    error_cgz[i] = np.sum(abs(signal)>2)/N
    signal = signal_dist_empirical(p, n, k, N)
    error_lst[i] = np.sum(signal>scipy.stats.f.ppf(alpha, k, n-k))/N
    print(error_cgz[i], error_lst[i])
np.save('empirical_2.npy', error_lst)
np.save('empirical_cgz_2.npy', error_cgz)
