# python3.7
"""Contains the function of rendering images.

Here rendering is the process of generating 2D images or features from the input
primitives, including 3D point colors or features, densities, etc.

Image rendering is an essential step for Neural Radiance Field (NeRF).

Paper: https://arxiv.org/pdf/2003.08934.pdf

Note that the rendering result of original NeRF is three-channel images.
However, the rendering result can be three-channel images or multichannel
feature maps in 3D generation.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from .point_sampler import PointSampler
from .point_integrator import PointIntegrator

__all__ = ['Renderer']


class Renderer(nn.Module):
    """Defines the class to render images.

    This class implements the `forward()` function for rendering, which
    includes the following steps:

    1. Sample points in 3D Space.
    2. Get the reference representation by injecting latent codes into the
       reference representation generator. Basically, the reference
       representation can be a feature volume, a triplane or others.
    3. Get the corresponding density and color or feature of each sampled point
       by the embedder.
    4. Do the integration along the coarse pass.
    5. Hierarchically sample points based on the coarse pass.
    6. Do the integration along the fine pass.
    """

    def __init__(self, point_sampling_kwargs=None, ray_marching_kwargs=None):
        """Initializes hyper-parameters for rendering.

        Detailed description of each element of `point_sampling_kwargs` and
        `ray_marching_kwargs` can be found in class `PointSampler` and
        `PointIntegrator` respectively.
        """
        super().__init__()

        if point_sampling_kwargs is None:
            point_sampling_kwargs = {}
        if ray_marching_kwargs is None:
            ray_marching_kwargs = {}

        self.point_sampler = PointSampler(**point_sampling_kwargs)
        self.point_integrator = PointIntegrator(**ray_marching_kwargs)

    def forward(self,
                latents,
                embedder,
                image_size,
                cam2world_matrix=None,
                ref_representation=None,
                mlp=None,
                mlp_kwargs=None,
                num_importance=0,
                smooth_weights=True,
                noise_std=0.,
                point_sampling_kwargs=None,
                ray_marching_kwargs=None):
        """Renders images.

        For simplicity, we define the following notations:

        `N` denotes batch size.
        `H` denotes the height of image.
        `W` denotes the width of image.
        `R` denotes the number of rays, which usually equals `H * W`.
        `K` denotes the number of points on each ray.
        `C` denotes the number of channels w.r.t. 2D features or images.

        Args:
            latents: Latent codes in W+ space.
            embedder: Feature embedder which takes in points and output point
                features.
            image_size: Size of the rendered image. One element indicates
                square image, while two elements stand for height and width
                respectively. Denoted as `H` and `W`.
            cam2world_matrix: Camera-to-world matrix. Defaults to `None`.
            ref_representation: Reference representation. Defaults to `None`.
            mlp: Multilayer perceptron for extracting point features.
            mlp_kwargs: Keyword arguments of `mlp`.
            num_importance: Number of points for importance sampling.
            smooth_weights: Whether to smooth weights when do importance
                sampling. Defaults to `True`.
            noise_std: Standard deviation of noise added to output densities.
                Defaults to `0.`.
            point_sampling_kwargs: Additional keyword arguments to override the
                variables initialized in `__init__()` w.r.t. `PointSampler`.
            ray_marching_kwargs: Additional keyword arguments to override the
                variables initialized in `__init__()` w.r.t. `PointIntegrator`.
        """
        if point_sampling_kwargs is None:
            point_sampling_kwargs = {}
        if ray_marching_kwargs is None:
            ray_marching_kwargs = {}
        if mlp_kwargs is None:
            mlp_kwargs = {}

        N = latents.shape[0]

        point_sampling_result = self.point_sampler(
            batch_size=N,
            image_size=image_size,
            cam2world_matrix=cam2world_matrix,
            **point_sampling_kwargs)

        points = point_sampling_result['points_world']  # [N, H, W, K, 3]
        ray_dirs = point_sampling_result['rays_world']  # [N, H, W, 3]
        radii_coarse = point_sampling_result['radii']  # [N, H, W, K]
        ray_origins = point_sampling_result['cam2world_matrix'][:, :3,
                                                                -1]  # [N, 3]

        camera_polar = point_sampling_result['camera_polar']  # [N]
        camera_azimuthal = point_sampling_result['camera_azimuthal']  # [N]
        if camera_polar is not None:
            camera_polar = camera_polar.unsqueeze(-1)
        if camera_azimuthal is not None:
            camera_azimuthal = camera_azimuthal.unsqueeze(-1)

        _, H, W, K, _ = points.shape
        R = H * W
        points = points.reshape(N, -1, 3)  # [N, R * K, 3]
        ray_dirs = ray_dirs.reshape(N, R, 3)
        ray_origins = ray_origins.unsqueeze(1).repeat(1, R, 1)  # [N, R, 3]
        radii_coarse = radii_coarse.reshape(N, R, K, 1)

        mlp_kwargs.update(latents=latents,
                          ray_dirs=ray_dirs,
                          raw_shape=(H, W, K))

        # Get densities and colors or features of sampled points along the
        # coarse pass.
        result = get_density_color(points,
                                   embedder,
                                   ref_representation=ref_representation,
                                   mlp=mlp,
                                   mlp_kwargs=mlp_kwargs,
                                   noise_std=noise_std)

        densities_coarse = result['density']  # [N, R * K, 1]
        colors_coarse = result['color']  # [N, R * K, C]
        densities_coarse = densities_coarse.reshape(N, R, K, 1)
        colors_coarse = colors_coarse.reshape(N, R, K, colors_coarse.shape[-1])

        if num_importance > 0:
            # Do the integration along the coarse pass.
            rendering_result = self.point_integrator(colors_coarse,
                                                     densities_coarse,
                                                     radii_coarse,
                                                     **ray_marching_kwargs)
            weights = rendering_result['weight']

            # Importance sampling.
            radii_fine = sample_importance(
                radii_coarse,
                weights,
                num_importance,
                smooth_weights=smooth_weights)
            points = ray_origins.unsqueeze(
                -2) + radii_fine * ray_dirs.unsqueeze(
                -2)  # [N, R, num_importance, 3]
            points = points.reshape(N, -1, 3)  # [N, R * num_importance, 3]

            # Get densities and colors or features of sampled points along the
            # fine pass.
            result = get_density_color(points,
                                       embedder,
                                       ref_representation=ref_representation,
                                       mlp=mlp,
                                       mlp_kwargs=mlp_kwargs,
                                       noise_std=noise_std)

            densities_fine = result['density']
            colors_fine = result['color']
            densities_fine = densities_fine.reshape(N, R, num_importance, 1)
            colors_fine = colors_fine.reshape(N, R, num_importance,
                                              colors_fine.shape[-1])

            # Gather coarse and fine results together.
            all_radiis, all_colors, all_densities = unify_attributes(
                radii_coarse, colors_coarse, densities_coarse, radii_fine,
                colors_fine, densities_fine)

            # Do the integration along the fine pass.
            rendering_result = self.point_integrator(all_colors,
                                                     all_densities,
                                                     all_radiis,
                                                     **ray_marching_kwargs)

        else:
            # Only do the integration along the coarse pass.
            rendering_result = self.point_integrator(colors_coarse,
                                                     densities_coarse,
                                                     radii_coarse,
                                                     **ray_marching_kwargs)

        rendering_result = {
            **rendering_result,
            **{
                'camera_azimuthal': camera_azimuthal,
                'camera_polar': camera_polar
            }
        }

        return rendering_result


def get_density_color(points,
                      embedder,
                      ref_representation=None,
                      mlp=None,
                      mlp_kwargs=None,
                      noise_std=0):
    """Get density and color or feature of each point.

    Args:
        points: Coordinates of sampled points in 3D space, with shape
            [N, R, K, 3].
        embedder: Feature embedder that takes in point coordinates and
            output point features.
        ref_representation: Reference representation. Defaults to `None`.
        mlp: Multilayer perceptron for extracting point features.
            Defaults to `None`.
        mlp_kwargs: Keyword arguments of `mlp`. Defaults to `None`.
        noise_std: Standard deviation of noise added to output densities.
            Defaults to `0`.

    Returns:
        A dictionary, containing
            - `density`: density value of each point, with shape
                [N, R, K, 1].
            - `color`: color or feature value of each point, with shape
                [N, R, K, C].
    """
    # Note: If `ref_representation` is `None`, then `point_features` is actually
    # equivalent to `points`.
    point_features = embedder(points, ref_representation)

    if mlp_kwargs is None:
        mlp_kwargs = {}
    if ref_representation is not None:
        # Since `point_features` is not equivalent to `points` in this
        # scenario, we also include points as part of `mlp_kwargs`.
        mlp_kwargs.update(points=points)

    results = mlp(point_features, **mlp_kwargs)

    if noise_std > 0:
        results['density'] = results['density'] + torch.randn_like(
            results['density']) * noise_std

    return results


def sample_importance(radial_dists,
                      weights,
                      num_importance,
                      smooth_weights=False):
    """Implements importance sampling, which is the crucial step in hierarchical
    sampling of NeRF. Hierarchical volume sampling mainly includes the following
    steps:

    1. Sample a set of `Nc` points using stratified sampling.
    2. Evaluate the 'coarse' network at locations of these points as described
       in Eq. (2) & (3) in the paper.
    3. Normalize the output weights to get a piecewise-constant PDF (probability
       density function) along the ray.
    4. Sample a second set of `Nf` points from this distribution using inverse
       transform sampling.

    And importance sampling refers to step 4 specifically.

    Code is borrowed from:

    https://github.com/NVlabs/eg3d/blob/main/eg3d/training/volumetric_rendering/renderer.py

    Args:
        radial_dists: Radial distances, with shape [N, R, K, 1]
        weights: Per-point weight for integral, with shape [N, R, K, 1].
        num_importance: Number of points for importance sampling.
        smooth_weights: Whether to smooth weights. Defaults to `False`.

    Returns:
        importance_radial_dists: Radial distances of importance sampled points
            along rays.
    """
    with torch.no_grad():
        batch_size, num_rays, samples_per_ray, _ = radial_dists.shape
        radial_dists = radial_dists.reshape(batch_size * num_rays,
                                            samples_per_ray)
        weights = weights.reshape(batch_size * num_rays, -1) + 1e-5

        # Smooth weights.
        if smooth_weights:
            weights = F.max_pool1d(weights.unsqueeze(1).float(),
                                   2, 1, padding=1)
            weights = F.avg_pool1d(weights, 2, 1).squeeze()
            weights = weights + 0.01

        radial_dists_mid = 0.5 * (radial_dists[:, :-1] + radial_dists[:, 1:])
        importance_radial_dists = sample_pdf(radial_dists_mid, weights[:, 1:-1],
                                             num_importance)
        importance_radial_dists = importance_radial_dists.detach().reshape(
            batch_size, num_rays, num_importance, 1)

    return importance_radial_dists


def sample_pdf(bins, weights, num_importance, det=False, eps=1e-5):
    """Sample `num_importance` samples from `bins` with distribution defined
        by `weights`. Borrowed from:

        https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py

    Args:
        bins: Bins distributed along rays, with shape (N * R, K - 1).
        weights: Per-point weight for integral, with shape [N * R, K].
        num_importance: The number of samples to draw from the distribution.
        det: Deterministic or not. Defaults to `False`.
        eps: A small number to prevent division by zero. Defaults to `1e-5`.

    Returns:
        samples: The sampled samples.
    """
    n_rays, n_samples_ = weights.shape
    weights = weights + eps
    # Prevent division by zero (don't do inplace op!).
    pdf = weights / torch.sum(weights, -1,
                              keepdim=True)  # (n_rays, n_samples_)
    cdf = torch.cumsum(pdf, -1)  # (n_rays, N_samples),
    # Cumulative distribution function.
    cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf],
                    -1)  # (n_rays, n_samples_+1)

    if det:
        u = torch.linspace(0, 1, num_importance, device=bins.device)
        u = u.expand(n_rays, num_importance)
    else:
        u = torch.rand(n_rays, num_importance, device=bins.device)
    u = u.contiguous()

    indices = torch.searchsorted(cdf, u)
    below = torch.clamp_min(indices - 1, 0)
    above = torch.clamp_max(indices, n_samples_)

    indices_sampled = torch.stack([below, above], -1).view(n_rays,
                                                           2 * num_importance)
    cdf_g = torch.gather(cdf, 1, indices_sampled)
    cdf_g = cdf_g.view(n_rays, num_importance, 2)
    bins_g = torch.gather(bins, 1, indices_sampled).view(n_rays,
                                                         num_importance, 2)

    # `denom` equals 0 means a bin has weight 0, in which case it will not be
    # sampled anyway, therefore any value for it is fine (set to 1 here).
    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom[denom < eps] = 1

    samples = (bins_g[..., 0] + (u - cdf_g[..., 0]) /
               denom * (bins_g[..., 1] - bins_g[..., 0]))

    return samples


def unify_attributes(radial_dists1,
                     colors1,
                     densities1,
                     radial_dists2,
                     colors2,
                     densities2):
    """Unify attributes of point samples according to their radial distances.

    Args:
        radial_dists1: Radial distances of the first pass, with shape
            [N, R, K1, 1].
        colors1: Colors or features of the first pass, with shape [N, R, K1, C].
        densities1: Densities of the first pass, with shape [N, R, K1, 1].
        radial_dists2: Radial distances of the second pass, with shape
            [N, R, K2, 1].
        colors2: Colors or features of the second pass, with shape
            [N, R, K2, C].
        densities2: Densities of the second pass, with shape [N, R, K2, 1].

    Returns:
        all_radial_dists: Unified radial distances, with shape [N, R, K1+K2, 1].
        all_colors: Unified colors or features, with shape [N, R, K1+k2, C].
        all_densities: Unified densities, with shape [N, R, K1+K2, 1].
    """
    all_radial_dists = torch.cat([radial_dists1, radial_dists2], dim=-2)
    all_colors = torch.cat([colors1, colors2], dim=-2)
    all_densities = torch.cat([densities1, densities2], dim=-2)

    _, indices = torch.sort(all_radial_dists, dim=-2)
    all_radial_dists = torch.gather(all_radial_dists, -2, indices)
    all_colors = torch.gather(
        all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
    all_densities = torch.gather(all_densities, -2,
                                 indices.expand(-1, -1, -1, 1))

    return all_radial_dists, all_colors, all_densities
