from .spectral import Spectral
from .simsiam import SimSiam
import torch
from .backbones import resnet18_cifar_variant1, resnet18_cifar_variant2
from .backbones import resnet18_cifar_variant1_mlp1000_norelu
from .backbones import resnet50_mlp8192_norelu_3layer, resnet50_mlp1024_norelu_3layer, resnet50_mlp2048_norelu_2layer
from torchvision.models import resnet50, resnet18
from .simclr import SimCLR

def get_backbone(backbone, castrate=True):
    backbone = eval(f"{backbone}()")

    if castrate:
        backbone.output_dim = backbone.fc.in_features
        backbone.fc = torch.nn.Identity()

    return backbone


def get_model(model_cfg):
    if model_cfg.name == 'spectral':
        if "mu" not in model_cfg.__dict__:
            model_cfg.mu = 1.0
        model = Spectral(get_backbone(model_cfg.backbone), mu=model_cfg.mu)
    elif model_cfg.name == 'simsiam':
        model = SimSiam(get_backbone(model_cfg.backbone))
        if model_cfg.proj_layers is not None:
            model.projector.set_layers(model_cfg.proj_layers)
    elif model_cfg.name == 'simclr':
        model = SimCLR(get_backbone(model_cfg.backbone))
    else:
        raise NotImplementedError
    return model






