from typing import Tuple, Union, List

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmengine.model.weight_init import xavier_init
from mmrazor.registry import MODELS


class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.shape
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y)
        return y


# Designed to be compatible with BEVFusion like teachers, with a random initalized bev queries.
@MODELS.register_module()
class BEVInitQueryGuidedDeformableTeacherReconstructedLoss(nn.Module):
    """
    Args:
        loss_weight (float): Weight of loss. Defaults to 1.0.
        resize_stu (bool): If True, we'll down/up sample the features of the
            student model to the spatial size of those of the teacher model if
            their spatial sizes are different. And vice versa. Defaults to
            True.
    """

    def __init__(self, encoder=None, loss_weight=100000.0, resize_stu=True, bev_h=180, bev_w=180, query_dims=256, embed_dims=512):
        super(BEVInitQueryGuidedDeformableTeacherReconstructedLoss, self).__init__()
        self.loss_weight = loss_weight
        self.resize_stu = resize_stu
        
        self.query_proj = nn.Linear(query_dims, embed_dims)
        self.encoder = build_transformer_layer_sequence(encoder)
        self.channel_selayer = SELayer(channel=embed_dims)
        self.spatial_mask_conv = nn.Conv2d(in_channels=embed_dims, out_channels=1, kernel_size=1)
        self.init_weights()
        self.grad_already_freezed = False
        self.bev_queries = nn.Embedding(bev_h*bev_w, query_dims)


    def init_weights(self):
        xavier_init(self.query_proj, distribution='uniform', bias=0.)
        nn.init.xavier_uniform_(self.spatial_mask_conv.weight)
        if self.spatial_mask_conv.bias is not None:
            nn.init.zeros_(self.spatial_mask_conv.bias)

    
    # freeze the grad of mask generation network
    def freeze_grad(self):
        for param in self.parameters():
            param.requires_grad = False
        self.grad_already_freezed = True


    def forward(self, preds_S: Union[torch.Tensor, Tuple],
                preds_T: Union[torch.Tensor, Tuple],
                mask_learning_stopped: bool) -> torch.Tensor:
        """Forward computation.

        Args:
            preds_S (torch.Tensor | Tuple[torch.Tensor]): The student model
                prediction. If tuple, it should be several tensors with shape
                (N, C, H, W).
            preds_T (torch.Tensor | Tuple[torch.Tensor]): The teacher model
                prediction. If tuple, it should be several tensors with shape
                (N, C, H, W).
            bev_queries的维度是(H*W, C)

        Return:
            torch.Tensor: The calculated loss value.
        """
        if isinstance(preds_S, torch.Tensor):
            preds_S, preds_T = (preds_S, ), (preds_T, )

        B, C, H, W = preds_T[0].shape

        queries = self.bev_queries.weight.to(preds_T[0].dtype).to(preds_T[0].device)
        masks = self.query_proj(queries)
        masks = self.encoder(preds_T, masks, H, W, C)

        # B, H, W, C -> B, C, H, W
        masks = masks.reshape(B, H, W, C).permute(0, 3, 1, 2)
        # B, 1, H, W
        mask_spatial = self.spatial_mask_conv(masks)
        mask_spatial = torch.sigmoid(mask_spatial)
        # B, C
        mask_channel = self.channel_selayer(masks)


        if mask_learning_stopped:
            if not self.grad_already_freezed:
                self.freeze_grad()

        # apply mask to teacher feature map
        teacher_maskedmap = self.mask_featuremap(mask_spatial, mask_channel, preds_T)

        loss = torch.tensor(0.0)

        for pred_S, pred_T in zip(preds_S, preds_T):
            pred_S = pred_S.detach()
            pred_T = pred_T.detach()
            size_S, size_T = pred_S.shape[2:], pred_T.shape[2:]
            if size_S[0] != size_T[0]:
                print("Warning: The feature map of student doesn't match the feature map of teacher!")
                if self.resize_stu:
                    pred_S = F.interpolate(pred_S, size_T, mode='bilinear')
                else:
                    pred_T = F.interpolate(pred_T, size_S, mode='bilinear')
            assert pred_S.shape == pred_T.shape

            loss = loss + self.masked_root_mse_loss(pred_S, pred_T, mask_spatial, mask_channel)

        result = [loss * self.loss_weight, teacher_maskedmap]

        return result

    
    # calculated masked distillation loss
    def masked_root_mse_loss(self, 
                            student_featuremap: torch.Tensor,
                            teacher_featuremap: torch.Tensor,
                            spatial_mask: torch.Tensor,
                            channel_mask: torch.Tensor) -> torch.Tensor:
        B, C, H, W = teacher_featuremap.shape
        assert teacher_featuremap.shape == student_featuremap.shape, "student & teacher feature map must have the same shape when cal MSE"
        diff_squared = (student_featuremap - teacher_featuremap).pow(2)

        spatial_mask_expanded = spatial_mask
        diff_squared_masked = diff_squared*spatial_mask_expanded

        channel_mask_expanded = channel_mask.view(B, C, 1, 1)
        diff_squared_masked = diff_squared_masked*channel_mask_expanded

        mse_loss = torch.mean(diff_squared_masked)

        return mse_loss

    
    def mask_featuremap(self, softmax_spatial_mask, softmax_channel_mask, featuremaps):
        softmax_spatial_mask = softmax_spatial_mask
        softmax_channel_mask = softmax_channel_mask.unsqueeze(2).unsqueeze(3)

        # softmax_spatial_mask = torch.ones((4, 1, 180, 180), dtype=torch.float32).to(featuremaps[0].device)
        # softmax_channel_mask = torch.ones((4, 512, 1, 1), dtype=torch.float32).to(featuremaps[0].device)
        
        assert len(featuremaps) == 1, 'Featuremap length is not 1'
        featuremaps[0] = featuremaps[0].detach()
        masked_featuremap = featuremaps[0] * softmax_spatial_mask * softmax_channel_mask

        return masked_featuremap
