import torch
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import argparse
from gan_trainer_bce import gan_trainer

def get_params(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', required=True, help='folder|cifar')
    parser.add_argument('--dataroot', required=True)
    parser.add_argument('--workers', type=int, default=2)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--dim_z', type=int, default=100)
    parser.add_argument('--nc', type=int, default=3, help='channels number')
    parser.add_argument('--ngf', type=int, default=64, help='number of feature maps in generator')
    parser.add_argument('--ndf', type=int, default=64, help='number of feature maps in discriminator')
    parser.add_argument('--epoch_n', type=int, default=151)
    parser.add_argument('--verbose', type=int, default=5, help='higher -> less saving, printing .etc')
    parser.add_argument('--net_D_pf', default=None, help='path to load discriminator net')
    parser.add_argument('--net_G_pf', default=None, help='path to load generator net')
    parser.add_argument('--path_f', required=True, help='path to save learned networks')
    parser.add_argument('--image_size', type=int, default=64)
    parser.add_argument('--lr_D', type=float, default=2e-4)
    parser.add_argument('--lr_G', type=float, default=2e-4)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--device',  required=True)
    parser.add_argument('--normalization', required=True, help='spectral|instance|batch')
    parser.add_argument('--delta_save', type=int, default=0)
    parser.add_argument('--aux_D', type=int, default=0)
    parser.add_argument('--aux_D_pf', default=None, help='path to load aux discriminator')

    return parser.parse_args(args=args)


def get_dataloader(dataset, dataroot, image_size, batch_size, workers):
    if dataset in ['folder']:
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                   ]))
    elif dataset in ['cifar']:
        dataset = dset.CIFAR10(root=dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))
    else:
        raise NotImplementedError

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
    return dataloader


def main():
    cudnn.benchmark = True

    params = get_params()
    
    if params.normalization == 'spectral':
        from dcgan_spn import Generator, Discriminator, weights_init
    elif params.normalization == 'instance':
        from dcgan import Generator, Discriminator, weights_init
    elif params.normalization == 'batch':
        from dcgan_bn import Generator, Discriminator, weights_init
    else:
        raise NotImplementedError

    def get_nets(device, dim_z, nc, ndf, ngf, net_D_pf, net_G_pf, aux_param, aux_D_pf):
        aux_D = None
        net_D = Discriminator(nc, ndf).to(device)
        if net_D_pf is None:
            net_D.apply(weights_init)
        else:
            net_D.load_state_dict(torch.load(net_D_pf, map_location=device))

        net_G = Generator(nc, dim_z, ngf).to(device)
        if net_G_pf is None:
            net_G.apply(weights_init)
        else:
            net_G.load_state_dict(torch.load(net_G_pf, map_location=device))

        if aux_param == 1:
            aux_D = Discriminator(nc, ndf).to(device)
            if aux_D_pf is None:
                aux_D.apply(weights_init)
            else:
                aux_D.load_state_dict(torch.load(aux_D_pf, map_location=device))

        return net_D, net_G, aux_D

    dataloader = get_dataloader(dataset=params.dataset, dataroot=params.dataroot,
                                image_size=params.image_size, batch_size=params.batch_size, workers=params.workers)

    device = torch.device(params.device) if torch.cuda.is_available() else torch.device("cpu")

    net_D, net_G, aux_D = get_nets(device,
                            dim_z=params.dim_z, nc=params.nc, ndf=params.ndf, ngf=params.ngf,
                            net_D_pf=params.net_D_pf, net_G_pf=params.net_G_pf,
                            aux_param=params.aux_D, aux_D_pf=params.aux_D_pf)

    if params.net_D_pf is None:
        delta = 0
    else:
        delta = int(params.delta_save)

    gan_trainer(dataloader=dataloader, net_D=net_D, net_G=net_G,
                device=device, lr_D=params.lr_D, lr_G=params.lr_G, beta1=params.beta1,
                dim_z=params.dim_z, epoch_number=params.epoch_n,
                verbose=params.verbose,  path_f=params.path_f, delta=delta, aux_D=aux_D)

    return 0


if __name__ == '__main__':
    main()
