
import torch.nn as nn
from models.modules import EmbSeq, ConvBlock1, ConvBlock2, FCBlock
import torch

class ConvResNet(nn.Module):
    def __init__(self, width: int, depth: int, output_size: int):
        super(ConvResNet, self).__init__()
        self.blockType = "ConvBlock1"
        self.num_input_channels = 3
        self.width = width
        self.depth = depth
        self.num_embs = self.depth
        self.num_matrices = 2*self.depth
        self.alpha = 1
        self.output_size = output_size
        self.activation = nn.ReLU()

        layers = [nn.Conv2d(self.num_input_channels, self.width, 2, 2),
                  nn.BatchNorm2d(self.width)]

        layers += [nn.Conv2d(self.width, self.width, 2, 2),
                   nn.BatchNorm2d(self.width),
                   self.activation]

        for i in range(self.depth):
            layers += [ConvBlock1(self.width, self.alpha),
                       self.activation]

        self.layers = EmbSeq(layers)
        print(f"layers:{layers}")
        self.analysis_layers = nn.ModuleList([k for k in layers if type(k)
                                              == ConvBlock1])

        self.fc = nn.Linear(self.width*8*8, self.output_size)

    def forward(self, x):
        output, inter_outputs = self.layers(x)
        output = output.view(output.shape[0], -1)
        output = self.fc(output)
        # return out5, [out1, out2, out3, out4, out5]
        return output, inter_outputs

def convresnet(settings):
    return ConvResNet(settings)

if __name__ == "__main__":
    ConvResNet(width=100, depth=5, output_size=512)

