import torch

from .MILFeatExt import MILFeatExt
from .MILTransformer import TransformerEncoder
from .MILPool import MILAttentionPool, MILMaxPool, MILMeanPool
from .MILModel import MILModel

class DeepMILAttModel(MILModel):
    def __init__(
        self,
        input_shape : tuple,
        feat_ext_name : str = 'none', 
        pool_name : str = 'att',
        transformer_encoder_kwargs : dict = {},
        pool_kwargs : dict = {},
        ce_criterion : torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
        **kwargs        
        ) -> None:
        super().__init__()
        self.input_shape = input_shape
        self.feat_ext_name = feat_ext_name
        self.pool_name = pool_name
        self.pool_kwargs = pool_kwargs
        self.transformer_encoder_kwargs = transformer_encoder_kwargs
        self.kwargs = kwargs
        self.ce_criterion = ce_criterion

        self.feat_ext = MILFeatExt(input_shape=input_shape, feat_ext_name=feat_ext_name)
        self.feat_dim = self.feat_ext.output_size
        if len(transformer_encoder_kwargs.keys()) > 0:
            self.transformer_encoder = TransformerEncoder(in_dim = self.feat_dim, **transformer_encoder_kwargs)
        else:
            self.transformer_encoder = None
        self.pool = self._get_pool(pool_name)(in_dim = self.feat_dim, **pool_kwargs)
        self.classifier = torch.nn.Linear(self.feat_dim, 1)

    def _get_pool(self, pool_name):
        if pool_name == 'att':
            return MILAttentionPool
        elif pool_name == 'max':
            return MILMaxPool
        elif pool_name == 'mean':
            return MILMeanPool
        else:
            raise ValueError('Invalid pool_name')   

    
    def forward(self, X, adj_mat, mask, return_att=False):
        """
        input:
            X: tensor (batch_size, bag_size, ...)
            adj_mat: sparse coo tensor (batch_size, bag_size, bag_size)
            mask: tensor (batch_size, bag_size)
        output:
            T_logits_pred: tensor (batch_size,)
            att: tensor (batch_size, bag_size) if return_att is True
        """

        X = self.feat_ext(X) # (batch_size, bag_size, D)
        if self.transformer_encoder is not None:
            X = self.transformer_encoder(X, adj_mat, mask)

        out_pool = self.pool(X, adj_mat, mask, return_att=return_att)
        if return_att:
            Z, f = out_pool # Z: (batch_size, D, n_samples), f: (batch_size, bag_size, n_samples)
            if len(f.shape) == 2:
                f = f.unsqueeze(dim=-1) 
            f = torch.mean(f, dim=2) # (batch_size, bag_size)
        else:
            Z = out_pool # (batch_size, D, n_samples)
        
        if len(Z.shape) == 2:
            Z = Z.unsqueeze(dim=-1) # (batch_size, D, 1)
        
        Z = Z.transpose(1,2) # (batch_size, n_samples, D)
        T_logits_pred = self.classifier(Z) # (batch_size, n_samples, 1)
        T_logits_pred = torch.mean(T_logits_pred, dim=(1,2)) # (batch_size,)

        if return_att:
            return T_logits_pred, f
        else:
            return T_logits_pred
    
    def compute_loss(self, T_labels, X, adj_mat, mask, *args, **kwargs):
        """
        Input:
            T_labels: tensor (batch_size,)
            X: tensor (batch_size, bag_size, ...)
            adj_mat: sparse coo tensor (batch_size, bag_size, bag_size)
            mask: tensor (batch_size, bag_size)
        Output:
            T_logits_pred: tensor (batch_size,)
            loss_dict: dict {'BCEWithLogitsLoss', **pool_loss_dict, ...}
        """
        T_logits_pred = self.forward(X, adj_mat, mask)
        pool_loss_dict = self.pool.compute_loss(T_labels=T_labels, X=X, adj_mat=adj_mat, mask=mask, *args, **kwargs)
        ce_loss = self.ce_criterion(T_logits_pred.float(), T_labels.float())
        return T_logits_pred, { 'BCEWithLogitsLoss': ce_loss, **pool_loss_dict }
    
    @torch.no_grad()
    def predict(self, X, adj_mat, mask, *args, return_y_pred=True, **kwargs):
        T_logits_pred, att_val = self.forward(X, adj_mat, mask, return_att=return_y_pred)
        return T_logits_pred, att_val
        


        
        