import jax.random as random
import jax.numpy as jnp


def elegant_pairing(x: int, y: int):
    """http://szudzik.com/ElegantPairing.pdf"""
    
    a = jnp.where(x >= 0, 2 * x, -2 * x - 1)
    b = jnp.where(y >= 0, 2 * y, -2 * y - 1)

    return jnp.where(a >= b, a * a + a + b, a + b * b)


def log1mexp(x):
    x = jnp.abs(x)
    return jnp.where(x < jnp.log(2.), jnp.log(-jnp.expm1(-x)), jnp.log1p(-jnp.exp(-x)))


def logsubexp(x, y, return_sign = False):
    larger = jnp.maximum(x, y)
    smaller = jnp.minimum(x, y)

    result = larger + log1mexp(jnp.maximum(larger - smaller, 0.))
    result = jnp.where(larger == -jnp.inf, -jnp.inf, result)

    if return_sign:
        return result, jnp.where(x < y, -1., 1.)

    return result


def trunc_gumbel(key, shape, loc, bound):
    """
    Samples a Gumbel variate truncated below the given bound with location loc
    """
    u = random.uniform(key, shape)
    g = -jnp.log(u) + jnp.exp(-bound + loc)
    g = loc - jnp.log(g)

    return g


def laplace_kl(q_loc, q_scale, p_loc, p_scale):
    delta = jnp.abs(q_loc - p_loc)
    log_scale_ratio = jnp.log(p_scale) - jnp.log(q_scale)

    kl = log_scale_ratio = jnp.exp(-(log_scale_ratio + delta / q_scale))
    kl = kl + delta / p_scale - 1.

    return kl