from collections import OrderedDict
from typing import Tuple, Optional, List, Dict
import math
from operator import mul
from functools import reduce
import copy

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, Dropout

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

class PromptVisionTransformer(nn.Module):
    def __init__(self, clip_model, device, clip_model_type, DeepPrompt, n_vtk):
        super().__init__()
        self.device = device
        self.input_resolution = clip_model.visual.input_resolution
        self.output_dim = clip_model.visual.output_dim
        self.conv1 = clip_model.visual.conv1
        
        self.class_embedding = clip_model.visual.class_embedding
        self.positional_embedding = clip_model.visual.positional_embedding
        self.ln_pre = clip_model.visual.ln_pre

        self.transformer = clip_model.visual.transformer

        self.ln_post = clip_model.visual.ln_post
        self.proj = clip_model.visual.proj
        self.Deep = DeepPrompt

        # prompt config
        if "ViT-B/32" in clip_model_type:
            self.patch_size = (32, 32)
            _, self.prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        elif "ViT-B/16" in clip_model_type:
            self.patch_size = (16, 16)
            _, self.prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        self.hidden_size = 768
        self.prompt_dropout = Dropout(0.1)

    def source_prompt_init(self, source_prompts_file, attn_mode):
        self.prompt_set_num = len(source_prompts_file)
        self.attn_mode = attn_mode
        self.source_prompt_set_list = []
        for i in range(self.prompt_set_num):
            prompt_set = {}
            parameters = torch.load(source_prompts_file[i])
            if "net" in list(parameters.keys())[0]:
                key_prompt_embeddings = "net.visual_backbone.prompt_embeddings"
                key_prompt_proj = "net.visual_backbone.prompt_proj"
            else:
                key_prompt_embeddings = "visual_backbone.prompt_embeddings"
                key_prompt_proj = "visual_backbone.prompt_proj"
            with torch.no_grad():
                # param
                source_prompt = parameters[key_prompt_embeddings].clone()
                prompt_embeddings = nn.Parameter(source_prompt).requires_grad_(False)
                prompt_embeddings.to(self.device)
                prompt_set["prompt_embeddings"] = prompt_embeddings
                assert (prompt_embeddings.to(self.device) == parameters[key_prompt_embeddings]).all()
                # proj
                source_prompt_proj = torch.nn.Linear(self.prompt_dim, self.hidden_size)
                source_prompt_proj.weight.copy_(parameters[key_prompt_proj+".weight"].clone())
                source_prompt_proj.bias.copy_(parameters[key_prompt_proj+".bias"].clone())
                source_prompt_proj.requires_grad_(False)
                source_prompt_proj.to(self.device)
                prompt_set["prompt_proj"] = source_prompt_proj
                assert (source_prompt_proj.bias.to(self.device) == parameters[key_prompt_proj+".bias"]).all()
                if attn_mode=="pre_attn":
                    prompt = self.incorporate_prompt(prompt_set["prompt_embeddings"], prompt_set["prompt_proj"])
                    prompt_set["prompt"] = prompt
                    # del prompt_set["prompt_embeddings"]
                    # del prompt_set["prompt_proj"]
                elif attn_mode=="post_attn":
                    pass
                else:
                    raise NotImplementedError
            self.source_prompt_set_list.append(copy.deepcopy(prompt_set))
        del prompt_set
        del parameters
        # ATTENTION
        if self.prompt_set_num > 1:
            if self.attn_mode == "pre_attn":
                self.attn_prompt_W_down = nn.Linear(self.hidden_size, 128, bias=False)
                self.attn_prompt_W_up = nn.Linear(128, self.hidden_size, bias=False)
                self.attn_prompt_non_linear = nn.SiLU()
                self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)
            elif self.attn_mode == "post_attn":
                self.attn_prompt_W_down = nn.Linear(self.hidden_size, 128, bias=False)
                self.attn_prompt_W_up = nn.Linear(128, self.hidden_size, bias=False)
                self.attn_prompt_non_linear = nn.SiLU()
                self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)

    def meta_prompt_init(self,):
        self.meta_prompt_proj = nn.Linear(self.prompt_dim, self.hidden_size)
        nn.init.kaiming_normal_(self.meta_prompt_proj.weight, a=0, mode='fan_out')

        val = math.sqrt(6. / float(3 * reduce(mul, self.patch_size, 1) + self.prompt_dim))  # noqa

        self.meta_prompt_embeddings = nn.Parameter(torch.zeros(
            1, self.num_tokens, self.prompt_dim))
        # xavier_uniform initialization
        nn.init.uniform_(self.meta_prompt_embeddings.data, -val, val)

        if self.Deep:  # Deep prompt version noqa
            total_d_layer = 12-1
            self.deep_meta_prompt_embeddings = nn.Parameter(torch.zeros(
                total_d_layer, self.num_tokens, self.prompt_dim))
            # xavier_uniform initialization
            nn.init.uniform_(self.deep_meta_prompt_embeddings.data, -val, val)
        
        self.prompt_set_num += 1

    def incorporate_prompt(self, prompt_embeddings, prompt_proj=None, x=None):
        if x is not None:
            if prompt_proj is not None:
                B = x.size(0)
                x = torch.cat((
                        x[:, :1, :],
                        self.prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
                        x[:, 1:, :]
                    ), dim=1)
                #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
                x = self.ln_pre(x)
            else:
                B = x.size(0)
                x = torch.cat((
                        x[:, :1, :],
                        self.prompt_dropout(prompt_embeddings).expand(B, -1, -1),
                        x[:, 1:, :]
                    ), dim=1)
                #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
                x = self.ln_pre(x)
        else:
            x = self.prompt_dropout(prompt_proj(prompt_embeddings))
        
        return x
    
    def pre_prompt_attention(self, X):
        # ATTENTION (source + new task)
        B = X.size(0)
        # X -> X_hat-> H
        # X # torch.Size([B, 50, hidden_size])
        X_hat, _ = torch.max(X, 1) # torch.Size([B, hidden_size])
        H = self.attn_prompt_W_down(X_hat)
        H = self.attn_prompt_non_linear(H)
        H = self.attn_prompt_W_up(H)
        H = self.prompt_layer_norm(H) # torch.Size([B, hidden_size])
        # QUERY
        H = H.unsqueeze(1) # torch.Size([B, 1, hidden_size])
        # P -> P_hat
        P_list = []
        P_hat_list = []
        for prompt_set in self.source_prompt_set_list:
            P = prompt_set["prompt"] # torch.Size([B, 8, hidden_size])
            P_list.append(P)
            P_hat, _ = torch.max(P, 1) # torch.Size([B, hidden_size])
            P_hat_list.append(P_hat)
        # prompt_set_num += 1
        if hasattr(self, "meta_prompt_proj"):
            meta_P = self.incorporate_prompt(self.meta_prompt_embeddings, self.meta_prompt_proj)
            P_list.append(meta_P)
            meta_P_hat, _ = torch.max(meta_P, 1)
            P_hat_list.append(meta_P_hat)
        # VALUE
        multi_P = torch.cat(P_list).view(len(P_list), -1) #torch.Size([prompt_set_num, 8*hidden_size])
        multi_P = multi_P.unsqueeze(0).expand(B, self.prompt_set_num, -1) # torch.Size([B, prompt_set_num, 8*hidden_size])
        # KEY
        multi_P_hat = torch.cat(P_hat_list) # torch.Size([prompt_set_num, hidden_size])
        multi_P_hat = multi_P_hat.unsqueeze(0).expand(B, self.prompt_set_num, -1) # torch.Size([B, prompt_set_num, hidden_size])
        score = torch.bmm(H, multi_P_hat.transpose(1, 2) / np.sqrt(H.shape[-1]))
        attn = torch.softmax(score, -1) # torch.Size([B, 1, prompt_set_num])
        # weighted sum
        context = torch.bmm(attn, multi_P)
        context = context.view(B, -1, H.shape[-1]) # torch.Size([B, 8, hidden_size])
        # P_instance
        if hasattr(self, "meta_prompt_proj"):
            prompt = meta_P + context # torch.Size([B, 8, hidden_size])
        else:
            prompt = context # torch.Size([B, 8, hidden_size])
        x = self.incorporate_prompt(prompt, x=X)
        return x

    def post_prompt_attention(self, X):
        # ATTENTION (source + new task)
        B = X.size(0)
        # X -> X_hat-> H
        # X # torch.Size([B, 50, hidden_size])
        X_hat, _ = torch.max(X, 1) # torch.Size([B, hidden_size])
        H = self.attn_prompt_W_down(X_hat)
        H = self.attn_prompt_non_linear(H)
        H = self.attn_prompt_W_up(H)
        H = self.prompt_layer_norm(H) # torch.Size([B, hidden_size])
        # QUERY
        H = H.unsqueeze(1) # torch.Size([B, 1, hidden_size])
        # P -> P_hat
        P_list = []
        P_hat_list = []
        for prompt_set in self.source_prompt_set_list:
            P = self.incorporate_prompt(prompt_set["prompt_embeddings"], prompt_set["prompt_proj"], X) # torch.Size([B, (50+8), hidden_size])
            P_list.append(P)
            P_hat, _ = torch.max(P, 1) # torch.Size([B, hidden_size])
            P_hat_list.append(P_hat)
        # prompt_set_num += 1
        if hasattr(self, "meta_prompt_proj"):
            meta_P = self.incorporate_prompt(self.meta_prompt_embeddings, self.meta_prompt_proj, X)
            P_list.append(meta_P)
            meta_P_hat, _ = torch.max(meta_P, 1)
            P_hat_list.append(meta_P_hat)
        # VALUE
        multi_P = torch.cat(P_list).view(B, self.prompt_set_num, -1) #torch.Size([B, prompt_set_num, (50+8)*hidden_size])
        # KEY
        multi_P_hat = torch.cat(P_hat_list).view(B, self.prompt_set_num, -1) # torch.Size([B, prompt_set_num, hidden_size])
        score = torch.bmm(H, multi_P_hat.transpose(1, 2) / np.sqrt(H.shape[-1]))
        attn = torch.softmax(score, -1) # torch.Size([B, 1, prompt_set_num])
        # weighted sum
        context = torch.bmm(attn, multi_P)
        context = context.view(B, -1, H.shape[-1]) # torch.Size([(50+8), B, hidden_size])
        # P_instance
        if hasattr(self, "meta_prompt_proj"):
            prompted_x = meta_P + context # torch.Size([B, (50+8), hidden_size])
        else:
            prompted_x = context # torch.Size([B, (50+8), hidden_size])
        return prompted_x

    def forward(self, x):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        # incorporate_prompt
        if self.prompt_set_num == 1:
            prompt_set = self.source_prompt_set_list[0]
            x = self.incorporate_prompt(prompt_set["prompt_embeddings"], prompt_set["prompt_proj"], x)
        elif self.attn_mode == "pre_attn":
            # pre prompt attention: making prompt instance
            x = self.pre_prompt_attention(x)
        elif self.attn_mode == "post_attn":
            # post prompt attention: making prompted input instance
            x = self.post_prompt_attention(x)
        else:
            raise NotImplementedError
        
        x = x.permute(1, 0, 2)  # NLD -> LND

        if self.Deep:  # Deep prompt version
            hidden_states = None
            num_layers = self.transformer.layers

            for i in range(num_layers):
                if i == 0:
                    hidden_states = self.transformer.resblocks[i](x)
                else:
                    if i <= self.deep_prompt_embeddings.shape[0]:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
                        
                        deep_prompt_emb = deep_prompt_emb.permute(1, 0, 2)  # NLD -> LND
            
                        hidden_states = torch.cat((
                            hidden_states[:1, :, :],
                            deep_prompt_emb,
                            hidden_states[(1+self.num_tokens):, :, :]
                        ), dim=0)

                    hidden_states = self.transformer.resblocks[i](hidden_states)
            x = hidden_states
        else:
            x = self.transformer(x)
        
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

class MultiVisualPromptTuningCLIP(nn.Module):
    def __init__(self, clip_model, device, clip_model_type="CLIPViT-B/32", DeepPrompt=False, n_vtk=8):
        super(MultiVisualPromptTuningCLIP, self).__init__()
        self.visual_backbone = PromptVisionTransformer(clip_model, device, clip_model_type, DeepPrompt, n_vtk)
        self.logit_scale = clip_model.logit_scale

    def prompt_init(self, source_prompts_file, attn_mode="pre_attn", meta_mode=False):
        self.visual_backbone.source_prompt_init(source_prompts_file, attn_mode)
        if meta_mode:
            self.visual_backbone.meta_prompt_init()

    def forward(self, images):
        image_features = self.visual_backbone(images)
        return image_features