# Copyright (c) Facebook, Inc. and its affiliates.

from collections import defaultdict
import gym
import numpy as np
import random
import re
import warnings
from functools import lru_cache
from tokenizers import BertWordPieceTokenizer
import hashlib
import queue
import threading


KEEP_MESSAGES = {
    "wod": {
        b"You see here a M wand.",
        b"What do you want to zap? [f or ?*]",
        b"f - a M wand.",
        b"You kill the minotaur!",
        b"You kill it!",
        b"In what direction?",
        b"Welcome to experience level 2.",
        b"You see here a minotaur corpse.",
        b"g - a minotaur corpse.",
    }
}


class FilterMessagesWrapper(gym.Wrapper):
    def __init__(self, env, keep_messages):
        self.env = env
        if keep_messages not in KEEP_MESSAGES:
            raise NotImplementedError(f"Unknown keep messages {keep_messages}")
        self.keep_messages = KEEP_MESSAGES[keep_messages]

    def reset(self, wizkit_items=None):
        obs = self.env.reset(wizkit_items=wizkit_items)
        self._filter_message(obs)
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self._filter_message(obs)
        return obs, reward, done, info

    def _filter_message(self, obs):
        message_bytes = obs["message"].tobytes().rstrip(b"\x00").rstrip(b" ")
        if message_bytes not in self.keep_messages:
            obs["message"] = np.zeros_like(obs["message"])


class HashMessagesWrapper(gym.Wrapper):
    def __init__(self, env, hash_len=10):
        self.env = env
        self.hash_len = hash_len

    def reset(self, wizkit_items=None):
        obs = self.env.reset(wizkit_items=wizkit_items)
        self._hash_obs(obs)
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self._hash_obs(obs)
        return obs, reward, done, info

    def _hash_obs(self, obs):
        message_hashed = self._hash_message(obs["message"].tobytes())
        obs["message"] = message_hashed

    @lru_cache(maxsize=10000)
    def _hash_message(self, message):
        hsh = hashlib.sha1(message).hexdigest()[:self.hash_len].encode("ascii")
        new_message = np.zeros((256, ), dtype=np.uint8)
        new_message[:self.hash_len] = np.frombuffer(hsh, dtype=np.uint8)
        return new_message


class WordWrapper(gym.Wrapper):
    """Class used to tokenize minihack messages"""
    def __init__(self, env, vocab_file, max_message_len=50, max_vocab_size=30522, template=False, remove_brackets=True):
        super().__init__(env)
        self.max_message_len = max_message_len
        self.max_vocab_size = max_vocab_size
        self.template = template
        self.remove_brackets = remove_brackets

        self.brackets_regex = re.compile(r'\[[^()]*\]')
        self.tokenizer = BertWordPieceTokenizer(vocab_file, lowercase=True)

        self.hash = random.randrange(10000000)  # In case lru_cache checks hash

        if self.template:
            raise NotImplementedError

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self._tokenize_message(obs)
        return obs, reward, done, info

    def reset(self):
        obs = self.env.reset()
        self._tokenize_message(obs)
        return obs

    def _tokenize_message(self, obs):
        message_bytes = obs["message"].tobytes()
        obs["message"], obs["message_len"] = self.tokenize(message_bytes)

    @lru_cache(maxsize=10000)
    def tokenize(self, message_bytes):
        message = message_bytes.decode('utf-8').rstrip('\x00').lower()
        message = message.replace('  ', ' ')
        if self.remove_brackets:
            message = re.sub(self.brackets_regex, '', message)

        tokenizer_output = self.tokenizer.encode(message)

        tokens = np.zeros(self.max_message_len, dtype=np.int64)
        for t, token_id in enumerate(tokenizer_output.ids):
            if t >= len(tokens):
                warnings.warn(f"exceeded max message len with message {message}")
                break
            tokens[t] = token_id

        tokens_len = min(len(tokens), len(tokenizer_output))
        tokens_len = np.array([tokens_len], dtype=np.int64)
        return tokens, tokens_len

    def __hash__(self):
        return self.hash


class CounterWrapper(gym.Wrapper):
    def __init__(self, env, state_counter="none", key="state_visits"):
        # intialize state counter
        self.state_counter = state_counter
        self.key = key
        if self.state_counter != "none":
            self.state_count_dict = defaultdict(int)
        # this super() goes to the parent of the particular task, not to object
        super().__init__(env)

    def step(self, action):
        # add state counting to step function if desired
        step_return = self.env.step(action)
        if self.state_counter == "none":
            # do nothing
            return step_return

        obs, reward, done, info = step_return

        if self.state_counter == "ones":
            # treat every state as unique
            state_visits = 1
        elif self.state_counter == "coordinates":
            # use the location of the agent in the dungeon to accumulate visits
            features = obs["blstats"]
            x = features[0]
            y = features[1]
            d = features[12]
            coord = (d, x, y)
            self.state_count_dict[coord] += 1
            state_visits = self.state_count_dict[coord]
        elif self.state_counter == "messages":
            msg = obs["message"].tobytes()
            self.state_count_dict[msg] += 1
            state_visits = self.state_count_dict[msg]
        elif self.state_counter == "coordinates_messages":
            # Visit consists of location + message, so if message changes, you
            # get reward.
            features = obs["blstats"]
            x = features[0]
            y = features[1]
            d = features[12]
            msg = obs["message"].tobytes()
            coord = (d, x, y, msg)
            self.state_count_dict[coord] += 1
            state_visits = self.state_count_dict[coord]
        else:
            raise NotImplementedError("state_counter=%s" % self.state_counter)

        obs[self.key] = np.array([state_visits])

        if done:
            self.state_count_dict.clear()

        return step_return

    def reset(self, wizkit_items=None):
        # reset state counter when env resets
        obs = self.env.reset(wizkit_items=wizkit_items)
        if self.state_counter != "none":
            self.state_count_dict.clear()
            # current state counts as one visit
            obs[self.key] = np.array([1])
        return obs


class CropWrapper(gym.Wrapper):
    def __init__(self, env, h=9, w=9, pad=0, keys=("tty_chars", "tty_colors")):
        super().__init__(env)
        self.env = env
        self.h = h
        self.w = w
        self.pad = pad
        self.keys = keys
        assert self.h % 2 == 1
        assert self.w % 2 == 1
        self.last_observation = None
        self._actions = self.env._actions

    def render(self, mode="human", crop=True):
        self.env.render()
        obs = self.last_observation
        tty_chars_crop = obs["tty_chars_crop"]
        tty_colors_crop = obs["tty_colors_crop"]
        rendering = self.env.get_tty_rendering(
            tty_chars_crop, tty_colors_crop, print_guides=True
        )
        print(rendering)

    def step(self, action):
        next_state, reward, done, info = self.env.step(action)

        dh = self.h // 2
        dw = self.w // 2

        (y, x) = next_state["tty_cursor"]
        x += dw
        y += dh

        for key in self.keys:
            obs = next_state[key]
            obs = np.pad(
                obs,
                pad_width=(dw, dh),
                mode="constant",
                constant_values=self.pad,
            )
            next_state[key + "_crop"] = obs[
                y - dh : y + dh + 1, x - dw : x + dw + 1
            ]

        self.last_observation = next_state

        return next_state, reward, done, info

    def reset(self, wizkit_items=None):
        obs = self.env.reset(wizkit_items=wizkit_items)
        obs["tty_chars_crop"] = np.zeros((self.h, self.w), dtype=np.uint8)
        obs["tty_colors_crop"] = np.zeros((self.h, self.w), dtype=np.int8)
        self.last_observation = obs
        return obs


class PrevWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.last_observation = None
        self._actions = self.env._actions

    def step(self, action):
        next_state, reward, done, info = self.env.step(action)
        next_state["prev_reward"] = np.array([reward], dtype=np.float32)
        next_state["prev_action"] = np.array([action], dtype=np.uint8)

        self.last_observation = next_state

        return next_state, reward, done, info

    def reset(self, wizkit_items=None):
        obs = self.env.reset(wizkit_items=wizkit_items)
        obs["prev_reward"] = np.zeros(1, dtype=np.float32)
        obs["prev_action"] = np.zeros(1, dtype=np.uint8)
        self.last_observation = obs
        return obs


def target(resetqueue, readyqueue):
    while True:
        env = resetqueue.get()
        if env is None:
            return
        obs = env.reset()
        readyqueue.put((obs, env))


class CachedEnvWrapper(gym.Env):
    def __init__(self, envs, num_threads=2):
        self._envs = envs

        # This could alternatively also use concurrent.futures. I hesitate to do
        # that as futures.wait would have me deal with sets all the time where they
        # are really not necessary.
        self._resetqueue = queue.SimpleQueue()
        self._readyqueue = queue.SimpleQueue()

        self._threads = [
            threading.Thread(
                target=target, args=(self._resetqueue, self._readyqueue)
            )
            for _ in range(num_threads)
        ]
        for t in self._threads:
            t.start()

        for env in envs[1:]:
            self._resetqueue.put(env)
        self._env = envs[0]
        self.observation_space = self._env.observation_space

    def __getattr__(self, name):
        return getattr(self._env, name)

    def reset(self):
        self._resetqueue.put(self._env)
        obs, self._env = self._readyqueue.get()
        return obs

    def step(self, action):
        return self._env.step(action)

    def close(self):
        for _ in self._threads:
            self._resetqueue.put(None)

        for t in self._threads:
            t.join()

        for env in self._envs:
            env.close()

    def seed(self, seed=None):
        self._env.seed(seed)

    def unwrapped(self):
        return self._env

    def __str__(self):
        return "<CachedEnvWrapper envs=%s>" % [str(env) for env in self._envs]

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()
        return False  # Propagate exception.
