import collections
import numpy as np

from dm_control.utils import rewards
from dm_control.rl import control
from dm_control.suite.cartpole import (
    _DEFAULT_TIME_LIMIT,
    get_model_and_assets,
    Physics as CartpolePhysics,
    Balance as BalanceTask,
    SUITE)


@SUITE.add()
def custom_swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT,
                          control_cost_weight=0.1,
                          random=None,
                          environment_kwargs=None):
    """Returns the sparse reward variant of the Cartpole Swingup task."""
    physics = CartpolePhysics.from_xml_string(*get_model_and_assets())
    task = CustomBalanceTask(
        swing_up=True,
        sparse=True,
        control_cost_weight=control_cost_weight,
        random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(
        physics, task, time_limit=time_limit, **environment_kwargs)


@SUITE.add()
def custom_swingup_vision_sparse(time_limit=_DEFAULT_TIME_LIMIT,
                                 control_cost_weight=0.1,
                                 random=None,
                                 environment_kwargs=None):
    """Returns the sparse reward variant of the Cartpole Swingup task."""
    physics = CartpolePhysics.from_xml_string(*get_model_and_assets())
    task = CustomBalanceVisionTask(
        swing_up=True,
        sparse=True,
        control_cost_weight=control_cost_weight,
        random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(
        physics, task, time_limit=time_limit, **environment_kwargs)


class CustomBalanceTask(BalanceTask):
    _CART_RANGE = (-0.1, 0.1)
    # _ANGLE_COSINE_RANGE = (.995, 1)
    _ANGLE_COSINE_RANGE = (.95, 1)
    _CART_VELOCITY_RANGE = (-1.0, 1.0)
    _ANGLE_VELOCITY_RANGE = (-1.0, 1.0)

    def __init__(self,
                 control_cost_weight,
                 *args,
                 **kwargs):
        self._control_cost_weight = control_cost_weight
        return super(CustomBalanceTask, self).__init__(*args, **kwargs)

    def _get_reward(self, physics, sparse):
        assert sparse, sparse

        cart_in_bounds = rewards.tolerance(
            physics.cart_position(), self._CART_RANGE)
        angle_in_bounds = rewards.tolerance(
            physics.pole_angle_cosine()[0], self._ANGLE_COSINE_RANGE)
        cart_velocity_in_bounds = rewards.tolerance(
            physics.velocity()[0], self._CART_VELOCITY_RANGE)
        angle_velocity_in_bounds = rewards.tolerance(
            physics.angular_vel()[0], self._ANGLE_VELOCITY_RANGE)

        in_bounds_reward = float(np.prod((
            cart_in_bounds,
            angle_in_bounds,
            cart_velocity_in_bounds,
            angle_velocity_in_bounds,
        )))

        control_cost = self._control_cost_weight * np.abs(physics.control())

        reward = in_bounds_reward - control_cost

        return reward


class CustomBalanceVisionTask(CustomBalanceTask):
    def __init__(self, *args, **kwargs):
        self._pixels_stack = collections.deque(maxlen=3)
        self._pixels_size = 32

        for i in range(self._pixels_stack.maxlen):
            width = height = self._pixels_size
            self._pixels_stack.append(
                np.zeros((width, height, 3), dtype=np.uint8))

        return super(CustomBalanceVisionTask, self).__init__(*args, **kwargs)

    def initialize_episode(self, physics):
        result = super(CustomBalanceVisionTask, self).initialize_episode(physics)
        physics.step()

        self._pixels_stack.clear()
        for i in range(self._pixels_stack.maxlen):
            width = height = self._pixels_size
            self._pixels_stack.append(
                np.zeros((width, height, 3), dtype=np.uint8))

        width = height = self._pixels_size
        current_pixels = physics.render(
            width=width, height=height, camera_id=0)
        self._pixels_stack.append(current_pixels)
        return result

    def after_step(self, physics):
        result = super(CustomBalanceTask, self).after_step(physics)
        width = height = self._pixels_size
        current_pixels = physics.render(
            width=width, height=height, camera_id=0)

        self._pixels_stack.append(current_pixels)
        return result

    def get_observation(self, physics):
        observation = collections.OrderedDict((
            (f'pixels-{i}', step_pixels)
            for i, step_pixels in enumerate(self._pixels_stack)
        ))

        return observation
