import abc

import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import (
    Layer,
    Dense,
    Conv2D,
    AveragePooling2D,
    GlobalAveragePooling2D,
)

from .constraints import SpectralConstraint,FrobeniusConstraint,SpectralConvConstraint
from .initializers import SpectralInitializer,FrobenusInitializer
from .normalizers import (
    DEFAULT_NITER_BJORCK,
    DEFAULT_NITER_SPECTRAL,
    DEFAULT_NITER_SPECTRAL_INIT,
    reshaped_kernel_orthogonalization,
    reshaped_kernel_orthogonalization_dense,
    bjork_normalization_conv,
    spectral_normalization_conv,
    spectral_normalization,
    DEFAULT_BETA_BJORCK,
)
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.keras.initializers import GlorotUniform
from .regularizers import Lorth2D,LorthRegularizer
from tensorflow.keras.layers import DepthwiseConv2D

from tensorflow.keras.utils import register_keras_serializable
from tensorflow.keras import backend as K
import tensorflow as tf
from deel.lip.layers import LipschitzLayer, Condensable
import numpy as np

def cayley_norm(kernel):
    n = kernel.shape[0]
    C = tf.linalg.band_part(kernel, 0, -1)
    D = C-tf.transpose(C)
    id_mat = tf.eye(n)
    inv =tf.linalg.inv (D+id_mat)
    return (D-id_mat)@inv
    
@register_keras_serializable("deel-lip", "CayleyConv1x1")
class CayleyConv1x1(Conv2D, LipschitzLayer, Condensable):
    def __init__(
        self,
        filters,
        strides=(1, 1),
        padding="same",
        data_format=None,
        dilation_rate=(1, 1),
        activation=None,
        use_bias=True,
        kernel_initializer=SpectralInitializer(
            niter_spectral=DEFAULT_NITER_SPECTRAL_INIT,
            niter_bjorck=DEFAULT_NITER_BJORCK,
        ),
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs
    ):
      
        #print(use_bias,bias_initializer)
        if not (
            (dilation_rate == (1, 1))
            or (dilation_rate == [1, 1])
            or (dilation_rate == 1)
        ):
            raise RuntimeError("NormalizedConv does not support dilation rate")
        #if padding != "same" :
        #    raise RuntimeError("NormalizedConv only support padding='same'")

        super(CayleyConv1x1, self).__init__(
            filters=filters,
            kernel_size=(1,1),
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs
        )
        self._kwargs = kwargs
       
        self.built = False
        
        self.wbar = None
            
            
        

    def build(self, input_shape):
        
        super(CayleyConv1x1, self).build(input_shape)
        #print("dtype",self.dtype)
   
        
        if self.kernel_regularizer is not None:
                self.kernel_regularizer.set_kernel_shape(self.kernel.shape)
        
        else :
            self.wbar = self.add_weight(
                    shape=self.kernel.shape,  # maximum spectral  value
                    name="wbar",
                    trainable=False,
                    dtype=self.dtype,
                )

            self.k_shape =  tf.convert_to_tensor(self.kernel.shape)
            k_flat = tuple([-1, self.kernel.shape[-1]])
            print(k_flat)
            self.k_flat =  tf.convert_to_tensor(k_flat)
        self.built = True

    def _compute_lip_coef(self, input_shape=None):
        return 1

    

    def call(self, x, training=True):
        #tf.print("training",training)
        if training:

            wbar =  cayley_norm(self.kernel)
            self.wbar.assign(wbar)
        else:

            wbar = self.wbar
        outputs = K.conv2d(
            x,
            wbar,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
        )
        
        if self.use_bias:
            outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
        
        if self.activation is not None:
            return self.activation(outputs)
        
        return outputs

    def get_config(self):
        config = {
            "k_coef_lip": self.k_coef_lip,
            "niter_spectral": self.niter_spectral,
            "niter_bjorck": self.niter_bjorck,
            "beta_bjorck": self.beta_bjorck,
        }
        base_config = super(CayleyConv1x1, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def condense(self):
        
        if not self.by_constraint:
            wbar =  cayley_norm(self.kernel)
            self.kernel.assign(wbar)

    def vanilla_export(self):
        self._kwargs["name"] = self.name
        layer = Conv2D(
            filters=self.filters,
            kernel_size=(1,1),
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            activation=self.activation,
            use_bias=self.use_bias,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
            **self._kwargs
        )
        layer.build(self.input_shape)
        layer.kernel.assign(self.wbar)
           
        if self.use_bias:
            layer.bias.assign(self.bias)
        return layer
    
    
@register_keras_serializable("deel-lip", "CayleyConv2D")
class CayleyConv2D(Conv2D, LipschitzLayer, Condensable):
    def __init__(
        self,
        filters,
        kernel_size,
        strides=(1, 1),
        padding="same",
        data_format=None,
        dilation_rate=(1, 1),
        activation=None,
        use_bias=True,
        kernel_initializer=SpectralInitializer(
            niter_spectral=20,
            niter_bjorck=25,
        ),
        k_coef_lip =1,
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        by_constraint = False,
        **kwargs
    ):
  

        #print(use_bias,bias_initializer)
        if not (
            (dilation_rate == (1, 1))
            or (dilation_rate == [1, 1])
            or (dilation_rate == 1)
        ):
            raise RuntimeError("NormalizedConv does not support dilation rate")
        #if padding != "same" :
        #    raise RuntimeError("NormalizedConv only support padding='same'")
        
        super(CayleyConv2D, self).__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs
        )
        self._kwargs = kwargs
        self.set_klip_factor(k_coef_lip)
        self.built = False
        self.by_constraint = by_constraint
        #print(use_bias,bias_initializer,self.use_bias)
        self.delta2one = 0.02    
        #print("dtype",self.dtype)
        if not self.by_constraint:
            self.u = None
            self.u_dense = None
            self.sig = None
            self.wbar = None
            
            
  

    def build(self, input_shape):
        #shape = c_in
        super(CayleyConv2D, self).build(input_shape)
        self._init_lip_coef(input_shape)
        self.kernel = self.add_weight(
                shape=(self.kernel_size[0], self.kernel_size[1],input_shape[-1],self.filters),  
                name="kernel",
                trainable=False,
                initializer=self.kernel_initializer,
                dtype=self.dtype,
            )
        print(input_shape)
        self.k_shape =  tf.convert_to_tensor(self.kernel.shape)
        self.id_mat = tf.eye(self.filters)
        print(self.k_shape)
        #print("dtype",self.dtype)
        self.wbar = self.add_weight(
                shape=self.kernel.shape,  # maximum spectral  value
                name="wbar",
                trainable=False,
                dtype=self.dtype,
            )
        self.bias = self.add_weight(
                shape=( self.kernel.shape[-1]),  # maximum spectral  value
                name="bias",
                trainable=True,
                initializer=tf.keras.initializers.Zeros(),
                dtype=self.dtype,
            )
        self.orth_kernel = self.add_weight(
                shape=(self.kernel.shape[-1], self.kernel.shape[-1]),  
                name="orth_kernel",
                trainable=True,
                initializer='uniform',
                dtype=self.dtype,
            )
        self.orth_kernel.assign(tf.linalg.band_part(self.orth_kernel, 0, -1))
   
        k_flat = tuple([-1, self.kernel.shape[-1]])
        self.k_flat =  tf.convert_to_tensor(k_flat)
        
        print("kernel shape",self.kernel.shape)
        print("orth kernel shape",self.orth_kernel.shape)
        self.built = True

    def _compute_lip_coef(self, input_shape=None):
        # According to the file lipschitz_CNN.pdf
        if self.padding == "valid":
            return float(self.strides[0])/float(self.kernel_size[0])
        stride = np.prod(self.strides)
        k1 = self.kernel_size[0]
        k1_div2 = (k1 - 1) / 2
        k2 = self.kernel_size[1]
        k2_div2 = (k2 - 1) / 2
        if self.data_format == "channels_last":
            h = input_shape[-3]
            w = input_shape[-2]
        elif self.data_format == "channels_first":
            h = input_shape[-2]
            w = input_shape[-1]
        else:
            raise RuntimeError("data_format not understood: " % self.data_format)
        if stride == 1:
            coefLip = np.sqrt(
                (w * h)
                / (
                    (k1 * h - k1_div2 * (k1_div2 + 1))
                    * (k2 * w - k2_div2 * (k2_div2 + 1))
                )
            )
        else:
            sn1 = self.strides[0]
            sn2 = self.strides[1]
            ho = np.floor(h / sn1)
            wo = np.floor(w / sn2)
            alphabar1 = np.floor(k1_div2 / sn1)
            alphabar2 = np.floor(k2_div2 / sn2)
            betabar1 = k1_div2 - alphabar1 * sn1
            betabar2 = k2_div2 - alphabar2 * sn2
            zl1 = (alphabar1 * sn1 + 2 * betabar1) * (alphabar1 + 1) / 2
            zl2 = (alphabar2 * sn2 + 2 * betabar2) * (alphabar2 + 1) / 2
            gamma1 = h - 1 - sn1 * np.ceil((h - 1 - k1_div2) / sn1)
            gamma2 = w - 1 - sn2 * np.ceil((w - 1 - k2_div2) / sn2)
            alphah1 = np.floor(gamma1 / sn1)
            alphaw2 = np.floor(gamma2 / sn2)
            zr1 = (alphah1 + 1) * (k1_div2 - gamma1 + sn1 * alphah1 / 2.0)
            zr2 = (alphaw2 + 1) * (k2_div2 - gamma2 + sn2 * alphaw2 / 2.0)
            coefLip = np.sqrt((h * w) / ((k1 * ho - zl1 - zr1) * (k2 * wo - zl2 - zr2)))
        #return 1
        return coefLip

    
    def cayley_param(self):
        wbar =  tf.reshape(self.kernel, self.k_flat)
        C = tf.linalg.band_part(self.orth_kernel, 0, -1)
        D = C-tf.transpose(C)

        inv = tf.linalg.inv (D+self.id_mat)
        wbar = (wbar@(D-self.id_mat)@inv)*self._get_coef()
        wbar = tf.reshape(wbar, self.k_shape )
        return wbar
    def call(self, x, training=True):

        if training:
            wbar =  self.cayley_param()
            self.wbar.assign(wbar)
        else:

            wbar = self.wbar
        outputs = K.conv2d(
            x,
            wbar,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
        )
        if self.use_bias:
            outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
        
        if self.activation is not None:
            return self.activation(outputs)
        
        return outputs

    def get_config(self):
        config = {
            
        }
        base_config = super(CayleyConv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def condense(self):
        
        if not self.by_constraint:
            wbar =  self.cayley_param()
            self.kernel.assign(wbar)

    def vanilla_export(self):
        self._kwargs["name"] = self.name
        layer = Conv2D(
            filters=self.filters,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            activation=self.activation,
            use_bias=self.use_bias,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
            **self._kwargs
        )
        layer.build(self.input_shape)
        if self.by_constraint:
            layer.kernel.assign(self.kernel)
            
        else :
            layer.kernel.assign(self.wbar)
           
        if self.use_bias:
            layer.bias.assign(self.bias)
        return layer