import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms as transfroms
import torch.nn.functional as F


def gan_trainer(dataloader, net_D, net_G, device, beta1=0.5, lr_D=2e-4, lr_G=2e-4,
                dim_z=100, nc=3, ngf=64, ndf=64, epoch_number=50,
                verbose=5, net_D_pf=None, net_G_pf=None, path_f='.', delta=0, aux_D=None):

    fixed_noise = torch.randn(64, dim_z, 1, 1, device=device)

    real_label = 1
    fake_label = 0

    optimizer_D = optim.Adam(net_D.parameters(), lr=lr_D, betas=(beta1, 0.999))
    if aux_D is not None:
        optimizer_aux_D = optim.Adam(aux_D.parameters(), lr=lr_D, betas=(beta1, 0.999))
    optimizer_G = optim.Adam(net_G.parameters(), lr=lr_G, betas=(beta1, 0.999))
    criterion = nn.BCELoss()

    print("Start")
    for epoch in range(epoch_number):
        for i, data in enumerate(dataloader, 0):
            # Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            net_D.zero_grad()
            real = data[0].to(device)
            b_size = real.size(0)
            label = torch.full((b_size,), real_label, device=device)
            output = net_D(real).view(-1)
            loss_D_real = criterion(output, label)
            loss_D_real.backward()

            # stats
            D_x = output.mean().item()

            noise = torch.randn(b_size, dim_z, 1, 1, device=device)
            fake = net_G(noise)
            label.fill_(fake_label)
            output = net_D(fake.detach()).view(-1)
            loss_D_fake = criterion(output, label)
            loss_D_fake.backward()
            optimizer_D.step()

            # stats
            loss_D = loss_D_real.item() + loss_D_fake.item()
            D_G_z1 = output.mean().item()

            # Update G network: maximize log(D(G(z)))
            net_G.zero_grad()
            label.fill_(real_label)
            output = net_D(fake).view(-1)
            loss_G = criterion(output, label)
            loss_G.backward()
            optimizer_G.step()

            # stats
            D_G_z2 = output.mean().item()

            # Update aux discriminator
            if aux_D is not None:
                batch = torch.cat([real, fake])
                label_true = torch.full((b_size,), real_label, device=device)
                label_fake = torch.full((b_size,), fake_label, device=device)
                label = torch.cat([label_true, label_fake])

                output_aux = aux_D(batch).view(-1)
                loss_D_aux = criterion(output_aux, label)
                loss_D_aux.backward()
                optimizer_aux_D.step()

            if (i % (10 * verbose) == 0):
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, epoch_number, i, len(dataloader),
                         loss_D, loss_G.item(), D_x, D_G_z1, D_G_z2))

        if (epoch % verbose == 0):
            torch.save(net_G.state_dict(), '%s/net_G_epoch_%d.pth' % (path_f, epoch+delta))
            torch.save(net_D.state_dict(), '%s/net_D_epoch_%d.pth' % (path_f, epoch+delta))
            if aux_D is not None:
                torch.save(aux_D.state_dict(), '%s/aux_D_epoch_%d.pth' % (path_f, epoch+delta))
            fake = net_G(fixed_noise)
            vutils.save_image(fake.detach(), '%s/img_grid_G/net_G_samples_epoch_%03d.png' % (path_f, epoch+delta),
                              normalize=True)
    return 0
