import sys
import os
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import argparse
from bce_pt import train_bce


def get_params(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset', required=True, help='folder|cifar')
    parser.add_argument('-r', '--dataroot', required=True)
    parser.add_argument('-w', '--workers', type=int, default=40)
    parser.add_argument('-b', '--batch_size', type=int, default=128)
    parser.add_argument('--dim_z', type=int, default=100)
    parser.add_argument('-e', '--epoch_n', type=int, default=2)
    parser.add_argument('-v', '--verbose', type=int, default=100)
    parser.add_argument('--image_size', type=int, default=64)
    parser.add_argument('-l', '--lr', type=float, default=2e-4)
    parser.add_argument('--lr_m', type=str)
    parser.add_argument('--device', required=True)
    parser.add_argument('--gan', required=True)
    parser.add_argument('--G_pf', default=None)

    return parser.parse_args(args=args)


def get_dataloader(dataset, dataroot, image_size, batch_size, workers):
    if dataset in ['folder']:
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                   ]))
    elif dataset in ['cifar']:
        dataset = dset.CIFAR10(root=dataroot, download=False,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))
    else:
        raise NotImplementedError

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers)

    return dataloader


def main():
    sys.path.append(".")
    cudnn.benchmark = True

    params = get_params()
    lr_decay = [int(item) for item in params.lr_m.split(',')]

    if params.gan == 'dcgan_bn':
        from gan.dcgan_bn import Generator, Discriminator
        from D_wraper import WrapD as Wrapper
    elif params.gan == 'dcgan_isn':
        from gan.dcgan import Generator, Discriminator
        from D_wraper import WrapD as Wrapper
    elif params.gan == 'wpgan':
        from gan.wgan_wp import Generator, Discriminator
        from D_wraper import WrapCriticD as Wrapper
    else:
        raise NotImplementedError

    def weights_init(net, std=0.02):
        classname = net.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(net.weight.data, 0.0, std)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(net.weight.data, 1.0, std)
            nn.init.constant_(net.bias.data, 0)

    def get_nets(G_pf):
        G = Generator()
        G.load_state_dict(torch.load(G_pf, map_location='cpu'))
        G.eval()
        D = Discriminator()
        D.apply(weights_init)
        W = Wrapper(D)

        return G, W

    dataloader = get_dataloader(dataset=params.dataset, dataroot=params.dataroot,
                                image_size=params.image_size,
                                batch_size=params.batch_size,
                                workers=params.workers)

    device = torch.device(params.device)
    G, W = get_nets(G_pf=params.G_pf)

    path_f = 'gG_tD_' + params.dataset + '_' + params.gan
    if not os.path.exists(path_f):
        os.makedirs(path_f)

    torch.save(params, '%s/_params_dict' % (path_f))

    train_bce(D=W, G=G, dataloader=dataloader,
              lr=params.lr, epoch_n=params.epoch_n, device=device,
              verbose=params.verbose, ml_ls=lr_decay, path_f=path_f)
    return path_f


if __name__ == '__main__':
    main()
