"""
    The code is for NeurIPS 2022 paper: Neural Surface Reconstruction of Dynamic Scenes with Monocular RGB-D Camera.
    The code is based on the framework of an open source project: https://github.com/Totoro97/NeuS,
    which is related to the following paper:
        NeuS: Learning Neural Implicit Surfaces by Volume Rendering for Multi-view Reconstruction. Wang et al. NeurIPS 2021.
"""

import os
import logging
import argparse
import numpy as np
import cv2 as cv
import trimesh
import torch
from shutil import copyfile
from pyhocon import ConfigFactory
from models.dataset import Dataset
from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, DeformNetwork, AppearanceNetwork, AmbientNetwork
from models.renderer import NeuSRenderer, DeformNeuSRenderer

class Runner:
    def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False):
        self.device = torch.device('cuda')
        self.gpu = torch.cuda.current_device()
        self.dtype = torch.get_default_dtype()

        # Configuration
        self.conf_path = conf_path
        f = open(self.conf_path)
        conf_text = f.read()
        conf_text = conf_text.replace('CASE_NAME', case)
        f.close()

        self.conf = ConfigFactory.parse_string(conf_text)
        self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case)
        self.base_exp_dir = self.conf['general.base_exp_dir']
        os.makedirs(self.base_exp_dir, exist_ok=True)
        self.dataset = Dataset(self.conf['dataset'])
        self.iter_step = 0

        # Deform
        self.use_deform = self.conf.get_bool('train.use_deform')
        if self.use_deform:
            self.deform_dim = self.conf.get_int('model.deform_network.d_feature')
            self.deform_codes = torch.randn(self.dataset.n_images, self.deform_dim, requires_grad=True).to(self.device)
            self.appearance_dim = self.conf.get_int('model.appearance_rendering_network.d_global_feature')
            self.appearance_codes = torch.randn(self.dataset.n_images, self.appearance_dim, requires_grad=True).to(self.device)

        # Training parameters
        self.end_iter = self.conf.get_int('train.end_iter')

        self.important_begin_iter = self.conf.get_int('model.neus_renderer.important_begin_iter')

        # Anneal
        self.max_pe_iter = self.conf.get_int('train.max_pe_iter')

        self.save_freq = self.conf.get_int('train.save_freq')
        self.report_freq = self.conf.get_int('train.report_freq')
        self.val_freq = self.conf.get_int('train.val_freq')
        self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
        self.validate_idx = self.conf.get_int('train.validate_idx', default=-1)
        self.batch_size = self.conf.get_int('train.batch_size')
        self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level')
        self.learning_rate = self.conf.get_float('train.learning_rate')
        self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha')
        self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0)
        self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
        self.test_batch_size = self.conf.get_int('test.test_batch_size')

        # Weights
        self.igr_weight = self.conf.get_float('train.igr_weight')
        self.mask_weight = self.conf.get_float('train.mask_weight')
        self.is_continue = is_continue
        self.mode = mode
        self.model_list = []
        self.writer = None

        # Depth
        self.use_depth = self.conf.get_bool('dataset.use_depth')
        if self.use_depth:
            self.geo_weight = self.conf.get_float('train.geo_weight')
            self.angle_weight = self.conf.get_float('train.angle_weight')

        # Deform
        if self.use_deform:
            self.deform_network = DeformNetwork(**self.conf['model.deform_network']).to(self.device)
            self.ambient_network = AmbientNetwork(**self.conf['model.ambient_network']).to(self.device)

        self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device)
        self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device)
        # Deform
        if self.use_deform:
            self.color_network = AppearanceNetwork(**self.conf['model.appearance_rendering_network']).to(self.device)
        else:
            self.color_network = RenderingNetwork(**self.conf['model.rendering_network']).to(self.device)

        # Deform
        if self.use_deform:
            self.renderer = DeformNeuSRenderer(self.report_freq,
                                     self.deform_network,
                                     self.ambient_network,
                                     self.sdf_network,
                                     self.deviation_network,
                                     self.color_network,
                                     **self.conf['model.neus_renderer'])
        else:
            self.renderer = NeuSRenderer(self.sdf_network,
                                        self.deviation_network,
                                        self.color_network,
                                        **self.conf['model.neus_renderer'])

        # Load Optimizer
        params_to_train = []
        if self.use_deform:
            params_to_train += [{'name':'deform_network', 'params':self.deform_network.parameters(), 'lr':self.learning_rate}]
            params_to_train += [{'name':'ambient_network', 'params':self.ambient_network.parameters(), 'lr':self.learning_rate}]
            params_to_train += [{'name':'deform_codes', 'params':self.deform_codes, 'lr':self.learning_rate}]
            params_to_train += [{'name':'appearance_codes', 'params':self.appearance_codes, 'lr':self.learning_rate}]
        params_to_train += [{'name':'sdf_network', 'params':self.sdf_network.parameters(), 'lr':self.learning_rate}]
        params_to_train += [{'name':'deviation_network', 'params':self.deviation_network.parameters(), 'lr':self.learning_rate}]
        params_to_train += [{'name':'color_network', 'params':self.color_network.parameters(), 'lr':self.learning_rate}]

        if self.dataset.camera_trainable:
            params_to_train += [{'name':'intrinsics_paras', 'params':self.dataset.intrinsics_paras, 'lr':self.learning_rate}]
            params_to_train += [{'name':'poses_paras', 'params':self.dataset.poses_paras, 'lr':self.learning_rate}]
            # Depth
            if self.use_depth:
                params_to_train += [{'name':'depth_intrinsics_paras', 'params':self.dataset.depth_intrinsics_paras, 'lr':self.learning_rate}]

        self.optimizer = torch.optim.Adam(params_to_train)

        # Load checkpoint
        latest_model_name = None
        if is_continue:
            latest_model_name = 'test.pth'

        if latest_model_name is not None:
            logging.info('Find checkpoint: {}'.format(latest_model_name))
            self.load_checkpoint(latest_model_name)

    def get_cos_anneal_ratio(self):
        if self.anneal_end == 0.0:
            return 1.0
        else:
            return np.min([1.0, self.iter_step / self.anneal_end])

    def load_checkpoint(self, checkpoint_name):
        checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device)
        self.sdf_network.load_state_dict(checkpoint['sdf_network_fine'])
        self.deviation_network.load_state_dict(checkpoint['variance_network_fine'])
        self.color_network.load_state_dict(checkpoint['color_network_fine'])
        # Deform
        if self.use_deform:
            self.deform_network.load_state_dict(checkpoint['deform_network'])
            self.ambient_network.load_state_dict(checkpoint['ambient_network'])
            self.deform_codes = torch.from_numpy(checkpoint['deform_codes']).to(self.device).requires_grad_()
            self.appearance_codes = torch.from_numpy(checkpoint['appearance_codes']).to(self.device).requires_grad_()
            logging.info('Use_deform True')
        self.dataset.intrinsics_paras = torch.from_numpy(checkpoint['intrinsics_paras']).to(self.device)
        self.dataset.poses_paras = torch.from_numpy(checkpoint['poses_paras']).to(self.device)
        # Depth
        if self.use_depth:
            self.dataset.depth_intrinsics_paras = torch.from_numpy(checkpoint['depth_intrinsics_paras']).to(self.device)
        # Camera
        if self.dataset.camera_trainable:
            self.dataset.intrinsics_paras.requires_grad_()
            self.dataset.poses_paras.requires_grad_()
            # Depth
            if self.use_depth:
                self.dataset.depth_intrinsics_paras.requires_grad_()
        else:
            self.dataset.static_paras_to_mat()
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.iter_step = checkpoint['iter_step']

        logging.info('End')

    def validate_image(self, idx=-1, resolution_level=-1, mode='train', normal_filename='normals', rgb_filename='rgbs', depth_filename='depths'):
        print('Validate image: frame: {}'.format(idx))
        if idx < 0:
            idx = np.random.randint(self.dataset.n_images)
        # Deform
        if self.use_deform:
            deform_code = self.deform_codes[idx][None, ...]
            appearance_code = self.appearance_codes[idx][None, ...]
        if mode == 'train':
            batch_size = self.batch_size
        else:
            batch_size = self.test_batch_size

        if resolution_level < 0:
            resolution_level = self.validate_resolution_level
        rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level)
        H, W, _ = rays_o.shape
        rays_o = rays_o.reshape(-1, 3).split(batch_size)
        rays_d = rays_d.reshape(-1, 3).split(batch_size)

        out_rgb_fine = []

        for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
            near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
            if self.use_deform:
                render_out = self.renderer.render(deform_code,
                                                appearance_code,
                                                rays_o_batch,
                                                rays_d_batch,
                                                near,
                                                far,
                                                cos_anneal_ratio=self.get_cos_anneal_ratio(),
                                                alpha_ratio=max(min(self.iter_step/self.max_pe_iter, 1.), 0.),
                                                iter_step=self.iter_step)
                render_out['gradients'] = render_out['gradients_o']
            else:
                render_out = self.renderer.render(rays_o_batch,
                                                rays_d_batch,
                                                near,
                                                far,
                                                cos_anneal_ratio=self.get_cos_anneal_ratio())
            
            def feasible(key): return (key in render_out) and (render_out[key] is not None)

            if feasible('color_fine'):
                out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
            del render_out

        img_fine = None
        if len(out_rgb_fine) > 0:
            img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255)

        depth_img = None
        os.makedirs(os.path.join(self.base_exp_dir, rgb_filename), exist_ok=True)

        for i in range(img_fine.shape[-1]):
            if len(out_rgb_fine) > 0:
                cv.imwrite(os.path.join(self.base_exp_dir,
                                        rgb_filename,
                                        '{}.png'.format(idx)),
                           np.concatenate([img_fine[..., i],
                                           self.dataset.image_at(idx, resolution_level=resolution_level)]))

    def validate_image_with_depth(self, idx=-1, resolution_level=-1, mode='train'):
        if idx < 0:
            idx = np.random.randint(self.dataset.n_images)

        # Deform
        if self.use_deform:
            deform_code = self.deform_codes[idx][None, ...]
            appearance_code = self.appearance_codes[idx][None, ...]
        print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx))
        if mode == 'train':
            batch_size = self.batch_size
        else:
            batch_size = self.test_batch_size

        if resolution_level < 0:
            resolution_level = self.validate_resolution_level
        rays_o, rays_d, rays_s, mask = self.dataset.gen_rays_at_depth(idx, resolution_level=resolution_level)
        H, W, _ = rays_o.shape
        rays_o = rays_o.reshape(-1, 3).split(batch_size)
        rays_d = rays_d.reshape(-1, 3).split(batch_size)
        rays_s = rays_s.reshape(-1, 3).split(batch_size)
        mask = (mask > 0.5).to(self.dtype).detach().cpu().numpy()[..., None] # .float()

        out_rgb_fine = []
        out_normal_fine = []

        for rays_o_batch, rays_d_batch, rays_s_batch in zip(rays_o, rays_d, rays_s):
            color_batch, gradients_batch = self.renderer.renderondepth(deform_code,
                                                    appearance_code,
                                                    rays_o_batch,
                                                    rays_d_batch,
                                                    rays_s_batch,
                                                    alpha_ratio=max(min(self.iter_step/self.max_pe_iter, 1.), 0.))

            out_rgb_fine.append(color_batch.detach().cpu().numpy())
            out_normal_fine.append(gradients_batch.detach().cpu().numpy())
            del color_batch, gradients_batch

        img_fine = None
        if len(out_rgb_fine) > 0:
            img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255)
            img_fine = img_fine * mask

        normal_img = None
        if len(out_normal_fine) > 0:
            normal_img = np.concatenate(out_normal_fine, axis=0)
            # w/ pose -> w/o pose. similar: world -> camera
            # Camera
            if self.dataset.camera_trainable:
                _, pose = self.dataset.dynamic_paras_to_mat(idx)
            else:
                pose = self.dataset.poses_all[idx]
            rot = np.linalg.inv(pose[:3, :3].detach().cpu().numpy())
            normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None])
                          .reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255)
            normal_img = normal_img * mask

        os.makedirs(os.path.join(self.base_exp_dir, 'rgbsondepth'), exist_ok=True)
        os.makedirs(os.path.join(self.base_exp_dir, 'normalsondepth'), exist_ok=True)

        for i in range(img_fine.shape[-1]):
            if len(out_rgb_fine) > 0:
                cv.imwrite(os.path.join(self.base_exp_dir,
                                        'rgbsondepth',
                                        '{:0>8d}_depth_{}.png'.format(self.iter_step, idx)),
                           np.concatenate([img_fine[..., i],
                                           self.dataset.image_at(idx, resolution_level=resolution_level)]))
            if len(out_normal_fine) > 0:
                cv.imwrite(os.path.join(self.base_exp_dir,
                                        'normalsondepth',
                                        '{:0>8d}_depth_{}.png'.format(self.iter_step, idx)),
                           normal_img[..., i])

    def validate_all_image(self, resolution_level=-1):
        for image_idx in range(self.dataset.n_images):
            self.validate_image(image_idx, resolution_level, 'test', 'validations_normals', 'validations_rgbs', 'validations_depths')
            print('device:', self.gpu)

    def validate_mesh(self, world_space=False, resolution=64, threshold=0.0):
        bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=self.dtype) # torch.float32
        bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=self.dtype) # torch.float32
        
        vertices, triangles =\
            self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold)
        os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True)

        if world_space:
            vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]

        mesh = trimesh.Trimesh(vertices, triangles)
        mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}.ply'.format(self.iter_step)))

        logging.info('End')
    
    # Deform
    def validate_observation_mesh(self, idx=-1, world_space=False, resolution=64, threshold=0.0, filename='meshes'):
        print('Validate mesh: frame: {}'.format(idx))
        if idx < 0:
            idx = np.random.randint(self.dataset.n_images)
        # Deform
        deform_code = self.deform_codes[idx][None, ...]
        
        bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=self.dtype) # torch.float32
        bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=self.dtype) # torch.float32
        
        vertices, triangles =\
            self.renderer.extract_observation_geometry(deform_code, bound_min, bound_max, resolution=resolution, threshold=threshold,
                                                        alpha_ratio=max(min(self.iter_step/self.max_pe_iter, 1.), 0.))
        os.makedirs(os.path.join(self.base_exp_dir, filename), exist_ok=True)

        if world_space:
            vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]

        mesh = trimesh.Trimesh(vertices, triangles)
        mesh.export(os.path.join(self.base_exp_dir, filename, '{}.ply'.format(idx)))

        logging.info('End')

    # Deform
    def validate_all_mesh(self, world_space=False, resolution=64, threshold=0.0):
        for image_idx in range(self.dataset.n_images):
            self.validate_observation_mesh(image_idx, world_space, resolution, threshold, 'validations_meshes')
            print('device:', self.gpu)



if __name__ == '__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
    logging.basicConfig(level=logging.DEBUG, format=FORMAT)

    parser = argparse.ArgumentParser()
    parser.add_argument('--conf', type=str, default='./confs/base.conf')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument('--mcube_threshold', type=float, default=0.0)
    parser.add_argument('--is_continue', default=False, action="store_true")
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--case', type=str, default='')

    args = parser.parse_args()

    torch.set_default_dtype(torch.float32)
    torch.cuda.set_device(args.gpu)

    runner = Runner(args.conf, args.mode, args.case, args.is_continue)

    if args.mode == 'validate_mesh':
        if runner.use_deform:
            runner.validate_all_mesh(world_space=False, resolution=512, threshold=args.mcube_threshold)
            runner.validate_all_image(resolution_level=1)
        else:
            runner.validate_mesh(world_space=False, resolution=512, threshold=args.mcube_threshold)
            runner.validate_all_image(resolution_level=1)