
import logging
import dataset
import utils
import os
import torch
import numpy as np
import matplotlib
import time
import argparse
import json
import random
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

os.environ["CUDA_VISIBLE_DEVICES"]="1,0"

parser = argparse.ArgumentParser(description='Training ProxyNCA++')
parser.add_argument('--epochs', default = 40, type=int, dest = 'nb_epochs')
parser.add_argument('--log-filename', default = 'example')
parser.add_argument('--lr_steps', default=[1000], nargs='+', type=int)
parser.add_argument('--source_dir', default='', type=str)
parser.add_argument('--root_dir', default='', type=str)
parser.add_argument('--recall', default=[1, 2, 4, 8], nargs='+', type=int)
parser.add_argument('--init_eval', default=False, action='store_true')
parser.add_argument('--apex', default=False, action='store_true')
parser.add_argument('--warmup_k', default=5, type=int)

parser.add_argument('--dataset', default='inshop_noisy')
parser.add_argument('--seed', default=3, type=int)
parser.add_argument('--eval_nmi', default=True, action='store_true')
parser.add_argument('--embedding-size', default = 512, type=int, dest = 'sz_embedding')
parser.add_argument('--config', default='config/inshop_SoftTriple.json')
parser.add_argument('--mode', default='trainval', choices=['train', 'trainval',
                                                           'test', 'testontrain'],
                    help='train with train data or train with trainval')
parser.add_argument('--batch-size', default = 32, type=int, dest = 'sz_batch')
parser.add_argument('--no_warmup', default=False, action='store_true')
parser.add_argument('--mislabel_percentage', default=0.1, type=float)
parser.add_argument('--loss-type', default='SoftTriple_noisy_0.1', type=str)
parser.add_argument('--workers', default = 4, type=int, dest = 'nb_workers')

args = parser.parse_args()

def save_best_checkpoint(model):
    torch.save(model.state_dict(), 'results/' + args.log_filename + '.pt')

def load_best_checkpoint(model):
    try:
        model.load_state_dict(torch.load('results/' + args.log_filename + '.pt'))
    except FileNotFoundError:
        model.load_state_dict(torch.load('results/' + args.log_filename + '.pth'))
    model = model.cuda()
    return model

if __name__ == '__main__':

    # set random seed for all gpus
    random.seed(args.seed)
    np.random.seed(args.seed) # FIXME: to not set seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    os.makedirs('results', exist_ok=True); os.makedirs('log', exist_ok=True)

    curr_fn = os.path.basename(args.config).split(".")[0]
    config = utils.load_config(args.config)
    dataset_config = utils.load_config('dataset/config.json')

    if args.source_dir != '':
        bs_name = os.path.basename(dataset_config['dataset'][args.dataset]['source'])
        dataset_config['dataset'][args.dataset]['source'] = os.path.join(args.source_dir, bs_name)
    if args.root_dir != '':
        bs_name = os.path.basename(dataset_config['dataset'][args.dataset]['root'])
        dataset_config['dataset'][args.dataset]['root'] = os.path.join(args.root_dir, bs_name)

    #set NMI or recall accordingly depending on dataset. note for cub and cars R=1,2,4,8
    if (args.mode =='trainval' or args.mode == 'test' or args.mode == 'testontrain'):
        if args.dataset == 'sop' or args.dataset == 'sop_h5':
            args.recall = [1, 10, 100, 1000]
        args.eval_nmi = True

    args.nb_epochs = config['nb_epochs']
    args.sz_batch = config['sz_batch']
    args.sz_embedding = config['sz_embedding']
    if 'warmup_k' in config:
        args.warmup_k = config['warmup_k']

    transform_key = 'transform_parameters'
    if 'transform_key' in config.keys():
        transform_key = config['transform_key']
    print('Transformation: ', transform_key)

    out_results_fn = "log/%s_%s_%s_%d_%d_loss%s.json" % (args.dataset, curr_fn,
                                                         args.mode, args.sz_embedding,
                                                        args.seed,
                                                        args.loss_type)

    args.log_filename = '%s_%s_%s_%d_%d_loss%s' % (args.dataset, curr_fn,
                                                   args.mode, args.sz_embedding,
                                                   args.seed,
                                                   args.loss_type)

    if args.mode == 'test':
        args.log_filename = args.log_filename.replace('test', 'trainval')
    elif args.mode == 'testontrain':
        args.log_filename = args.log_filename.replace('testontrain', 'trainval')
    best_epoch = args.nb_epochs

    '''Dataloader'''
    if args.mode == 'trainval':
        train_results_fn = "log/%s_%s_%s_%d_%d_loss%s.json" % (args.dataset, curr_fn,
                                                               args.mode, args.sz_embedding,
                                                               args.seed,
                                                               args.loss_type)

        if os.path.exists(train_results_fn):
            with open(train_results_fn, 'r') as f:
                train_results = json.load(f)
            args.lr_steps = train_results['lr_steps']
            best_epoch = train_results['best_epoch']

    train_transform = dataset.utils.make_transform(
                **dataset_config[transform_key]
    )
    print('best_epoch', best_epoch)

    results = {}

    if ('inshop' not in args.dataset ):
        dl_ev = torch.utils.data.DataLoader(
            dataset.load(
                name = args.dataset.split('_noisy')[0],
                root = dataset_config['dataset'][args.dataset]['root'],
                source = dataset_config['dataset'][args.dataset]['source'],
                classes = dataset_config['dataset'][args.dataset]['classes']['eval'],
                transform = dataset.utils.make_transform(
                    **dataset_config[transform_key],
                    is_train = False
                ),
            ),
            batch_size = args.sz_batch,
            shuffle = False,
            num_workers = args.nb_workers,
        )
    else:
        #inshop trainval mode
        dl_query = torch.utils.data.DataLoader(
            dataset.load_inshop(
                name = args.dataset.split('_noisy')[0],
                root = dataset_config['dataset'][args.dataset]['root'],
                source = dataset_config['dataset'][args.dataset]['source'],
                classes = dataset_config['dataset'][args.dataset]['classes']['eval'],
                transform = dataset.utils.make_transform(
                    **dataset_config[transform_key],
                    is_train = False
                ),
                dset_type = 'query'
            ),
            batch_size = args.sz_batch,
            shuffle = False,
            num_workers = args.nb_workers,
            #pin_memory = True
        )
        dl_gallery = torch.utils.data.DataLoader(
            dataset.load_inshop(
                name = args.dataset.split('_noisy')[0],
                root = dataset_config['dataset'][args.dataset]['root'],
                source = dataset_config['dataset'][args.dataset]['source'],
                classes = dataset_config['dataset'][args.dataset]['classes']['eval'],
                transform = dataset.utils.make_transform(
                    **dataset_config[transform_key],
                    is_train = False
                ),
                dset_type = 'gallery'
            ),
            batch_size = args.sz_batch,
            shuffle = False,
            num_workers = args.nb_workers,
            #pin_memory = True
        )

    logging.basicConfig(
        format="%(asctime)s %(message)s",
        level=logging.INFO,
        handlers=[
            logging.FileHandler("{0}/{1}.log".format('log', args.log_filename)),
            logging.StreamHandler()
        ]
    )

    if 'inshop' not in args.dataset:
        tr_dataset = dataset.load_noisy(
                name = args.dataset,
                root = dataset_config['dataset'][args.dataset]['root'],
                source = dataset_config['dataset'][args.dataset]['source'],
                classes = dataset_config['dataset'][args.dataset]['classes']['trainval'],
                transform = train_transform,
                seed = args.seed,
                mislabel_percentage=args.mislabel_percentage
        )
    else:
        tr_dataset = dataset.load_noisy_inshop(
            name=args.dataset,
            root=dataset_config['dataset'][args.dataset]['root'],
            source=dataset_config['dataset'][args.dataset]['source'],
            classes=dataset_config['dataset'][args.dataset]['classes']['trainval'],
            transform=train_transform,
            seed=args.seed,
            dset_type='train',
            mislabel_percentage=args.mislabel_percentage
        )

    num_class_per_batch = config['num_class_per_batch']
    num_gradcum = config['num_gradcum']
    is_random_sampler = config['is_random_sampler']
    if is_random_sampler:
        batch_sampler = dataset.utils.RandomBatchSampler(tr_dataset.ys, args.sz_batch, True, num_class_per_batch, num_gradcum)
    else:

        batch_sampler = dataset.utils.BalancedBatchSampler(torch.Tensor(tr_dataset.ys), num_class_per_batch,
                                                           int(args.sz_batch / num_class_per_batch))


    dl_tr = torch.utils.data.DataLoader(
        tr_dataset,
        batch_sampler = batch_sampler,
        num_workers = args.nb_workers,
    )


    print("===")
    '''Model'''
    feat = config['model']['type']()
    feat.eval()
    in_sz = feat(torch.rand(1, 3, 256, 256)).squeeze().size(0)
    feat.train()
    emb = torch.nn.Linear(in_sz, args.sz_embedding)
    model = torch.nn.Sequential(feat, emb)
    model = torch.nn.DataParallel(model)
    model = model.cuda()

    '''Loss'''
    criterion = config['criterion']['type'](
        nb_classes = dl_tr.dataset.nb_classes(),
        sz_embed = args.sz_embedding,
        **config['criterion']['args']
    ).cuda()

    opt_warmup = config['opt']['type'](
        [
            {
                **{'params': list(feat.parameters()
                    )
                },
                'lr': 0
            },
            {
                **{'params': list(emb.parameters()
                    )
                },
                **config['opt']['args']['embedding']

            },

            {
                **{'params': criterion.proxies}
                ,
                **config['opt']['args']['proxynca']

            },
        ],
        **config['opt']['args']['base']
    )

    # options for model and loss
    opt = config['opt']['type'](
        [
            {
                **{'params': list(feat.parameters()
                    )
                },
                **config['opt']['args']['backbone']
            },
            {
                **{'params': list(emb.parameters()
                    )
                },
                **config['opt']['args']['embedding']
            },

            {
                **{'params': criterion.proxies},
                **config['opt']['args']['proxynca']
            },
        ],
        **config['opt']['args']['base']
    )

    if args.mode == 'test':
        with torch.no_grad():
            logging.info("**Evaluating...(test mode)**")
            model = load_best_checkpoint(model)
            if 'inshop' in args.dataset:
                utils.evaluate_inshop(model, dl_query, dl_gallery)
            else:
                utils.evaluate(model, dl_ev, args.eval_nmi, args.recall)
        exit() # exit the program

    elif args.mode == 'trainval':
        scheduler = config['lr_scheduler2']['type'](
            opt,
            milestones=args.lr_steps,
            gamma=0.1
            #opt, **config['lr_scheduler2']['args']
        )

    logging.info("Training parameters: {}".format(vars(args)))
    logging.info("Training for {} epochs.".format(args.nb_epochs))
    losses = []
    scores = []
    scores_tr = []

    t1 = time.time()

    if args.init_eval:
        logging.info("**Evaluating initial model...**")
        with torch.no_grad():
            c_dl = dl_ev
            utils.evaluate(model, c_dl, args.eval_nmi, args.recall) #dl_val

    it = 0
    best_val_hmean = 0
    best_val_nmi = 0
    best_val_epoch = 0
    best_val_r1 = 0
    best_val_mapr = 0
    best_test_nmi = 0
    best_test_r1 = 0
    best_test_r2 = 0
    best_test_r5 = 0
    best_test_r8 = 0
    best_tnmi = 0
    best_test_mapr = 0

    prev_lr = opt.param_groups[0]['lr']
    lr_steps = []

    logging.info('Number of training: {}'.format(len(dl_tr.dataset)))
    if 'inshop' not in args.dataset:
        logging.info('Number of testing: {}'.format(len(dl_ev.dataset)))
    else:
        logging.info('Number of query set: {}'.format(len(dl_query.dataset)))
        logging.info('Number of gallery set: {}'.format(len(dl_gallery.dataset)))

    '''Warmup training'''
    if not args.no_warmup:
        #warm up training for 5 epochs
        logging.info("**warm up for %d epochs.**" % args.warmup_k)
        for e in range(0, args.warmup_k):
            for ct, (x, y, _) in tqdm(enumerate(dl_tr)):
                opt_warmup.zero_grad()
                m = model(x.cuda())
                loss = criterion(m, None, y.cuda())
                loss.backward()
                torch.nn.utils.clip_grad_value_(model.parameters(), 10)
                opt_warmup.step()
            logging.info('warm up ends in %d epochs' % (args.warmup_k-e))


    '''training loop'''
    for e in range(0, args.nb_epochs):

        time_per_epoch_1 = time.time()
        losses_per_epoch = []
        tnmi = []

        for ct, (x, y, indices) in tqdm(enumerate(dl_tr)):
            it += 1
            x, y = x.cuda(), y.cuda()
            m = model(x)
            loss = criterion(m, indices, y)
            opt.zero_grad()
            loss.backward() # backprop
            torch.nn.utils.clip_grad_value_(model.parameters(), 10) # clip gradient?
            opt.step() # gradient descent

            losses_per_epoch.append(loss.data.cpu().numpy())

        time_per_epoch_2 = time.time()
        losses.append(np.mean(losses_per_epoch[-20:]))

        print('it: {}'.format(it))
        print(opt)
        logging.info(
            "Epoch: {}, loss: {:.3f}, time (seconds): {:.2f}.".format(
                e,
                losses[-1],
                time_per_epoch_2 - time_per_epoch_1
            )
        )

        model.losses = losses
        model.current_epoch = e


        if e == best_epoch:
            break

        if e == args.nb_epochs-1:
            save_dir = 'models/dvi_data_{}_{}_loss{}/ResNet_{}_Model'.format(args.dataset, args.seed,
                                                                      args.loss_type, str(args.sz_embedding))
            os.makedirs('{}'.format(save_dir), exist_ok=True)
            os.makedirs('{}/Epoch_{}'.format(save_dir, e+1), exist_ok=True)
            torch.save(model.state_dict(), '{}/Epoch_{}/{}_{}_{}_{}_{}.pth'.format(save_dir, e+1, args.dataset,
                                                                                   args.dataset, args.mode,
                                                                                   str(args.sz_embedding), str(args.seed)))
            torch.save({"proxies": criterion.proxies},
                       '{}/Epoch_{}/proxy.pth'.format(save_dir, e+1))
        ######################################################################################

        if args.mode == 'trainval':
            scheduler.step(e) # adjust learning rate

    if args.mode == 'trainval':
        save_best_checkpoint(model)

        with torch.no_grad():
            logging.info("**Evaluating...**")
            model = load_best_checkpoint(model)
            if 'inshop' in args.dataset:
                best_test_nmi, (best_test_r1, best_test_r10, best_test_r20, best_test_r30, best_test_r40, best_test_r50), best_mapr = utils.evaluate_inshop(model, dl_query, dl_gallery)
            else:
                best_test_nmi, (best_test_r1, best_test_r2, best_test_r4, best_test_r8), best_mapr = utils.evaluate(model, dl_ev, args.eval_nmi, args.recall)
            #logging.info('Best test r8: %s', str(best_test_r8))
        if 'inshop' in args.dataset:
            results['NMI'] = best_test_nmi
            results['R1']  = best_test_r1
            results['R10'] = best_test_r10
            results['R20'] = best_test_r20
            results['R30'] = best_test_r30
            results['R40'] = best_test_r40
            results['R50'] = best_test_r50
            results['MAP@R'] = best_test_mapr

        else:
            results['NMI'] = best_test_nmi
            results['R1'] = best_test_r1
            results['R2'] = best_test_r2
            results['R4'] = best_test_r4
            results['R8'] = best_test_r8
            results['MAP@R'] = best_test_mapr

    if args.mode == 'train':
        print('lr_steps', lr_steps)
        results['lr_steps'] = lr_steps

    with open(out_results_fn,'w') as outfile:
        json.dump(results, outfile)

    t2 = time.time()
    logging.info("Total training time (minutes): {:.2f}.".format((t2 - t1) / 60))
