from torchvision.models.video.resnet import BasicBlock, R2Plus1dStem, Conv2Plus1D, VideoResNet
import torch
import torch.nn as nn
import torchvision.models.resnet


def build_video_encoder(num_classes):
    model = VideoResNet(block=BasicBlock,
                        conv_makers=[Conv2Plus1D] * 4,
                        layers=[2, 2, 2, 2],
                        stem=R2Plus1dStem)

    model.fc = nn.Linear(model.fc.in_features, out_features=num_classes)

    # Fix difference in PyTorch vs Caffe2 architecture
    # https://github.com/facebookresearch/VMZ/issues/89
    # https://github.com/pytorch/vision/issues/1265
    model.layer2[0].conv2[0] = Conv2Plus1D(128, 128, 288)
    model.layer3[0].conv2[0] = Conv2Plus1D(256, 256, 576)
    model.layer4[0].conv2[0] = Conv2Plus1D(512, 512, 1152)

    # We need exact Caffe2 momentum for BatchNorm scaling
    for m in model.modules():
        if isinstance(m, nn.BatchNorm3d):
            m.eps = 1e-3
            m.momentum = 0.9
    return model


class XDC(nn.Module):
    def __init__(self, vid_arch, aud_base_arch):
        super(XDC, self).__init__()
        self.video_network = vid_arch
        self.audio_network = aud_base_arch

    def forward(self, img, spec, whichhead=0):
        img_features = self.video_network(img).squeeze()
        aud_features = self.audio_network(spec).squeeze()
        return img_features, aud_features
    

class AudioEncoder(torchvision.models.ResNet):
    def __init__(self, use_max_pool=False):
        super().__init__(torchvision.models.resnet.BasicBlock, [3, 2, 2, 2], 400, width_per_group=64)

        self.inplanes = 16
        self.dilation = 1

        self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=2, padding=1, bias=True)
        self.bn1 = nn.Identity()
        if not use_max_pool:
            self.maxpool = nn.Identity()

        self.layer1 = self._make_layer(torchvision.models.resnet.BasicBlock, 32, 3, stride=1, dilate=False)
        # self.layer1.relu = nn.Identity()
        self.layer2 = self._make_layer(torchvision.models.resnet.BasicBlock, 64, 2, stride=2, dilate=False)
        self.layer3 = self._make_layer(torchvision.models.resnet.BasicBlock, 128, 2, stride=2, dilate=False)
        self.layer4 = self._make_layer(torchvision.models.resnet.BasicBlock, 256, 2, stride=2, dilate=False)
        self.layer5 = self._make_layer(torchvision.models.resnet.BasicBlock, 512, 2, stride=2, dilate=False)
        self.avgpool = torch.nn.MaxPool2d((2, 4), stride=1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                m.eps = 1e-3
                m.momentum = 0.9

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        #x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
