from torch import nn
from einops import rearrange
from croma_transformer import BaseTransformerCrossAttn
import croma_utils

class CROMA(nn.Module):
    def __init__(self,
                 patch_size=8,
                 encoder_width=768,
                 encoder_layers=12,
                 attention_heads=16,
                 decoder_width=512,
                 decoder_layers=1,
                 total_channels=14,
                 num_patches=225,
                 ):
        super().__init__()
        self.encoder_width = encoder_width
        self.encoder_layers = encoder_layers
        self.decoder_width = decoder_width
        self.decoder_layers = decoder_layers
        self.attention_heads = attention_heads
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.total_channels = total_channels
        self.s1_encoder = croma_utils.ViT(num_patches=self.num_patches,
                                          width=self.encoder_width,
                                          layers=int(self.encoder_layers/2),
                                          attention_heads=self.attention_heads,
                                          in_channels=2,
                                          patch_size=self.patch_size,
                                          )
        self.s2_encoder = croma_utils.ViT(num_patches=self.num_patches,
                                          width=self.encoder_width,
                                          layers=self.encoder_layers,
                                          attention_heads=self.attention_heads,
                                          in_channels=12,
                                          patch_size=self.patch_size,
                                          )
        self.cross_encoder = BaseTransformerCrossAttn(width=self.encoder_width,
                                                      layers=int(self.encoder_layers/2),
                                                      attention_heads=self.attention_heads,
                                                      )
        self.GAP_FFN_s1 = nn.Sequential(
            nn.LayerNorm(self.encoder_width),
            nn.Linear(self.encoder_width, int(4*self.encoder_width)),
            nn.GELU(),
            nn.Linear(int(4*self.encoder_width), self.encoder_width)
        )
        self.GAP_FFN_s2 = nn.Sequential(
            nn.LayerNorm(self.encoder_width),
            nn.Linear(self.encoder_width, int(4*self.encoder_width)),
            nn.GELU(),
            nn.Linear(int(4*self.encoder_width), self.encoder_width)
        )
        self.decoder = croma_utils.DecoderMAE(num_patches=self.num_patches,
                                              encoder_width=self.encoder_width,
                                              decoder_width=self.decoder_width,
                                              decoder_layers=self.decoder_layers,
                                              attention_heads=8,
                                              total_channels=self.total_channels,
                                              patch_size=self.patch_size,
                                              )
        self.attn_bias = croma_utils.get_alibi(attention_heads=self.attention_heads,
                                               num_patches=self.num_patches)
        self.global_contrast_loss = croma_utils.ContrastLossInputMix(projection_input=self.encoder_width,
                                                                     projection_output=self.encoder_width,
                                                                     )

    def forward(self, s12_imgs, s1_mask_info, s2_mask_info, rank, world_size, lam, mixup_labels):
        s1_imgs = s12_imgs[:, 12:, ...]
        s2_imgs = s12_imgs[:, :12, ...]
        s1_masked_attn_bias = croma_utils.apply_mask_to_alibi(alibi=self.attn_bias.to(s1_imgs.device),
                                                              ids_keep_queries=s1_mask_info['ids_keep'],
                                                              ids_keep_keys=s1_mask_info['ids_keep'],
                                                              batch_size=s1_imgs.shape[0],
                                                              orig_seq_len=self.num_patches,
                                                              masked_seq_len=s1_mask_info['len_keep'],
                                                              attention_heads=self.attention_heads)
        s2_masked_attn_bias = croma_utils.apply_mask_to_alibi(alibi=self.attn_bias.to(s2_imgs.device),
                                                              ids_keep_queries=s2_mask_info['ids_keep'],
                                                              ids_keep_keys=s2_mask_info['ids_keep'],
                                                              batch_size=s1_imgs.shape[0],
                                                              orig_seq_len=self.num_patches,
                                                              masked_seq_len=s2_mask_info['len_keep'],
                                                              attention_heads=self.attention_heads)
        s1_encodings = self.s1_encoder(imgs=s1_imgs, attn_bias=s1_masked_attn_bias, mask_info=s1_mask_info)
        s2_encodings = self.s2_encoder(imgs=s2_imgs, attn_bias=s2_masked_attn_bias, mask_info=s2_mask_info)
        s1_GAP = self.GAP_FFN_s1(s1_encodings.mean(dim=1))
        s2_GAP = self.GAP_FFN_s2(s2_encodings.mean(dim=1))
        contrastive_loss = self.global_contrast_loss(s1_features=s1_GAP,
                                                     s2_features=s2_GAP,
                                                     world_size=world_size,
                                                     rank=rank,
                                                     lam=lam,
                                                     mixup_labels=mixup_labels)
        cross_attn_bias = croma_utils.apply_mask_to_alibi(alibi=self.attn_bias.to(s1_imgs.device),
                                                          ids_keep_queries=s1_mask_info['ids_keep'],
                                                          ids_keep_keys=s2_mask_info['ids_keep'],
                                                          batch_size=s1_imgs.shape[0],
                                                          orig_seq_len=self.num_patches,
                                                          masked_seq_len=s2_mask_info['len_keep'],
                                                          attention_heads=self.attention_heads)
        s12_encodings = self.cross_encoder(x=s1_encodings,
                                           context=s2_encodings,
                                           alibi=cross_attn_bias)
        s12_target = rearrange(s12_imgs, 'b c (h i) (w j) -> b (h w) (c i j)', i=self.patch_size, j=self.patch_size)
        mae_loss = self.decoder(x=s12_encodings,
                                mask_info_radar=s1_mask_info,
                                mask_info_optical=s2_mask_info,
                                target=s12_target,
                                )
        return contrastive_loss, mae_loss

