import numpy as np

from sklearn.base import TransformerMixin
from scipy.linalg import eigh


class Random(TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y):
        return self

    def transform(self, X):
        return np.random.randn(X.shape[0], 1)


class logDiag(TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y):
        return self

    def transform(self, X):
        return np.log(np.diagonal(X, axis1=1, axis2=2))


class Ledoit(TransformerMixin):
    def __init__(self, a):
        self.a = a

    def fit(self, X, y):
        return self

    def transform(self, X):
        a = self.a
        p = X.shape[1]
        return np.array([x * (1 - a) + a / p * np.eye(p) for x in X])


class Common(TransformerMixin):
    def __init__(self, n_projs):
        self.n_projs = n_projs

    def fit(self, X, y):
        C = np.mean(X, axis=0)
        _, U = eigh(C)
        self.proj = U.T[-self.n_projs:]
        return self

    def transform(self, X):
        P = self.proj
        return np.array([P.dot(x.dot(P.T)) for x in X])


class Spoc(TransformerMixin):
    def __init__(self, n_projs, reg=False):
        self.n_projs = n_projs
        self.reg = reg

    def fit(self, X, y):
        y_ = y - np.mean(y)
        y_ /= np.std(y_)
        Cw = np.mean(y_[:, None, None] * X, axis=0)
        C = np.mean(X, axis=0)
        if self.reg:
            C += self.reg * np.eye(X.shape[1])
        eig, U = eigh(Cw, C)
        order = np.argsort(np.abs(eig))[::-1]
        U = U[:, order]
        self.proj = U.T[:self.n_projs]
        return self

    def transform(self, X):
        P = self.proj
        return np.array([P.dot(x.dot(P.T)) for x in X])


class RandomProj(TransformerMixin):
    def __init__(self, n_projs):
        self.n_projs = n_projs

    def fit(self, X, y):
        _, U = eigh(np.random.randn(X.shape[1], X.shape[1]))
        self.proj = U.T[-self.n_projs:]
        return self

    def transform(self, X):
        P = self.proj
        return np.array([P.dot(x.dot(P.T)) for x in X])


class Ravel(TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y):
        return self

    def transform(self, X):
        return X.reshape(X.shape[0], -1)
