import torch
import torch.distributions as distributions
import torch.distributions.transforms as transforms
import numpy as np


class TanhTransform(transforms.Transform):
    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return torch.tanh(x)

    def _inverse(self, y):
        return (y.log1p() - (-y).log1p()) / 2

    def log_abs_det_jacobian(self, x, y):
        return 2. * (np.log(2.) - x - torch.nn.functional.softplus(-2. * x))
        # return (1. - y.pow(2)).clamp(min=1.e-6).log()


class TanhGaussian(distributions.TransformedDistribution):
    from torch.distributions import constraints
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    has_rsample = True

    def __init__(self, loc, scale, validate_args=None):
        base_dist = distributions.Normal(loc, scale)
        super().__init__(base_dist, TanhTransform(), validate_args=validate_args)

    @property
    def mean(self):
        return self.base_dist.mean.tanh()

    @property
    def stddev(self):
        return self.base_dist.stddev


__all__ = ['TanhGaussian']
