import functools
import unittest

import chex
import jax
import jax.numpy as jnp

from tabular_mvdrl.envs import mrp


class TestMarkovRewardProcesses(unittest.TestCase):
    def setUp(self):
        self.rng = jax.random.PRNGKey(0)
        self.num_states = 10
        self.reward_dim = 2
        self.num_terminal = 3
        self.base_mrp = mrp.MarkovRewardProcess.from_independent_priors(
            0,
            self.num_states,
            functools.partial(jax.random.dirichlet, alpha=jnp.ones(self.num_states)),
            functools.partial(jax.random.normal, shape=(self.reward_dim,)),
        )
        self.horizon = 4

    def test_finite_horizon_terminal_mrp_init(self):
        terminal_mrp = mrp.FiniteHorizonTerminalRewardMRP(self.base_mrp, self.horizon)
        self.assertListEqual(
            list(terminal_mrp.transition_kernel.shape),
            [self.horizon * self.num_states, self.horizon * self.num_states],
        )
        self.assertListEqual(
            list(terminal_mrp.cumulants.shape),
            [self.horizon * self.num_states, self.reward_dim],
        )

    def test_finite_horizon_terminal_mrp_transition_consistency(self):
        terminal_mrp = mrp.FiniteHorizonTerminalRewardMRP(self.base_mrp, self.horizon)
        transition_probs = jnp.sum(terminal_mrp.transition_kernel, axis=-1)
        self.assertTrue(jnp.allclose(transition_probs, jnp.ones_like(transition_probs)))

    def test_finite_horizon_terminal_mrp_reward(self):
        terminal_mrp = mrp.FiniteHorizonTerminalRewardMRP(self.base_mrp, self.horizon)
        nonterminal_cumulants = jnp.reshape(
            terminal_mrp.cumulants[: -self.num_states, :], (-1,)
        )
        terminal_cumulants = jnp.reshape(
            terminal_mrp.cumulants[(self.horizon - 1) * self.num_states :, :], (-1,)
        )
        expected_cumulants = self.base_mrp.cumulants.reshape((-1,))
        self.assertListEqual(
            list(nonterminal_cumulants), list(jnp.zeros_like(nonterminal_cumulants))
        )
        self.assertListEqual(list(terminal_cumulants), list(expected_cumulants))

    def test_finite_horizon_terminal_mrp_absorbing_state(self):
        terminal_mrp = mrp.FiniteHorizonTerminalRewardMRP(self.base_mrp, self.horizon)
        prior_probs = jax.nn.softmax(
            jax.random.normal(self.rng, shape=(self.num_states,))
        )
        prior_probs = (
            jnp.zeros(terminal_mrp.num_states).at[-self.num_states :].set(prior_probs)
        )
        p_t = terminal_mrp.transition_kernel.T @ prior_probs
        self.assertListEqual(list(p_t), list(prior_probs))

    def test_finite_horizon_terminal_mrp_terminal_state(self):
        terminal_mrp = mrp.FiniteHorizonTerminalRewardMRP(self.base_mrp, self.horizon)
        prior_probs = jax.nn.softmax(
            jax.random.normal(self.rng, shape=(self.num_states,))
        )
        prior_probs = (
            jnp.zeros(terminal_mrp.num_states).at[: self.num_states].set(prior_probs)
        )
        p = prior_probs
        for h in range(self.horizon):
            p = terminal_mrp.transition_kernel.T @ p
        self.assertTrue(
            jnp.allclose(p[: -self.num_states], jnp.zeros_like(p[: -self.num_states]))
        )
        self.assertAlmostEqual(jnp.sum(p), 1.0)

    def test_finite_horizon_terminal_mrp_return_dist(self):
        def encode_return(x: chex.Array) -> int:
            for i in range(self.base_mrp.num_states):
                cumulant = self.base_mrp.cumulants[i, :]
                if jnp.allclose(x, cumulant):
                    return i
            return -1

        terminal_mrp = mrp.FiniteHorizonTerminalRewardMRP(self.base_mrp, self.horizon)
        discount = 1.0
        mc_samples_per_state = 10
        mc_rollout = functools.partial(
            terminal_mrp.monte_carlo_return, discount=discount, max_steps=self.horizon
        )
        mc_keys = jax.random.split(self.rng, mc_samples_per_state)
        mc_samples = jax.vmap(
            jax.vmap(mc_rollout, in_axes=(None, 0)), in_axes=(0, None)
        )(mc_keys, jnp.arange(self.base_mrp.num_states))
        mc_samples = jnp.reshape(mc_samples, (-1, terminal_mrp.reward_dim))
        mc_samples_encoded = set([encode_return(x) for x in mc_samples])
        for encoding in mc_samples_encoded:
            self.assertNotEqual(encoding, -1)

    ### TERMINAL REWARD INFINITE HORIZON

    def test_terminal_mrp_init(self):
        terminal_mrp = mrp.TerminalRewardMRP(self.base_mrp, self.horizon)
        self.assertListEqual(
            list(terminal_mrp.transition_kernel.shape),
            [self.num_states + 1, self.num_states + 1],
        )
        self.assertListEqual(
            list(terminal_mrp.cumulants.shape),
            [self.num_states + 1, self.reward_dim],
        )

    def test_terminal_mrp_transition_consistency(self):
        terminal_mrp = mrp.TerminalRewardMRP(self.base_mrp, self.num_terminal)
        transition_probs = jnp.sum(terminal_mrp.transition_kernel, axis=-1)
        self.assertTrue(jnp.allclose(transition_probs, jnp.ones_like(transition_probs)))

    def test_terminal_mrp_reward(self):
        terminal_mrp = mrp.TerminalRewardMRP(self.base_mrp, self.num_terminal)
        nonterminal_cumulants = jnp.reshape(
            terminal_mrp.cumulants[: self.num_states - self.num_terminal, :], (-1,)
        )
        terminal_cumulants = jnp.reshape(
            terminal_mrp.cumulants[
                self.num_states - self.num_terminal : self.num_states, :
            ],
            (-1,),
        )
        expected_cumulants = self.base_mrp.cumulants[-self.num_terminal :].reshape(
            (-1,)
        )
        self.assertListEqual(
            list(nonterminal_cumulants), list(jnp.zeros_like(nonterminal_cumulants))
        )
        self.assertListEqual(list(terminal_cumulants), list(expected_cumulants))
        self.assertListEqual(list(terminal_mrp.cumulants[-1]), [0, 0])

    def test_terminal_mrp_absorbing_state(self):
        terminal_mrp = mrp.TerminalRewardMRP(self.base_mrp, self.num_terminal)
        other_prob = 0.5
        prior_probs = (1 - other_prob) * jax.nn.softmax(
            jax.random.normal(self.rng, shape=(self.num_terminal + 1,))
        )
        prior_probs = (
            jnp.zeros(terminal_mrp.num_states)
            .at[0]
            .set(other_prob)
            .at[self.base_mrp.num_states - self.num_terminal :]
            .set(prior_probs)
        )
        p_t = terminal_mrp.transition_kernel.T @ prior_probs
        self.assertAlmostEqual(p_t[-1].item(), 1 - other_prob)

    def test_terminal_mrp_return_dist(self):
        def encode_return(x: chex.Array) -> int:
            for i in range(self.num_terminal):
                cumulant = self.base_mrp.cumulants[-(i + 1), :]
                if jnp.allclose(x, cumulant):
                    return i
            return -1

        terminal_mrp = mrp.TerminalRewardMRP(self.base_mrp, self.num_terminal)
        discount = 1.0
        mc_samples_per_state = 10
        mc_rollout = functools.partial(
            terminal_mrp.monte_carlo_return, discount=discount, max_steps=1000
        )
        mc_keys = jax.random.split(self.rng, mc_samples_per_state)
        mc_samples = jax.vmap(
            jax.vmap(mc_rollout, in_axes=(None, 0)), in_axes=(0, None)
        )(mc_keys, jnp.arange(self.base_mrp.num_states))
        mc_samples = jnp.reshape(mc_samples, (-1, terminal_mrp.reward_dim))
        mc_samples_encoded = set([encode_return(x) for x in mc_samples])
        for encoding in mc_samples_encoded:
            self.assertNotEqual(encoding, -1)


if __name__ == "__main__":
    unittest.main()
