# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Doubly Robust IV for Heterogeneous Treatment Effects.

An Doubly Robust machine learning approach to estimation of heterogeneous
treatment effect with an endogenous treatment and an instrument.

"""

import numpy as np
from sklearn.model_selection import KFold, train_test_split, StratifiedKFold
from econml.utilities import hstack
from sklearn.base import clone


class DRIV:
    """
    Implements the doubly robust algorithm for estimating CATE,
    i.e. the Algorithm in Section 3.1
    """

    def __init__(self, model_Y_X, model_T_X, model_Z_X,
                 prel_model_effect, model_TZ_X,
                 model_effect,
                 cov_clip=.1,
                 n_splits=3,
                 binary_instrument=False, binary_treatment=False):
        """
        Parameters
        ----------
        model_Y_X : model to predict E[Y | X]
        model_T_X : model to predict E[T | X]. In alt_fit, this model is also
            used to predict E[T | X, Z]
        model_Z_X : model to predict E[Z | X]
        model_theta : model that estimates a preliminary version of the CATE
            (e.g. via DMLIV or other method)
        model_TZ_X : model to estimate E[T * Z | X]
        model_effect : model to estimate second stage effect model from doubly robust target
        cov_clip : clipping of the covariate for regions with low "overlap",
            so as to reduce variance
        n_splits : number of splits to use in cross-fitting
        binary_instrument : whether to stratify cross-fitting splits by instrument
        binary_treatment : whether to stratify cross-fitting splits by treatment
        """
        self.prel_model_effect = [clone(prel_model_effect, safe=False) for _ in range(n_splits)]
        self.model_TZ_X = [clone(model_TZ_X, safe=False) for _ in range(n_splits)]
        self.model_T_X = [clone(model_T_X, safe=False) for _ in range(n_splits)]
        self.model_Z_X = [clone(model_Z_X, safe=False) for _ in range(n_splits)]
        self.model_Y_X = [clone(model_Y_X, safe=False) for _ in range(n_splits)]
        self.model_effect = clone(model_effect, safe=False)
        self.cov_clip = cov_clip
        self.n_splits = n_splits
        self.binary_instrument = binary_instrument
        self.binary_treatment = binary_treatment
        self.stored_final_data = False

    def fit(self, y, T, X, Z, store_final=False):
        """
        Parameters
        ----------
        y : outcome
        T : treatment (single dimensional)
        X : features/controls
        Z : instrument (single dimensional)
        store_final : whether to store nuisance data that are used in the final for
            refitting the final stage later on
        """
        if len(Z.shape) > 1 and Z.shape[1] > 1:
            raise AssertionError("Can only accept single dimensional instrument")
        if len(T.shape) > 1 and T.shape[1] > 1:
            raise AssertionError("Can only accept single dimensional treatment")
        if len(y.shape) > 1 and y.shape[1] > 1:
            raise AssertionError("Can only accept single dimensional outcome")
        Z = Z.flatten()
        T = T.flatten()
        y = y.flatten()

        n_samples = y.shape[0]
        prel_theta = np.zeros(n_samples)
        res_t = np.zeros(n_samples)
        res_y = np.zeros(n_samples)
        res_z = np.zeros(n_samples)
        cov = np.zeros(n_samples)

        # We do a three way split, as typically a preliminary theta estimator would require
        # many samples. So having 2/3 of the sample to train model_theta seems appropriate.
        if self.n_splits == 1:
            splits = [(np.arange(X.shape[0]), np.arange(X.shape[0]))]
        # TODO. Deal with multi-class instrument
        elif self.binary_instrument or self.binary_treatment:
            group = 2*T*self.binary_treatment + Z.flatten()*self.binary_instrument
            splits = StratifiedKFold(
                n_splits=self.n_splits, shuffle=True).split(X, group)
        else:
            splits = KFold(n_splits=self.n_splits, shuffle=True).split(X)

        for idx, (train, test) in enumerate(splits):
            # Estimate preliminary theta in cross fitting manner
            prel_theta[test] = self.prel_model_effect[idx].fit(
                y[train], T[train], X[train], Z[train]).effect(X[test]).flatten()
            # Estimate p(X) = E[T | X] in cross fitting manner
            self.model_T_X[idx].fit(X[train], T[train])
            pr_t_test = self.model_T_X[idx].predict(X[test])
            # Estimate r(Z) = E[Z | X] in cross fitting manner
            self.model_Z_X[idx].fit(X[train], Z[train])
            pr_z_test = self.model_Z_X[idx].predict(X[test])
            # Calculate residual T_res = T - p(X) and Z_res = Z - r(X)
            res_t[test] = T[test] - pr_t_test
            res_z[test] = Z[test] - pr_z_test
            # Estimate residual Y_res = Y - q(X) = Y - E[Y | X] in cross fitting manner
            res_y[test] = y[test] - \
                self.model_Y_X[idx].fit(X[train], y[train]).predict(X[test])
            # Estimate cov[T, Z | X] = E[(T-p(X))*(Z-r(X)) | X] = E[T*Z | X] - E[T |X]*E[Z | X]
            cov[test] = self.model_TZ_X[idx].fit(
                X[train], T[train] * Z[train]).predict(X[test]) - pr_t_test * pr_z_test

        self.cov = cov

        # Estimate final model of theta(X) by minimizing the square loss:
        # (prel_theta(X) + (Y_res - prel_theta(X) * T_res) * Z_res / cov[T,Z | X] - theta(X))^2
        # We clip the covariance so that it is bounded away from zero, so as to reduce variance
        # at the expense of some small bias. For points with very small covariance we revert
        # to the model-based preliminary estimate and do not add the correction term.
        cov_sign = np.sign(cov)
        cov_sign[cov_sign == 0] = 1
        clipped_cov = cov_sign * np.clip(np.abs(cov),
                                         self.cov_clip, np.inf)
        theta_dr = prel_theta + \
            (res_y - prel_theta * res_t) * res_z / clipped_cov
        self.model_effect.fit(X, theta_dr)

        if store_final:
            self.X = X
            self.theta_dr = theta_dr
            self.stored_final_data = True

        return self
    
    def refit_final(self, model_effect):
        """
        Change the final effect model and refit the final stage.
        Parameters
        ----------
        model_effect : an instance of the new effect model to be fitted in the final stage
        """
        if not self.stored_final_data:
            raise AttributeError("Estimator is not yet fit with store_data=True")
        self.model_effect = model_effect
        self.model_effect.fit(self.X, self.theta_dr)
        return self

    def effect(self, X):
        """
        Parameters
        ----------
        X : features
        """
        return self.model_effect.predict(X)

    @property
    def effect_model(self):
        return self.model_effect
    
    @property
    def fitted_nuisances(self):
        return {'prel_model_effect': self.prel_model_effect,
                'model_TZ_X': self.model_TZ_X,
                'model_T_X': self.model_T_X,
                'model_Z_X': self.model_Z_X,
                'model_Y_X': self.model_Y_X}

    @property
    def coef_(self):
        if not hasattr(self.effect_model, 'coef_'):
            raise AttributeError("Effect model is not linear!")
        return self.effect_model.coef_

    @property
    def intercept_(self):
        if not hasattr(self.effect_model, 'intercept_'):
            raise AttributeError("Effect model is not linear!")
        return self.effect_model.intercept_ 


class ProjectedDRIV:
    """
    This is a slight variant of DRIV where we use E[T|Z, X] as
    the instrument as opposed to Z. The rest is the same as the normal
    fit.
    """

    def __init__(self, model_Y_X, model_T_X, model_T_XZ,
                 prel_model_effect, model_TZ_X,
                 model_effect,
                 cov_clip=.1,
                 n_splits=3,
                 binary_instrument=False, binary_treatment=False):
        """
        Parameters
        ----------
        model_Y_X : model to predict E[Y | X]
        model_T_X : model to predict E[T | X]. In alt_fit, this model is also
            used to predict E[T | X, Z]
        model_T_XZ : model to predict E[T | X, Z]
        model_theta : model that estimates a preliminary version of the CATE
            (e.g. via DMLIV or other method)
        model_TZ_X : model to estimate cov[T, E[T|X,Z] | X] = E[(T-E[T|X]) * (E[T|X,Z] - E[T|X]) | X].
        model_effect : model to estimate second stage effect model from doubly robust target
        cov_clip : clipping of the covariate for regions with low "overlap",
            so as to reduce variance
        n_splits : number of splits to use in cross-fitting
        binary_instrument : whether to stratify cross-fitting splits by instrument
        binary_treatment : whether to stratify cross-fitting splits by treatment
        """
        self.prel_model_effect = [clone(prel_model_effect, safe=False) for _ in range(n_splits)]
        self.model_TZ_X = [clone(model_TZ_X, safe=False) for _ in range(n_splits)]
        self.model_T_X = [clone(model_T_X, safe=False) for _ in range(n_splits)]
        self.model_T_XZ = [clone(model_T_XZ, safe=False) for _ in range(n_splits)]
        self.model_Y_X = [clone(model_Y_X, safe=False) for _ in range(n_splits)]
        self.model_effect = model_effect
        self.cov_clip = cov_clip
        self.n_splits = n_splits
        self.binary_instrument = binary_instrument
        self.binary_treatment = binary_treatment
        self.stored_final_data = False

    def fit(self, y, T, X, Z, store_final=False):
        """ 
        Parameters
        ----------
        y : outcome
        T : treatment (single dimensional)
        X : features/controls
        Z : instrument
        store_final : whether to store nuisance data that are used in the final for
            refitting the final stage later on
        """
        if len(T.shape) > 1 and T.shape[1] > 1:
            raise AssertionError("Can only accept single dimensional treatment")
        if len(y.shape) > 1 and y.shape[1] > 1:
            raise AssertionError("Can only accept single dimensional outcome")
        if len(Z.shape) == 1:
            Z = Z.reshape(-1, 1)
        if (Z.shape[1] > 1) and self.binary_instrument:
            raise AssertionError("Binary instrument flag is True, but instrument is multi-dimensional")
        T = T.flatten()
        y = y.flatten()

        n_samples = y.shape[0]
        prel_theta = np.zeros(n_samples)
        res_t = np.zeros(n_samples)
        res_y = np.zeros(n_samples)
        res_z = np.zeros(n_samples)
        cov = np.zeros(n_samples)

        # We do a three way split, as typically a preliminary theta estimator would require
        # many samples. So having 2/3 of the sample to train model_theta seems appropriate.
        if self.n_splits == 1:
            splits = [(np.arange(X.shape[0]), np.arange(X.shape[0]))]
        # TODO. Deal with multi-class instrument/treatment
        elif self.binary_instrument or self.binary_treatment:
            group = 2*T*self.binary_treatment + Z.flatten()*self.binary_instrument
            splits = StratifiedKFold(
                n_splits=self.n_splits, shuffle=True).split(X, group)
        else:
            splits = KFold(n_splits=self.n_splits, shuffle=True).split(X)

        for idx, (train, test) in enumerate(splits):
            # Estimate preliminary theta in cross fitting manner
            prel_theta[test] = self.prel_model_effect[idx].fit(
                y[train], T[train], X[train], Z[train]).effect(X[test]).flatten()
            # Estimate p(X) = E[T | X] in cross fitting manner
            self.model_T_X[idx].fit(X[train], T[train])
            pr_t_test = self.model_T_X[idx].predict(X[test])
            pr_t_train = self.model_T_X[idx].predict(X[train])
            # Estimate h(X, Z) = E[T | X, Z] in cross fitting manner
            self.model_T_XZ[idx].fit(hstack([X[train], Z[train]]), T[train])
            proj_t_test = self.model_T_XZ[idx].predict(hstack([X[test], Z[test]]))
            proj_t_train = self.model_T_XZ[idx].predict(
                hstack([X[train], Z[train]]))
            # Calculate residual T_res = T - p(X) and Z_res = h(Z, X) - p(X)
            res_t[test] = T[test] - pr_t_test
            res_z[test] = proj_t_test - pr_t_test
            # Estimate residual Y_res = Y - q(X) = Y - E[Y | X] in cross fitting manner
            res_y[test] = y[test] - \
                self.model_Y_X[idx].fit(X[train], y[train]).predict(X[test])
            # Estimate cov[T, E[T|X,Z] | X] = E[T * E[T|X,Z]] - E[T|X]^2
            cov[test] = self.model_TZ_X[idx].fit(X[train],
                                            (T[train] - pr_t_train) * (proj_t_train - pr_t_train)).predict(X[test])

        self.cov = cov

        # Estimate final model of theta(X) by minimizing the square loss:
        # (prel_theta(X) + (Y_res - prel_theta(X) * T_res) * Z_res / cov[T,Z | X] - theta(X))^2
        # We clip the covariance so that it is bounded away from zero, so as to reduce variance
        # at the expense of some small bias. For points with very small covariance we revert
        # to the model-based preliminary estimate and do not add the correction term.
        # In this case covariance must be positive. So we clip its actual not absolute value.
        cov_sign = np.sign(cov)
        cov_sign[cov_sign == 0] = 1
        clipped_cov = cov_sign * np.clip(np.abs(cov), self.cov_clip, np.inf)
        theta_dr = prel_theta + \
            (res_y - prel_theta * res_t) * res_z / clipped_cov
        
        self.model_effect.fit(X, theta_dr)

        if store_final:
            self.X = X
            self.theta_dr = theta_dr
            self.stored_final_data = True

        return self
    
    def refit_final(self, model_effect):
        """
        Change the final effect model and refit the final stage.
        Parameters
        ----------
        model_effect : an instance of the new effect model to be fitted in the final stage
        """
        if not self.stored_final_data:
            raise AttributeError("Estimator is not yet fit with store_data=True")
        self.model_effect = model_effect
        self.model_effect.fit(self.X, self.theta_dr)
        return self

    def effect(self, X):
        """
        Parameters
        ----------
        X : features
        """
        return self.model_effect.predict(X)

    @property
    def effect_model(self):
        return self.model_effect

    @property
    def fitted_nuisances(self):
        return {'prel_model_effect': self.prel_model_effect,
                'model_TZ_X': self.model_TZ_X,
                'model_T_X': self.model_T_X,
                'model_T_XZ': self.model_T_XZ,
                'model_Y_X': self.model_Y_X}

    @property
    def coef_(self):
        if not hasattr(self.effect_model, 'coef_'):
            raise AttributeError("Effect model is not linear!")
        return self.effect_model.coef_

    @property
    def intercept_(self):
        if not hasattr(self.effect_model, 'intercept_'):
            raise AttributeError("Effect model is not linear!")
        return self.effect_model.intercept_ 
