import math 
import numpy as np
import sys

def gen_data(d, n, N, gamma, sigma, seed=None, func='quad'):
    """This function generates the data for the random feature problem.
    
    Arguments:
        d: The dimension of the data.
        n: The number of observations.
        N: The number of hidden units.
        gamma: a d-by-d matrix for generating the data.
        sigma: a d-vector for the covariance of the weights.
        seed: The random seed used for replication purposes.        
        func: The function generating the data. Valid choices are
                {quad, mixture}.
    Returns:
        y: a vector of n-by-1 of the dependent variables.
        W: a matrix of m-by-d with rows distributed on the unit sphere.
        X: a matrix of n-by-d with rows distributed on sqrt(d) sphere.
    """
    assert len(sigma) == d
    assert gamma.shape[0] == d and gamma.shape[1] == d    
    # Normalize the scaling
    sigma = sigma / np.mean(sigma)
    tau = np.diag(sigma)
    zeta_1 = np.trace(np.dot(tau, gamma)) / (d + 0.0)
    zeta_2 = np.mean(sigma ** 2.0)
    y_stat = np.trace(np.dot(gamma, gamma)) / (d + 0.0)
    
    if seed is not None:
        np.random.seed(seed)    
    sigma_half = np.diag(sigma ** 0.5)        
    W = np.random.normal(size=(N, d))         
    W = np.dot(W, sigma_half)
    W = W / np.sqrt(d)
      
        
    if func == 'quad':
        X = np.random.normal(size=(n, d))
        mean = np.trace(gamma)        
        z = np.dot(X, gamma)
        z = np.sum(np.multiply(X, z), axis=1, keepdims=True) - mean
        y = z / np.sqrt(d)            
    elif func == 'mixture':
        y = np.zeros((n, 1))
        y[:(n/2)] = 1.0
        y[(n/2):] = -1.0
        m = d
        S1 = np.ones((d,))        
        S1[:m] = 1.0 - 1.0 / np.sqrt(d)
        S1[:(m / 3)] = 1.0 - 2.0 / np.sqrt(d)
        S1[(m / 3):(2 * m / 3)] = 1.0 - 1.5 / np.sqrt(d)
        
        S2 = np.ones((d,))
        S2[:m] = 1.0 + 1.0 / np.sqrt(d)
        S2[:(m / 3)] = 1.0 + 2.0 / np.sqrt(d)
        S2[(m / 3):(2 * m / 3)] = 1.0 + 1.5 / np.sqrt(d)
        
        X = np.random.normal(size=(n, d))
        X[:(n/2), :] = np.dot(X[:(n/2), :], np.diag(S1 ** 0.5))
        X[(n/2):, :] = np.dot(X[(n/2):, :], np.diag(S2 ** 0.5))
    else:
        raise Exception('%s is not a valid function'%(func))
    return (y, W, X, (zeta_1, zeta_2, y_stat))

def fit_model(y, W, X, sigma):    
    """This function fits the second layer of a random feature network. 
    
    Arguments:
        y: a n-by-1 vector of dependent variables. 
        W: a N-by-d matrix of random features. 
        X: a n-by-d matrix of the independent variables.
        sigma: the random feature mapping.          
        
    Returns: 
        opt_vals: Training MSE.
        var(y): Naive MSE from fitting a constant.
        ah: The optimal value for the second layer.
        biasval: The optimal value for the bias parameter. For now fixed to zero.
    """
    eps = 0.01
    N = W.shape[0]
    n = X.shape[0]
    Z = np.dot(X, W.T)
    Z = sigma(Z)
    temp = np.dot(Z.T, Z) / (n + 0.0) + eps / (n + 0.0) * np.eye(N)
    tau = np.dot(Z.T, y) / (n + 0.0)
    assert tau.shape[1] == 1
    beta = np.linalg.solve(temp, tau)
    yhat = np.dot(Z, beta)
    err = np.mean(np.power(yhat - y, 2))
    return (err, np.var(y), beta, 0.0)

def mse(ah, n, W, gamma, sigma, func='quad'):
    """This function finds the out-of-sample mse for a fitted model.
    
    Arguments:
        ah: a N-by-1 vector of the second-layer weights.        
        n: The number of samples to be used for averaging MSE.
        W: a N-by-d matrix of random features. 
        gamma: The hyper-parameter for generating the ys.              
        sigma: the random feature mapping.         
        func: The function generating the data. Valid choices are
                {quad, third_deg}.
        
    Returns: 
        err: The out-of-sample MSE of the data.
    """    
    d = W.shape[1]
    N = W.shape[0]        
    y, _, X, _ = gen_data(d, n, N, gamma, np.ones((d,)), None, func)
    Z = sigma(np.dot(X, W.T))
    pred = np.dot(Z, ah)
    assert pred.shape == y.shape
    err = np.mean(np.power(pred - y, 2))
    return err

def num_obs_exp(rho, gamma, sigma, nonlin, func, T, name, save=False):    
    """ This function studies the evolution of the MSE as the number of observations
    increase. 
    
    Arguments:
        d: The dimension of the data.
        T: The number of independent runs.                
        func: The function generating the data. Valid choices are {quad, third_deg}.
        
        save: Whether to save the output on disk.
    Returns: 
        values: The aggregated results of the simulations.
        n_mat: The array of the number of observations used.
        N: The number of hidden units.
    """   
    d = len(sigma)
    assert gamma.shape[0] == d and gamma.shape[1] == d
    N = int(d / rho)
    n_mat = np.linspace(30, 300, num=50) * d # change the limits
    values = np.zeros((len(n_mat), T, 3))        
    for t in range(T):        
        _, W, _, _ = gen_data(d, 1, N, gamma, sigma, None, func)
        for j in range(len(n_mat)):
            n = int(n_mat[j])
            print('Iteration Number: %d, Number of obs used %d'%(t, n))        
            y, _, X, _ = gen_data(d, n, N, gamma, sigma, None, func)
            values[j, t, 1], values[j, t, 2], ah, _ = fit_model(y, W, X, nonlin)
            values[j, t, 0] = mse(ah, 10000, W, gamma, nonlin, func)                
    if save:
        np.save('gaussian_data/GaussianRF_Evolution_d_%d_N_%d_%s_%s.npy'%(d, N, func, name), values)
    return (values, n_mat, N)

def final_mse(index, gamma, sigma, nonlin, func, T, name, save=False):    
    """ This function studies the evolution of the MSE as the number of observations
    increase. 
    
    Arguments:
        d: The dimension of the data.
        T: The number of independent runs.                
        func: The function generating the data. Valid choices are {quad, third_deg}.
        
        save: Whether to save the output on disk.
    Returns: 
        values: The aggregated results of the simulations.
        n_mat: The array of the number of observations used.
        N: The number of hidden units.
    """   
    d = len(sigma)
    assert gamma.shape[0] == d and gamma.shape[1] == d    
    n = int(5 * 10 ** 5)
    a1 = 2 ** np.linspace(-5, 0, num=30)
    a2 = 2 ** np.linspace(0, 4, num=11)
    rho_mat = np.concatenate([a1, a2[1:]])
    
    values = np.zeros((T, 3)) 
    index = int(index)
    rho = rho_mat[index]
    N = int(d / rho)    
    for t in range(T):        
        print('Realization Number: %d, Number of obs used: %d, Number of units: %d'%(t, n, N))        
        y, W, X, _ = gen_data(d, n, N, gamma, sigma, None, func)
        values[t, 1], values[t, 2], ah, _ = fit_model(y, W, X, nonlin)
        values[t, 0] = mse(ah, 10000, W, gamma, nonlin, func)                
        if save:
            np.save('mixture/final_mse/finalmse_d_%d_N_%d_%s_%s.npy'%(d, N, func, name), values)
    return values

d = int(sys.argv[1])
T = int(sys.argv[2])
rho = np.float(sys.argv[3])
mode = sys.argv[4]
experiment_type = sys.argv[5]
func = 'mixture'
name = 'square'

print('Experiment Summary:')
print('The data dimension is %d'%(d))
print('Number of experiments is %d'%(T))
print('Aspect ratio rho is %f'%(rho))
print('Data generating function is %s'%(func))
print('acitvation is %s'%(name))
print('Sigma is chosen according to the mode: %s'%(mode))
print('Experiment type is: %s'%(experiment_type))
sys.stdout.flush()

if name == 'relu':
    def nonlinearity(x):
        return np.maximum(x, 0) - 1.0 / np.sqrt(2 * math.pi)
else:
    def nonlinearity(x):
        return np.power(x, 2) - 1.0

np.random.seed(100)
gamma = np.random.exponential(1, size=(d,))
gamma = np.diag(gamma)
if mode == 'naive':
	sigma = np.ones((d,))
elif mode == 'optim':
	sigma = np.diag(gamma)
expname = name + '_' + mode + '_' + experiment_type
if experiment_type == 'num_obs':
	vals = num_obs_exp(rho, gamma, sigma, nonlinearity, func, T, expname, True)
if experiment_type == 'final_mse':
    vals = final_mse(rho, gamma, sigma, nonlinearity, func, T, expname, True)
