import torch
import torch.nn as nn
import torch.optim as optim

import torch.utils.data
import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms as transfroms
import torch.nn.functional as F


def diff_liphistz_penalty(x, y, net_D, C=10.):
    d_x = net_D(x)
    d_y = net_D(y)

    norm_up = torch.norm(d_x - d_y, p=2) / torch.norm(x - y, p=2)
    penalty = (C * (norm_up - 1) ** 2).mean()
    return penalty


def gan_trainer(dataloader, net_D, net_G, device, beta1=0.5, lr_D=2e-4, lr_G=2e-4,
                dim_z=100, nc=3, ngf=64, ndf=64, epoch_number=50,
                verbose=5, net_D_pf=None, net_G_pf=None, path_f='.', delta=0, aux_D=None):

    fixed_noise = torch.randn(64, dim_z, 1, 1, device=device)

    if aux_D is not None:
        optimizer_aux_D = optim.Adam(aux_D.parameters(), lr=lr_D, betas=(beta1, 0.999))
        aux_criterion = nn.BCELoss()
    optimizer_D = optim.Adam(net_D.parameters(), lr=lr_D, betas=(beta1, 0.999))
    optimizer_G = optim.Adam(net_G.parameters(), lr=lr_G, betas=(beta1, 0.999))

    print("Start")
    for epoch in range(epoch_number):
        for i, data in enumerate(dataloader, 0):
            # Update D network
            net_D.zero_grad()
            real = data[0].to(device)
            b_size = real.size(0)
            output_real = net_D(real).view(-1)

            noise = torch.randn(b_size, dim_z, 1, 1, device=device)
            fake = net_G(noise)
            output_fake = net_D(fake.detach()).view(-1)

            loss = output_fake.mean() - output_real.mean()

            u1 = torch.rand(b_size, 1, 1, 1).to(device)
            u2 = torch.rand(b_size, 1, 1, 1).to(device)

            noise_inter_x = torch.randn(b_size, dim_z, 1, 1, device=device)
            noise_inter_y = torch.randn(b_size, dim_z, 1, 1, device=device)
            fake_inter_x = net_G(noise_inter_x)
            fake_inter_y = net_G(noise_inter_y)
            inter_x = (1. - u1) * real + u1 * fake_inter_x
            inter_y = (1. - u2) * real + u2 * fake_inter_y
            penalty = diff_liphistz_penalty(inter_x, inter_y, net_D)

            loss_D = loss + penalty
            net_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

            # Update G network
            net_G.zero_grad()
            loss_G = - net_D(fake).mean()
            loss_G.backward()
            optimizer_G.step()

            # Update aux discriminator
            if aux_D is not None:
                batch = torch.cat([real, fake])
                label_true = torch.full((b_size,), real_label, device=device)
                label_fake = torch.full((b_size,), fake_label, device=device)
                label = torch.cat([label_true, label_fake])

                output_aux = aux_D(batch).view(-1)
                loss_D_aux = aux_criterion(output_aux, label)
                loss_D_aux.backward()
                optimizer_aux_D.step()

            if (i % (10 * verbose) == 0):
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tPenalty: %.4f' %
                      (epoch, epoch_number, i, len(dataloader), loss_D.item(), loss_G.item(), penalty.item()))

        if (epoch % verbose == 0):
            torch.save(net_G.state_dict(), '%s/net_G_epoch_%d.pth' % (path_f, epoch+delta))
            torch.save(net_D.state_dict(), '%s/net_D_epoch_%d.pth' % (path_f, epoch+delta))
            if aux_D is not None:
                torch.save(aux_D.state_dict(), '%s/aux_D_epoch_%d.pth' % (path_f, epoch+delta))
            fake = net_G(fixed_noise)
            vutils.save_image(fake.detach(), '%s/img_grid_G/net_G_samples_epoch_%03d.png' % (path_f, epoch+delta),
                              normalize=True)
    return 0
