import helper.mnist_dataset as mnist_dataset

def load_dataset(dataset_name):
    if dataset_name == 'mnist':
        return mnist_dataset.MnistDataset
