import math
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from scipy.linalg import block_diag

class NCBlock(nn.Module):
    def __init__(self, in_channels, in_height, in_width,
                    spatial_height=24, spatial_width=24,
                    reduction=8, size_is_consistant=True):
        # nblock_channel: number of block along channel axis
        # spatial_size: spatial_size
        super(NCBlock, self).__init__()

        # set channel splits
        if in_channels <= 512:
            self.nblocks_channel = 1
        else:
            self.nblocks_channel = in_channels // 512
        block_size = in_channels // self.nblocks_channel
        block = torch.Tensor(block_size, block_size).fill_(1)
        self.mask = torch.Tensor(in_channels, in_channels).fill_(0)
        for i in range(self.nblocks_channel):
            self.mask[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) * block_size].copy_(block)

        # set spatial splits
        if in_height * in_width < 32 * 32 and size_is_consistant:
            self.spatial_area = in_height * in_width
            self.spatial_height = in_height
            self.spatial_width = in_width
        else:
            self.spatial_area = spatial_height * spatial_width
            self.spatial_height = spatial_height
            self.spatial_width = spatial_width

        self.fc_in = nn.Sequential(
            nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
            nn.ReLU(True),
            nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
        )

        self.fc_out = nn.Sequential(
            nn.Conv1d(self.spatial_area, self.spatial_area // reduction, 1, 1, 0, bias=True),
            nn.ReLU(True),
            nn.Conv1d(self.spatial_area // reduction, self.spatial_area, 1, 1, 0, bias=True),
        )

        self.bn = nn.BatchNorm1d(self.spatial_area)

    def forward(self, x):
        '''
        :param x: (bt, c, h, w)
        :return:
        '''
        bt, c, h, w = x.shape
        residual = x
        x_stretch = x.view(bt, c, h * w)
        spblock_h = int(np.ceil(h / self.spatial_height))
        spblock_w = int(np.ceil(w / self.spatial_width))
        stride_h = int((h - self.spatial_height) / (spblock_h - 1)) if spblock_h > 1 else 0
        stride_w = int((w - self.spatial_width) / (spblock_w - 1)) if spblock_w > 1 else 0

        x_stacked = x_stretch # (b) x c x (h * w)
        x_stacked = x_stacked.view(bt * self.nblocks_channel, c // self.nblocks_channel, -1)
        x_v = x_stacked.permute(0, 2, 1).contiguous() # (b) x (h * w) x c
        x_v = self.fc_in(x_v) # (b) x (h * w) x c
        x_m = x_v.mean(1).view(-1, 1, c // self.nblocks_channel).detach() # (b * h * w) x 1 x c
        score = -(x_m - x_m.permute(0, 2, 1).contiguous())**2 # (b * h * w) x c x c
        score.masked_fill_(self.mask.unsqueeze(0).expand_as(score).type_as(score).eq(0), -np.inf)
        attn = F.softmax(score, dim=1) # (b * h * w) x c x c
        out = self.bn(self.fc_out(torch.bmm(x_v, attn))) # (b) x (h * w) x c
        out = out.permute(0, 2, 1).contiguous().view(bt, c, h, w)
        return F.relu(residual + out)
