# Rdkit import should be first, do not move it
try:
    from rdkit import Chem
except ModuleNotFoundError:
    pass
import copy
import utils
import argparse
import wandb
from configs.datasets_config import get_dataset_info
from os.path import join
from qm9 import dataset
from qm9.models import get_optim, get_model
from equivariant_diffusion import en_diffusion
from equivariant_diffusion.utils import assert_correctly_masked
from equivariant_diffusion import utils as flow_utils
import torch
import time
import pickle
from qm9.utils import prepare_context, compute_mean_mad
import qm9.visualizer as vis
from absl import logging
from qm9.sampling import sample_chain, sample, sample_sweep_conditional
from qm9.analyze import check_stability


from train_test import (
    train_epoch,
    test,
    analyze_and_save,
    data_callback,
    save_and_sample_chain,
)

parser = argparse.ArgumentParser(description="E3Diffusion")
parser.add_argument("--exp_name", type=str, default="debug_10")
parser.add_argument("--gpu_id", type=str, default="0")
parser.add_argument(
    "--model",
    type=str,
    default="egnn_dynamics",
    help="our_dynamics | schnet | simple_dynamics | "
    "kernel_dynamics | egnn_dynamics |gnn_dynamics",
)
parser.add_argument(
    "--probabilistic_model", type=str, default="diffusion", help="diffusion|flow_match"
)
parser.add_argument(
    "--node_classifier_model_ckpt",
    type=str,
    default="/sharefs/anonymous/node_predict/model_ckpt/model_195.npy",
)


# Training complexity is O(1) (unaffected), but sampling complexity is O(steps).
parser.add_argument("--diffusion_steps", type=int, default=500)
parser.add_argument(
    "--diffusion_noise_schedule",
    type=str,
    default="polynomial_2",
    help="learned, cosine",
)

parser.add_argument(
    "--discrete_path", type=str, default="OT_path", help="OT_path, HB_path, VP_path"
)

parser.add_argument(
    "--diffusion_noise_precision",
    type=float,
    default=1e-5,
)
parser.add_argument("--diffusion_loss_type", type=str, default="l2", help="vlb, l2")

parser.add_argument("--cat_loss_step", type=float, default=-1)

parser.add_argument("--cat_loss", type=str, default="l2", help='"l2" or "cse"')

parser.add_argument("--on_hold_batch", type=int, default=-1)

parser.add_argument("--sampling_method", type=str, default="vanilla")
parser.add_argument("--weighted_methods", type=str, default="jump")

parser.add_argument(
    "--output_dir",
    type=str,
    default="outputs",
    help="outputs | /sharefs/anonymous/edm_output/outputs",
)

# /sharefs/anonymous/edm_output/outputs

parser.add_argument("--n_epochs", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--brute_force", type=eval, default=False, help="True | False")
parser.add_argument("--actnorm", type=eval, default=True, help="True | False")
parser.add_argument(
    "--break_train_epoch", type=eval, default=False, help="True | False"
)
parser.add_argument("--dp", type=eval, default=True, help="True | False")
parser.add_argument("--condition_time", type=eval, default=True, help="True | False")
parser.add_argument("--clip_grad", type=eval, default=True, help="True | False")
parser.add_argument("--trace", type=str, default="hutch", help="hutch | exact")
# EGNN args -->
parser.add_argument("--n_layers", type=int, default=6, help="number of layers")
parser.add_argument("--inv_sublayers", type=int, default=1, help="number of layers")
parser.add_argument("--nf", type=int, default=128, help="number of layers")
parser.add_argument("--tanh", type=eval, default=True, help="use tanh in the coord_mlp")
parser.add_argument(
    "--attention", type=eval, default=True, help="use attention in the EGNN"
)
parser.add_argument(
    "--norm_constant", type=float, default=1, help="diff/(|diff| + norm_constant)"
)
parser.add_argument(
    "--sin_embedding",
    type=eval,
    default=False,
    help="whether using or not the sin embedding",
)
# <-- EGNN args
parser.add_argument("--ode_regularization", type=float, default=1e-3)
parser.add_argument(
    "--dataset",
    type=str,
    default="qm9",
    help="qm9 | qm9_second_half (train only on the last 50K samples of the training dataset)",
)
parser.add_argument("--datadir", type=str, default="qm9/temp", help="qm9 directory")
parser.add_argument(
    "--filter_n_atoms",
    type=int,
    default=None,
    help="When set to an integer value, QM9 will only contain molecules of that amount of atoms",
)
parser.add_argument(
    "--dequantization",
    type=str,
    default="argmax_variational",
    help="uniform | variational | argmax_variational | deterministic",
)
parser.add_argument("--n_report_steps", type=int, default=1)
parser.add_argument("--wandb_usr", type=str)
parser.add_argument("--no_wandb", action="store_true", help="Disable wandb")
parser.add_argument(
    "--online",
    type=bool,
    default=True,
    help="True = wandb online -- False = wandb offline",
)
parser.add_argument(
    "--no-cuda", action="store_true", default=False, help="enables CUDA training"
)
parser.add_argument("--save_model", type=eval, default=True, help="save model")
parser.add_argument("--generate_epochs", type=int, default=1, help="save model")
parser.add_argument(
    "--num_workers", type=int, default=0, help="Number of worker for the dataloader"
)

parser.add_argument("--test_epochs", type=int, default=10)

parser.add_argument("--sample_eva_epochs", type=int, default=20)

parser.add_argument(
    "--data_augmentation", type=eval, default=False, help="use attention in the EGNN"
)
parser.add_argument(
    "--conditioning",
    nargs="+",
    default=[],
    help="arguments : homo | lumo | alpha | gap | mu | Cv",
)
parser.add_argument("--resume", type=str, default=None, help="")
parser.add_argument("--start_epoch", type=int, default=0, help="")
parser.add_argument(
    "--ema_decay",
    type=float,
    default=0.999,
    help="Amount of EMA decay, 0 means off. A reasonable value" " is 0.999.",
)
parser.add_argument("--augment_noise", type=float, default=0)
parser.add_argument(
    "--n_stability_samples",
    type=int,
    default=500,
    help="Number of samples to compute the stability",
)
parser.add_argument(
    "--normalize_factors",
    type=eval,
    default=[1, 4, 1],
    help="normalize factors for [x, categorical, integer]",
)
parser.add_argument("--remove_h", action="store_true")
parser.add_argument(
    "--include_charges", type=eval, default=True, help="include atom charge or not"
)
parser.add_argument(
    "--visualize_every_batch",
    type=int,
    default=1e8,
    help="Can be used to visualize multiple times per epoch",
)
parser.add_argument(
    "--normalization_factor",
    type=float,
    default=1,
    help="Normalize the sum aggregation of EGNN",
)
parser.add_argument(
    "--aggregation_method", type=str, default="sum", help='"sum" or "mean"'
)
parser.add_argument("--ode_method", type=str, default="euler", help='"euler" or "rk4"')

parser.add_argument(
    "--minimize_type_entropy",
    action="store_true",
    default=False,
    help="minimize_type_entropy",
)
parser.add_argument(
    "--minimize_entropy_grad_coeff",
    type=float,
    default=0.5,
    help="minimize_entropy_grad_coeff",
)
parser.add_argument(
    "--extend_feature_dim",
    type=int,
    default=0,
    help="extend_feature_dim",
)
parser.add_argument(
    "--without_cat_loss", action="store_true", help="train without categorical loss"
)

parser.add_argument(
    "--angle_penalty", action="store_true", help="train with angle penalty"
)
# cat_loss=  args.cat_loss,
#             cat_loss_step=args.cat_loss_step

args = parser.parse_args()

dataset_info = get_dataset_info(args.dataset, args.remove_h)

atom_encoder = dataset_info["atom_encoder"]
atom_decoder = dataset_info["atom_decoder"]

# args, unparsed_args = parser.parse_known_args()
args.wandb_usr = utils.get_wandb_username(args.wandb_usr)

args.cuda = not args.no_cuda and torch.cuda.is_available()

args.dp = False
if args.dp == False:  # If not using DP, then use CUDA
    device_id = "cuda:%s" % args.gpu_id
    torch.cuda.set_device(device_id)
    device = torch.device(device_id)
else:
    device = torch.device("cuda" if args.cuda else "cpu")

dtype = torch.float32

if args.resume is not None:
    exp_name = args.exp_name + "_resume"
    start_epoch = args.start_epoch
    resume = args.resume
    wandb_usr = args.wandb_usr
    sample_eva_epochs = args.sample_eva_epochs
    out_dir = args.output_dir
    sampling_method = args.sampling_method
    ode_method = args.ode_method
    weighted_methods = args.weighted_methods

    dp = args.dp
    test_epochs = args.test_epochs
    normalization_factor = args.normalization_factor
    aggregation_method = args.aggregation_method
    cat_loss = args.cat_loss
    cat_loss_step = args.cat_loss_step
    on_hold_batch = args.on_hold_batch
    weighted_methods = args.weighted_methods
    node_classifier_model_ckpt = args.node_classifier_model_ckpt
    n_stability_samples = args.n_stability_samples
    minimize_entropy_grad_coeff = args.minimize_entropy_grad_coeff
    minimize_type_entropy = args.minimize_type_entropy
    gpu_id = args.gpu_id
    extend_feature_dim = args.extend_feature_dim
    without_cat_loss = args.without_cat_loss
    angle_penalty = args.angle_penalty
    no_wandb = args.no_wandb
    batch_size = args.batch_size
    diffusion_steps = args.diffusion_steps

    # sampling_method =

    with open(join(args.resume, f"args_{start_epoch}.pickle"), "rb") as f:
        args = pickle.load(f)

    args.resume = resume
    args.node_classifier_model_ckpt = node_classifier_model_ckpt
    args.datadir = "/sharefs/anonymous/temp"

    args.break_train_epoch = False

    args.dp = dp
    args.exp_name = exp_name
    args.start_epoch = start_epoch
    args.wandb_usr = wandb_usr
    args.test_epochs = test_epochs
    args.sample_eva_epochs = sample_eva_epochs
    args.output_dir = out_dir
    args.cat_loss = cat_loss
    args.cat_loss_step = cat_loss_step
    args.sampling_method = sampling_method
    args.weighted_methods = weighted_methods
    args.on_hold_batch = on_hold_batch
    args.weighted_methods = weighted_methods
    args.ode_method = ode_method
    args.n_stability_samples = n_stability_samples
    args.minimize_entropy_grad_coeff = minimize_entropy_grad_coeff
    args.minimize_type_entropy = minimize_type_entropy
    args.gpu_id = gpu_id
    args.extend_feature_dim = extend_feature_dim
    args.without_cat_loss = without_cat_loss
    args.angle_penalty = angle_penalty
    args.no_wandb = no_wandb
    args.batch_size = batch_size
    args.diffusion_steps = diffusion_steps

    # Careful with this -->
    if not hasattr(args, "normalization_factor"):
        args.normalization_factor = normalization_factor
    if not hasattr(args, "aggregation_method"):
        args.aggregation_method = aggregation_method

    print(args)

# for resume only
# args.dp = False
# args.test_epochs = 200
# args.sample_eva_epochs = 20

utils.create_folders(args)
# print(args)

# Wandb config
if args.no_wandb:
    mode = "disabled"
else:
    mode = "online" if args.online else "offline"
kwargs = {
    "entity": args.wandb_usr,
    "name": args.exp_name,
    "project": "e3_diffusion",
    "config": args,
    "settings": wandb.Settings(_disable_stats=True),
    "reinit": True,
    "mode": mode,
}
wandb.init(**kwargs)
wandb.save("*.txt")

# Retrieve QM9 dataloaders
dataloaders, charge_scale = dataset.retrieve_dataloaders(args)

data_dummy = next(iter(dataloaders["train"]))

if len(args.conditioning) > 0:
    print(f"Conditioning on {args.conditioning}")
    property_norms = compute_mean_mad(dataloaders, args.conditioning, args.dataset)
    context_dummy = prepare_context(args.conditioning, data_dummy, property_norms)
    context_node_nf = context_dummy.size(2)
else:
    context_node_nf = 0
    property_norms = None

args.context_node_nf = context_node_nf

# Create EGNN flow
model, nodes_dist, prop_dist, deq = get_model(
    args, device, dataset_info, dataloaders["train"]
)
if prop_dist is not None:
    prop_dist.set_normalizer(property_norms)
model = model.to(device)
optim = get_optim(args, model)
# print(model)

gradnorm_queue = utils.Queue()
gradnorm_queue.add(3000)  # Add large value that will be flushed.


def sample_only_stable_different_sizes_and_save(
    args, device, dequantizer, flow, nodes_dist, dataset_info, n_samples=10, n_tries=50
):
    assert n_tries > n_samples

    nodesxsample = nodes_dist.sample(n_tries)
    one_hot, charges, x, node_mask = sample(
        args, device, dequantizer, flow, dataset_info, nodesxsample=nodesxsample
    )

    counter = 0
    for i in range(n_tries):
        num_atoms = int(node_mask[i : i + 1].sum().item())
        atom_type = (
            one_hot[i : i + 1, :num_atoms].argmax(2).squeeze(0).cpu().detach().numpy()
        )
        x_squeeze = x[i : i + 1, :num_atoms].squeeze(0).cpu().detach().numpy()
        mol_stable = check_stability(x_squeeze, atom_type, dataset_info)[0]

        num_remaining_attempts = n_tries - i - 1
        num_remaining_samples = n_samples - counter

        if mol_stable or num_remaining_attempts <= num_remaining_samples:
            if mol_stable:
                print("Found stable mol.")
            vis.save_xyz_file(
                join(f"{args.output_dir}/{args.exp_name}", "eval/molecules/"),
                one_hot[i : i + 1],
                charges[i : i + 1],
                x[i : i + 1],
                id_from=counter,
                name="molecule_stable",
                dataset_info=dataset_info,
                node_mask=node_mask[i : i + 1],
            )
            counter += 1

            if counter >= n_samples:
                break


def check_mask_correct(variables, node_mask):
    for variable in variables:
        if len(variable) > 0:
            assert_correctly_masked(variable, node_mask)


def main():
    if args.resume is not None:
        if args.start_epoch != 0:
            print("Resuming from epoch %d" % args.start_epoch)

            flow_state_dict = torch.load(
                join(args.resume, "generative_model_%d.npy" % args.start_epoch)
            )
            optim_state_dict = torch.load(
                join(args.resume, "optim_%d.npy" % args.start_epoch)
            )
            model.load_state_dict(flow_state_dict)
            optim.load_state_dict(optim_state_dict)

        else:
            flow_state_dict = torch.load(join(args.resume, "generative_model.npy"))
            # dequantizer_state_dict = torch.load(join(args.resume, 'dequantizer.npy'))
            optim_state_dict = torch.load(join(args.resume, "optim.npy"))
            model.load_state_dict(flow_state_dict)
            # deq.load_state_dict(dequantizer_state_dict)
            optim.load_state_dict(optim_state_dict)

    # Initialize dataparallel if enabled and possible.
    print(args.dp, "parallel or not")

    if args.dp and torch.cuda.device_count() > 1:
        print(f"Training using {torch.cuda.device_count()} GPUs")
        model_dp = torch.nn.DataParallel(model.cpu())
        model_dp = model_dp.cuda()
    else:
        model_dp = model

    if args.ema_decay > 0:
        model_ema = copy.deepcopy(model)
        if args.start_epoch != 0 and args.resume is not None:
            model_ema_state_dict = torch.load(
                join(args.resume, "generative_model_ema_%d.npy" % args.start_epoch)
            )
            model_ema.load_state_dict(model_ema_state_dict)
            print("ema model loaded")

        if args.dp and torch.cuda.device_count() > 1:
            model_ema_dp = torch.nn.DataParallel(model_ema)
        else:
            model_ema_dp = model_ema
    else:
        ema = None
        model_ema = model
        model_ema_dp = model_dp

    # for b_i in range(10):
    # b_i = 1
    # save_and_sample_chain(
    #     model_ema,
    #     deq,
    #     args,
    #     device,
    #     dataset_info,
    #     prop_dist,
    #     epoch=0,
    #     batch_id=str(b_i),
    # )
    # # TODO uncomment the following to get the stability analysis
    # analyze_and_save(
    #     args=args,
    #     epoch=0,
    #     model_sample=model_ema,
    #     dequantizer=deq,
    #     nodes_dist=nodes_dist,
    #     dataset_info=dataset_info,
    #     device=device,
    #     prop_dist=prop_dist,
    #     n_samples=args.n_stability_samples,
    #     batch_size=args.batch_size,
    # )
    sample_only_stable_different_sizes_and_save(
        args=args,
        device=device,
        dequantizer=deq,
        flow=model_ema,
        nodes_dist=nodes_dist,
        dataset_info=dataset_info,
        n_samples=21,
        n_tries=100,
    )
    vis.visualize(
        f"{args.output_dir}/{args.exp_name}/eval/molecules/",
        dataset_info=dataset_info,
        wandb=wandb,
        spheres_3d=True,
    )
    # vis.visualize_chain(
    #     f"{args.output_dir}/{args.exp_name}/epoch_{0}_{b_i}/chain/",
    #     dataset_info,
    #     wandb=wandb,
    # )
    # Initialize model copy for exponential moving average of params.

    # best_nll_val = 1e8
    # best_nll_test = 1e8
    # for epoch in range(args.start_epoch, args.n_epochs):
    #     start_epoch = time.time()
    #     # data_callback(device=device,dtype=dtype,loader=dataloaders['train'])
    #     # break
    #     train_epoch(args=args, deq=deq,loader=dataloaders['train'], epoch=epoch, model=model, model_dp=model_dp,
    #                 model_ema=model_ema, ema=ema, device=device, dtype=dtype, property_norms=property_norms,
    #                 nodes_dist=nodes_dist, dataset_info=dataset_info,
    #                 gradnorm_queue=gradnorm_queue, optim=optim, prop_dist=prop_dist)
    #     print(f"Epoch took {time.time() - start_epoch:.1f} seconds.")

    #     if epoch %  args.sample_eva_epochs == 0:
    #         if not args.break_train_epoch:

    # if args.save_model:
    #         utils.save_model(optim, '%s/%s/optim_%d.npy' % (args.output_dir,args.exp_name, epoch))
    #         utils.save_model(model, '%s/%s/generative_model_%d.npy' % (args.output_dir, args.exp_name, epoch))
    #         if args.ema_decay > 0:
    #             utils.save_model(model_ema, '%s/%s/generative_model_ema_%d.npy' % (args.output_dir,args.exp_name, epoch))
    #         with open('%s/%s/args_%d.pickle' % (args.output_dir,args.exp_name, epoch), 'wb') as f:
    #             pickle.dump(args, f)

    # if epoch % args.test_epochs == 0 and epoch != 0:
    #     if isinstance(model, en_diffusion.EnVariationalDiffusion):
    #         wandb.log(model.log_info(), commit=True)

    #     # if not args.break_train_epoch:
    #     #     analyze_and_save(args=args, epoch=epoch, model_sample=model_ema, dequantizer=deq, nodes_dist=nodes_dist,
    #     #                      dataset_info=dataset_info, device=device,
    #     #                      prop_dist=prop_dist, n_samples=args.n_stability_samples)
    #     nll_val = test(args=args, deq=deq, loader=dataloaders['valid'], epoch=epoch, eval_model=model_ema_dp,
    #                    partition='Val', device=device, dtype=dtype, nodes_dist=nodes_dist,
    #                    property_norms=property_norms)
    #     nll_test = test(args=args, deq=deq, loader=dataloaders['test'], epoch=epoch, eval_model=model_ema_dp,
    #                     partition='Test', device=device, dtype=dtype,
    #                     nodes_dist=nodes_dist, property_norms=property_norms)

    #     if nll_val < best_nll_val:
    #         best_nll_val = nll_val
    #         best_nll_test = nll_test
    #         if args.save_model:
    #             args.current_epoch = epoch + 1
    #             utils.save_model(optim, '%s/%s/optim.npy' % (args.output_dir, args.exp_name))
    #             utils.save_model(model, '%s/%s/generative_model.npy' % (args.output_dir, args.exp_name))
    #             if args.ema_decay > 0:
    #                 utils.save_model(model_ema, '%s/%s/generative_model_ema.npy' % (args.output_dir, args.exp_name))
    #             with open('%s/%s/args.pickle' % (args.output_dir,args.exp_name), 'wb') as f:
    #                 pickle.dump(args, f)

    #     print('Val loss: %.4f \t Test loss:  %.4f' % (nll_val, nll_test))
    #     print('Best val loss: %.4f \t Best test loss:  %.4f' % (best_nll_val, best_nll_test))
    #     wandb.log({"Val loss ": nll_val}, commit=True)
    #     wandb.log({"Test loss ": nll_test}, commit=True)
    #     wandb.log({"Best cross-validated test loss ": best_nll_test}, commit=True)

    # break


if __name__ == "__main__":
    logging.set_verbosity(logging.INFO)
    main()
