import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import numpy as np
from torch.nn import functional as F


class DSU(nn.Module):
    """
    Distribution Uncertainty Module
        Args:
        p   (float): probabilty of foward distribution uncertainty module, p in [0,1].
    """

    def __init__(self, p=0.5, eps=1e-6):
        super(DSU, self).__init__()
        self.eps = eps
        self.p = p
        self.factor = 1.0

    def _reparameterize(self, mu, std):
        epsilon = torch.randn_like(std) * self.factor
        return mu + epsilon * std

    def sqrtvar(self, x):
        t = (x.var(dim=0, keepdim=True) + self.eps).sqrt()  # 1xKxCx1
        # t = t.repeat(x.shape[0], 1)
        return t

    def forward(self, x):
        # B K C L
        if (not self.training) or (np.random.random()) > self.p:
            return x

        # mean = x.mean(dim=[2, 3], keepdim=False)
        # std = (x.var(dim=[2, 3], keepdim=False) + self.eps).sqrt()

        mean = x.mean(dim=3, keepdim=True)  # BxKxCx1
        std = (x.var(dim=3, keepdim=True) + self.eps).sqrt()

        sqrtvar_mu = self.sqrtvar(mean)
        sqrtvar_std = self.sqrtvar(std)

        beta = self._reparameterize(mean, sqrtvar_mu)
        gamma = self._reparameterize(std, sqrtvar_std)

        x = (x - mean) / std
        x = x * gamma + beta

        # x = (x - mean.reshape(x.shape[0], 1, x.shape[2])) / std.reshape(x.shape[0], 1, x.shape[2])
        # x = x * gamma.reshape(x.shape[0], 1, x.shape[2]) + beta.reshape(x.shape[0], 1, x.shape[2])

        return x