# Rdkit import should be first, do not move it
try:
    from rdkit import Chem
except ModuleNotFoundError:
    pass
import build_geom_dataset
from configs.datasets_config import geom_with_h
import copy
import utils
import argparse
import wandb
from os.path import join
from qm9.models import get_optim, get_model
from equivariant_diffusion import en_diffusion

from equivariant_diffusion import utils as flow_utils
import torch
import time
import pickle

from qm9.utils import prepare_context, compute_mean_mad
from train_test import train_epoch, test, analyze_and_save, data_callback

parser = argparse.ArgumentParser(description="E3Diffusion")
parser.add_argument("--exp_name", type=str, default="debug_10")
parser.add_argument("--gpu_id", type=str, default="0")
parser.add_argument(
    "--model",
    type=str,
    default="egnn_dynamics",
    help="our_dynamics | schnet | simple_dynamics | "
    "kernel_dynamics | egnn_dynamics |gnn_dynamics",
)
parser.add_argument(
    "--probabilistic_model", type=str, default="diffusion", help="diffusion|flow_match"
)

# Training complexity is O(1) (unaffected), but sampling complexity O(steps).
parser.add_argument("--diffusion_steps", type=int, default=500)
parser.add_argument(
    "--diffusion_noise_schedule",
    type=str,
    default="polynomial_2",
    help="learned, cosine",
)
parser.add_argument("--diffusion_loss_type", type=str, default="l2", help="vlb, l2")
parser.add_argument("--diffusion_noise_precision", type=float, default=1e-5)

parser.add_argument(
    "--discrete_path", type=str, default="OT_path", help="OT_path, HB_path, VP_path"
)

# parser.add_argument("--diffusion_loss_type", type=str, default="l2", help="vlb, l2")

parser.add_argument("--cat_loss_step", type=float, default=-1)

parser.add_argument("--cat_loss", type=str, default="l2", help='"l2" or "cse"')

parser.add_argument("--on_hold_batch", type=int, default=-1)

parser.add_argument("--sampling_method", type=str, default="vanilla")
parser.add_argument("--weighted_methods", type=str, default="jump")
parser.add_argument("--ode_method", type=str, default="dopri5")

parser.add_argument(
    "--output_dir",
    type=str,
    default="outputs",
    help="outputs | /sharefs/anonymous/edm_output/outputs",
)

parser.add_argument("--n_epochs", type=int, default=10000)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument(
    "--break_train_epoch", type=eval, default=False, help="True | False"
)
parser.add_argument("--dp", type=eval, default=True, help="True | False")
parser.add_argument("--condition_time", type=eval, default=True, help="True | False")
parser.add_argument("--clip_grad", type=eval, default=True, help="True | False")
parser.add_argument("--trace", type=str, default="hutch", help="hutch | exact")
# EGNN args -->
parser.add_argument("--n_layers", type=int, default=6, help="number of layers")
parser.add_argument("--inv_sublayers", type=int, default=1, help="number of layers")
parser.add_argument("--nf", type=int, default=192, help="number of layers")
parser.add_argument("--tanh", type=eval, default=True, help="use tanh in the coord_mlp")
parser.add_argument(
    "--attention", type=eval, default=True, help="use attention in the EGNN"
)
parser.add_argument(
    "--norm_constant", type=float, default=1, help="diff/(|diff| + norm_constant)"
)
parser.add_argument(
    "--sin_embedding",
    type=eval,
    default=False,
    help="whether using or not the sin embedding",
)
# <-- EGNN args
parser.add_argument("--ode_regularization", type=float, default=1e-3)
parser.add_argument("--dataset", type=str, default="geom", help="dataset name")
parser.add_argument(
    "--filter_n_atoms",
    type=int,
    default=None,
    help="When set to an integer value, QM9 will only contain molecules of that amount of atoms",
)
parser.add_argument(
    "--dequantization",
    type=str,
    default="argmax_variational",
    help="uniform | variational | argmax_variational | deterministic",
)
parser.add_argument("--n_report_steps", type=int, default=50)
parser.add_argument("--wandb_usr", type=str)
parser.add_argument("--no_wandb", action="store_true", help="Disable wandb")
parser.add_argument(
    "--online",
    type=bool,
    default=True,
    help="True = wandb online -- False = wandb offline",
)
parser.add_argument(
    "--no-cuda", action="store_true", default=False, help="disable CUDA training"
)
parser.add_argument("--save_model", type=eval, default=True, help="save model")
parser.add_argument("--generate_epochs", type=int, default=1)
parser.add_argument(
    "--num_workers", type=int, default=0, help="Number of worker for the dataloader"
)
parser.add_argument("--test_epochs", type=int, default=1)
parser.add_argument(
    "--data_augmentation", type=eval, default=False, help="use attention in the EGNN"
)
parser.add_argument("--sample_eva_epochs", type=int, default=20)

parser.add_argument(
    "--conditioning",
    nargs="+",
    default=[],
    help="multiple arguments can be passed, "
    "including: homo | onehot | lumo | num_atoms | etc. "
    'usage: "--conditioning H_thermo homo onehot H_thermo"',
)
parser.add_argument("--resume", type=str, default=None, help="")
parser.add_argument("--start_epoch", type=int, default=0, help="")
parser.add_argument(
    "--ema_decay",
    type=float,
    default=0,  # TODO
    help="Amount of EMA decay, 0 means off. A reasonable value" " is 0.999.",
)
parser.add_argument("--augment_noise", type=float, default=0)
parser.add_argument(
    "--n_stability_samples",
    type=int,
    default=20,
    help="Number of samples to compute the stability",
)
parser.add_argument(
    "--normalize_factors",
    type=eval,
    default=[1, 4, 10],
    help="normalize factors for [x, categorical, integer]",
)
parser.add_argument("--remove_h", action="store_true")
parser.add_argument(
    "--include_charges", type=eval, default=False, help="include atom charge or not"
)
parser.add_argument("--visualize_every_batch", type=int, default=5000)
parser.add_argument(
    "--normalization_factor",
    type=float,
    default=100,
    help="Normalize the sum aggregation of EGNN",
)
parser.add_argument(
    "--aggregation_method",
    type=str,
    default="sum",
    help='"sum" or "mean" aggregation for the graph network',
)
parser.add_argument(
    "--filter_molecule_size",
    type=int,
    default=None,
    help="Only use molecules below this size.",
)

parser.add_argument(
    "--without_cat_loss", action="store_true", help="train without categorical loss"
)

parser.add_argument("--node_classifier_model_ckpt", type=str)

parser.add_argument(
    "--angle_penalty", action="store_true", help="train with angle penalty"
)

parser.add_argument(
    "--sequential",
    action="store_true",
    help="Organize data by size to reduce average memory usage.",
)
parser.add_argument("--extend_feature_dim", type=int, default=0)


args = parser.parse_args()

data_file = "/sharefs/anonymous/geom_drugs_30.npy"

if args.remove_h:
    raise NotImplementedError()
else:
    dataset_info = geom_with_h

args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.dp == False:  # If not using DP, then use CUDA
    device_id = "cuda:%s" % args.gpu_id
    torch.cuda.set_device(device_id)
    device = torch.device(device_id)
else:
    device = torch.device("cuda" if args.cuda else "cpu")
dtype = torch.float32

split_data = build_geom_dataset.load_split_data(
    data_file,
    val_proportion=0.1,
    test_proportion=0.1,
    filter_size=args.filter_molecule_size,
)
transform = build_geom_dataset.GeomDrugsTransform(
    dataset_info, args.include_charges, device, args.sequential
)
dataloaders = {}
for key, data_list in zip(["train", "val", "test"], split_data):
    dataset = build_geom_dataset.GeomDrugsDataset(data_list, transform=transform)
    shuffle = (key == "train") and not args.sequential

    # Sequential dataloading disabled for now.
    dataloaders[key] = build_geom_dataset.GeomDrugsDataLoader(
        sequential=args.sequential,
        dataset=dataset,
        batch_size=args.batch_size,
        shuffle=shuffle,
    )
del split_data

atom_encoder = dataset_info["atom_encoder"]
atom_decoder = dataset_info["atom_decoder"]

# args, unparsed_args = parser.parse_known_args()
args.wandb_usr = utils.get_wandb_username(args.wandb_usr)


if args.resume is not None:
    exp_name = args.exp_name + "_resume"
    start_epoch = args.start_epoch
    resume = args.resume
    wandb_usr = args.wandb_usr
    sample_eva_epochs = args.sample_eva_epochs
    out_dir = args.output_dir
    sampling_method = args.sampling_method
    weighted_methods = args.weighted_methods

    dp = args.dp
    test_epochs = args.test_epochs
    normalization_factor = args.normalization_factor
    aggregation_method = args.aggregation_method
    cat_loss = args.cat_loss
    cat_loss_step = args.cat_loss_step
    on_hold_batch = args.on_hold_batch
    ode_method = args.ode_method
    vis_epoch = args.visualize_every_batch

    with open(join(args.resume, "args_%d.pickle" % args.start_epoch), "rb") as f:
        args = pickle.load(f)

    args.resume = resume
    args.break_train_epoch = False

    args.dp = dp
    args.exp_name = exp_name
    args.start_epoch = start_epoch
    args.wandb_usr = wandb_usr
    args.test_epochs = test_epochs
    args.sample_eva_epochs = sample_eva_epochs
    args.output_dir = out_dir
    args.cat_loss = cat_loss
    args.cat_loss_step = cat_loss_step
    args.sampling_method = sampling_method
    args.weighted_methods = weighted_methods
    args.on_hold_batch = on_hold_batch
    args.ode_method = ode_method
    args.visualize_every_batch = vis_epoch

    # Careful with this -->
    if not hasattr(args, "normalization_factor"):
        args.normalization_factor = normalization_factor
    if not hasattr(args, "aggregation_method"):
        args.aggregation_method = aggregation_method

    print(args)


utils.create_folders(args)
print(args)

# Wandb config
if args.no_wandb:
    mode = "disabled"
else:
    mode = "online" if args.online else "offline"
kwargs = {
    "entity": args.wandb_usr,
    "name": args.exp_name,
    "project": "e3_diffusion_geom",
    "config": args,
    "settings": wandb.Settings(_disable_stats=True),
    "reinit": True,
    "mode": mode,
    "dir": "/sharefs/anonymous/logs",
}
wandb.init(**kwargs)
wandb.save("*.txt")

data_dummy = next(iter(dataloaders["train"]))

if len(args.conditioning) > 0:
    print(f"Conditioning on {args.conditioning}")
    property_norms = compute_mean_mad(dataloaders, args.conditioning)
    context_dummy = prepare_context(args.conditioning, data_dummy, property_norms)
    context_node_nf = context_dummy.size(2)
else:
    context_node_nf = 0
    property_norms = None

args.context_node_nf = context_node_nf

# Create EGNN flow
model, nodes_dist, prop_dist, deq = get_model(
    args, device, dataset_info, dataloaders["train"]
)
print("==" * 20)
print(nodes_dist)

if prop_dist is not None:
    prop_dist.set_normalizer(property_norms)

model = model.to(device)
deq = deq.to(device)
optim = get_optim(args, model)
# print(model)

gradnorm_queue = utils.Queue()
gradnorm_queue.add(3000)  # Add large value that will be flushed.


def main():
    if args.resume is not None:
        if args.start_epoch != 0:
            print("Resuming from epoch %d" % args.start_epoch)

            flow_state_dict = torch.load(
                join(args.resume, "generative_model_%d.npy" % args.start_epoch)
            )
            optim_state_dict = torch.load(
                join(args.resume, "optim_%d.npy" % args.start_epoch)
            )

            model.load_state_dict(flow_state_dict)
            optim.load_state_dict(optim_state_dict)

        else:
            flow_state_dict = torch.load(join(args.resume, "generative_model.npy"))
            # dequantizer_state_dict = torch.load(join(args.resume, 'dequantizer.npy'))
            optim_state_dict = torch.load(join(args.resume, "optim.npy"))
            model.load_state_dict(flow_state_dict)
            # deq.load_state_dict(dequantizer_state_dict)
            optim.load_state_dict(optim_state_dict)

    # Initialize dataparallel if enabled and possible.
    print(args.dp, "parallel or not")

    if args.dp and torch.cuda.device_count() > 1:
        print(f"Training using {torch.cuda.device_count()} GPUs")
        model_dp = torch.nn.DataParallel(model.cpu())
        deq_dp = torch.nn.DataParallel(deq.cpu())
        model_dp = model_dp.cuda()
        deq_dq = deq_dp.cuda()
    else:
        model_dp = model

    # Initialize model copy for exponential moving average of params.
    if args.ema_decay > 0:
        model_ema = copy.deepcopy(model)
        if args.start_epoch != 0 and args.resume is not None:
            model_ema_state_dict = torch.load(
                join(args.resume, "generative_model_ema_%d.npy" % args.start_epoch)
            )
            model_ema.load_state_dict(model_ema_state_dict)
            print("ema model loaded")

        ema = flow_utils.EMA(args.ema_decay)

        if args.dp and torch.cuda.device_count() > 1:
            model_ema_dp = torch.nn.DataParallel(model_ema)
        else:
            model_ema_dp = model_ema
    else:
        ema = None
        model_ema = model
        model_ema_dp = model_dp

    best_nll_val = 1e8
    best_nll_test = 1e8
    for epoch in range(args.start_epoch, args.n_epochs):
        start_epoch = time.time()
        # data_callback(device=device,dtype=dtype,loader=dataloaders['train'])
        # break
        train_epoch(
            args=args,
            deq=deq,
            loader=dataloaders["train"],
            epoch=epoch,
            model=model,
            model_dp=model_dp,
            model_ema=model_ema,
            ema=ema,
            device=device,
            dtype=dtype,
            property_norms=property_norms,
            nodes_dist=nodes_dist,
            dataset_info=dataset_info,
            gradnorm_queue=gradnorm_queue,
            optim=optim,
            prop_dist=prop_dist,
        )
        print(f"Epoch took {time.time() - start_epoch:.1f} seconds.")

        if epoch % args.sample_eva_epochs == 0:
            if not args.break_train_epoch:
                analyze_and_save(
                    args=args,
                    epoch=epoch,
                    model_sample=model_ema,
                    dequantizer=deq,
                    nodes_dist=nodes_dist,
                    dataset_info=dataset_info,
                    device=device,
                    prop_dist=prop_dist,
                    n_samples=args.n_stability_samples,
                )
            if args.save_model:
                utils.save_model(
                    optim,
                    "%s/%s/optim_%d.npy" % (args.output_dir, args.exp_name, epoch),
                )
                utils.save_model(
                    model,
                    "%s/%s/generative_model_%d.npy"
                    % (args.output_dir, args.exp_name, epoch),
                )
                if args.ema_decay > 0:
                    utils.save_model(
                        model_ema,
                        "%s/%s/generative_model_ema_%d.npy"
                        % (args.output_dir, args.exp_name, epoch),
                    )
                with open(
                    "%s/%s/args_%d.pickle" % (args.output_dir, args.exp_name, epoch),
                    "wb",
                ) as f:
                    pickle.dump(args, f)

        if epoch % args.test_epochs == 0 and epoch != 0:
            if isinstance(model, en_diffusion.EnVariationalDiffusion):
                wandb.log(model.log_info(), commit=True)

            # if not args.break_train_epoch:
            #     analyze_and_save(args=args, epoch=epoch, model_sample=model_ema, dequantizer=deq, nodes_dist=nodes_dist,
            #                      dataset_info=dataset_info, device=device,
            #                      prop_dist=prop_dist, n_samples=args.n_stability_samples)
            nll_val = test(
                args=args,
                deq=deq,
                loader=dataloaders["valid"],
                epoch=epoch,
                eval_model=model_ema_dp,
                partition="Val",
                device=device,
                dtype=dtype,
                nodes_dist=nodes_dist,
                property_norms=property_norms,
            )
            nll_test = test(
                args=args,
                deq=deq,
                loader=dataloaders["test"],
                epoch=epoch,
                eval_model=model_ema_dp,
                partition="Test",
                device=device,
                dtype=dtype,
                nodes_dist=nodes_dist,
                property_norms=property_norms,
            )

            if nll_val < best_nll_val:
                best_nll_val = nll_val
                best_nll_test = nll_test
                if args.save_model:
                    args.current_epoch = epoch + 1
                    utils.save_model(
                        optim, "%s/%s/optim.npy" % (args.output_dir, args.exp_name)
                    )
                    utils.save_model(
                        model,
                        "%s/%s/generative_model.npy" % (args.output_dir, args.exp_name),
                    )
                    if args.ema_decay > 0:
                        utils.save_model(
                            model_ema,
                            "%s/%s/generative_model_ema.npy"
                            % (args.output_dir, args.exp_name),
                        )
                    with open(
                        "%s/%s/args.pickle" % (args.output_dir, args.exp_name), "wb"
                    ) as f:
                        pickle.dump(args, f)

            print("Val loss: %.4f \t Test loss:  %.4f" % (nll_val, nll_test))
            print(
                "Best val loss: %.4f \t Best test loss:  %.4f"
                % (best_nll_val, best_nll_test)
            )
            wandb.log({"Val loss ": nll_val}, commit=True)
            wandb.log({"Test loss ": nll_test}, commit=True)
            wandb.log({"Best cross-validated test loss ": best_nll_test}, commit=True)


if __name__ == "__main__":
    main()
