from typing import NamedTuple
import jax.numpy as jnp
from numpy.typing import NDArray


class State(NamedTuple):
    """
    Indexed by [group, (x value)].
    """

    # "X":  "X=x"
    # "lX": "X<=x"
    # "a": and
    # "g": given
    pr_G: NDArray
    pr_X: NDArray
    pr_Y1gX: NDArray
    pr_Y1aX: NDArray
    pr_Y0aX: NDArray
    pr_lX: NDArray
    pr_Y1alX: NDArray
    pr_Y0alX: NDArray
    pr_Y1: NDArray  # qualified
    pr_Y0: NDArray  # unqualified


def create_state(pr_G, pr_X, pr_Y1gX, rng=None) -> State:
    # These follow from Bayes' rule calcs
    # TODO check broadcasting should work correctly
    pr_Y1aX = pr_X * pr_Y1gX
    pr_Y0aX = pr_X * (1 - pr_Y1gX)
    # TODO fix cumulative rounding errors (is axis=-1 right?)
    pr_lX = jnp.cumsum(pr_X, axis=-1)
    pr_Y1alX = jnp.cumsum(pr_Y1aX, axis=-1)  # not a pdf
    pr_Y0alX = jnp.cumsum(pr_Y0aX, axis=-1)  # not a pdf

    pr_Y1 = pr_Y1alX[:, -1]
    pr_Y0 = pr_Y0alX[:, -1]
    return State(
        pr_G=pr_G,
        pr_X=pr_X,
        pr_Y1gX=pr_Y1gX,
        pr_Y1aX=pr_Y1aX,
        pr_Y0aX=pr_Y0aX,
        pr_lX=pr_lX,
        pr_Y1alX=pr_Y1alX,
        pr_Y0alX=pr_Y0alX,
        pr_Y1=pr_Y1,
        pr_Y0=pr_Y0,
    )
