
import torch
from torchvision import models
import torchvision
from torch.hub import load_state_dict_from_url
from torchvision.models.resnet import Bottleneck, model_urls
from style_adaptation import *

__all__ = ['ResNet', 'resnet50', 'resnet101']

class ResNet(models.ResNet):
    """ResNets without fully connected layer"""

    def __init__(self, *args, **kwargs):
        super(ResNet, self).__init__(*args, **kwargs)
        self._out_features = self.fc.in_features
        del self.fc  
        self.stymix = Style_Adaptation()

    def forward(self, x):
        """"""
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        if self.training:
            x_ex = self.stymix(x)
            x_ex = self.layer1(x_ex)
            x_ex = self.stymix(x_ex)
            x_ex = self.layer2(x_ex)
            # x_ex = self.stymix(x_ex)
        else:
            x_ex = self.layer1(x)
            x_ex = self.layer2(x_ex)
        x_ex = self.layer3(x_ex)
        x_ex = self.layer4(x_ex)

        x_ex = self.avgpool(x_ex)
        x_ex = torch.flatten(x_ex, 1)
        x_ex = x_ex.view(-1, self._out_features)
        return x_ex

    @property
    def out_features(self) -> int:
        """The dimension of output features"""
        return self._out_features


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model


def resnet50(pretrained=False, progress=True, **kwargs):
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet101(pretrained=False, progress=True, **kwargs):
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)