import os
import torch
import numbers
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torchvision.datasets import CIFAR10, ImageFolder
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
from datasets.celeba import CelebA
from datasets.ffhq import FFHQ
from datasets.lsun import LSUN
from torch.utils.data import Subset
import numpy as np
from PIL import Image
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union


class Crop(object):
    def __init__(self, x1, x2, y1, y2):
        self.x1 = x1
        self.x2 = x2
        self.y1 = y1
        self.y2 = y2

    def __call__(self, img):
        return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)

    def __repr__(self):
        return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
            self.x1, self.x2, self.y1, self.y2
        )

class ImageDataset64(ImageFolder):
    def __init__(self,
        root: str,
    ):
        super().__init__(
            root=root
        )
        self.resolution = 64
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        pil_image = self.loader(path)
        # if self.transform is not None:
        #     sample = self.transform(sample)
        
        # We are not on a new enough PIL to support the `reducing_gap`
        # argument, which uses BOX downsampling at powers of two first.
        # Thus, we do it by hand to improve downsample quality.
        while min(*pil_image.size) >= 2 * self.resolution:
            pil_image = pil_image.resize(
                tuple(x // 2 for x in pil_image.size), resample=Image.BOX
            )

        scale = self.resolution / min(*pil_image.size)
        pil_image = pil_image.resize(
            tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
        )

        arr = np.array(pil_image.convert("RGB"))
        crop_y = (arr.shape[0] - self.resolution) // 2
        crop_x = (arr.shape[1] - self.resolution) // 2
        arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution]
        arr = arr.astype(np.float32) / 255.

        if self.target_transform is not None:
            target = self.target_transform(target)

        return np.transpose(arr, [2, 0, 1]), target

def get_dataset(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        tran_transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )
        test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )

    if config.data.dataset == "CIFAR10":
        dataset = CIFAR10(
            os.path.join(args.exp, "datasets", "cifar10"),
            train=True,
            download=True,
            transform=tran_transform,
        )
        test_dataset = CIFAR10(
            os.path.join(args.exp, "datasets", "cifar10_test"),
            train=False,
            download=True,
            transform=test_transform,
        )

    elif config.data.dataset == "CELEBA":
        cx = 89
        cy = 121
        x1 = cy - 64
        x2 = cy + 64
        y1 = cx - 64
        y2 = cx + 64
        if config.data.random_flip:
            dataset = CelebA(
                root=os.path.join(args.exp, "datasets", "celeba"),
                split="train",
                transform=transforms.Compose(
                    [
                        Crop(x1, x2, y1, y2),
                        transforms.Resize(config.data.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                    ]
                ),
                download=True,
            )
        else:
            dataset = CelebA(
                root=os.path.join(args.exp, "datasets", "celeba"),
                split="train",
                transform=transforms.Compose(
                    [
                        Crop(x1, x2, y1, y2),
                        transforms.Resize(config.data.image_size),
                        transforms.ToTensor(),
                    ]
                ),
                download=True,
            )

        test_dataset = CelebA(
            root=os.path.join(args.exp, "datasets", "celeba"),
            split="test",
            transform=transforms.Compose(
                [
                    Crop(x1, x2, y1, y2),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor(),
                ]
            ),
            download=True,
        )

    elif config.data.dataset == "LSUN":
        train_folder = "{}_train".format(config.data.category)
        val_folder = "{}_val".format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(
                root=config.data.root,
                classes=[train_folder],
                transform=transforms.Compose(
                    [
                        transforms.Resize(config.data.image_size),
                        transforms.CenterCrop(config.data.image_size),
                        transforms.RandomHorizontalFlip(p=0.5),
                        transforms.ToTensor(),
                    ]
                ),
            )
        else:
            dataset = LSUN(
                root=config.data.root,
                classes=[train_folder],
                transform=transforms.Compose(
                    [
                        transforms.Resize(config.data.image_size),
                        transforms.CenterCrop(config.data.image_size),
                        transforms.ToTensor(),
                    ]
                ),
            )

        test_dataset = None #LSUN(
        #     root=config.data.root,
        #     classes=[val_folder],
        #     transform=transforms.Compose(
        #         [
        #             transforms.Resize(config.data.image_size),
        #             transforms.CenterCrop(config.data.image_size),
        #             transforms.ToTensor(),
        #         ]
        #     ),
        # )

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(
                path=os.path.join(args.exp, "datasets", "FFHQ"),
                transform=transforms.Compose(
                    [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()]
                ),
                resolution=config.data.image_size,
            )
        else:
            dataset = FFHQ(
                path=os.path.join(args.exp, "datasets", "FFHQ"),
                transform=transforms.ToTensor(),
                resolution=config.data.image_size,
            )

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = (
            indices[: int(num_items * 0.9)],
            indices[int(num_items * 0.9) :],
        )
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)
    
    elif config.data.dataset == "IMAGENET64":
        # if config.data.loader_type == 'custom':
            # from datasets.imagenet64 import ImageNetDownSample
            # if config.data.random_flip:
            #     dataset = ImageNetDownSample(
            #         root=config.data.root,
            #         transform=transforms.Compose(
            #             [
            #                 transforms.RandomHorizontalFlip(p=0.5),
            #                 transforms.ToTensor(),
            #             ]
            #         ),
            #     )
            # else:
            #     dataset = ImageNetDownSample(
            #         root=config.data.root,
            #         transform=transforms.Compose(
            #             [
            #                 transforms.ToTensor(),
            #             ]
            #         ),
            #     )
            # test_dataset = None
        # else:
            train_folder = "{}/train".format(config.data.root)
            val_folder = "{}/val".format(config.data.root)
            if config.data.random_flip:
                dataset = ImageDataset64(
                    root=train_folder,
                    # transform=transforms.Compose(
                    #     [
                    #         transforms.Resize(config.data.image_size, interpolation=transforms.InterpolationMode.BICUBIC), #
                    #         transforms.CenterCrop(config.data.image_size),
                    #         transforms.RandomHorizontalFlip(p=0.5),
                    #         transforms.ToTensor(),
                    #     ]
                    # ),
                )
            else:
                dataset = ImageDataset64(
                    root=train_folder,
                    # transform=transforms.Compose(
                    #     [
                    #         transforms.Resize(config.data.image_size, interpolation=transforms.InterpolationMode.BICUBIC), #interpolation=InterpolationMode.BICUBIC
                    #         transforms.CenterCrop(config.data.image_size),
                    #         transforms.ToTensor(),
                    #     ]
                    # ),
                )

            test_dataset = ImageDataset64(
                root=val_folder,
                # transform=transforms.Compose(
                #     [
                #         transforms.Resize(config.data.image_size, interpolation=transforms.InterpolationMode.BICUBIC), #interpolation=InterpolationMode.BICUBIC
                #         transforms.CenterCrop(config.data.image_size),
                #         transforms.ToTensor(),
                #     ]
                # ),
            )

    else:
        dataset, test_dataset = None, None

    return dataset, test_dataset


def logit_transform(image, lam=1e-6):
    image = lam + (1 - 2 * lam) * image
    return torch.log(image) - torch.log1p(-image)


def data_transform(config, X):
    if config.data.uniform_dequantization:
        X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
    if config.data.gaussian_dequantization:
        X = X + torch.randn_like(X) * 0.01

    if config.data.rescaled:
        X = 2 * X - 1.0
    elif config.data.logit_transform:
        X = logit_transform(X)

    if hasattr(config, "image_mean"):
        return X - config.image_mean.to(X.device)[None, ...]

    return X


def inverse_data_transform(config, X):
    if hasattr(config, "image_mean"):
        X = X + config.image_mean.to(X.device)[None, ...]

    if config.data.logit_transform:
        X = torch.sigmoid(X)
    elif config.data.rescaled:
        X = (X + 1.0) / 2.0

    return torch.clamp(X, 0.0, 1.0)
