from collections import OrderedDict
from typing import Tuple, Optional, List, Dict
import math
from operator import mul
from functools import reduce
import copy

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, Dropout

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

class PromptVisionTransformer(nn.Module):
    def __init__(self, clip_model, device, clip_model_type, DeepPrompt, n_vtk):
        super().__init__()
        self.device = device
        self.input_resolution = clip_model.visual.input_resolution
        self.output_dim = clip_model.visual.output_dim
        self.conv1 = clip_model.visual.conv1
        
        self.class_embedding = clip_model.visual.class_embedding
        self.positional_embedding = clip_model.visual.positional_embedding
        self.ln_pre = clip_model.visual.ln_pre

        self.transformer = clip_model.visual.transformer

        self.ln_post = clip_model.visual.ln_post
        self.proj = clip_model.visual.proj
        self.Deep = DeepPrompt
        self.n_vtk = n_vtk

        # prompt config
        if "ViT-B/32" in clip_model_type:
            self.patch_size = (32, 32)
            _, self.prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        elif "ViT-B/16" in clip_model_type:
            self.patch_size = (16, 16)
            _, self.prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        self.hidden_size = 768
        self.prompt_dropout = Dropout(0.1)

        self.attn_list = []

    def source_prompt_init(self, source_prompts_file, multi_p_mode):
        self.prompt_num = len(source_prompts_file)
        self.multi_p_mode = multi_p_mode
        self.source_prompt_list = nn.ParameterList([])
        self.source_prompt_attn_weight_list = nn.ModuleList([])
        for i in range(self.prompt_num):
            parameters = torch.load(source_prompts_file[i])
            if "net" in list(parameters.keys())[0]:
                key_prompt_embeddings = "net.visual_backbone.prompt_embeddings"
                key_prompt_proj = "net.visual_backbone.prompt_proj"
            else:
                key_prompt_embeddings = "visual_backbone.prompt_embeddings"
                key_prompt_proj = "visual_backbone.prompt_proj"
            with torch.no_grad():
                # prompt embeddings
                source_prompt = parameters[key_prompt_embeddings].clone()
                source_prompt_embeddings = nn.Parameter(source_prompt).requires_grad_(False)
                source_prompt_embeddings.to(self.device)
                assert (source_prompt_embeddings.to(self.device) == parameters[key_prompt_embeddings]).all()
                # prompt proj
                source_prompt_proj = torch.nn.Linear(self.prompt_dim, self.hidden_size)
                source_prompt_proj.weight.copy_(parameters[key_prompt_proj+".weight"].clone())
                source_prompt_proj.bias.copy_(parameters[key_prompt_proj+".bias"].clone())
                source_prompt_proj.requires_grad_(False)
                source_prompt_proj.to(self.device)
                assert (source_prompt_proj.bias.to(self.device) == parameters[key_prompt_proj+".bias"]).all()
                prompt = self.incorporate_prompt(source_prompt_embeddings, source_prompt_proj)
                self.source_prompt_list.append(prompt)
                if self.prompt_num > 1 and multi_p_mode[1]=="WEIGHTED":
                    # ATTENTION MODULE
                    if multi_p_mode[0] == "COMPOSE":
                        attn = nn.Sequential(
                            nn.Linear(self.hidden_size, 128, bias=False),
                            nn.SiLU(),
                            nn.Linear(128, self.hidden_size, bias=False),
                            nn.LayerNorm(self.hidden_size),
                        )
                        self.source_prompt_attn_weight_list.append(attn)
                    elif multi_p_mode[0] == "ATTEMPT":
                        # BASELINE REFERENCE: ATTEMPT
                        self.attn_prompt_W_down = nn.Linear(self.hidden_size, 128, bias=False)
                        self.attn_prompt_W_up = nn.Linear(128, self.hidden_size, bias=False)
                        self.attn_prompt_non_linear = nn.SiLU()
                        self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)
                
        self.source_prompt_list.requires_grad_(False)

    def meta_prompt_init(self,):
        self.meta_prompt_proj = nn.Linear(self.prompt_dim, self.hidden_size)
        nn.init.kaiming_normal_(self.meta_prompt_proj.weight, a=0, mode='fan_out')

        val = math.sqrt(6. / float(3 * reduce(mul, self.patch_size, 1) + self.prompt_dim))  # noqa

        self.meta_prompt_embeddings = nn.Parameter(torch.zeros(
            1, self.num_tokens, self.prompt_dim))
        # xavier_uniform initialization
        nn.init.uniform_(self.meta_prompt_embeddings.data, -val, val)

        if self.Deep:  # Deep prompt version noqa
            total_d_layer = 12-1
            self.deep_meta_prompt_embeddings = nn.Parameter(torch.zeros(
                total_d_layer, self.num_tokens, self.prompt_dim))
            # xavier_uniform initialization
            nn.init.uniform_(self.deep_meta_prompt_embeddings.data, -val, val)
        
        self.prompt_num += 1
        print(self.prompt_num)

    def incorporate_prompt(self, prompt_embeddings, prompt_proj=None, x=None):
        if x is not None:
            if prompt_proj is not None:
                B = x.size(0)
                x = torch.cat((
                        x[:, :1, :],
                        self.prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
                        x[:, 1:, :]
                    ), dim=1)
                #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
                x = self.ln_pre(x)
            else:
                B = x.size(0)
                x = torch.cat((
                        x[:, :1, :],
                        prompt_embeddings.expand(B, -1, -1),
                        x[:, 1:, :]
                    ), dim=1)
                #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
                x = self.ln_pre(x)
        else:
            x = self.prompt_dropout(prompt_proj(prompt_embeddings))
        
        return x
    
    def pre_prompt_forward(self, X):
        # ATTENTION (source + new task)
        if self.multi_p_mode[0] == "COMPOSE" and self.multi_p_mode[1] == "WEIGHTED":
            B, _, D = X.shape # torch.Size([B, 50, hidden_size])
            # X -> X_hat
            X_hat = torch.mean(X, 1) # torch.Size([B, hidden_size])
            # QUERY: H
            H = X_hat.unsqueeze(1) # torch.Size([B, 1, hidden_size])
            # P -> P_hat
            P_list = []
            P_hat_list = []
            for prompt, attn_w in zip(self.source_prompt_list, self.source_prompt_attn_weight_list):
                P = prompt # torch.Size([1, 8, hidden_size])
                P_list.append(P)
                P_hat = torch.mean(prompt, 1) # torch.Size([1, hidden_size])
                P_hat = attn_w(P_hat) # torch.Size([1, hidden_size])
                P_hat_list.append(P_hat)
            # prompt_num += 1
            if hasattr(self, "meta_prompt_proj"):
                meta_P = self.incorporate_prompt(self.meta_prompt_embeddings, self.meta_prompt_proj, x=x) # torch.Size([1, 8, hidden_size])
                meta_P_hat = torch.mean(meta_P, 1) # torch.Size([1, hidden_size])
                H = meta_P_hat.unsqueeze(1)
            # VALUE: P
            multi_P = torch.cat(P_list).view(self.prompt_num, -1) #torch.Size([prompt_num, 8*hidden_size])
            multi_P = multi_P.unsqueeze(0).expand(B, self.prompt_num, -1) # torch.Size([B, prompt_num, 8*hidden_size])
            # KEY: P_hat
            multi_P_hat = torch.cat(P_hat_list).view(self.prompt_num, -1) # torch.Size([prompt_num, hidden_size])
            multi_P_hat = multi_P_hat.unsqueeze(0).expand(B, self.prompt_num, -1) # torch.Size([B, prompt_num, hidden_size])
            score = torch.bmm(H, multi_P_hat.transpose(1, 2) / np.sqrt(D))
            attn = torch.softmax(score, -1) # torch.Size([B, 1, prompt_num])
            if self.multi_p_mode[2] == "AVG":
                context = torch.bmm(attn, multi_P)
                context = context.view(B, -1, D) # torch.Size([B, 8, hidden_size])
                # P_instance
                if hasattr(self, "meta_prompt_proj"):
                    prompt = meta_P + context
                else:
                    prompt = context
                x = self.incorporate_prompt(prompt, x=X)

        elif self.multi_p_mode[0] == "ATTEMPT" and self.multi_p_mode[1] == "WEIGHTED":
            # BASELINE REFERENCE: ATTEMPT
            B, _, D = X.shape # torch.Size([B, 50, hidden_size])
            # X -> X_hat-> H
            X_hat, _ = torch.max(X, 1) # torch.Size([B, hidden_size])
            H = self.attn_prompt_W_down(X_hat)
            H = self.attn_prompt_non_linear(H)
            H = self.attn_prompt_W_up(H)
            H = self.prompt_layer_norm(H) # torch.Size([B, hidden_size])
            # QUERY
            H = H.unsqueeze(1) # torch.Size([B, 1, hidden_size])
            # P -> P_hat
            P_list = []
            P_hat_list = []
            for prompt in self.source_prompt_list:
                P = prompt # torch.Size([1, 8, hidden_size])
                P_list.append(P)
                P_hat, _ = torch.max(P, 1) # torch.Size([1, hidden_size])
                P_hat_list.append(P_hat)
            
            # prompt_num += 1
            if hasattr(self, "meta_prompt_proj"):
                meta_P = self.incorporate_prompt(self.meta_prompt_embeddings, self.meta_prompt_proj)
                P_list.append(meta_P)
                meta_P_hat, _ = torch.max(meta_P, 1)
                P_hat_list.append(meta_P_hat)
            # VALUE
            multi_P = torch.cat(P_list).view(self.prompt_num, -1) #torch.Size([prompt_num, 8*hidden_size])
            multi_P = multi_P.unsqueeze(0).expand(B, self.prompt_num, -1) # torch.Size([B, prompt_num, 8*hidden_size])
            # KEY
            
            multi_P_hat = torch.cat(P_hat_list) # torch.Size([prompt_num, hidden_size])
            multi_P_hat = multi_P_hat.unsqueeze(0).expand(B, self.prompt_num, -1) # torch.Size([B, prompt_num, hidden_size])
            score = torch.bmm(H, multi_P_hat.transpose(1, 2) / np.sqrt(H.shape[-1]))
            attn = torch.softmax(score, -1) # torch.Size([B, 1, prompt_num])
            # weighted sum
            context = torch.bmm(attn, multi_P)
            context = context.view(B, -1, H.shape[-1]) # torch.Size([B, 8, hidden_size])
            # P_instance
            if hasattr(self, "meta_prompt_proj"):
                prompt = meta_P + context # torch.Size([B, 8, hidden_size])
            else:
                prompt = context # torch.Size([B, 8, hidden_size])
            x = self.incorporate_prompt(prompt, x=X)

        else:
            B, _, D = X.shape # torch.Size([B, 50, hidden_size])
            P_list = []
            for prompt in self.source_prompt_list:
                P = prompt # torch.Size([1, 8, hidden_size])
                P_list.append(P)
            multi_P = torch.cat(P_list).view(self.prompt_num, -1) #torch.Size([prompt_num, 8*hidden_size])
            multi_P = multi_P.unsqueeze(0).expand(B, self.prompt_num, -1) # torch.Size([B, prompt_num, 8*hidden_size])
            multi_P = multi_P.view(B, self.prompt_num, self.n_vtk, D)
            if self.multi_p_mode[2] == "AVG":
                prompt = torch.mean(multi_P, dim=1)
                x = self.incorporate_prompt(prompt, x=X)
            
        # self.attn_list.append(attn)
        # print(torch.mean(torch.stack(self.attn_list).squeeze(), 0))
        return x

    def post_prompt_forward(self, X):
        P_list = []
        for prompt in self.source_prompt_list:
            P = self.incorporate_prompt(prompt, x=X) # torch.Size([B, (50+8), hidden_size])
            P_list.append(P)
        B, L, D = P.shape
        # prompt_num += 1
        if hasattr(self, "meta_prompt_proj"):
            meta_P = self.incorporate_prompt(self.meta_prompt_embeddings, self.meta_prompt_proj, X)
            P_list.append(meta_P)
        multi_P = torch.cat(P_list, dim=1).view(B*self.prompt_num, L, D) #torch.Size([B*prompt_num, (50+8), hidden_size])
        return multi_P

    def forward(self, x, mode):
        with torch.no_grad():
            x = self.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
            x = x + self.positional_embedding.to(x.dtype)
        # incorporate_prompt
        if mode:
            x = self.incorporate_prompt(self.meta_prompt_embeddings, self.meta_prompt_proj, x)
        elif self.prompt_num == 1:
            prompt = self.source_prompt_list[0]
            x = self.incorporate_prompt(prompt, x=x)
        elif (self.multi_p_mode[0] == "COMPOSE") or (self.multi_p_mode[0] == "ATTEMPT"):
            # pre prompt attention: making prompt instance
            x = self.pre_prompt_forward(x)
        elif (self.multi_p_mode[0] == "ENSEMBLE") or ((self.multi_p_mode[0] == "SESoM")):
            # post prompt attention: making prompted input instance
            x = self.post_prompt_forward(x)
        else:
            raise NotImplementedError
        
        x = x.permute(1, 0, 2)  # NLD -> LND

        x = self.transformer(x)
        
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj
        
        return x

class AVGMultiVisualPromptTuningCLIP(nn.Module):
    def __init__(self, clip_model, device, clip_model_type="CLIPViT-B/32", DeepPrompt=False, n_vtk=8):
        super(AVGMultiVisualPromptTuningCLIP, self).__init__()
        self.visual_backbone = PromptVisionTransformer(clip_model, device, clip_model_type, DeepPrompt, n_vtk)
        self.logit_scale = clip_model.logit_scale

    def prompt_init(self, source_prompts_file, multi_p_mode="COMPOSE", meta_mode=False):
        self.visual_backbone.source_prompt_init(source_prompts_file, multi_p_mode)
        if meta_mode:
            self.visual_backbone.meta_prompt_init()

    def forward(self, images, mode=False):
        image_features = self.visual_backbone(images, mode)
        return image_features