"""Here the encoders all operate in the time domain, i.e. they stretch the density ratio.
"""
from functools import partial

from jax import Array
from jax import jit
import jax.numpy as jnp
import jax.random as random
from jax.lax import cond
from jax.scipy.stats import norm as normal_dist
from jax.experimental.ode import odeint
from .gaussian_density_ratios import IsotropicGaussianDensityRatio
from .gauss_helpers import gauss_ratio_super_level_sols
from greedy_rejection_process.util import trunc_gumbel

from .util import logsubexp


@partial(jit, static_argnames=['gauss_dr'])
def slow_gauss_loop_body(k, key, log_time, gauss_dr):
  key = random.fold_in(key, k)
  log_time_key, x_key = random.split(key, num=2)

  log_time = -trunc_gumbel(log_time_key, shape=(), loc=0., bound=-log_time)
  x = gauss_dr.sample_p(x_key, shape=())
  r = gauss_dr.ratio(x)

  f = odeint(gauss_dr.stretch_ode, jnp.array(0.), jnp.array([0., r]))[-1]

  return x, f, log_time

def slow_gauss_encoder(seed, gauss_dr, max_iter=1_000):

  key = random.PRNGKey(seed)

  log_time = -jnp.inf

  for k in range(max_iter):
    x, f, log_time = slow_gauss_loop_body(k, key, log_time, gauss_dr)

    if log_time < jnp.log(f):
      return x, k

  else:
    raise ValueError("did not terminate!")


def slow_gauss_decoder(seed: int, 
                       k: int, 
                       p_loc: Array, 
                       p_scale: Array) -> Array:
  key = random.PRNGKey(seed)
  key = random.fold_in(key, k)

  _, x_key = random.split(key, num=2)

  # use `random.truncated_normal`, because `sample_p` uses it for the encoding
  return p_loc + p_scale * random.truncated_normal(x_key, shape=p_loc.shape, lower=-jnp.inf, upper=jnp.inf)


@partial(jit, static_argnames=['gauss_dr'])
def ggrs_gaus_encoder_loop_body(base_key: random.PRNGKey,
                               k: int,
                               bounds: Array,
                               log_time: Array,
                               gauss_dr: IsotropicGaussianDensityRatio):
    print(f"tracing")
    key = random.fold_in(base_key, k)
    log_time_key, u_key, b_key = random.split(key, num=3)

    bound_size = bounds[1] - bounds[0]

    log_time = -trunc_gumbel(log_time_key, shape=(), loc=jnp.log(bound_size), bound=-log_time)
    u = random.uniform(u_key, shape=())
    u = bounds[0] + bound_size * u

    x = normal_dist.ppf(u, 0., 1.)

    r = gauss_dr.ratio(x)

    f = odeint(gauss_dr.stretch_ode, jnp.array(0.), jnp.array([0., r]))[-1]

    return log_time, jnp.log(f), x, u, bound_size, b_key


def sac_gauss_encoder(seed: int,
                      gauss_dr: IsotropicGaussianDensityRatio,
                      max_iter: int = 100) -> tuple[Array, int]:

  log_time = -jnp.inf
  base_key = random.PRNGKey(seed)

  bounds = jnp.array([0., 1.])

  for k in range(1, max_iter + 1):

    log_time, log_f, x, u, bound_size, _ = ggrs_gaus_encoder_loop_body(base_key, k, bounds, log_time, gauss_dr)
    
    if log_time <= log_f:
      return x, k, -jnp.log2(bound_size)

    if x < gauss_dr.r_loc:
      bounds = jnp.array([u, bounds[1]])
    else:
      bounds = jnp.array([bounds[0], u])
  
  else:
    raise ValueError('did not terminate!')
    
  
def sac_gauss_decoder(seed: int,
                      k: int,
                      p_loc: Array,
                      p_scale: Array):
  """TODO"""


@partial(jit, static_argnames=['gauss_dr'])
def binary_gauss_oracle_computation(log_time0: Array,
                                       log_g_inv0: Array,
                                       bounds: Array,
                                       log_time: Array,
                                       gauss_dr: IsotropicGaussianDensityRatio):
  print("tracing oracle function")

  bound_center = (bounds[0] + bounds[1]) / 2.
  bound_left = jnp.array([bounds[0], bound_center])
  bound_right = jnp.array([bound_center, bounds[1]])

  log_g_inv = cond(True,#jnp.all(log_time0 == -jnp.inf),
                   lambda: jnp.log(odeint(gauss_dr.inv_stretch_ode, 0., jnp.array([0., jnp.exp(log_time)]))[-1]),
                   lambda: odeint(gauss_dr.log_inv_stretch_ode, log_g_inv0, jnp.array([jnp.exp(log_time0), jnp.exp(log_time)]))[-1])

  #change level to log level
  sol_down, sol_up = gauss_ratio_super_level_sols(
    gauss_dr.q_loc, gauss_dr.q_scale, gauss_dr.p_loc, gauss_dr.p_scale, log_g_inv)

  sol_down = normal_dist.cdf(sol_down, gauss_dr.p_loc, gauss_dr.p_scale)
  sol_up = normal_dist.cdf(sol_up, gauss_dr.p_loc, gauss_dr.p_scale)

  return sol_up, sol_down, bound_left, bound_right, log_g_inv


def binary_gauss_encoder(seed: int,
                         gauss_dr: IsotropicGaussianDensityRatio,
                         max_iter: int = 100) -> tuple[Array, int]:
  log_g_inv0 = -jnp.inf
  log_time = log_time0 = -jnp.inf
  base_key = random.PRNGKey(seed)

  bounds = jnp.array([0., 1.])

  for k in range(1, max_iter + 1):
    log_time, log_f, x, _, bound_size, b_key = ggrs_gaus_encoder_loop_body(base_key, k, bounds, log_time, gauss_dr)

    if log_time <= log_f:
      return x, k, -jnp.log2(bound_size)

    sol_up, sol_down, bound_left, bound_right, log_g_inv = binary_gauss_oracle_computation(log_time0, log_g_inv0, bounds, log_time, gauss_dr)

    log_g_inv0 = jnp.maximum(log_g_inv, log_g_inv0)
    log_time0 = jnp.maximum(log_time, log_time0)

    assert not jnp.isnan(log_g_inv), "log g inv was nan"

    # compute bound measures: for the parent bound we should be guaranteed intersection
    set_down = jnp.maximum(bounds[0], sol_down)
    set_up = jnp.minimum(bounds[1], sol_up)

    log_p_measure = jnp.log(set_up - set_down)
    log_q_upper_term = normal_dist.logcdf(normal_dist.ppf(set_up, gauss_dr.p_loc, gauss_dr.p_scale), gauss_dr.q_loc, gauss_dr.q_scale)
    log_q_lower_term = normal_dist.logcdf(normal_dist.ppf(set_down, gauss_dr.p_loc, gauss_dr.p_scale), gauss_dr.q_loc, gauss_dr.q_scale)
    log_q_measure = logsubexp(log_q_upper_term, log_q_lower_term)

    log_norm_const = logsubexp(log_q_measure, log_g_inv + log_p_measure)

    assert not (sol_down > bounds[1] or bounds[0] > sol_up), "parent bounds did not intersect the superlevel set!"

    # right: check if the sets intersect at all
    if sol_down > bound_right[1] or bound_right[0] > sol_up:
      log_right_prob = -jnp.inf

    else:
      right_set_down = jnp.maximum(bound_right[0], sol_down)
      right_set_up = jnp.minimum(bound_right[1], sol_up)

      log_right_p_measure = jnp.log(right_set_up - right_set_down)
      log_right_q_upper_term = normal_dist.logcdf(normal_dist.ppf(right_set_up, gauss_dr.p_loc, gauss_dr.p_scale), gauss_dr.q_loc, gauss_dr.q_scale)
      log_right_q_lower_term = normal_dist.logcdf(normal_dist.ppf(right_set_down, gauss_dr.p_loc, gauss_dr.p_scale), gauss_dr.q_loc, gauss_dr.q_scale)
      log_right_q_measure = logsubexp(log_right_q_upper_term, log_right_q_lower_term)

      log_right_prob = logsubexp(log_right_q_measure, log_g_inv + log_right_p_measure)

    right_prob = jnp.exp(log_right_prob - log_norm_const)
    assert 0. <= right_prob <= 1., f"{k, log_time, right_prob, log_right_q_measure, log_right_p_measure, log_g_inv, log_right_prob, bounds, log_norm_const=}"

    b = random.bernoulli(b_key, right_prob).astype(jnp.int32)

    bounds = [bound_left, bound_right][b]

    #return 0., 0., 0.

  else:
    raise ValueError('did not terminate!')
