# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Loss functions used in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

import numpy as np
import torch
from torch_utils import persistence

#----------------------------------------------------------------------------
# Loss function corresponding to the variance preserving (VP) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".

@persistence.persistent_class
class VPLoss:
    def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
        self.beta_d = beta_d
        self.beta_min = beta_min
        self.epsilon_t = epsilon_t

    def __call__(self, net, images, labels, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

    def sigma(self, t):
        t = torch.as_tensor(t)
        return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()

#----------------------------------------------------------------------------
# Loss function corresponding to the variance exploding (VE) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".

@persistence.persistent_class
class VELoss:
    def __init__(self, sigma_min=0.02, sigma_max=100):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def __call__(self, net, images, labels, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

#----------------------------------------------------------------------------
# Improved loss function proposed in the paper "Elucidating the Design Space
# of Diffusion-Based Generative Models" (EDM).

@persistence.persistent_class
class EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels=None, augment_pipe=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

#----------------------------------------------------------------------------

@persistence.persistent_class
class Patch_EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, pos_embed=None):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

        self.pos_embed = pos_embed

    def random_patch(self, images, patch_size, resolution):
        assert patch_size <= resolution
        device = images.device

        pos_shape = (images.shape[0], 1, patch_size, patch_size)
        x_pos = torch.ones(pos_shape)
        y_pos = torch.ones(pos_shape)
        x_start = np.random.randint(resolution - patch_size) if patch_size < resolution else 0
        y_start = np.random.randint(resolution - patch_size) if patch_size < resolution else 0

        x_pos = x_pos * x_start + torch.arange(patch_size).view(1, -1)
        y_pos = y_pos * y_start + torch.arange(patch_size).view(-1, 1)

        # rescale x and y pos to (-1, 1)
        x_pos = (x_pos / (resolution-1) - 0.5) * 2.
        y_pos = (y_pos / (resolution-1) - 0.5) * 2.

        # Add x and y additional position channels
        images_patch = images[:, :, x_start:x_start+patch_size, y_start:y_start+patch_size]
        images_pos = torch.cat([x_pos.to(device), y_pos.to(device)], dim=1)

        return images_patch, images_pos

    def pos_encode(self, images, patch_size, resolution):
        device = images.device
        # pos_shape = (images.shape[0], 1, patch_size, patch_size)
        x_pos = torch.ones((patch_size, patch_size))
        y_pos = torch.ones((patch_size, patch_size))
        x_start = np.random.randint(resolution - patch_size) if patch_size < resolution else 0
        y_start = np.random.randint(resolution - patch_size) if patch_size < resolution else 0

        x_pos = x_pos * x_start + torch.arange(patch_size).view(1, -1)
        y_pos = y_pos * y_start + torch.arange(patch_size).view(-1, 1)
        # rescale x and y pos to (-1, 1)
        x_pos = (x_pos / (resolution - 1) - 0.5) * 2.
        y_pos = (y_pos / (resolution - 1) - 0.5) * 2.

        image_pos = torch.stack([x_pos, y_pos], dim=0)
        image_pos_embed = self.pos_embed(image_pos.to(device))

        image_pos_embed = image_pos_embed.unsqueeze(0).repeat(images.shape[0], 1, 1, 1)
        images_patch = images[:, :, x_start:x_start+patch_size, y_start:y_start+patch_size]

        return images_patch, image_pos_embed

    def __call__(self, net, images, min_patch_size, resolution, labels=None, augment_pipe=None):
        # random_patch_size = np.random.choice([16, 32, 64], p=[0.5, 0.4, 0.1])
        # random_patch_size = np.random.choice([16, 32], p=[0.5, 0.5])
        # random_patch_size = np.random.choice([16, 32, 64], p=[0.3, 0.3, 0.4])  # for LSUN-Bedroom
        random_patch_size = np.random.choice([32, 64, 128, 256], p=[0.3, 0.3, 0.3, 0.1])  # for LSUN-Bedroom
        if self.pos_embed is not None:
            images, images_pos = self.pos_encode(images, random_patch_size, resolution)
        else:
            images, images_pos = self.random_patch(images, random_patch_size, resolution)
        # images_pos_noised = images_pos + torch.randn_like(images_pos) * (1/max_res*0.2)

        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2

        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        # Add positional channels
        yn, y = torch.cat([y + n, images_pos], dim=1), torch.cat([y, images_pos], dim=1)

        D_yn = net(yn, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

#----------------------------------------------------------------------------

