import os
from glob import glob

import numpy as np
from PIL import Image
import torch
from torch import nn
from torch.utils.data import Dataset


class ADE20KDataset(Dataset):
    """
    ADE20K dataset 
    http://groups.csail.mit.edu/vision/datasets/ADE20K/
    """

    def __init__(self, root, split, transform=None, target_transform=None):
        super(ADE20KDataset, self).__init__()
        self.num_classes = 150
        self.root = root
        self.split = split
        self.transform = transform
        self.target_transform = target_transform

        if self.split in ['training', 'validation']:
            self.image_dir = os.path.join(self.root, 'images', self.split)
            self.label_dir = os.path.join(self.root, 'annotations', self.split)
            self.files = [os.path.basename(path).split('.')[0] for path in glob(self.image_dir + '/*.jpg')]
            print(f'Found {len(self.files)} images in the {self.split} split')
        else:
            raise ValueError(f'Invalid split name {self.split}')

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        image_id = self.files[idx]
        image_path = os.path.join(self.image_dir, image_id + '.jpg')
        label_path = os.path.join(self.label_dir, image_id + '.png')
        image = self.transform(Image.open(image_path).convert('RGB'))
        label = self.target_transform(Image.open(label_path))
        return image, label
