import random
import time
import argparse
import shutil
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import os.path as osp
import pickle
import copy
import pprint

import tqdm
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision

from utils import utils
from utils import tsne
from utils.meter import AverageMeter, ProgressMeter
from utils.logger import CompleteLogger

from models.VPT import VisualPromptTuningCLIP # for transformer

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()
_tf_toTensor = torchvision.transforms.ToTensor() 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.autograd.set_detect_anomaly(True)


class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

def MLP(dim, projection_size, hidden_size=4096):
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size)
    )

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

class NetWrapper(nn.Module):
    def __init__(self, net, dim, projection_size, projection_hidden_size):
        super().__init__()
        self.net = net
        self.dim = dim
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size       
        self.projector = self._get_projector()

    def get_parameters(self, optimize_head=False, base_lr=1.0):
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": list(self.net.visual_backbone.parameters()) + list(self.projector.parameters()), "lr": 1.0 * base_lr},
        ]

        return params
    def _get_projector(self):
        create_mlp_fn = MLP
        projector = create_mlp_fn(self.dim, self.projection_size, self.projection_hidden_size)
        return projector

    def forward(self, x):
        rep = self.net(x)
        projection = self.projector(rep)
        return projection, rep


class Sampler:
    def __init__(self, args, model=None, device=None):
        with open(args.root, 'rb') as f:
            data = pickle.load(f)
        self.traj = data
        self.args = args
        self.goal_types = data["goal_types"]
        self.classes = data["classes"]
        self.actions = data["actions"]
        self.model = model
        self.device = device
        self.to_tensor = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.to_tensor_aug = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.ColorJitter(
                brightness=(0.2, 2),
                contrast=(0.9, 1.5),
                saturation=(1.5, 2), 
                hue=(-0.4, 0.4)
            ),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self._preprocess_dataset()

    def sample_positive_pairs(self):
        # query
        self._episode_index = np.random.randint(0, len(self.traj_temp_action_timestpes), size=(self.args.batch_size))
        self._tempo_index = np.random.randint(0, [len(self.traj_temp_action_timestpes[i]) for i in self._episode_index], size=(self.args.batch_size))
        self._actions_index = np.random.randint(0, [len(self.traj_temp_action_timestpes[i][j]) for i, j in zip(self._episode_index, self._tempo_index)], size=(self.args.batch_size))
        self._timestep_index = np.random.randint(0, [self.traj_temp_action_timestpes[i][j][k] for i, j, k in zip(self._episode_index, self._tempo_index, self._actions_index)], size=(self.args.batch_size))
        x = []
        episode_keys = list(self.traj_temp_action.keys())
        for i in range(self.args.batch_size):
            ep = episode_keys[self._episode_index[i]]
            ep_tempo_keys = list(self.traj_temp_action[ep].keys())
            tp = ep_tempo_keys[self._tempo_index[i]]
            ep_tempo_action_keys = list(self.traj_temp_action[ep][tp].keys())
            a = ep_tempo_action_keys[self._actions_index[i]]
            s = self._timestep_index[i]
            x.append(self._to_tensor(self.traj_temp_action[ep][tp][a][s], False))
        query = torch.stack(x)
        
        # key
        self._timestep_index = np.random.randint(0, [self.traj_temp_action_timestpes[i][j][k] for i, j, k in zip(self._episode_index, self._tempo_index, self._actions_index)], size=(self.args.batch_size))
        x = []
        episode_keys = list(self.traj_temp_action.keys())
        for i in range(self.args.batch_size):
            ep = episode_keys[self._episode_index[i]]
            ep_tempo_keys = list(self.traj_temp_action[ep].keys())
            tp = ep_tempo_keys[self._tempo_index[i]]
            ep_tempo_action_keys = list(self.traj_temp_action[ep][tp].keys())
            a = ep_tempo_action_keys[self._actions_index[i]]
            s = self._timestep_index[i]
            x.append(self._to_tensor(self.traj_temp_action[ep][tp][a][s], True))
        key = torch.stack(x)

        return query, key
    
    def action_contrastive_loss(self, q, k, neg, temp):
        N = neg.shape[1]
        b = q.shape[0]
        l_pos = torch.bmm(q.view(b, 1, -1), k.view(b, -1, 1)) # (b,1,1)
        l_neg = torch.bmm(q.view(b, 1, -1), neg.transpose(1,2)) # (b,1,N)
        logits = torch.cat([l_pos.view(b, 1), l_neg.view(b, N)], dim=1)
        
        labels = torch.zeros(b, dtype=torch.long)
        labels = labels.to(device)
        cross_entropy_loss = nn.CrossEntropyLoss()
        loss = cross_entropy_loss(logits/temp, labels)
        #print(logits, labels, loss)
        return loss

    def _preprocess_dataset(self,):
        self.mdps = list(self.traj.keys())
        self.mdps.remove("goal_types")
        self.mdps.remove("classes")
        self.mdps.remove("actions")
        self.episodes = [list(self.traj[mdp].keys()) for mdp in self.mdps]
        self.timesteps = copy.deepcopy(self.episodes)
        for i, mdp in enumerate(self.mdps):
            for j, episode in enumerate(self.episodes[i]):
                self.timesteps[i][j] = len(self.traj[mdp][episode]["frame"])
        # print(self.episodes)
        # print(self.timesteps)
        
        # to tensor
        for mdp in self.mdps:
            for episode in self.traj[mdp].keys():
                print(mdp, episode, len(self.traj[mdp][episode]["frame"]))
                for i in range(len(self.traj[mdp][episode]["frame"])):
                    self.traj[mdp][episode]["frame"][i] = Image.fromarray(self.traj[mdp][episode]["frame"][i])
                    self.traj[mdp][episode]["action"][i] = torch.tensor(self.traj[mdp][episode]["action"][i])
                    self.traj[mdp][episode]["reward"][i] = torch.tensor(self.traj[mdp][episode]["reward"][i])
        
        # temporal and behaviour based
        base_tempo = 1
        self.traj_temp_action = {}
        self.traj_temp_action_timestpes = []
        # dictionary init
        for episode in self.episodes[0]:
            self.traj_temp_action[episode] = {}
            for tempo in range(base_tempo+1):
                self.traj_temp_action[episode][tempo] = {}
                for action in self.actions:
                    self.traj_temp_action[episode][tempo][action] = []
        # pprint.pprint(self.traj_temp_action)
        
        # self.traj -> self.traj_temp_action
        for mdp in self.mdps:
            for episode in self.traj[mdp].keys():
                for i in range(len(self.traj[mdp][episode]["frame"])):
                    timestep = len(self.traj[mdp][episode]["frame"])
                    tempo_group = list(range(timestep//base_tempo, timestep+timestep//base_tempo, timestep//base_tempo))
                    for j, tempo in enumerate(tempo_group):
                        if i < tempo:
                            break
                    action = self.actions[self.traj[mdp][episode]["action"][i]]
                    self.traj_temp_action[episode][j][action].append(self.traj[mdp][episode]["frame"][i])
        
        vis_mdps = random.sample(self.mdps, 2)
        self.traj_vis = {}
        
        for mdp in vis_mdps:
            self.traj_vis[mdp]={}
            for episode in self.traj[mdp].keys():
                self.traj_vis[mdp][episode] = []
                for i in range(len(self.traj[mdp][episode]["frame"])):
                    self.traj_vis[mdp][episode].append(self._to_tensor(self.traj[mdp][episode]["frame"][i], False))
        set0 = set(self.traj_vis[list(self.traj_vis.keys())[0]].keys())
        set1 = set(self.traj_vis[list(self.traj_vis.keys())[1]].keys())
        same_episode = list(set0 & set1)
        for episode in set0:
            if episode not in same_episode:
                del self.traj_vis[list(self.traj_vis.keys())[0]][episode]
        for episode in set1:
            if episode not in same_episode:
                del self.traj_vis[list(self.traj_vis.keys())[1]][episode]
        # print(sorted(self.traj_vis[list(self.traj_vis.keys())[0]].keys()))
        # print(sorted(self.traj_vis[list(self.traj_vis.keys())[1]].keys()))
        
        # clearing
        actions = self.actions.copy()
        actions_check = self.actions.copy()
        episode_keys = list(self.traj_temp_action.keys())
        tempo_keys = list(self.traj_temp_action[list(episode_keys)[0]].keys())
        
        for episode in episode_keys:
            self.traj_temp_action_timestpes.append([])
            for tempo in tempo_keys:
                self.traj_temp_action_timestpes[episode].append([])
                for action in self.actions:
                    if len(self.traj_temp_action[episode][tempo][action]) == 0:
                        del self.traj_temp_action[episode][tempo][action]
                        continue
                    else:
                        self.traj_temp_action_timestpes[episode][-1].append(len(self.traj_temp_action[episode][tempo][action]))
                        try:
                            actions_check.remove(action)
                        except:
                            pass

        remove = []
        for episode in range(len(self.traj_temp_action_timestpes)):
            for tempo, content in enumerate(self.traj_temp_action_timestpes[episode]):
                if not len(content):
                    remove.append((episode, tempo))
        
        for r in remove:
            del self.traj_temp_action[episode_keys[r[0]]][tempo_keys[r[1]]]
            del self.traj_temp_action_timestpes[r[0]][r[1]]
        
        if len(actions_check):
            print(actions_check)
            for action in actions_check:
                self.actions.remove(action)

        del self.traj
        # pprint.pprint(self.traj_temp_action)
        pprint.pprint(self.traj_temp_action_timestpes)

    def _to_tensor(self, x, augmentation=False):
        if augmentation:
            x = self.to_tensor_aug(x)
        else:
            x = self.to_tensor(x)
        return x


def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    utils.seed_fix(args.seed)

    # contrastive learning sampler
    sampler = Sampler(args)

    # create model
    print("=> using clip_model model '{}'".format(args.arch))
    clip_model = utils.get_model(args.arch, pretrain = True)
    
    dummy_input = torch.randn([1, 3, 224, 224]).to(device)
    feature = clip_model.encode_image(dummy_input)
    feat_dim = feature.size(-1)

    if args.model_cls == "VPTCLIP":
        class_list = sampler.classes
        model = VisualPromptTuningCLIP(clip_model, 
                            class_list,
                            feat_dim, device, 
                            clip_model_type=args.arch, DeepPrompt=args.deep, n_vtk=8).to(device)
    else:
        raise NotImplementedError
    
    
    print("BYOL version")
    projection_hidden_size = 4096
    moving_average_decay = 0.99
    projection_size = 16
    model = NetWrapper(model, dim=feat_dim,projection_size=projection_size,projection_hidden_size=projection_hidden_size).to(device)

    target_encoder = copy.deepcopy(model)
    set_requires_grad(target_encoder, False)
    online_predictor = MLP(projection_size, projection_size, 512).to(device)
    target_ema_updater = EMA(moving_average_decay)

    if args.resume:
        model.load_state_dict(torch.load(args.resume_path))

    # train the prompt wish multi domain data
    train_func = train_comparative

    # define optimizer and lr scheduler
    if args.resume:
        optimizer = Adam(model.get_parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
    else:
        optimizer = SGD(model.get_parameters(), args.lr,
                        momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
    
    # start training    
    for epoch in range(args.epochs):
        print(lr_scheduler.get_last_lr())
        
        # train for one epoch
        train_func(sampler, model,target_encoder, online_predictor, target_ema_updater, optimizer, lr_scheduler, epoch, args)
        
        # save checkpoint
        torch.save(model.state_dict(), logger.get_checkpoint_path(args.objective +"_"+ args.contrastive_task + '_latest'))
        
        # evaluate
        if epoch % args.print_freq == 0:
            # latent space analysis
            visualize(model, args, logger, post_fix='_'+str(epoch), sampler=sampler)
            
        torch.cuda.empty_cache()

    if args.phase == "analysis":
        model.load_state_dict(torch.load(args.ckpt_path))
        # latent space analysis
        visualize(model, args, logger, post_fix='_'+args.phase, sampler=sampler)
        exit()
    else:
        print("Pretraining process is done.")
    
    print("***Last model visualization")
    visualize(model, args, logger, post_fix="_latest", sampler=sampler)


def train_comparative(sampler, model, target_encoder,online_predictor, target_ema_updater,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    action_losses = AverageMeter('Action Loss', ':6.6f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, action_losses],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        # measure data loading time
        data_time.update(time.time() - end)
        
        # BYOL
        # get query and key feature
        query, key = sampler.sample_positive_pairs()
        q = query
        query = q.to(device)
        k  = key
        key  = k.to(device)
        x_pos = torch.cat((query, key), dim=0)

        # torchvision.utils.save_image(torchvision.utils.make_grid(query, nrow=16, normalize=True), "grid_image_vis.png")
        # torchvision.utils.save_image(torchvision.utils.make_grid(key, nrow=16, normalize=True), "grid_image_vis_aug.png")
        
        proj_features, image_rep = model(x_pos)
        f_query, f_key = proj_features.chunk(2, dim=0)

        online_pred_one = online_predictor(f_query)
        online_pred_two = online_predictor(f_key)

        with torch.no_grad():
            target_proj_one, image_rep = target_encoder(query)
            target_proj_two, image_rep = target_encoder(key)
            target_proj_one.detach_()
            target_proj_two.detach_()

        loss_one = loss_fn(online_pred_one, target_proj_two.detach())
        loss_two = loss_fn(online_pred_two, target_proj_one.detach())

        action_loss = (loss_one + loss_two).mean()

        # Compute gradient and do SGD step
        optimizer.zero_grad()
        action_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        update_moving_average(target_ema_updater,target_encoder,model)
        
        # for param in model.net.visual_backbone.parameters():
        #     if param.grad is not None:
        #         print(param[0][0][0])
        #         break
        
        action_losses.update(action_loss.item(), f_query.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)


def visualize(model, args, logger, post_fix, sampler=None):
    print("Val Plot")
    model.eval()
    # latent space analysis
    source_feature = []
    target_feature = []
    vis_keys = list(sampler.traj_vis.keys())
    # source
    for i, episode in tqdm.tqdm(enumerate(sampler.traj_vis[vis_keys[0]])):
        for x in sampler.traj_vis[vis_keys[0]][episode]:
            x = x.to(device)
            _, f_s = model(x.unsqueeze(0))
            f_s = f_s.cpu().detach()
            source_feature.append(f_s)
    source_feature = torch.cat(source_feature, dim=0)
    # target
    for i, episode in tqdm.tqdm(enumerate(sampler.traj_vis[vis_keys[1]])):
        for x in sampler.traj_vis[vis_keys[1]][episode]:
            x = x.to(device)
            _, f_s = model(x.unsqueeze(0))
            f_s = f_s.cpu().detach()
            target_feature.append(f_s)
    target_feature = torch.cat(target_feature, dim=0)
    
    # plot t-SNE
    tSNE_filename = osp.join(logger.visualize_directory, args.objective+post_fix+'_val_TSNE.pdf')
    tsne.visualize(source_feature, target_feature, tSNE_filename)
    print("Saving t-SNE to", tSNE_filename)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Contrastive Learning for MMRL')
    # dataset parameters
    parser.add_argument('root', metavar='DIR',
                        help='root path of dataset')
    parser.add_argument('--norm-mean', type=float, nargs='+',
                        default=(0.485, 0.456, 0.406), help='normalization mean')
    parser.add_argument('--norm-std', type=float, nargs='+',
                        default=(0.229, 0.224, 0.225), help='normalization std')
    # model parameters
    parser.add_argument('-a', '--arch', metavar='ARCH', default='',
                        choices=utils.get_model_names(),
                        help='backbone architecture: ' +
                             ' | '.join(utils.get_model_names()) +
                             ' (default: )')
    parser.add_argument('--objective', type=str, default=None,
                        help='learning objectives')
    parser.add_argument('--contrastive-task', type=str, default=None,
                        help='task type')
    # training parameters
    parser.add_argument('--state-batch-size', default=16, type=int,
                        metavar='N',
                        help='state sample size')
    parser.add_argument('-b', '--batch-size', default=16, type=int,
                        metavar='N',
                        help='mini-batch size (default: 16)')
    parser.add_argument('--n-size', default=16, type=int,
                        metavar='N',
                        help='dynamics negative sample size')
    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                        metavar='LR', help='initial learning rate of the classifier', dest='lr')
    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
                        metavar='W', help='weight decay (default: 1e-3)',
                        dest='weight_decay')
    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
                        help='number of data loading workers (default: 2)')
    parser.add_argument('--epochs', default=100, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-i', '--iters-per-epoch', default=100, type=int,
                        help='Number of iterations per epoch')
    parser.add_argument('-p', '--print-freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 100)')
    parser.add_argument('--seed', default=777, type=int,
                        help='seed for initializing training. ')
    parser.add_argument("--log", type=str, default='',
                        help="Where to save logs, checkpoints and debugging images.")
    parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
                        help="When phase is 'test', only test the model."
                             "When phase is 'analysis', only analysis the model.")
    parser.add_argument("--ckpt-path", type=str, default='',
                        help="")
    parser.add_argument("--model-cls", type=str, default='',
                        choices=["VPTCLIP"])
    parser.add_argument("--deep", action='store_true')
    parser.add_argument("--resume", action='store_true')
    parser.add_argument("--resume-path", type=str, default='',
                        help="")
    args = parser.parse_args()
    main(args)
