import os
import glob
import argparse
from analysis.metrics import MoleculeProperties 
from analysis.metrics import BasicMolecularMetrics 
from rdkit import Chem
from pathlib import Path
from rdkit import Chem
import glob
from preprocessing.constants import dataset_params
from analysis.molecule_builder import build_molecule
import numpy as np


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='/path/to/moad/processed_noH_ca_only/')
    parser.add_argument('--test_dir', type=str, default='/path/to/results/moad_ca_joint/')
    parser.add_argument('--experiment_dir', type=str, default='debug_1')
    args = parser.parse_args()


    mol_metrics = MoleculeProperties()
    test_dir = os.path.join(args.test_dir, args.experiment_dir,'processed','*.sdf')

    sdf_names = glob.glob(test_dir)

    pocket_mols_lst = []
    for sdf_name in sdf_names:
        with Chem.SDMolSupplier(sdf_name) as suppl:
            pocket_mols = [x for x in suppl if x is not None]
        pocket_mols_lst.append(pocket_mols)

    all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity = mol_metrics.evaluate(pocket_mols_lst)
    print(len(pocket_mols_lst)) # 55
    print([len(x) for x in pocket_mols_lst]) 


    if 'moad' in args.test_dir:
        dataset_info = dataset_params['bindingmoad']
    elif 'crossdock' in args.test_dir:
        dataset_info = dataset_params['crossdock']

    dataset_smiles_list_path = os.path.join(args.data_dir, 'train_smiles.npy') 
    dataset_smiles_list =  np.load(dataset_smiles_list_path)

    basic_mol_metrics = BasicMolecularMetrics(dataset_info, dataset_smiles_list)

    sdf_names = glob.glob(test_dir)


    generated_mols_list = []
    for sdf_name in sdf_names:
        with Chem.SDMolSupplier(sdf_name) as suppl:
            #generated_mols = [x for x in suppl if x is not None]
            generated_mols = [x for x in suppl]
        generated_mols_list.extend(generated_mols)

    #generated_mols_list = generated_mols_list[:100]
    print('generated_mols_list ', len(generated_mols_list))


    # Convert into rdmols
    #rdmols = [build_molecule(*graph, dataset_info) for graph in pocket_mols]
    #(validity, connectivity, uniqueness, novelty), (_, connected_mols) = basic_mol_metrics.evaluate_rdmols(rdmols)
    valid, validity = basic_mol_metrics.compute_validity(generated_mols_list)
    print(f"Validity over {len(generated_mols_list)} molecules: {validity * 100 :.2f}%")

    connected, connectivity, connected_smiles = basic_mol_metrics.compute_connectivity(valid)
    print(f"Connectivity over {len(valid)} valid molecules: "f"{connectivity * 100 :.2f}%")

    unique, uniqueness = basic_mol_metrics.compute_uniqueness(connected_smiles)
    print(f"Uniqueness over {len(connected)} connected molecules: "f"{uniqueness * 100 :.2f}%")

    _, novelty = basic_mol_metrics.compute_novelty(unique) 
    print(f"Novelty over {len(unique)} unique connected molecules: "f"{novelty * 100 :.2f}%")