import torch
import torchvision
import torch.nn.functional as F
from torch import nn
from .blocks import *
from torchvision import transforms



class Normalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)
    
    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
        return x_clone


class Denormalize:
    def __init__(self, opt, expected_values, variance):
        self.n_channels = opt.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)
    
    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
        return x_clone


#---------------------------- Generators ----------------------------#

class Generator(nn.Sequential):
    def __init__(self, opt, out_channels = None):
        super(Generator, self).__init__()
        if(opt.dataset == 'mnist'):
            channel_init = 16
            steps = 2
        else:
            channel_init = 32
            steps = 3

        channel_current = opt.input_channel
        channel_next = channel_init
        for step in range(steps):
            self.add_module('convblock_down_{}'.format(2 * step), Conv2dBlock(channel_current, channel_next))
            self.add_module('convblock_down_{}'.format(2 * step + 1), Conv2dBlock(channel_next, channel_next))
            self.add_module('downsample_{}'.format(step), DownSampleBlock())
            if(step < steps - 1):
                channel_current = channel_next
                channel_next *= 2

        self.add_module('convblock_middle', Conv2dBlock(channel_next, channel_next))

        channel_current = channel_next
        channel_next = channel_current // 2
        for step in range(steps):
            self.add_module('upsample_{}'.format(step), UpSampleBlock())
            self.add_module('convblock_up_{}'.format(2 * step), Conv2dBlock(channel_current, channel_current))
            if step == steps - 1:
                self.add_module('convblock_up_{}'.format(2 * step + 1), Conv2dBlock(channel_current, channel_next, relu=False))
            else:
                self.add_module('convblock_up_{}'.format(2 * step + 1), Conv2dBlock(channel_current, channel_next))
            channel_current = channel_next
            channel_next = channel_next // 2
            if(step == steps - 2):
                if out_channels is None:
                   channel_next = opt.input_channel
                else:
                   channel_next = out_channels          

        self._EPSILON = 1e-7
        self._normalizer = self._get_normalize(opt)
        self._denormalizer = self._get_denormalize(opt)

    def _get_denormalize(self, opt):
        if(opt.dataset == 'cifar10'):
            denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif(opt.dataset == 'mnist'):
            denormalizer = Denormalize(opt, [0.5], [0.5])
        elif(opt.dataset == 'gtsrb'):
            denormalizer = None
        else: 
            raise Exception("Invalid dataset")
        return denormalizer

    def _get_normalize(self, opt):
        if(opt.dataset == 'cifar10'):
            normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif(opt.dataset == 'mnist'):
            normalizer = Normalize(opt, [0.5], [0.5])
        elif(opt.dataset == 'gtsrb'):
            normalizer = None
        else:
            raise Exception("Invalid dataset")
        return normalizer

    def forward(self, x):
        for module in self.children():
            x = module(x)
        x = nn.Tanh()(x) / (2 + self._EPSILON) + 0.5
        return x

    def normalize_pattern(self, x):
        if(self._normalizer):
            x = self._normalizer(x)
        return x

    def denormalize_pattern(self, x):
        if(self._denormalizer):
            x = self._denormalizer(x)
        return x 

    def threshold(self, x):
        return nn.Tanh()(x*20 - 10) / (2 + self._EPSILON) + 0.5

#---------------------------- Classifiers ----------------------------#

class NetC_MNIST(nn.Module):
    def __init__(self):
        super(NetC_MNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, (5, 5), 1, 0)
        self.relu2 = nn.ReLU(inplace=True)
        self.dropout3 = nn.Dropout(0.1)

        self.maxpool4 = nn.MaxPool2d((2, 2))
        self.conv5 = nn.Conv2d(32, 64, (5, 5), 1, 0)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout(0.1)

        self.maxpool5 = nn.MaxPool2d((2, 2))
        self.flatten = nn.Flatten()
        self.linear6 = nn.Linear(64 * 4 * 4, 512)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout8 = nn.Dropout(0.1)
        self.linear9 = nn.Linear(512, 10)

    def forward(self, x):
        for module in self.children():
            x = module(x)
        return x


