import argparse
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from Bio.PDB import PDBParser
from torch_scatter import scatter_mean

import utils
from utils import compute_rmsd
from lightning_modules import LigandPocketDDPM
from preprocessing.constants import FLOAT_TYPE, INT_TYPE, dataset_params
from analysis.molecule_builder import build_molecule, process_molecule
from analysis.metrics import BasicMolecularMetrics 



def prepare_ligand(biopython_atoms, atom_encoder):
    coord = torch.tensor(np.array([a.get_coord()
                                   for a in biopython_atoms]), dtype=FLOAT_TYPE)
    types = torch.tensor([atom_encoder[a.element.capitalize()]
                          for a in biopython_atoms])
    one_hot = F.one_hot(types, num_classes=len(atom_encoder))

    return coord, one_hot


def guide_ligand(task,model, pdb_file, n_samples, ligand, fix_atoms, scale,
                   add_n_nodes=None, center='ligand', sanitize=False,
                   largest_frag=False, relax_iter=0, timesteps=None, save_traj=False, resamplings=1):

    if save_traj and n_samples > 1:
        raise NotImplementedError("Can only visualize trajectory with "
                                  "n_samples=1.")
    frames = timesteps if save_traj else 1
    sanitize = False if save_traj else sanitize
    relax_iter = 0 if save_traj else relax_iter
    largest_frag = False if save_traj else largest_frag

    # Load PDB
    pdb_model = PDBParser(QUIET=True).get_structure('', pdb_file)[0]

    # Define pocket based on reference ligand
    residues = utils.get_pocket_from_ligand(pdb_model, ligand)
    pocket = model.prepare_pocket(residues, repeats=n_samples)

    # Get ligand
    chain, resi = ligand.split(':')
    ligand = utils.get_residue_with_resi(pdb_model[chain], int(resi))
    fixed_atoms = fix_atoms # [a for a in ligand.get_atoms()if a.get_name() in set(fix_atoms)]
    n_fixed = len(fixed_atoms)
    x_fixed, one_hot_fixed = prepare_ligand(fixed_atoms, model.lig_type_encoder)
    
    fixed_ligand = {
        'x': x_fixed.cuda(),
        'one_hot': one_hot_fixed.cuda(),
        'size': torch.tensor([x_fixed.shape[0]]).cuda(),
        'mask': torch.zeros(x_fixed.shape[0]).cuda(),
    }

    for mol_pc in zip(utils.batch_to_list(fixed_ligand['x'], fixed_ligand['mask']), utils.batch_to_list(fixed_ligand['one_hot'].argmax(1), fixed_ligand['mask'])):
        fixed_mol = build_molecule(*mol_pc, model.dataset_info, add_coords=True)
        fixed_mol = process_molecule(fixed_mol,add_hydrogens=False, sanitize=sanitize, relax_iter=0,largest_frag=False)
    

    if add_n_nodes is None:
        num_nodes_lig = model.ddpm.size_distribution.sample_conditional(
            n1=None, n2=pocket['size'])
        num_nodes_lig = torch.clamp(num_nodes_lig, min=n_fixed)
    else:
        num_nodes_lig = torch.ones(n_samples, dtype=int) * n_fixed + add_n_nodes
    
    ligand_mask = utils.num_nodes_to_batch_mask(
        len(num_nodes_lig), num_nodes_lig, model.device)

    ligand = {
        'x': torch.zeros((len(ligand_mask), model.x_dims),
                         device=model.device, dtype=FLOAT_TYPE),
        'one_hot': torch.zeros((len(ligand_mask), model.atom_nf),
                               device=model.device, dtype=FLOAT_TYPE),
        'size': num_nodes_lig,
        'mask': ligand_mask
    }

    # fill in fixed atoms
    lig_fixed = torch.zeros_like(ligand_mask)
    for i in range(n_samples):
        sele = (ligand_mask == i)

        x_new = ligand['x'][sele]
        x_new[:n_fixed] = x_fixed
        ligand['x'][sele] = x_new

        h_new = ligand['one_hot'][sele]
        h_new[:n_fixed] = one_hot_fixed
        ligand['one_hot'][sele] = h_new

        fixed_new = lig_fixed[sele]
        fixed_new[:n_fixed] = 1
        lig_fixed[sele] = fixed_new

    # Pocket's center of mass
    pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0)
    ligand_com_before = scatter_mean(ligand['x'], ligand['mask'], dim=0)

    # Run sampling
    if model.mode == 'pocket_conditioning':
        if task == 'inpainting':
            xh_lig, xh_pocket, lig_mask, pocket_mask = model.ddpm.inpaint(
                    ligand, pocket, lig_fixed, center=center, resamplings=resamplings, timesteps=timesteps, return_frames=frames)

        else:    
            xh_lig, xh_pocket, lig_mask, pocket_mask = model.ddpm.observation_guidance(
                task, ligand, pocket, lig_fixed, center=center, timesteps=timesteps, return_frames=frames, scale = scale)
    else:
        pocket_mask_fixed = torch.ones(len(pocket['mask']),device=model.device) 
        
        xh_lig, xh_pocket, lig_mask, pocket_mask = model.ddpm.guidance_resampling(
            ligand, pocket, lig_fixed, pocket_mask_fixed, timesteps=timesteps, return_frames=frames, scale = scale, resamplings=resamplings)

    # Treat intermediate states as molecules for downstream processing
    if save_traj:
        xh_lig = utils.reverse_tensor(xh_lig)
        xh_pocket = utils.reverse_tensor(xh_pocket)

        # # Repeat last frame to see final sample better.
        # xh_lig = torch.cat([xh_lig, xh_lig[-1:].repeat(10, 1, 1)], dim=0)
        # xh_pocket = torch.cat([xh_pocket, xh_pocket[-1:].repeat(10, 1, 1)], dim=0)

        lig_mask = torch.arange(xh_lig.size(0), device=model.device
                                ).repeat_interleave(len(lig_mask))
        pocket_mask = torch.arange(xh_pocket.size(0), device=model.device
                                   ).repeat_interleave(len(pocket_mask))

        xh_lig = xh_lig.view(-1, xh_lig.size(2))
        xh_pocket = xh_pocket.view(-1, xh_pocket.size(2))

    # Move generated molecule back to the original pocket position
    pocket_com_after = scatter_mean(xh_pocket[:, :model.x_dims], pocket_mask, dim=0)

    xh_pocket[:, :model.x_dims] += \
        (pocket_com_before - pocket_com_after)[pocket_mask]

    xh_lig[:, :model.x_dims] += \
        (pocket_com_before - pocket_com_after)[lig_mask]


    rmsd_ligands= compute_rmsd(xh_lig[:, :model.x_dims], ligand['x'], lig_mask).detach()
    rmsd_ligands_xh = compute_rmsd(xh_lig[:,model.x_dims:], ligand['one_hot'], lig_mask).detach()
    
    # Build mol objects
    x = xh_lig[:, :model.x_dims].detach().cpu()
    atom_type = xh_lig[:, model.x_dims:].argmax(1).detach().cpu()

    molecules = []
    basic_mol_metrics = BasicMolecularMetrics(dataset_params['bindingmoad'])
    for mol_pc in zip(utils.batch_to_list(x, lig_mask),
                      utils.batch_to_list(atom_type, lig_mask)):

        mol = build_molecule(*mol_pc, model.dataset_info, add_coords=True)
        mol = process_molecule(mol,
                               add_hydrogens=False,
                               sanitize=sanitize,
                               relax_iter=relax_iter,
                               largest_frag=False)
        if mol is not None:
            molecules.append(mol)
    valid, validity = basic_mol_metrics.compute_validity(molecules)
    #print(f"Validity over {len(molecules)} molecules: {validity * 100 :.2f}%")

    connected, connectivity, connected_smiles = basic_mol_metrics.compute_connectivity(valid)
    #print(f"Connectivity over {len(valid)} valid molecules: "f"{connectivity * 100 :.2f}%")
    

    molecules = []
    for mol_pc in zip(utils.batch_to_list(x, lig_mask),
                      utils.batch_to_list(atom_type, lig_mask)):

        mol = build_molecule(*mol_pc, model.dataset_info, add_coords=True)
        mol = process_molecule(mol,
                               add_hydrogens=False,
                               sanitize=sanitize,
                               relax_iter=relax_iter,
                               largest_frag=largest_frag)
        if mol is not None:
            molecules.append(mol)

    return molecules, fixed_mol, connected, valid


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('checkpoint', type=Path)
    parser.add_argument('--pdbfile', type=str)
    parser.add_argument('--ref_ligand', type=str, default=None)
    parser.add_argument('--fix_atoms', type=str, nargs='+', default=None)
    parser.add_argument('--outfile', type=Path)
    parser.add_argument('--n_samples', type=int, default=20)
    parser.add_argument('--add_n_nodes', type=int, default=None)
    parser.add_argument('--all_frags', action='store_true')
    parser.add_argument('--relax', type=int, default=0)
    parser.add_argument('--raw', action='store_true')
    parser.add_argument('--resamplings', type=int, default=10)
    parser.add_argument('--timesteps', type=int, default=50)
    parser.add_argument('--save_traj', action='store_true')
    args = parser.parse_args()

    pdb_id = Path(args.pdbfile).stem

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Load model
    model = LigandPocketDDPM.load_from_checkpoint(
        args.checkpoint, map_location=device)
    model = model.to(device)

    molecules = guide_ligand(model, args.pdbfile, args.n_samples,
                               args.ref_ligand, args.fix_atoms,
                               args.add_n_nodes, center='pocket',
                               sanitize=not args.raw,
                               largest_frag=not args.all_frags,
                               relax_iter=args.relax,
                               timesteps=args.timesteps,
                               resamplings=args.resamplings,
                               save_traj=args.save_traj)

    # Make SDF files
    utils.write_sdf_file(args.outfile, molecules)