""" feedforward_net_recallx_2d.py
    Feed-forward 2D convolutional neural network.
    With recall using inputs.
    August 2021
"""

import torch
from torch import nn
import torch.nn.functional as F

from icecream import ic

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class BasicBlock(nn.Module):
    """Basic residual block class"""

    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes,
                                                    kernel_size=1, stride=stride, bias=False))

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class FeedForwardNet(nn.Module):
    """Modified ResidualNetworkSegment model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, max_iters=8):
        super(FeedForwardNet, self).__init__()

        self.width = int(width)
        self.recall = recall

        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3, stride=1, padding=1, bias=False)

        if self.recall:
            self.recall_layer = nn.Conv2d(width+in_channels, width, kernel_size=3,
                                          stride=1, padding=1, bias=False)
        else:
            self.recall_layer = nn.Sequential()

        self.feedforward_layers = nn.ModuleList()
        for i in range(max_iters):
            internal_block = []
            for j in range(len(num_blocks)):
                internal_block.append(self._make_layer(block, width, num_blocks[j], stride=1))
            self.feedforward_layers.append(nn.Sequential(*internal_block))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3, stride=1, padding=1, bias=False)

        self.iters = max_iters
        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.head = nn.Sequential(head_conv1, nn.ReLU(), head_conv2, nn.ReLU(), head_conv3)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, interim_thought=None, **kwargs):
        n, k = kwargs.get("n", 0), kwargs.get("k", self.iters)
        assert (n + k) <= self.iters
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), k, 2, x.size(2), x.size(3))).to(x.device)

        for i, layer in enumerate(self.feedforward_layers[n:n+k]):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
                interim_thought = self.recall_layer(interim_thought)
            interim_thought = layer(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            return out, interim_thought
        else:
            return all_outputs


def feedforward_net_recallx_2d(width, **kwargs):
    return FeedForwardNet(BasicBlock, [2], width, recall=True, in_channels=kwargs["in_channels"],
                          max_iters=kwargs["max_iters"])


# Testing
if __name__ == "__main__":
    net = feedforward_net_recallx_2d(width=5, in_channels=3, max_iters=5)
    print(net)
    x = torch.rand(4*3*5*5).reshape([4, 3, 5, 5])
    out, _ = net(x)
    print(out.shape)
    out, _ = net(x, n=2, k=2)
    print(out.shape)

    net.eval()
    outputs = net(x)
    print(outputs.shape)
    outputs = net(x, n=2, k=2)
    print(outputs.shape)
