from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms

from itertools import chain

from core import metrics, dataloader, utils, samplers
from scripts.vae import models
from time import time
from core.logger import Logger

import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--nocuda", action='store_true', default=False)
parser.add_argument("--checkpoint", type=str, default='')
flags = parser.parse_args()


class Generator(nn.Module):
    def __init__(self, nc=3, dim_z=100, ngf=64):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(dim_z, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 8, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).flatten()

    def dre(self, x):
        return self(x)/(1-self(x))


def main():
    fmt = {'lr': '.4f',
           'tr_loss': '.4f',
           'AR': '.4f',
           'time': '.3f'}
    logger = Logger(base='./logs/GAN-MH-CELEBA', name="cross-ent", fmt=fmt)
    torch.manual_seed(322)
    device = torch.device("cpu") if flags.nocuda else torch.device("cuda")

    decoder = Generator().to(device)
    discriminator = Discriminator().to(device)
    decoder_dict = torch.load(flags.checkpoint, map_location='cpu')
    decoder.load_state_dict(decoder_dict)

    dataroot = "../../data/celeba"
    image_size = 64
    trainset = datasets.ImageFolder(root=dataroot,
                                    transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
                                    ]))
    batch_size = 256
    lr_start = 1e-7
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
    optimizer = optim.Adam(discriminator.parameters(), lr=lr_start, betas=(0.9, 0.999))
    criterion = torch.nn.BCELoss(reduction='sum')

    epochs = 4
    for epoch in range(epochs):
        t0 = time()
        train_loss = 0.0
        AR = 0.0
        length = 0
        for i, (real_images, labels) in enumerate(trainloader, 0):
            real_images, labels = real_images.to(device), labels.to(device)
            z = torch.randn([real_images.shape[0], 100, 1, 1]).to(device)
            fake_images = decoder(z).detach()
            batch_images = torch.cat([real_images, fake_images], 0)
            batch_labels = torch.cat([torch.ones(real_images.shape[0]), torch.zeros(fake_images.shape[0])]).to(device)
            preds = discriminator(batch_images)
            loss = criterion(preds, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            preds_x = preds[:real_images.shape[0]].detach()
            preds_y = preds[real_images.shape[0]:].detach()
            AR += torch.sum(preds_y/(1-preds_y)*(1-preds_x)/preds_x).cpu().numpy()
            train_loss += loss.detach().cpu().numpy()
            length += real_images.shape[0]

            if i % 200 == 0:
                iter = epoch*len(trainset) + i*batch_size
                logger.add(iter, tr_loss=train_loss/length)
                logger.add(iter, AR=AR/length)
                logger.add(iter, lr=optimizer.param_groups[0]['lr'])
                logger.add(iter, time=time()-t0)
                logger.iter_info()
                logger.save(silent=True)
                torch.save([decoder.state_dict(), discriminator.state_dict()],
                           logger.get_checkpoint(iter))
                t0 = time()
                train_loss = 0.0
                AR = 0.0
                length = 0


if __name__ == '__main__':
    main()
