import random
import time
import argparse
import shutil
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import os.path as osp
import pickle
import copy

import tqdm
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision

from utils import utils
from utils import tsne
from utils.meter import AverageMeter, ProgressMeter
from utils.logger import CompleteLogger

from models.VPT import VisualPromptTuningCLIP # for transformer

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()
_tf_toTensor = torchvision.transforms.ToTensor() 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.autograd.set_detect_anomaly(True)


def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    utils.seed_fix(args.seed)

    # contrastive learning sampler
    # Data loading code
    train_neg_transform = utils.get_negative_transform()
    train_pos_transform = utils.get_positive_transform()
    val_transform = utils.get_val_transform()
    val_target_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.ColorJitter(
                # brightness=(.5,.5),
                # contrast=(0.3,0.3),
                # saturation=(1.7, 1.7),
                hue=(0.4, 0.4)
            ),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=args.norm_mean, std=args.norm_std)
        ]
    )

    print("train_pos_transform: ", train_pos_transform)
    print("train_neg_transform: ", train_neg_transform)
    print("val_transform: ", val_transform)
    print("val_target_transform: ", val_target_transform)
    
    train_pickle_dataset, val_pickle_dataset, num_classes, args.class_names = \
        utils.get_dataset(args.data, args.root, \
                        args.source, args.target, train_neg_transform, val_transform, train_pos_transform, val_target_transform)
    train_pickle_loader = DataLoader(train_pickle_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=False)
    val_pickle_loader = DataLoader(val_pickle_dataset, batch_size=args.batch_size,
                                     shuffle=False, num_workers=args.workers, drop_last=False)
    
    train_loader = train_pickle_loader
    val_loader = val_pickle_loader

    # create model
    print("=> using clip_model model '{}'".format(args.arch))
    clip_model = utils.get_model(args.arch, pretrain=not args.scratch)
    
    dummy_input = torch.randn([1, 3, 224, 224]).to(device)
    feature = clip_model.encode_image(dummy_input)
    feat_dim = feature.size(-1)
    
    if args.model_cls == "VPTCLIP":
        class_list = args.class_names
        model = VisualPromptTuningCLIP(clip_model, 
                            class_list,
                            feat_dim, device, 
                            clip_model_type=args.arch, DeepPrompt=args.deep, n_vtk=8).to(device)
    else:
        raise NotImplementedError

    if args.resume:
        model.load_state_dict(torch.load(args.resume_path))

    # train the prompt wish multi domain data
    train_func = train_contrastive

    # define optimizer and lr scheduler
    if args.resume:
        optimizer = Adam(model.get_parameters(), lr=5e-54,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
    else:
        optimizer = SGD(model.get_parameters(), args.lr,
                        momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

    # start training    
    for epoch in range(args.epochs):
        print(lr_scheduler.get_last_lr())
        
        # train for one epoch
        train_func(train_loader, model, optimizer, lr_scheduler, epoch, args)
        
        # save checkpoint
        torch.save(model.state_dict(), logger.get_checkpoint_path(args.objective +"_"+ '_latest'))
        
        # evaluate
        if epoch % args.print_freq == 0:
            # latent space analysis
            visualize(model, args, logger, post_fix='_'+str(epoch), sampler=val_loader)
        torch.cuda.empty_cache()

    if args.phase == "analysis":
        model.load_state_dict(torch.load(args.ckpt_path))
        # latent space analysis
        visualize(model, args, logger, post_fix='_'+args.phase, sampler=val_loader)
        exit()
    else:
        print("Pretraining process is done.")
    
    print("***Last model evaluation")
    visualize(model, args, logger, post_fix="_latest", sampler=val_loader)


def train_contrastive(train_loader, model,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):

    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    total_losses = AverageMeter('Total Loss', ':6.6f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, total_losses],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, batch in enumerate(train_loader):
        query = batch[0]
        key = batch[1]
        # measure data loading time
        data_time.update(time.time() - end)
        
        # w.r.t state variation
        # torchvision.utils.save_image(torchvision.utils.make_grid(query, nrow=16, normalize=True), "grid_image.png")
        # torchvision.utils.save_image(torchvision.utils.make_grid(key, nrow=16, normalize=True), "grid_image_.png")
        # exit()
        
        query = query.to(device)
        key = key.to(device)
        
        x_pos = torch.cat((query, key), dim=0)        
        
        pos_image_features = model(x_pos)

        f_query, f_key = pos_image_features.chunk(2, dim=0)
        
        # auxilary
        mse_loss = F.mse_loss(f_query, f_key)

        logits_per_image_q = model.logit_scale.exp() * f_query @ f_key.t()
        logits_per_image_k = logits_per_image_q.t()
        
        labels = torch.arange(len(query), dtype=torch.long, device=device)

        vis_loss = (F.cross_entropy(logits_per_image_q, labels) + F.cross_entropy(logits_per_image_k, labels))/2

        total_loss = vis_loss + mse_loss

        # Compute gradient and do SGD step
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        '''
        for param in model.textual_backbone.parameters():
            if param.grad is not None:
                print(param[0][0][0])
        '''
        total_losses.update(total_loss.item(), f_query.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)


def visualize(model, args, logger, post_fix, sampler=None):
    print("Val Plot")
    model.eval()
    # latent space analysis
    mode = args.contrastive_task
    
    source_feature = []
    target_feature = []
    for i, batch in enumerate(sampler):
        # source
        x = batch[0]
        x = x.to(device)
        f_s = model(x)
        f_s = f_s.cpu().detach()
        source_feature.append(f_s)
        # target
        x_aug = batch[1]
        x_aug = x_aug.to(device)
        # torchvision.utils.save_image(torchvision.utils.make_grid(x, nrow=16, normalize=True), "grid_image_val.png")
        # torchvision.utils.save_image(torchvision.utils.make_grid(x_aug, nrow=16, normalize=True), "grid_image_val_.png")
        
        f_t = model(x_aug)
        f_t = f_t.cpu().detach()
        target_feature.append(f_t)
        if i == 20:
            break
    source_feature = torch.cat(source_feature, dim=0)
    target_feature = torch.cat(target_feature, dim=0)
    
    # plot t-SNE
    tSNE_filename = osp.join(logger.visualize_directory, args.objective+post_fix+'_val_image-pair_TSNE.pdf')
    tsne.visualize(source_feature, target_feature, tSNE_filename)
    print("Saving t-SNE to", tSNE_filename)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Contrastive Learning for MMRL')
    # dataset parameters
    parser.add_argument('root', metavar='DIR',
                        help='root path of dataset')
    parser.add_argument('-d', '--data', metavar='DATA', default='', choices=utils.get_dataset_names(),
                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
                             ' (default: )')
    parser.add_argument('-s', '--source', help='source dataset', nargs='+')
    parser.add_argument('-t', '--target', help='target dataset', nargs='+')
    parser.add_argument('--norm-mean', type=float, nargs='+',
                        default=(0.485, 0.456, 0.406), help='normalization mean')
    parser.add_argument('--norm-std', type=float, nargs='+',
                        default=(0.229, 0.224, 0.225), help='normalization std')
    parser.add_argument('--train-resizing', type=str, default='default')
    parser.add_argument('--val-resizing', type=str, default='default')
    parser.add_argument('--resize-size', type=int, default=224,
                        help='the image size after resizing')
    parser.add_argument('--scale', type=float, nargs='+', default=[1., 1.0], metavar='PCT',
                        help='Random resize scale (default: 0.08 1.0)')
    parser.add_argument('--ratio', type=float, nargs='+', default=[1., 1.], metavar='RATIO',
                        help='Random resize aspect ratio (default: 0.75 1.33)')
    parser.add_argument('--no-hflip', action='store_true',
                        help='no random horizontal flipping during training')
    # model parameters
    parser.add_argument('-a', '--arch', metavar='ARCH', default='',
                        choices=utils.get_model_names(),
                        help='backbone architecture: ' +
                             ' | '.join(utils.get_model_names()) +
                             ' (default: )')
    parser.add_argument('--objective', type=str, default=None,
                        help='learning objectives')
    parser.add_argument('--contrastive-task', type=str, default=None,
                        help='task type')
    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
    # training parameters
    parser.add_argument('--state-batch-size', default=16, type=int,
                        metavar='N',
                        help='state sample size')
    parser.add_argument('-b', '--batch-size', default=16, type=int,
                        metavar='N',
                        help='mini-batch size (default: 16)')
    parser.add_argument('--n-size', default=16, type=int,
                        metavar='N',
                        help='dynamics negative sample size')
    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                        metavar='LR', help='initial learning rate of the classifier', dest='lr')
    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
                        metavar='W', help='weight decay (default: 1e-3)',
                        dest='weight_decay')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 2)')
    parser.add_argument('--epochs', default=100, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-i', '--iters-per-epoch', default=100, type=int,
                        help='Number of iterations per epoch')
    parser.add_argument('-p', '--print-freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 100)')
    parser.add_argument('--seed', default=777, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--per-class-eval', action='store_true',
                        help='whether output per-class accuracy during evaluation')
    parser.add_argument("--log", type=str, default='',
                        help="Where to save logs, checkpoints and debugging images.")
    parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
                        help="When phase is 'test', only test the model."
                             "When phase is 'analysis', only analysis the model.")
    parser.add_argument("--ckpt-path", type=str, default='',
                        help="")
    parser.add_argument("--den-ckpt-path", type=str, default='',
                        help="")
    parser.add_argument("--model-cls", type=str, default='',
                        choices=["VPCLIP", "VPTCLIP", "MVPTCLIP", "DPCLIP", "DPTCLIP", "MDPTCLIP"])
    parser.add_argument("--deep", action='store_true')
    parser.add_argument("--resume", action='store_true')
    parser.add_argument("--resume-path", type=str, default='',
                        help="")
    parser.add_argument("--pretrain-prompt", type=str, default='',
                        help="")
    args = parser.parse_args()
    main(args)
