import torch
import torch.nn as nn
from torch import optim

class MCELoss(nn.Module):
    def __init__(self, D):
        super(MCELoss, self).__init__()
        self.D = D

    def loss(self, x_s, x_p):
        b = 1e-12
        d_sp = (1. - b) * self.D(x_s, x_p) + b
        d_ps = (1. - b) * self.D(x_p, x_s) + b

        return - (torch.log(d_sp) + torch.log(1 - d_ps))

class UBLoss(nn.Module):
    def __init__(self, D):
        super(UBLoss, self).__init__()
        self.D = D

    def loss(self, x_s, x_p):
        b = 1e-12
        d_sp = (1. - b) * self.D(x_s, x_p) + b
        d_ps = (1. - b) * self.D(x_p, x_s) + b

        linear = d_ps / d_sp
        return linear + torch.log(linear)


def train_zmce(D, P, crit_type, dataloader, lr, epoch_n, device, verbose, ml_ls, path_f='.'):
    read_names = []
    loss_history = []
    ar_history = []
    count_object = 0

    D.to(device)
    P.eval().to(device)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=ml_ls, gamma=0.1)

    if crit_type in ['mce']:
        criterion = MCELoss(D)
    elif crit_type in ['UB']:
        criterion = UBLoss(D)
    else:
        raise NotImplementedError

    print('Start')
    for epoch in range(epoch_n):
        for i, data in enumerate(dataloader, 0):
            x_s, z_s = data
            b_size = x_s.size(1)
            count_object += b_size

            with torch.no_grad():
                z_s = z_s.view(z_s.size(1), z_s.size(2))
                x_p, z_p = P(z_s)

            x_s, x_p = x_s[0].to(device), x_p.to(device)
            loss = torch.mean(criterion.loss(x_s, x_p))
            loss_history.append([count_object, loss.item()])


            D.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if i % verbose == 0:
                with torch.no_grad():
                    d_sp = D(x_s, x_p)
                    d_ps = D(x_p, x_s)
                    C = torch.mean(d_ps / d_sp).item()
                    ar = torch.mean(D.ar(x_s, x_p)).item()

                print('[%d/%d][%d/%d]\tLoss: %.4f\tC: %.4f\tAR :%.4f'
                      % (epoch, epoch_n - 1, i, len(dataloader), loss.item(), C, ar))
                save_name = '%s/net_D_co_%d.pth' % (path_f, count_object)
                read_names.append(save_name)
                torch.save(D.to('cpu').state_dict(), save_name)
                ar_history.append(ar)
                D.to(device)

    ar_D = dict(zip(read_names, ar_history))
    torch.save(ar_D, '%s/_ar_history' % path_f)
    save_name = '%s/net_D_co_%d.pth' % (path_f, count_object)
    read_names.append(save_name)
    torch.save(D.to('cpu').state_dict(), save_name)
    torch.save(read_names, '%s/_D_order' % path_f)
    torch.save(loss_history, '%s/_loss_history' % path_f)
    torch.cuda.empty_cache()
    print(path_f)
    return 0
