import os
import warnings
from typing import Optional, Callable, Tuple, Any, List, Iterable
import bisect
import pickle
from PIL import Image
import random
import json
import numpy as np

from torch.utils.data.dataset import Dataset, T_co, IterableDataset
import torchvision
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader

_tf_toTensor = torchvision.transforms.ToTensor() 
norm_mean=(0.485, 0.456, 0.406)
norm_std=(0.229, 0.224, 0.225)
_tf_norm = torchvision.transforms.Normalize(mean=norm_mean, std=norm_std)
class ImageList(datasets.VisionDataset):
    """A generic Dataset class for image classification

    Args:
        root (str): Root directory of dataset
        classes (list[str]): The names of all the classes
        data_list_file (str): File to read the image list from.
        transform (callable, optional): A function/transform that  takes in an PIL image \
            and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
        target_transform (callable, optional): A function/transform that takes in the target and transforms it.

    .. note:: In `data_list_file`, each line has 2 values in the following format.
        ::
            source_dir/dog_xxx.png 0
            source_dir/cat_123.png 1
            target_dir/dog_xxy.png 0
            target_dir/cat_nsdf3.png 1

        The first value is the relative path of an image, and the second value is the label of the corresponding image.
        If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`.
    """

    def __init__(self, root: str, classes: List[str], data_list_file: str,
                 transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.samples = self.parse_data_file(data_list_file)
        self.targets = [s[1] for s in self.samples]
        self.classes = classes
        self.class_to_idx = {cls: idx
                             for idx, cls in enumerate(self.classes)}
        self.loader = default_loader
        self.data_list_file = data_list_file

    def __getitem__(self, index: int) -> Tuple[Any, int]:
        """
        Args:
            index (int): Index
            return (tuple): (image, target) where target is index of the target class.
        """
        path, target = self.samples[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None and target is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self) -> int:
        return len(self.samples)

    def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]:
        """Parse file to data list

        Args:
            file_name (str): The path of data file
            return (list): List of (image path, class_index) tuples
        """
        with open(file_name, "r") as f:
            data_list = []
            for line in f.readlines():
                split_line = line.split()
                target = split_line[-1]
                path = ' '.join(split_line[:-1])
                if not os.path.isabs(path):
                    path = os.path.join(self.root, path)
                target = int(target)
                data_list.append((path, target))
        return data_list

    @property
    def num_classes(self) -> int:
        """Number of classes"""
        return len(self.classes)

    @classmethod
    def domains(cls):
        """All possible domain in this dataset"""
        raise NotImplemented

class PklImageListIthor(datasets.VisionDataset):

    def __init__(self, root: str, classes: List[str], data_file_list: str,
                 transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.classes = classes
        self.samples = self.parse_data_file(data_file_list)
        self.class_to_idx = {cls: idx
                             for idx, cls in enumerate(self.classes)}
        self.data_file_list = data_file_list

    def __getitem__(self, index: int) -> Tuple[Any, list]:
        """
        Args:
            index (int): Index
            return (tuple): (image, target) where target is index of the target class.
        """
        sample = self.samples[index]
        img = Image.fromarray(sample.copy())
        if self.transform is not None:
            img = self.transform(img)
        
        img_s, img_t = img.copy(), img.copy()
        img_s = _tf_norm(_tf_toTensor(img_s))
        if self.target_transform is not None:
            img_t = self.target_transform(img_t)
        return img_s, img_t

    def __len__(self) -> int:
        return len(self.samples)

    def parse_data_file(self, file_name_list: str) -> List[Tuple[str, int]]:
        sample_list = []
        for file_name in file_name_list:
            with open(file_name, 'rb') as f:
                data = pickle.load(f)
            for mdp in data.keys():
                if isinstance(mdp, int):
                    for episode in data[mdp].keys():
                        for i in range(len(data[mdp][episode]["frame"])):
                            sample_list.append(data[mdp][episode]["frame"][i].copy())
                else:
                    continue
        
        if self.classes is None:
            self.classes = data["classes"]
        
        print("data samples:", len(sample_list))
        return sample_list

    @property
    def num_classes(self) -> int:
        """Number of classes"""
        return len(self.classes)

    @classmethod
    def domains(cls):
        """All possible domain in this dataset"""
        raise NotImplemented


class PklImageListMetaworld(datasets.VisionDataset):

    def __init__(self, root: str, classes: List[str], data_file_list: str,
                 transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.classes = classes
        self.samples = self.parse_data_file(data_file_list)
        self.class_to_idx = {cls: idx
                             for idx, cls in enumerate(self.classes)}
        self.data_file_list = data_file_list

    def __getitem__(self, index: int) -> Tuple[Any, list]:
        """
        Args:
            index (int): Index
            return (tuple): (image, target) where target is index of the target class.
        """
        sample = self.samples[index]
        img = Image.fromarray(sample.copy())
        if self.transform is not None:
            img = self.transform(img)
        
        img_s, img_t = img.copy(), img.copy()
        img_s = _tf_norm(_tf_toTensor(img_s))
        if self.target_transform is not None:
            img_t = self.target_transform(img_t)
        return img_s, img_t

    def __len__(self) -> int:
        return len(self.samples)

    def parse_data_file(self, file_name_list: str) -> List[Tuple[str, int]]:
        sample_list = []
        for file_name in file_name_list:
            with open(file_name, 'rb') as f:
                data = pickle.load(f)
            for mdp in data.keys():
                if isinstance(mdp, int):
                    for i in range(len(data[mdp]["frame"])):
                        sample_list.append(data[mdp]["frame"][i].copy())
                else:
                    continue
        
        if self.classes is None:
            self.classes = data["classes"]
        
        print("data samples:", len(sample_list))
        return sample_list

    @property
    def num_classes(self) -> int:
        """Number of classes"""
        return len(self.classes)

    @classmethod
    def domains(cls):
        """All possible domain in this dataset"""
        raise NotImplemented


class MultipleDomainsDataset(Dataset[T_co]):
    r"""Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.

    Args:
        datasets (sequence): List of datasets to be concatenated
    """
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, domains: Iterable[Dataset], domain_names: Iterable[str], domain_ids) -> None:
        super(MultipleDomainsDataset, self).__init__()
        # Cannot verify that datasets is Sized
        assert len(domains) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
        self.datasets = self.domains = list(domains)
        for d in self.domains:
            assert not isinstance(d, IterableDataset), "MultipleDomainsDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.domains)
        self.domain_names = domain_names
        self.domain_ids = domain_ids

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.domains[dataset_idx][sample_idx] + (self.domain_ids[dataset_idx],)

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes