from sswimlib.utils.math_pytorch import nlml_chol_fast
from sswimlib.utils.containers import Prediction
from sswimlib.utils.opt import ParamsList
from sswimlib.utils.cholesky import psd_safe_cholesky as cholesky
from tqdm import trange
import torch
from sswimlib.utils.metrics import mnlp as mnlp_np, rmse


class BayLinRegUICholFast:
    """
    SSWIM augmented SSGPR (Using Cholesky)
    """

    def __init__(self,
                 kphi,
                 has_pseudo_training=False,
                 warpmul_models=None,
                 warpadd_models=None,
                 alpha=0.05,
                 beta=1.5,
                 x_trn=None,
                 y_trn=None,
                 verbose=0):
        self.x_trn = x_trn  # Mean of the gaussian inputs
        self.y_trn = y_trn
        self.verbose = verbose

        self.kphi = kphi
        self.M = self.kphi.M  # Feature dimensionality
        self.has_pseudo_training = has_pseudo_training
        self.warpmul_models = warpmul_models
        self.warpadd_models = warpadd_models
        self.alpha = alpha
        self.beta = beta
        self.params = [self.alpha,
                       self.beta]
        if self.has_pseudo_training is True:
            self.params = self.params + [self.x_trn, self.y_trn]

        self.params_list = ParamsList(self.get_params(),
                                      numpy_2_torch=False,
                                      torch_2_numpy=False,
                                      keep_tensor_list=True)

        self.optim_class = None
        self.optimizer = None
        self.logging = {"loss": []}

    def get_params(self,
                   full_depth=True):
        """
        Returns all parameter objects from this class
        Optionally (default) returns every parameter from all parameter holding objects contained within
        :param full_depth:
        :return:
        """
        if full_depth is True:
            params = self.params + self.kphi.get_params()
            if self.warpmul_models is not None:
                for warpmul_model in self.warpmul_models:
                    params += warpmul_model.get_params()
            if self.warpadd_models is not None:
                for warpadd_model in self.warpadd_models:
                    params += warpadd_model.get_params()
            return params
        else:
            return self.params

    def fit(self,
            x_trn=None,
            x_var=None,  # for uncertain inputs prediction
            y_trn=None,
            with_grad=False):
        """
        Fit the weights
        :return:
        """
        if x_trn is None:
            x_trn = self.x_trn
            y_trn = self.y_trn
        if self.has_pseudo_training is True:
            x_trn = x_trn.forward()
            y_trn = y_trn.forward()

        if self.warpadd_models is not None:
            for j in range(len(self.warpmul_models)):
                warpmul_model = self.warpmul_models[j]
                warpadd_model = self.warpadd_models[j]
                warpmul_model.fit(with_grad=with_grad)
                warpadd_model.fit(with_grad=with_grad)
            # Do the warping!
            for j in range(len(self.warpadd_models)):
                warpmul_model = self.warpmul_models[j]
                warpadd_model = self.warpadd_models[j]

                warpmul_pred = warpmul_model.predict(x_trn,
                                                     x_var=x_var,  # The variance of the previous step!
                                                     with_grad=with_grad,
                                                     with_var=True)
                warpadd_pred = warpadd_model.predict(x_trn,
                                                     x_var=x_var,  # The variance of the previous step!
                                                     with_grad=with_grad,
                                                     with_var=True)

                warpmul_mean = warpmul_pred.mean
                warpmul_var = warpmul_pred.var
                warpadd_mean = warpadd_pred.mean
                warpadd_var = warpadd_pred.var
                """ Update the variance term """
                if x_var is None:
                    x_var = warpmul_var * (x_trn * x_trn) + warpadd_var
                else:
                    # Do the V[XY] = V[X]V[Y] + V[X]E[Y]^2 + V[Y]E[X]^2 calculation
                    # Note, predvar is a (N,1) column vector, and x_trn is (N,D) matrix. this broadcasts.
                    EX = warpmul_mean
                    EY = x_trn
                    VX = warpmul_var
                    VY = x_var
                    x_var = (VX * VY + VX * EY ** 2 + VY * EX ** 2) + warpadd_var
                """ Update the mean term """
                x_trn = x_trn * warpmul_mean + warpadd_mean

        self.kphi.sample_frequencies()
        PHI = self.kphi.transform(x_trn, X_var=x_var)
        A = torch.matmul(PHI.t(), PHI) + (self.alpha.forward() / self.beta.forward()) * torch.eye(self.kphi.M)
        self.R_lower = cholesky(A, upper=False)
        b = torch.matmul(PHI.t(), y_trn)
        self.Rb_solve, _ = torch.solve(b, self.R_lower)
        self.mu, _ = torch.solve(self.Rb_solve, self.R_lower.t())

    def fit_get_loss(self,
                     x_trn=None,
                     x_var=None,
                     y_trn=None,
                     loss_type="nlml",
                     log_the_loss=False,
                     with_grad=False):
        """
        Run a fit as well as calculate the loss (NLML)
        :param log_the_loss:
        :return:
        """
        if x_trn is None:
            x_trn = self.x_trn
            y_trn = self.y_trn

        if loss_type == "nlml":
            self.fit(x_trn=x_trn,
                     x_var=x_var,
                     y_trn=y_trn,
                     with_grad=with_grad)
            loss = nlml_chol_fast(y_trn=y_trn,
                                  r=self.R_lower,
                                  rb_solve=self.Rb_solve,
                                  n=x_trn.shape[0],
                                  m=self.M,
                                  alpha=self.alpha.forward(),
                                  beta=self.beta.forward())

        ''' optimize '''
        if log_the_loss is True:
            self.logging["loss"].append(loss.item())

        return loss

    def optimize(self,
                 optimizer=None,
                 optimizer_kwargs=None,
                 loss_type="nlml",
                 log_the_loss=True,
                 test_logging=None):
        """

        :param optimizer:           The optimizer.
        :param optimizer_kwargs:    Kwargs for the optimizer.
        :param loss_type:           Loss type. Only NLML for now.
        :param log_the_loss:        Whether or not to log values during training.
        :param test_logging:        Whether or not to log the test data during training.
                                    For overfitting analysis.
        :return:
        """
        if optimizer == "adam":
            self.optim_class = torch.optim.Adam

            if "optim_epochs" in optimizer_kwargs:
                epochs = optimizer_kwargs["optim_epochs"]
            else:
                epochs = 100
            if "optim_kwargs" in optimizer_kwargs:
                optim_kwargs = optimizer_kwargs["optim_kwargs"]
            else:
                optim_kwargs = {"lr": .01,
                                "betas": (0.9, 0.999),
                                "eps": 1e-10}
            self.optimizer = self.optim_class(self.params_list.params_flat, **optim_kwargs)

            if test_logging is not None:
                self.logging["rmse_test_normscale"] = []
                self.logging["mnlp_test_normscale"] = []
                self.logging["rmse_test_origscale"] = []
                self.logging["mnlp_test_origscale"] = []
                self.logging["nlml"] = []
                X_tst = test_logging["X_tst"]
                Y_tst_np = test_logging["Y_tst_np"]
                Y_scaler = test_logging["Y_scaler"]

            for i in trange(epochs, position=0, leave=True):
                self.optimizer.zero_grad()
                if loss_type == "nlml":
                    loss = self.fit_get_loss(loss_type=loss_type,
                                             with_grad=True)

                loss.backward(retain_graph=False)
                if i < epochs - 1:
                    self.optimizer.step()
                    print("LOSS: {}".format(loss.item()))

                if log_the_loss:
                    self.logging["loss"].append(loss.item())

                if test_logging:
                    prediction_tst = self.predict(x=X_tst, with_var=True, with_grad=False)
                    Y_predmean_tst = prediction_tst.mean.cpu().data.numpy()
                    Y_predvar_tst = prediction_tst.var.cpu().data.numpy()

                    if Y_scaler:
                        test_rmse = rmse(y_actual=Y_scaler.inverse_transform(Y_tst_np),
                                         y_pred=Y_scaler.inverse_transform(Y_predmean_tst))
                        test_mnll = mnlp_np(actual_mean=Y_scaler.inverse_transform(Y_tst_np),
                                            pred_mean=Y_scaler.inverse_transform(Y_predmean_tst),
                                            pred_var=Y_scaler.var_ * Y_predvar_tst)
                        self.logging["rmse_test_origscale"].append(test_rmse)
                        self.logging["mnlp_test_origscale"].append(test_mnll)
                        print(f"[ORIG SCALE] <RMSE>: {test_rmse}, <MNLP>: {test_mnll}", end="")
                    test_rmse = rmse(y_actual=Y_tst_np,
                                     y_pred=Y_predmean_tst)
                    test_mnll = mnlp_np(actual_mean=Y_tst_np,
                                        pred_mean=Y_predmean_tst,
                                        pred_var=Y_predvar_tst)
                    self.logging["rmse_test_normscale"].append(test_rmse)
                    self.logging["mnlp_test_normscale"].append(test_mnll)
                    print(f"[NORM SCALE] <RMSE>: {test_rmse}, <MNLP>: {test_mnll}")

    def predict(self,
                x,
                x_var=None,
                with_var=True,
                with_grad=False):
        """

        :param x:
        :param with_var:
        :param with_grad:
        :return:
        """

        prediction = Prediction()

        if self.warpmul_models is not None:
            # Do the warping!
            for j in range(len(self.warpadd_models)):
                warpmul_model = self.warpmul_models[j]
                warpadd_model = self.warpadd_models[j]

                warpmul_pred = warpmul_model.predict(x,
                                                     x_var=x_var,
                                                     with_grad=with_grad,
                                                     with_var=True)
                warpadd_pred = warpadd_model.predict(x,
                                                     x_var=x_var,
                                                     with_grad=with_grad,
                                                     with_var=True)

                warpmul_mean = warpmul_pred.mean
                warpmul_var = warpmul_pred.var
                warpadd_mean = warpadd_pred.mean
                warpadd_var = warpadd_pred.var

                """ Update the variance term """
                if x_var is None:
                    x_var = warpmul_var * (x * x) + warpadd_var
                else:
                    # Do the V[XY] = V[X]V[Y] + V[X]E[Y]^2 + V[Y]E[X]^2 calculation
                    # Note, predvar is a (N,1) column vector, and x_trn is (N,D) matrix. this broadcasts.
                    EX = warpmul_mean
                    EY = x
                    VX = warpmul_var
                    VY = x_var
                    x_var = (VX * VY + VX * EY ** 2 + VY * EX ** 2) + warpadd_var
                """ Update the mean term """
                x = x * warpmul_mean + warpadd_mean
        if with_grad:
            PHI = self.kphi.transform(x, X_var=x_var)
            Y_preds = torch.matmul(PHI, self.mu)
            prediction.mean = Y_preds
        else:
            with torch.no_grad():
                PHI = self.kphi.transform(x, X_var=x_var)
                Y_preds = torch.matmul(PHI, self.mu)
                prediction.mean = Y_preds

        if with_var:
            RPhiPred_solve, _ = torch.solve(PHI.t(), self.R_lower)
            prediction.var = (1 / self.beta.forward() + (1 / self.beta.forward()) * torch.norm(RPhiPred_solve, dim=0) ** 2).t()
        return prediction
