from jax.lax import stop_gradient
import jax.numpy as jnp
import flax.linen as nn
import functools


class VGG(nn.Module):
    greedy: bool=False
    architecture: str='vgg16'
    kernel_init: functools.partial=nn.initializers.lecun_normal()
    bias_init: functools.partial=nn.initializers.zeros
    dtype: str='float32'

    def setup(self):
        self.param_dict = None
        # if self.pretrained == 'imagenet':
        #     ckpt_file = utils.download(self.ckpt_dir, URLS[self.architecture])
        #     self.param_dict = h5py.File(ckpt_file, 'r')

    @nn.compact
    def __call__(self, x, train=True):
        
        act = {}
        if self.architecture == 'vgg8':
            model_dependent = 1
        if self.architecture == 'vgg16':
            model_dependent = 3
        elif self.architecture == 'vgg19':
            model_dependent = 4

        x = self._conv_block(x, features=64, num_layers=2, block_num=1, act=act, dtype=self.dtype, train=train)
        if self.greedy:
            act['block_1'] = x
            x = stop_gradient(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = self._conv_block(x, features=128, num_layers=2, block_num=2, act=act, dtype=self.dtype, train=train)
        if self.greedy:
            act['block_2'] = x
            x = stop_gradient(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = self._conv_block(x, features=256, num_layers=model_dependent, block_num=3, act=act, dtype=self.dtype, train=train)
        if self.greedy:
            act['block_3'] = x
            x = stop_gradient(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = self._conv_block(x, features=512, num_layers=model_dependent, block_num=4, act=act, dtype=self.dtype, train=train)
        if self.greedy:
            act['block_4'] = x
            x = stop_gradient(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        
        x = self._conv_block(x, features=512, num_layers=model_dependent, block_num=5, act=act, dtype=self.dtype, train=train)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        act['out'] = x
        
        return act

    def _conv_block(self, x, features, num_layers, block_num, act, dtype='float32', train=True):
        for l in range(num_layers):
            layer_name = f'conv{block_num}_{l + 1}'
            w = self.kernel_init if self.param_dict is None else lambda *_ : jnp.array(self.param_dict[layer_name]['weight']) 
            b = self.bias_init if self.param_dict is None else lambda *_ : jnp.array(self.param_dict[layer_name]['bias']) 
            x = nn.Conv(features=features, kernel_size=(3, 3), kernel_init=w, bias_init=b, padding='same', name=layer_name, dtype=dtype)(x)
            x = nn.relu(x)
        # batch norm
        x = x = nn.BatchNorm(use_running_average=not train, dtype=dtype)(x)
        return x


def VGG8(greedy=False,
          kernel_init=nn.initializers.lecun_normal(),
          bias_init=nn.initializers.zeros,
          dtype='float32',
          dataset='cifar10'):

    return VGG(greedy=greedy,
               architecture='vgg8',
               kernel_init=kernel_init,
               bias_init=bias_init,
               dtype=dtype)


def VGG16(greedy=False,
          kernel_init=nn.initializers.lecun_normal(),
          bias_init=nn.initializers.zeros,
          dtype='float32',
          dataset='cifar10'):

    return VGG(greedy=greedy,
               architecture='vgg16',
               kernel_init=kernel_init,
               bias_init=bias_init,
               dtype=dtype)




def VGG19(greedy=False,
          kernel_init=nn.initializers.lecun_normal(),
          bias_init=nn.initializers.zeros,
          dtype='float32'):

    return VGG(greedy=greedy,
               architecture='vgg19',
               kernel_init=kernel_init,
               bias_init=bias_init,
               dtype=dtype)