import torch
import torch.nn as nn
from utils import *
from models.rnode import quadratic_cost, jacobian_frobenius_regularization_fn, RegularizedODEfunc, divergence_approx


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim, device='cuda'):
    return nn.GroupNorm(min(32, dim), dim, device=device)


class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=False)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut


class RecurrentBlock(nn.Module):

    def __init__(self, module: nn.Module, num_recurrences=1):
        super(RecurrentBlock, self).__init__()
        self.module = module
        self.num_recurrences = num_recurrences

    def forward(self, x):
        for _ in range(self.num_recurrences):
            x = self.module(x)
        return x


class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

class UnitBlock(nn.Module):
    def __init__(self, dim, act_func):
        super(UnitBlock, self).__init__()
        # self.norm1 = norm(dim)
        self.norm1 = nn.Identity()
        self.act_func = act_func(inplace=False)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)

    def forward(self, t, x):
        out = self.norm1(x)
        out = self.act_func(out)
        out = self.conv1(t, out)
        return out

class ConvBlock(nn.Module):
    def __init__(self, dim, act_func):
        super().__init__()
        # self.norm1 = norm(dim)
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.act_func = act_func(inplace=False)

    def forward(self, x):
        out = self.conv1(x)
        out = self.act_func(out)
        return out

class Simulator(nn.Module):
    def __init__(self, dim, num_layers=3, use_norm=True):
        super().__init__()
        layers = []
        if num_layers == 1:
            layers.append(nn.Linear(dim+dim+1, dim))
        else:
            layers.append(nn.Linear(dim+dim+1, dim+dim))
            if use_norm:
                layers.append(nn.LayerNorm(dim+dim))
            layers.append(nn.SiLU(inplace=False))
            for _ in range(num_layers-2):
                layers.append(nn.Linear(dim+dim, dim+dim))
                if use_norm:
                    layers.append(nn.LayerNorm(dim+dim))
                layers.append(nn.SiLU(inplace=False))
            layers.append(nn.Linear(dim+dim, dim))
        self.layers = nn.Sequential(*layers)
    
    def forward(self, z0, z1, t):
        if len(z0.shape) != 2:
            if len(z0.shape) > 2:
                z0 = z0.view(z0.shape[0], -1)
            if len(z0.shape) < 2:
                z0 = z0.view(1, -1)
        if len(z1.shape) != 2:
            if len(z1.shape) > 2:
                z1 = z1.view(z1.shape[0], -1)
            if len(z1.shape) < 2:
                z1 = z1.view(1, -1)
        if len(t.shape) != 2:
            if len(t.shape) > 2:
                t = t.view(t.shape[0], -1)
            if len(t.shape) < 2:
                t = t.view(1, -1)
        z = torch.cat([z0, z1, t], 1)
        z = self.layers(z)
        return z

class ODEfunc(nn.Module):

    def __init__(self, dim, final_norm=True, act_func='relu', dropout=0.0, hidden_dim=0, add_blocks=0, use_norm=False):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim) if use_norm else nn.Identity()
        if act_func == 'relu':
            act_func = nn.ReLU
        elif act_func in ['silu', 'swish']:
            act_func = nn.SiLU
        else:
            raise ValueError(f'Unknown activation function: {act_func}')

        self.act_func = act_func(inplace=False)
        if hidden_dim == 0:
            hidden_dim = dim

        self.conv1 = ConcatConv2d(dim, hidden_dim, 3, 1, 1)
        self.norm2 = norm(hidden_dim) if use_norm else nn.Identity() 
        self.conv2 = ConcatConv2d(hidden_dim, dim, 3, 1, 1)
        self.final_norm = final_norm
        self.drop_p = dropout
        if self.drop_p > 0:
            self.dropout = nn.Dropout(self.drop_p, inplace=False)
        if final_norm:
            self.norm3 = norm(dim)           
        self.nfe = 0
        self.input_dim = dim

        blocks = [UnitBlock(hidden_dim, act_func=act_func) for _ in range(add_blocks)]
        self.blocks = nn.ModuleList(blocks)
        self.divergence_fn = divergence_approx

    def forward(self, t, x):
        self.nfe += 1

        with torch.set_grad_enabled(self.training):
            x.requires_grad_(True)
            t.requires_grad_(True)
        
            out = self.norm1(x)
            out = self.act_func(out)
            out = self.conv1(t, out)

            for block in self.blocks:
                out = block(t, out)

            out = self.norm2(out)
            out = self.act_func(out)
            if self.drop_p > 0:
                out = self.dropout(out)
            out = self.conv2(t, out)
            if self.final_norm:
                return self.norm3(out)

            if self.training:
                self._e = [torch.randn_like(x), ]
                divergence, sqjacnorm = self.divergence_fn(out, x, e=self._e)
                self.sqjacnorm = sqjacnorm
            
        return (out, )



class AppendRepeat(nn.Module):
    '''
    append and apply repeat {rep_dims}
    e.g. rep_dims=(H,W) for (B, C) -> (B, C, H, W)
    '''
    def __init__(self, rep_dims):
        super(AppendRepeat, self).__init__()
        self.rep_dims = rep_dims
    
    def forward(self, x):
        ori_dim = x.ndim
        for _ in range(len(self.rep_dims)):
            x = x.unsqueeze(-1)
        return x.repeat(*[1 for _ in range(ori_dim)], *self.rep_dims)


class PaddingLayer(nn.Module):
    def __init__(self, input_dim, output_dim, mode=0):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.pad = output_dim-input_dim
        self.mode = mode

    def forward(self, x):
        # B, C, H, W
        B, C, H, W = x.shape
        pad = torch.zeros(B, self.pad, H, W, device=x.device, dtype=x.dtype)
        return torch.cat([x, pad], dim=1)

class Scale(nn.Module):
    '''
    Simply scale input to constant
    '''
    def __init__(self, scale):
        super(Scale, self).__init__()
        self.scale = scale
    
    def forward(self, x):
        return x * self.scale


class OurModel(nn.Module):
    def __init__(self, channel_in=3, emb_res=(7,7), device='cuda', label_proj_strategy='repeat', latent_chan=64,
                 norm_fix=False, final_norm=True, in_proj_scale=None, label_proj_scale=None, proj_norm='none',
                 out_norm=True, t_final=1.0, in_latent_chan=64, f_act='relu', h_act='relu', h_dropout=0.0,
                 h_dim=0, h_add_blocks=0, f_add_blocks=0, adjoint=False, augment_dim=0,
                 g_add_blocks=0, h_norm=False):
        '''
        params:
        - channel_in: input channel
        - emb_res: spatial resolution at embedding space
        - in_proj_scale: if set, regularize x->z0 with proj_reg and scale it with in_proj_scale
        - label_proj_scale: if set, regularize y->z1 with proj_reg and scale it with label_proj_scale
        - proj_norm: 'none', 'bn' or 'ln'
        - out_norm: if True, start out_projection layer with GroupNorm
        '''
        from .mlp_model import ODEBlock # TODO: refactor this import structure
        super(OurModel, self).__init__()
        self.augment_dim = augment_dim
        if augment_dim > 0:
            latent_chan = channel_in
            latent_chan = latent_chan + augment_dim

        if proj_norm == 'bn':
            add_norm = lambda dim: nn.BatchNorm2d(dim)
        elif proj_norm == 'ln':
            # add_norm = lambda dim: nn.LayerNorm(dim)
            assert 0
        elif proj_norm == 'gn':
            # add_norm = lambda dim: norm(dim)
            assert 0
        else:
            add_norm = lambda dim: nn.Identity()
            assert proj_norm == 'none'

        if f_act == 'relu':
            f_act_cls = nn.ReLU
        elif f_act in ['silu', 'swish']:
            f_act_cls = nn.SiLU
        else:
            raise ValueError(f'Unknown activation: {f_act}')

        ### in projection
        in_proj_layer = [
            nn.Conv2d(channel_in, in_latent_chan, 3, 1),
            add_norm(in_latent_chan),
            f_act_cls(inplace=False),
        ]

        for _ in range(f_add_blocks):
            in_proj_layer += [
                nn.Conv2d(in_latent_chan, in_latent_chan, 3, 1, 1),
                add_norm(in_latent_chan),
                f_act_cls(inplace=False),
            ]

        in_proj_layer += [
            nn.Conv2d(in_latent_chan, in_latent_chan, 4, 2, 1),
            add_norm(in_latent_chan),
            f_act_cls(inplace=False),
            nn.Conv2d(in_latent_chan, latent_chan, 4, 2, 1),
        ]
        self.in_projection = nn.Sequential(
            *in_proj_layer
        )
        ### out projection
        out_proj_layer = [norm(latent_chan)] if out_norm else []
        out_proj_layer += [
            ConvBlock(latent_chan, act_func=f_act_cls) for _ in range(g_add_blocks)
        ]
        out_proj_layer += [
            nn.ReLU(inplace=False) if g_add_blocks==0 else nn.Identity(),
            nn.AdaptiveAvgPool2d((1, 1)),
            Flatten(),
            nn.Linear(latent_chan, 10),
        ]
        self.out_projection = nn.Sequential(
            *out_proj_layer
        )
        ### label projection
        assert (in_proj_scale is None and label_proj_scale is None) or label_proj_strategy == 'repeat', \
              'proj scale only implemented to repeat strategy'
        if label_proj_strategy == 'repeat':
            label_proj_layer = [
                nn.Linear(10, latent_chan),
                AppendRepeat(emb_res),
            ]
            self.label_projection = nn.Sequential(
                *label_proj_layer
            )
        elif label_proj_strategy == 'reshape':
            target_dim = latent_chan * emb_res[0] * emb_res[1]
            self.label_projection = nn.Sequential(
                nn.Linear(10, target_dim),
                nn.Unflatten(1, (latent_chan, emb_res[0], emb_res[1])),
            )
        elif label_proj_strategy == 'mlp':
            self.label_projection = nn.Sequential(
                nn.Linear(10, latent_chan),
                nn.ReLU(inplace=False),
                nn.Linear(latent_chan, latent_chan),
                nn.ReLU(inplace=False),
                nn.Linear(latent_chan, latent_chan),
                AppendRepeat(emb_res),
            )
        else:
            raise ValueError(f'Unknown label_proj_strategy: {label_proj_strategy}')
        
        ### dynamics model
        if norm_fix:
            assert 0
        else:
            # hard coded: No norm in h
            odefunc = ODEfunc(latent_chan, final_norm=final_norm, act_func=h_act, dropout=h_dropout, hidden_dim=h_dim,
                              add_blocks=h_add_blocks, use_norm=h_norm)
            
            reg_odefunc = RegularizedODEfunc(odefunc)
        
        if augment_dim > 0:
            self.in_projection = PaddingLayer(channel_in, latent_chan, mode=0)

        self.adjoint = adjoint
        self.odeblock = ODEBlock(device, reg_odefunc, is_conv=True, t_final=t_final, adjoint=adjoint)
        self.to(device)
        self.device = device
        self.t_final = t_final


    def forward(self, x, return_features=False, method='dopri5'):
        x = self.in_projection(x)
        features, regs = self.odeblock(x, method=method)
        pred = self.out_projection(features)
        if return_features:
            return features, pred
        return pred, regs

    def steer(self, x, return_features=False, method='dopri5', b=0):
        raise NotImplementedError

    def get_traj(self, x, timesteps=100+1, method='dopri5'):
        '''
        timestep: int
            note: should do +1 to timesteps since it is both start & end inclusive.
        '''
        z0 = self.in_projection(x)
        out = self.odeblock.trajectory(z0, timesteps, method=method)[0]
        return out, self.out_projection(out[-1].clone())


    def get_traj_cheat(self, x, y, epsilon, timesteps=1+1, method='dopri5'):
        '''
        same as get_traj but starts with z_eps = (1-eps) * z0 + eps * y
        '''
        z0 = self.in_projection(x)
        z1 = self.label_projection(y)
        z_eps = (1-epsilon) * z0 + epsilon * z1
        eval_t = torch.linspace(epsilon, 1., timesteps)
        out = self.odeblock.trajectory(z_eps, eval_t, method=method)
        return out, self.out_projection(out[-1].clone())

    
    def pred_v(self, z, t):
        self.odeblock.odefunc.nfe = 0
        return self.odeblock.odefunc(t, z)
