# coding=utf-8
# Copyright 2020 The Attribution Gnn Benchmarking Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functionality for loading, setting up and saving datasets."""
import os
from typing import Dict, List, Optional, Text, Tuple

import numpy as np
import pandas as pd

import featurization
import fragment_identifier
import graph_nets
import tasks

# Typing alias.
GraphsTuple = graph_nets.graphs.GraphsTuple

ATTRIBUTION_DIR = 'data'
BIAS_DIR = os.path.join(ATTRIBUTION_DIR, 'dataset_bias')

DEFAULT_FILES = dict([('hparams', 'hparams.flagfile'),
                      ('init_state', 'init_state/checkpoint'),
                      ('splits', '{}_traintest_indices.npz'),
                      ('smiles', '{}_smiles.csv'), ('smarts', '{}_smarts.csv'),
                      ('checkpoint', 'checkpoint'),
                      ('saved_model', 'saved_model'), ('losses', 'losses.npz'),
                      ('predictions', 'predictions.npz'),
                      ('results', 'aggregate_results.csv'),
                      ('attribution_metrics', '{}_attribution_metrics.npy'),
                      ('attributions', '{}_raw_attribution_datadicts.npy')])

get_task_dir = lambda t: os.path.join(ATTRIBUTION_DIR, tasks.Task(t).name)


def smiles_to_graphs_tuple(
    smiles_list,
    tensorizer):
  """Convert a SMILES into a GraphsTuple."""
  graph_list = tensorizer.transform_data_dict(smiles_list)
  return graph_nets.utils_tf.data_dicts_to_graphs_tuple(graph_list)


def get_output_filename(fname_kwarg, work_dir,
                        name):
  """Get output file name using the default file naming scheme."""
  # Separate attributions because these results will be written per method.
  fname = DEFAULT_FILES[fname_kwarg]
  is_formatable = '{' in fname and '}' in fname
  if is_formatable and name is None:
    raise ValueError(f'{fname_kwarg} is formattable, missing name argument.')
  if is_formatable:
    return os.path.join(work_dir, fname.format(name))
  else:
    return os.path.join(work_dir, fname)


def load_train_test_indices(filename):
  """Read a numpy file with train/test indices."""
  data = np.load(filename)
  return data['train_index'], data['test_index']


def load_molgraph_dataset(exp_dir,
                          dataset_name,
                          use_h = False):
  """loads relevant data for a molecule graph problem."""
  data_file = get_output_filename('smiles', exp_dir, dataset_name)
  splits_file = get_output_filename('splits', exp_dir, dataset_name)

  df = pd.read_csv(data_file)
  train_index, test_index = load_train_test_indices(splits_file)
  smi_to_mol = lambda s: featurization.smiles_to_mol(s, infer_hydrogens=use_h)
  df['mol'] = df['smiles'].apply(smi_to_mol)
  mols = df['mol'].tolist()
  tensorizer = featurization.MolTensorizer(preprocess_fn=smi_to_mol)
  x = smiles_to_graphs_tuple(df['smiles'].tolist(), tensorizer)
  return df, x, mols, tensorizer, train_index, test_index


def interpolate_bias_train_indices(df, train_indices,
                                   ab_mix):
  """Interpolate between two sets of positive labels for training."""
  train_df = df.iloc[train_indices]
  neg_indices = train_df[train_df['~(A&B)']].index.tolist()
  a_and_b_indices = train_df[train_df['A&B']].index.tolist()
  a_not_b_indices = train_df[train_df['A&~B']].index.tolist()
  a_and_b_n = int(np.round(sum(train_df['A&B']) * ab_mix))
  a_not_b_n = int(np.round(sum(train_df['A&~B']) * (1.0 - ab_mix)))
  pos_indices = a_and_b_indices[:a_and_b_n] + a_not_b_indices[:a_not_b_n]
  pos_indices = pos_indices[:min(len(pos_indices), len(neg_indices))]
  new_indices = pos_indices + neg_indices
  return new_indices


def setup_bias_rules(
    exp_dir,
    dataset_name):
  """Setup fragment identification rules for a biased dataset experiment."""
  frag_file = get_output_filename('smarts', exp_dir, dataset_name)
  frag_df = pd.read_csv(frag_file).set_index('role')

  rules = {}
  rules['A'] = fragment_identifier.BasicFragmentRule('A',
                                                     frag_df.loc['A'].smarts,
                                                     frag_df.loc['A'].label)
  rules['B'] = fragment_identifier.BasicFragmentRule('B',
                                                     frag_df.loc['B'].smarts,
                                                     frag_df.loc['B'].label)
  rules['(A & B)'] = fragment_identifier.CompositeRule('AND',
                                                       [rules['A'], rules['B']])
  return rules


def setup_bias_test_indices(df,
                            test_index):
  """Gets test indices for each subgraph task."""
  test_df = df.iloc[test_index]
  negative_indices = test_df[test_df['~(A&B)']].index.tolist()
  only_a_indices = test_df[test_df['A&~B']].index.tolist()
  only_b_indices = test_df[test_df['~A&B']].index.tolist()
  a_and_b_indices = test_df[test_df['A&B']].index.tolist()

  test_indices = {}
  test_indices['A'] = np.array(only_a_indices + negative_indices)
  test_indices['B'] = np.array(only_b_indices + negative_indices)
  test_indices['(A & B)'] = np.array(a_and_b_indices + negative_indices)
  return test_indices
