# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Image based High Level tasks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import cPickle as pickle
from gym import spaces
from hal.clevr_env import ClevrEnv
from hal.high_level_policy.high_level_env import HighLevelEnv
import hal.language_processing_utils.word_vectorization as wv
import hal.high_level_policy.visual_feature_extractor as vfe


max_starting = -8
min_reward = -10


def get_hl_env(
    env_name='Arrange', normal_dqn=False, session=None,
    low_level_duration=3, low_level_obs_type='image',
    low_level_action_type='discrete'):

  env = ClevrEnv(
      num_object=5, num_objective=10, random_start=False, description_num=30,
      action_type=low_level_action_type, maximum_episode_steps=100,
      rew_shape=False, direct_obs=False, repeat_objective=False,
      fixed_objective=True, obs_type=low_level_obs_type, resolution=64)

  print('Creating image based high level environment...')

  if not normal_dqn:
    assert session
    print('Running hierarchical image based high level environemnt...')
    if env_name == 'Arrange' or env_name == 'statement':
      env = ImageArrangeWrapper(
          env, session, low_level_step=low_level_duration,
          low_level_obs_type=low_level_obs_type,
          low_level_action_type=low_level_action_type)
    elif env_name == 'Order' or env_name == 'sort':
      env = ImageOrderWrapper(
          env, session, low_level_step=low_level_duration,
          low_level_obs_type=low_level_obs_type,
          low_level_action_type=low_level_action_type)
    elif env_name == 'Sort' or env_name == '2dsort':
      env = ImageSortWrapper(
          env, session, low_level_step=low_level_duration,
          low_level_obs_type=low_level_obs_type,
          low_level_action_type=low_level_action_type)
    else:
      raise ValueError('{} is not an existing environment'.format(env_name))
  # else:
  #   if env_name == 'Arrange' or env_name == 'statement':
  #     env = nhe.get_high_env('Arrange', env, action_type='discrete')
  #   elif env_name == 'Order' or env_name == 'sort':
  #     env = nhe.get_high_env('Order', env, action_type='discrete')
  #   elif env_name == 'Sort' or env_name == '2dsort':
  #     env = nhe.get_high_env('Sort', env, action_type='discrete')
  #   else:
  #     raise ValueError('{} is not an existing environment'.format(env_name))
  return env

max_starting = -8
min_reward = 0


def get_hl_env(
    env_name='Arrange', normal_dqn=False, session=None,
    low_level_duration=5, low_level_obs_type='image',
    low_level_action_type='discrete', top_down_view=False):

  env = ClevrEnv(
      num_object=5, num_objective=10, random_start=True, description_num=30,
      action_type=low_level_action_type, maximum_episode_steps=100,
      rew_shape=False, direct_obs=False, repeat_objective=False,
      fixed_objective=True, obs_type=low_level_obs_type, resolution=64,
      top_down_view=top_down_view)

  print('Creating image based high level environment...')

  if not normal_dqn:
    assert session
    print('Running hierarchical image based high level environemnt...')
    if env_name == 'Arrange' or env_name == 'statement':
      env = ImageArrangeWrapper(
          env, session, low_level_step=low_level_duration,
          low_level_obs_type=low_level_obs_type,
          low_level_action_type=low_level_action_type)
    elif env_name == 'Order' or env_name == 'sort':
      env = ImageOrderWrapper(
          env, session, low_level_step=low_level_duration,
          low_level_obs_type=low_level_obs_type,
          low_level_action_type=low_level_action_type)
    elif env_name == 'Sort' or env_name == '2dsort':
      env = ImageSortWrapper(
          env, session, low_level_step=low_level_duration,
          low_level_obs_type=low_level_obs_type,
          low_level_action_type=low_level_action_type)
    elif env_name == 'Toy' or env_name == 'toy':
      env = ImageToyWrapper(
          env, session, low_level_step=low_level_duration,
          low_level_obs_type=low_level_obs_type,
          low_level_action_type=low_level_action_type)
    else:
      raise ValueError('{} is not an existing environment'.format(env_name))
  else:
    if env_name == 'Arrange' or env_name == 'statement':
      env = nhe.get_high_env('Arrange', env, action_type='discrete')
    elif env_name == 'Order' or env_name == 'sort':
      env = nhe.get_high_env('Order', env, action_type='discrete')
    elif env_name == 'Sort' or env_name == '2dsort':
      env = nhe.get_high_env('Sort', env, action_type='discrete')
    else:
      raise ValueError('{} is not an existing environment'.format(env_name))
  return env


class ImageArrangeWrapper(HighLevelEnv):

  def __init__(
      self, env, sess, low_level_step=5, augmented_actions=False,
      low_level_obs_type='image', low_level_action_type='discrete'
  ):
    low_level_name = 'model' if low_level_obs_type == 'image' else 'target'
    super(ImageArrangeWrapper, self).__init__(
        env,
        sess,
        self_attention_low_level=True,
        obs_type=low_level_obs_type,
        low_level_action_type=low_level_action_type,
        low_level_step=low_level_step,
        low_level_policy_name=low_level_name
    )

    action_num = len(self.text)
    # if augmented_actions: action_num += len(self.discrete_action_set)
    self.action_space = spaces.Discrete(action_num)
    self.curr_step = 0
    self.low_level_step = low_level_step
    print('Image hi-level: number of actions: {}'.format(len(self.text)))
    print('Image hi-level: number of low step: {}'.format(low_level_step))
    self.sample_episode_n = 20

  def step(self, action):
    text = self.text[action]
    for _ in range(self.low_level_step):
      low_level_action = self.get_low_level_action(text)
      _, _, done, _ = self.env.step(low_level_action)
    rew = self._reward()
    self.curr_step += 1
    done = self._is_done(rew) or self.curr_step+1 >= self.max_episode_steps
    return self.env.get_image_obs(), rew, done, None

  def _reward(self):
    ans = [self.get_answer(question) for question in self.objective_program]
    satisfied_goal = np.logical_not(np.logical_xor(ans, self.objective_goal))
    reward = np.sum(satisfied_goal.astype(float)) - len(self.objective_program)
    reward = -10. if reward < min_reward else reward
    return reward

  def _complete(self):
    ans = [self.get_answer(question) for question in self.objective_program]
    satisfied_goal = np.logical_not(np.logical_xor(ans, self.objective_goal))
    reward = np.sum(satisfied_goal.astype(float)) - len(self.objective_program)
    return reward

  def _is_done(self, reward):
    return reward == 0.

  def reset(self, max_reset=50):
    _ = self.env.reset()
    r = self._complete()
    rep = 0
    while r > max_starting and rep < max_reset:
      _ = self.env.reset()
      r = self._complete()
      rep += 1
    self.curr_step = 0
    return self.env.get_image_obs()

  def rollout(self, act, exp_file_name):
    print('rolling out...')
    all_frames = []
    for _ in range(self.sample_episode_n):
      obs = self.reset(5)
      black_scene = self.render(mode='rgb_array')*0.0
      for _ in range(10):
        all_frames.append(black_scene)
      all_frames.append(self.render(mode='rgb_array'))
      for _ in range(60):
        action = act(np.array(obs)[None])
        if not isinstance(action, int):
          action = action[0]
        if action < len(self.text):
          text = self.text[action]
          actual_text = self.decode_fn(text)
          # print(actual_text)
          for _ in range(self.low_level_step):
            low_level_action = self.get_low_level_action(text)
            # print(low_level_action)
            obs, _, _, _ = self.env.step(low_level_action)
            frame = self.render(mode='rgb_array')
            # all_frames.append(add_text(frame, actual_text))
        else:
          self.env.step_perfect(action-len(self.text))
          all_frames.append(self.render(mode='rgb_array'))
        obs = self.env.get_image_obs()
        if self._is_done(self._reward()):
          final_frame = self.render(mode='rgb_array')
          # final_frame = add_text(final_frame, 'success', False)
          for _ in range(10):
            all_frames.append(final_frame)
          break
    # save_video(np.uint8(all_frames), exp_file_name, fps=10.0)
    print('finish roll out...')


class ImageOrderWrapper(ImageArrangeWrapper):

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    obj_sorted = True
    for i, coord in enumerate(curr_state[:-1]):
      obj_sorted = obj_sorted and curr_state[i+1][0] < coord[0]-0.01
      obj_sorted = obj_sorted and abs(curr_state[i+1][1]-coord[1]) < 0.4
      if not obj_sorted: break
    return 0. if obj_sorted else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    num_pair_sorted = 0
    for i, coord in enumerate(curr_state[:-1]):
      pair_sorted = curr_state[i+1][0] < coord[0]-0.01
      pair_sorted = pair_sorted and abs(curr_state[i+1][1]-coord[1]) < 0.4
      num_pair_sorted += float(pair_sorted)
    return 10 if num_pair_sorted > 1 else -10

  def _is_done(self, reward):
    return reward == 0.


class ImageSortWrapper(ImageArrangeWrapper):

  def constraint_count(self):
    max_dist = 0.5
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    # obj 0 constraint
    obj_0_const = curr_state[0][1] > curr_state[4][1] + 0.01
    obj_0_const = obj_0_const and curr_state[0][0] > curr_state[3][0] + 0.01
    obj_0_const = obj_0_const and curr_state[0][0] < curr_state[1][0] - 0.01
    dist_04 = np.linalg.norm(curr_state[4]-curr_state[0])
    obj0_const = float(obj_0_const and dist_04 < max_dist)
    # obj 1 constraint
    obj_1_const = curr_state[1][0] > curr_state[4][0] + 0.01
    obj_1_const = obj_1_const and curr_state[1][1] > curr_state[2][1] + 0.01
    obj_1_const = obj_1_const and curr_state[1][1] < curr_state[0][1] - 0.01
    dist_14 = np.linalg.norm(curr_state[4]-curr_state[1])
    obj1_const = float(obj_1_const and dist_14 < max_dist)
    # obj 2 constraint
    obj_2_const = curr_state[2][1] < curr_state[4][1] - 0.01
    obj_2_const = obj_2_const and curr_state[2][0] > curr_state[3][0] + 0.01
    obj_2_const = obj_2_const and curr_state[2][0] < curr_state[1][0] - 0.01
    dist_24 = np.linalg.norm(curr_state[4]-curr_state[2])
    obj2_const = float(obj_2_const and dist_24 < max_dist)
    # obj 3 constraint
    obj_3_const = curr_state[3][0] < curr_state[4][0] - 0.01
    obj_3_const = obj_3_const and curr_state[3][1] > curr_state[2][1] + 0.01
    obj_3_const = obj_3_const and curr_state[3][1] < curr_state[0][1] - 0.01
    dist_34 = np.linalg.norm(curr_state[4]-curr_state[3])
    obj3_const = float(obj_3_const and dist_34 < max_dist)
    return obj0_const + obj1_const + obj2_const + obj3_const

  def _reward(self):
    constraint_satisfied = self.constraint_count()
    return 0. if constraint_satisfied >= 4.0 else -10.

  def _complete(self):
    constraint_satisfied = self.constraint_count()
    return 10 if constraint_satisfied > 3.0 else -10

  def _is_done(self, reward):
    return reward == 0.


class ImageToyWrapper(ImageArrangeWrapper):

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    obj_1, obj_2, obj_3 = curr_state[0], curr_state[1], curr_state[2]
    obj_4, obj_5 = curr_state[3], curr_state[4]
    obj_sorted = True
    obj_sorted = obj_sorted and obj_1[0] < obj_2[0] - 0.2
    obj_sorted = obj_sorted and obj_1[0] < obj_3[0] - 0.2
    obj_sorted = obj_sorted and obj_1[0] > obj_4[0] + 0.2
    return 0. if obj_sorted else -10.

  def _complete(self):
    return 10. if self._reward() == 0. else -10

  def _is_done(self, reward):
    return reward == 0.


max_starting = -8
min_reward = 0

vocab_path = os.path.join(__file__, '..', 'assets/vocab.txt')

vocab_list = wv.load_vocab_list(vocab_path)
_, i2v = wv.create_look_up_table(vocab_list)
decode_fn = wv.decode_with_lookup_table(i2v)

diverse_instruction_path = None  # path for diverse environment's instruction set.


color_filter = [[0.57735027, -0.57735027, -0.57735027],
                [-0.57735027, -0.57735027, 0.57735027],
                [-0.57735027, 0.57735027, -0.57735027],
                [-0.98162999, 0.163605, 0.098163],
                [0.06311944, -0.9467916, 0.3155972]]


def one_by_one_conv(img, kernel):
  img_flat = img.reshape([-1, 3])
  inner = np.array(kernel).dot(img_flat.T)
  return inner.reshape(img.shape[:-1])


def hard_coded_conv(img):
  all_img = []
  for f in color_filter:
    test_img = one_by_one_conv(img, f)
    test_img[test_img < 0] = 0.
    test_img /= 0.01
    all_img.append(test_img)
  return np.stack(all_img, axis=-1)


class HighLevelWrapper(object):

  def __init__(self, env, obs_type, sess=None):
    self.env = env
    self.curr_step = 0
    self.obs_type = obs_type
    self.sess = sess
    if self.obs_type == 'image':
      self.extractor = vfe.FilmFeatureExtractor(sess)
      self.observation_space = spaces.Box(
          low=-1.0, high=1.0, shape=[2*96], dtype=np.float32)

  def __getattr__(self, attr):
    return getattr(self.env, attr)

  def get_obs(self):
    if self.obs_type == 'direct':
      return self.env.get_direct_obs()
    elif self.obs_type == 'image':
      return self.extractor.get_feature(self.env.get_image_obs())

  def step(self, a):
    obs, _, done, _ = self.env.step(a)
    rew = self._reward()
    done = self._is_done(rew) or self.curr_step+1 >= self.max_episode_steps
    self.curr_step += 1
    return self.get_obs(), rew, done, None

  def _reward(self):
    ans = [self.get_answer(question) for question in self.objective_program]
    satisfied_goal = np.logical_not(np.logical_xor(ans, self.objective_goal))
    reward = np.sum(satisfied_goal.astype(float)) - len(self.objective_program)
    reward = -10. if reward < min_reward else reward
    return reward

  def _complete(self):
    ans = [self.get_answer(question) for question in self.objective_program]
    satisfied_goal = np.logical_not(np.logical_xor(ans, self.objective_goal))
    reward = np.sum(satisfied_goal.astype(float)) - len(self.objective_program)
    return reward

  def _is_done(self, reward):
    return reward == 0.

  def reset(self, max_reset=10):
    obs = self.env.reset()
    r = self._complete()
    rep = 0
    while r > max_starting and rep < max_reset:
      obs = self.env.reset()
      r = self._complete()
      rep += 1
    self.curr_step = 0
    return self.get_obs()

  def rollout(self, act, exp_file_name):
    all_frames = []
    for _ in range(10):
      obs = self.reset(5)
      black_scene = self.render(mode='rgb_array')*0.0
      for _ in range(10):
        all_frames.append(black_scene)
      all_frames.append(self.render(mode='rgb_array'))
      for _ in range(60):
        action = act(np.array(obs)[None])
        if not isinstance(action, int):
          action = action[0]
        obs, rew, done, _ = self.step(action)
        all_frames.append(self.render(mode='rgb_array'))
        if self._is_done(self._reward()):
          final_frame = self.render(mode='rgb_array')
          for _ in range(10):
            all_frames.append(final_frame)
          break
    # save_video(np.uint8(all_frames), exp_file_name)


class HighLevelSortWrapper(HighLevelWrapper):

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    obj_sorted = True
    for i, coord in enumerate(curr_state[:-1]):
      obj_sorted = obj_sorted and curr_state[i+1][0] < coord[0]-0.01
      obj_sorted = obj_sorted and abs(curr_state[i+1][1]-coord[1]) < 0.4
      if not obj_sorted: break
    return 0. if obj_sorted else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    num_pair_sorted = 0
    for i, coord in enumerate(curr_state[:-1]):
      pair_sorted = curr_state[i+1][0] < coord[0]-0.01
      pair_sorted = pair_sorted and abs(curr_state[i+1][1]-coord[1]) < 0.4
      num_pair_sorted += float(pair_sorted)
    return 10 if num_pair_sorted > 1 else -10

  def _is_done(self, reward):
    return reward == 0.


class HighLevel2DSortWrapper(HighLevelWrapper):

  def constraint_count(self):
    max_dist = 0.5
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    # obj 0 constraint
    obj_0_const = curr_state[0][1] > curr_state[4][1] + 0.01
    obj_0_const = obj_0_const and curr_state[0][0] > curr_state[3][0] + 0.01
    obj_0_const = obj_0_const and curr_state[0][0] < curr_state[1][0] - 0.01
    dist_04 = np.linalg.norm(curr_state[4]-curr_state[0])
    obj0_const = float(obj_0_const and dist_04 < max_dist)
    # obj 1 constraint
    obj_1_const = curr_state[1][0] > curr_state[4][0] + 0.01
    obj_1_const = obj_1_const and curr_state[1][1] > curr_state[2][1] + 0.01
    obj_1_const = obj_1_const and curr_state[1][1] < curr_state[0][1] - 0.01
    dist_14 = np.linalg.norm(curr_state[4]-curr_state[1])
    obj1_const = float(obj_1_const and dist_14 < max_dist)
    # obj 2 constraint
    obj_2_const = curr_state[2][1] < curr_state[4][1] - 0.01
    obj_2_const = obj_2_const and curr_state[2][0] > curr_state[3][0] + 0.01
    obj_2_const = obj_2_const and curr_state[2][0] < curr_state[1][0] - 0.01
    dist_24 = np.linalg.norm(curr_state[4]-curr_state[2])
    obj2_const = float(obj_2_const and dist_24 < max_dist)
    # obj 3 constraint
    obj_3_const = curr_state[3][0] < curr_state[4][0] - 0.01
    obj_3_const = obj_3_const and curr_state[3][1] > curr_state[2][1] + 0.01
    obj_3_const = obj_3_const and curr_state[3][1] < curr_state[0][1] - 0.01
    dist_34 = np.linalg.norm(curr_state[4]-curr_state[3])
    obj3_const = float(obj_3_const and dist_34 < max_dist)
    return obj0_const + obj1_const + obj2_const + obj3_const

  def _reward(self):
    constraint_satisfied = self.constraint_count()
    return 0. if constraint_satisfied >= 4.0 else -10.

  def _complete(self):
    constraint_satisfied = self.constraint_count()
    return 10 if constraint_satisfied > 3.0 else -10

  def _is_done(self, reward):
    return reward == 0.


class HighLevelColorSortWrapper(HighLevelWrapper):

  def __init__(self, env, obs_type, sess=None):
    self.env = env
    self.curr_step = 0
    self.obs_type = obs_type
    self.sess = sess
    if self.obs_type == 'image':
      self.extractor = vfe.DiverseGTFeatureExtractor(
          sess, 10, state_fn=env.get_order_invariant_obs)
      self.observation_space = spaces.Box(
          low=-1.0, high=1.0, shape=[self.extractor.obs_shape], dtype=np.float32)

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    colors = [self.scene_graph[i]['color'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    colors = [colors[i] for i in argsorted]
    color_order = {'red': 0, 'green': 1, 'blue': 2, 'cyan': 3, 'purple': 4}
    ordering = [color_order[c] for c in colors]
    is_sorted = 1
    for index, number in enumerate(ordering[:-1]):
      is_sorted *= float(number <= ordering[index+1])
    return 0. if is_sorted > 0 else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    colors = [self.scene_graph[i]['color'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    colors = [colors[i] for i in argsorted]
    color_order = {'red': 0, 'green': 1, 'blue': 2, 'cyan': 3, 'purple': 4}
    ordering = [color_order[c] for c in colors]
    num_pair_sorted = 0
    for index, number in enumerate(ordering[:-1]):
      num_pair_sorted += float(number <= ordering[index+1])
    return 10 if num_pair_sorted > 3 else -10

  def _is_done(self, reward):
    return reward == 0.


class HighLevelShapeSortWrapper(HighLevelColorSortWrapper):

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    shapes = [self.scene_graph[i]['shape'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    shapes = [shapes[i] for i in argsorted]
    shape_order = {'sphere': 0, 'cube': 1, 'cylinder': 2}
    ordering = [shape_order[c] for c in shapes]
    is_sorted = 1
    for index, number in enumerate(ordering[:-1]):
      is_sorted *= float(number <= ordering[index+1])
    return 0. if is_sorted > 0 else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    shapes = [self.scene_graph[i]['shape'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    shapes = [shapes[i] for i in argsorted]
    shape_order = {'sphere': 0, 'cube': 1, 'cylinder': 2}
    ordering = [shape_order[c] for c in shapes]
    num_pair_sorted = 0
    for index, number in enumerate(ordering[:-1]):
      num_pair_sorted += float(number <= ordering[index+1])
    return 10 if num_pair_sorted > 3 else -10

  def _is_done(self, reward):
    return reward == 0.


class HighLevelColorShapeSortWrapper(HighLevelColorSortWrapper):

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    shapes = [self.scene_graph[i]['shape'] for i in range(len(self.obj_name))]
    colors = [self.scene_graph[i]['color'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    shapes = [shapes[i] for i in argsorted]
    colors = [colors[i] for i in argsorted]
    shape_order = {'sphere': 0, 'cube': 1, 'cylinder': 2}
    color_order = {'red': 0, 'green': 1, 'blue': 2, 'cyan': 3, 'purple': 4}
    shape_ordering = [shape_order[c] for c in shapes]
    color_ordering = [color_order[c] for c in colors]
    is_sorted = 1
    for index, number in enumerate(color_ordering[:-1]):
      color_strictly_sorted = number < color_ordering[index+1]
      if color_strictly_sorted:
        all_sorted = True
      elif number == color_ordering[index+1]:
        all_sorted = shape_ordering[index] <= shape_ordering[index+1]
      else:
        all_sorted = False
      is_sorted *= float(all_sorted)
    return 0. if is_sorted > 0 else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    shapes = [self.scene_graph[i]['shape'] for i in range(len(self.obj_name))]
    colors = [self.scene_graph[i]['color'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    shapes = [shapes[i] for i in argsorted]
    colors = [colors[i] for i in argsorted]
    shape_order = {'sphere': 0, 'cube': 1, 'cylinder': 2}
    color_order = {'red': 0, 'green': 1, 'blue': 2, 'cyan': 3, 'purple': 4}
    shape_ordering = [shape_order[c] for c in shapes]
    color_ordering = [color_order[c] for c in colors]
    num_pair_sorted = 0
    for index, number in enumerate(color_ordering[:-1]):
      color_strictly_sorted = number < color_ordering[index+1]
      if color_strictly_sorted:
        all_sorted = True
      elif number == color_ordering[index+1]:
        all_sorted = shape_ordering[index] <= shape_ordering[index+1]
      else:
        all_sorted = False
      num_pair_sorted += float(all_sorted)
    return 10 if num_pair_sorted > 3 else -10

  def _is_done(self, reward):
    return reward == 0.


#==============================================================================


class DiscreteHierarchicalEnv(HighLevelEnv):

  def __init__(
      self, env, sess, self_attention_low_level=False,
      augmented_actions=False, low_level_obs_type='order_invariant',
      low_level_action_type='perfect', low_level_step=5
  ):

    # self.img_ph = tf.placeholder(tf.float32, [1, 64, 64, 5])
    # self.spsm = tf.contrib.layers.spatial_softmax(self.img_ph, trainable=False)
    # temperatures = []
    # for var in tf.global_variables():
    #   if 'temperature' in var.name: temperatures.append(var)
    # print('temperature variables: {}'.format(temperatures))
    # sess.run(tf.initialize_variables(temperatures))

    self.extractor = vfe.FilmFeatureExtractor(sess)
    self.observation_space = spaces.Box(
        low=-1.0, high=1.0, shape=[2*96], dtype=np.float32)

    super(DiscreteHierarchicalEnv, self).__init__(
        env, sess, self_attention_low_level=self_attention_low_level,
        obs_type='direct',
        low_level_action_type=low_level_action_type,
        low_level_obs_type=low_level_obs_type
    )

    # random.shuffle(self.text)
    action_num = len(self.text)
    if augmented_actions: action_num += len(self.perfect_action_set)
    self.action_space = spaces.Discrete(action_num)
    self.curr_step = 0
    self.low_level_step = low_level_step

    print('DiscreteHierarchicalEnv number of actions {}'.format(len(self.text)))
    print('DiscreteHierarchicalEnv number of perfect actions {}'.format(
        len(self.perfect_action_set)))
    print('DiscreteHierarchicalEnv low level step {}'.format(
        self.low_level_step))

  def _get_obs(self):
    # return self.env.get_direct_obs()

    # img = self.env.get_image_obs()
    # filtered_img = hard_coded_conv(img)
    # return self.sess.run(self.spsm, {self.img_ph: [filtered_img]})[0]

    return self.extractor.get_feature(self.env.get_image_obs())

  def step(self, action):
    if action < len(self.text):
      text = self.text[action]
      for i in range(self.low_level_step):
        low_level_action = self.get_low_level_action(text)
        obs, _, done, _ = self.env.step(low_level_action)
    else:
      self.env.step_perfect(action-len(self.text))
    rew = self._reward()
    self.curr_step += 1
    done = self._is_done(rew) or self.curr_step+1 >= self.max_episode_steps
    return self._get_obs(), rew, done, None

  def _reward(self):
    ans = [self.get_answer(question) for question in self.objective_program]
    satisfied_goal = np.logical_not(np.logical_xor(ans, self.objective_goal))
    reward = np.sum(satisfied_goal.astype(float)) - len(self.objective_program)
    reward = -10. if reward < min_reward else reward
    return reward

  def _complete(self):
    ans = [self.get_answer(question) for question in self.objective_program]
    satisfied_goal = np.logical_not(np.logical_xor(ans, self.objective_goal))
    reward = np.sum(satisfied_goal.astype(float)) - len(self.objective_program)
    return reward

  def _is_done(self, reward):
    return reward == 0.

  def reset(self, max_reset=50):
    _ = self.env.reset()
    r = self._complete()
    rep = 0
    while r > max_starting and rep < max_reset:
      _ = self.env.reset()
      r = self._complete()
      rep += 1
    self.curr_step = 0
    return self._get_obs()

  def rollout(self, act, exp_file_name):
    print('rolling out...')
    all_frames = []
    for _ in range(10):
      obs = self.reset(5)
      black_scene = self.render(mode='rgb_array')*0.0
      for _ in range(10):
        all_frames.append(black_scene)
      all_frames.append(self.render(mode='rgb_array'))
      for _ in range(60):
        action = act(np.array(obs)[None])
        if not isinstance(action, int):
          action = action[0]
        if action < len(self.text):
          text = self.text[action]
          actual_text = self.decode_fn(text)
          for i in range(self.low_level_step):
            low_level_action = self.get_low_level_action(text)
            obs, _, done, _ = self.env.step(low_level_action)
            frame = self.render(mode='rgb_array')
            # all_frames.append(add_text(frame, actual_text))
        else:
          self.env.step_perfect(action-len(self.text))
          all_frames.append(self.render(mode='rgb_array'))
        obs = self._get_obs()
        if self._is_done(self._reward()):
          final_frame = self.render(mode='rgb_array')
          # final_frame = add_text(final_frame, 'success', False)
          for _ in range(10):
            all_frames.append(final_frame)
          break
    # save_video(np.uint8(all_frames), exp_file_name, fps=10.0)
    print('finish roll out...')


class DiscreteHierarchicalSortEnv(DiscreteHierarchicalEnv):

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    obj_sorted = True
    for i, coord in enumerate(curr_state[:-1]):
      obj_sorted = obj_sorted and curr_state[i+1][0] < coord[0]-0.01
      obj_sorted = obj_sorted and abs(curr_state[i+1][1]-coord[1]) < 0.4
      if not obj_sorted: break
    return 0. if obj_sorted else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    num_pair_sorted = 0
    for i, coord in enumerate(curr_state[:-1]):
      pair_sorted = curr_state[i+1][0] < coord[0]-0.01
      pair_sorted = pair_sorted and abs(curr_state[i+1][1]-coord[1]) < 0.4
      num_pair_sorted += float(pair_sorted)
    return 10 if num_pair_sorted > 1 else -10

  def _is_done(self, reward):
    return reward == 0.


class DiscreteHierarchical2DSortEnv(DiscreteHierarchicalEnv):

  def constraint_count(self):
    max_dist = 0.5
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    # obj 0 constraint
    obj_0_const = curr_state[0][1] > curr_state[4][1] + 0.01
    obj_0_const = obj_0_const and curr_state[0][0] > curr_state[3][0] + 0.01
    obj_0_const = obj_0_const and curr_state[0][0] < curr_state[1][0] - 0.01
    dist_04 = np.linalg.norm(curr_state[4]-curr_state[0])
    obj0_const = float(obj_0_const and dist_04 < max_dist)
    # obj 1 constraint
    obj_1_const = curr_state[1][0] > curr_state[4][0] + 0.01
    obj_1_const = obj_1_const and curr_state[1][1] > curr_state[2][1] + 0.01
    obj_1_const = obj_1_const and curr_state[1][1] < curr_state[0][1] - 0.01
    dist_14 = np.linalg.norm(curr_state[4]-curr_state[1])
    obj1_const = float(obj_1_const and dist_14 < max_dist)
    # obj 2 constraint
    obj_2_const = curr_state[2][1] < curr_state[4][1] - 0.01
    obj_2_const = obj_2_const and curr_state[2][0] > curr_state[3][0] + 0.01
    obj_2_const = obj_2_const and curr_state[2][0] < curr_state[1][0] - 0.01
    dist_24 = np.linalg.norm(curr_state[4]-curr_state[2])
    obj2_const = float(obj_2_const and dist_24 < max_dist)
    # obj 3 constraint
    obj_3_const = curr_state[3][0] < curr_state[4][0] - 0.01
    obj_3_const = obj_3_const and curr_state[3][1] > curr_state[2][1] + 0.01
    obj_3_const = obj_3_const and curr_state[3][1] < curr_state[0][1] - 0.01
    dist_34 = np.linalg.norm(curr_state[4]-curr_state[3])
    obj3_const = float(obj_3_const and dist_34 < max_dist)
    return obj0_const + obj1_const + obj2_const + obj3_const

  def _reward(self):
    constraint_satisfied = self.constraint_count()
    return 0. if constraint_satisfied >= 4.0 else -10.

  def _complete(self):
    constraint_satisfied = self.constraint_count()
    return 10 if constraint_satisfied > 3.0 else -10

  def _is_done(self, reward):
    return reward == 0.


class DiverseColorSort(DiscreteHierarchicalEnv):

  def __init__(
      self, env, sess=None, self_attention_low_level=False,
      augmented_actions=False, low_level_obs_type='order_invariant',
      low_level_action_type='perfect', low_level_step=5
  ):
    self.extractor = vfe.DiverseGTFeatureExtractor(
        sess, 10, state_fn=env.get_order_invariant_obs)
    if not sess: sess = tf.Session()
    # self.extractor = vfe.DiverseFilmFeatureExtractorV3(
    #     sess, num_statements=1)
    self.observation_space = spaces.Box(
        low=-1.0, high=1.0, shape=[self.extractor.obs_shape], dtype=np.float32)

    super(DiscreteHierarchicalEnv, self).__init__(
        env, sess,
        self_attention_low_level=self_attention_low_level,
        obs_type='direct',
        low_level_action_type=low_level_action_type,
        low_level_obs_type=low_level_obs_type,
        diverse=True
    )

    with tf.gfile.GFile(diverse_instruction_path, mode='r') as f:
      self.text = pickle.load(f)
    action_num = len(self.text)
    if augmented_actions: action_num += len(self.perfect_action_set)
    self.action_space = spaces.Discrete(action_num)
    self.curr_step = 0
    self.low_level_step = low_level_step

    print('DiscreteHierarchicalEnv number of actions {}'.format(len(self.text)))
    print('DiscreteHierarchicalEnv number of perfect actions {}'.format(
        len(self.perfect_action_set)))
    print('DiscreteHierarchicalEnv low level step {}'.format(
        self.low_level_step))

  def _get_obs(self):
    return self.extractor.get_feature(self.env.get_image_obs())

  def reset(self, max_reset=10):
    _ = self.env.reset(True)
    r = self._complete()
    rep = 0
    while r > max_starting and rep < max_reset:
      _ = self.env.reset(True)
      r = self._complete()
      rep += 1
    self.curr_step = 0
    return self._get_obs()

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    colors = [self.scene_graph[i]['color'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    colors = [colors[i] for i in argsorted]
    color_order = {'red': 0, 'green': 1, 'blue': 2, 'cyan': 3, 'purple': 4}
    ordering = [color_order[c] for c in colors]
    is_sorted = 1
    for index, number in enumerate(ordering[:-1]):
      is_sorted *= float(number <= ordering[index+1])
    return 0. if is_sorted > 0 else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    colors = [self.scene_graph[i]['color'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    colors = [colors[i] for i in argsorted]
    color_order = {'red': 0, 'green': 1, 'blue': 2, 'cyan': 3, 'purple': 4}
    ordering = [color_order[c] for c in colors]
    num_pair_sorted = 0
    for index, number in enumerate(ordering[:-1]):
      num_pair_sorted += float(number <= ordering[index+1])
    return 10 if num_pair_sorted > 3 else -10

  def _is_done(self, reward):
    return reward == 0.


class DiverseShapeSort(DiverseColorSort):

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    shapes = [self.scene_graph[i]['shape'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    shapes = [shapes[i] for i in argsorted]
    shape_order = {'sphere': 0, 'cube': 1, 'cylinder': 2}
    ordering = [shape_order[c] for c in shapes]
    is_sorted = 1
    for index, number in enumerate(ordering[:-1]):
      is_sorted *= float(number <= ordering[index+1])
    return 0. if is_sorted > 0 else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    shapes = [self.scene_graph[i]['shape'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    shapes = [shapes[i] for i in argsorted]
    shape_order = {'sphere': 0, 'cube': 1, 'cylinder': 2}
    ordering = [shape_order[c] for c in shapes]
    num_pair_sorted = 0
    for index, number in enumerate(ordering[:-1]):
      num_pair_sorted += float(number <= ordering[index+1])
    return 10 if num_pair_sorted > 3 else -10

  def _is_done(self, reward):
    return reward == 0.


class DiverseColorShapeSort(DiverseColorSort):

  def _reward(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    shapes = [self.scene_graph[i]['shape'] for i in range(len(self.obj_name))]
    colors = [self.scene_graph[i]['color'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    shapes = [shapes[i] for i in argsorted]
    colors = [colors[i] for i in argsorted]
    shape_order = {'sphere': 0, 'cube': 1, 'cylinder': 2}
    color_order = {'red': 0, 'green': 1, 'blue': 2, 'cyan': 3, 'purple': 4}
    shape_ordering = [shape_order[c] for c in shapes]
    color_ordering = [color_order[c] for c in colors]
    is_sorted = 1
    for index, number in enumerate(color_ordering[:-1]):
      color_strictly_sorted = number < color_ordering[index+1]
      if color_strictly_sorted:
        all_sorted = True
      elif number == color_ordering[index+1]:
        all_sorted = shape_ordering[index] <= shape_ordering[index+1]
      else:
        all_sorted = False
      is_sorted *= float(all_sorted)
    return 0. if is_sorted > 0 else -10.

  def _complete(self):
    curr_state = np.array([self.get_body_com(name) for name in self.obj_name])
    shapes = [self.scene_graph[i]['shape'] for i in range(len(self.obj_name))]
    colors = [self.scene_graph[i]['color'] for i in range(len(self.obj_name))]
    argsorted = sorted(list(range(len(curr_state))),
                       key=lambda idx: curr_state[idx][0])
    shapes = [shapes[i] for i in argsorted]
    colors = [colors[i] for i in argsorted]
    shape_order = {'sphere': 0, 'cube': 1, 'cylinder': 2}
    color_order = {'red': 0, 'green': 1, 'blue': 2, 'cyan': 3, 'purple': 4}
    shape_ordering = [shape_order[c] for c in shapes]
    color_ordering = [color_order[c] for c in colors]
    num_pair_sorted = 0
    for index, number in enumerate(color_ordering[:-1]):
      color_strictly_sorted = number < color_ordering[index+1]
      if color_strictly_sorted:
        all_sorted = True
      elif number == color_ordering[index+1]:
        all_sorted = shape_ordering[index] <= shape_ordering[index+1]
      else:
        all_sorted = False
      num_pair_sorted += float(all_sorted)
    return 10 if num_pair_sorted > 3 else -10

  def _is_done(self, reward):
    return reward == 0.

#==============================================================================
