import torch
import torch.nn as nn
import numpy as np
import warnings


def poisson_loss(output, target):
    # Assuming output is log rate (for numerical stability), convert to rate
    rate = torch.exp(output)
    loss = torch.mean(rate - target * output)  # Simplified negative log likelihood
    return loss


def neg_log_likelihood(output, target):
    # output has gone through a softplus
    loss = nn.PoissonNLLLoss(log_input=False, full=True, reduction="none")
    return loss(output, target).sum() / output.size(0)


def kl_divergence(mu, logvar):
    loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return loss / mu.size(0)


def latent_regularizer(z, cfg):
    l2_reg = torch.sum(z**2)
    k = cfg.training.get("td_k", 5)  # number of time differences
    temporal_difference_loss = 0

    z_diff = 0
    for i in range(1, k + 1):
        z_diff += ((z[:, :, :-i] - z[:, :, i:]) ** 2 * ((i + 1) / 2)).sum()

    temporal_difference_loss = z_diff  # gp prior-like loss
    return l2_reg + temporal_difference_loss / (cfg.training.latent_beta) * (
        cfg.training.latent_td_beta
    )


def latent_regularizer_v2(z, cfg):
    """ regualrizer that penalizes the squared difference between latents and 0 and
    latent states at different time steps"""
    l2_reg = torch.sum(z**2)
    k = cfg.training.get("td_k", 5)  # number of time differences
    temporal_difference_loss = 0

    z_diff = 0
    for i in range(1, k + 1):
        z_diff += ((z[:, :, :-i] - z[:, :, i:]) ** 2 * (1 / (1 + i))).sum()

    temporal_difference_loss = z_diff  # gp prior-like loss
    # it later gets scaled by latent_beta which then only affects l2_reg
    # so the temporal difference loss is multiplied only by training.latent_td_beta
    return l2_reg + temporal_difference_loss / (cfg.training.latent_beta) * (
        cfg.training.latent_td_beta
    )



def latent_squared_regularizer_v2(z, cfg):
    l2_reg = torch.sum(z**2)
    k = cfg.training.get("td_k", 5)  # number of time differences
    temporal_difference_loss = 0

    z_diff = 0
    for i in range(1, k + 1):
        z_diff += ((z[:, :, :-i] - z[:, :, i:]) ** 2 * (1 / (1 + i**2))).sum()

    temporal_difference_loss = z_diff  # gp prior-like loss
    return l2_reg + temporal_difference_loss / (cfg.training.latent_beta) * (
        cfg.training.latent_td_beta
    )


class GPNLL(nn.Module):
    def __init__(
        self, T, lengthscale=2.0, bandwidth=10, kernel="cauchy", covariance_eps=1e-2
    ):
        super().__init__()
        self.T = torch.arange(T).float()
        self.lengthscale = lengthscale
        self.bandwidth = bandwidth
        self.kernel = kernel
        self.covariance_eps = covariance_eps
        if kernel not in ["cauchy", "rbf"]:
            raise ValueError("Kernel must be one of 'cauchy' or 'rbf'")
        self.register_buffer(
            "full_precision_matrix", self.compute_full_precision_matrix()
        )
        self.register_buffer(
            "banded_precision_matrix", self.compute_banded_precision_matrix()
        )
        self.register_buffer("covariance_matrix", self.compute_covariance_matrix())

        self.check_psd()

    def cauchy_kernel(self, t1, t2, lengthscale):
        return 1 / (1 + ((t1 - t2) / lengthscale) ** 2)

    def rbf_kernel(self, t1, t2, lengthscale):
        return torch.exp(-((t1 - t2) ** 2) / lengthscale**2)

    # check if precision matrix is symmetric and positive definite
    @staticmethod
    def is_psd(mat):
        return bool(
            (mat == mat.T).all() and (torch.linalg.eigvals(mat).real >= 0).all()
        )

    def check_psd(self):
        print("banded precision is_psd:", self.is_psd(self.banded_precision_matrix))
        print("full precision is_psd:", self.is_psd(self.full_precision_matrix))
        print(
            "extracted full precision is_psd:",
            self.is_psd(
                self.full_precision_matrix[
                    len(self.T) // 2 - 30 : len(self.T) // 2 + 30,
                    len(self.T) // 2 - 30 : len(self.T) // 2 + 30,
                ]
            ),
        )

    def compute_covariance_matrix(self):
        t1 = self.T[:, None]
        t2 = self.T[None, :]
        if self.kernel == "cauchy":
            covariance_matrix = self.cauchy_kernel(t1, t2, self.lengthscale)
        else:
            covariance_matrix = self.rbf_kernel(t1, t2, self.lengthscale)

        covariance_matrix = covariance_matrix + self.covariance_eps * torch.eye(
            covariance_matrix.shape[0]
        )
        covariance_matrix = (covariance_matrix + covariance_matrix.T) / 2
        return covariance_matrix

    def compute_full_precision_matrix(self):
        covariance_matrix = self.compute_covariance_matrix()
        # covariance_matrix += 1 * torch.eye(covariance_matrix.shape[0])
        precision_matrix = torch.linalg.inv(covariance_matrix)
        precision_matrix = (precision_matrix + precision_matrix.T) / 2
        return precision_matrix

    def compute_banded_precision_matrix(self):
        covariance_matrix = self.compute_covariance_matrix()
        for i in range(covariance_matrix.shape[0]):
            for j in range(covariance_matrix.shape[1]):
                if np.abs(i - j) > self.bandwidth:
                    covariance_matrix[i, j] = 0
        precision_matrix = torch.linalg.inv(covariance_matrix)
        precision_matrix = (precision_matrix + precision_matrix.T) / 2
        return precision_matrix

    def forward(self, x, method="full", reduction="sum"):
        """
        Compute the negative log-likelihood (xT.Q.x) of the data given the GP prior.
        Args:
            x: Tensor of shape (B, C, T) where B is the batch size, C is the number of channels, and T is the number of time points.
            method: Method to use for computing the likelihood. Can be "full" or "banded_dense".
        Returns:
            Tensor of shape (B,C) containing the negative log-likelihood of the data. Mean over X
        """

        precision_matrix = (
            self.full_precision_matrix
            if method == "full"
            else self.banded_precision_matrix
        )
        if method != "full":
            warnings.warn(
                "Banded precision matrix has no guarantee to be positive definite."
            )

        max_T = precision_matrix.shape[0]
        B, C, T = x.shape
        if max_T != T:
            precision_matrix = (
                precision_matrix[
                    max_T // 2 - T // 2 : max_T // 2 + T // 2,
                    max_T // 2 - T // 2 : max_T // 2 + T // 2,
                ]
                / precision_matrix[max_T // 2, max_T // 2]
            )
        else:
            precision_matrix = (
                precision_matrix / precision_matrix[max_T // 2, max_T // 2]
            )

        result = torch.einsum("bci,ij,bcj->bc", x, precision_matrix, x)
        # print(result)

        return result.sum() if reduction == "sum" else result


# gpnll_prior = GPNLL(T=1024, lengthscale=1.0)  # max T = 1024


# def latent_regularizer_v3(z, cfg):
#     return gpnll_prior(z)


if __name__ == "__main__":
    import lovely_tensors

    lovely_tensors.monkey_patch()

    gp = GPNLL(T=1024, lengthscale=1.0, kernel="cauchy")
    x = torch.randn(4, 8, 1024)

    result = gp(x)
    result_banded = gp(x, method="full")
    print("X shape:", x)
    print("Result shape:", result)
    print(gp.full_precision_matrix[100, 100])

    gp = GPNLL(T=1024, lengthscale=2.0, kernel="cauchy")
    x = torch.randn(4, 8, 1024)

    result = gp(x)
    result_banded = gp(x, method="full")
    print("X shape:", x)
    print("Result shape:", result)
    print(gp.full_precision_matrix[100, 100])

    gp = GPNLL(T=1024, lengthscale=4.0, kernel="cauchy", covariance_eps=1e-6)
    x = torch.randn(4, 8, 1024)

    result = gp(x)
    result_banded = gp(x, method="full")
    print("X shape:", x)
    print("Result shape:", result)
    print(gp.full_precision_matrix[100, 100])

    gp = GPNLL(T=1024, lengthscale=4.0, kernel="cauchy", covariance_eps=1e-2)
    x = torch.randn(4, 8, 1024)

    result = gp(x)
    result_banded = gp(x, method="full")
    print("X shape:", x)
    print("Result shape:", result)
    print(gp.full_precision_matrix[100, 100])
