from .chair_dataset import *

def get_train_dataloaders(args):
    print("Loading training data....")
    if args.dataset == 'chair':
        train_dataset = ChairTrainDataset(
            args.data_dir,
            len_dataset=args.steps_per_epoch,
            cluster_size=args.train_cluster_size,
            num_actions=args.train_num_actions,
            image_channels=args.image_channels
        )
        train_dataloader = DataLoader(
            train_dataset, batch_size=args.train_batch_size, shuffle=True
        )
    else:
        raise Exception("Dataset not implemented.")

    return train_dataloader

def get_test_dataloaders(args):
    print("Loading test data....")
    if args.dataset == 'chair':
        test_dataset = ChairTestDataset(args.data_dir)
        test_dataloader = DataLoader(
            test_dataset, batch_size=args.evaluation_batch_size, shuffle=True
        )
    else:
        raise Exception("Dataset not implemented.")

    return test_dataloader