from collections import OrderedDict
from typing import Tuple, Optional, List, Dict
import math
from operator import mul
from functools import reduce
import copy

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, Dropout

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

class PromptVisionTransformer(nn.Module):
    def __init__(self, clip_model, device, clip_model_type, DeepPrompt, n_vtk):
        super().__init__()
        self.device = device
        self.input_resolution = clip_model.visual.input_resolution
        self.output_dim = clip_model.visual.output_dim
        self.conv1 = clip_model.visual.conv1
        
        self.class_embedding = clip_model.visual.class_embedding
        self.positional_embedding = clip_model.visual.positional_embedding
        self.ln_pre = clip_model.visual.ln_pre

        self.transformer = clip_model.visual.transformer

        self.ln_post = clip_model.visual.ln_post
        self.proj = clip_model.visual.proj
        self.Deep = DeepPrompt
        self.n_vtk = n_vtk

        # prompt config
        if "ViT-B/32" in clip_model_type:
            self.patch_size = (32, 32)
            _, self.prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        elif "ViT-B/16" in clip_model_type:
            self.patch_size = (16, 16)
            _, self.prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        self.hidden_size = 768
        self.prompt_dropout = Dropout(0.1)

        self.attn_list = []

    def source_prompt_init(self, source_prompts_file, multi_p_mode):
        self.prompt_num = len(source_prompts_file)
        self.multi_p_mode = multi_p_mode
        self.source_prompt_list = nn.ParameterList([])
        self.source_prompt_attn_weight_list = nn.ModuleList([])
        for i in range(self.prompt_num):
            parameters = torch.load(source_prompts_file[i])
            if "net" in list(parameters.keys())[0]:
                key_prompt_embeddings = "net.visual_backbone.prompt_embeddings"
                key_prompt_proj = "net.visual_backbone.prompt_proj"
            else:
                key_prompt_embeddings = "visual_backbone.prompt_embeddings"
                key_prompt_proj = "visual_backbone.prompt_proj"
            with torch.no_grad():
                # prompt embeddings
                source_prompt = parameters[key_prompt_embeddings].clone()
                source_prompt_embeddings = nn.Parameter(source_prompt).requires_grad_(False)
                source_prompt_embeddings.to(self.device)
                assert (source_prompt_embeddings.to(self.device) == parameters[key_prompt_embeddings]).all()
                # prompt proj
                source_prompt_proj = torch.nn.Linear(self.prompt_dim, self.hidden_size)
                source_prompt_proj.weight.copy_(parameters[key_prompt_proj+".weight"].clone())
                source_prompt_proj.bias.copy_(parameters[key_prompt_proj+".bias"].clone())
                source_prompt_proj.requires_grad_(False)
                source_prompt_proj.to(self.device)
                assert (source_prompt_proj.bias.to(self.device) == parameters[key_prompt_proj+".bias"]).all()
                prompt = self.incorporate_prompt(source_prompt_embeddings, source_prompt_proj)
                self.source_prompt_list.append(prompt)
                if self.prompt_num > 1 and multi_p_mode[1]=="WEIGHTED":
                    # ATTENTION MODULE
                    if multi_p_mode[0] == "COMPOSE":
                        attn = nn.Sequential(
                            nn.Linear(self.hidden_size, 128, bias=False),
                            nn.SiLU(),
                            nn.Linear(128, self.hidden_size, bias=False),
                            nn.LayerNorm(self.hidden_size),
                        )
                        self.source_prompt_attn_weight_list.append(attn)
                
        self.source_prompt_list.requires_grad_(False)

    def meta_prompt_init(self,):
        self.meta_prompt_proj = nn.Linear(self.prompt_dim, self.hidden_size)
        nn.init.kaiming_normal_(self.meta_prompt_proj.weight, a=0, mode='fan_out')

        val = math.sqrt(6. / float(3 * reduce(mul, self.patch_size, 1) + self.prompt_dim))  # noqa

        self.meta_prompt_embeddings = nn.Parameter(torch.zeros(
            1, self.num_tokens, self.prompt_dim))
        # xavier_uniform initialization
        nn.init.uniform_(self.meta_prompt_embeddings.data, -val, val)

        if self.Deep:  # Deep prompt version noqa
            total_d_layer = 12-1
            self.deep_meta_prompt_embeddings = nn.Parameter(torch.zeros(
                total_d_layer, self.num_tokens, self.prompt_dim))
            # xavier_uniform initialization
            nn.init.uniform_(self.deep_meta_prompt_embeddings.data, -val, val)
        
        self.prompt_num += 1

    def incorporate_prompt(self, prompt_embeddings, prompt_proj=None, x=None):
        if x is not None:
            if prompt_proj is not None:
                B = x.size(0)
                x = torch.cat((
                        x[:, :1, :],
                        self.prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
                        x[:, 1:, :]
                    ), dim=1)
                #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
                x = self.ln_pre(x)
            else:
                B = x.size(0)
                x = torch.cat((
                        x[:, :1, :],
                        prompt_embeddings.expand(B, -1, -1),
                        x[:, 1:, :]
                    ), dim=1)
                #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
                x = self.ln_pre(x)
        else:
            x = self.prompt_dropout(prompt_proj(prompt_embeddings))
        
        return x
    
    def pre_prompt_forward(self, X):
        # ATTENTION (source + new task)
        B, _, D = X.shape # torch.Size([B, 50, hidden_size])
        # X -> X_hat
        X_hat, _ = torch.max(X, 1) # torch.Size([B, hidden_size])
        # QUERY: H
        H = X_hat.unsqueeze(1) # torch.Size([B, 1, hidden_size])
        # P -> P_hat
        P_list = []
        P_hat_list = []
        for prompt, attn_w in zip(self.source_prompt_list, self.source_prompt_attn_weight_list):
            P = prompt # torch.Size([1, 8, hidden_size])
            P_list.append(P)
            P_hat, _ = torch.max(prompt, 1) # torch.Size([1, hidden_size])
            P_hat = attn_w(P_hat) # torch.Size([1, hidden_size])
            P_hat_list.append(P_hat)
        # prompt_num += 1
        if hasattr(self, "meta_prompt_proj"):
            meta_P = self.incorporate_prompt(self.meta_prompt_embeddings, self.meta_prompt_proj, x=x) # torch.Size([1, 8, hidden_size])
            meta_P_hat, _ = torch.max(meta_P, 1) # torch.Size([1, hidden_size])
            H = meta_P_hat.unsqueeze(1)
        # VALUE: P
        multi_P = torch.cat(P_list).view(self.prompt_num, -1) #torch.Size([prompt_num, 8*hidden_size])
        multi_P = multi_P.unsqueeze(0).expand(B, self.prompt_num, -1) # torch.Size([B, prompt_num, 8*hidden_size])
        # KEY: P_hat
        multi_P_hat = torch.cat(P_hat_list).view(self.prompt_num, -1) # torch.Size([prompt_num, hidden_size])
        multi_P_hat = multi_P_hat.unsqueeze(0).expand(B, self.prompt_num, -1) # torch.Size([B, prompt_num, hidden_size])
        score = torch.bmm(H, multi_P_hat.transpose(1, 2) / np.sqrt(D))
        attn = torch.softmax(score, -1) # torch.Size([B, 1, prompt_num])
        if self.multi_p_mode[2] == "CAT":
            # weighted cat
            attn = attn.squeeze().unsqueeze(-1)
            attn = attn.expand(B, self.prompt_num, multi_P.size(-1)) # torch.Size([B, prompt_num, 8*hidden_size])
            context = attn * multi_P # torch.Size([B, prompt_num, 8*hidden_size])
            # P_instance
            if hasattr(self, "meta_prompt_proj"):
                meta_P = meta_P.expand(B, self.n_vtk, D)
                prompt_instances = context.view(B, self.prompt_num, self.n_vtk, D)
                prompt_instances = torch.cat([meta_P, prompt_instances], dim=1)
            else:
                prompt_instances = context.view(B, self.prompt_num, self.n_vtk, D)        
        # self.attn_list.append(attn)
        # print(torch.mean(torch.stack(self.attn_list).squeeze(), 0))
        return prompt_instances

    def post_prompt_forward(self, X):
        P_list = []
        for prompt in self.source_prompt_list:
            P = self.incorporate_prompt(prompt, x=X) # torch.Size([B, (50+8), hidden_size])
            P_list.append(P)
        B, L, D = P.shape
        # prompt_num += 1
        if hasattr(self, "meta_prompt_proj"):
            meta_P = self.incorporate_prompt(self.meta_prompt_embeddings, self.meta_prompt_proj, X)
            P_list.append(meta_P)
        multi_P = torch.cat(P_list, dim=1).view(B, self.prompt_num, L, D) #torch.Size([B, prompt_num, (50+8), hidden_size])
        multi_P = multi_P / self.prompt_num
        return multi_P

    def forward(self, x):        
        x = x.permute(1, 0, 2)  # NLD -> LND

        x = self.transformer(x)
        
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj
        
        return x

class CATMultiVisualPromptTuningCLIP(nn.Module):
    def __init__(self, clip_model, device, clip_model_type="CLIPViT-B/32", DeepPrompt=False, n_vtk=8):
        super(CATMultiVisualPromptTuningCLIP, self).__init__()
        self.visual_backbone = PromptVisionTransformer(clip_model, device, clip_model_type, DeepPrompt, n_vtk)
        self.logit_scale = clip_model.logit_scale

    def prompt_init(self, source_prompts_file, multi_p_mode="COMPOSE", meta_mode=False):
        self.visual_backbone.source_prompt_init(source_prompts_file, multi_p_mode)
        if meta_mode:
            self.visual_backbone.meta_prompt_init()

    def forward(self, images):
        with torch.no_grad():
            x = self.visual_backbone.conv1(images)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat([self.visual_backbone.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
            x = x + self.visual_backbone.positional_embedding.to(x.dtype)
        # incorporate_prompt
        if self.visual_backbone.multi_p_mode[1] == "COMPOSE" and self.visual_backbone.multi_p_mode[1] == "WEIGHTED":
            # pre prompt attention: making prompt instance
            prompts = self.visual_backbone.pre_prompt_forward(x)
        else:
            # post prompt attention: making prompted input instance
            prompts = self.visual_backbone.post_prompt_forward(x)
        image_features = []
        for i in range(self.visual_backbone.prompt_num):
            prompt = prompts[:,i,:,:]
            inputs = self.visual_backbone.incorporate_prompt(prompt, x=x.clone())
            image_feature = self.visual_backbone(inputs)
            image_features.append(image_feature.unsqueeze(1))
        image_features = torch.cat(image_features, dim=1)
        return image_features # torch.size([B, prompt_num, clip_emb_size])