import torch
import torchvision
from torchvision import transforms

from config import cfg

def get_fashion_mnist():
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    trainset = torchvision.datasets.FashionMNIST(root='./data/pytorch/fashion_mnist', train=True, download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=cfg.data.batch_size, shuffle=cfg.data.shuffle, num_workers=cfg.data.num_workers)

    testset = torchvision.datasets.FashionMNIST(root='./data/pytorch/fashion_mnist', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=cfg.data.test_batch_size, shuffle=False, num_workers=cfg.data.num_workers)

    return train_loader, test_loader
