import os
import time
import functools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_scatter import segment_coo

from . import grid
from .dvgo import Raw2Alpha, Alphas2Weights
from .dmpigo import create_full_step_id

from torch.utils.cpp_extension import load
parent_dir = os.path.dirname(os.path.abspath(__file__))
ub360_utils_cuda = load(
        name='ub360_utils_cuda',
        sources=[
            os.path.join(parent_dir, path)
            for path in ['cuda/ub360_utils.cpp', 'cuda/ub360_utils_kernel.cu']],
        verbose=True)


#TODO ORIGINAL bg_len=0.2
'''Model'''
class DirectContractedVoxGO(nn.Module):
    def __init__(self, xyz_min, xyz_max,
                 num_voxels=0, num_voxels_base=0, num_objects = 1,
                 alpha_init=None,
                 mask_cache_world_size=None,
                 fast_color_thres=0, bg_len=0.2,
                 contracted_norm='inf',
                 density_type='DenseGrid', k0_type='DenseGrid', f_k0_type='DenseGrid',
                 density_config={}, k0_config={}, f_k0_config={},
                 rgbnet_dim=0,
                 rgbnet_depth=3, rgbnet_width=128,
                 viewbase_pe=4,
                 **kwargs):
        super(DirectContractedVoxGO, self).__init__()
        # xyz_min/max are the boundary that separates fg and bg scene
        xyz_min = torch.Tensor(xyz_min)
        xyz_max = torch.Tensor(xyz_max)
        assert len(((xyz_max - xyz_min) * 100000).long().unique()), 'scene bbox must be a cube in DirectContractedVoxGO'
        self.register_buffer('scene_center', (xyz_min + xyz_max) * 0.5)
        self.register_buffer('scene_radius', (xyz_max - xyz_min) * 0.5)
        self.register_buffer('xyz_min', torch.Tensor([-1,-1,-1]) - bg_len)
        self.register_buffer('xyz_max', torch.Tensor([1,1,1]) + bg_len)
        if isinstance(fast_color_thres, dict):
            self._fast_color_thres = fast_color_thres
            self.fast_color_thres = fast_color_thres[0]
        else:
            self._fast_color_thres = None
            self.fast_color_thres = fast_color_thres
        self.bg_len = bg_len
        self.contracted_norm = contracted_norm

        # determine based grid resolution
        self.num_voxels_base = num_voxels_base
        self.voxel_size_base = ((self.xyz_max - self.xyz_min).prod() / self.num_voxels_base).pow(1/3)
        self.f_num_voxels_base = kwargs['f_num_voxels_base']
        self.f_voxel_size_base = ((self.xyz_max - self.xyz_min).prod() / self.f_num_voxels_base).pow(1/3)

        # determine init grid resolution
        self._set_grid_resolution(num_voxels, kwargs['f_num_voxels'])

        # determine the density bias shift
        self.alpha_init = alpha_init
        self.register_buffer('act_shift', torch.FloatTensor([np.log(1/(1-alpha_init) - 1)]))
        print('dcvgo: set density bias shift to', self.act_shift)

        # init density voxel grid
        self.density_type = density_type
        self.density_config = density_config
        self.density = grid.create_grid(
            density_type, channels=1, world_size=self.world_size,
            xyz_min=self.xyz_min, xyz_max=self.xyz_max,
            config=self.density_config)
        
        self.mode = 'coarse'
        self.num_objects = num_objects
        self.seg_mask_grid = grid.create_grid(
                density_type, channels=self.num_objects, world_size=self.world_size,
                xyz_min=self.xyz_min, xyz_max=self.xyz_max,
                config=self.density_config)
        
                
        self.dual_seg_mask_grid = grid.create_grid(
                density_type, channels=self.num_objects, world_size=self.world_size,
                xyz_min=self.xyz_min, xyz_max=self.xyz_max,
                config=self.density_config)

        # segmentation mask
        self.segmentation_mask = torch.nn.Parameter(torch.ones([1, 1, *self.world_size]), requires_grad=False)

        # init color representation
        self.rgbnet_kwargs = {
            'rgbnet_dim': rgbnet_dim,
            'rgbnet_depth': rgbnet_depth, 'rgbnet_width': rgbnet_width,
            'viewbase_pe': viewbase_pe,
        }
        self.k0_type = k0_type
        self.k0_config = k0_config

        # feature grid
#         self.f_k0_type = f_k0_type
#         self.f_k0_config = f_k0_config
#         self.f_k0_dim = 64
#         self.f_k0 = grid.create_grid(
#             f_k0_type, channels=self.f_k0_dim, world_size=self.f_world_size,
#             xyz_min=self.xyz_min, xyz_max=self.xyz_max,
#             config=self.f_k0_config)
        
        if rgbnet_dim <= 0:
            # color voxel grid  (coarse stage)
            self.k0_dim = 3
            self.k0 = grid.create_grid(
                k0_type, channels=self.k0_dim, world_size=self.world_size,
                xyz_min=self.xyz_min, xyz_max=self.xyz_max,
                config=self.k0_config)
            self.rgbnet = None
        else:
            # feature voxel grid + shallow MLP  (fine stage)
            self.k0_dim = rgbnet_dim
            self.k0 = grid.create_grid(
                k0_type, channels=self.k0_dim, world_size=self.world_size,
                xyz_min=self.xyz_min, xyz_max=self.xyz_max,
                config=self.k0_config)
            self.register_buffer('viewfreq', torch.FloatTensor([(2**i) for i in range(viewbase_pe)]))
            dim0 = (3+3*viewbase_pe*2)
            dim0 += self.k0_dim
            self.rgbnet = nn.Sequential(
                nn.Linear(dim0, rgbnet_width), nn.ReLU(inplace=True),
                *[
                    nn.Sequential(nn.Linear(rgbnet_width, rgbnet_width), nn.ReLU(inplace=True))
                    for _ in range(rgbnet_depth-2)
                ],
                nn.Linear(rgbnet_width, 3),
            )
            nn.init.constant_(self.rgbnet[-1].bias, 0)
            print('dcvgo: feature voxel grid', self.k0)
            print('dcvgo: mlp', self.rgbnet)

        # Using the coarse geometry if provided (used to determine known free space and unknown space)
        # Re-implement as occupancy grid (2021/1/31)
        if mask_cache_world_size is None:
            mask_cache_world_size = self.world_size
        mask = torch.ones(list(mask_cache_world_size), dtype=torch.bool)
        self.mask_cache = grid.MaskGrid(
            path=None, mask=mask,
            xyz_min=self.xyz_min, xyz_max=self.xyz_max)

    def _set_grid_resolution(self, num_voxels, f_num_voxels):
        # Determine grid resolution
        self.num_voxels = num_voxels
        self.f_num_voxels = f_num_voxels

        self.voxel_size = ((self.xyz_max - self.xyz_min).prod() / num_voxels).pow(1/3)
        self.f_voxel_size = ((self.xyz_max - self.xyz_min).prod() / f_num_voxels).pow(1/3)
        self.world_size = ((self.xyz_max - self.xyz_min) / self.voxel_size).long()
        self.f_world_size = ((self.xyz_max - self.xyz_min) / self.f_voxel_size).long()

        self.world_len = self.world_size[0].item()
        self.f_world_len = self.f_world_size[0].item()

        self.voxel_size_ratio = self.voxel_size / self.voxel_size_base
        self.f_voxel_size_ratio = self.f_voxel_size / self.f_voxel_size_base

        print('dcvgo: voxel_size      ', self.voxel_size)
        print('dcvgo: world_size      ', self.world_size)
        print('dcvgo: voxel_size_base ', self.voxel_size_base)
        print('dcvgo: voxel_size_ratio', self.voxel_size_ratio)

    def get_kwargs(self):
        return {
            'xyz_min': self.xyz_min.cpu().numpy(),
            'xyz_max': self.xyz_max.cpu().numpy(),
            'num_voxels': self.num_voxels,
            'num_voxels_base': self.num_voxels_base,
            'f_num_voxels': self.f_num_voxels,
            'f_num_voxels_base': self.f_num_voxels_base,
            'alpha_init': self.alpha_init,
            'voxel_size_ratio': self.voxel_size_ratio,
            'f_voxel_size_ratio': self.voxel_size_ratio,
            'mask_cache_world_size': list(self.mask_cache.mask.shape),
            'fast_color_thres': self.fast_color_thres,
            'contracted_norm': self.contracted_norm,
            'density_type': self.density_type,
            'k0_type': self.k0_type,
#             'f_k0_type': self.f_k0_type,
            'density_config': self.density_config,
            'k0_config': self.k0_config,
#             'f_k0_config': self.f_k0_config,
            **self.rgbnet_kwargs,
        }
    
    @torch.no_grad()
    def change_num_objects(self, num_obj):
        self.num_objects = num_obj
        device = self.seg_mask_grid.grid.device
        self.seg_mask_grid = grid.create_grid(
                'DenseGrid', channels=self.num_objects, world_size=self.world_size,
                xyz_min=self.xyz_min, xyz_max=self.xyz_max,
                config=self.density_config)
        self.dual_seg_mask_grid = grid.create_grid(
                'DenseGrid', channels=self.num_objects, world_size=self.world_size,
                xyz_min=self.xyz_min, xyz_max=self.xyz_max,
                config=self.density_config)
        self.seg_mask_grid.to(device)
        self.dual_seg_mask_grid.to(device)
        print("Reset the seg_mask_grid with num_objects =", num_obj)
        
    @torch.no_grad()
    def segmentation_to_density(self):
        device = self.seg_mask_grid.grid.device
        assert self.seg_mask_grid.grid.shape[1] == 1 and "multi-object seg label cannot be applied directly to the density grid"
        thresh = self.seg_mask_grid.grid.mean()
        mask_grid = torch.zeros_like(self.seg_mask_grid.grid)
        mask_grid[self.seg_mask_grid.grid > max(thresh + 2*self.seg_mask_grid.grid.std(), 0)] = 1
        # mask_grid[self.seg_mask_grid.grid > 0] = 1
        
        self.density.grid = nn.Parameter(self.density.grid.to(device) * mask_grid.to(device))
        self.density.grid[self.density.grid == 0] = -1e7

        del mask_grid

    @torch.no_grad()
    def segmentation_only(self):
        device = self.seg_mask_grid.grid.device
        assert self.seg_mask_grid.grid.shape[1] == 1 and "multi-object seg label cannot be applied directly to the density grid"
        thresh = self.seg_mask_grid.grid.mean()
        invalid = (self.seg_mask_grid.grid <= max(thresh + 2*self.seg_mask_grid.grid.std(), 0)) & (self.seg_mask_grid.grid > 0)
        self.seg_mask_grid.grid[invalid] = 0
        
    @torch.no_grad()
    def remove_target_object(self):
        device = self.seg_mask_grid.grid.device
        assert self.seg_mask_grid.grid.shape[1] == 1 and "multi-object seg label cannot be applied directly to the density grid"
        thresh = self.seg_mask_grid.grid.mean()
        mask_grid = torch.ones_like(self.seg_mask_grid.grid)
        mask_grid[self.seg_mask_grid.grid > max(thresh - 2*self.seg_mask_grid.grid.std(), 0)] = 0
        
#         self.density.grid = nn.Parameter(self.density.grid.to(device) * mask_grid.to(device))
        tmp_density_grid = self.density.grid.to(device)
        tmp_density_grid[mask_grid == 0] = -1e9
        self.density.grid = nn.Parameter(tmp_density_grid)

        del mask_grid
        
    @torch.no_grad()
    def multi_object_segmentation_to_density(self):
        device = self.seg_mask_grid.grid.device

        tmp_grid = self.seg_mask_grid.grid.view(1, self.seg_mask_grid.grid.shape[1], -1)
        thresh = tmp_grid.mean(dim = -1) + tmp_grid.std(dim = -1)
        thresh = thresh[:,:,None,None,None]
        mask_grid = torch.zeros_like(self.seg_mask_grid.grid)
        mask_grid[torch.logical_and(self.seg_mask_grid.grid > thresh, self.seg_mask_grid.grid > 0)] = 1
        
        self.seg_mask_grid.grid = nn.Parameter(mask_grid.to(device))

        del mask_grid
        
    @torch.no_grad()
    def change_to_fine_mode(self):

#         thresh = self.seg_mask_grid.grid.mean()
#         self.seg_mask_grid.grid = nn.Parameter(self.seg_mask_grid.grid - thresh)
#         self.seg_mask_grid.grid.to(device)
        
#         print(device, "?????????????????", self.seg_mask_grid.grid.device)
        self.mode = 'fine'

    @torch.no_grad()
    def scale_volume_grid(self, num_voxels, f_num_voxels):
        print('dcvgo: scale_volume_grid start')
        ori_world_size = self.world_size
        f_ori_world_size = self.f_world_size
        self._set_grid_resolution(num_voxels, f_num_voxels)
        print('dcvgo: scale_volume_grid scale world_size from', ori_world_size.tolist(), 'to', self.world_size.tolist())
        print('dvgo: scale_volume_grid scale f_world_size from', f_ori_world_size.tolist(), 'to', self.f_world_size.tolist())

        self.density.scale_volume_grid(self.world_size)
        self.seg_mask_grid.scale_volume_grid(self.world_size)
        self.dual_seg_mask_grid.scale_volume_grid(self.world_size)
        self.k0.scale_volume_grid(self.world_size)
#         self.f_k0.scale_volume_grid(self.f_world_size)
        self.segmentation_mask = torch.nn.Parameter(torch.ones([1, 1, *self.world_size]), requires_grad=False)

        if np.prod(self.world_size.tolist()) <= 256**3:
            self_grid_xyz = torch.stack(torch.meshgrid(
                torch.linspace(self.xyz_min[0], self.xyz_max[0], self.world_size[0]),
                torch.linspace(self.xyz_min[1], self.xyz_max[1], self.world_size[1]),
                torch.linspace(self.xyz_min[2], self.xyz_max[2], self.world_size[2]),
            ), -1)
            self_alpha = F.max_pool3d(self.activate_density(self.density.get_dense_grid()), kernel_size=3, padding=1, stride=1)[0,0]
            self.mask_cache = grid.MaskGrid(
                path=None, mask=self.mask_cache(self_grid_xyz) & (self_alpha>self.fast_color_thres),
                xyz_min=self.xyz_min, xyz_max=self.xyz_max)

        print('dcvgo: scale_volume_grid finish')

    @torch.no_grad()
    def update_occupancy_cache(self):
        ori_p = self.mask_cache.mask.float().mean().item()
        cache_grid_xyz = torch.stack(torch.meshgrid(
            torch.linspace(self.xyz_min[0], self.xyz_max[0], self.mask_cache.mask.shape[0]),
            torch.linspace(self.xyz_min[1], self.xyz_max[1], self.mask_cache.mask.shape[1]),
            torch.linspace(self.xyz_min[2], self.xyz_max[2], self.mask_cache.mask.shape[2]),
        ), -1)
        cache_grid_density = self.density(cache_grid_xyz)[None,None]
        cache_grid_alpha = self.activate_density(cache_grid_density)
        cache_grid_alpha = F.max_pool3d(cache_grid_alpha, kernel_size=3, padding=1, stride=1)[0,0]
        self.mask_cache.mask &= (cache_grid_alpha > self.fast_color_thres)
        new_p = self.mask_cache.mask.float().mean().item()
        print(f'dcvgo: update mask_cache {ori_p:.4f} => {new_p:.4f}')

    def update_occupancy_cache_lt_nviews(self, rays_o_tr, rays_d_tr, imsz, render_kwargs, maskout_lt_nviews):
        print('dcvgo: update mask_cache lt_nviews start')
        eps_time = time.time()
        count = torch.zeros_like(self.density.get_dense_grid()).long()
        device = count.device
        for rays_o_, rays_d_ in zip(rays_o_tr.split(imsz), rays_d_tr.split(imsz)):
            ones = grid.DenseGrid(1, self.world_size, self.xyz_min, self.xyz_max)
            for rays_o, rays_d in zip(rays_o_.split(8192), rays_d_.split(8192)):
                ray_pts, inner_mask, t = self.sample_ray(
                        ori_rays_o=rays_o.to(device), ori_rays_d=rays_d.to(device),
                        **render_kwargs)
                ones(ray_pts).sum().backward()
            count.data += (ones.grid.grad > 1)
        ori_p = self.mask_cache.mask.float().mean().item()
        self.mask_cache.mask &= (count >= maskout_lt_nviews)[0,0]
        new_p = self.mask_cache.mask.float().mean().item()
        print(f'dcvgo: update mask_cache {ori_p:.4f} => {new_p:.4f}')
        eps_time = time.time() - eps_time
        print(f'dcvgo: update mask_cache lt_nviews finish (eps time:', eps_time, 'sec)')

    def density_total_variation_add_grad(self, weight, dense_mode):
        w = weight * self.world_size.max() / 128
        self.density.total_variation_add_grad(w, w, w, dense_mode)

    def k0_total_variation_add_grad(self, weight, dense_mode):
        w = weight * self.world_size.max() / 128
        self.k0.total_variation_add_grad(w, w, w, dense_mode)

#     def f_k0_total_variation_add_grad(self, weight, dense_mode):
#         w = weight * self.f_world_size.max() / 128
#         self.f_k0.total_variation_add_grad(w, w, w, dense_mode)

    def activate_density(self, density, interval=None):
        interval = interval if interval is not None else self.voxel_size_ratio
        shape = density.shape
        return Raw2Alpha.apply(density.flatten(), self.act_shift, interval).reshape(shape)

    def sample_ray(self, ori_rays_o, ori_rays_d, stepsize, is_train=False, **render_kwargs):
        '''Sample query points on rays.
        All the output points are sorted from near to far.
        Input:
            rays_o, rayd_d:   both in [N, 3] indicating ray configurations.
            stepsize:         the number of voxels of each sample step.
        Output:
            ray_pts:          [M, 3] storing all the sampled points.
            ray_id:           [M]    the index of the ray of each point.
            step_id:          [M]    the i'th step on a ray of each point.
        '''
        rays_o = (ori_rays_o - self.scene_center) / self.scene_radius
        rays_d = ori_rays_d / ori_rays_d.norm(dim=-1, keepdim=True)
        N_inner = int(2 / (2+2*self.bg_len) * self.world_len / stepsize) + 1
        N_outer = N_inner
        b_inner = torch.linspace(0, 2, N_inner+1)
        b_outer = 2 / torch.linspace(1, 1/128, N_outer+1)
        t = torch.cat([
            (b_inner[1:] + b_inner[:-1]) * 0.5,
            (b_outer[1:] + b_outer[:-1]) * 0.5,
        ])
        ray_pts = rays_o[:,None,:] + rays_d[:,None,:] * t[None,:,None]
        if self.contracted_norm == 'inf':
            norm = ray_pts.abs().amax(dim=-1, keepdim=True)
        elif self.contracted_norm == 'l2':
            norm = ray_pts.norm(dim=-1, keepdim=True)
        else:
            raise NotImplementedError
        inner_mask = (norm<=1)
        ray_pts = torch.where(
            inner_mask,
            ray_pts,
            ray_pts / norm * ((1+self.bg_len) - self.bg_len/norm)
        )
        return ray_pts, inner_mask.squeeze(-1), t

    @torch.no_grad()
    def forward(self, rays_o, rays_d, viewdirs, global_step=None, is_train=False, distill_active=False, render_fct=0.0, **render_kwargs):
        '''Volume rendering
        @rays_o:   [N, 3] the starting point of the N shooting rays.
        @rays_d:   [N, 3] the shooting direction of the N rays.
        @viewdirs: [N, 3] viewing direction to compute positional embedding for MLP.
        '''
        assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format'
        if isinstance(self._fast_color_thres, dict) and global_step in self._fast_color_thres:
            print(f'dcvgo: update fast_color_thres {self.fast_color_thres} => {self._fast_color_thres[global_step]}')
            self.fast_color_thres = self._fast_color_thres[global_step]

        ret_dict = {}
        N = len(rays_o)

        # sample points on rays
        ray_pts, inner_mask, t = self.sample_ray(
                ori_rays_o=rays_o, ori_rays_d=rays_d, is_train=global_step is not None, **render_kwargs)
        n_max = len(t)
        interval = render_kwargs['stepsize'] * self.voxel_size_ratio
        ray_id, step_id = create_full_step_id(ray_pts.shape[:2])

        # cumsum ray_pts to get distance from ray_o to any ray_pt in a ray
        ray_distance = torch.zeros_like(ray_pts)
        ray_distance[:, 1:] = torch.abs(ray_pts[:, 1:] - ray_pts[:, :-1])
        ray_distance = torch.cumsum(ray_distance, dim=1)

        # skip oversampled points outside scene bbox
        mask = inner_mask.clone()
        dist_thres = (2+2*self.bg_len) / self.world_len * render_kwargs['stepsize'] * 0.95
        dist = (ray_pts[:,1:] - ray_pts[:,:-1]).norm(dim=-1)
        mask[:, 1:] |= ub360_utils_cuda.cumdist_thres(dist, dist_thres)
        ray_pts = ray_pts[mask]
        ray_distance = ray_distance[mask]
        inner_mask = inner_mask[mask]
        t = t[None].repeat(N,1)[mask]
        ray_id = ray_id[mask.flatten()]
        step_id = step_id[mask.flatten()]

        # skip known free space
        mask = self.mask_cache(ray_pts)
        ray_pts = ray_pts[mask]
        ray_distance = ray_distance[mask]
        inner_mask = inner_mask[mask]
        t = t[mask]
        ray_id = ray_id[mask]
        step_id = step_id[mask]

#         print(self.fast_color_thres, "self.fast_color_thres")
        render_fct = max(render_fct, self.fast_color_thres)

        # query for alpha w/ post-activation
        density = self.density(ray_pts)
        alpha = self.activate_density(density, interval)
        if render_fct > 0:
            mask = (alpha > render_fct)
            ray_pts = ray_pts[mask]
            ray_distance = ray_distance[mask]
            inner_mask = inner_mask[mask]
            t = t[mask]
            ray_id = ray_id[mask]
            step_id = step_id[mask]
            density = density[mask]
            alpha = alpha[mask]

        # compute accumulated transmittance
        weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N)
        if render_fct > 0:
            mask = (weights > render_fct)
            ray_pts = ray_pts[mask]
            ray_distance = ray_distance[mask]
            inner_mask = inner_mask[mask]
            t = t[mask]
            ray_id = ray_id[mask]
            step_id = step_id[mask]
            density = density[mask]
            alpha = alpha[mask]
            weights = weights[mask]

        # query for feature
#         if distill_active:
#             f_k0 = self.f_k0(ray_pts)
            #f_k0 = torch.sigmoid(f_k0) * 2.0 - 1.0
#         else:
#             f_k0 = None

        # query for segmentation mask
        # only optimize the mask volume
        if self.seg_mask_grid.grid.requires_grad:
            with torch.enable_grad():
                mask_pred = self.seg_mask_grid(ray_pts)
                if self.mode == 'fine':
                    dual_mask_pred = self.dual_seg_mask_grid(ray_pts)
        else:
            mask_pred = self.seg_mask_grid(ray_pts)
            if self.mode == 'fine':
                dual_mask_pred = self.dual_seg_mask_grid(ray_pts)
        
        # query for color
        k0 = self.k0(ray_pts)
        if self.rgbnet is None:
            # no view-depend effect
            rgb = torch.sigmoid(k0)
        else:
            # view-dependent color emission
            viewdirs_emb = (viewdirs.unsqueeze(-1) * self.viewfreq).flatten(-2)
            viewdirs_emb = torch.cat([viewdirs, viewdirs_emb.sin(), viewdirs_emb.cos()], -1)
            viewdirs_emb = viewdirs_emb.flatten(0,-2)[ray_id]
            rgb_feat = torch.cat([k0, viewdirs_emb], -1)
            rgb_logit = self.rgbnet(rgb_feat)
            rgb = torch.sigmoid(rgb_logit)

        # Ray marching
        rgb_marched = segment_coo(
                src=(weights.unsqueeze(-1) * rgb),
                index=ray_id,
                out=torch.zeros([N, 3]),
                reduce='sum')
            
#         print(weights.max(), weights.min(), weights.mean(), weights.std(), "???")
#         mask_weights = weights.unsqueeze(-1)
#         mask_weights = torch.nn.functional.relu(weights.unsqueeze(-1) - 0.1)
        dual_seg_mask_marched = None
        if self.num_objects == 1:
            if self.seg_mask_grid.grid.requires_grad:
                with torch.enable_grad():
                    seg_mask_marched = segment_coo(
                            src=(weights.unsqueeze(-1).detach().clone() * mask_pred.unsqueeze(-1)),
                            index=ray_id,
                            out=torch.zeros([N, self.num_objects]),
                            reduce='sum')
                    if self.mode == 'fine':
                        dual_seg_mask_marched = segment_coo(
                            src=(weights.unsqueeze(-1).detach().clone() * dual_mask_pred.unsqueeze(-1)),
                            index=ray_id,
                            out=torch.zeros([N, self.num_objects]),
                            reduce='sum')
            else:
                seg_mask_marched = segment_coo(
                            src=(weights.unsqueeze(-1) * mask_pred.unsqueeze(-1)),
                            index=ray_id,
                            out=torch.zeros([N, self.num_objects]),
                            reduce='sum')
                if self.mode == 'fine':
                    dual_seg_mask_marched = segment_coo(
                        src=(weights.unsqueeze(-1) * dual_mask_pred.unsqueeze(-1)),
                        index=ray_id,
                        out=torch.zeros([N, self.num_objects]),
                        reduce='sum')
        else:
            if self.seg_mask_grid.grid.requires_grad:
                with torch.enable_grad():
                    seg_mask_marched = segment_coo(
                            src=(weights.unsqueeze(-1).detach().clone() * mask_pred),
                            index=ray_id,
                            out=torch.zeros([N, self.num_objects]),
                            reduce='sum')
                    if self.mode == 'fine':
                        dual_seg_mask_marched = segment_coo(
                            src=(weights.unsqueeze(-1).detach().clone() * dual_mask_pred.unsqueeze(-1)),
                            index=ray_id,
                            out=torch.zeros([N, self.num_objects]),
                            reduce='sum')
            else:
                seg_mask_marched = segment_coo(
                            src=(weights.unsqueeze(-1) * mask_pred),
                            index=ray_id,
                            out=torch.zeros([N, self.num_objects]),
                            reduce='sum')
                if self.mode == 'fine':
                    dual_seg_mask_marched = segment_coo(
                        src=(weights.unsqueeze(-1) * dual_mask_pred.unsqueeze(-1)),
                        index=ray_id,
                        out=torch.zeros([N, self.num_objects]),
                        reduce='sum')

#         if distill_active:
#             f_marched = segment_coo(
#                     src=(weights.unsqueeze(-1) * f_k0),
#                     index=ray_id,
#                     out=torch.zeros([N, 64]),
#                     reduce='sum')
#         else:
#             f_marched = None

        if render_kwargs.get('rand_bkgd', False) and is_train:
            rgb_marched += (alphainv_last.unsqueeze(-1) * torch.rand_like(rgb_marched))
        else:
            rgb_marched += (alphainv_last.unsqueeze(-1) * render_kwargs['bg'])
        wsum_mid = segment_coo(
                src=weights[inner_mask],
                index=ray_id[inner_mask],
                out=torch.zeros([N]),
                reduce='sum')

        s = 1 - 1/(1+t)  # [0, inf] => [0, 1]
        ray_distance = ray_distance.norm(dim=-1)
        ret_dict.update({
            'alphainv_last': alphainv_last,
            'weights': weights,
            'wsum_mid': wsum_mid,
            'rgb_marched': rgb_marched,
            'raw_density': density,
            'raw_alpha': alpha,
            'raw_rgb': rgb,
            'ray_id': ray_id,
            'step_id': step_id,
            'n_max': n_max,
            't': t,
            's': s,
            'f_marched': None,
            'seg_mask_marched': seg_mask_marched,
            'dual_seg_mask_marched': dual_seg_mask_marched,
            'ray_distance': ray_distance
        })

        if render_kwargs.get('render_depth', False):
            with torch.no_grad():
                depth = segment_coo(
                        src=(weights * s),
                        index=ray_id,
                        out=torch.zeros([N]),
                        reduce='sum')
                distance = segment_coo(
                        src=(weights * ray_distance),
                        index=ray_id,
                        out=torch.zeros([N]),
                        reduce='sum')
            ret_dict.update({'depth': depth})
            ret_dict.update({'distance': distance})

        return ret_dict
    
    @torch.no_grad()
    def forward_mask(self, rays_o, rays_d, render_fct=0.0,**render_kwargs):
        '''Volume rendering
        @rays_o:   [N, 3] the starting point of the N shooting rays.
        @rays_d:   [N, 3] the shooting direction of the N rays.
        '''
        assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format'
#         if isinstance(self._fast_color_thres, dict) and global_step in self._fast_color_thres:
#             print(f'dcvgo: update fast_color_thres {self.fast_color_thres} => {self._fast_color_thres[global_step]}')
#             self.fast_color_thres = self._fast_color_thres[global_step]

        ret_dict = {}
        N = len(rays_o)

        # sample points on rays
        ray_pts, inner_mask, t = self.sample_ray(
                ori_rays_o=rays_o, ori_rays_d=rays_d, is_train=False, **render_kwargs)
        n_max = len(t)
        interval = render_kwargs['stepsize'] * self.voxel_size_ratio
        ray_id, step_id = create_full_step_id(ray_pts.shape[:2])

        # skip oversampled points outside scene bbox
        mask = inner_mask.clone()
        dist_thres = (2+2*self.bg_len) / self.world_len * render_kwargs['stepsize'] * 0.95
        dist = (ray_pts[:,1:] - ray_pts[:,:-1]).norm(dim=-1)
        mask[:, 1:] |= ub360_utils_cuda.cumdist_thres(dist, dist_thres)
        ray_pts = ray_pts[mask]
        inner_mask = inner_mask[mask]
        t = t[None].repeat(N,1)[mask]
        ray_id = ray_id[mask.flatten()]
        step_id = step_id[mask.flatten()]

        # skip known free space
        mask = self.mask_cache(ray_pts)
        ray_pts = ray_pts[mask]
        inner_mask = inner_mask[mask]
        t = t[mask]
        ray_id = ray_id[mask]
        step_id = step_id[mask]
        
        
        render_fct = max(render_fct, self.fast_color_thres)

        # query for alpha w/ post-activation
        density = self.density(ray_pts)
        alpha = self.activate_density(density, interval)
        if render_fct > 0:
            mask = (alpha > render_fct)
            ray_pts = ray_pts[mask]
            inner_mask = inner_mask[mask]
            t = t[mask]
            ray_id = ray_id[mask]
            step_id = step_id[mask]
            density = density[mask]
            alpha = alpha[mask]

        # compute accumulated transmittance
        weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N)
        if render_fct > 0:
            mask = (weights > render_fct)
            ray_pts = ray_pts[mask]
            inner_mask = inner_mask[mask]
            t = t[mask]
            ray_id = ray_id[mask]
            step_id = step_id[mask]
            density = density[mask]
            alpha = alpha[mask]
            weights = weights[mask]

        # query for segmentation mask
        # only optimize the mask volume
        if self.seg_mask_grid.grid.requires_grad:
            with torch.enable_grad():
                mask_pred = self.seg_mask_grid(ray_pts)
        else:
            mask_pred = self.seg_mask_grid(ray_pts)

        if self.seg_mask_grid.grid.requires_grad:
            with torch.enable_grad():
                seg_mask_marched = segment_coo(
                        src=(weights.unsqueeze(-1) * mask_pred),
                        index=ray_id,
                        out=torch.zeros([N, self.num_objects]),
                        reduce='sum')
        else:
            seg_mask_marched = segment_coo(
                        src=(weights.unsqueeze(-1) * mask_pred),
                        index=ray_id,
                        out=torch.zeros([N, self.num_objects]),
                        reduce='sum')
            

        ret_dict.update({
            'seg_mask_marched': seg_mask_marched,
        })

        return ret_dict

class DistortionLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, s, n_max, ray_id):
        n_rays = ray_id.max()+1
        interval = 1/n_max
        w_prefix, w_total, ws_prefix, ws_total = ub360_utils_cuda.segment_cumsum(w, s, ray_id)
        loss_uni = (1/3) * interval * w.pow(2)
        loss_bi = 2 * w * (s * w_prefix - ws_prefix)
        ctx.save_for_backward(w, s, w_prefix, w_total, ws_prefix, ws_total, ray_id)
        ctx.interval = interval
        return (loss_bi.sum() + loss_uni.sum()) / n_rays

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_back):
        w, s, w_prefix, w_total, ws_prefix, ws_total, ray_id = ctx.saved_tensors
        interval = ctx.interval
        grad_uni = (1/3) * interval * 2 * w
        w_suffix = w_total[ray_id] - (w_prefix + w)
        ws_suffix = ws_total[ray_id] - (ws_prefix + w*s)
        grad_bi = 2 * (s * (w_prefix - w_suffix) + (ws_suffix - ws_prefix))
        grad = grad_back * (grad_bi + grad_uni)
        return grad, None, None, None

distortion_loss = DistortionLoss.apply

