import sys
import os
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import argparse

from criterion_pt import train_zmce
from DG_wrapper import *


def get_params(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', required=True)
    parser.add_argument('--data_f', required=True)
    parser.add_argument('-w', '--workers', type=int, default=40)
    parser.add_argument('-b', '--batch_size', type=int, default=1)
    parser.add_argument('--dim_z', type=int, default=100)
    parser.add_argument('-e', '--epoch_n', type=int, default=2)
    parser.add_argument('-v', '--verbose', type=int, default=100)
    parser.add_argument('--image_size', type=int, default=64)
    parser.add_argument('-l', '--lr', type=float, default=2e-4)
    parser.add_argument('--lr_m', type=str)
    parser.add_argument('--device', required=True)
    parser.add_argument('--gan', required=True)
    parser.add_argument('--G_pf', default=None)
    parser.add_argument('--time', type=float, required=True)
    parser.add_argument('--crit', type=str, required=True) # mce UB
    parser.add_argument('--name', type=str, required=True)
    parser.add_argument('--relative', type=str, default=None)

    return parser.parse_args(args=args)


def main():
    sys.path.append(".")
    cudnn.benchmark = True

    params = get_params()
    lr_decay = [int(item) for item in params.lr_m.split(',')]

    if params.gan == 'dcgan_bn':
        from gan.dcgan_bn import Generator
    elif params.gan == 'dcgan_isn':
        from gan.dcgan import Generator
    elif params.gan == 'wpgan':
        from gan.wgan_wp import Generator
    else:
        raise NotImplementedError

    def weights_init(net, std=0.02):
        classname = net.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(net.weight.data, 0.0, std)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(net.weight.data, 1.0, std)
            nn.init.constant_(net.bias.data, 0)

    def get_nets(G_pf, time, device, relative):
        G = Generator()
        G.load_state_dict(torch.load(G_pf, map_location='cpu'))
        G.eval()
        trans = HmcTranstion(time)
        MG = MarkovWrap(G, trans, device)

        if relative is None:
            D = Discriminator(nc=6)
            D.apply(weights_init)
            W = DWrapper(D)
        else:
            D = LogitDiscriminator()
            W = DWrapperRelativeMCE(D)

        return MG, W


    max_batch = 323 if params.dataset == 'folder' else 196
    z_matched_dataset = ZDataset(params.data_f, max_batch)
    z_matched_loader = torch.utils.data.DataLoader(z_matched_dataset, batch_size=1, shuffle=False, num_workers=1)
    MG, W = get_nets(G_pf=params.G_pf, time=params.time, device=params.device, relative=params.relative)

    if params.relative is None:
        d_type = 'cat'
    else:
        d_type = 'rel'

    path_f = 'gG_MG_tD' + params.dataset + '_' + params.gan + '_' + d_type + '_'  + params.name
    if not os.path.exists(path_f):
        os.makedirs(path_f)

    torch.save(params, '%s/_params_dict' % (path_f))

    train_zmce(W, MG, params.crit, z_matched_loader,
               lr=params.lr, epoch_n=params.epoch_n, device=params.device,
               verbose=params.verbose, ml_ls=lr_decay, path_f=path_f)
    return path_f

if __name__ == '__main__':
    main()


