import torch
from models.fp16 import network_to_half
from config import cfg

def get_resnet_for_cifar():
    from models.cifar.resnet import ResNet18
    return ResNet18(cfg.model.embedding_size, cfg.model.num_class)

def get_vgg16_for_cifar():
    from models.cifar.vgg import VGG
    return VGG('VGG16', cfg.model.num_class)

def get_convnet_for_mnist():
    from models.mnist.convnet import ConvNet
    return ConvNet()

def get_resnet50_for_imagenet():
    from models.imagenet.resnet50 import Resnet50
    return Resnet50(cfg.model.num_class)

def get_vgg16_for_imagenet():
    from models.imagenet.vgg16 import VGG16
    return VGG16(cfg.model.num_class)

def get_vgg16_gap_for_imagenet():
    from models.imagenet.vgg16_gap import VGG16GAP
    return VGG16GAP(cfg.model.num_class)

def get_wrn_for_cifar():
    from models.cifar.wrn import Wide_ResNet
    return Wide_ResNet(28, 10, 0.3, cfg.model.num_class)

def get_resnet9_for_cifar():
    from models.cifar.resnet9 import resnet9
    return resnet9(cfg.model.num_class)

def get_se_resnet50():
    from models.senets.se_resnet50 import SEResnet50
    return SEResnet50(cfg.model.num_class, cfg.model.pretrained)

def get_cbam_resnet50():
    from models.cbam.imp import CBAMResnet50
    return CBAMResnet50(cfg.model.num_class)

def get_gbam_resnet50():
    from models.gbam.gbam import GBAMResnet50
    return GBAMResnet50(cfg.model.num_class, cfg.model.pretrained)

def get_resnet56():
    from models.cifar.resnet56 import resnet56
    return resnet56(cfg.model.num_class)

def get_densenet40():
    from models.cifar.densenet import DenseNet
    return DenseNet(depth = 40, num_classes = cfg.model.num_class, growth_rate = 12, reduction = 1.0, bottleneck = False, dropRate = 0.0)

def get_mod_densenet40():
    from models.cifar.mod_densenet import ModDenseNet
    return ModDenseNet(depth = 40, num_classes = cfg.model.num_class, growth_rate = 12, reduction = 1.0, bottleneck = False, dropRate = 0.0)

def get_preact_resnet56_for_cifar():
    from models.cifar.preresnet import PreResNet
    return PreResNet(depth = 56, num_classes = cfg.model.num_class)

def get_model():
    pair = {
        'cifar.resnet18': get_resnet_for_cifar,
        'cifar.vgg16': get_vgg16_for_cifar,
        'cifar.wrn': get_wrn_for_cifar,
        'cifar.resnet9': get_resnet9_for_cifar,
        'mnist.convnet': get_convnet_for_mnist,
        'vgg16bn': get_vgg16_for_imagenet,
        'vgg16bn_gap': get_vgg16_gap_for_imagenet,
        'resnet50': get_resnet50_for_imagenet,
        'se_resnet50': get_se_resnet50,
        'cbam_resnet50': get_cbam_resnet50,
        'gbam_resnet50': get_gbam_resnet50,
        'cifar.resnet56': get_resnet56,
        'cifar.densenet40': get_densenet40,
        'cifar.mod_densenet40': get_mod_densenet40,
        'cifar.pre_resnet56': get_preact_resnet56_for_cifar
    }

    model = pair[cfg.model.name]()

    if cfg.base.checkpoint_path != '':
        print('restore checkpoint: ' + cfg.base.checkpoint_path)
        model.load_state_dict(torch.load(cfg.base.checkpoint_path, map_location='cpu' if not cfg.base.cuda else 'cuda'))

    if cfg.base.cuda:
        model = model.cuda()
        if cfg.base.fp16:
            model = network_to_half(model)

    if cfg.base.multi_gpus:
        if hasattr(model, 'to_parallel'):
            model.to_parallel()
        else:
            model = torch.nn.DataParallel(model)
    return model
