import torch
import torchvision
from torchvision import transforms
import os

from config import cfg

def _get_loaders(root):
    train_dir = os.path.join(root, 'train')
    val_dir = os.path.join(root, 'val')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = torchvision.datasets.ImageFolder(
        train_dir,
        transforms.Compose([
            transforms.Resize(512),
            # transforms.RandomRotation(45),
            transforms.RandomResizedCrop(448, scale=(0.08, 1.25)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    
    val_dataset = torchvision.datasets.ImageFolder(
        val_dir,
        transforms.Compose([
            transforms.Resize(512),
            transforms.CenterCrop(448),
            transforms.ToTensor(),
            normalize,
        ]))
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.data.batch_size,
        shuffle=cfg.data.shuffle, 
        num_workers=cfg.data.num_workers,
        pin_memory=True
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.data.test_batch_size,
        shuffle=False,
        num_workers=cfg.data.num_workers,
        pin_memory=True
    )

    return train_loader, val_loader

def get_big_cub200():
    return _get_loaders('./data/CUB_200_2011/split')
