import os

from src.methods.catsvsdogs_train import CatsVsDogsTrain
from src.methods.celebA_train import CelebATrain
from src.methods.cifar10_train import CIFAR10Train
from src.methods.svhn_train import SVHNTrain


from src.arguments import init_train_argparse
from src.methods import MnistTrain, IN9lTrain




def main(args):
    if len(args.gpu_ids) == 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids[0])
    if args.dataset == 'mnist':
        method = MnistTrain(args)
    elif args.dataset == 'in9l':
        method = IN9lTrain(args)
    elif args.dataset == 'svhn':
        method = SVHNTrain(args)
    elif args.dataset == 'catsvsdogs':
        method = CatsVsDogsTrain(args)
    elif args.dataset == 'cifar10':
        method = CIFAR10Train(args)
    elif args.dataset == 'celeba':
        method = CelebATrain(args)
    else:
        raise NotImplementedError
    if args.train:
        if args.masktune:
            method.train_augmask()
        else:
            method.train()
    if args.selective_classification:
        method.test_selective_classification()
    else:
        method.test()


if __name__ == '__main__':
    parser = init_train_argparse()
    args = parser.parse_args()
    main(args)
