import cv2
import numpy as np
from gymnasium import ObservationWrapper, spaces


class MiniworldObservationWrapper(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)

        obs_shape = self.observation_space.shape

        self.observation_space = spaces.Dict(
            {
                "obs": spaces.Box(
                    self.observation_space.low[0, 0, 0],
                    self.observation_space.high[0, 0, 0],
                    [1, obs_shape[0], obs_shape[1]],
                    dtype=self.observation_space.dtype,
                ),
                "state": spaces.Box(
                    self.observation_space.low[0, 0, 0],
                    self.observation_space.high[0, 0, 0],
                    [1, obs_shape[0], obs_shape[1]],
                    dtype=self.observation_space.dtype,
                ),
            }
        )

    def observation(self, obs):
        obs_image = obs
        state_image = self.env.render_top_view()

        obs_image = cv2.cvtColor(obs_image, cv2.COLOR_RGB2GRAY)
        obs_image = np.expand_dims(obs_image, 0)
        state_image = cv2.cvtColor(state_image, cv2.COLOR_RGB2GRAY)
        state_image = np.expand_dims(state_image, 0)

        return {"obs": obs_image, "state": state_image}
