from __future__ import print_function

import numpy as np
from rdkit.Chem import MolFromSmiles

import sys
sys.path.insert(0, '../graph_methods')
sys.path.insert(0, '../src')
from function import reshape_data_into_2_dim
from graph_util import extract_atom_features, extract_bond_features

degrees = [0, 1, 2, 3, 4, 5]


class MolGraph(object):
    def __init__(self):
        self.nodes = {}

    def new_node(self, ntype, features=None, rdkit_ix=None):
        new_node = Node(ntype, features, rdkit_ix)
        self.nodes.setdefault(ntype, []).append(new_node)
        return new_node

    def add_subgraph(self, subgraph):
        old_nodes = self.nodes
        new_nodes = subgraph.nodes
        for ntype in set(old_nodes.keys()) | set(new_nodes.keys()):
            old_nodes.setdefault(ntype, []).extend(new_nodes.get(ntype, []))

    def sort_nodes_by_degree(self, ntype):
        nodes_by_degree = {i: [] for i in degrees}
        for node in self.nodes[ntype]:
            nodes_by_degree[len(node.get_neighbors(ntype))].append(node)

        new_nodes = []
        for degree in degrees:
            cur_nodes = nodes_by_degree[degree]
            self.nodes[(ntype, degree)] = cur_nodes
            new_nodes.extend(cur_nodes)

        self.nodes[ntype] = new_nodes

    def feature_array(self, ntype):
        assert ntype in self.nodes
        return np.array([node.features for node in self.nodes[ntype]])

    def rdkit_ix_array(self):
        return np.array([node.rdkit_ix for node in self.nodes['atom']])

    def neighbor_list(self, self_ntype, neighbor_ntype):
        assert self_ntype in self.nodes and neighbor_ntype in self.nodes
        neighbor_idxs = {n: i for i, n in enumerate(self.nodes[neighbor_ntype])}
        return [[neighbor_idxs[neighbor]
                 for neighbor in self_node.get_neighbors(neighbor_ntype)]
                for self_node in self.nodes[self_ntype]]


class Node(object):
    __slots__ = ['ntype', 'features', '_neighbors', 'rdkit_ix']

    def __init__(self, ntype, features, rdkit_ix):
        self.ntype = ntype # in [molecule, atom, bond]
        self.features = features
        self._neighbors = []
        self.rdkit_ix = rdkit_ix

    def add_neighbors(self, neighbor_list):
        for neighbor in neighbor_list:
            self._neighbors.append(neighbor)
            neighbor._neighbors.append(self)

    def get_neighbors(self, ntype):
        return [n for n in self._neighbors if n.ntype == ntype]


def neural_fingerprint_collate_fn(data):
    def graph_from_smiles(smiles):
        graph = MolGraph()
        mol = MolFromSmiles(smiles)
        if not mol:
            raise ValueError("Could not parse SMILES string:", smiles)
        atoms_by_rd_idx = {}
        for atom in mol.GetAtoms():
            new_atom_node = graph.new_node('atom', features=extract_atom_features(atom), rdkit_ix=atom.GetIdx())
            atoms_by_rd_idx[atom.GetIdx()] = new_atom_node

        for bond in mol.GetBonds():
            atom1_node = atoms_by_rd_idx[bond.GetBeginAtom().GetIdx()]
            atom2_node = atoms_by_rd_idx[bond.GetEndAtom().GetIdx()]
            new_bond_node = graph.new_node('bond', features=extract_bond_features(bond))
            new_bond_node.add_neighbors((atom1_node, atom2_node))
            atom1_node.add_neighbors((atom2_node,))

        mol_node = graph.new_node('molecule')
        mol_node.add_neighbors(graph.nodes['atom'])
        return graph

    def graph_from_smiles_tuple(smiles_tuple):
        graph_list = [graph_from_smiles(s) for s in smiles_tuple]
        big_graph = MolGraph()
        for subgraph in graph_list:
            big_graph.add_subgraph(subgraph)

        big_graph.sort_nodes_by_degree('atom')
        return big_graph

    smiles_list = [d[0] for d in data]
    label_list = np.stack([d[1] for d in data])
    label_list = reshape_data_into_2_dim(label_list)

    molgraph = graph_from_smiles_tuple(smiles_list)
    array_rep = {'atom_features': molgraph.feature_array('atom'),
                 'bond_features': molgraph.feature_array('bond'),
                 'atom_list': molgraph.neighbor_list('molecule', 'atom'),
                 'rdkit_ix': molgraph.rdkit_ix_array()
                 }
    for degree in degrees:
        array_rep[('atom_neighbors', degree)] = \
            np.array(molgraph.neighbor_list(('atom', degree), 'atom'), dtype=int)
        array_rep[('bond_neighbors', degree)] = \
            np.array(molgraph.neighbor_list(('atom', degree), 'bond'), dtype=int)

    return array_rep, label_list