#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#


import os
import os.path as osp
import torch
from random import randint
import sys
import uuid
from tqdm import tqdm
from argparse import ArgumentParser, Namespace
import numpy as np
import yaml
import json

from gaussian_renderer import render, query
from arguments import ModelParams, PipelineParams, OptimizationParams
from utils.cfg_utils import load_config, args2string
from utils.general_utils import safe_state
from scene import Scene, GaussianModel
from utils.loss_utils import l1_loss, ssim, tv_3d_loss
from utils.image_utils import metric_vol, metric_proj
from utils.plot_utils import show_two_slice

try:
    from torch.utils.tensorboard import SummaryWriter

    TENSORBOARD_FOUND = True
except ImportError:
    TENSORBOARD_FOUND = False


def training(
    dataset: ModelParams,
    opt: OptimizationParams,
    pipe: PipelineParams,
    test_every,
    save_every,
    ckpt_every,
    checkpoint,
    debug_from,
    init_from,
    args,
):
    first_iter = 0
    tb_writer = prepare_output_and_logger(args)

    # Set up some parameters
    with open(osp.join(dataset.source_path, "meta_data.json"), "r") as handle:
        meta_data = json.load(handle)
    scanner_cfg = meta_data["scanner"]
    bbox = torch.tensor(meta_data["bbox"])
    sVoxel = torch.tensor(scanner_cfg["sVoxel"])
    nVoxel = torch.tensor(scanner_cfg["nVoxel"])
    dVoxel = sVoxel / nVoxel
    scale_min_bound = opt.scale_min_bound * float(dVoxel.min())
    max_scale = opt.max_scale * float(sVoxel.min()) if opt.max_scale else None
    scale_max_bound = opt.scale_max_bound * float(dVoxel.min())
    densify_scale_threshold = (
        opt.densify_scale_threshold * float(sVoxel.min())
        if opt.densify_scale_threshold
        else None
    )

    # Load gaussian and scene
    gaussians = GaussianModel([scale_min_bound, scale_max_bound])
    scene = Scene(
        dataset,
        gaussians,
        load_iteration=None,
        init_from=init_from,
        shuffle=False,
    )
    gaussians.training_setup(opt)
    if checkpoint:
        (model_params, first_iter) = torch.load(checkpoint)
        gaussians.restore(model_params, opt)

    # Set loss
    use_tv = opt.lambda_tv > 0
    if use_tv:
        print("Use total variation loss")
        tv_vol_size = opt.tv_vol_size
        tv_vol_nVoxel = torch.tensor([tv_vol_size, tv_vol_size, tv_vol_size])
        tv_vol_sVoxel = sVoxel / nVoxel * tv_vol_nVoxel

    # Train
    iter_start = torch.cuda.Event(enable_timing=True)
    iter_end = torch.cuda.Event(enable_timing=True)
    ckpt_save_path = osp.join(scene.model_path, "ckpt")
    os.makedirs(ckpt_save_path, exist_ok=True)
    viewpoint_stack = None
    progress_bar = tqdm(range(0, opt.iterations), desc="train")
    progress_bar.update(first_iter)
    first_iter += 1
    for iteration in range(first_iter, opt.iterations + 1):
        iter_start.record()
        gaussians.update_learning_rate(iteration)
        # Pick a random Camera
        if not viewpoint_stack:
            viewpoint_stack = scene.getTrainCameras().copy()
        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
        if (iteration - 1) == debug_from:
            pipe.debug = True

        render_pkg = render(viewpoint_cam, gaussians, pipe)

        image, viewspace_point_tensor, visibility_filter, radii = (
            render_pkg["render"],
            render_pkg["viewspace_points"],
            render_pkg["visibility_filter"],
            render_pkg["radii"],
        )
        gt_image = viewpoint_cam.original_image.cuda()

        # Loss
        loss = {}
        render_loss = l1_loss(image, gt_image)
        loss["render"] = render_loss
        loss["total"] = render_loss
        if opt.lambda_dssim > 0:
            loss_dssim = 1.0 - ssim(image, gt_image)
            loss["dssim"] = loss_dssim
            loss["total"] = loss["total"] + opt.lambda_dssim * loss_dssim
        if use_tv:
            tv_vol_center = (bbox[0] + tv_vol_sVoxel / 2) + (
                bbox[1] - tv_vol_sVoxel - bbox[0]
            ) * torch.rand(3)
            vol_pred = query(
                scene.gaussians,
                tv_vol_center,
                tv_vol_nVoxel,
                tv_vol_sVoxel,
                pipe,
            )["vol"]
            loss_tv = tv_3d_loss(vol_pred, reduction="mean")
            loss["tv"] = loss_tv
            loss["total"] = loss["total"] + opt.lambda_tv * loss_tv

        loss["total"].backward()

        iter_end.record()
        torch.cuda.synchronize()

        with torch.no_grad():
            # Adaptive control
            gaussians.max_radii2D[visibility_filter] = torch.max(
                gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
            )
            gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
            grads = gaussians.xyz_gradient_accum / gaussians.denom
            grads[grads.isnan()] = 0.0
            if iteration < opt.densify_until_iter:
                if (
                    iteration > opt.densify_from_iter
                    and iteration % opt.densification_interval == 0
                ):
                    gaussians.densify_and_prune(
                        grads,
                        opt.densify_grad_threshold,
                        opt.density_min_threshold,
                        opt.max_screen_size,
                        max_scale,
                        opt.max_num_gaussians,
                        densify_scale_threshold,
                        bbox,
                    )
            # Optimizer step
            if iteration < opt.iterations:
                gaussians.optimizer.step()
                gaussians.optimizer.zero_grad(set_to_none=True)

            # Save checkpoints
            if iteration % ckpt_every == 0 or iteration == opt.iterations:
                tqdm.write("[ITER {}] Saving Checkpoint".format(iteration))
                torch.save(
                    (gaussians.capture(), iteration),
                    ckpt_save_path + "/chkpnt" + str(iteration) + ".pth",
                )

            # Prune nan
            prune_mask = (torch.isnan(gaussians.get_density)).squeeze()
            if prune_mask.sum() > 0:
                gaussians.prune_points(prune_mask)
                tqdm.write(
                    "Prune {} gaussians because of nan.".format(prune_mask.sum())
                )

            # Progress bar
            if iteration % 10 == 0:
                progress_bar.set_postfix({"Loss": f"{loss['total'].item():.3e}"})
                progress_bar.update(10)
            if iteration == opt.iterations:
                progress_bar.close()

            # Tensorboard stats
            metrics_train = {
                "psnr": metric_vol(gt_image, image)[0],
            }
            for param_group in gaussians.optimizer.param_groups:
                metrics_train["lr_{}".format(param_group["name"])] = param_group["lr"]

            # Log and save
            training_report(
                tb_writer,
                iteration,
                loss,
                metrics_train,
                iter_start.elapsed_time(iter_end),
                test_every,
                scene,
                render,
                query,
                pipe,
            )

            # Save gaussians
            if iteration % save_every == 0 or iteration == opt.iterations:
                tqdm.write("[ITER {}] Saving Gaussians".format(iteration))
                scene.save(iteration, query, pipe)


def prepare_output_and_logger(args):
    if not args.model_path:
        if os.getenv("OAR_JOB_ID"):
            unique_str = os.getenv("OAR_JOB_ID")
        else:
            unique_str = str(uuid.uuid4())
        args.model_path = osp.join("./output/", unique_str[0:10])

    # Set up output folder
    print("Output folder: {}".format(args.model_path))
    os.makedirs(args.model_path, exist_ok=True)
    with open(osp.join(args.model_path, "cfg_args"), "w") as cfg_log_f:
        cfg_log_f.write(str(Namespace(**vars(args))))

    # Save to yaml
    args_dict = vars(args)
    with open(osp.join(args.model_path, "cfg_args.yml"), "w") as f:
        yaml.dump(args_dict, f, default_flow_style=False, sort_keys=False)

    # Create Tensorboard writer
    tb_writer = None
    if TENSORBOARD_FOUND:
        tb_writer = SummaryWriter(args.model_path)
        tb_writer.add_text("args", args2string(args_dict), global_step=0)
    else:
        print("Tensorboard not available: not logging progress")
    return tb_writer


def training_report(
    tb_writer,
    iteration,
    loss,
    metrics_train,
    elapsed,
    test_every,
    scene: Scene,
    renderFunc,
    queryFunc,
    renderArgs,
):
    if tb_writer:
        for key in list(loss.keys()):
            tb_writer.add_scalar(f"train/loss_{key}", loss[key].item(), iteration)
        for key in list(metrics_train.keys()):
            tb_writer.add_scalar(f"train/{key}", metrics_train[key], iteration)
        tb_writer.add_scalar("train/iter_time", elapsed, iteration)
        tb_writer.add_scalar(
            "train/total_points", scene.gaussians.get_xyz.shape[0], iteration
        )
    # Evaluate 2D projections
    if iteration % test_every == 0 or iteration == 1:
        test_save_path = osp.join(
            scene.model_path, "test", "iter_{:06d}".format(iteration)
        )
        os.makedirs(test_save_path, exist_ok=True)
        torch.cuda.empty_cache()
        validation_configs = [
            {"name": "train", "cameras": scene.getTrainCameras()},
            {"name": "test", "cameras": scene.getTestCameras()},
        ]
        scanner_cfg = scene.meta_data["scanner"]
        for config in validation_configs:
            if config["cameras"] and len(config["cameras"]) > 0:
                images = []
                gt_images = []
                image_show_2d = []
                # Render projections
                show_idx = np.linspace(0, len(config["cameras"]), 7).astype(int)[1:-1]
                for idx, viewpoint in enumerate(config["cameras"]):
                    image = renderFunc(
                        viewpoint,
                        scene.gaussians,
                        renderArgs,
                    )["render"]
                    gt_image = viewpoint.original_image.to("cuda")
                    images.append(image)
                    gt_images.append(gt_image)
                    if tb_writer and idx in show_idx:
                        image_show_2d.append(
                            torch.from_numpy(
                                show_two_slice(
                                    gt_image[0],
                                    image[0],
                                    "{} gt".format(viewpoint.image_name),
                                    "{} render".format(viewpoint.image_name),
                                    vmin=None,
                                    vmax=None,
                                    save=True,
                                )
                            )
                        )
                images = torch.concat(images, 0).permute(1, 2, 0)
                gt_images = torch.concat(gt_images, 0).permute(1, 2, 0)
                psnr_2d, psnr_2d_projs = metric_proj(gt_images, images, "psnr")
                ssim_2d, ssim_2d_projs = metric_proj(gt_images, images, "ssim")
                eval_dict_2d = {
                    "psnr_2d": psnr_2d,
                    "ssim_2d": ssim_2d,
                    "psnr_2d_projs": psnr_2d_projs,
                    "ssim_2d_projs": ssim_2d_projs,
                }
                with open(
                    osp.join(test_save_path, "eval2d_{}.yml".format(config["name"])),
                    "w",
                ) as f:
                    yaml.dump(
                        eval_dict_2d, f, default_flow_style=False, sort_keys=False
                    )

                if tb_writer:
                    image_show_2d = torch.from_numpy(
                        np.concatenate(image_show_2d, axis=0)
                    )[None].permute([0, 3, 1, 2])
                    tb_writer.add_images(
                        config["name"]
                        + "/proj-gt_render_diff".format(viewpoint.image_name),
                        image_show_2d,
                        global_step=iteration,
                    )

        # Evaluate 3D volume
        query_pkg = queryFunc(
            scene.gaussians,
            scanner_cfg["offOrigin"],
            scanner_cfg["nVoxel"],
            scanner_cfg["sVoxel"],
            renderArgs,
        )
        vol_pred = query_pkg["vol"].clip(0, 1)
        vol_gt = scene.vol_gt
        psnr_3d, _ = metric_vol(vol_gt, vol_pred, "psnr")
        ssim_3d, ssim_3d_axis = metric_vol(vol_gt, vol_pred, "ssim")
        eval_dict = {
            "psnr_3d": psnr_3d,
            "ssim_3d": ssim_3d,
            "ssim_3d_x": ssim_3d_axis[0],
            "ssim_3d_y": ssim_3d_axis[1],
            "ssim_3d_z": ssim_3d_axis[2],
        }
        with open(osp.join(test_save_path, "eval.yml"), "w") as f:
            yaml.dump(eval_dict, f, default_flow_style=False, sort_keys=False)
        np.save(osp.join(test_save_path, "vol_pred.npy"), vol_gt.cpu().numpy())
        if tb_writer:
            image_show_3d = np.concatenate(
                [
                    show_two_slice(
                        vol_gt[..., i],
                        vol_pred[..., i],
                        "slice {} gt".format(i),
                        "slice {} pred".format(i),
                        vmin=vol_gt[..., i].min(),
                        vmax=vol_gt[..., i].max(),
                        save=True,
                    )
                    for i in np.linspace(0, vol_gt.shape[2], 7).astype(int)[1:-1]
                ],
                axis=0,
            )
            image_show_3d = torch.from_numpy(image_show_3d)[None].permute([0, 3, 1, 2])
            tb_writer.add_images(
                config["name"] + "/slice-gt_pred_diff",
                image_show_3d,
                global_step=iteration,
            )
        tqdm.write(
            "[ITER {}] Evaluating {}: psnr3d {:.3f}, ssim3d {:.3f}, psnr2d {:.3f}, ssim2d {:.3f}".format(
                iteration, config["name"], psnr_3d, ssim_3d, psnr_2d, ssim_2d
            )
        )

        # Metrics
        if tb_writer:
            tb_writer.add_scalar(config["name"] + "/psnr_2d", psnr_2d, iteration)
            tb_writer.add_scalar(config["name"] + "/ssim_2d", ssim_2d, iteration)
            tb_writer.add_scalar(config["name"] + "/psnr_3d", psnr_3d, iteration)
            tb_writer.add_scalar(config["name"] + "/ssim_3d", ssim_3d, iteration)

    torch.cuda.empty_cache()


if __name__ == "__main__":
    # fmt: off
    # Set up command line argument parser
    parser = ArgumentParser(description="Training script parameters")
    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    parser.add_argument("--debug_from", type=int, default=-1)
    parser.add_argument("--detect_anomaly", action="store_true", default=False)
    parser.add_argument("--test_every", type=int, default=5000)
    parser.add_argument("--save_every", type=int, default=10000)
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--ckpt_every", type=int, default=100000)
    parser.add_argument("--start_checkpoint", type=str, default=None)
    parser.add_argument("--config", type=str, default="configs/foot_cone_50.yml")
    args = parser.parse_args(sys.argv[1:])
    # fmt: on

    args_dict = vars(args)
    if args.config is not None:
        cfg = load_config(args.config)
        for key in list(cfg.keys()):
            args_dict[key] = cfg[key]

    # Initialize system state (RNG)
    safe_state(args.quiet)

    print("Optimizing " + args.model_path)

    torch.autograd.set_detect_anomaly(args.detect_anomaly)
    training(
        lp.extract(args),
        op.extract(args),
        pp.extract(args),
        args.test_every,
        args.save_every,
        args.ckpt_every,
        args.start_checkpoint,
        args.debug_from,
        args.init_from,
        args,
    )

    # All done
    print("\nTraining complete.")
