# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

""" Utility classes and functions.
"""

from sklearn.linear_model import LinearRegression
import numpy as np
import copy

class RegWrapper:
    """
    A simple wrapper that makes a binary classifier behave like a regressor.
    Essentially .fit, calls the fit method of the classifier and
    .predict calls the .predict_proba method of the classifier
    and returns the probability of label 1.
    """

    def __init__(self, clf):
        """
        Parameters
        ----------
        clf : the classifier model
        """
        self._clf = clf

    def fit(self, X, y):
        """
        Parameters
        ----------
        X : features
        y : binary label
        """
        self._clf.fit(X, y)
        return self

    def predict(self, X):
        """
        Parameters
        ----------
        X : features
        """
        return self._clf.predict_proba(X)[:, 1]

    def __getattr__(self, name):
        if name == 'get_params':
            raise AttributeError("not sklearn")
        return getattr(self._clf, name)

    def __deepcopy__(self, memo):
        return RegWrapper(copy.deepcopy(self._clf, memo))

class SubsetWrapper:
    """
    A simple wrapper that fits the data on a subset of the
    features given by an index list.
    """

    def __init__(self, model, inds):
        """
        Parameters
        ----------
        model : an sklearn model
        inds : a subset of the input features to use
        """
        self._model = model
        self._inds = inds

    def fit(self, X, y):
        """
        Parameters
        ----------
        X : features
        y : binary label
        """
        self._model.fit(X[:, self._inds], y)
        return self

    def predict(self, X):
        """
        Parameters
        ----------
        X : subset of features that correspond to inds
        """
        return self._model.predict(X)
    
    def __getattr__(self, name):
        if name == 'get_params':
            raise AttributeError("not sklearn")
        return getattr(self._model, name)

    def __deepcopy__(self, memo):
        return SubsetWrapper(copy.deepcopy(self._model, memo), self._inds)

class SelectiveLasso:
    
    def __init__(self, inds, lasso_model):
        self.inds = inds
        self.lasso_model = lasso_model
        self.model_Y_X = LinearRegression(fit_intercept=False)
        self.model_X1_X2 = LinearRegression(fit_intercept=False)
        self.model_X2 = LinearRegression(fit_intercept=False)
        

    def fit(self, X, y):
        self.n_feats = X.shape[1]
        inds = self.inds
        inds_c = np.setdiff1d(np.arange(self.n_feats), self.inds)
        self.inds_c = inds_c
        if len(inds_c)==0:
            self.lasso_model.fit(X, y)
            return self
        res_y = y - self.model_Y_X.fit(X[:, inds_c], y).predict(X[:, inds_c])
        res_X1 = X[:, inds] - self.model_X1_X2.fit(X[:, inds_c], X[:, inds]).predict(X[:, inds_c])
        self.lasso_model.fit(res_X1, res_y)
        self.model_X2.fit(X[:, inds_c], y - self.lasso_model.predict(X[:, inds]))
        return self
    
    def predict(self, X):
        inds = self.inds
        inds_c = self.inds_c
        if len(inds_c)==0:
            return self.lasso_model.predict(X)
        return self.lasso_model.predict(X[:, inds]) + self.model_X2.predict(X[:, inds_c])
    
    @property
    def model(self):
        return self.lasso_model

    @property
    def coef_(self):
        coef = np.zeros(self.n_feats)
        inds = self.inds
        inds_c = self.inds_c
        if len(inds_c)==0:
            return self.lasso_model.coef_
        coef[inds] = self.lasso_model.coef_
        coef[inds_c] = self.model_X2.coef_
        return coef
    
    @property
    def intercept_(self):
        return self.lasso_model.intercept_ + self.model_X2.intercept_
