# Author: Yu-Chia Chen <yuchaz@uw.edu>
# LICENSE: Simplified BSD https://github.com/yuchaz/homology_emb/blob/main/LICENSE

from __future__ import print_function, absolute_import, division
import numpy as np
from homology_emb import choose_n_farthest_points
import collections


def punctplane(seed=42, **kwargs):
    data_orig = generate_2d_grids(ratio=1)[0][:, :2]

    ix1 = np.linalg.norm(data_orig - [[1, 0]], np.infty, 1) > .3
    ix2 = np.linalg.norm(data_orig - [[-1, 0]], np.infty, 1) > .3
    ix3 = np.logical_and(np.logical_or(abs(data_orig[:, 0]) > .2,
                                       abs(data_orig[:, 1]) < .3),
                         ~(abs(data_orig[:, 0]) < .05))
    rix = np.logical_and(np.logical_and(ix1, ix2), ix3)

    data = data_orig[:, :2][rix]

    return dict(
        point_cloud=data,
        delta=1,
        beta1=2
    )


def generate_2d_grids(N_total=5000, height=2, ratio=3, sigma_z=0.05):
    """
    Generate 2D grids. Generated by meshgrid function, indeces can be traced.

    --- returns ---
    data_all: grid data
    x_zero_idx: index when x is zero (y coordinate)
    y_zero_idx: index when y is zero (x coordinate)
    """
    N = int(np.sqrt(N_total / ratio))
    if N % 2 == 0:
        N += 1
    width = height * ratio
    grid_x = np.linspace(-width, width, N*ratio)
    grid_y = np.linspace(-height, height, N)
    x, y = np.meshgrid(grid_x, grid_y)
    data = np.vstack([x.flatten(), y.flatten()]).T
    x_zero_idx = (data[:, 0] == 0.0).nonzero()[0]
    y_zero_idx = (data[:, 1] == 0.0).nonzero()[0]
    z_noise = np.random.normal(0, sigma_z, data.shape[0])
    data_all = np.hstack([data, z_noise[:,None]])
    return data_all, x_zero_idx, y_zero_idx


def torus(n=1200, outer_rad=1, inner_rad=.5, height=1, primary_sigma=0.03,
          additional_dims=10, additional_sigma=0.01, **kwargs):
    """
    Create torus data with a near uniform grid.

    --- returns ---
    torus_noisy: torus with noise
    torus_clean: torus without noise
    colordict: contains theta: inner (verticle) loop, and phi: horizonal loop
    """
    k = int(np.sqrt(n))
    phi = np.linspace(0, 2*np.pi, k, endpoint=False)

    def theta_fun(phi_curr):
        n_sep_ = int(np.rint(
            k * (outer_rad + inner_rad * np.cos(phi_curr)) / float(outer_rad)))
        return np.linspace(0, 2*np.pi, n_sep_, endpoint=False)

    theta_all = [theta_fun(phi_curr) for phi_curr in phi]
    phi_all = [np.repeat(phi_curr, len(theta_all[i])) for (i, phi_curr) in enumerate(phi)]
    theta_all = np.hstack(theta_all)
    phi_all = np.hstack(phi_all)

    x = (outer_rad + inner_rad*np.cos(phi_all)) * np.cos(theta_all)
    y = (outer_rad + inner_rad*np.cos(phi_all)) * np.sin(theta_all)
    z = inner_rad * np.sin(phi_all) * height
    torus_clean  = np.vstack([x, y, z]).T
    torus_noisy = add_noises_on_primary_dimensions(
        torus_clean, primary_sigma)
    torus_noisy = add_noises_on_additional_dimensions(
        torus_noisy, addition_dims=additional_dims, sigmas=additional_sigma)

    return dict(
        point_cloud=torus_noisy,
        delta=1.1,
        beta1=2
    )


def three_torus(n0=100000, seed=42):
    rdn_state = np.random.RandomState(seed=seed)
    theta = rdn_state.uniform(0, 2*np.pi, n0)
    phi = rdn_state.uniform(0, 2*np.pi, n0)
    varphi = rdn_state.uniform(0, 2*np.pi, n0)

    w = (4 + (2 + np.cos(theta))*np.cos(phi))*np.cos(varphi)
    x = (4 + (2 + np.cos(theta))*np.cos(phi))*np.sin(varphi)
    y = (2 + np.cos(theta))*np.sin(phi)
    z = np.sin(theta)

    data_all = np.vstack([w, x, y, z]).T

    cix = choose_n_farthest_points(data_all, 2000, seed)
    data_orig = data_all[cix]
    intrinsic_coord = np.array([theta, phi, varphi]).T

    return dict(
        point_cloud=data_orig,
        intrinsic_coord=intrinsic_coord[cix],
        delta=1,
        beta1=3
    )


def genus_two(**kwargs):
    def solve_for_z(x, y):
        x, y = map(lambda x: np.array(x).astype('complex'), [x, y])
        return np.sqrt(
            0.01 - ((x**2 + y**2) ** 2 - .75*x**2 + .75*y**2)**2
        )
    aa = np.linspace(-1, 1, 1000)
    xx, yy = np.meshgrid(aa, aa)
    xx = xx.flatten()
    yy = yy.flatten()
    zz = solve_for_z(xx, yy)
    ix_on_surface = np.isreal(zz)

    data_pos = np.vstack([
        xx[ix_on_surface],
        yy[ix_on_surface],
        np.real(zz[ix_on_surface]),
    ]).T

    data_neg = np.vstack([
        xx[ix_on_surface],
        yy[ix_on_surface],
        -np.real(zz[ix_on_surface]),
    ]).T
    data = np.vstack([data_pos, data_neg])
    chosen_ix = choose_n_farthest_points(data, 1500, 42)
    data_orig = data[chosen_ix]
    return dict(
        point_cloud=data_orig,
        delta=0.9,
        beta1=4
    )


def tori_concat(**kwargs):
    all_data = []
    for x0 in np.array([-1, 0, 1, 2]) * 3:
        data = torus(**kwargs)['point_cloud']
        data[:, 0] -= x0
        all_data.append(data)

    all_data = np.vstack(all_data)
    return dict(
        point_cloud=all_data,
        delta=1,
        beta1=8
    )


def genereate_noises(sigmas, size, dimensions, seed=123):
    np.random.seed(seed)
    is_array_like = isinstance(sigmas, (collections.Sequence, np.ndarray))
    if is_array_like:
        assert len(sigmas) == dimensions, \
            'The size of sigmas should be the same as noises dimensions'
        return np.random.multivariate_normal(np.zeros(dimensions),
                                             np.diag(sigmas), size)
    else:
        return np.random.normal(0,sigmas,[size,dimensions])


def add_noises_on_primary_dimensions(data, sigmas=0.1, seed=123):
    size,dim = data.shape
    noises = genereate_noises(sigmas,size,dim)
    return data + noises


def add_noises_on_additional_dimensions(data, addition_dims, sigmas=1,
                                        seed=123):
    noises = genereate_noises(sigmas,data.shape[0],addition_dims,seed)
    return np.hstack((data,noises))


datagen_dict = {
    'punctplane': punctplane,
    'torus': torus,
    '3-torus': three_torus,
    'genus-2': genus_two,
    'tori-concat': tori_concat
}


def load_data(alias_data, seed=42, **kwargs):
    alias_data = alias_data.lower()
    datadict = datagen_dict[alias_data](seed=seed, **kwargs)
    return Dataset(datadict)


class Dataset(object):
    def __init__(self, datadict):
        self.__dict__.update(datadict)
