# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
"""
This module contains extra constraint objects. These object can be added as params to
regular layers.
"""
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from .normalizers import reshaped_kernel_orthogonalization,DEFAULT_NITER_SPECTRAL,spectral_normalization_conv
from tensorflow.keras.utils import register_keras_serializable
#import tensorflow_riemopt as riemopt
import numpy as np

@register_keras_serializable("deel-lip", "WeightClipConstraint")
class WeightClipConstraint(Constraint):
    def __init__(self, c=2):
        """
        Clips the weights incident to each hidden unit to be inside the range `[-c,+c]`.

        Args:
            c: clipping parameter.
        """
        self.c = c

    def __call__(self, p):
        return K.clip(p, -self.c, self.c)

    def get_config(self):
        return {"c": self.c}


@register_keras_serializable("deel-lip", "AutoWeightClipConstraint")
class AutoWeightClipConstraint(Constraint):
    def __init__(self, scale=1):
        """
        Clips the weights incident to each hidden unit to be inside the range `[-c,+c]`.
        With c = 1/sqrt(size(kernel)).

        Args:
            scale: scaling factor to increase/decrease clipping value.
        """
        self.scale = scale

    def __call__(self, w):
        c = 1 / (tf.sqrt(tf.cast(tf.size(w), dtype=w.dtype)) * self.scale)
        return tf.clip_by_value(w, -c, c)

    def get_config(self):
        return {"scale": self.scale}


@register_keras_serializable("deel-lip", "FrobeniusConstraint")
class FrobeniusConstraint(Constraint):
    # todo: duplicate of keras/constraints/UnitNorm ?

    def __init__(self, eps=1e-7,axis = None,k_coef_lip=1.0):
        """
        Constrain the weights by dividing the weight matrix by it's L2 norm.
        """
        self.eps = eps
        self.axis = axis
        self.k_coef_lip=k_coef_lip

    def __call__(self, w):
        #return w
        return self.k_coef_lip*w / (tf.sqrt(tf.reduce_sum(tf.square(w), axis=self.axis)) + self.eps)

    def get_config(self):
        return {"eps": self.eps}
    
    
@register_keras_serializable("deel-lip", "NormInfConstraint")
class NormInfConstraint(Constraint):
    # todo: duplicate of keras/constraints/UnitNorm ?

    def __init__(self):
        self.sigma = None
       
    def set_sigma(self,sigma):
        self.sigma = sigma
    
    def __call__(self, w):
        #return w
        shape = tf.shape(w)
        #return w / (tf.sqrt(tf.cast(shape[0]*shape[1]*shape[2],tf.float32))*tf.abs(tf.reduce_max(w)))
        return (tf.sqrt(tf.cast(shape[0]*shape[1]*shape[2],tf.float32)))*w / self.sigma

    def get_config(self):
        return {}


@register_keras_serializable("deel-lip", "SpectralConstraint")
class SpectralConstraint(Constraint):
    def __init__(
        self, k_coef_lip=1.0, niter_spectral=3, niter_bjorck=15, u=None
    ) -> None:
        """
        Ensure that *all* singular values of the weight matrix equals to 1. Computation
        based on Bjorck algorithm. The computation is done in two steps:

        1. reduce the larget singular value to k_coef_lip, using iterate power method.
        2. increase other singular values to k_coef_lip, using bjorck algorithm.

        Args:
            k_coef_lip: lipschitz coefficient of the weight matrix
            niter_spectral: number of iteration to find the maximum singular value.
            niter_bjorck: number of iteration with Bjorck algorithm..
            u: vector used for iterated power method, can be set to None (used for
                serialization/deserialization purposes).
        """
        self.niter_spectral = niter_spectral
        self.niter_bjorck = niter_bjorck
        self.k_coef_lip = k_coef_lip
        if not (isinstance(u, tf.Tensor) or (u is None)):
            u = tf.convert_to_tensor(u)
        self.u = u
        #self.manifold = riemopt.manifolds.StiefelEuclidean()
        super(SpectralConstraint, self).__init__()
    def set_shape(self,kernel):
        None
    def __call__(self, w):
        wbar, u, sigma = reshaped_kernel_orthogonalization(
            w,
            self.u,
            self.k_coef_lip,
            self.niter_spectral,
            self.niter_bjorck,
        )
 
        return wbar

    def get_config(self):
        config = {
            "k_coef_lip": self.k_coef_lip,
            "niter_spectral": self.niter_spectral,
            "niter_bjorck": self.niter_bjorck,
            "u": None if self.u is None else self.u.numpy(),
        }
        base_config = super(SpectralConstraint, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    
    
@register_keras_serializable("deel-lip", "SpectralStiefelConstraint")
class SpectralStiefelConstraint(Constraint):
    def __init__(
        self,strides = (1,1),conv_first = False, niter=DEFAULT_NITER_SPECTRAL
    ) -> None:
        """
        Ensure that *all* singular values of the weight matrix equals to 1. Computation
        based on Bjorck algorithm. The computation is done in two steps:

        1. reduce the larget singular value to k_coef_lip, using iterate power method.
        2. increase other singular values to k_coef_lip, using bjorck algorithm.

        Args:
            k_coef_lip: lipschitz coefficient of the weight matrix
            niter_spectral: number of iteration to find the maximum singular value.
            niter_bjorck: number of iteration with Bjorck algorithm..
            u: vector used for iterated power method, can be set to None (used for
                serialization/deserialization purposes).
        """
        super(SpectralStiefelConstraint, self).__init__()
        self.niter = niter
        self.conv_first = conv_first
        self.stride = strides[0]
    
    def set_shape(self,shape,transpose):
        self.conv_shape = shape
        self.transpose = transpose
        (R0,R,C,M) = shape
        self.cPad=[int(R0/2),int(R/2)]
        r = R//2
        #print(r)
        if r<1:
            N=5
        else:
            N = 4*r+1
            if self.stride >1:
                N = int(0.5+N/self.stride)
        #FM 01 N = 6*N
        #FM 01 print("test N = 2*N")
        #FM 01 print("self.niter_spectral",self.niter_spectral)
        if C*self.stride**2>M:
            self.spectral_input_shape = (N,N,M)
            self.RO_case = True
        else:
            self.spectral_input_shape = (self.stride*N,self.stride*N,C)
            self.RO_case = False
        self.usize = np.prod(self.spectral_input_shape)
        self.u = tf.Variable(tf.random_normal_initializer(mean=0.0, stddev=1)(shape=(1,)+self.spectral_input_shape,dtype = tf.float32))
        
        #print(self.u.shape)
    def __call__(self, w):
        #return w
        #tf.print("projection")
        #u = tf.random_normal_initializer(mean=0.0, stddev=1)(shape=(1,)+self.spectral_input_shape,dtype = tf.float32)
        wbar =w
        if self.transpose:
            wbar = tf.transpose(wbar)
            
        wbar =  tf.reshape(wbar, self.conv_shape)
        wbar, u, sigma = spectral_normalization_conv(wbar,  u=self.u, stride = self.stride, conv_first = not self.RO_case, w_pad=self.cPad[0],h_pad = self.cPad[1], niter=self.niter)
        self.sigma.assign([[sigma]])
        self.u.assign(u)
        return w
    def set_sigma(self,sigma):
        self.sigma = sigma
    def get_config(self):
        config = {
            "k_coef_lip": self.k_coef_lip,
            "niter_spectral": self.niter_spectral,
            "niter_bjorck": self.niter_bjorck,
            "u": None if self.u is None else self.u.numpy(),
        }
        base_config = super(SpectralConstraint, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    

@register_keras_serializable("deel-lip", "SpectralConvConstraint")
class SpectralConvConstraint(Constraint):
    def __init__(
        self,strides = (1,1),conv_first = False, niter=DEFAULT_NITER_SPECTRAL
    ) -> None:
        """
        Ensure that *all* singular values of the weight matrix equals to 1. Computation
        based on Bjorck algorithm. The computation is done in two steps:

        1. reduce the larget singular value to k_coef_lip, using iterate power method.
        2. increase other singular values to k_coef_lip, using bjorck algorithm.

        Args:
            k_coef_lip: lipschitz coefficient of the weight matrix
            niter_spectral: number of iteration to find the maximum singular value.
            niter_bjorck: number of iteration with Bjorck algorithm..
            u: vector used for iterated power method, can be set to None (used for
                serialization/deserialization purposes).
        """
        super(SpectralConvConstraint, self).__init__()
        self.niter = niter
        self.conv_first = conv_first
        self.stride = strides[0]
    
    def set_shape(self,kernel):
        (R0,R,C,M) = kernel.shape
        self.cPad=[int(R0/2),int(R/2)]
        r = R//2
        #print(r)
        if r<1:
            N=5
        else:
            N = 4*r+1
            if self.stride >1:
                N = int(0.5+N/self.stride)
        #FM 01 N = 6*N
        #FM 01 print("test N = 2*N")
        #FM 01 print("self.niter_spectral",self.niter_spectral)
        if C*self.stride**2>M:
            self.spectral_input_shape = (N,N,M)
            self.RO_case = True
        else:
            self.spectral_input_shape = (self.stride*N,self.stride*N,C)
            self.RO_case = False
        self.usize = np.prod(self.spectral_input_shape)
        self.u = tf.Variable(tf.random_normal_initializer(mean=0.0, stddev=1)(shape=(1,)+self.spectral_input_shape,dtype = tf.float32))
        #print(self.u.shape)
    def __call__(self, w):
        #return w
        #tf.print("projection")
        #u = tf.random_normal_initializer(mean=0.0, stddev=1)(shape=(1,)+self.spectral_input_shape,dtype = tf.float32)
        wbar, u, sigma = spectral_normalization_conv(w,  u=self.u, stride = self.stride, conv_first = self.conv_first, cPad=self.cPad, niter=self.niter)
        self.u.assign(u)
        return wbar

    def get_config(self):
        config = {
            "k_coef_lip": self.k_coef_lip,
            "niter_spectral": self.niter_spectral,
            "niter_bjorck": self.niter_bjorck,
            "u": None if self.u is None else self.u.numpy(),
        }
        base_config = super(SpectralConstraint, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
