import os
import pickle
import pprint as pp

from PIL import Image
import clip
import torch
import torchvision
from glob import glob

from models.MVPT_ConPE import CONPEMultiVisualPromptTuningCLIP

def get_jaccard_similarity(x1, x2):
    return torch.stack([x1, x2]).min(dim=0)[0].sum() / torch.stack([x1, x2]).max(dim=0)[0].sum()

prompts = (
    "./logs/illuminations_/brightness/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/contrast/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/saturation/checkpoints/contrastive__latest.pth",
    "./logs/illuminations_/hue/checkpoints/contrastive__latest.pth",

    "./logs/FOVs/FOV_39-59_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/FOVs/FOV_69-89_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/FOVs/FOV_99-139_/checkpoints/comparative_action_byol_latest.pth",

    "./logs/LOOK_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/ROTATE_/checkpoints/comparative_action_byol_latest.pth",
    "./logs/STEPSIZE_/checkpoints/comparative_action_byol_latest.pth",

)
multi_p_mode = [
    ("COMPOSE", "UNIFORM", "AVG"), 
    ("COMPOSE", "UNIFORM", "CAT"), 
    ("COMPOSE", "WEIGHTED", "AVG"), 
    ("COMPOSE", "WEIGHTED", "CAT"), 
    ("ENSEMBLE", "UNIFORM", "AVG"), 
    ("ENSEMBLE", "UNIFORM", "CAT"), 
    ("ENSEMBLE", "WEIGHTED", "AVG"), 
    ("ENSEMBLE", "WEIGHTED", "CAT"), 
    ("ATTEMPT","WEIGHTED", "AVG"),
    ("SESoM","WEIGHTED", "AVG"),
]

to_tensor_aug = torchvision.transforms.Compose([
    # torchvision.transforms.Resize((224,224)),
    # torchvision.transforms.ColorJitter(
    #     brightness=(0.2, 0.2),
    #     contrast=(3.4, 3.4),
    #     saturation=(0.5, 0.5), 
    #     hue=(0.4, 0.4)
    # ),
    # torchvision.transforms.RandomGrayscale(p=0.2),
    # torchvision.transforms.RandomInvert(),
    # torchvision.transforms.RandomSolarize(threshold=0.75),
    # torchvision.transforms.RandomAdjustSharpness(sharpness_factor=2),
    # torchvision.transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
    # torchvision.transforms.RandomAutocontrast(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = clip.load("ViT-B/32", device="cpu")
for module in clip_model.modules():
    if "BatchNorm" in type(module).__name__:
        module.momentum = 0.0
clip_model.eval().float()

embedder = CONPEMultiVisualPromptTuningCLIP(clip_model, device)
embedder.eval()
embedder.prompt_init(prompts, multi_p_mode=multi_p_mode[0])
embedder.cuda()

print("Turning off gradients in both the image and the text encoder")
for name, param in embedder.named_parameters():
    if "prompt" not in name:
        param.requires_grad_(False)
# Double check
enabled = set()
for name, param in embedder.named_parameters():
    if param.requires_grad:
        enabled.add(name)
print(f"Parameters to be updated: {enabled}")
for name, param in clip_model.named_parameters():
    param.requires_grad_(False)
    assert not param.requires_grad

prompts_embeddings = embedder.visual_backbone.source_prompt_list

root = "/path/to/MMRL/trajdata/ObjNav/original"
datapath = glob(os.path.join(root, "*", '*.pkl'))
print(datapath)
datapath = ["/path/to/MMRL/trajdata/ObjNav/original/FOV_99-139/train_dataset.pkl"] # 4, 3
# datapath = ["/path/to/MMRL/trajdata/ObjNav/original/LOOK/train_dataset.pkl"] # 7, 21
# datapath = ["/path/to/MMRL/trajdata/ObjNav/original/FOV_39-59/train_dataset.pkl"] # 0, 7

result_dict = {}
for path in datapath:
    images = []
    DF = path.split("/")[-2]
    print(DF)
    result_dict[DF] = {}
    with open(path, 'rb') as f:
        data = pickle.load(f)
    del data["classes"]
    del data["goal_types"]
    del data["actions"]
    for i in [0]: # data.keys():
        for j in [7]: # data[i].keys():
            # for image in data[i][j]["frame"]:
            #     image = Image.fromarray(image)
            for image in ["allenact/test1.png", "allenact/test2.png", "allenact/test3.png", "allenact/test4.png", "allenact/test5.png"]:
                image = Image.open(image)
                image.save("test_ori.png")
                image = to_tensor_aug(image).unsqueeze(0).to(device)
                # torchvision.utils.save_image(image, "test_aug.png")
                # exit()
                images.append(image)
    # for k, prompt_embeddings in enumerate(prompts_embeddings):
    #     print(prompt_embeddings[0][0][0])
    #     similarities = []
    #     cos_sim = []
    #     dist_sim = []
    #     jac_sim = []
    #     for img in images:
    #         input_x = embedder(img.cuda())
    #         avg_x = torch.mean(input_x, dim=1)
    #         avg_p = torch.mean(prompt_embeddings, dim=1)
    #         cos = torch.nn.functional.cosine_similarity(avg_x, avg_p, dim=1)
    #         dist = torch.cdist(avg_x, avg_p, p=2)
    #         jac = get_jaccard_similarity(avg_x, avg_p)
            
    #         cos_sim.append(cos.item())
    #         dist_sim.append(dist.item())
    #         jac_sim.append(jac.item())
    #     similarities.append(cos_sim)
    #     similarities.append(dist_sim)
    #     similarities.append(jac_sim)
    #     result_dict[DF][k] = [round(torch.mean(torch.tensor(sim)).item(),3) for sim in similarities]
    # pp.pprint(result_dict)
    # exit()
    for k, prompt_embeddings in enumerate(prompts_embeddings):
        similarities = []
        cos_sim = []
        dist_sim = []
        jac_sim = []
        for img in images:
            img = img.expand(5, 3, 224, 224)
            torchvision.utils.save_image(torchvision.utils.make_grid(img, nrow=1, normalize=True), "grid_image_val_.png")
            # exit()
            ori_x = clip_model.visual(img.cuda()).unsqueeze(1)
            input_x = embedder(img.cuda())
            input_x = input_x.view(5,-1,512)
            print(ori_x.shape)
            print(input_x[:,:,0])
            Q_norm = torch.norm(ori_x, dim=2, keepdim=True)
            K_norm = torch.norm(input_x, dim=2, keepdim=True)
            dot_prod = torch.bmm(ori_x, input_x.permute(0, 2, 1))
            cos = dot_prod / torch.bmm(Q_norm, K_norm.permute(0, 2, 1)) # torch.Size([B, 1, prompt_num])
            print(cos)
            exit()
            dist = torch.cdist(ori_x, input_x[k].unsqueeze(0), p=2)
            jac = get_jaccard_similarity(ori_x, input_x[k].unsqueeze(0))
            
            cos_sim.append(cos.item())
            dist_sim.append(dist.item())
            jac_sim.append(jac.item())
        similarities.append(cos_sim)
        similarities.append(dist_sim)
        similarities.append(jac_sim)
            
        result_dict[DF][k] = [round(torch.mean(torch.tensor(sim)).item(),3) for sim in similarities]
    pp.pprint(result_dict)