import numpy as np
import scipy.linalg as l
from numpy.linalg import cholesky, det
from numpy.linalg import inv, pinv
from numpy.linalg import det


def exponential_cov(x, y, params):
    '''
    Description: Squared exponential kernel definition.
    '''
    return params[0] * np.exp( -0.5 * params[1] * np.subtract.outer(x, y)**2)

def conditional(x_new, x, y, params):
    '''
    Description: Gaussian conditioning rule.
    Parameters: 
        - x_new: New points to be predicted by GP.
        - x, y: Existing input and output pairs.
    '''

    B = exponential_cov(x_new, x, params)
    C = exponential_cov(x, x, params)
    A = exponential_cov(x_new, x_new, params)
    mu = np.linalg.inv(C).dot(B.T).T.dot(y)
    sigma = A - B.dot(np.linalg.inv(C).dot(B.T))
    return(mu.squeeze(), sigma.squeeze())

def predict(x, x_star, kernel, params, K, f):
    '''
    Description: Gaussian conditioning rule.
    Parameters: 
        - x: Known points.
        - x_star: Points to fit.
        - kernel: Kernel function used.
        - f: 
    '''

    K_s = [kernel(x, x_, params) for x_ in x_star]
    Kinv = np.linalg.inv(K)
    K_ss = kernel(x, x, params)

    mean = np.dot(K_s, Kinv).dot(f)
    cov = K_ss - np.dot(K_s, Kinv).dot(K_s)

    return mean, cov

def predict_noisy(x, x_star, kernel, params, K, f, noise=0.5):
    '''
    Description: Gaussian conditioning rule with noise.
    Parameters: 
        - x: Known points.
        - x_star: Points to fit.
        - kernel: Kernel function used.
        - f: Signal values.
    '''

    K_s = [kernel(x, x_, params) for x_ in x_star]
    Kinv = np.linalg.inv(K + (noise**2)*np.eye(len(K)))
    K_ss = kernel(x, x, params)

    mean = np.dot(K_s, Kinv).dot(f)
    cov = K_ss - np.dot(K_s, Kinv).dot(K_s)

    return mean, cov


def ARD_kernel(x_idx, y_idx, Q, Lambda, size, beta=5, sigma=1, noise = 0):
    '''
    Description: Diffusion kernel to compare graphs.
    Parameters: 
        - x_idx, y_idx: Indices for the nodes we want to compare, which define slice on Gram matrix. 
        - Q, Lamda: Eigenvector and Eigenvalue matrices.
    '''

    K = Q.dot(l.expm(-beta*Lambda)).dot(Q.T)
    V = np.eye(size)*(1/np.sqrt(np.diag(K)))
    K = V.dot(K).dot(V.T)
    K = K[:,y_idx]
    K = K[x_idx,:]
    return K*sigma + noise**2

def predict_GGP(x, x_star, kernel, Q, Lambda, beta, K, f, size, noise=0.1, sigma=1):
    '''
    Description: Gaussian conditioning rule for graphs.
    Parameters: 
        - x: Known points.
        - x_star: Points to fit.
        - kernel: Kernel function used.
        - f: Signal values.
    '''
    
    K_s = [kernel(x_element, x_star, Q, Lambda, size, beta=beta, sigma=sigma)for x_element in x][0]
    Kinv = np.linalg.inv(K+ (noise**2)*np.eye(len(K))) 
    K_ss = kernel(x, x, Q, Lambda, size, beta=beta, sigma=sigma)[0,0]


    mean = np.dot(K_s, Kinv).dot(f)
    cov = K_ss - np.dot(K_s, Kinv).dot(K_s)

    return mean, cov


def ARD_kernel_nll(x_idx, y_idx, Q, Lambda, size, beta=1, sigma=1, noise=0):
    '''
    Parameters:
    - x_idx, y_idx: Indices for the nodes we want to compare, which define slice on Gram matrix. 
    - Q, Lamda: Eigencevtor and Eigenvalue matrix
    '''
    K = Q.dot(l.expm(-beta*Lambda)).dot(Q.T)
    V = np.eye(size)*(1/np.sqrt(np.diag(K)))
    K = V.dot(K).dot(V.T)

    return np.array(K[np.ix_(x_idx,y_idx)]) + noise**2*np.eye(K[np.ix_(x_idx,y_idx)].shape[0])


def nll_fn_graph(X_train, Y_train, Q, Lambda, size, noise=0, naive=True):

    Y_train = Y_train.ravel()
    
    def nll_naive(theta):
        K = ARD_kernel_nll(X_train, X_train, Q, Lambda, size, beta=theta[0], noise=noise) 
       
        return 0.5 * np.log(det(K)) + \
               0.5 * Y_train.dot(pinv(K).dot(Y_train)) + \
               0.5 * len(X_train) * np.log(2*np.pi)
        

    return nll_naive