import math
import numpy as np
import torch
from sklearn.datasets import make_swiss_roll
from torch.utils.data import Dataset
import random

__all__ = ["Gaussian8", "Gaussian25", "SwissRoll", "DataStreamer", "GenToyDataset", "Gaussian25_Imbalanced", "Gaussian2_1D", "GaussianND_", 
           "GaussianND_More_Modes", "GaussianND_Odd_Even"]


class ToyDataset(Dataset):
    def __init__(self, size: int, stdev: float, random_state: int = None):
        self.size = size
        self.noise = stdev
        self.random_state = random_state
        self.stdev = self._calc_stdev()
        self.data = self._sample()
        
    def _calc_stdev(self):
        pass

    def _sample(self):
        pass

    def resample(self):
        self.data = self._sample()

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return torch.from_numpy(self.data[idx])

class GenToyDataset(Dataset):
    def __init__(self, data):
        self.size = len(data)
        self.data = data

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return torch.from_numpy(self.data[idx])



class Gaussian8(ToyDataset):
    scale = 2
    modes = [
        (math.cos(0.25 * t * math.pi), math.sin(0.25 * t * math.pi))
        for t in range(8)
    ]  # scale x (8 roots of z^8 = 1)

    def __init__(self, size, stdev=0.02, random_state=1234):
        self.modes = self.scale * np.array(self.modes, dtype=np.float32)
        super(Gaussian8, self).__init__(size, stdev, random_state)
    
    def _calc_stdev(self):
        # total variance = expected conditional variance + variance of conditional expectation
        return math.sqrt(self.noise ** 2 + (self.scale ** 2) * 0.5)  # x-y symmetric; around 1.414
    
    def _sample(self):
        rng = np.random.default_rng(seed=self.random_state)
        data = self.noise * rng.standard_normal((self.size, 2), dtype=np.float32)
        data += np.array(self.modes)[
            np.random.choice(np.arange(8), size=self.size, replace=True)]
        data /= self.stdev
        return data


class Gaussian25(ToyDataset):
    scale = 2
    modes = [(i, j) for i in range(-2, 3) for j in range(-2, 3)]

    def __init__(self, size, stdev=0.05, random_state=1234):
        self.modes = self.scale * np.array(self.modes, dtype=np.float32)
        super(Gaussian25, self).__init__(size, stdev, random_state)

    def _calc_stdev(self):
        # x-y symmetric; around 2.828
        return math.sqrt(self.noise ** 2 + (self.scale ** 2) * 2.)  

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        data = self.noise * rng.standard_normal((self.size, 2), dtype=np.float32)
        data += np.array(self.modes)[np.arange(self.size) % 25]
        data /= self.stdev
        return data
    
class Gaussian25_NoSTD(ToyDataset):
    modes = [(i, j) for i in range(-2, 3) for j in range(-2, 3)]

    def __init__(self, size, stdev=0.05, random_state=1234):
        self.modes = np.array(self.modes, dtype=np.float32)
        super(Gaussian25_NoSTD, self).__init__(size, stdev, random_state)

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        data = self.noise * rng.standard_normal((self.size, 2), dtype=np.float32)
        data += np.array(self.modes)[np.arange(self.size) % 25]

        return data
    
class Gaussian25_Rotated(ToyDataset):
    scale = 2
    # modes = [(i, j) for i in range(-2, 3) for j in range(-2, 3)]
    # modes = [(-2, 0), (-1.5, 0.5), (-1, 1), (-0.5, 1.5), (0,2),
    #          (-1.5, -0.5), (-1, 0), (-0.5, 0.5), (0, 1), (0.5, 1.5),
    #         ``]
    modes = [(-2, 0), (-1.5, 0.5), (-1, 1), (-0.5, 1.5), (0, 2),
         (-1.5, -0.5), (-1, 0), (-0.5, 0.5), (0, 1), (0.5, 1.5),
         (-1, -1), (-0.5, -0.5), (0, 0), (0.5, 0.5), (1, 1),
         (-0.5, -1.5), (0, -1), (0.5, -0.5), (1, 0), (1.5, 0.5),
         (0, -2), (0.5, -1.5), (1, -1), (1.5, -0.5), (2, 0)]

    def __init__(self, size, stdev=0.05, random_state=1234):
        self.modes = self.scale * np.array(self.modes, dtype=np.float32)
        super(Gaussian25_Rotated, self).__init__(size, stdev, random_state)

    def _calc_stdev(self):
        # x-y symmetric; around 2.828
        return math.sqrt(self.noise ** 2 + (self.scale ** 2) * 2.)  

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        data = self.noise * rng.standard_normal((self.size, 2), dtype=np.float32)
        data += np.array(self.modes)[np.arange(self.size) % 25]
        data /= self.stdev
        return data
    
class GaussianND_(ToyDataset):
    def __init__(self, size, stdev=1, random_state=1234):
        self.d = 16
        self.mu = np.zeros(self.d)
        for i in range(0, self.d):
            if random.random() < 0.5:
                self.mu[i] = 10
            else:
                self.mu[i] = 20
        self.sigma = stdev
        # self.modes = self.scale * np.array(self.modes, dtype=np.float32)
        super().__init__(size, stdev, random_state)

    def _calc_stdev(self):
        # x-y symmetric; around 2.828
        return self.sigma

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        # Initialize a d-dimensional array of zeros for each sample
        data = np.zeros((self.size, self.d), dtype=np.float32)
        # data.fill(1e-4)
        # Randomly choose one dimension for each sample to be non-zero
        chosen_dims = rng.integers(0, self.d, size=self.size)
        # Sample from Gaussian for each chosen dimension
        for i, dim in enumerate(chosen_dims):
            data[i, dim] = rng.normal(self.mu[dim], self.sigma)
        return data

class GaussianND_More_Modes(ToyDataset):
    def __init__(self, size, stdev=1, random_state=1234):
        self.d = 32
        self.mu = np.zeros(self.d)
        self.mu_choices = [10, 20, 30, 40, 50, 60, 70, 80]
        # Shuffle the choices
        # random.shuffle(self.mu_choices)
        for i in range(0, self.d):
            # self.mu[i] = self.mu_choices[i]
            self.mu[i] = random.choice(self.mu_choices)
            # if random.random() < 0.5:
                # self.mu[i] = 10
            # else:
            #   self.mu[i] = 20
        print(self.mu)
        self.sigma = stdev
        # self.modes = self.scale * np.array(self.modes, dtype=np.float32)
        super().__init__(size, stdev, random_state) 

    def _calc_stdev(self):
        # x-y symmetric; around 2.828
        return self.sigma

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        # Initialize a d-dimensional array of zeros for each sample
        data = np.zeros((self.size, self.d), dtype=np.float32)
        # data.fill(1e-4)
        # Randomly choose one dimension for each sample to be non-zero
        chosen_dims = rng.integers(0, self.d, size=self.size)
        # Sample from Gaussian for each chosen dimension
        for i, dim in enumerate(chosen_dims):
            data[i, dim] = rng.normal(self.mu[dim], self.sigma)
        return data


class GaussianND_Odd_Even(ToyDataset):
    def __init__(self, size, stdev=1, random_state=1234):
        self.d = 16
        self.mu = np.zeros(self.d)
        self.mu_choices = [10, 20]
        for i in range(0, self.d):
            if i%2 == 0:
                self.mu[i] = 10
            else:
                self.mu[i] = 20
        
        #     self.mu[i] = random.choice(self.mu_choices)
        #     # if random.random() < 0.5:
        #         # self.mu[i] = 10
        #     # else:
        #     #   self.mu[i] = 20
        print(self.mu)
        self.sigma = stdev
        # self.modes = self.scale * np.array(self.modes, dtype=np.float32)
        super().__init__(size, stdev, random_state)

    def _calc_stdev(self):
        # x-y symmetric; around 2.828
        return self.sigma

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        # Initialize a d-dimensional array of zeros for each sample
        data = np.zeros((self.size, self.d), dtype=np.float32)
        # data.fill(1e-4)
        # Randomly choose one dimension for each sample to be non-zero
        # chosen_dims = rng.integers(0, self.d, size=self.size)
        # odd = random.random() < 0.5
        # Generate a random between 0 and 1 for every size.
        odd = rng.integers(0, 2, size=self.size)
        # Sample from Gaussian for each chosen dimension
        for i in range(self.size):
            for j, dim in enumerate(range(self.d)):
                if odd[i]:
                    if dim % 2 != 0:
                        data[i, dim] = rng.normal(self.mu[dim], self.sigma)
                else:
                    if dim % 2 == 0:
                        data[i, dim] = rng.normal(self.mu[dim], self.sigma)
        return data

class Gaussian2D_Composition_Test(ToyDataset):
    def __init__(self, size, stdev=1, random_state=1234):
        self.d = 2
        self.mu = np.zeros(self.d)
        self.mu_choices = [10, 20]
        for i in range(0, self.d):
            if i%2 == 0:
                self.mu[i] = 10
            else:
                self.mu[i] = 20
        
        #     self.mu[i] = random.choice(self.mu_choices)
        #     # if random.random() < 0.5:
        #         # self.mu[i] = 10
        #     # else:
        #     #   self.mu[i] = 20
        print(self.mu)
        self.sigma = stdev
        # self.modes = self.scale * np.array(self.modes, dtype=np.float32)
        super().__init__(size, stdev, random_state)

    def _calc_stdev(self):
        # x-y symmetric; around 2.828
        return self.sigma

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        # Initialize a d-dimensional array of zeros for each sample
        data = np.zeros((self.size, self.d), dtype=np.float32)
        # data.fill(1e-4)
        # Randomly choose one dimension for each sample to be non-zero
        # chosen_dims = rng.integers(0, self.d, size=self.size)
        # odd = random.random() < 0.5
        # Generate a random between 0 and 1 for every size.
        # Sample from Gaussian for each chosen dimension
        for i in range(self.size):
            if random.random() < 0.5:
                data[i, 0] = rng.normal(self.mu[1], self.sigma)
                data[i, 1] = rng.normal(self.mu[0], self.sigma)
            else:
                # if dim % 2 == 0:
                data[i, 0] = rng.normal(self.mu[0], self.sigma)
                data[i, 1] = rng.normal(self.mu[1], self.sigma)
        return data


class GaussianMixture2D(ToyDataset):
    def __init__(self, size, stdev=0.05, random_state=1234):
        # self.stdev = stdev
        self.std = stdev
        super(GaussianMixture2D, self).__init__(size, stdev, random_state)
        self.stdev = stdev
    def _calc_stdev(self):
        # x-y symmetric; around 2.828
        return self.std
    
    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        
        # Half of the size for x and y samples respectively
        half_size = self.size // 2
        
        # Generate x samples at (x, 0)
        samples_x = rng.normal(1, self.stdev, half_size)
        samples_at_x = np.stack((samples_x, np.zeros(half_size)), axis=-1)
        
        # Generate y samples at (0, y)
        samples_y = rng.normal(1, self.stdev, self.size - half_size)
        samples_at_y = np.stack((np.zeros(self.size - half_size), samples_y), axis=-1)
        
        # Combine the two sets of samples
        samples_final = np.vstack((samples_at_x, samples_at_y))
        
        return samples_final.astype(np.float32)


class Gaussian1_1D(ToyDataset):
    def __init__(self, size, distance=1, stdev=0.05, random_state=1234):
        self.distance = distance
        self.modes = np.array([distance], dtype=np.float32)
        print(self.modes)
        super(Gaussian1_1D, self).__init__(size, stdev, random_state)

    # def _calc_stdev(self):
    #     # Since it's 1D, only considering the noise contribution
    #     return math.sqrt(self.noise ** 2 + (self.distance ** 2) / 4.)

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        # Generate samples with added noise
        data = self.noise * rng.standard_normal(self.size, dtype=np.float32)
        # Assign each sample to one of the two modes
        mode_indices = np.arange(self.size) % 1
        for i, mode_index in enumerate(mode_indices):
            data[i] += self.modes[mode_index]
        # data /= self.stdev
        return data


class Gaussian2_1D(ToyDataset):
    def __init__(self, size, distance=1, stdev=0.05, random_state=1234):
        self.distance = distance
        self.modes = np.array([1, 1+distance], dtype=np.float32)
        print(self.modes)
        super(Gaussian2_1D, self).__init__(size, stdev, random_state)

    # def _calc_stdev(self):
    #     # Since it's 1D, only considering the noise contribution
    #     return math.sqrt(self.noise ** 2 + (self.distance ** 2) / 4.)

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        # Generate samples with added noise
        data = self.noise * rng.standard_normal(self.size, dtype=np.float32)
        # Assign each sample to one of the two modes
        mode_indices = np.arange(self.size) % 2
        for i, mode_index in enumerate(mode_indices):
            data[i] += self.modes[mode_index]
        # data /= self.stdev
        return data

class Gaussian3_1D(ToyDataset):
    def __init__(self, size, distance=1, distance_2 = 1, stdev=0.05, random_state=1234):
        self.distance = distance
        a = distance
        b = distance_2
        print('Distances: ', distance, distance_2)
        self.modes = np.array([1, 1+a, 1+a+b], dtype=np.float32)
        print(self.modes)
        super(Gaussian3_1D, self).__init__(size, stdev, random_state)
        print("Stdev: ",self.stdev)
        print("Noise:", self.noise)

    def _calc_stdev(self):
        # Since it's 1D, only considering the noise contribution
        return math.sqrt(self.noise ** 2 + (self.distance ** 2) / 4.)

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        # Generate samples with added noise
        data = self.noise * rng.standard_normal(self.size, dtype=np.float32)
        # Assign each sample to one of the two modes
        mode_indices = np.arange(self.size) % 3
        for i, mode_index in enumerate(mode_indices):
            data[i] += self.modes[mode_index]
        # data /= self.stdev

        return data

class Gaussian4_1D(ToyDataset):
    def __init__(self, size, distance=1, distance_2 = 1, stdev=0.05, random_state=1234):
        self.distance = distance
        a = distance
        b = distance_2
        c = distance
        print('Distances: ', distance, distance_2)
        self.modes = np.array([1, 1+a, 1+a+b, 1+a+b+c], dtype=np.float32)
        print(self.modes)
        super(Gaussian4_1D, self).__init__(size, stdev, random_state)
        print("Stdev: ",self.stdev)
        print("Noise:", self.noise)

    def _calc_stdev(self):
        # Since it's 1D, only considering the noise contribution
        return math.sqrt(self.noise ** 2 + (self.distance ** 2) / 4.)

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        # Generate samples with added noise
        data = self.noise * rng.standard_normal(self.size, dtype=np.float32)
        # Assign each sample to one of the two modes
        mode_indices = np.arange(self.size) % 4
        for i, mode_index in enumerate(mode_indices):
            data[i] += self.modes[mode_index]
        # data /= self.stdev

        return data

class Gaussian25_Imbalanced(ToyDataset):
    scale = 2
    modes = [(i, j) for i in range(-2, 3) for j in range(-2, 3)]

    def __init__(self, size, stdev=0.05, random_state=1234):
        self.modes = self.scale * np.array(self.modes, dtype=np.float32)
        self.mode_distribution = np.array([
            1, 1, 1, 0.01, 1,
            1, 1, 1, 1, 1,
            1, 1, 1, 1, 1,
            1, 1, 1, 1, 1,
            1, 0.01, 1, 1, 1
        ])
        self.mode_distribution /= self.mode_distribution.sum()
        super(Gaussian25_Imbalanced, self).__init__(size, stdev, random_state)

    def _calc_stdev(self):
        # x-y symmetric; around 2.828
        return math.sqrt(self.noise ** 2 + (self.scale ** 2) * 2.)  

    def _sample(self):
        rng = np.random.default_rng(self.random_state)
        data = self.noise * rng.standard_normal((self.size, 2), dtype=np.float32)
        mode_indices = rng.choice(len(self.modes), size=self.size, p=self.mode_distribution)
        selected_modes = np.array(self.modes)[mode_indices]
        data += selected_modes
        data /= self.stdev
        # data += np.array(self.modes)[np.arange(self.size) % 25]
        # data /= self.stdev
        return data

class SwissRoll(ToyDataset):
    """
    source: https://homepages.ecs.vuw.ac.nz/~marslast/Code/Ch6/lle.py
    def swissroll():
        # Make the swiss roll dataset
        N = 1000
        noise = 0.05

        t = 3*np.pi/2 * (1 + 2*np.random.rand(1,N))
        h = 21 * np.random.rand(1,N)
        data = np.concatenate((t*np.cos(t),h,t*np.sin(t))) + noise*np.random.randn(3,N)
        return np.transpose(data), np.squeeze(t)

    The covariate standard deviation of x,y without noise
    E[x] = 2 and var(x) = (39/8)*pi^2 - 17/4
    E[y] = 2/(3*pi) and var(y) = (39/8)*pi^2 - 15/4
    """

    def __init__(self, size, stdev=0.25, random_state=1234):
        super(SwissRoll, self).__init__(size, stdev, random_state)

    def _calc_stdev(self):
        # calculate the stdev's for the data
        stdev = np.empty((1, 2))
        stdev.fill(39 * math.pi ** 2 / 8 - 4)
        stdev += np.array([[-1, 1]]) * 0.25 + self.noise ** 2
        stdev = np.sqrt(stdev)
        return stdev

    def _sample(self):
        data = make_swiss_roll(
            self.size, noise=self.noise,
            random_state=self.random_state)[0][:, [0, 2]].astype(np.float32)
        data /= self.stdev
        return data

# class CustomDataStreamer:

#     def __init__(self, dataset, batch_size: int, num_batches: int, resample: bool = False):
        
#         self.dataset = dat
#         self.batch_size = batch_size
#         self.num_batches = num_batches
#         self.resample = resample
#         self.dataset = dataset(batch_size * num_batches, random_state=None)

#     def __iter__(self):
#         cnt = 0
#         while True:
#             start = cnt * self.batch_size
#             end = start + self.batch_size
#             yield torch.from_numpy(self.dataset.data[start:end])
#             cnt += 1
#             if cnt >= self.num_batches:
#                 break
#         if self.resample:
#             self.dataset.resample()

#     def __len__(self):
#         return self.num_batches

class DataStreamer:

    def __init__(self, dataset: ToyDataset, batch_size: int, num_batches: int, resample: bool = False, distance=1, distance_2=1):
        
        if isinstance(dataset, str):
            if dataset=="gaussian2_1d" or dataset=="gaussian3_1d" or dataset=="gaussian4_1d":
                flag = True
                dataset_name = dataset
            else:
                flag = False
                dataset_name = dataset
            dataset = self.dataset_map(dataset)
            if flag:
                print("Using Gaussian2_1D")
                if dataset_name == "gaussian2_1d":
                    print("Here")
                    self.dataset = dataset(size=batch_size * num_batches, random_state=None, distance=distance)
                else:
                    self.dataset = dataset(size=batch_size * num_batches, random_state=None, distance=distance, distance_2=distance_2)
            else:
                self.dataset = dataset(batch_size * num_batches, random_state=None)
        else:
            self.dataset = GenToyDataset(dataset)
        self.batch_size = batch_size
        self.num_batches = num_batches
        self.resample = resample

    def __iter__(self):
        cnt = 0
        while True:
            start = cnt * self.batch_size
            end = start + self.batch_size
            yield torch.from_numpy(self.dataset.data[start:end])
            cnt += 1
            if cnt >= self.num_batches:
                break
        if self.resample:
            self.dataset.resample()

    def __len__(self):
        return self.num_batches
        
    @staticmethod
    def dataset_map(dataset):
        return {
            "gaussian8": Gaussian8,
            "gaussian25": Gaussian25,
            "swissroll": SwissRoll,
            "gaussian25_imbalanced": Gaussian25_Imbalanced,
            "gaussian2_1d": Gaussian2_1D,
            "gaussian_mixture_2d": GaussianMixture2D,
            "gaussian_nd_zeros": GaussianND_,
            "gaussian_nd_more_modes": GaussianND_More_Modes,
            "gaussian_nd_odd_even": GaussianND_Odd_Even,
            "gaussian2d_composition_test": Gaussian2D_Composition_Test,
            "gaussian25_rotated": Gaussian25_Rotated,
            "gaussian25_no_std": Gaussian25_NoSTD,
            "gaussian3_1d": Gaussian3_1D,
            "gaussian1_1d": Gaussian1_1D,
            "gaussian4_1d": Gaussian4_1D,

        }.get(dataset, None)


if __name__ == "__main__":
    import os
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    from torch.utils.data import DataLoader

    mpl.rcParams["figure.dpi"] = 144

    fig_dir = "./figs"
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    size = 100000

    DATASET = {
            "gaussian8": Gaussian8,
            "gaussian25": Gaussian25,
            "swissroll": SwissRoll
    }

    for name, dataset in DATASET.items():
        data = dataset(size)
        plt.figure(figsize=(6, 6))
        plt.scatter(*np.hsplit(data.data, 2), s=0.5, alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(fig_dir, f"{name}.jpg"))
        dataloader = DataLoader(data)
        x = next(iter(dataloader))
