from collections import OrderedDict
from typing import Tuple, Optional, List, Dict
import math
from operator import mul
from functools import reduce

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, Dropout


from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

class PromptVisionTransformer(nn.Module):
    def __init__(self, clip_model, clip_model_type, DeepPrompt, n_vtk):
        super().__init__()
        self.input_resolution = clip_model.visual.input_resolution
        self.output_dim = clip_model.visual.output_dim
        self.conv1 = clip_model.visual.conv1
        
        self.class_embedding = clip_model.visual.class_embedding
        self.positional_embedding = clip_model.visual.positional_embedding
        self.ln_pre = clip_model.visual.ln_pre

        self.transformer = clip_model.visual.transformer

        self.ln_post = clip_model.visual.ln_post
        self.proj = clip_model.visual.proj
        self.Deep = DeepPrompt

        # prompt config
        if "ViT-B/32" in clip_model_type:
            patch_size = (32, 32)
            _, prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        elif "ViT-B/16" in clip_model_type:
            patch_size = (16, 16)
            _, prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        hidden_size = 768
        self.prompt_dropout = Dropout(0.1)
        self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
        nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out')

        val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim))  # noqa

        self.prompt_embeddings = nn.Parameter(torch.zeros(
            1, self.num_tokens, prompt_dim))
        # xavier_uniform initialization
        nn.init.uniform_(self.prompt_embeddings.data, -val, val)

        if self.Deep:  # Deep prompt version noqa
            total_d_layer = 12-1
            self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
                total_d_layer, self.num_tokens, prompt_dim))
            # xavier_uniform initialization
            nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
        
    def forward(self, x):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        #print(x.shape)
        x = x + self.positional_embedding.to(x.dtype)
        #print(self.positional_embedding.shape)
        
        # incorporate_prompt
        B = x.size(0)
        x = torch.cat((
                x[:, :1, :],
                self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
                x[:, 1:, :]
            ), dim=1)
        #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # NLD -> LND

        if self.Deep:  # Deep prompt version
            hidden_states = None
            num_layers = self.transformer.layers

            for i in range(num_layers):
                if i == 0:
                    hidden_states = self.transformer.resblocks[i](x)
                else:
                    if i <= self.deep_prompt_embeddings.shape[0]:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
                        
                        deep_prompt_emb = deep_prompt_emb.permute(1, 0, 2)  # NLD -> LND
            
                        hidden_states = torch.cat((
                            hidden_states[:1, :, :],
                            deep_prompt_emb,
                            hidden_states[(1+self.num_tokens):, :, :]
                        ), dim=0)

                    hidden_states = self.transformer.resblocks[i](hidden_states)
            x = hidden_states
        else:
            x = self.transformer(x)
        
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

class VisualPromptTuningCLIP(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device, clip_model_type="CLIPViT-B/32", DeepPrompt=False, n_vtk=8):
        super(VisualPromptTuningCLIP, self).__init__()
        self.visual_backbone = PromptVisionTransformer(clip_model, clip_model_type, DeepPrompt, n_vtk)
        self._features_dim = feat_dim
        self.logit_scale = clip_model.logit_scale
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classnames]).to(device)
        text_features = clip_model.encode_text(text_inputs)
        self.text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt" not in name:
                 if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad

    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def forward(self, images):
        image_features = self.visual_backbone(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        #logits = self.logit_scale * image_features @ self.text_features.t()
        return image_features
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
        ]

        return params