from __future__ import division
import numpy as np
import torch
import math 
import pyro.distributions as dist
from pyro.distributions import TorchDistribution
from torch.distributions import constraints

__all__ = [
    'isotropic_gauss_loglike', 'laplace_prior', 
    'isotropic_gauss_prior', 'spike_slab_2GMM',
    'spike_slab_2GMM_pyro'
]

def isotropic_gauss_loglike(x, mu, sigma, do_sum=True):
    cte_term = -(0.5) * math.log(2 * math.pi)
    det_sig_term = -torch.log(sigma)
    inner = (x - mu) / sigma
    dist_term = -(0.5) * (inner ** 2)

    if do_sum:
        out = (cte_term + det_sig_term + dist_term).sum()  # sum over all weights
    else:
        out = (cte_term + det_sig_term + dist_term)
    return out


class laplace_prior(object):
    def __init__(self, mu, b):
        self.mu = mu
        self.b = b

    def loglike(self, x, do_sum=True):
        if do_sum:
            return (-np.log(2 * self.b) - torch.abs(x - self.mu) / self.b).sum()
        else:
            return (-np.log(2 * self.b) - torch.abs(x - self.mu) / self.b)


class isotropic_gauss_prior(object):
    def __init__(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma

        self.cte_term = -(0.5) * np.log(2 * np.pi)
        self.det_sig_term = -np.log(self.sigma)

    def loglike(self, x, do_sum=True):

        dist_term = -(0.5) * ((x - self.mu) / self.sigma) ** 2
        if do_sum:
            return (self.cte_term + self.det_sig_term + dist_term).sum()
        else:
            return (self.cte_term + self.det_sig_term + dist_term)

    def sample(self, size):
        return torch.randn(size) * self.sigma + self.mu

    def mu(self):
        return self.mu 

    def sigma(self):
        return self.sigma

class spike_slab_2GMM(object):
    def __init__(self, mu1, mu2, sigma1, sigma2, pi):
        self.N1 = isotropic_gauss_prior(mu1, sigma1)
        self.N2 = isotropic_gauss_prior(mu2, sigma2)

        self.pi1 = pi
        self.lpi1 = math.log(self.pi1)
        self.pi2 = (1 - pi)
        self.lpi2 = math.log(self.pi2)

    def loglike(self, x):
        N1_ll = self.N1.loglike(x)
        N2_ll = self.N2.loglike(x)

        loglike = torch.logsumexp(torch.tensor([N1_ll + self.lpi1, N2_ll + self.lpi2]), 0)

        # # Numerical stability trick -> unnormalising logprobs will underflow otherwise
        # max_loglike = torch.max(N1_ll, N2_ll)
        # normalised_like = self.pi1 * torch.exp(N1_ll - max_loglike) + self.pi2 * torch.exp(N2_ll - max_loglike)
        # loglike = torch.log(normalised_like) + max_loglike

        return loglike

    def sample(self, size):
        z = torch.bernoulli(torch.ones(size) * self.pi1)
        s1 = self.N1.sample(size)
        s2 = self.N2.sample(size)
        return s1 * z + s2 * (1 - z)

    def mu(self):
        return self.pi1 * self.N1.mu + self.pi2 * self.N2.mu

    def sigma(self):
        return math.sqrt(self.pi1 * self.N1.sigma ** 2 + self.pi2 * self.N2.sigma ** 2 \
            + self.pi1 * self.N1.mu ** 2 + self.pi2 * self.N2.mu ** 2 - self.mu() ** 2)

class spike_slab_2GMM_pyro(TorchDistribution):
    support = constraints.real
    has_rsample = False 
    arg_constraints = {}

    def __init__(self, mu1, mu2, sigma1, sigma2, pi):
        super(TorchDistribution, self).__init__()
        self.N1 = dist.Normal(mu1, sigma1)
        self.N2 = dist.Normal(mu2, sigma2)
        self.pi1 = pi
        self.pi2 = (1 - pi)
        self.device = self.pi1.device

    def log_prob(self, x):
        N1_ll = self.N1.log_prob(x)
        N2_ll = self.N2.log_prob(x)

        # Numerical stability trick -> unnormalising logprobs will underflow otherwise
        max_loglike = torch.max(N1_ll, N2_ll)
        normalised_like = self.pi1 * torch.exp(N1_ll - max_loglike) + self.pi2 * torch.exp(N2_ll - max_loglike)
        loglike = torch.log(normalised_like) + max_loglike

        return loglike

    def sample(self, size):
        z = torch.bernoulli(torch.ones(size, device=self.device) * self.pi1)
        s1 = self.N1.sample(size)
        s2 = self.N2.sample(size)
        return s1 * z + s2 * (1 - z)


class HorseShoe(object):
    def __init__(self, mu, b):
        self.mu = mu
        self.b = b

    def loglike(self, x, do_sum=True):
        if do_sum:
            return (-np.log(2 * self.b) - torch.abs(x - self.mu) / self.b).sum()
        else:
            return (-np.log(2 * self.b) - torch.abs(x - self.mu) / self.b)
