import torch
from skopt.sampler import Sobol, Lhs, Halton, Grid, Hammersly

def normalize_samples(x):
    # Normalize the samples to have mean 0 and variance 1
    mean = torch.mean(x, dim=0, keepdim=True)
    std = torch.std(x, dim=0, keepdim=True)
    normalized_x = (x - mean) / std
    return normalized_x

def rvs_sample(space, n_samples):
    num_samples = n_samples[0]*n_samples[1]
    x = space.rvs(num_samples)
    x = torch.tensor(x)
    return x.reshape(x.shape[-1], n_samples[0], n_samples[1])

def solab_sample(space, n_samples):
    num_samples = n_samples[0]*n_samples[1]
    sobol = Sobol()
    x = sobol.generate(space.dimensions, num_samples)
    x = torch.tensor(x)
    return x.reshape(x.shape[-1], n_samples[0], n_samples[1])
    
def lhs_sample(space, n_samples, lhs_mode='centered'):
    # lhs mode: classic, centered, maximin, correlation, ratio
    num_samples = n_samples[0]*n_samples[1]
    if lhs_mode in ['classic', 'centered']:
        lhs = Lhs(lhs_type=lhs_mode, criterion=None)
    else:
        lhs = Lhs(criterion=lhs_mode, iterations=100)
    x = lhs.generate(space.dimensions, num_samples)
    x = torch.tensor(x)
    return x.reshape(x.shape[-1], n_samples[0], n_samples[1])

def halton_sample(space, n_samples):
    num_samples = n_samples[0]*n_samples[1]
    halton = Halton()
    x = halton.generate(space.dimensions, num_samples)
    x = torch.tensor(x)
    return x.reshape(x.shape[-1], n_samples[0], n_samples[1])

def hammersly_sample(space, n_samples):
    num_samples = n_samples[0]*n_samples[1]
    hammersly = Hammersly()
    x = hammersly.generate(space.dimensions, num_samples)
    x = torch.tensor(x)
    return x.reshape(x.shape[-1], n_samples[0], n_samples[1])
    
def grid_sample(space, n_samples):
    num_samples = n_samples[0]*n_samples[1]
    grid = Grid(border="include", use_full_layout=False)
    x = grid.generate(space.dimensions, num_samples)
    x = torch.tensor(x)
    return x.reshape(x.shape[-1], n_samples[0], n_samples[1])

_sampling_factory = {
    'rvs': rvs_sample,
    'solab': solab_sample,
    'halton': halton_sample,
    'hammersly':  hammersly_sample,
    'grid': grid_sample,
    'lhs': lhs_sample
}
    