import torch.utils.data as data

class ExtraLabelDatasetWrapper(data.Dataset):
    def __init__(self, dataset, extra_labels):
        """
        Initializes a new instance of the ExtraLabelDatasetWrapper class.

        Args:
            dataset (torch.utils.data.Dataset): The original dataset.
            extra_labels (list): The list of extra labels to add to each item in the dataset.
        """
        self.dataset = dataset
        self.extra_labels = extra_labels

        assert len(self.dataset) == len(self.extra_labels), "The number of items in the dataset and the number of extra labels must be the same."

    def __getitem__(self, index):
        """
        Returns the item at the specified index, with an extra label added.

        Args:
            index (int): The index of the item to retrieve.

        Returns:
            tuple: A tuple containing the item and the extra label.
        """
        item, label = self.dataset[index]
        extra_label = self.extra_labels[index]
        return item, label, extra_label

    def __len__(self):
        """
        Returns the number of items in the dataset.

        Returns:
            int: The number of items in the dataset.
        """
        return len(self.dataset)
