import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def generate_spiral(n_points, n_spirals=2, noise=0.5):
    """
    Generate a 2D spiral dataset.

    Args:
    n_points (int): Total number of data points.
    n_spirals (int): Number of spirals. Default is 2.

    Returns:
    numpy.ndarray: 2D array where each row is a point on the spiral.
    """
    # Generate a vector of "time" values
    t = np.linspace(0, 4 * np.pi, n_points)  # adjust range as needed

    # Generate corresponding x and y values for each spiral
    data = []
    for i in range(n_spirals):
        x = (1 + 0.1 * i) * t * np.cos(t + i * np.pi / n_spirals) + np.random.normal(0, noise, n_points)
        y = (1 + 0.1 * i) * t * np.sin(t + i * np.pi / n_spirals) + np.random.normal(0, noise, n_points)
        data.append(np.column_stack((x, y, t * np.ones(n_points))))

    return np.vstack(data)

output = generate_spiral(1000, 1)
np.savetxt('spiral.csv', output, delimiter=',')

# plot
plt.figure(figsize=(8, 8))
plt.scatter(output[:, 0], output[:, 1], s=1, c=output[:, 2], cmap='Spectral')
plt.axis('equal')
plt.savefig('spiral.png', dpi=300)
