import torch.nn as nn
from .ops import *


class stem(nn.Module):
    num_layer = 1

    def __init__(self, conv, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d):
        super(stem, self).__init__()

        self.conv1 = conv(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        return out


class basic(nn.Module):
    expansion = 1
    num_layer = 2

    def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
        super(basic, self).__init__()
        midplanes = planes if midplanes is None else midplanes
        self.conv1 = conv(inplanes, midplanes, stride)
        self.bn1 = norm_layer(midplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv(midplanes, planes)
        self.bn2 = norm_layer(planes)
        if stride!=1 or inplanes!=planes*self.expansion:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, planes, stride),
                norm_layer(planes),
            )
        else:
            self.downsample = None

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class bottleneck(nn.Module):
    expansion = 4
    num_layer = 3

    def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
        super(bottleneck, self).__init__()
        midplanes = planes if midplanes is None else midplanes
        self.conv1 = conv1x1(inplanes, midplanes)
        self.bn1 = norm_layer(midplanes)
        self.conv2 = conv(midplanes, midplanes, stride)
        self.bn2 = norm_layer(midplanes)
        self.conv3 = conv1x1(midplanes, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        if stride!=1 or inplanes!=planes*self.expansion:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, planes*self.expansion, stride),
                norm_layer(planes*self.expansion),
            )
        else:
            self.downsample = None

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class invert(nn.Module):
    def __init__(self, conv, inp, oup, stride=1, expand_ratio=1, norm_layer=nn.BatchNorm2d):
        super(invert, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                conv(hidden_dim, hidden_dim, stride),
                norm_layer(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                norm_layer(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                norm_layer(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                conv(hidden_dim, hidden_dim, stride),
                norm_layer(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                norm_layer(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


invert2 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=2, **kwargs)
invert3 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=3, **kwargs)
invert4 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=4, **kwargs)
invert6 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=6, **kwargs)


def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    # reshape
    x = x.view(batchsize, groups, channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    # flatten
    x = x.view(batchsize, -1, height, width)
    return x


class shuffle(nn.Module):
    expansion = 1
    num_layer = 3

    def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
        super(shuffle, self).__init__()
        inplanes = inplanes // 2 if stride == 1 else inplanes
        midplanes = outplanes // 2 if midplanes is None else midplanes
        rightoutplanes = outplanes - inplanes
        if stride == 2:
            self.left_branch = nn.Sequential(
                # dw
                conv(inplanes, inplanes, stride),
                norm_layer(inplanes),
                # pw-linear
                conv1x1(inplanes, inplanes),
                norm_layer(inplanes),
                nn.ReLU(inplace=True),
            )

        self.right_branch = nn.Sequential(
            # pw
            conv1x1(inplanes, midplanes),
            norm_layer(midplanes),
            nn.ReLU(inplace=True),
            # dw
            conv(midplanes, midplanes, stride),
            norm_layer(midplanes),
            # pw-linear
            conv1x1(midplanes, rightoutplanes),
            norm_layer(rightoutplanes),
            nn.ReLU(inplace=True),
        )

        self.reduce = stride==2

    def forward(self, x):
        if self.reduce:
            out = torch.cat((self.left_branch(x), self.right_branch(x)), 1)
        else:
            x1 = x[:, :(x.shape[1]//2), :, :]
            x2 = x[:, (x.shape[1]//2):, :, :]
            out = torch.cat((x1, self.right_branch(x2)), 1)

        return channel_shuffle(out, 2)


class shufflex(nn.Module):
    expansion = 1
    num_layer = 3

    def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
        super(shufflex, self).__init__()
        inplanes = inplanes // 2 if stride == 1 else inplanes
        midplanes = outplanes // 2 if midplanes is None else midplanes
        rightoutplanes = outplanes - inplanes
        if stride==2:
            self.left_branch = nn.Sequential(
                # dw
                conv(inplanes, inplanes, stride),
                norm_layer(inplanes),
                # pw-linear
                conv1x1(inplanes, inplanes),
                norm_layer(inplanes),
                nn.ReLU(inplace=True),
            )

        self.right_branch = nn.Sequential(
            # dw
            conv(inplanes, inplanes, stride),
            norm_layer(inplanes),
            # pw-linear
            conv1x1(inplanes, midplanes),
            norm_layer(midplanes),
            nn.ReLU(inplace=True),
            # dw
            conv(midplanes, midplanes, 1),
            norm_layer(midplanes),
            # pw-linear
            conv1x1(midplanes, midplanes),
            norm_layer(midplanes),
            nn.ReLU(inplace=True),
            # dw
            conv(midplanes, midplanes, 1),
            norm_layer(midplanes),
            # pw-linear
            conv1x1(midplanes, rightoutplanes),
            norm_layer(rightoutplanes),
            nn.ReLU(inplace=True),
        )

        self.reduce = stride==2

    def forward(self, x):
        if self.reduce:
            out = torch.cat((self.left_branch(x), self.right_branch(x)), 1)
        else:
            x1 = x[:, :(x.shape[1] // 2), :, :]
            x2 = x[:, (x.shape[1] // 2):, :, :]
            out = torch.cat((x1, self.right_branch(x2)), 1)

        return channel_shuffle(out, 2)