from typing import Callable

import chex
import einops
import jax
import jax.numpy as jnp

SupportInitializer = Callable[[chex.PRNGKey], chex.Array]
SupportMapInitializer = Callable[[chex.PRNGKey], chex.Array]


def repeated_map(support_init: SupportInitializer, n: int) -> SupportMapInitializer:
    def _map(key: chex.PRNGKey) -> chex.Array:
        base_map = support_init(key)
        return einops.repeat(base_map, "... -> n ...", n=n)

    return _map


def independent_map(support_init: SupportInitializer, n: int) -> SupportMapInitializer:
    def _map(key: chex.PRNGKey) -> chex.Array:
        keys = jax.random.split(key, n)
        return jax.vmap(support_init)(keys)

    return _map


def uniform_random_support(
    d: int,
    num_atoms: int,
    maxval: float | chex.Array,
    minval: float | chex.Array | None = None,
):
    if isinstance(maxval, float):
        maxval = maxval * jnp.ones(d)
    if minval is None:
        minval = -maxval
    elif isinstance(minval, float):
        minval = minval * jnp.ones(d)

    def _map(rng: chex.PRNGKey):
        def unif(key, s, m, M):
            return jax.random.uniform(key, shape=s, minval=m, maxval=M)

        return jax.vmap(unif, in_axes=(0, None, 0, 0), out_axes=-1)(
            jax.random.split(rng, d), (num_atoms,), minval, maxval
        )

    return _map


def uniform_lattice(
    d: int,
    bins_per_dim: int,
    maxval: float | chex.Array,
    minval: float | chex.Array | None = None,
) -> SupportInitializer:
    if isinstance(maxval, float):
        maxval = maxval * jnp.ones(d)
    if minval is None:
        minval = -maxval
    elif isinstance(minval, float):
        minval = minval * jnp.ones(d)

    def _map(key: chex.PRNGKey) -> chex.Array:
        del key
        dim_bins = jax.vmap(jnp.linspace, in_axes=(0, 0, None))(
            minval, maxval, bins_per_dim
        )
        all_bins = jax.vmap(
            jax.vmap(lambda x, y: jnp.array([x, y]), in_axes=(None, 0)),
            in_axes=(0, None),
        )(dim_bins[0], dim_bins[1])
        return jnp.reshape(all_bins, (-1, d))

    return _map


def explicit_support(support: chex.Array) -> SupportInitializer:
    return lambda _: support
