# coding=utf-8
# Copyright 2020 The Attribution Gnn Benchmarking Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Implementations of graph neural networks, building on graph_nets."""
import enum
from typing import Callable, List, Text, Tuple, Union

import numpy as np
import tensorflow as tf

import graph_nets
import sonnet as snt
import templates

GraphsTuple = graph_nets.graphs.GraphsTuple
Activation = templates.Activation


class ModelType(enum.Enum):
  """Prediction types for odor tasks."""
  gcn = 'gcn'
  mpnn = 'mpnn'
  graphnet = 'graphnet'


def print_model(model):
  print(f'{model.__class__.__name__} : {model.name}\n')
  print(snt.format_variables(model.variables))
  n_params = np.sum([np.prod(v.shape) for v in model.variables])
  trainable_params = np.sum(
      [np.prod(v.shape) for v in model.trainable_variables])
  print(f'\nParams: {trainable_params} trainable out of {n_params}')


def cast_activation(act):
  """Map string to activation, or just pass the activation function."""
  activations = {
      'selu': tf.nn.selu,
      'softplus': tf.nn.softplus,
      'relu': tf.nn.relu,
      'leaky_relu': tf.nn.leaky_relu,
      'tanh': tf.nn.tanh,
      'sigmoid': tf.nn.sigmoid,
      'softmax': tf.nn.softmax,
      'identity': tf.identity
  }
  if callable(act):
    return act
  else:
    return activations[act]


def get_mlp_fn(
    layer_sizes,
    act = 'relu'):
  """Instantiates a new MLP, followed by LayerNorm."""

  def make_mlp():
    return snt.Sequential([
        snt.nets.MLP(
            layer_sizes, activate_final=True, activation=cast_activation(act)),
        snt.LayerNorm(axis=-1, create_offset=True, create_scale=True)
    ])

  return make_mlp


class ReadoutGAP(snt.Module):
  """Global Average pooling style-layer."""

  def __init__(self, global_size, activation, name='ReadoutGAP'):
    super(ReadoutGAP, self).__init__(name=name)
    reducer = tf.math.unsorted_segment_sum
    self.node_reducer = graph_nets.blocks.NodesToGlobalsAggregator(reducer)
    self.edge_reducer = graph_nets.blocks.EdgesToGlobalsAggregator(reducer)
    self.node_emb = snt.Sequential(
        [snt.Linear(global_size),
         cast_activation(activation)])
    self.edge_emb = snt.Sequential(
        [snt.Linear(global_size),
         cast_activation(activation)])

  def get_activations(self, graph):
    """Get pre-pooling activations for nodes and edges."""
    return self.node_emb(graph.nodes), self.edge_emb(graph.edges)

  def __call__(self, inputs):
    new_nodes, new_edges = self.get_activations(inputs)
    graph = inputs.replace(nodes=new_nodes, edges=new_edges)
    new_globals = self.node_reducer(graph) + self.edge_reducer(graph)
    return graph.replace(
        nodes=inputs.nodes, edges=inputs.edges, globals=new_globals)


class NodesAggregator(snt.Module):
  """Agregates neighboring nodes based on sent and received nodes."""

  def __init__(self,
               reducer=tf.math.unsorted_segment_sum,
               name='nodes_aggregator'):
    super(NodesAggregator, self).__init__(name=name)
    self.reducer = reducer

  def __call__(self, graph):
    num_nodes = tf.reduce_sum(graph.n_node)
    adjacent_nodes = tf.gather(graph.nodes, graph.senders)
    return self.reducer(adjacent_nodes, graph.receivers, num_nodes)


class NodeLayer(graph_nets.blocks.NodeBlock):
  """GNN layer that only updates nodes, but uses edges."""

  def __init__(self, *args, **kwargs):
    super(NodeLayer, self).__init__(*args, use_globals=False, **kwargs)


class GCNLayer(graph_nets.blocks.NodeBlock):
  """GNN layer that only updates nodes using neighboring nodes and edges."""

  def __init__(self, *args, **kwargs):
    super(GCNLayer, self).__init__(*args, use_globals=False, **kwargs)
    self.gather_nodes = NodesAggregator()

  def __call__(self, graph):
    """Collect nodes, adjacent nodes, edges and update to get new nodes.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s, whose individual edges
        features (if `use_received_edges` or `use_sent_edges` is `True`),
        individual nodes features (if `use_nodes` is True) and per graph globals
        (if `use_globals` is `True`) should be concatenable on the last axis.

    Returns:
      An output `graphs.GraphsTuple` with updated nodes.
    """

    nodes_to_collect = []

    if self._use_sent_edges:
      nodes_to_collect.append(self._sent_edges_aggregator(graph))

    if self._use_received_edges:
      edge2node = self._received_edges_aggregator(graph)
      nodes_to_collect.append(edge2node)

    if self._use_nodes:
      nodes_to_collect.append(graph.nodes)

    adjacent_nodes = self.gather_nodes(graph)
    nodes_to_collect.append(adjacent_nodes)

    if self._use_globals:
      # The hint will be an integer if the graph has node features and the total
      # number of nodes is known at tensorflow graph definition time, or None
      # otherwise.
      num_nodes_hint = graph_nets.blocks._get_static_num_nodes(graph)
      nodes_to_collect.append(
          graph_nets.blocks.broadcast_globals_to_nodes(
              graph, num_nodes_hint=num_nodes_hint))

    collected_nodes = tf.concat(nodes_to_collect, axis=-1)
    updated_nodes = self._node_model(collected_nodes)
    return graph.replace(nodes=updated_nodes)


class NodeEdgeLayer(snt.Module):
  """GNN layer that only updates nodes and edges."""

  def __init__(self, node_model_fn, edge_model_fn, name='NodeEdgeLayer'):
    super(NodeEdgeLayer, self).__init__(name=name)
    self.edge_block = graph_nets.blocks.EdgeBlock(
        edge_model_fn=edge_model_fn, use_globals=False)
    self.node_block = graph_nets.blocks.NodeBlock(
        node_model_fn=node_model_fn, use_globals=False)

  def __call__(self, graph):
    return self.node_block(self.edge_block(graph))


class SequentialWithActivations(snt.Sequential):
  """Extend snt.Sequential with function for intermediate activations."""

  def call_with_activations(self, inputs, *args, **kwargs):
    """Same code as snt.call but also stores intermediate activations."""
    outputs = inputs
    acts = []
    for i, mod in enumerate(self._layers):
      if i == 0:
        # Pass additional arguments to the first layer.
        outputs = mod(outputs, *args, **kwargs)
      else:
        outputs = mod(outputs)
      acts.append(outputs)
    return outputs, acts
