
from ogb.utils.features import (atom_to_feature_vector,bond_to_feature_vector) 
from rdkit.Chem import AllChem
from rdkit import Chem
import numpy as np
import torch
from utils.mol_utils import bond_to_feature_vector as bond_to_feature_vector_non_santize
from utils.mol_utils import atom_to_feature_vector as atom_to_feature_vector_non_santize

class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)
        
def getmorganfingerprint(mol):
    return list(AllChem.GetMorganFingerprintAsBitVect(mol, 2))

def getmaccsfingerprint(mol):
    fp = AllChem.GetMACCSKeysFingerprint(mol)
    return [int(b) for b in fp.ToBitString()]

def log_base(base, x):
    return np.log(x) / np.log(base) 

def smiles2graph(smiles_string, sanitize=True):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """
    try:
        mol = Chem.MolFromSmiles(smiles_string, sanitize=sanitize)
        # atoms
        atom_features_list = []
        atom_label = []
        # print('smiles_string', smiles_string)
        # print('mol', Chem.MolToSmiles(mol), 'vs smiles_string', smiles_string)
        for atom in mol.GetAtoms():
            if sanitize:
                atom_features_list.append(atom_to_feature_vector(atom))
            else:
                atom_features_list.append(atom_to_feature_vector_non_santize(atom))

            atom_label.append(atom.GetSymbol())

        x = np.array(atom_features_list, dtype = np.int64)
        atom_label = np.array(atom_label, dtype = np.str)

        # bonds
        num_bond_features = 3  # bond type, bond stereo, is_conjugated
        if len(mol.GetBonds()) > 0: # mol has bonds
            edges_list = []
            edge_features_list = []
            for bond in mol.GetBonds():
                i = bond.GetBeginAtomIdx()
                j = bond.GetEndAtomIdx()

                # edge_feature = bond_to_feature_vector(bond)
                if sanitize:
                    edge_feature = bond_to_feature_vector(bond)
                else:
                    edge_feature = bond_to_feature_vector_non_santize(bond)
                # add edges in both directions
                edges_list.append((i, j))
                edge_features_list.append(edge_feature)
                edges_list.append((j, i))
                edge_features_list.append(edge_feature)

            # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
            edge_index = np.array(edges_list, dtype = np.int64).T

            # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
            edge_attr = np.array(edge_features_list, dtype = np.int64)

        else:   # mol has no bonds
            edge_index = np.empty((2, 0), dtype = np.int64)
            edge_attr = np.empty((0, num_bond_features), dtype = np.int64)

        graph = dict()
        graph['edge_index'] = edge_index
        graph['edge_feat'] = edge_attr
        graph['node_feat'] = x
        graph['num_nodes'] = len(x)
        
        return graph 

    except:
        return None


import copy
import pathlib
import pandas as pd
from tqdm import tqdm
from torch_geometric.data import Data
def labeled2graphs(raw_dir):
    '''
        - raw_dir: the position where property csv stored,  
    '''
    path_suffix = pathlib.Path(raw_dir).suffix
    if path_suffix == '.csv':
        df_full = pd.read_csv(raw_dir, engine='python')
        df_full.set_index('SMILES', inplace=True)
        print(df_full[:5])
    else:
        raise ValueError("Support only csv.")
    graph_list = []
    for smiles_idx in tqdm(df_full.index[:]):
        graph_dict = smiles2graph(smiles_idx)
        props = df_full.loc[smiles_idx]
        for (name,value) in props.iteritems():
            graph_dict[name] = np.array([[value]])
        graph_list.append(graph_dict)
    return graph_list

def unlabel2graphs(raw_dir, property_name=None, drop_property=False):
    '''
        - raw_dir: the position where property csv stored,  
    '''
    path_suffix = pathlib.Path(raw_dir).suffix
    if path_suffix == '.csv':
        df_full = pd.read_csv(raw_dir, engine='python')
        print(df_full[:5])
        # select data without current property
        if drop_property:
            df_full = df_full[df_full[property_name.split('-')[1]].isna()]
        df_full = df_full.dropna(subset=['SMILES'])
    elif path_suffix == '.txt':
        df_full = pd.read_csv(raw_dir, sep=" ", header=None, names=['SMILES'])
        print(df_full[:5])
    else:
        raise ValueError("Support only csv and txt.")
    graph_list = []
    for smiles_idx in tqdm(df_full['SMILES']):
        graph_dict = smiles2graph(smiles_idx)
        # graph_dict[property_name.split('-')[1]] = np.array([[np.nan]])
        graph_list.append(graph_dict)
    return graph_list
    
def read_graph_list(raw_dir, property_name=None, drop_property=False, process_labeled=False):
    print('raw_dir', raw_dir)
    if process_labeled:
        graph_list = labeled2graphs(raw_dir)
    else:
        graph_list = unlabel2graphs(raw_dir, property_name=property_name, drop_property=drop_property)
    pyg_graph_list = []
    print('Converting graphs into PyG objects...')
    for graph in graph_list:
        g = Data()
        g.__num_nodes__ = graph['num_nodes']
        g.edge_index = torch.from_numpy(graph['edge_index'])
        del graph['num_nodes']
        del graph['edge_index']
        if process_labeled:
            g.y = torch.from_numpy(graph[property_name.split('-')[1]])
            del graph[property_name.split('-')[1]]

        if graph['edge_feat'] is not None:
            g.edge_attr = torch.from_numpy(graph['edge_feat'])
            del graph['edge_feat']

        if graph['node_feat'] is not None:
            g.x = torch.from_numpy(graph['node_feat'])
            del graph['node_feat']

        addition_prop = copy.deepcopy(graph)
        for key in addition_prop.keys():
            g[key] = torch.tensor(graph[key])
            del graph[key]

        pyg_graph_list.append(g)

    return pyg_graph_list

import math
def x_u_split(args, labels):
    label_per_class = args.num_labeled // args.num_classes
    labels = np.array(labels)
    labeled_idx = []
    # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
    unlabeled_idx = np.array(range(len(labels)))
    for i in range(args.num_classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    assert len(labeled_idx) == args.num_labeled

    if args.expand_labels or args.num_labeled < args.batch_size:
        num_expand_x = math.ceil(
            args.batch_size * args.eval_step / args.num_labeled)
        labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
    np.random.shuffle(labeled_idx)
    return labeled_idx, unlabeled_idx


def nx_to_graph_data_obj_with_edge_attr(g):
    n_nodes = g.number_of_nodes()
    # graph level attributes
    center_node_idx = g.graph['center_node_idx']
    go_target_downstream = g.graph['go_target_downstream']
    species_id = g.graph['species_id']
    go_target_pretrain = g.graph['go_target_pretrain']

    nx_node_ids = [n_i for n_i in g.nodes()]  # contains list of nx node ids
    x = torch.tensor(np.ones(n_nodes).reshape(-1, 1), dtype=torch.long)
    edges_list = []
    edge_features_list = []
    for node_1, node_2, attr_dict in g.edges(data=True):
        edge_feature = [attr_dict['w1'], attr_dict['w2'], attr_dict['w3'],
                        attr_dict['w4'], attr_dict['w5'], attr_dict['w6'],
                        attr_dict['w7'], 0, 0]  # last 2 indicate self-loop
        # and masking
        edge_feature = np.array(edge_feature, dtype=int)
        i = nx_node_ids.index(node_1)
        j = nx_node_ids.index(node_2)
        edges_list.append((i, j))
        edge_features_list.append(edge_feature)
        edges_list.append((j, i))
        edge_features_list.append(edge_feature)

    edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
    edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.float)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    data.species_id = torch.tensor(species_id, dtype=torch.long)
    data.y = torch.tensor([np.array(go_target_downstream)], dtype=torch.long)
    data.center_node_idx = torch.tensor([center_node_idx], dtype=torch.long)

    return data

def nx_to_graph_data_obj(g):
    n_nodes = g.number_of_nodes()
    n_edges = g.number_of_edges()

    # graph level attributes
    center_node_idx = g.graph['center_node_idx']
    go_target_downstream = g.graph['go_target_downstream']
    species_id = g.graph['species_id']
    go_target_pretrain = g.graph['go_target_pretrain']

    # nodes
    nx_node_ids = [n_i for n_i in g.nodes()]  # contains list of nx node ids
    # we don't have any node labels, so set to dummy 1. dim n_nodes x 1
    x = torch.tensor(np.ones(n_nodes).reshape(-1, 1), dtype=torch.long)

    # edges
    edges_list = []
    edge_features_list = []
    for node_1, node_2, attr_dict in g.edges(data=True):
        # edge_feature = [attr_dict['w1'], attr_dict['w2'], attr_dict['w3'],
        #                 attr_dict['w4'], attr_dict['w5'], attr_dict['w6'],
        #                 attr_dict['w7'], 0, 0]  # last 2 indicate self-loop
        edge_feature = [1]  # no edge feature
        # and masking
        edge_feature = np.array(edge_feature, dtype=int)
        i = nx_node_ids.index(node_1)
        j = nx_node_ids.index(node_2)
        edges_list.append((i, j))
        edge_features_list.append(edge_feature)
        edges_list.append((j, i))
        edge_features_list.append(edge_feature)

    # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
    edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

    # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
    edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.float)

    # construct data obj
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    # data.center_node_idx = torch.tensor([center_node_idx], dtype=torch.long)
    data.species_id = torch.tensor(species_id, dtype=torch.long)
    data.y = torch.tensor([np.array(go_target_downstream)], dtype=torch.long)
    # data.go_target_pretrain = torch.tensor(np.array(go_target_pretrain),dtype=torch.long)

    return data


def nx_to_graph_data_obj_unlabeled(g):
    n_nodes = g.number_of_nodes()

    # nodes
    nx_node_ids = [n_i for n_i in g.nodes()]  # contains list of nx node ids
    # we don't have any node labels, so set to dummy 1. dim n_nodes x 1
    x = torch.tensor(np.ones(n_nodes).reshape(-1, 1), dtype=torch.long)

    # edges
    edges_list = []
    edge_features_list = []
    for node_1, node_2, attr_dict in g.edges(data=True):
        # edge_feature = [attr_dict['w1'], attr_dict['w2'], attr_dict['w3'],
        #                 attr_dict['w4'], attr_dict['w5'], attr_dict['w6'],
        #                 attr_dict['w7'], 0, 0]  # last 2 indicate self-loop
        edge_feature = [1]  # no edge feature
        # and masking
        edge_feature = np.array(edge_feature, dtype=int)
        i = nx_node_ids.index(node_1)
        j = nx_node_ids.index(node_2)
        edges_list.append((i, j))
        edge_features_list.append(edge_feature)
        edges_list.append((j, i))
        edge_features_list.append(edge_feature)

    # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
    edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

    # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
    edge_attr = torch.tensor(np.array(edge_features_list),
                             dtype=torch.float)

    # construct data obj
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data