import colorsys
import tree

import gym
import numpy as np
from pycolab import rendering

from environments.empty_room import pycolab_maze
from environments.empty_room.pycolab_maze import PLACEABLES


NUM_ACTIONS = 4
DEFAULT_MAX_FRAMES = 50
FIXED_HUES = {}
for i, ch in enumerate(PLACEABLES.values()):
    FIXED_HUES[ch] = i / len(PLACEABLES)
_FIXED_COLOURS = tree.map_structure(lambda hue: colorsys.hsv_to_rgb(hue, 1., 1.), FIXED_HUES)
FIXED_COLOURS = tree.map_structure(lambda c: int(c * 1000), _FIXED_COLOURS)
FIXED_COLOURS[' '] = (100, 100, 100)
FIXED_COLOURS['#'] = (800, 800, 800)


def one_hot_converter(observation):
    layers = observation.layers
    num_rows, num_cols = observation.board.shape
    image = []
    for c in [
        '#',
        PLACEABLES['player'],
        PLACEABLES['key'],
    ]:
        layer = layers[c] if c in layers else np.zeros((num_rows, num_cols))
        image.append(layer.astype(np.uint8) * 255)
    return np.stack(image, axis=2)


class MazeEnv(gym.Env):
    """A Maze Environment."""

    def __init__(
        self,
        size,
        default_reward=0,
        observation_type='one_hot',
    ):
        self._size = size
        self._num_actions = NUM_ACTIONS
        self._colours = FIXED_COLOURS.copy()
        self._default_reward = default_reward
        if observation_type == 'rgb':
            # Agents expect HWC uint8 observations, Pycolab uses CHW float observations.
            colours = tree.map_structure(lambda c: float(c) * 255 / 1000, self._colours)
            self._observation_converter = rendering.ObservationToArray(
                    value_mapping=colours, permute=(1, 2, 0), dtype=np.uint8)
        elif observation_type == 'one_hot':
            self._observation_converter = one_hot_converter
        else:
            raise KeyError

        self._episode = self.make_episode()
        observation, _, _ = self._episode.its_showtime()
        self._image_shape = self._observation_converter(observation).shape
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=self._image_shape, dtype=np.uint8)
        self.action_space = gym.spaces.Discrete(self._num_actions)

    @property
    def colours(self):
        """Symbol to colour map for key to door."""
        return self._colours

    def make_episode(self):
        return pycolab_maze.make_game(self._size)

    def _process_outputs(self, observation, reward):
        if reward is None:
            reward = self._default_reward
        image = self._observation_converter(observation)
        return image, reward

    def reset(self):
        """Start a new episode."""
        self._episode = self.make_episode()
        observation, reward, _ = self._episode.its_showtime()
        observation, reward = self._process_outputs(observation, reward)
        return observation

    def step(self, action):
        observation, reward, discount = self._episode.play(action)
        observation, reward = self._process_outputs(observation, reward)
        return observation, reward, self._episode.game_over, {}


if __name__ == '__main__':
    env = MazeEnv(size=9, observation_type='one_hot')
    ob = env.reset()
    print(ob.transpose((2, 0, 1)))
    done = False
    t = 0
    while not done:
        action = env.action_space.sample()
        ob, rew, done, info = env.step(action)
        print(ob.transpose((2, 0, 1)))
        t += 1
        print(t, rew)
        if rew == 1:
            break
