import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np
import sys
import matplotlib.pyplot as plt
import wandb
sys.path.append("lib/")
sys.path.append("../")
from deepthinking.lib.solvers import anderson, broyden
from deepthinking.lib.optimizations import weight_norm, VariationalHidDropout2d
from .pixel_norm import PixelNormalization

NUM_GROUPS = 4
BLOCK_GN_AFFINE = True  

class BasicBlock(nn.Module):
    """Basic residual block class"""
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, wnorm=False, norm_type='group'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv1d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False)
            )

        print(f"Using NUMGROUP {NUM_GROUPS} norm_type {norm_type}")
        if norm_type == 'group':
            self.gn1 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
            self.gn2 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
            self.gn3 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
        elif norm_type == 'instance':
            self.gn1 = nn.InstanceNorm2d(planes)
            self.gn2 = nn.InstanceNorm2d(planes)
            self.gn3 = nn.InstanceNorm2d(planes)
        elif norm_type == 'batch':
            self.gn1 = nn.BatchNorm2d(planes)
            self.gn2 = nn.BatchNorm2d(planes)
            self.gn3 = nn.BatchNorm2d(planes)
        elif norm_type == 'pixel':
            self.gn1 = PixelNormalization()
            self.gn2 = PixelNormalization()
            self.gn3 = PixelNormalization()
        elif norm_type == 'none':
            self.gn1 = nn.Sequential()
            self.gn2 = nn.Sequential()
            self.gn3 = nn.Sequential()
        else:
            raise ValueError(f"Unknown normalization type {norm_type}")

        if wnorm: self._wnorm()

    def _wnorm(self):
        """
        Register weight normalization
        """
        self.conv1, self.conv1_fn = weight_norm(self.conv1, names=['weight'], dim=0)
        self.conv2, self.conv2_fn = weight_norm(self.conv2, names=['weight'], dim=0)
        if len(self.shortcut) > 0:
            self.shortcut[0].conv, self.shortcut_fn = weight_norm(self.shortcut[0].conv, names=['weight'], dim=0)
    
    def _reset(self, bsz, d, H):
        """
        Reset dropout mask and recompute weight via weight normalization
        """
        if 'conv1_fn' in self.__dict__:
            self.conv1_fn.reset(self.conv1)
        if 'conv2_fn' in self.__dict__:
            self.conv2_fn.reset(self.conv2)
        if 'shortcut_fn' in self.__dict__:
            self.shortcut_fn.reset(self.shortcut[0].conv)

    def forward(self, x, injection=None):
        if injection is None: injection = 0
        out = self.conv1(x)
        out = F.relu(self.gn1(out))
        out = self.conv2(out) + injection
        out += self.shortcut(x)
        out = self.gn3(F.relu(self.gn2(out)))
        return out

class BasicBlockV2(nn.Module):
    """Basic residual block class"""
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, wnorm=False, norm_type='group'):
        super(BasicBlockV2, self).__init__()
        self.conv1 = nn.Conv1d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False)
            )
        print(f"Using NUMGROUP {NUM_GROUPS} norm_type {norm_type}")
        if norm_type == 'group':
            self.gn1 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
            self.gn2 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
            self.gn3 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
        elif norm_type == 'instance':
            self.gn1 = nn.InstanceNorm2d(planes)
            self.gn2 = nn.InstanceNorm2d(planes)
            self.gn3 = nn.InstanceNorm2d(planes)
        elif norm_type == 'batch':
            self.gn1 = nn.BatchNorm2d(planes)
            self.gn2 = nn.BatchNorm2d(planes)
            self.gn3 = nn.BatchNorm2d(planes)
        elif norm_type == 'pixel':
            self.gn1 = PixelNormalization()
            self.gn2 = PixelNormalization()
            self.gn3 = PixelNormalization()
        elif norm_type == 'none':
            self.gn1 = nn.Sequential()
            self.gn2 = nn.Sequential()
            self.gn3 = nn.Sequential()
        else:
            raise ValueError(f"Unknown normalization type {norm_type}")

        if wnorm: self._wnorm()

    def _wnorm(self):
        """
        Register weight normalization
        """
        self.conv1, self.conv1_fn = weight_norm(self.conv1, names=['weight'], dim=0)
        self.conv2, self.conv2_fn = weight_norm(self.conv2, names=['weight'], dim=0)
        if len(self.shortcut) > 0:
            self.shortcut[0].conv, self.shortcut_fn = weight_norm(self.shortcut[0].conv, names=['weight'], dim=0)
    
    def _reset(self, bsz, d, H):
        """
        Reset dropout mask and recompute weight via weight normalization
        """
        if 'conv1_fn' in self.__dict__:
            self.conv1_fn.reset(self.conv1)
        if 'conv2_fn' in self.__dict__:
            self.conv2_fn.reset(self.conv2)
        if 'shortcut_fn' in self.__dict__:
            self.shortcut_fn.reset(self.shortcut[0].conv)

    def forward(self, x, injection=None):
        if injection is None: injection = 0
        out = self.conv1(x)
        out = F.relu(self.gn1(out))
        out = self.gn2(self.conv2(out) + injection)
        out += self.shortcut(x)
        out = self.gn3(F.relu(out))
        return out

class BasicBlockV3(nn.Module):
    """Basic residual block class"""
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, wnorm=False, norm_type='group'):
        super(BasicBlockV3, self).__init__()
        self.conv1 = nn.Conv1d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False)
            )
        print(f"Using NUMGROUP {NUM_GROUPS} norm_type {norm_type}")
        if norm_type == 'group':
            self.gn1 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
            self.gn2 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
        elif norm_type == 'instance':
            self.gn1 = nn.InstanceNorm2d(planes)
            self.gn2 = nn.InstanceNorm2d(planes)
        elif norm_type == 'batch':
            self.gn1 = nn.BatchNorm2d(planes)
            self.gn2 = nn.BatchNorm2d(planes)
        elif norm_type == 'pixel':
            self.gn1 = PixelNormalization()
            self.gn2 = PixelNormalization()
        elif norm_type == 'none':
            self.gn1 = nn.Sequential()
            self.gn2 = nn.Sequential()
        else:
            raise ValueError(f"Unknown normalization type {norm_type}")

        if wnorm: self._wnorm()

    def _wnorm(self):
        """
        Register weight normalization
        """
        self.conv1, self.conv1_fn = weight_norm(self.conv1, names=['weight'], dim=0)
        self.conv2, self.conv2_fn = weight_norm(self.conv2, names=['weight'], dim=0)
        if len(self.shortcut) > 0:
            self.shortcut[0].conv, self.shortcut_fn = weight_norm(self.shortcut[0].conv, names=['weight'], dim=0)
    
    def _reset(self, bsz, d, H):
        """
        Reset dropout mask and recompute weight via weight normalization
        """
        if 'conv1_fn' in self.__dict__:
            self.conv1_fn.reset(self.conv1)
        if 'conv2_fn' in self.__dict__:
            self.conv2_fn.reset(self.conv2)
        if 'shortcut_fn' in self.__dict__:
            self.shortcut_fn.reset(self.shortcut[0].conv)

    def forward(self, x, injection=None):
        if injection is None: injection = 0
        out = self.conv1(x)
        out = F.relu(self.gn1(out))
        out = self.conv2(out) + injection
        out += self.shortcut(x)
        out = F.relu(self.gn2(out))
        return out

blocks_dict = { 'BASIC': BasicBlock, 'BASIC2': BasicBlockV2, 'BASIC3': BasicBlockV3}

class DEQModule(nn.Module):
    def __init__(self, block, num_blocks, width, norm_type='group'):
        super(DEQModule, self).__init__()
        self.in_planes = int(width)
        self.num_blocks = num_blocks 
        self.norm_type = norm_type
        self.recur_layer = self._make_layer(block, width, num_blocks, stride=1)

    def _wnorm(self):
        """
        Apply weight normalization to the learnable parameters of MDEQ
        """
        for i, block in enumerate(self.recur_layer):
            block._wnorm()
        
        # Throw away garbage
        torch.cuda.empty_cache()
        
    def _reset(self, xs):
        """
        Reset the dropout mask and the learnable parameters (if weight normalization is applied)
        """
        for i, block in enumerate(self.recur_layer):
            block._reset(*xs.shape)

    def _make_layer(self, block, planes, num_blocks, stride=1):
        """
        Make a specific branch indexed by `branch_index`. This branch contains `num_blocks` residual blocks of type `block`.
        """
        # if num_blocks == 1:
        #     return block(self.in_planes, planes, stride)
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.in_planes, planes, strd, norm_type=self.norm_type))
            self.in_planes = planes * block.expansion
        return nn.ModuleList(layers)

    def forward(self, x, injection):
        for i in range(self.num_blocks):
            x = self.recur_layer[i](x, injection)
        return x

class DEQMazeNet(nn.Module):
    """Modified ResNet model class"""

    def __init__(self, width, config, **kwargs):
        super(DEQMazeNet, self).__init__()
        self.kwargs = kwargs
        self.config = config
        self.parse_cfg(config)
        self.hook = None
        self.width = width

        self.proj_conv = nn.Conv1d(2, self.width, kernel_size=3,
                               stride=1, padding=1, bias=False)
        conv2 = nn.Conv1d(self.width, self.width, kernel_size=3,
                               stride=1, padding=1, bias=False)
        conv3 = nn.Conv1d(self.width, int(self.width/2), kernel_size=3,
                               stride=1, padding=1, bias=False)
        conv4 = nn.Conv1d(int(self.width/2), 2, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.head = nn.Sequential(  conv2, nn.ReLU(),
                                    conv3, nn.ReLU(),
                                    conv4)
        self.deq = DEQModule(blocks_dict[self.block_type], self.num_blocks, self.width, self.norm_type)

        self.avg_iters = 0
        self.total_count = 0

        if self.wnorm:
            self.deq._wnorm()

    def parse_cfg(self, config):
        cfg = config.problem
        # DEQ related
        self.f_solver = eval(cfg.deq.f_solver)
        self.b_solver = eval(cfg.deq.b_solver)
        if self.b_solver is None:
            self.b_solver = self.f_solver

        self.f_thres = cfg.deq.f_thres
        self.b_thres = cfg.deq.b_thres
        self.stop_mode = cfg.deq.stop_mode

        # Model related
        self.num_layers = cfg.deq.num_layers
        self.num_blocks = cfg.deq.num_blocks
        self.block_type = cfg.deq.extra.block 
        self.in_channels = cfg.deq.in_channels

        self.anderson_lam   = cfg.deq.solver.lam
        self.anderson_m     = cfg.deq.solver.m

        # Training related config
        self.pretrain_steps = cfg.train.pretrain_steps
        self.wnorm = cfg.deq.wnorm

        global NUM_GROUPS 
        NUM_GROUPS = cfg.deq.num_groups

        print(f"Using {cfg.deq.norm} normalization")
        self.norm_type = cfg.deq.norm

        self.fp_init = cfg.deq.fp_init

    def forward(self, x_init, train_step=-1, return_interm_vals=False, plot=False, logger=None):
        #print(f"Before: {self.avg_iters} {self.total_count}")
        x = F.relu(self.proj_conv(x_init))

        deq_mode = (train_step < 0) or (train_step >= self.pretrain_steps)
        func = lambda z: self.deq(z, x)

        if self.fp_init == 'zeros':
            z1 = torch.zeros_like(x, device=x_init.device)
        elif self.fp_init == 'x_proj':
            z1 = x.clone()
        elif self.fp_init == 'x_init':
            z1 = torch.cat([x, x_init], 1)
        else:
            raise ValueError(f"Unknown initialization {self.fp_init}")
        
        # For weight normalization re-computations
        if self.wnorm:
            self.deq._reset(x)
        
        if not deq_mode:
            # prev = z1.clone()
            print(f"Train step {train_step}! Not in DEQ mode!")
            for lnum in range(self.num_layers): 
                z1 = func(z1)
                # abs_diff = (next_z1 - prev).norm().item()
                # rel_diff = abs_diff / (1e-5 + abs_diff.norm().item())

                # logger({
                #     "Abs trace": abs_diff,
                #     "Rel trace": rel_diff
                # })
                # if plot:
                #     cur_z1 = z1.clone().detach().cpu().numpy().reshape(-1)
                #     if lnum % 2 == 0:
                #         plt.clf()
                #         fig = plt.hist(cur_z1, bins=100)
                #         logger({f"Activation {lnum}": wandb.Image(plt), "mean": cur_z1.mean(), "std": cur_z1.std(), "min": cur_z1.min(), "max": cur_z1.max()})

            new_z1 = z1 
        else:
            interm_vals = []
            layer_idx = []
            if return_interm_vals:
                layer_idx = np.arange(0, self.f_thres, 5)

            with torch.no_grad():
                #result = self.f_solver(func, z1, lam=self.anderson_lam, m=self.anderson_m, threshold=self.f_thres, stop_mode=self.stop_mode, layer_idx=layer_idx, name="forward")
                result = self.f_solver(func, z1, threshold=self.f_thres, stop_mode=self.stop_mode, layer_idx=layer_idx, name="forward")
                z1 = result['result']
                
                abs_trace = result['abs_trace']
                rel_trace = result['rel_trace']

                if return_interm_vals:
                    interm_vals = result["interm_vals"]
                
                if train_step % 500 == 0 or train_step == -1:
                    print(f"[For] {train_step} {result['nstep']} {min(result['abs_trace'])} {min(result['rel_trace'])}")

            self.avg_iters += result["nstep"]
            self.total_count += 1

            new_z1 = z1

            if self.training:
                new_z1 = func(z1.requires_grad_())
                def backward_hook(grad):
                    if self.hook is not None:
                        self.hook.remove()
                        torch.cuda.synchronize()
                    result = self.b_solver(lambda y: autograd.grad(new_z1, z1, y, retain_graph=True)[0] + grad, torch.zeros_like(grad), 
                                          lam=self.anderson_lam, m=self.anderson_m, threshold=self.b_thres, stop_mode=self.stop_mode, name="backward")
                    if train_step % 500 == 0:
                        print(f"[Back] {train_step} {result['nstep']} {min(result['abs_trace'])} {min(result['rel_trace'])}")
                    return result['result']

                self.hook = new_z1.register_hook(backward_hook)

        thought = self.head(new_z1)

        interm_thoughts = []
        if return_interm_vals:
            for val in interm_vals:
                val_ = self.head(val)
                interm_thoughts.append(val_)
            interm_thoughts = torch.stack(interm_thoughts)
        
        if plot:
            return thought, interm_thoughts, abs_trace, rel_trace

        if self.training:
            return thought, new_z1

        #print(f"{self.avg_iters/self.total_count} {self.avg_iters} {self.total_count}")
        return thought

def deq_net_1d(width, config, **kwargs):
    ### Depth here for just backward compatibility with other models
    ### Depth doesn't make sense for DEQs
    net = DEQMazeNet(width, config)
    #net.init_weights()
    return net
