################################################################################
# spectral/modules/spectral_set.py
#
# 
# 
# 
# 2024
#
# Implements a modified version of the spectral normalization parametrization.
# In this case, the spectral norm can be set to a specific value via a function.

import torch

from typing import Callable, Optional

Tensor = torch.Tensor
Module = torch.nn.Module

class _SetSpectralNorm(torch.nn.utils.parametrizations._SpectralNorm):

  def __init__(self,
      # Arguments:
      weight: Tensor,
      set_fn: Callable[[Tensor, Tensor], Tensor],
      # Keyword Arguments:
      n_power_iterations: int = 1,
      dim:                int = 0,
      eps:                float = 1e-12
    ):
    super(_SetSpectralNorm, self).__init__(
      weight,
      n_power_iterations = n_power_iterations,
      dim                = dim,
      eps                = eps
    )
    self.set_fn = set_fn
    self._enabled = True

  def forward(self,
      weight: Tensor
    ) -> Tensor:
    if weight.ndim == 1:
      sigma = torch.linalg.norm(weight)
      return self.set_fn(weight, sigma + self.eps)
    else:
      weight_mat = self._reshape_weight_to_matrix(weight)
      if self.training and self._enabled:
        self._power_method(weight_mat, self.n_power_iterations)
      u = self._u.clone(memory_format = torch.contiguous_format)
      v = self._v.clone(memory_format = torch.contiguous_format)
      sigma = torch.dot(u, torch.mv(weight_mat, v))
      return self.set_fn(weight, sigma)

def set_spectral_norm(
    # Arguments:
    module: Module,
    set_fn: Callable[[Tensor, Tensor], Tensor],
    # Keyword Arguments:
    name:                str           = "weight",
    n_power_iterations:  int           = 1,
    eps:                 float         = 1e-12,
    get_parametrization: bool          = False,
    dim:                 Optional[int] = None
  ) -> Module:
  weight = getattr(module, name, None)
  assert isinstance(weight, Tensor), \
    f"{module} does not have a parameter or buffer called '{name}'."
  if dim is None:
    is_conv_transpose = isinstance(
      module,
      (
        torch.nn.ConvTranspose1d,
        torch.nn.ConvTranspose2d,
        torch.nn.ConvTranspose3d
      )
    )
    dim = 1 if is_conv_transpose else 0
  parametrization = _SetSpectralNorm(
    weight,
    set_fn,
    n_power_iterations,
    dim,
    eps
  )
  torch.nn.utils.parametrize.register_parametrization(
    module,
    name,
    parametrization
  )
  return (module, parametrization) if get_parametrization else module