'''
code for rendering on the toy dataset using the pretrained checkpoint and object id
'''

'''
RENDERING CODE
renders views for a given checkpoint at a specific output resolution
'''

import torch
import argparse

from nerf.provider_abo_small import NeRFDataset,MetaNeRFDataset, MetaNeRFInversionDataset
from nerf.gui import NeRFGUI
# from nerf.utils_inversion import *
from nerf.utils import *

from functools import partial
from loss import huber_loss

import gzip
import os.path as osp
import json
import os

'''
TODO hardcode the parameter values 
'''


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('path', type=str)
    parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload")
    parser.add_argument('--test', action='store_true', help="test mode")
    parser.add_argument('--workspace', type=str, default='workspace')
    parser.add_argument('--class_choice', type=str, default='chair')
    parser.add_argument('--seed', type=int, default=0)

    ### training options
    parser.add_argument('--iters', type=int, default=30000, help="training iters")
    parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
    parser.add_argument('--ckpt', type=str, default='latest')
    parser.add_argument('--load_ckpt', action="store_true", help="if the checkpoint should not be loaded, the checkpoint would be deleted, beware!")
    parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step")
    parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
    parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
    parser.add_argument('--num_steps', type=int, default=512, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
    parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
    parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
    parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
    parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")

    ### network backbone options
    parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
    parser.add_argument('--ff', action='store_true', help="use fully-fused MLP")
    parser.add_argument('--tcnn', action='store_true', help="use TCNN backend")

    ### dataset options
    parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
    parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory")
    # (the default value is for the fox dataset)
    parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
    parser.add_argument('--scale', type=float, default=0.33, help="scale camera location into box[-bound, bound]^3")
    parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
    parser.add_argument('--dt_gamma', type=float, default=1/128, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
    parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera")
    parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
    parser.add_argument('--bg_radius', type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)")

    ### GUI options
    parser.add_argument('--gui', action='store_true', help="start a GUI")
    parser.add_argument('--W', type=int, default=1920, help="GUI width")
    parser.add_argument('--H', type=int, default=1080, help="GUI height")
    parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center")
    parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy")
    parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel")

    ### experimental
    parser.add_argument('--error_map', action='store_true', help="use error map to sample rays")
    parser.add_argument('--clip_text', type=str, default='', help="text input for CLIP guidance")
    parser.add_argument('--rand_pose', type=int, default=-1, help="<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses")
    parser.add_argument('-b', type=int, default=1, help="batch size")
    parser.add_argument('--custom_hashhn', action='store_true', help="if selected, the hypernetwork will use 5 layers and 512 dimensional MLP for predicting the hash grid")

    parser.add_argument('--clip_mapping', action='store_true', help="learn a mapping from clip space to the hypernetwork space")
    parser.add_argument('--kd_clip', action='store_true', help="knowledge distillation on clip")

    parser.add_argument('--superresolution', action='store_true', help="superresolution")
    parser.add_argument('--sr_ckpt', type=str, help="the checkpoint to load low-resolution hypernet")

    parser.add_argument('--varprior', action='store_true', help="variational prior")

    parser.add_argument('--clipcondition', action='store_true', help="clip condition")
    parser.add_argument('--invert', action='store_true', help="invert")
    parser.add_argument('--finetune', action='store_true', help="finetune")
    parser.add_argument('--multiview_inversion', action='store_true', help='invert using multiple views')
    parser.add_argument('--codebook_index', type=int, default=-1, help="Render a specific codebook index") # if not provided sample from one of predefined objects

    opt = parser.parse_args()

    if opt.patch_size > 1:
        opt.error_map = False # do not use error_map if use patch-based training
        # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
        assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."

    if opt.ff:
        opt.fp16 = True
        assert opt.bg_radius <= 0, "background model is not implemented for --ff"
        from nerf.network_ff import NeRFNetwork
    elif opt.tcnn:
        opt.fp16 = True
        assert opt.bg_radius <= 0, "background model is not implemented for --tcnn"
        from nerf.network_tcnn import NeRFNetwork
    else:
        from nerf.network_fcblock import NeRFNetwork,NeRFGen,NeRFSuperresolution

    checkpoints_path = os.path.join(opt.workspace, "checkpoints")
    if not opt.load_ckpt and os.path.exists(checkpoints_path):
        import shutil
        shutil.rmtree(checkpoints_path)
        print("Deleted previous checkpoints!")

    print(f"Options: {opt}")
    
    # seed_everything(opt.seed)

    criterion = torch.nn.MSELoss(reduction='none')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    opt.device = device
    
    root_path = osp.join(opt.path, 'abo_small')

    objects = os.listdir(root_path + '/')

    # class_choices = ['CHAIR', 'TABLE', 'SOFA']
    class_choices = ['CHAIR']

    test_dataset = MetaNeRFInversionDataset(objects, opt, device=device, type='test')
    test_loader = DataLoader(test_dataset, batch_size=opt.b, shuffle=False, num_workers=20)

    chair_instances = len(objects)

    instances = chair_instances

    model = NeRFGen(opt, instances, custom_hashhn=opt.custom_hashhn) 

    '''
    TODO posed input image that was used
    '''

    # input_ckpts_path = os.path.join(opt.workspace, 'checkpoints','abo_chairs',object_id)
    # input_ckpts_path = ''
    input_ckpts_path = ''
    print(f'Loading pretrained checkpoint for inversion : {input_ckpts_path}')     # provided in a certain dir 

    # optimizer = lambda model: torch.optim.Adam(model.parameters(), betas=(0.9, 0.99), eps=1e-15)
    optimizer = lambda model: torch.optim.Adam(list(model.shape_code.parameters())+list(model.color_code.parameters()), betas=(0.9, 0.99), eps=1e-15)
    
    # decay to 0.1 * init_lr at last iter step
    scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))

    metrics = [PSNRMeter(), LPIPSMeter(device=device)]

    num_epochs = 10

    trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer,
        criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler,
        scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=num_epochs,checkpoint=input_ckpts_path)

    object_id = 'B07HSBJ7J9'

    object_index = objects.index(object_id) # function to get the object index from the object id

    if opt.codebook_index != -1:
        print(f'Rendering user defined codebook index : {opt.codebook_index}')
        object_index = opt.codebook_index
        object_id = objects[object_index]


    print(f'Rendering codebook index : {object_index} for object {object_id}')
            
    '''
    run the test code for generating the images from the 91 different poses 
    and save them to the path
    '''
    test_loader.dataset.object_index = object_index

    '''
    construct the save_path where the rendered images are saved
    '''
    save_path = osp.join(opt.workspace, f'results/{object_id}')

    print(f'Starting rendering')

    '''
    code for rendering the loaded nerf checkpoint from multiple predefined input poses
    '''
    trainer.test(test_loader, save_path=save_path, write_video=False)

    '''
    TODO save the output as a flythrough video - using different poses
    '''