from collections import OrderedDict
from typing import Tuple, Optional, List, Dict
import math
from operator import mul
from functools import reduce

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, Dropout


from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

class PromptVisionTransformer(nn.Module):
    def __init__(self, clip_model, clip_model_type, DeepPrompt, n_vtk, mode):
        super().__init__()
        self.input_resolution = clip_model.visual.input_resolution
        self.output_dim = clip_model.visual.output_dim
        self.conv1 = clip_model.visual.conv1
        
        self.class_embedding = clip_model.visual.class_embedding
        self.positional_embedding = clip_model.visual.positional_embedding
        self.ln_pre = clip_model.visual.ln_pre

        self.transformer = clip_model.visual.transformer

        self.ln_post = clip_model.visual.ln_post
        self.proj = clip_model.visual.proj
        self.Deep = DeepPrompt

        # prompt config
        if "ViT-B/32" in clip_model_type:
            patch_size = (32, 32)
            _, prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        elif "ViT-B/16" in clip_model_type:
            patch_size = (16, 16)
            _, prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        hidden_size = 768
        self.prompt_dropout = Dropout(0.1)

        # multi-prompt
        # state
        if mode == "state":
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, self.num_tokens, prompt_dim))
            # initialization
            val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim))  # noqa
            nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out')
            # xavier_uniform initialization
            nn.init.uniform_(self.prompt_embeddings.data, -val, val)
        # action
        elif "action" in mode:
            self.action_prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.action_prompt_embeddings = nn.Parameter(torch.zeros(
                1, self.num_tokens, prompt_dim))
            val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim))  # noqa
            # initialization
            nn.init.kaiming_normal_(self.action_prompt_proj.weight, a=0, mode='fan_out')
            # xavier_uniform initialization
            nn.init.uniform_(self.action_prompt_embeddings.data, -val, val)
        
        if self.Deep:  # Deep prompt version noqa
            total_d_layer = 12-1
            self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
                total_d_layer, self.num_tokens, prompt_dim))
            # xavier_uniform initialization
            nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)

        
    def forward(self, x, mode=None):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        #print(x.shape)
        x = x + self.positional_embedding.to(x.dtype)
        #print(self.positional_embedding.shape)
        
        if mode == "action":
            prompt_proj = self.action_prompt_proj
            prompt_embeddings = self.action_prompt_embeddings
        else:
            prompt_proj = self.prompt_proj
            prompt_embeddings = self.prompt_embeddings
        
        # incorporate_prompt
        B = x.size(0)
        x = torch.cat((
                x[:, :1, :],
                self.prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
                x[:, 1:, :]
            ), dim=1)
        #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # NLD -> LND

        if self.Deep:  # Deep prompt version
            hidden_states = None
            num_layers = self.transformer.layers

            for i in range(num_layers):
                if i == 0:
                    hidden_states = self.transformer.resblocks[i](x)
                else:
                    if i <= self.deep_prompt_embeddings.shape[0]:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
                        
                        deep_prompt_emb = deep_prompt_emb.permute(1, 0, 2)  # NLD -> LND
            
                        hidden_states = torch.cat((
                            hidden_states[:1, :, :],
                            deep_prompt_emb,
                            hidden_states[(1+self.num_tokens):, :, :]
                        ), dim=0)

                    hidden_states = self.transformer.resblocks[i](hidden_states)
            x = hidden_states
        else:
            x = self.transformer(x)
        
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x

class MTTextPromptLearner(nn.Module):
    def __init__(self, clip_model, class_names, action_names, device, use_csc, n_ctx):
        super().__init__()
        n_cls_class = len(class_names)
        n_cls_action = len(action_names)
        n_ctx = n_ctx
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = 224
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
        use_csc = use_csc

        # scene
        # random initialization
        if use_csc[0]:
            print("Initializing class-specific contexts")
            ctx_vectors = torch.empty(n_cls_class, n_ctx, ctx_dim, dtype=dtype)
        else:
            print("Initializing a generic context")
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
        nn.init.normal_(ctx_vectors, std=0.02)

        self.ctx_class = nn.Parameter(ctx_vectors)  # to be optimized

        # dynamic
        # random initialization
        if use_csc[1]:
            print("Initializing class-specific contexts")
            ctx_vectors = torch.empty(n_cls_action, n_ctx, ctx_dim, dtype=dtype)
        else:
            print("Initializing a generic context")
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
        nn.init.normal_(ctx_vectors, std=0.02)
        prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx_action = nn.Parameter(ctx_vectors)  # to be optimized

        # class scenes prompt
        class_names = [scene.replace("_", " ") for scene in class_names]
        class_name_lens = [len(_tokenizer.encode(name)) for name in class_names]
        class_prompts = [prompt_prefix + " " + name + "." for name in class_names]

        class_tokenized_prompts = torch.cat([clip.tokenize(p) for p in class_prompts]).to(device)
        with torch.no_grad():
            class_embedding = clip_model.token_embedding(class_tokenized_prompts).type(dtype)
        
        # class name prompt
        action_names = [name.replace("_", " ") for name in action_names]
        actino_name_lens = [len(_tokenizer.encode(name)) for name in action_names]
        action_prompts = [prompt_prefix + " " + name + "." for name in action_names]

        action_tokenized_prompts = torch.cat([clip.tokenize(p) for p in action_prompts]).to(device)
        with torch.no_grad():
            action_embedding = clip_model.token_embedding(action_tokenized_prompts).type(dtype)
        
        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("class_token_prefix", class_embedding[:, :1, :])  # SOS
        self.register_buffer("class_token_suffix", class_embedding[:, 1 + n_ctx :, :])  # CLS, EOS
        self.register_buffer("action_token_prefix", action_embedding[:, :1, :])  # SOS
        self.register_buffer("action_token_suffix", action_embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls_class
        self.n_cls_n = n_cls_action
        self.n_ctx = n_ctx
        self.class_tokenized_prompts = class_tokenized_prompts  # torch.Tensor
        self.action_tokenized_prompts = action_tokenized_prompts  # torch.Tensor
        self.scene_lens = class_name_lens
        self.name_lens = actino_name_lens
        self.class_token_position = "end"

    def forward(self, mode=None):        
        if mode != "action":
            ctx = self.ctx_class
            if ctx.dim() == 2:
                ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
            prefix = self.class_token_prefix
            suffix = self.class_token_suffix
        elif mode == "action":
            ctx = self.ctx_action
            if ctx.dim() == 2:
                ctx = ctx.unsqueeze(0).expand(self.n_cls_n, -1, -1)
            prefix = self.action_token_prefix
            suffix = self.action_token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

class MTPromptTextual(nn.Module):
    def __init__(self, clip_model, class_names, action_names, device, use_csc, n_ctx):
        super().__init__()
        self.prompt_learner = MTTextPromptLearner(clip_model, class_names, action_names, device, use_csc, n_ctx)
        self.class_tokenized_prompts = self.prompt_learner.class_tokenized_prompts
        self.action_tokenized_prompts = self.prompt_learner.action_tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        self.dtype = clip_model.dtype

    def forward(self, mode=None):
        # class
        if mode != "action":
            prompts = self.prompt_learner("class")
            tokenized_prompts = self.class_tokenized_prompts
        # action
        elif mode == "action":
            prompts = self.prompt_learner("action")
            tokenized_prompts = self.action_tokenized_prompts
        
        text_features = self.text_encoder(prompts, tokenized_prompts)
        text_features = text_features.type(self.dtype)

        return text_features

class VLMultiDualPromptViTCLIP(nn.Module):
    def __init__(self, clip_model, class_names, action_names, feat_dim, device, clip_model_type="CLIPViT-B/16", DeepPrompt=False, n_vtk=8, use_csc=False, n_ctx=8, mode=None):
        super(VLMultiDualPromptViTCLIP, self).__init__()
        self.logit_scale = clip_model.logit_scale
        self.visual_backbone = PromptVisionTransformer(clip_model, clip_model_type, DeepPrompt, n_vtk, mode)
        self.textual_backbone = MTPromptTextual(clip_model, class_names, action_names, device, use_csc, n_ctx)
        self._features_dim = feat_dim
        self.temp_text_features = None

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt" not in name:
                 if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        for name, param in self.textual_backbone.text_encoder.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images, mode=None):
        image_features = self.visual_backbone(images, mode)
        text_features = self.textual_backbone(mode)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logits = self.logit_scale * image_features @ text_features.t()

        return logits, text_features, image_features
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params