import numpy
import random
import scipy.linalg
from sklearn.linear_model import HuberRegressor

FILTER_CONSTANT = 2
tol = 1e-8

class FilterData:
    def __init__(self, xi, present):
        self.xi = xi #m x p
        self.present = present #m
        self.mu = (1.0/sum(self.present)) * self.present @ self.xi
        self.cov = (1.0/sum(self.present)) * xi.T @ (xi * self.present[:,None]) - self.mu @ self.mu.T
    def mean(self):
        return self.mu
    def covariance(self):
        return self.cov
    def threshold(self, v, t):
        scores = ((self.xi - self.mu[:,None].T) @ v)**2
        scores = numpy.ravel(scores)
        scores = scores * self.present
        if numpy.mean(scores) < t:
            return False
        T = random.uniform(0, max(scores))
        self.present = self.present * (scores <= T)
        return True
    

class IV_Moments:
    def __init__(self, X, Y, Z):
        self.X = numpy.array(X) #m x d
        self.Y = numpy.array(Y) #m
        self.Z = numpy.array(Z) #m x p
        self.m = self.X.shape[0]
        self.d = self.X.shape[1]
        self.p = self.Z.shape[1]
        self.present = numpy.ones((self.m))
    def reset(self):
        self.present = numpy.ones((self.m))
    def compute_gradients(self):
        self.gradients = -(1.0/sum(self.present))*self.Z.T @ numpy.diag(self.present) @ self.X
    def moment_mean(self, w):
        return (1.0/sum(self.present))*self.Z.T @ ((self.Y - self.X @ w) * self.present)
    def gradient_mean(self, w, u):
        return -(1.0/sum(self.present))*self.Z.T @ numpy.diag(self.present) @ self.X @ u
    def moment_covariance(self, w, mu):
        xi = self.Z * (self.Y - self.X@w)[:,None]
        return (1.0/sum(self.present)) * xi.T @ numpy.diag(self.present) @ xi - mu.T@mu
    def gradient_covariance(self, w, u, mu):
        xi = -self.Z * (self.X @ u)[:,None]
        return (1.0/sum(self.present)) * xi.T @ numpy.diag(self.present) @ xi - mu.T@mu
    def optimize(self):
        A = self.Z.T @ (self.X * self.present[:,None])
        b = self.Z.T @ (self.Y * self.present)
        try:
            return scipy.linalg.solve(A,b)
        except:
            return numpy.linalg.pinv(A) @ b
    
        
def gmm_sever(moments, L, R, sigma, w0):
    p = moments.p
    d = moments.d
    while(1):
        w = moments.optimize()
        u = moments.moment_mean(w)
        F = FilterData(moments.Z * (moments.X @ u)[:,None], moments.present)
        C = F.covariance()
        v = scipy.linalg.eigh(C, subset_by_index=[p-1,p-1])[1]
        if numpy.linalg.norm(u) > tol:
            if F.threshold(v, FILTER_CONSTANT*L*L*(u.T@u)):
                moments.present = F.present
                continue
        F = FilterData(moments.Z * (moments.Y - moments.X @ w)[:,None], moments.present)
        C = F.covariance()
        v = scipy.linalg.eigh(C, subset_by_index=[p-1,p-1])[1]
        if F.threshold(v, FILTER_CONSTANT*(sigma**2 + (L**2)*((R + numpy.linalg.norm(w - w0))**2))):
            moments.present = F.present
        else:
            return w

def repeated_gmm_sever(moments, L, R, sigma, steps):
    w0 = numpy.zeros((moments.d))
    for i in range(steps):
        moments.reset()
        w0 = gmm_sever(moments, L, R, sigma, w0)
        R = R / 2.0
    return w0

def two_stage_sls(X, Y, Z):
    m = X.shape[0]
    d = X.shape[1]
    p = Z.shape[1]
    PX = numpy.zeros(X.shape)
    for i in range(d):
        # regress X_i against Z
        linear = LinearRegression(fit_intercept = False).fit(Z, X[:,i])
        PX[:,i] = linear.predict(Z)
    # regression Y against PX
    linear = LinearRegression(fit_intercept = False).fit(PX, Y)
    return linear.coef_

def two_stage_robust_sls(X, Y, Z, threshold=1.35):
    m = X.shape[0]
    d = X.shape[1]
    p = Z.shape[1]
    PX = numpy.zeros(X.shape)
    for i in range(d):
        # regress X_i against Z
        huber = HuberRegressor(epsilon=threshold).fit(Z, X[:,i])
        PX[:,i] = huber.predict(Z)
    # regression Y against PX
    huber = HuberRegressor(epsilon=threshold).fit(PX, Y)
    return huber.coef_
