# coding=utf-8
# Copyright 2020 The Attribution Gnn Benchmarking Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Attribution tasks."""
import collections
import enum
from typing import (Any, Callable, List, MutableMapping, Optional, Text, Tuple,
                    Union)

import numpy as np
import pandas as pd
import tensorflow as tf
from rdkit import Chem

import attribution_metrics as att_metrics
import fragment_identifier
import graph_nets
import graphs as graph_utils
import sklearn
import sklearn.metrics
import sklearn.preprocessing
import templates

# Typing aliases.
GraphsTuple = templates.GraphsTuple
Mol = Chem.Mol

CRIPPEN_PATH = 'data/crippen/crippen_subgraph_contributions.csv'


def _get_mol_sender_receivers(mol):
  """Get connectivity (messages) info for a data_dict."""
  senders, receivers = [], []
  for bond in mol.GetBonds():
    id1 = bond.GetBeginAtom().GetIdx()
    id2 = bond.GetEndAtom().GetIdx()
    senders.extend([id1, id2])
    receivers.extend([id2, id1])
  return np.array(senders), np.array(receivers)


def _make_attribution_from_nodes(mol, nodes,
                                 global_vec):
  """Makes an attribution from node information."""
  senders, receivers = _get_mol_sender_receivers(mol)
  data_dict = {
      'nodes': nodes.astype(np.float32),
      'senders': senders,
      'receivers': receivers,
      'globals': global_vec.astype(np.float32)
  }
  return graph_nets.utils_np.data_dicts_to_graphs_tuple([data_dict])


def get_crippen_features(
    mol):
  """Calculate crippen features."""
  n_atoms = mol.GetNumAtoms()
  n_atoms_with_h = Chem.AddHs(mol).GetNumAtoms()
  if n_atoms != n_atoms_with_h:
    raise ValueError('Your molecule might not have explicit hydrogens!')
  atom_types = [None] * n_atoms
  atom_labels = [None] * n_atoms
  contribs = Chem.rdMolDescriptors._CalcCrippenContribs(
      mol,
      force=True,
      atomTypes=atom_types,
      atomTypeLabels=atom_labels)
  logp, _ = zip(*contribs)  # Second component is molecular reflecivity.
  return np.array(logp), atom_types, atom_labels


class CrippenLogPTask(templates.AttributionTask):
  """CrippenLogP weighting task.

  Crippen's model for LogP (Octanol/Water Partition Coefficient) considers
  LogP as a weighted sum of atom's contributions. Atom's contributions are
  assigned based on their local graph neighborhood, they are classified and
  based on their label given a score. For actual LogP, there are better models,
  this is only used as a real-world example for a synthetic task.
  The equation for the score is:

    CrippenLogP = sum_i  w|sub_graph(atom_i)

  The attributions are w.

  Task based on "Prediction of Physicochemical Parameters by
  Atomic Contributions" (https://pubs.acs.org/doi/10.1021/ci990307l).


  """

  def __init__(self, load_data = False, data_path = None):
    """CrippenLogPTask.

    Args:
      load_data: Bool if to add additional data related to atom labels.
      data_path: If loading data, file path for the file. If None, will load
        froma default path.
    """
    self.data = None
    if load_data:
      if data_path is None:
        self.data = pd.read_csv(CRIPPEN_PATH)

  @property
  def name(self):
    return 'CrippenLogP'

  def get_nn_activation_fn(self):
    """Activation useful for NN model building."""
    return tf.identity

  def get_nn_loss_fn(self):
    """Loss function useful for NN model training."""

    def loss_fn(y_true, y_pred):
      return tf.reduce_mean(tf.losses.mean_squared_error(y_true, y_pred))

    return loss_fn

  def get_true_predictions(self, mols):
    """Gets Crippen values."""
    values = [sum(get_crippen_features(mol)[0]) for mol in mols]
    return np.array(values).reshape(-1, 1)

  def evaluate_predictions(self, y_true,
                           y_pred):
    """Scores predictions and return a dict of (metric_name, value)."""
    values = [('R2', sklearn.metrics.r2_score(y_true, y_pred)),
              ('RMSE', att_metrics.rmse(y_true, y_pred)),
              ('tau', att_metrics.kendall_tau_score(y_true, y_pred)),
              ('r', att_metrics.pearson_r_score(y_true, y_pred))]
    return collections.OrderedDict(values)

  def get_true_attributions(self, mols):
    """Gets crippen values for each molecule as a GraphsTuple."""
    graph_list = []
    for mol in mols:
      logp = np.array(get_crippen_features(mol)[0])
      senders, receivers = _get_mol_sender_receivers(mol)
      data_dict = {
          'nodes': logp,
          'edges': None,
          'senders': senders,
          'receivers': receivers
      }
      graph_list.append(
          graph_nets.utils_np.data_dicts_to_graphs_tuple([data_dict]))
    return graph_list

  def preprocess_attributions(self,
                              atts):
    """Prepare attribtuions for visualization or evaluation."""
    new_atts = []
    for att in atts:
      att = graph_utils.cast_to_np(graph_utils.reduce_sum_edges(att))
      new_atts.append(att)
    return new_atts

  def evaluate_attributions(
      self,
      true_atts,
      pred_atts,
      reducer_fn = np.nanmean
  ):
    """Scores attributions, return dict of (metric, reduce_fn(values))."""
    pred_atts = self.preprocess_attributions(pred_atts)
    stats = collections.OrderedDict()
    stats['ATT tau'] = reducer_fn(
        att_metrics.nodewise_kendall_tau_score(true_atts, pred_atts))
    stats['ATT r'] = reducer_fn(
        att_metrics.nodewise_pearson_r_score(true_atts, pred_atts))
    return stats


class FragmentLogicTask(templates.AttributionTask):
  """Base Fragment logic task.

  General purpose fragment identification task. Given a fragment identification
  rule it will setup an attribution task for binary classification if a graph
  obeys a rule or not. Supports complex logics like &, | and multi-fragments.
  """

  def __init__(self, rule):
    """Constructor for base FragmentLogicTask."""
    self.rule = rule

  @property
  def name(self):
    return self.rule.label

  def get_true_predictions(self, x):
    binary_matches = [bool(self.rule.match(m)) for m in x]
    return np.array(binary_matches, dtype=np.float32).reshape(-1, 1)

  def get_nn_activation_fn(self):
    return tf.nn.sigmoid

  def get_nn_loss_fn(self):

    def loss_fn(y_true, y_pred):
      return tf.reduce_mean(tf.losses.binary_crossentropy(y_true, y_pred))

    return loss_fn

  def evaluate_predictions(self, y_true,
                           y_prob):
    p_tol = att_metrics.get_optimal_threshold(y_true, y_prob)
    y_pred = (y_prob >= p_tol).astype(np.float32)
    return collections.OrderedDict([
        ('AUROC', att_metrics.nan_auroc_score(y_true, y_prob)),
        ('F1', att_metrics.nan_f1_score(y_true, y_pred)),
        ('ACC', att_metrics.accuracy_score(y_true, y_pred))
    ])

  def get_true_attributions(self, mols):
    """Gets fragments matches and converts them to multi-truth attributions."""
    atts = []
    for mol in mols:
      n_atoms = mol.GetNumAtoms()
      matches = self.rule.match(mol)
      if matches:
        nodes = np.array(matches).T.astype(np.float32)
      else:
        nodes = np.zeros((n_atoms, 1))
      global_vec = np.array(sum(nodes) > 0)
      atts.append(_make_attribution_from_nodes(mol, nodes, global_vec))

    return atts

  def preprocess_attributions(self,
                              atts,
                              positive = False,
                              normalize = False):
    """Prepare attributions for visualization or evaluation."""
    new_atts = []
    for att in atts:
      # If the attribution is 2D, then we pick the last truth.
      if att.nodes.ndim > 1:
        att = att.replace(nodes=att.nodes[:, -1])
      att = graph_utils.cast_to_np(graph_utils.reduce_sum_edges(att))
      new_atts.append(att)

    if normalize:
      new_atts = self.normalize_attributions(new_atts, positive)

    return new_atts

  def normalize_attributions(self,
                             atts,
                             positive = False):
    """Normalize all nodes to 0 to 1 range via quantiles."""
    all_values = np.concatenate([att.nodes for att in atts])
    all_values = all_values[all_values > 0] if positive else all_values

    normalizer = sklearn.preprocessing.QuantileTransformer()
    normalizer.fit(all_values.reshape(-1, 1))
    new_atts = []
    for att in atts:
      normed_nodes = normalizer.transform(att.nodes.reshape(-1, 1)).ravel()
      new_atts.append(att.replace(nodes=normed_nodes))
    return new_atts

  def evaluate_attributions(
      self,
      atts_true,
      atts_pred,
      reducer_fn = np.nanmean
  ):
    atts_probs = self.preprocess_attributions(atts_pred, normalize=True)
    atts_true_last = self.preprocess_attributions(atts_true)
    atts_binary = att_metrics.get_opt_binary_attributions(
        atts_true_last, atts_probs)
    stats = collections.OrderedDict()
    stats['ATT AUROC'] = reducer_fn(
        att_metrics.attribution_auroc(atts_true, atts_probs))
    stats['ATT F1'] = reducer_fn(
        att_metrics.attribution_f1(atts_true, atts_binary))
    stats['ATT ACC'] = reducer_fn(
        att_metrics.attribution_accuracy(atts_true, atts_binary))
    return stats


# Aliases for shorter task definitions.
_frag_rule = fragment_identifier.BasicFragmentRule
_and_rules = lambda rules: fragment_identifier.CompositeRule('AND', rules)

benzene_task = FragmentLogicTask(_frag_rule('benzene', 'c1ccccc1'))
logic7_task = FragmentLogicTask(
    _and_rules(
        [_frag_rule('flouride', '[FX1]'),
         _frag_rule('carbonyl', '[CX3]=O')]))
logic8_task = FragmentLogicTask(
    _and_rules([
        _frag_rule('unbranched alkane', '[R0;D2,D1][R0;D2][R0;D2,D1]'),
        _frag_rule('carbonyl', '[CX3]=O')
    ]))
logic10_task = FragmentLogicTask(
    _and_rules([
        _frag_rule('amine', '[NX3;H2]'),
        _frag_rule('ether2', '[OD2](C)C'),
        _frag_rule('benzene', '[cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1')
    ]))


class Task(enum.Enum):
  """Types of task readily implemented."""
  benzene = 'benzene'
  logic7 = 'logic7'
  logic8 = 'logic8'
  logic10 = 'logic10'
  crippen = 'crippen'


def get_task(task):
  """Retrieve a task by it's enum/name."""
  task = Task(task)
  task_obj = {
      Task.crippen: CrippenLogPTask(),
      Task.benzene: benzene_task,
      Task.logic7: logic7_task,
      Task.logic8: logic8_task,
      Task.logic10: logic10_task
  }
  return task_obj[task]
