from .sparsecut import Sparsecut
import torch
from .backbones import resnet18_cifar_variant1, resnet18_cifar_variant2
from .backbones import resnet50_mlp1_norelu, resnet50_mlp2_norelu, resnet50_mlp3_norelu, resnet50_mlp4_norelu, resnet50_mlp5_norelu, resnet50_mlp6_norelu, resnet50_mlp10_norelu, resnet50_mlp20_norelu, resnet50_mlp100_norelu, resnet50_mlp1000_norelu, resnet50_mlp2000_norelu, resnet50_mlp2048_norelu, resnet50_mlp256_norelu, resnet50_mlp512_norelu
from .backbones import resnet18_cifar_variant1_mlp1000_norelu, resnet18_cifar_variant1_mlp512_norelu, resnet18_cifar_variant1_mlp256_norelu, resnet18_cifar_variant1_mlp128_norelu, resnet18_cifar_variant1_mlp64_norelu, resnet18_cifar_variant1_mlp32_norelu, resnet18_cifar_variant1_mlp16_norelu, resnet18_cifar_variant1_mlp8_norelu, resnet18_cifar_variant1_mlp4_norelu, resnet18_cifar_variant1_mlp2_norelu, resnet50_cifar_variant1_mlp8_norelu, resnet50_cifar_variant1_mlp512_norelu
from .backbones import resnet50_mlp2048_norelu_3layer, resnet50_mlp1024_norelu_3layer, resnet50_mlp4096_norelu_3layer, resnet50_mlp8192_norelu_3layer, resnet50_mlp16384_norelu_3layer


def get_backbone(backbone, castrate=True):
    backbone = eval(f"{backbone}()")

    if castrate:
        backbone.output_dim = backbone.fc.in_features
        backbone.fc = torch.nn.Identity()

    return backbone


def get_model(model_cfg):
    if model_cfg.name == 'sparsecut':
        if "lam" not in model_cfg.__dict__:
            model_cfg.lam = 1.0
        if "r" not in model_cfg.__dict__:
            model_cfg.r = 1.0
        model = Sparsecut(get_backbone(model_cfg.backbone), model_cfg.version, model_cfg.lam, r=model_cfg.r)

        if model_cfg.proj_layers is not None:
            model.projector.set_layers(model_cfg.proj_layers)
    else:
        raise NotImplementedError
    return model






