import numpy as np
import torch


def sample_n(start, end, n, sampling):
    if sampling == 'uniform':
        x = np.linspace(start, end, n)
    elif sampling == 'random':
        interval = end - start
        x = np.sort(interval * np.random.rand(n)+start)
    else:
        raise ValueError("Invalid sampling type")

    return x


def load_triangle_1(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = -np.abs(x) + 1

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_triangle_2(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 2, n, sampling)

    y = np.zeros(n)
    for i in range(n):
        if -1 <= x[i] <= 1:
            y[i] = -np.abs(x[i]) + 1
        else:
            y[i] = x[i]-1

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_parabola(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = x**2

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_cubic(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = x**3

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_quartic(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = x**5

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_parabola2(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = (x-1)**2 - 1

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_sine1(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = np.sin(np.pi*x)

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_sine2(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = np.sin(2*np.pi*x)

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_sine22(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = np.sin(2*np.pi*x)-1

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_sine5(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = np.sin(2.5*np.pi*x)

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_line(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = np.copy(x) + 5

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_line_flat(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = np.full(n, 0)
    y = np.zeros(n) + 2

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_square_wave1(n, dtype=np.float32, sampling='uniform'):
    x = sample_n(-1, 1, n, sampling)
    y = np.zeros(n, dtype=dtype)
    y[n//3:2*(n//3)] = 1.0

    return torch.from_numpy(y.astype(dtype)).unsqueeze(1), torch.from_numpy(x.astype(dtype)).unsqueeze(1)


def load_geometry(geometry_type, n, dtype=np.float32, sampling='uniform'):

    if geometry_type == "triangle1":
        return load_triangle_1(n, dtype, sampling)
    elif geometry_type == "triangle2":
        return load_triangle_2(n, dtype, sampling)
    elif geometry_type == "parabola":
        return load_parabola(n, dtype, sampling)
    elif geometry_type == "cubic":
        return load_cubic(n, dtype, sampling)
    elif geometry_type == "quartic":
        return load_quartic(n, dtype, sampling)
    elif geometry_type == "parabola2":
        return load_parabola2(n, dtype, sampling)
    elif geometry_type == "sine1":
        return load_sine1(n, dtype, sampling)
    elif geometry_type == "sine2":
        return load_sine2(n, dtype, sampling)
    elif geometry_type == "sine22":
        return load_sine22(n, dtype, sampling)
    elif geometry_type == "sine5":
        return load_sine5(n, dtype, sampling)
    elif geometry_type == "line":
        return load_line(n, dtype, sampling)
    elif geometry_type == "line-flat":
        return load_line_flat(n, dtype, sampling)
    elif geometry_type == "square1":
        return load_square_wave1(n, dtype, sampling)
    else:
        raise ValueError("Invalid Geometry Type!")
