import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np
import sys
import matplotlib.pyplot as plt
import wandb
sys.path.append("lib/")
sys.path.append("../")
from deepthinking.lib.solvers import anderson, broyden
from deepthinking.lib.optimizations import weight_norm, VariationalHidDropout2d

NUM_GROUPS = 4
BLOCK_GN_AFFINE = True  

class BasicBlock(nn.Module):
    """Basic residual block class"""
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, dropout=0.1, wnorm=False, norm_type='group'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False)
            )
        print(f"Using NUMGROUP {NUM_GROUPS}")
        if norm_type == 'group':
            self.gn1 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
            self.gn2 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
            self.gn3 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
        elif norm_type == 'instance':
            self.gn1 = nn.InstanceNorm2d(planes)
            self.gn2 = nn.InstanceNorm2d(planes)
            self.gn3 = nn.InstanceNorm2d(planes)
        elif norm_type == 'batch':
            self.gn1 = nn.BatchNorm2d(planes)
            self.gn2 = nn.BatchNorm2d(planes)
            self.gn3 = nn.BatchNorm2d(planes)
        else:
            raise ValueError(f"Unknown normalization type {norm_type}")

        if wnorm: self._wnorm()

    def _wnorm(self):
        """
        Register weight normalization
        """
        self.conv1, self.conv1_fn = weight_norm(self.conv1, names=['weight'], dim=0)
        self.conv2, self.conv2_fn = weight_norm(self.conv2, names=['weight'], dim=0)
        if len(self.shortcut) > 0:
            self.shortcut[0].conv, self.shortcut_fn = weight_norm(self.shortcut[0].conv, names=['weight'], dim=0)
    
    def _reset(self, bsz, d, H, W):
        """
        Reset dropout mask and recompute weight via weight normalization
        """
        if 'conv1_fn' in self.__dict__:
            self.conv1_fn.reset(self.conv1)
        if 'conv2_fn' in self.__dict__:
            self.conv2_fn.reset(self.conv2)
        if 'shortcut_fn' in self.__dict__:
            self.shortcut_fn.reset(self.shortcut[0].conv)

       #self.drop.reset_mask(bsz, d, H, W)

    def forward(self, x, injection=None):
        if injection is None: injection = 0
        #out = self.drop(self.conv1(x)) + injection
        out = self.conv1(x) + injection
        out = self.gn1(F.relu(out))
        out = self.conv2(out) 
        out += self.shortcut(x)
        out = self.gn3(F.relu(self.gn2(out)))
        return out

class BasicResBlock(nn.Module):
    """Basic residual block class"""
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, wnorm=False, norm_type='group'):
        super(BasicResBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False)
            )

        if wnorm: self._wnorm()

    def _wnorm(self):
        """
        Register weight normalization
        """
        self.conv1, self.conv1_fn = weight_norm(self.conv1, names=['weight'], dim=0)
        self.conv2, self.conv2_fn = weight_norm(self.conv2, names=['weight'], dim=0)
        if len(self.shortcut) > 0:
            self.shortcut[0].conv, self.shortcut_fn = weight_norm(self.shortcut[0].conv, names=['weight'], dim=0)
    
    def _reset(self, bsz, d, H, W):
        """
        Reset dropout mask and recompute weight via weight normalization
        """
        if 'conv1_fn' in self.__dict__:
            self.conv1_fn.reset(self.conv1)
        if 'conv2_fn' in self.__dict__:
            self.conv2_fn.reset(self.conv2)
        if 'shortcut_fn' in self.__dict__:
            self.shortcut_fn.reset(self.shortcut[0].conv)

    def forward(self, x, injection=None):
        if injection is None: injection = 0
        out = self.conv1(x) + injection
        out = F.relu(out)
        out = self.conv2(out) 
        out += self.shortcut(x)
        out = F.relu(out)
        return out

blocks_dict = { 'BASIC': BasicBlock, 'BASICRES': BasicResBlock}

class DEQModule(nn.Module):
    def __init__(self, block, num_blocks, width, norm_type='group'):
        super(DEQModule, self).__init__()
        self.in_planes = int(width)
        self.num_blocks = num_blocks 
        self.norm_type = norm_type
        self.deq_layer = self._make_layer(block, width, num_blocks, stride=1)

    def _wnorm(self):
        """
        Apply weight normalization to the learnable parameters of MDEQ
        """
        for i, block in enumerate(self.deq_layer):
            block._wnorm()
        
        # Throw away garbage
        torch.cuda.empty_cache()
        
    def _reset(self, xs):
        """
        Reset the dropout mask and the learnable parameters (if weight normalization is applied)
        """
        for i, block in enumerate(self.deq_layer):
            block._reset(*xs.shape)

    def _make_layer(self, block, planes, num_blocks, stride=1):
        """
        Make a specific branch indexed by `branch_index`. This branch contains `num_blocks` residual blocks of type `block`.
        """
        # if num_blocks == 1:
        #     return block(self.in_planes, planes, stride)
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.in_planes, planes, strd, norm_type=self.norm_type))
            self.in_planes = planes * block.expansion
        return nn.ModuleList(layers)

    def forward(self, x, injection):
        for i in range(self.num_blocks):
            x = self.deq_layer[i](x, injection)
        return x

class DEQMazeNet(nn.Module):
    """Modified ResNet model class"""

    def __init__(self, width, config, **kwargs):
        super(DEQMazeNet, self).__init__()
        self.kwargs = kwargs
        self.config = config
        self.parse_cfg(config)
        self.hook = None
        self.width = width

        self.proj_conv = nn.Conv2d(self.in_channels, self.width, kernel_size=3,
                               stride=1, padding=1, bias=False)
        conv2 = nn.Conv2d(self.width, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        conv3 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=False)
        conv4 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.head = nn.Sequential(  conv2, nn.ReLU(),
                                    conv3, nn.ReLU(),
                                    conv4)
        self.deq = DEQModule(blocks_dict[self.block_type], self.num_blocks, self.width)

        self.avg_iters = 0
        self.total_count = 0

        if self.wnorm:
            self.deq._wnorm()

    def parse_cfg(self, config):
        cfg = config.problem
        # DEQ related
        self.f_solver = eval(cfg.deq.f_solver)
        self.b_solver = eval(cfg.deq.b_solver)
        if self.b_solver is None:
            self.b_solver = self.f_solver
        self.f_thres = cfg.deq.f_thres
        self.b_thres = cfg.deq.b_thres
        self.stop_mode = cfg.deq.stop_mode

        self.in_channels = cfg.deq.in_channels
        # Model related
        self.num_layers = cfg.deq.num_layers
        self.num_blocks = cfg.deq.num_blocks
        self.block_type = cfg.deq.extra.block 
        
        # Training related config
        self.pretrain_steps = cfg.train.pretrain_steps
        self.wnorm = cfg.deq.wnorm

        global NUM_GROUPS 
        NUM_GROUPS = cfg.deq.num_groups

        print(f"Using {cfg.deq.norm} normalization")
        self.norm_type = cfg.deq.norm

        self.fp_init = self.fp_init = cfg.deq.fp_init

    def forward(self, x_init, train_step=-1, return_interm_vals=False, plot=False, logger=None):

        x = F.relu(self.proj_conv(x_init))

        deq_mode = (train_step < 0) or (train_step >= self.pretrain_steps)
        func = lambda z: self.deq(z, x)

        if self.fp_init == 'zeros':
            z1 = torch.zeros_like(x, device=x_init.device)
        elif self.fp_init == 'x_proj':
            z1 = x.clone()
        elif self.fp_init == 'x_init':
            z1 = torch.cat([x, x_init], 1)
        else:
            raise ValueError(f"Unknown initialization {self.fp_init}")
        
        # For weight normalization re-computations
        self.deq._reset(x)

        if plot:
            deq_mode = False

        if not deq_mode:
            #res = []
            print(f"Train step {train_step}! Not in DEQ mode!")
            for lnum in range(self.num_layers): 
                z1 = func(z1)
                if plot:
                    cur_z1 = z1.clone().detach().cpu().numpy().reshape(-1)
                    if lnum % 2 == 0:
                        plt.clf()
                        fig = plt.hist(cur_z1, bins=100)
                        logger({f"Activation {lnum}": wandb.Image(plt), "mean": cur_z1.mean(), "std": cur_z1.std(), "min": cur_z1.min(), "max": cur_z1.max()})

            new_z1 = z1 
        else:
            interm_vals = []
            layer_idx = []
            if return_interm_vals:
                layer_idx = np.arange(0, self.f_thres, 1)

            with torch.no_grad():
                result = self.f_solver(func, z1, threshold=self.f_thres, stop_mode=self.stop_mode, layer_idx=layer_idx, name="forward")
                z1 = result['result']
                self.avg_iters += result["nstep"]
                self.total_count += 1
                if return_interm_vals:
                    interm_vals = result["interm_vals"]
                if train_step % 100 == 0 or train_step == -1:
                    print(f"[For] {train_step} {result['nstep']} {min(result['abs_trace'])} {min(result['rel_trace'])}")

            new_z1 = z1

            if self.training:
                new_z1 = func(z1.requires_grad_())
                def backward_hook(grad):
                    if self.hook is not None:
                        self.hook.remove()
                        torch.cuda.synchronize()
                    result = self.b_solver(lambda y: autograd.grad(new_z1, z1, y, retain_graph=True)[0] + grad, torch.zeros_like(grad), 
                                          threshold=self.b_thres, stop_mode=self.stop_mode, name="backward")
                    if train_step % 100 == 0:
                        print(f"[Back] {train_step} {result['nstep']} {min(result['abs_trace'])} {min(result['rel_trace'])}")
                    return result['result']

                self.hook = new_z1.register_hook(backward_hook)

        thought = self.head(new_z1)

        self.interm_thoughts = []
        if return_interm_vals:
            for val in interm_vals:
                val_ = self.head(val)
                self.interm_thoughts.append(val_)
            self.interm_thoughts = torch.stack(self.interm_thoughts)
        
        if self.training:
            return thought, new_z1
        return thought

def deq_net(width, config, **kwargs):
    ### Depth here for just backward compatibility with other models
    ### Depth doesn't make sense for DEQs
    net = DEQMazeNet(width, config)
    return net