import sys
import shutil
import os
import torch
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import argparse
from evaluation import g_scores_evaluation, imh_scores_evaluation


def get_params(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset', required=True, help='folder|cifar')
    parser.add_argument('-r', '--dataroot', required=True)
    parser.add_argument('-w', '--workers', type=int, default=40)
    parser.add_argument('-b', '--batch_size', type=int, default=2048)
    parser.add_argument('--dim_z', type=int, default=100)
    parser.add_argument('--device', required=True)
    parser.add_argument('--gan', required=True)
    parser.add_argument('--D_pf', default=None)
    parser.add_argument('--G_pf', default=None)
    parser.add_argument('--image_size', type=int, default=64)
    parser.add_argument('--act_path', required=True)
    parser.add_argument('--chain_size', type=int, default=10000)
    parser.add_argument('--g_loops', type=int, default=5)
    parser.add_argument('--mh_loops', type=int, default=5)
    parser.add_argument('--wrap', required=True, type=int, default=1)

    return parser.parse_args(args=args)


def get_init(dataset, dataroot, image_size, batch_size):
    if dataset in ['folder']:
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                   ]))
    elif dataset in ['cifar']:
        dataset = dset.CIFAR10(root=dataroot, download=False,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))
    else:
        raise NotImplementedError

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1)

    return next(iter(dataloader))[0]


def main():
    sys.path.append(".")
    cudnn.benchmark = True
    params = get_params()

    if params.gan == 'dcgan_bn':
        from gan.dcgan_bn import Generator, Discriminator
        from D_wraper import WrapD as Wrapper
    elif params.gan == 'dcgan_isn':
        from gan.dcgan import Generator, Discriminator
        from D_wraper import WrapD as Wrapper
    elif params.gan == 'wpgan':
        from gan.wgan_wp import Generator, Discriminator
        from D_wraper import WrapCriticD as Wrapper
    else:
        raise NotImplementedError

    def get_nets(G_pf, D_pf, wrap):
        G = Generator()
        G.load_state_dict(torch.load(G_pf, map_location='cpu'))
        G.eval()
        D = Discriminator()
        W = Wrapper(D)
        W.load_state_dict(torch.load(D_pf, map_location='cpu'))
        W.eval()
        return G, W

    init = get_init(dataset=params.dataset, dataroot=params.dataroot,
                    image_size=params.image_size, batch_size=params.batch_size)
    print(init.size(0))

    device = torch.device(params.device)
    G, W = get_nets(G_pf=params.G_pf, D_pf=params.D_pf, wrap=params.wrap)

    path_f = 'gG_tD_score' + params.dataset + '_' + params.gan
    if not os.path.exists(path_f):
        os.makedirs(path_f)
    else:
        shutil.rmtree(path_f)
        os.makedirs(path_f)

    score_g = []
    for i in range(params.g_loops):
        is_g, fid_g = g_scores_evaluation(G, params.dim_z, init, device,
                                          params.act_path, chain_size=params.chain_size)
        score_g.append((is_g, fid_g))

    print('-----------------')

    score_mh = []
    for i in range(params.mh_loops):
        is_mh, fid_mh = imh_scores_evaluation(W, G, params.dim_z, init, device,
                                              params.act_path, chain_size=params.chain_size)
        score_mh.append((is_mh, fid_mh))

    torch.save(score_g, '%s/_score_g' % path_f)
    torch.save(score_mh, '%s/_score_mh' % path_f)
    return 0


if __name__ == '__main__':
    main()
