import torch
from torch_geometric.nn import GCNConv, GATConv, MessagePassing
from torch.nn import Sequential, ReLU, Linear
import logging
import torch.nn as nn
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from models import layers
import torch_geometric.nn as gnn

class GeneralizedMPNN(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        self.mpnn_type = cfg.model.attention_type
        
        # atom encoder
        self.atom_dim = cfg.model.atom_encoder.in_dim
        self.use_linear_atom_encoder = cfg.model.atom_encoder.linear
        
        # edge encoder
        self.use_linear_edge_encoder = cfg.model.edge_encoder.linear
        self.use_edge_attr = cfg.model.edge_encoder.use_edge_attr
        self.edge_attr_dim = cfg.model.edge_encoder.in_dim
        
        # model
        self.num_layers = cfg.model.num_layer
        self.H = cfg.model.H
        self.final_dim = cfg.model.final_dim
        self.dropout = cfg.model.dropout
        self.use_residual = cfg.model.residual
        self.use_linear = cfg.model.layer_encoder.linear
        self.dim_embed = cfg.model.dim_embed
        
        # pooling
        self.use_sum_pooling = cfg.model.sum_pooling
        
        # general
        self.dataset = cfg.data.name

        logging.info("Initializing Atom encoder")
        if self.use_linear_atom_encoder:  # linear layer
            self.embed_v = nn.Sequential(
                nn.Linear(self.atom_dim, self.dim_embed), nn.ReLU(), nn.Linear(self.dim_embed, self.dim_embed))
        else:  # look up table
            self.embed_v = AtomEncoder(self.dim_embed)
            
        logging.info(f"Initializing all {self.num_layers} layers")
        # MPNN (local/global)
        self.MPNNs = nn.ModuleDict()
        self.EDGE_ENCODER = nn.ModuleDict()
        # layer_encoder
        self.LAYER_ENCODER = nn.ModuleDict()
        # BNORM
        self.BNORM_RELUs = nn.ModuleDict()
        # DROPOUT
        self.DROP_OUTs = nn.ModuleDict()

        for layer_idx in range(self.num_layers):
            logging.info(f"Initializing layer number {layer_idx}.")
            layer_idx = str(layer_idx)
            if self.use_edge_attr:
                edge_out_dim = self.get_edge_encoder_out_dim(
                    layer_idx=layer_idx)
                edge_encoder_i = self.init_edge_encoder(
                    use_linear=self.use_linear_edge_encoder, in_dim=self.edge_attr_dim, out_dim=edge_out_dim)
                self.EDGE_ENCODER[layer_idx] = edge_encoder_i
            else:
                edge_out_dim = None
                self.EDGE_ENCODER[layer_idx] = None
            
            mpnn_i = self.MPNN_block_instnatiator(edge_out_dim=edge_out_dim)
            self.MPNNs[layer_idx] = mpnn_i
            
            if self.use_linear:
                layer_encoder = layers.LINEAR(self.dim_embed, self.dim_embed)
            else:
                layer_encoder = layers.MLP(self.dim_embed, self.dim_embed)
            self.LAYER_ENCODER[layer_idx] = layer_encoder
            
            self.BNORM_RELUs[layer_idx] = layers.NormReLU(
                self.dim_embed)
            
            if self.dropout > 0:
                self.DROP_OUTs[layer_idx] = nn.Dropout(p=self.dropout)
        logging.info(f"Initializing pooling")
        self.POOLING = Pooling_mpnn(idim=self.dim_embed, odim=self.final_dim)

    def forward(self, batch):
        batch.x = self.embed_v(batch.x)
        # LAYERS
        for layer_idx in range(self.num_layers):
            layer_idx = str(layer_idx)
            if self.use_edge_attr:
                encoded_edge_atr = self.EDGE_ENCODER[layer_idx](
                    message=-1, attrs=batch.edge_attr, dont_use_message=True)
            else:
                encoded_edge_atr = None
            agg_element = self.MPNNs[layer_idx](
                x=batch.x, edge_index=batch.edge_index, edge_attr=encoded_edge_atr)
            batch_x = self.BNORM_RELUs[layer_idx](
                self.LAYER_ENCODER[layer_idx](agg_element))
            # DROPOUT
            if self.dropout > 0:
                batch_x = self.DROP_OUTs[layer_idx](batch_x)
                # RESIDUAL
            if self.use_residual:
                batch.x = batch_x + batch.x
            else:
                batch.x = batch_x
        # POOL
        pool_value = self.POOLING(batch=batch, sum_pooling=self.use_sum_pooling)
        return pool_value
    
    def MPNN_block_instnatiator(self, edge_out_dim):
        mpnn = layers.MPNN_block(
            d=self.dim_embed, H=self.H, d_output=self.dim_embed, edge_dim=edge_out_dim, type=self.mpnn_type, use_linear=self.use_linear)
        return mpnn

    def get_edge_encoder_out_dim(self, layer_idx):
        edge_out_dim = self.dim_embed
        if layer_idx == 0 and self.dataset == "alchemy":
            edge_out_dim = 6
        return edge_out_dim

    def init_edge_encoder(self, use_linear, in_dim, out_dim):
        edge_encoder = layers.Bond(
            dim=out_dim, linear=use_linear, linear_in_dim=in_dim)
        return edge_encoder


class Pooling_mpnn(nn.Module):
    """Final pooling

    Args:
        idim (int): input dimension
        odim (int): output dimension

    """

    def __init__(self, idim: int, odim: int):
        super().__init__()

        self.predict = layers.MLP(idim, odim, hdim=idim*2, norm=False)

    def forward(self, batch, sum_pooling):
        if not sum_pooling:  # TODO: this is mean pool!!!
            return self.predict(gnn.global_mean_pool(batch.x, batch.batch))
        else:
            # TODO: this is sum pool !!!
            return self.predict(gnn.global_sum_pool(batch.x, batch.batch))
