import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import pickle
import copy

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision

from utils import utils
from models.MVPT_ConPE import CONPEMultiVisualPromptTuningCLIP

from clip import clip

torch.autograd.set_detect_anomaly(True)

def pre_attention_foward(model, prompt, x, image, prefix=""):
    
    x = model.conv1(x)  # shape = [*, width, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]

    x = x + model.positional_embedding.to(x.dtype)
    
    # incorporate_prompt
    B = x.size(0)
    
    x = torch.cat((
            x[:, :1, :],
            model.prompt_dropout(prompt.expand(B, -1, -1)),
            x[:, 1:, :]
        ), dim=1)
    
    x = model.ln_pre(x)
    x = x.permute(1, 0, 2)  # NLD -> LND
    
    for index, layer in enumerate(model.transformer.resblocks):
        if index == len(model.transformer.resblocks) - 1:
            x = layer.ln_1(x)
            layer.attn_mask = layer.attn_mask.to(dtype=x.dtype, device=x.device) if layer.attn_mask is not None else None
            x = layer.attn(x, x, x, need_weights=True, attn_mask=layer.attn_mask)[1]
            break
        x = layer(x)

    x = x[:, 0, 1:]

    image_attention = x.detach()

    min_v = torch.min(image_attention)
    max_v = torch.max(image_attention)
    image_attention = (image_attention - min_v.item()) / (max_v.item() - min_v.item())
    
    image_attention = image_attention[0][8:]
    image_attention = torch.nn.functional.interpolate(image_attention.view(1, 1, 7, 7), (224,224), mode='bicubic', align_corners=False)
    
    image = image*image_attention
    return image
    
    # image_attention = x.detach().cpu().numpy()
    # fig = plt.figure(figsize=[10, 5], frameon=False)
    # ax = fig.add_subplot(1, 2, 1)
    # ax.axis("off")
    # ax.imshow(image.resize((300,300)))
    # ax = fig.add_subplot(1, 2, 2)
    # ax.axis("off")
    # ax.imshow(image_attention[0][8:].reshape(7, 7))
    # fig.subplots_adjust(hspace=0, wspace=0)
    # fig.savefig(f"attention_map.png")
    # plt.cla()
    # return

def attention_foward(model, x, image, prefix=""):
    
    x = model.conv1(x)  # shape = [*, width, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]

    x = x + model.positional_embedding.to(x.dtype)
    
    # incorporate_prompt
    B = x.size(0)
    
    x = model.ln_pre(x)
    x = x.permute(1, 0, 2)  # NLD -> LND
    
    for index, layer in enumerate(model.transformer.resblocks):
        if index == len(model.transformer.resblocks) - 1:
            x = layer.ln_1(x)
            layer.attn_mask = layer.attn_mask.to(dtype=x.dtype, device=x.device) if layer.attn_mask is not None else None
            x = layer.attn(x, x, x, need_weights=True, attn_mask=layer.attn_mask)[1]
            break
        x = layer(x)

    x = x[:, 0, 1:]

    image_attention = x.detach()
    min_v = torch.min(image_attention)
    max_v = torch.max(image_attention)
    image_attention = (image_attention - min_v.item()) / (max_v.item() - min_v.item())
    
    image_attention = image_attention[0][:]
    image_attention = torch.nn.functional.interpolate(image_attention.view(1, 1, 7, 7), (224,224), mode='bicubic', align_corners=False)
    
    torchvision.utils.save_image(image, f"vis/{prefix}_test_image.png")
    image = image*image_attention
    return image
    

def post_attention_foward(model, prompt, x, image, prefix=""):
    
    x = model.conv1(x)  # shape = [*, width, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]

    x = x + model.positional_embedding.to(x.dtype)
    
    # incorporate_prompt
    B = x.size(0)
    
    x = torch.cat((
            x[:, :1, :],
            model.prompt_dropout(prompt.expand(B, -1, -1)),
            x[:, 1:, :]
        ), dim=1)
    
    x = model.ln_pre(x)
    x = x.permute(1, 0, 2)  # NLD -> LND
    
    for index, layer in enumerate(model.transformer.resblocks):
        if index == len(model.transformer.resblocks) - 1:
            x = layer.ln_1(x)
            layer.attn_mask = layer.attn_mask.to(dtype=x.dtype, device=x.device) if layer.attn_mask is not None else None
            x = layer.attn(x, x, x, need_weights=True, attn_mask=layer.attn_mask)[1]
            break
        x = layer(x)

    x = x[:, 0, 1:]

    image_attention = x.detach()
    attentions = []
    for attention in image_attention:
        min_v = torch.min(attention)
        max_v = torch.max(attention)
        attention = (attention - min_v.item()) / (max_v.item() - min_v.item())
        attentions.append(attention)
    attentions = torch.cat(attentions)
    image_attention =  attention.unsqueeze(0)
    image_attention = torch.mean(image_attention, 0, True)
    min_v = torch.min(image_attention)
    max_v = torch.max(image_attention)
    image_attention = (image_attention - min_v.item()) / (max_v.item() - min_v.item())
    
    image_attention = image_attention[0][8:]
    image_attention = torch.nn.functional.interpolate(image_attention.view(1, 1, 7, 7), (224,224), mode='bicubic', align_corners=False)
    
    image = image*image_attention
    return image

utils.seed_fix(777)

pre_all_prompts = (
    # "./logs/illumination_brightness/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_contrast/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_saturation/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_hue/checkpoints/contrastive__latest.pth",
    
    # "./logs/FOV_39-59/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOV_69-89/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOV_99-139/checkpoints/comparative_action_byol_latest.pth",

    # "./logs/illumination/checkpoints/contrastive__latest.pth",
    # "./logs/fov_large_epoche/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/look/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/rotate/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/stepsize/checkpoints/comparative_action_byol_latest.pth",

    "./logs/illuminations_/brightness/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/contrast/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/saturation/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/hue/checkpoints/contrastive__latest.pth",

    "./logs/FOVs/FOV_39-59_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/FOVs/FOV_69-89_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/FOVs/FOV_99-139_/checkpoints/comparative_action_byol_latest.pth",

    "./logs/LOOK_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/ROTATE_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/STEPSIZE_/checkpoints/comparative_action_byol_latest.pth",
)
pre_prompts = (
    # "./logs/illumination_brightness/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_contrast/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_saturation/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_hue/checkpoints/contrastive__latest.pth",
    
    # "./logs/FOV_39-59/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOV_69-89/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOV_99-139/checkpoints/comparative_action_byol_latest.pth",

    # "./logs/illumination/checkpoints/contrastive__latest.pth",
    # "./logs/fov_large_epoche/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/look/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/rotate/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/stepsize/checkpoints/comparative_action_byol_latest.pth",

    "./logs/illuminations_/brightness/checkpoints/contrastive__latest.pth",
    # "./logs/illuminations_/contrast/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/saturation/checkpoints/contrastive__latest.pth",
    # "./logs/illuminations_/hue/checkpoints/contrastive__latest.pth",

    # "./logs/FOVs/FOV_39-59_/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOVs/FOV_69-89_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/FOVs/FOV_99-139_/checkpoints/comparative_action_byol_latest.pth",

    # "./logs/LOOK_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/ROTATE_/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/STEPSIZE_/checkpoints/comparative_action_byol_latest.pth",
)
post_all_prompts = (
    # "./logs/illumination_brightness/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_contrast/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_saturation/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_hue/checkpoints/contrastive__latest.pth",
    
    # "./logs/FOV_39-59/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOV_69-89/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOV_99-139/checkpoints/comparative_action_byol_latest.pth",

    # "./logs/illumination/checkpoints/contrastive__latest.pth",
    # "./logs/fov_large_epoche/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/look/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/rotate/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/stepsize/checkpoints/comparative_action_byol_latest.pth",

    "./logs/illuminations_/brightness/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/contrast/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/saturation/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/hue/checkpoints/contrastive__latest.pth",

    "./logs/FOVs/FOV_39-59_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/FOVs/FOV_69-89_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/FOVs/FOV_99-139_/checkpoints/comparative_action_byol_latest.pth",

    "./logs/LOOK_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/ROTATE_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/STEPSIZE_/checkpoints/comparative_action_byol_latest.pth",
)
post_prompts = (
    # "./logs/illumination_brightness/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_contrast/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_saturation/checkpoints/contrastive__latest.pth",
    # "./logs/illumination_hue/checkpoints/contrastive__latest.pth",
    
    # "./logs/FOV_39-59/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOV_69-89/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOV_99-139/checkpoints/comparative_action_byol_latest.pth",

    # "./logs/illumination/checkpoints/contrastive__latest.pth",
    # "./logs/fov_large_epoche/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/look/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/rotate/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/stepsize/checkpoints/comparative_action_byol_latest.pth",

    "./logs/illuminations_/brightness/checkpoints/contrastive__latest.pth",
    # "./logs/illuminations_/contrast/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/saturation/checkpoints/contrastive__latest.pth",
    # "./logs/illuminations_/hue/checkpoints/contrastive__latest.pth",

    # "./logs/FOVs/FOV_39-59_/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/FOVs/FOV_69-89_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/FOVs/FOV_99-139_/checkpoints/comparative_action_byol_latest.pth",

    # "./logs/LOOK_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/ROTATE_/checkpoints/comparative_action_byol_latest.pth",
    # "./logs/STEPSIZE_/checkpoints/comparative_action_byol_latest.pth",
)
multi_p_mode = [
    ("COMPOSE", "UNIFORM", "AVG"), 
    ("COMPOSE", "UNIFORM", "CAT"), 
    ("COMPOSE", "WEIGHTED", "AVG"), 
    ("COMPOSE", "WEIGHTED", "CAT"), 
    ("ENSEMBLE", "UNIFORM", "AVG"), 
    ("ENSEMBLE", "UNIFORM", "CAT"), 
    ("ENSEMBLE", "WEIGHTED", "AVG"), 
    ("ENSEMBLE", "WEIGHTED", "CAT"), 
    ("ATTEMPT","WEIGHTED", "AVG"),
    ("SESoM","WEIGHTED", "AVG"),
]

to_tensor_aug = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    # torchvision.transforms.ColorJitter(
    #     brightness=(0.7, 2),
    #     contrast=(0.9, 1.5),
    #     saturation=(1.5, 2), 
    #     hue=(-0.4, 0.4)
    # ),
    # torchvision.transforms.RandomGrayscale(p=0.2),
    # # torchvision.transforms.RandomInvert(),
    # # torchvision.transforms.RandomSolarize(threshold=0.75),
    # # torchvision.transforms.RandomAdjustSharpness(sharpness_factor=2),
    # torchvision.transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
    # torchvision.transforms.RandomAutocontrast(),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = clip.load("ViT-B/32", device="cpu")
for module in clip_model.modules():
    if "BatchNorm" in type(module).__name__:
        module.momentum = 0.0
clip_model.eval().float()

# all
pre_all_embedder = CONPEMultiVisualPromptTuningCLIP(clip_model, device)
pre_all_embedder.eval()
pre_all_embedder.prompt_init(pre_all_prompts, multi_p_mode=multi_p_mode[0])
pre_all_embedder.cuda()
# selective
pre_embedder = CONPEMultiVisualPromptTuningCLIP(clip_model, device)
pre_embedder.eval()
pre_embedder.prompt_init(pre_prompts, multi_p_mode=multi_p_mode[0])
pre_embedder.cuda()
# all
post_all_embedder = CONPEMultiVisualPromptTuningCLIP(clip_model, device)
post_all_embedder.eval()
post_all_embedder.prompt_init(post_all_prompts, multi_p_mode=multi_p_mode[0])
post_all_embedder.cuda()
# selective
post_embedder = CONPEMultiVisualPromptTuningCLIP(clip_model, device)
post_embedder.eval()
post_embedder.prompt_init(post_prompts, multi_p_mode=multi_p_mode[0])
post_embedder.cuda()

pre_all_prompts_embeddings = pre_all_embedder.visual_backbone.source_prompt_list
pre_prompts_embeddings = pre_embedder.visual_backbone.source_prompt_list
post_all_prompts_embeddings = post_all_embedder.visual_backbone.source_prompt_list
post_prompts_embeddings = post_embedder.visual_backbone.source_prompt_list

datapath = ["/path/to/MMRL/trajdata/ObjNav/original/FOV_99-139/train_dataset.pkl"] # 4, 3
# datapath = ["/path/to/MMRL/trajdata/ObjNav/original/LOOK/train_dataset.pkl"] # 7, 21
# datapath = ["/path/to/MMRL/trajdata/ObjNav/original/FOV_39-59/train_dataset.pkl"] # 0, 7

for path in datapath:
    # PRE
    prompts = []
    for prompt_embedding in pre_all_prompts_embeddings:
        prompts.append(prompt_embedding)
    prompts = torch.cat(prompts)
    pre_all_prompt_embedding = torch.mean(prompts, 0, True)
    print(pre_all_prompt_embedding.shape)

    prompts = []
    for prompt_embedding in pre_prompts_embeddings:
        prompts.append(prompt_embedding)
    prompts = torch.cat(prompts)
    pre_prompt_embedding = torch.mean(prompts, 0, True)
    print(pre_prompt_embedding.shape)
    
    # POST
    prompts = []
    for prompt_embedding in post_all_prompts_embeddings:
        prompts.append(prompt_embedding)
    post_all_prompt_embedding = torch.cat(prompts)
    print(post_all_prompt_embedding.shape)

    prompts = []
    for prompt_embedding in post_prompts_embeddings:
        prompts.append(prompt_embedding)
    post_prompt_embedding = torch.cat(prompts)
    print(post_prompt_embedding.shape)
    
    
    images = []
    DF = path.split("/")[-2]
    print(DF)
    with open(path, 'rb') as f:
        data = pickle.load(f)
    del data["classes"]
    del data["goal_types"]
    del data["actions"]
    print(data.keys())
    for i in [4]: # data.keys():
        print(data[i].keys())
        for j in [3]: # data[i].keys():
            for num, image in enumerate(data[i][j]["frame"]):
                image = Image.fromarray(image)
                # image.save("test_ori.png")
                input_x = to_tensor_aug(image).unsqueeze(0).to(device)
                # exit()
                out = attention_foward(clip_model.visual, input_x, input_x, prefix="input_"+str(num))
                torchvision.utils.save_image(out, "vis/clip_"+str(num)+"_test_attn_image.png")

                out = pre_attention_foward(pre_all_embedder.visual_backbone, pre_all_prompt_embedding, input_x, input_x, prefix="input_"+str(num))
                torchvision.utils.save_image(out, "vis/composition_all_"+str(num)+"_test_attn_image.png")
                out = pre_attention_foward(pre_embedder.visual_backbone, pre_prompt_embedding, input_x, input_x, prefix="input_"+str(num))
                torchvision.utils.save_image(out, "vis/composition_"+str(num)+"_test_attn_image.png")

                x = input_x.expand(post_all_prompt_embedding.size(0), -1, -1, -1)
                out = post_attention_foward(post_all_embedder.visual_backbone, post_all_prompt_embedding, x, input_x, prefix="input_"+str(num))
                torchvision.utils.save_image(out, "vis/ensemble_all_"+str(num)+"_test_attn_image.png")
                x = input_x.expand(post_prompt_embedding.size(0), -1, -1, -1)
                out = post_attention_foward(post_embedder.visual_backbone, post_prompt_embedding, x, input_x, prefix="input_"+str(num))
                torchvision.utils.save_image(out, "vis/ensemble_"+str(num)+"_test_attn_image.png")
            exit()

