import numpy as np
import tree
from gym import spaces


def rescale_values(values, old_low, old_high, new_low, new_high):
    rescaled_values = new_low + (new_high - new_low) * (
        (values - old_low) / (old_high - old_low))
    rescaled_values = np.clip(rescaled_values, new_low, new_high)
    return rescaled_values


def estimate_value_function(environment,
                            policy,
                            discount=1.0,
                            num_rollouts=100,
                            max_rollout_length=1000):
    if isinstance(environment.observation_space, spaces.Box):
        states = np.reshape(np.stack(np.meshgrid(*np.linspace(
            environment.observation_space.low,
            environment.observation_space.high,
            25,
        ).T, indexing='ij'), axis=-1), (-1, 2))
    else:
        raise NotImplementedError(environment.observation_space)

    values = np.full((states.shape[0], num_rollouts), np.nan)
    for state_i, state in enumerate(states):
        print(f"state {state_i}/{len(states)}: {state}")
        for iteration in range(num_rollouts):
            print(f"iteration {iteration}/{num_rollouts}")
            path = []
            observation_t_0 = environment.reset(state)
            for i in range(max_rollout_length):
                action = policy.action(observation_t_0).numpy().reshape(
                    environment.action_space.shape)
                observation_t_1, reward, terminal, info = (
                    environment.step(action))
                path.append({
                    'observation_t_0': observation_t_0,
                    'observation_t_1': observation_t_1,
                    'reward': reward,
                    'terminal': terminal,
                    'info': info,
                })
                observation_t_0 = observation_t_1
                if terminal:
                    break

            value = np.sum([
                (discount ** float(i)) * step['reward']
                for i, step in enumerate(path)
            ])
            values[state_i, iteration] = value

    values = np.mean(values, axis=-1)[..., None]

    return states, values


def generate_dataset(environment,
                     behavior_policy,
                     target_policy=None,
                     num_samples=1000,
                     max_path_length=None,
                     verbose=0.,
                     seed=1,
                     criteria=("MSE", ),
                     error_every=1,
                     eval_on_traces=False,
                     double_samples=False,
                     n_samples_eval=None,
                     independent_samples=True):
    print("Generate samples.")
    samples = sample_environment(
        environment,
        policy=behavior_policy,
        num_samples=num_samples,
        max_path_length=max_path_length,
        independent_samples=independent_samples)

    if double_samples:
        raise NotImplementedError
        print("Generate double samples.")
        double_samples = (
            actions_1,
            rewards_1,
            states_2,
        ) = sample_environment(
            environment,
            policy=behavior_policy,
            states=samples['state_0'],
            seed=seed)

    if eval_on_traces:
        raise NotImplementedError
        print("Evaluation of traces samples.")
        set_mu_from_states(
            seed=mu_seed,
            s=samples['state_0'],
            n_samples_eval=n_samples_eval)

    print("Generate off-policy weights.")
    action_probability_target = target_policy.probs(
        samples['state_0'], samples['action']).numpy()
    action_probability_behavior = behavior_policy.probs(
        samples['state_0'], samples['action']).numpy()
    rho = action_probability_target / action_probability_behavior
    samples.update({
        'rho': rho,
        'action_probability_target': action_probability_target,
        'action_probability_behavior': action_probability_behavior,
    })

    return {'samples': samples, 'double_samples': double_samples}


def sample_environment_independent_samples(environment,
                                           policy,
                                           num_samples=1000,
                                           max_path_length=None,
                                           independent_samples=True):
    samples = []
    for i in range(num_samples):
        state_0_0 = environment.uniform_random_state()
        state_0 = environment.reset(state_0_0)
        assert np.all(state_0 == state_0_0)
        action = policy.action(state_0).numpy().reshape(
            environment.action_space.shape)
        state_1, reward, terminal, info = environment.step(action)

        samples += [{
            'state_0': np.atleast_1d(state_0),
            'action': np.atleast_1d(action),
            'state_1': np.atleast_1d(state_1),
            'reward': np.atleast_1d(reward),
            'terminal': np.atleast_1d(terminal),
            'info': info,
        }]

    samples = tree.map_structure(lambda *x: np.stack(x), *samples)

    return samples


# @memory.cache(hashfun={"mymdp": repr, "policy": repr}, ignore=["verbose"])
def sample_environment(environment,
                       policy,
                       num_samples=1000,
                       max_path_length=None,
                       independent_samples=True):
    if independent_samples:
        return sample_environment_independent_samples(
            environment,
            policy,
            num_samples=num_samples,
            max_path_length=max_path_length)
    max_path_length = max_path_length or num_samples

    samples = []
    terminal = True
    path_length = 0

    for i in range(num_samples):
        if terminal or max_path_length < path_length:
            state_0 = environment.reset()
            path_length = 0

        action = policy.action(state_0).numpy()
        state_1, reward, terminal, info = environment.step(action)

        samples += [{
            'state_0': np.atleast_1d(state_0),
            'action': np.atleast_1d(action),
            'state_1': np.atleast_1d(state_1),
            'reward': np.atleast_1d(reward),
            'terminal': np.atleast_1d(terminal),
            'info': info,
        }]

        state_0 = state_1
        path_length += 1

    samples = tree.map_structure(lambda *x: np.stack(x), *samples)

    return samples
