from typing import Sequence

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.serialization import msgpack_restore
from flax.training import train_state  # Useful dataclass to keep train state
import flax
import numpy as np  # Ordinary NumPy
import optax  # Optimizers
from functools import partial
from gym import spaces
import matplotlib.pyplot as plt
import seaborn as sns


def pretty_time(elapsed):
    """
    Converts a float (seconds) into a pretty string, i.e., "1023983" -> "2h 34min"
    """
    if elapsed > 60 * 60:
        h = int(elapsed // (60 * 60))
        mins = int((elapsed // 60) % 60)
        return f"{h}h {mins:02d} min"
    elif elapsed > 60:
        mins = elapsed // 60
        secs = int(elapsed) % 60
        return f"{mins:0.0f}min {secs}s"
    elif elapsed < 1:
        return f"{elapsed*1000:0.1f}ms"
    else:
        return f"{elapsed:0.1f}s"


def pretty_number(number):
    """
    Converts a large number into SI representation (3409230384-> 3.4G)
    """
    if number >= 1.0e9:
        return f"{number/1e9:0.3g}G"
    elif number >= 1.0e6:
        return f"{number/1e6:0.3g}M"
    elif number >= 1.0e3:
        return f"{number/1e3:0.3g}k"
    else:
        return number


def v_contains(box, states):
    """
    Computes a bool-array indicating whether states are inside the box.
    NumPy version (for Jax version see jv_contains)
    """
    b_low = np.expand_dims(box.low, axis=0)
    b_high = np.expand_dims(box.high, axis=0)
    contains = np.logical_and(
        np.all(states >= b_low, axis=1), np.all(states <= b_high, axis=1)
    )
    return contains


def jv_contains(box, states):
    """
    Computes a bool-array indicating whether states are inside the box.
    JAX version (for NumPy version see jv_contains)
    """
    b_low = jnp.expand_dims(box.low, axis=0)
    b_high = jnp.expand_dims(box.high, axis=0)
    contains = np.logical_and(
        jnp.all(states >= b_low, axis=1), jnp.all(states <= b_high, axis=1)
    )
    return contains


def v_intersect(box, lb, ub):
    """
    Computes a bool-array indicating whether (lb,ub) boxes overlap/intersect with the box.
    NumPy version (for Jax version see jv_intersect)
    """
    b_low = np.expand_dims(box.low, axis=0)
    b_high = np.expand_dims(box.high, axis=0)
    contain_lb = np.logical_and(lb >= b_low, lb <= b_high)
    contain_ub = np.logical_and(ub >= b_low, ub <= b_high)
    contains_any = np.all(np.logical_or(contain_lb, contain_ub), axis=1)

    return contains_any


def jv_intersect(box, lb, ub):
    """
    Computes a bool-array indicating whether (lb,ub) boxes overlap/intersect with the box.
    Jax version (for numpy version see v_intersect)
    """
    b_low = jnp.expand_dims(box.low, axis=0)
    b_high = jnp.expand_dims(box.high, axis=0)
    contain_lb = jnp.logical_and(lb >= b_low, lb <= b_high)
    contain_ub = jnp.logical_and(ub >= b_low, ub <= b_high)
    # every axis much either lb or ub contain
    contains_any = jnp.all(jnp.logical_or(contain_lb, contain_ub), axis=1)

    return contains_any


def clip_and_filter_spaces(obs_space, space_list):
    """
    Projects the list of Box spaces into the obs_space and returns a list of projected Box spaces
    """
    new_space_list = []
    for space in space_list:
        new_space = spaces.Box(
            low=np.clip(space.low, obs_space.low, obs_space.high),
            high=np.clip(space.high, obs_space.low, obs_space.high),
        )
        volume = np.prod(new_space.high - new_space.low)
        if volume > 0:
            new_space_list.append(new_space)
    return new_space_list


def make_unsafe_spaces(obs_space, unsafe_bounds):
    """
    Creates a list of Box spaces that represent the set complement of the obs_space
    minus the sets inside the unsafe bounds
    """
    unsafe_spaces = []
    dims = obs_space.shape[0]
    for i in range(dims):
        low = np.array(obs_space.low)
        high = np.array(obs_space.high)
        high[i] = -unsafe_bounds[i]
        if not np.allclose(low, high):
            unsafe_spaces.append(spaces.Box(low=low, high=high, dtype=np.float32))

        high = np.array(obs_space.high)
        low = np.array(obs_space.low)
        low[i] = unsafe_bounds[i]
        if not np.allclose(low, high):
            unsafe_spaces.append(spaces.Box(low=low, high=high, dtype=np.float32))
    return unsafe_spaces


def make_corner_spaces(obs_space, unsafe_bounds):
    """
    Creates a list of Box spaces that represent the corners of the obs_space.
    Size of the corners is such that they are greater than the unsafe bounds
    """
    unsafe_spaces = []
    dims = obs_space.shape[0]
    for i in range(dims):
        low = np.array(obs_space.low)
        high = np.array(obs_space.high)
        high[i] = low[i] + unsafe_bounds[i]
        if not np.allclose(low, high):
            unsafe_spaces.append(spaces.Box(low=low, high=high, dtype=np.float32))

        high = np.array(obs_space.high)
        low = np.array(obs_space.low)
        low[i] = high[i] - unsafe_bounds[i]
        if not np.allclose(low, high):
            unsafe_spaces.append(spaces.Box(low=low, high=high, dtype=np.float32))
    return unsafe_spaces


def enlarge_space(space, bound, limit_space=None):
    """
    Enlarges a given space by the values of bound (multi-dim array).
    If a limit_space is given, the resulting enlarged space will be projected into the limit_space
    """
    new_space = spaces.Box(low=space.low - bound, high=space.high + bound)
    if limit_space is not None:
        new_space = spaces.Box(
            low=np.clip(new_space.low, limit_space.low, limit_space.high),
            high=np.clip(new_space.high, limit_space.low, limit_space.high),
        )
    return new_space


@jax.jit
def clip_grad_norm(grad, max_norm):
    norm = jnp.linalg.norm(
        jax.tree_util.tree_leaves(jax.tree_map(jnp.linalg.norm, grad))
    )
    factor = jnp.minimum(max_norm, max_norm / (norm + 1e-6))
    return jax.tree_map((lambda x: x * factor), grad)


def contained_in_any(spaces, state):
    """
    Returns True if state is contained in at least of of the Box spaces, False otherwise
    """
    for space in spaces:
        if space.contains(state):
            return True
    return False


def triangular(rng_key, shape):
    """
    Samples from a triangular distribution with mean 0 and range (-1,+1) with shape shape
    """
    U = jax.random.uniform(rng_key, shape=shape)
    p1 = -1 + jnp.sqrt(2 * U)
    p2 = 1 - jnp.sqrt((1 - U) * 2)
    return jnp.where(U <= 0.5, p1, p2)


# Not used?
# def softhuber(x):
#     return jnp.sqrt(1 + jnp.square(x)) - 1


class MLP(nn.Module):
    features: Sequence[int]
    activation: str = "relu"
    softplus_output: bool = False

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.Dense(feat, kernel_init=jax.nn.initializers.glorot_uniform())(x)
            if self.activation == "relu":
                x = nn.relu(x)
            else:
                x = nn.tanh(x)
        x = nn.Dense(
            self.features[-1], kernel_init=jax.nn.initializers.glorot_uniform()
        )(x)
        if self.softplus_output:
            x = jax.nn.softplus(x)
        return x


# Must be called "Dense" because flax uses self.__class__.__name__ to name variables
class Dense(nn.Module):
    """Interval-bound propagation abstract interpretation of a flax.linen.Dense layer
    IBP paper: https://arxiv.org/abs/1810.12715
    """

    features: int

    @nn.compact
    def __call__(self, inputs):
        lower_bound_head, upper_bound_head = inputs
        kernel = self.param(
            "kernel",
            jax.nn.initializers.glorot_uniform(),
            (lower_bound_head.shape[-1], self.features),
        )  # shape info.
        bias = self.param("bias", nn.initializers.zeros, (self.features,))
        # Center and width
        center_prev = 0.5 * (upper_bound_head + lower_bound_head)
        edge_len_prev = 0.5 * jnp.maximum(
            upper_bound_head - lower_bound_head, 0
        )  # avoid numerical issues

        # Two matrix multiplications
        center = jnp.matmul(center_prev, kernel) + bias
        edge_len = jnp.matmul(edge_len_prev, jnp.abs(kernel))  # Edge length has no bias

        # New bounds
        lower_bound_head = center - edge_len
        upper_bound_head = center + edge_len
        # self.sow("intermediates", "edge_len", edge_len)
        return [lower_bound_head, upper_bound_head]


class IBPMLP(nn.Module):
    """Interval-bound propagation abstract interpretation of an MLP model"""

    features: Sequence[int]
    activation: str = "relu"
    softplus_output: bool = False

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = Dense(feat)(x)
            if self.activation == "relu":
                x = [nn.relu(x[0]), nn.relu(x[1])]
            else:
                x = [nn.tanh(x[0]), nn.tanh(x[1])]
        x = Dense(self.features[-1])(x)
        if self.softplus_output:
            x = [jax.nn.softplus(x[0]), jax.nn.softplus(x[1])]
        return x


def martingale_loss(l, l_next, eps):
    diff = l_next - l
    return jnp.mean(jnp.maximum(diff + eps, 0))


def jax_save(params, filename):
    """Saves parameters into a file"""
    bytes_v = flax.serialization.to_bytes(params)
    with open(filename, "wb") as f:
        f.write(bytes_v)


def jax_load(params, filename):
    """Loads parameters from a file"""
    with open(filename, "rb") as f:
        bytes_v = f.read()
    try:
        params = flax.serialization.from_bytes(params, bytes_v)
    except ValueError as e:
        print("Caught exeception")
        state_dict = msgpack_restore(bytes_v)
        print(state_dict)
        raise e
    return params


def lipschitz_l1_jax(params):
    lipschitz_l1 = 1
    sum_axis = 1  # flax dense is transposed
    for i, (k, v) in enumerate(params["params"].items()):
        lipschitz_l1 *= jnp.max(jnp.sum(jnp.abs(v["kernel"]), axis=sum_axis))

    return lipschitz_l1


def lipschitz_linf_jax(params):
    lipschitz_linf = 1
    sum_axis = 0  # flax dense is transposed
    for i, (k, v) in enumerate(params["params"].items()):
        lipschitz_linf *= jnp.max(jnp.sum(jnp.abs(v["kernel"]), axis=sum_axis))

    return lipschitz_linf


def create_train_state(model, rng, in_dim, learning_rate, ema=0, clip_norm=None):
    """Creates initial `TrainState`."""
    params = model.init(rng, jnp.ones([1, in_dim]))
    tx = optax.adam(learning_rate)
    if clip_norm is not None:
        tx = optax.chain(tx, optax.clip_by_global_norm(clip_norm))
    if ema > 0:
        tx = optax.chain(tx, optax.ema(ema))
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


def get_pmass_grid(env, n):
    """
    Compute the bounds of the sum terms and corresponding probability masses
    for the expectation computation
    """
    dims = len(env.noise_bounds[0])
    grid, steps = [], []
    for i in range(dims):
        samples, step = jnp.linspace(
            env.noise_bounds[0][i],
            env.noise_bounds[1][i],
            n,
            endpoint=False,
            retstep=True,
        )
        grid.append(samples)
        steps.append(step)
    grid_lb = jnp.meshgrid(*grid)
    grid_lb = [x.flatten() for x in grid_lb]
    grid_ub = [grid_lb[i] + steps[i] for i in range(dims)]

    if dims < env.observation_dim:
        # Fill remaining dimensions with 0
        remaining = env.observation_dim - len(env.noise_bounds)
        for i in range(remaining):
            grid_lb.append(jnp.zeros_like(grid_lb[0]))
            grid_ub.append(jnp.zeros_like(grid_lb[0]))
    batched_grid_lb = jnp.stack(grid_lb, axis=1)  # stack on input  dim
    batched_grid_ub = jnp.stack(grid_ub, axis=1)  # stack on input dim
    pmass = env.integrate_noise(grid_lb, grid_ub)
    return pmass, batched_grid_lb, batched_grid_ub


@partial(jax.jit, static_argnums=(0, 1))
def compute_expected_l(
    env, ibb_apply_fn, params, s, a, pmass, pmass_grid_lb, pmass_grid_ub
):
    """
    Compute kernel (jit compiled) that computes an upper bounds on the expected value of L(s next)
    """
    deterministic_s_next = env.v_next(s, a)
    batch_size = s.shape[0]
    ibp_size = pmass_grid_lb.shape[0]
    obs_dim = env.observation_dim

    # Broadcasting happens here, that's why we don't do directly vmap (although it's probably possible somehow)
    deterministic_s_next = deterministic_s_next.reshape((batch_size, 1, obs_dim))
    pmass_grid_lb = pmass_grid_lb.reshape((1, ibp_size, obs_dim))
    pmass_grid_ub = pmass_grid_ub.reshape((1, ibp_size, obs_dim))

    pmass_grid_lb = pmass_grid_lb + deterministic_s_next
    pmass_grid_ub = pmass_grid_ub + deterministic_s_next

    pmass_grid_lb = pmass_grid_lb.reshape((-1, obs_dim))
    pmass_grid_ub = pmass_grid_ub.reshape((-1, obs_dim))
    lb, ub = ibb_apply_fn(params, [pmass_grid_lb, pmass_grid_ub])
    ub = ub.reshape((batch_size, ibp_size))

    pmass = pmass.reshape((1, ibp_size))  # Boradcast to batch size
    exp_terms = pmass * ub
    expected_value = jnp.sum(exp_terms, axis=1)
    return expected_value


def plot_policy(env, policy, filename, rsm=None, title=None):
    dims = env.observation_dim

    sns.set()
    fig, ax = plt.subplots(figsize=(6, 6))

    if env.observation_dim == 2:
        if rsm is not None:
            grid, new_steps = [], []
            for i in range(dims):
                samples = jnp.linspace(
                    env.observation_dim.low[i],
                    env.observation_dim.high[i],
                    50,
                    endpoint=False,
                    retstep=True,
                )
                grid.append(samples.flatten())
            grid = jnp.meshgrid(*grid)
            grid = jnp.stack(grid, axis=1)
            l = rsm.apply_fn(rsm.params, grid).flatten()
            l = np.array(l)
            sc = ax.scatter(
                grid[:, 0], grid[:, 1], marker="s", c=l, zorder=1, alpha=0.7
            )
            fig.colorbar(sc)

    n = 50
    rng = jax.random.PRNGKey(3)
    rng, r = jax.random.split(rng)
    r = jax.random.split(r, n)
    state, obs = env.v_reset(r)
    done = jnp.zeros(n, dtype=jnp.bool_)
    total_returns = jnp.zeros(n)
    obs_list = []
    done_list = []
    while not jnp.any(done):
        action = policy.apply_fn(policy.params, obs)
        rng, r = jax.random.split(rng)
        r = jax.random.split(r, n)
        state, new_obs, reward, new_done = env.v_step(state, action, r)
        total_returns += reward * (1.0 - done)
        done_list.append(done)
        obs_list.append(obs)
        obs, done = new_obs, new_done
    obs_list = jnp.stack(obs_list, 1)
    done_list = jnp.stack(done_list, 1)
    traces = [obs_list[i, jnp.logical_not(done_list[i])] for i in range(n)]

    if title is None:
        title = env.name

    title = (
        title
        + f" ({jnp.mean(total_returns):0.1f} [{jnp.min(total_returns):0.1f},{jnp.max(total_returns):0.1f}])"
    )
    ax.set_title(title)

    terminals_x, terminals_y = [], []
    for i in range(n):
        ax.plot(
            traces[i][:, 0],
            traces[i][:, 1],
            color=sns.color_palette()[0],
            zorder=2,
            alpha=0.15,
        )
        ax.scatter(
            traces[i][:, 0],
            traces[i][:, 1],
            color=sns.color_palette()[0],
            zorder=2,
            marker=".",
            alpha=0.4,
        )
        terminals_x.append(float(traces[i][-1, 0]))
        terminals_y.append(float(traces[i][-1, 1]))
    ax.scatter(terminals_x, terminals_y, color="white", marker="x", zorder=5)
    for init in env.init_spaces:
        x = [
            init.low[0],
            init.high[0],
            init.high[0],
            init.low[0],
            init.low[0],
        ]
        y = [
            init.low[1],
            init.low[1],
            init.high[1],
            init.high[1],
            init.low[1],
        ]
        ax.plot(x, y, color="cyan", alpha=0.5, zorder=7)
    if hasattr(env, "_reward_boxes"):
        for box, rs in env._reward_boxes:
            x = [
                box.low[0],
                box.high[0],
                box.high[0],
                box.low[0],
                box.low[0],
            ]
            y = [
                box.low[1],
                box.low[1],
                box.high[1],
                box.high[1],
                box.low[1],
            ]
            ax.plot(x, y, color="yellow", alpha=0.5, zorder=7)
    for unsafe in env.unsafe_spaces:
        x = [
            unsafe.low[0],
            unsafe.high[0],
            unsafe.high[0],
            unsafe.low[0],
            unsafe.low[0],
        ]
        y = [
            unsafe.low[1],
            unsafe.low[1],
            unsafe.high[1],
            unsafe.high[1],
            unsafe.low[1],
        ]
        ax.plot(x, y, color="red", alpha=0.5, zorder=7)
    for target_space in env.target_spaces:
        x = [
            target_space.low[0],
            target_space.high[0],
            target_space.high[0],
            target_space.low[0],
            target_space.low[0],
        ]
        y = [
            target_space.low[1],
            target_space.low[1],
            target_space.high[1],
            target_space.high[1],
            target_space.low[1],
        ]
        ax.plot(x, y, color="green", alpha=0.5, zorder=7)

    ax.set_xlim([env.observation_space.low[0], env.observation_space.high[0]])
    ax.set_ylim([env.observation_space.low[1], env.observation_space.high[1]])
    fig.tight_layout()
    fig.savefig(filename)
    plt.close(fig)


if __name__ == "__main__":

    learning_rate = 0.0005
    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)
    layer_size = [64, 16, 5]
    model = MLP(layer_size)
    ibp_model = IBPMLP(layer_size)
    state = create_train_state(model, init_rng, 8, learning_rate)

    print("Lipschitz: ", compute_lipschitz(state.params))
    del init_rng  # Must not be used anymore.

    fake_x = jax.random.uniform(
        jax.random.PRNGKey(0), shape=(1, 8), minval=-1, maxval=1
    )
    fake_y = model.apply(state.params, fake_x)
    fake_x_lb = fake_x - 0.01
    fake_x_ub = fake_x + 0.01
    print("fake_x\n", fake_x)
    print("fake_x_lb\n", fake_x_lb)
    print("fake_x_ub\n", fake_x_ub)
    print("#### output ####")
    print("Fake y\n", fake_y)
    (fake_y_lb, fake_y_ub), mod_vars = ibp_model.apply(
        state.params, [fake_x_lb, fake_x_ub], mutable="intermediates"
    )
    print("Fake lb\n", fake_y_lb)
    print("Fake ub\n", fake_y_ub)

    print("diff", fake_y_ub - fake_y_lb)
    # print("sowed vars", mod_vars)

    # print("Params: ", state.params)
    # model = MLP([12, 8, 4])
    # batch = jnp.ones((32, 10))
    # variables = model.init(jax.random.PRNGKey(0), batch)
    # output = model.apply(variables, batch)