# coding=utf-8
# Copyright 2020 The Attribution Gnn Benchmarking Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Functions for running experiments (colabs) and storing/saving data."""
import collections
import dataclasses
import os
from typing import (Any, Callable, List, MutableMapping, Optional, Text, Tuple,
                    Union)

import more_itertools
import numpy as np
import pandas as pd
import rdkit.Chem
import tensorflow.compat.v2 as tf

import datasets
import graph_nets
import graphnet_models as models
import graphnet_techniques as techniques
import graphs as graph_utils
import sonnet as snt
import tasks
import templates

# Typing alias.
GraphsTuple = graph_nets.graphs.GraphsTuple
TransparentModel = templates.TransparentModel
AttributionTechnique = templates.AttributionTechnique
AttributionTask = templates.AttributionTask
NodeEdgeTensors = templates.NodeEdgeTensors
Mol = rdkit.Chem.Mol
MethodDict = MutableMapping[Text, AttributionTechnique]
OrderedDict = MutableMapping


def set_seed(random_seed):
  """Sets initial seed for random numbers."""
  tf.random.set_seed(random_seed)
  np.random.seed(random_seed)


def get_graph_layer(layer_type, node_size,
                    edge_size, global_size, index):
  """Gets a GNN layer based on string and sizes."""
  name = f'{layer_type.name}_{index+1}'
  if layer_type == models.ModelType.gcn:
    return models.GCNLayer(models.get_mlp_fn([node_size] * 2), name=name)
  elif layer_type == models.ModelType.mpnn:
    return models.NodeEdgeLayer(
        models.get_mlp_fn([node_size] * 2),
        models.get_mlp_fn([edge_size] * 2),
        name=name)
  elif layer_type == models.ModelType.graphnet:
    use_globals = index != 0
    return graph_nets.modules.GraphNetwork(
        node_model_fn=models.get_mlp_fn([node_size] * 2),
        edge_model_fn=models.get_mlp_fn([edge_size] * 2),
        global_model_fn=models.get_mlp_fn([global_size] * 2),
        edge_block_opt={'use_globals': use_globals},
        node_block_opt={'use_globals': use_globals},
        global_block_opt={'use_globals': use_globals},
        name=name)
  else:
    raise ValueError(f'layer_type={layer_type} not implemented')


class GNN(snt.Module, templates.TransparentModel):
  """A general graph neural network for graph property prediction."""

  def __init__(self,
               node_size,
               edge_size,
               global_size,
               y_output_size,
               layer_type,
               activation,
               n_layers=3):
    super(GNN, self).__init__(name=layer_type.name)
    self.task_index = 0
    # Graph encoding step, basic linear mapping.
    self.encode = graph_nets.modules.GraphIndependent(
        node_model_fn=lambda: snt.Linear(node_size),
        edge_model_fn=lambda: snt.Linear(edge_size))

    # Message passing steps or GNN blocks.
    gnn_layers = [
        get_graph_layer(layer_type, node_size, edge_size, global_size, index)
        for index in range(0, n_layers)
    ]
    self.gnn = models.SequentialWithActivations(gnn_layers)
    self.gap = models.ReadoutGAP(global_size, tf.nn.softmax)
    self.linear = snt.Linear(y_output_size, with_bias=False)
    self.activation = models.cast_activation(activation)
    self.pred_layer = snt.Sequential([self.linear, self.activation])

  @tf.function(experimental_relax_shapes=True)
  def get_graph_embedding(self, x):
    """Build a graph embedding."""
    out_graph = self.gap(self.gnn(self.encode(x)))
    return out_graph.globals

  def __call__(self, x):
    """Typical forward pass for the model."""
    graph_emb = self.get_graph_embedding(x)
    y = self.pred_layer(graph_emb)
    return y

  @tf.function(experimental_relax_shapes=True)
  def predict(self, x):
    """Forward pass with output set on the task of interest (y[:,index])."""
    return self(x)[:, self.task_index]

  @tf.function(experimental_relax_shapes=True)
  def get_gradient(self, x):
    """Gets gradient of inputs wrt to the target."""
    with tf.GradientTape(watch_accessed_variables=False) as gtape:
      gtape.watch([x.nodes, x.edges])
      y = self.predict(x)
    nodes_grad, edges_grad = gtape.gradient(y, [x.nodes, x.edges])
    return nodes_grad, edges_grad

  @tf.function(experimental_relax_shapes=True)
  def get_gap_activations(self, x):
    """Gets node-wise and edge-wise contributions to graph embedding."""
    return self.gap.get_activations(self.gnn(self.encode(x)))

  def get_prediction_weights(self):
    """Gets last layer prediction weights."""
    w = self.linear.w[:, self.task_index]
    return w

  @tf.function(experimental_relax_shapes=True)
  def get_intermediate_activations_gradients(
      self, x
  ):
    """Gets intermediate layer activations and gradients."""
    acts = []
    grads = []
    with tf.GradientTape(
        persistent=True, watch_accessed_variables=False) as gtape:
      gtape.watch([x.nodes, x.edges])
      x = self.encode(x)
      outputs, acts = self.gnn.call_with_activations(x)
      out_graph = self.gap(outputs)
      y = self.activation(self.linear(out_graph.globals))[:, self.task_index]
    acts = [(act.nodes, act.edges) for act in acts]
    grads = gtape.gradient(y, acts)
    return acts, grads, y


def get_batched_attributions(method,
                             model,
                             inputs,
                             batch_size = 2500):
  """Batched attribution since memory (e.g. IG) can be an issue."""
  n = graph_utils.get_num_graphs(inputs)
  att_pred = []
  actual_batch_size = int(np.ceil(batch_size / method.sample_size))
  for chunk in more_itertools.chunked(range(n), actual_batch_size):
    x_chunk = graph_utils.get_graphs_tf(inputs, np.array(chunk))
    att = method.attribute(x_chunk, model)
    att_pred.extend(att)
  return att_pred


def generate_result(model,
                    method,
                    task,
                    inputs,
                    y_true,
                    true_atts,
                    pred_atts = None,
                    reducer_fn = np.nanmean,
                    batch_size = 1000):
  """For a given model, method and task, generate metrics."""
  if pred_atts is None:
    pred_atts = get_batched_attributions(method, model, inputs, batch_size)
  result = task.evaluate_attributions(
      true_atts, pred_atts, reducer_fn=reducer_fn)
  # Need to reshape since predict returns a 1D array.
  y_pred = model.predict(inputs).numpy().reshape(-1, 1)
  result.update(task.evaluate_predictions(y_true, y_pred))
  result['Task'] = task.name
  result['Technique'] = method.name
  result['Model'] = model.name
  return result


@dataclasses.dataclass
class ExperimentData:
  """Helper class to hold all data relating to an experiment."""
  df: pd.DataFrame
  tensorizer: Any
  train_index: np.ndarray
  test_index: np.ndarray
  mol_train: List[Mol]
  mol_test: List[Mol]
  x_train: GraphsTuple
  x_test: GraphsTuple
  y_train: np.ndarray
  y_test: np.ndarray
  atts_train: List[GraphsTuple]
  atts_test: List[GraphsTuple]

  def __post_init__(self):
    """Check train/test data coincides in shapes."""

    def check_length_match(label, *args):
      if len(set(args)) != 1:
        raise ValueError(f'{label} data lengths are different ({args})!')

    check_length_match('train', len(self.train_index), len(self.mol_train),
                       graph_utils.get_num_graphs(self.x_train),
                       len(self.y_train), len(self.atts_train))
    check_length_match('test', len(self.test_index), len(self.mol_test),
                       graph_utils.get_num_graphs(self.x_test),
                       len(self.y_test), len(self.atts_test))

  @classmethod
  def from_data_and_splits(cls, df, tensorizer, train_index, test_index,
                           mol_list, x, y, atts):
    """Build class from data and split indices."""
    if np.intersect1d(train_index, test_index).shape[0]:
      raise ValueError('train/test indices have overlap!.')
    return cls(
        df=df,
        tensorizer=tensorizer,
        train_index=train_index,
        test_index=test_index,
        mol_train=[mol_list[i] for i in train_index],
        mol_test=[mol_list[i] for i in test_index],
        x_train=graph_utils.get_graphs_tf(x, np.array(train_index)),
        x_test=graph_utils.get_graphs_tf(x, np.array(test_index)),
        y_train=y[train_index],
        y_test=y[test_index],
        atts_train=[atts[i] for i in train_index],
        atts_test=[atts[i] for i in test_index])


def get_experiment_setup(
    task_type
):
  """Get experiment data based on task_name."""
  task_type = tasks.Task(task_type)
  use_h = task_type == tasks.Task.crippen
  exp_dir = datasets.get_task_dir(task_type)
  df, x, mols, tensorizer, train_index, test_index = datasets.load_molgraph_dataset(
      exp_dir, task_type.name, use_h)

  task = tasks.get_task(task_type)
  y = task.get_true_predictions(mols)
  atts = task.get_true_attributions(mols)
  exp = ExperimentData.from_data_and_splits(df, tensorizer, train_index,
                                            test_index, mols, x, y, atts)
  methods = techniques.get_techniques_dict(*tensorizer.get_null_vectors())
  return exp, task, methods


def get_bias_experiment_setup(
    dataset_name, ab_mix
):
  """Setup experiment data for a bias task."""
  exp_dir = os.path.join(datasets.BIAS_DIR, dataset_name)
  df, x, mols, tensorizer, train_index, test_index = datasets.load_molgraph_dataset(
      exp_dir, dataset_name)

  train_index = datasets.interpolate_bias_train_indices(df, train_index, ab_mix)
  rules = datasets.setup_bias_rules(exp_dir, dataset_name)
  rule_indices = datasets.setup_bias_test_indices(df, test_index)

  task_dict = collections.OrderedDict()
  exp_dict = collections.OrderedDict()
  for rule_name, rule in rules.items():
    task = tasks.FragmentLogicTask(rule=rule)
    task_dict[rule_name] = task
    y = task.get_true_predictions(mols)
    atts = task.get_true_attributions(mols)
    test_indices = rule_indices[rule_name]
    exp_dict[rule_name] = ExperimentData.from_data_and_splits(
        df, tensorizer, train_index, test_indices, mols, x, y, atts)

  methods = techniques.get_techniques_dict(*tensorizer.get_null_vectors())
  return exp_dict, task_dict, methods
