'''
INVERSION CODE
This code takes an object id, and checkpoint as input and renders the nerf at a specific resolution 
The pose is taken from one of the predefined camera poses 

Output -> (fixed render resolution) 
[i] rendered images from different camera poses
[ii] video
[iii] posed input image that was used
'''

import torch
import argparse

from nerf.provider_abo import NeRFDataset,MetaNeRFDataset, MetaNeRFInversionDataset
from nerf.gui import NeRFGUI
# from nerf.utils_inversion import *
from nerf.utils import *

from functools import partial
from loss import huber_loss

import gzip
import os.path as osp
import json
import os

'''
loads all the objects in the dataset for inversion 
'''
def init(root_path, listing_path, split_type="test"):

    with open(os.path.join(root_path, "train_test_split.csv")) as f:
        train_test_split = f.read().splitlines()
        train_split = [x.split(",")[0] for x in train_test_split if "TRAIN" in x]
        test_split = [x.split(",")[0] for x in train_test_split if "TEST" in x]

    print(len(train_split))

    split = train_split if split_type == 'train' else test_split

    objects = os.listdir(root_path)

    try:
        objects.remove('README.md')
        objects.remove('test_sample_idx.json')
        objects.remove('train_sample_idx.json')
        objects.remove('train_test_split.csv')
    except:
        pass

    '''
    remove objects from path that do not have a metadata.json file 
    '''
    for path in objects:
        data_path = os.path.join(root_path, path)
        if not (os.path.exists(os.path.join(data_path, 'metadata.json'))):
            objects.remove(path)

    print(f'Total number of files with a valid metadata.json file : {len(objects)}')

    '''
    read the zipped json files
    '''

    with gzip.open(os.path.join(listing_path,"listings_0.json.gz"), mode="r") as f:
        metadata = [json.loads(line) for line in f]
            
    load_all = True
    if load_all:
        with gzip.open(os.path.join(listing_path,"listings_1.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_2.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_3.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_4.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_5.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_6.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_7.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_8.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_9.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_a.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_b.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_c.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_d.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_e.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)
        with gzip.open(os.path.join(listing_path,"listings_f.json.gz"), mode="r") as f:
            for line in f:
                json_dict= json.loads(line)
                metadata.append(json_dict)

        print(f'All files loaded successfully')

    types = []
    objects.sort()
    for d in metadata:
        try:
            types.append(d['item_id'])
        except:
            print(d)

    print(f"Total objects in ABO Dataset: {len(types)}")
    print(f"Total rendered objects {len(objects)}")

    return metadata, types, objects, split

def get_class_objects(metadata, types, objects, split, class_choice):
    filtered_objects = []

    '''
    get objects specific to a given class
    '''
    for val in objects:
        if val in types and val in split:
            if metadata[types.index(val)]['product_type'][0]['value'].lower() == class_choice.lower():
                filtered_objects.append(val)

    # objects = filtered_objects.copy()
    print(f"Total rendered {class_choice} objects: {len(filtered_objects)}")

    return filtered_objects

'''
TODO hardcode the parameter values 
'''


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('path', type=str)
    parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload")
    parser.add_argument('--test', action='store_true', help="test mode")
    parser.add_argument('--workspace', type=str, default='workspace')
    parser.add_argument('--class_choice', type=str, default='chair')
    parser.add_argument('--seed', type=int, default=0)

    ### training options
    parser.add_argument('--iters', type=int, default=30000, help="training iters")
    parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
    parser.add_argument('--ckpt', type=str, default='latest')
    parser.add_argument('--load_ckpt', action="store_true", help="if the checkpoint should not be loaded, the checkpoint would be deleted, beware!")
    parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step")
    parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
    parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
    parser.add_argument('--num_steps', type=int, default=512, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
    parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
    parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
    parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
    parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")

    ### network backbone options
    parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
    parser.add_argument('--ff', action='store_true', help="use fully-fused MLP")
    parser.add_argument('--tcnn', action='store_true', help="use TCNN backend")

    ### dataset options
    parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
    parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory")
    # (the default value is for the fox dataset)
    parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
    parser.add_argument('--scale', type=float, default=0.33, help="scale camera location into box[-bound, bound]^3")
    parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
    parser.add_argument('--dt_gamma', type=float, default=1/128, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
    parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera")
    parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
    parser.add_argument('--bg_radius', type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)")

    ### GUI options
    parser.add_argument('--gui', action='store_true', help="start a GUI")
    parser.add_argument('--W', type=int, default=1920, help="GUI width")
    parser.add_argument('--H', type=int, default=1080, help="GUI height")
    parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center")
    parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy")
    parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel")

    ### experimental
    parser.add_argument('--error_map', action='store_true', help="use error map to sample rays")
    parser.add_argument('--clip_text', type=str, default='', help="text input for CLIP guidance")
    parser.add_argument('--rand_pose', type=int, default=-1, help="<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses")
    parser.add_argument('-b', type=int, default=1, help="batch size")
    parser.add_argument('--custom_hashhn', action='store_true', help="if selected, the hypernetwork will use 5 layers and 512 dimensional MLP for predicting the hash grid")

    parser.add_argument('--clip_mapping', action='store_true', help="learn a mapping from clip space to the hypernetwork space")
    parser.add_argument('--kd_clip', action='store_true', help="knowledge distillation on clip")

    parser.add_argument('--superresolution', action='store_true', help="superresolution")
    parser.add_argument('--sr_ckpt', type=str, help="the checkpoint to load low-resolution hypernet")

    parser.add_argument('--varprior', action='store_true', help="variational prior")

    parser.add_argument('--clipcondition', action='store_true', help="clip condition")
    parser.add_argument('--invert', action='store_true', help="invert")
    parser.add_argument('--finetune', action='store_true', help="finetune")

    parser.add_argument('--multiview_inversion', action='store_true', help='invert using multiple views')

    opt = parser.parse_args()

    if opt.patch_size > 1:
        opt.error_map = False # do not use error_map if use patch-based training
        # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
        assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."

    if opt.ff:
        opt.fp16 = True
        assert opt.bg_radius <= 0, "background model is not implemented for --ff"
        from nerf.network_ff import NeRFNetwork
    elif opt.tcnn:
        opt.fp16 = True
        assert opt.bg_radius <= 0, "background model is not implemented for --tcnn"
        from nerf.network_tcnn import NeRFNetwork
    else:
        from nerf.network_fcblock import NeRFNetwork,NeRFGen,NeRFSuperresolution

    checkpoints_path = os.path.join(opt.workspace, "checkpoints")
    if not opt.load_ckpt and os.path.exists(checkpoints_path):
        import shutil
        shutil.rmtree(checkpoints_path)
        print("Deleted previous checkpoints!")

    print(f"Options: {opt}")
    
    # seed_everything(opt.seed)

    criterion = torch.nn.MSELoss(reduction='none')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    opt.device = device
    

    
    print(f'Running init')

    root_path = ''
    listing_path = ''

    metadata, types, objects, split = init(root_path, listing_path, split_type="test")

    # class_choices = ['CHAIR', 'TABLE', 'SOFA']
    class_choices = ['CHAIR']

    '''
    iterate through all class choices and perform inversion
    '''
    # get objects belonging to a given class
    objects = get_class_objects(metadata, types, objects, split, class_choices[0])

    chair_num_epochs = 1800

    num_epochs = chair_num_epochs

    '''
    TODO bound and scale_factor needs to be hardcoded for object indices 
    '''

    train_dataset = MetaNeRFInversionDataset(objects, opt, device=device, type='train')
    val_dataset = MetaNeRFInversionDataset(objects, opt, device=device, type='val')
    test_dataset = MetaNeRFInversionDataset(objects, opt, device=device, type='test')

    train_loader = DataLoader(train_dataset, batch_size=opt.b, shuffle=True, num_workers=20)
    valid_loader = DataLoader(val_dataset, batch_size=opt.b, shuffle=True, num_workers=20)
    test_loader = DataLoader(test_dataset, batch_size=opt.b, shuffle=False, num_workers=20)

    num_examples = len(train_dataset)

    chair_instances = 1038

    instances = chair_instances

    model = NeRFGen(opt, instances, custom_hashhn=opt.custom_hashhn) 
    
    print(model)
    print(f"Number of training exmaples: {len(train_dataset)}")

    # for the specific object id that is provided as the input -> get the object index 
    # object_id = '' # provided as the input 
    object_ids = ['B07B4D49HD', 'B07B4M68LW', 'B07THSVKCJ', 'B07QBMQCP1', 'B0746H4BP6', 'B075YQXR2Y', 'B07BWJMB9F', 'B075X52BMR', 'B07HSBJ7J9', 'B082VLJR5V']

    import random 
    object_id = object_ids[random.randint(0, len(object_ids)-1)]

    object_poses = [4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 18, 19, 21, 22, 23, 25, 26, 28, 29, 30, 31, 33, 35, 36, 37, 48, 49, 50, 52, 53, 54, 56, 57, 58, 59, 77, 82]
    pose_index = object_poses[random.randint(0, len(object_poses)-1)]

    '''
    TODO posed input image that was used
    '''

    # input_ckpts_path = os.path.join(opt.workspace, 'checkpoints','abo_chairs',object_id)
    input_ckpts_path = ''
    print(f'Loading pretrained checkpoint for inversion : {input_ckpts_path}')     # provided in a certain dir 

    # optimizer = lambda model: torch.optim.Adam(model.parameters(), betas=(0.9, 0.99), eps=1e-15)
    optimizer = lambda model: torch.optim.Adam(list(model.shape_code.parameters())+list(model.color_code.parameters()), betas=(0.9, 0.99), eps=1e-15)
    
    # decay to 0.1 * init_lr at last iter step
    scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))

    metrics = [PSNRMeter(), LPIPSMeter(device=device)]

    trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer,
        criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler,
        scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=num_epochs,checkpoint=input_ckpts_path)

    object_index = objects.index(object_id) # function to get the object index from the object id

            
    print(f'Using pose {pose_index} for object {object_index}, object id {object_id}')

    train_loader.dataset.object_index = object_index
    train_loader.dataset.pose_index = pose_index

    max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
    print(f'Started inversion!!')

    inverted_ckpt_path = ''
    os.makedirs(inverted_ckpt_path, exist_ok=True)

    trainer.train(train_loader, valid_loader, num_epochs, inverted_ckpt_path)
            
    '''
    run the test code for generating the images from the 91 different poses 
    and save them to the path
    '''
    test_loader.dataset.object_index = object_index

    '''
    construct the save_path where the rendered images are saved
    '''
    save_path = osp.join(opt.workspace, f'results/{object_id}')

    trainer.test(test_loader, save_path=save_path, write_video=False)