import torch
import torch.nn as nn
from torch import log as tlog


class WrapD(nn.Module):
    def __init__(self, D):
        super(WrapD, self).__init__()
        self.D = D

    def forward(self, x):
        prob = self.D(x)
        return prob

    def ar(self, x_s, x_p):
        d_s = self.D(x_s).cpu()
        d_p = self.D(x_p).cpu()
        log_prob = tlog(1. - d_s) + tlog(d_p) - tlog(d_s) - tlog(1 - d_p)
        log_prob = torch.clamp(log_prob, max=0.0).view(log_prob.size(0), )
        return torch.exp(log_prob)

class WrapCriticD(nn.Module):
    def __init__(self, D):
        super(WrapCriticD, self).__init__()
        self.D = D
        self.nonlinear = nn.Sigmoid()

    def forward(self, x):
        logit = self.D(x)
        prob = self.nonlinear(logit)
        return prob

    def ar(self, x_s, x_p):
        d_s = self.forward(x_s).cpu()
        d_p = self.forward(x_p).cpu()
        log_prob = tlog(1. - d_s) + tlog(d_p) - tlog(d_s) - tlog(1 - d_p)
        log_prob = torch.clamp(log_prob, max=0.0).view(log_prob.size(0), )
        return torch.exp(log_prob)

