from jax import random
import jax.numpy as jnp
from .util import log1mexp


def modu_pdf(u, kl2, num_modes, offset):

  height = jnp.exp2(kl2)
  width = jnp.exp2(-kl2)

  x = num_modes * (u - offset)

  return height * (x - jnp.floor(x) < width).astype(u.dtype)


def modu_width_p_in_bounds(kl2, num_modes, offset, lower, upper):
  width = jnp.exp2(-kl2)

  lower_floor = jnp.floor(num_modes * (lower - offset)) / num_modes + offset
  upper_floor = jnp.floor(num_modes * (upper - offset)) / num_modes + offset

  middle_term = width * (upper_floor - lower_floor)
  lower_term = jnp.minimum(lower - lower_floor, width / num_modes)
  upper_term = jnp.minimum(upper - upper_floor, width / num_modes)

  return middle_term - lower_term + upper_term


def modu_width_q_in_bounds(kl2, num_modes, offset, lower, upper):
  return jnp.exp2(kl2) * modu_width_p_in_bounds(kl2, num_modes, offset, lower, upper)


class ModuUniformDensityRatio:

  def __init__(self, kl2, num_modes, offset):
    self.kl2 = kl2
    self.num_modes = num_modes
    self.offset = offset

    self.kl = self.kl2 * jnp.log(2)
    self.height = jnp.exp2(self.kl2)

    super().__init__()

  def log_ratio(self, u):
    return jnp.log(modu_pdf(u, self.kl2, self.num_modes, self.offset))

  def width_p(self, h, lower, upper):
    return modu_width_p_in_bounds(self.kl2, self.num_modes, self.offset, lower, upper)

  def width_q(self, h, lower, upper):
    return modu_width_q_in_bounds(self.kl2, self.num_modes, self.offset, lower, upper)

  def stretch(self, h):
    return -self.height * jnp.log1p(-h / self.height)

  def inv_stretch(self, t):
    return jnp.exp(self.kl + log1mexp(-t / self.height))