# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Hindsight Instruction Relabeling."""

import tensorflow as tf
import numpy as np
import random
import os
import re
import gc
import time

from tqdm import tqdm
from hal.clevr_env import ClevrEnv
from hal.models.language_modules import encoder
from hal.language_processing_utils.sentences_mutator import mutate_sentence
from hal.language_processing_utils.sentences_mutator import negate_unary_sentence
import hal.language_processing_utils.word_vectorization as wv
from hal.models.film import film_pi_network
from hal.models.film import mlp_pi_network
from hal.models.film import combine_variable_input
from hal.low_level_policy.model import Model
from hal.low_level_policy.model import VariableInputModel
from hal.low_level_policy.model import ImageModel

from absl import app
from absl import flags

FLAGS = flags.FLAGS
flags.DEFINE_string('save_dir', 'tmp', 'experiment home directory')
flags.DEFINE_bool('direct_obs', True, 'direct observation')
flags.DEFINE_bool('shape_reward', False, 'use shaped reward')
flags.DEFINE_bool('save_model', False, 'save model and log')
flags.DEFINE_bool('pseudo_boltzman', False, 'use boltzman exploration')
flags.DEFINE_bool('save_video', True, 'save video for evaluation')
flags.DEFINE_string('action_type', 'perfect', 'what type of action to use')
flags.DEFINE_integer('max_episode_length', 50, 'maximum episode duration')
flags.DEFINE_integer('num_epoch', 200, 'number of epoch')
flags.DEFINE_integer('num_cycle', 50, 'number of cycle per epoch')
flags.DEFINE_integer('num_episode', 50, 'number of episode per cycle')
flags.DEFINE_integer('optimisation_steps', 100, 'optimization per episode')
flags.DEFINE_integer('collect_cycle', 10, 'cycles for populating buffer')
flags.DEFINE_integer('K', 3, 'number of future to put into buffer')
flags.DEFINE_integer('buffer_size', int(1e6), 'size of replay buffer')
flags.DEFINE_integer('batchsize', 128, 'batchsize')
flags.DEFINE_float('tau', 0.95, 'moving average factor for target')
flags.DEFINE_float('gamma', 0.5, 'discount factor')
flags.DEFINE_float('initial_epsilon', 1.0, 'initial epsilon')
flags.DEFINE_float('min_epsilon', 0.1, 'minimum epsilon')
flags.DEFINE_float('learning_rate', 1e-4, 'minimum epsilon')
flags.DEFINE_float('epsilon_decay', 0.95, 'decay factor for epsilon')
flags.DEFINE_integer('log_iter', 25, 'number of episode between each log')
flags.DEFINE_string('obs_type', 'direct', 'type of observation')
flags.DEFINE_bool('cpu_collect', False, 'use cpu for rolling out')
flags.DEFINE_bool('record_atomic', False, 'record atomic goals')
flags.DEFINE_integer('K_immediate', 2,
                     'number of immediate correct statements added')
flags.DEFINE_bool('masking_q', False, 'mask the end of the episode')
flags.DEFINE_bool('onehot_goal', False, 'use one hot goal')
flags.DEFINE_bool('state_goal', False, 'use state space as the goal')
flags.DEFINE_bool('use_pretrained_encoder', False, 'use pretrained encoder')
flags.DEFINE_bool('mutate', False, 'mutate_sentences')
flags.DEFINE_bool('HER', True, 'use hindsight experience replay')
flags.DEFINE_bool('use_subset', False,
                  'use a subset of 600 sentences, for generalization')
flags.DEFINE_bool('rollout_only', False, 'only roll out')
flags.DEFINE_bool('continual_learning', False, 'the environment never resets')
flags.DEFINE_string('image_action_parameterization', 'regular',
                    'type of action parameterization used by the image model')
flags.DEFINE_integer('frame_skip', 20, 'simulation step for the physics')
flags.DEFINE_bool('pretrained_embedding', False, 'use pretrained embedding')
flags.DEFINE_bool('use_polar', False,
                  'use polar coordinate for neighbor assignment')
flags.DEFINE_bool('suppress', False,
                  'suppress movements of unnecessary objects')
flags.DEFINE_bool('variable_scene_content', False,
                  'variable scene content')
flags.DEFINE_float('shape_val', 0.25, 'Value for reward shaping')
flags.DEFINE_bool('use_vqa', False, 'Use vqa as the auxiliary task')


def pad_to_max_length(data):
  EOS = 0
  max_l = -1
  for p in data:
    max_l = max(max_l, len(p))
  data_ = []
  for p in data:
    if len(p) == max_l:
      data_.append(list(p))
    else:
      p = list(p) + [EOS] * (max_l - len(p))
      data_.append(p)
  return np.array(data_)


# Experience replay buffer
class Buffer():

  def __init__(self, buffer_size=50000):
    self.buffer = []
    self.buffer_size = buffer_size

  def add(self, experience):
    self.buffer.append(experience)
    if len(self.buffer) > self.buffer_size:
      self.buffer = self.buffer[int(0.0001 * self.buffer_size):]

  def sample(self, size):
    if len(self.buffer) >= size:
      experience_buffer = self.buffer
    else:
      experience_buffer = self.buffer * size
    sample = random.sample(experience_buffer, size)
    sample = np.array(sample)
    return np.copy(sample)


class Model1(object):

  def __init__(self, input_dim, ac_dim, name, vocab_size, embedding_size,
      conv_layer_config, dense_layer_config, encoder_n_unit,
      seq_length, direct_obs, learning_rate=1e-3, reuse=False,
      variable_input=False, trainable_encoder=True):
    with tf.variable_scope(name, reuse=reuse):
      self.input_dim = input_dim
      print('input dimension: {}'.format(self.input_dim))
      self.ac_dim = ac_dim[0]
      embedding = tf.get_variable(
          name='word_embedding',
          shape=(vocab_size, embedding_size),
          dtype=tf.float32, trainable=True)
      self.inputs = tf.placeholder(
          shape=[None]+input_dim, dtype=tf.float32, name='input_ph')
      print('input placeholder: {}'.format(self.inputs))
      self.word_inputs = tf.placeholder(
          shape=[None, None], dtype=tf.int32, name='text_ph')
      _, goal_embedding = encoder(self.word_inputs, embedding, encoder_n_unit,
                                  trainable=trainable_encoder)
      if variable_input:
        self.inputs_ = combine_variable_input(
            self.inputs, goal_embedding, 64, 128)
      else:
        self.inputs_ = self.inputs

      if not direct_obs:
        self.Q_ = film_pi_network(
            self.inputs_, goal_embedding, self.ac_dim, conv_layer_config,
            dense_layer_config)
      else:
        self.Q_ = mlp_pi_network(self.inputs_, goal_embedding, self.ac_dim,
                                 dense_layer_config)
      self.predict = tf.argmax(self.Q_, axis=-1)
      self.action = tf.placeholder(shape=None, dtype=tf.int32)
      self.action_onehot = tf.one_hot(
          self.action, self.ac_dim, dtype=tf.float32)
      self.Q = tf.reduce_sum(tf.multiply(self.Q_, self.action_onehot), axis=1)
      self.Q_next = tf.placeholder(shape=None, dtype=tf.float32)
      self.loss = tf.losses.huber_loss(self.Q_next, self.Q,
                                       reduction=tf.losses.Reduction.MEAN)
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      self.train_op = self.optimizer.minimize(self.loss)
      self.init_op = tf.global_variables_initializer()


def updateTargetGraph(tfVars_, tau):
  tfVars = []
  for v in tfVars_:
    tfVars.append(v)
  total_vars = len(tfVars)
  op_holder = []
  for idx, var in enumerate(tfVars[0:total_vars // 2]):
    op_holder.append(
        tfVars[idx + total_vars // 2].assign((var.value() * (1. - tau)) + (
            tau * tfVars[idx + total_vars // 2].value())))
  return op_holder


def updateTarget(op_holder, sess):
  for op in op_holder:
    sess.run(op)


def encode_text(text, lookup_table):
  sentence = re.findall(r"[\w']+|[.,!?;]",  text + ' eos')
  encoded_sentence = []
  for w in sentence:
    encoded_sentence.append(lookup_table[w.lower()])
  return encoded_sentence


def relabeling(
    obs, obs_next, action, env, achieved_goals,
    buff, encode_fn, episode_length, current_t, future_k,
    episode_experience, gamma):

  # print('================={}===================='.format(current_t))
  if achieved_goals:
    for _ in range(min(FLAGS.K_immediate, len(achieved_goals)+1)):
      achieved_instruction = random.choice(achieved_goals)
      if FLAGS.mutate and len(achieved_instruction) > 40:
        achieved_instruction = mutate_sentence(
            achieved_instruction, delete_color=FLAGS.variable_scene_content)
      sarsng = [obs, action, env.reward_scale,
                obs_next, encode_fn(achieved_instruction)]
      buff.add(np.array(sarsng))
      if len(achieved_instruction) < 40:
        negative_instruction = negate_unary_sentence(achieved_instruction)
        if negative_instruction:
          sarsng = [obs, action, 0, obs_next, encode_fn(negative_instruction)]
          buff.add(np.array(sarsng))

  goal_count, repeat = 0, 0
  while goal_count < future_k and repeat < (episode_length - current_t) * 2:
    future = np.random.randint(current_t, episode_length)
    _, _, _, _, _, g_n = episode_experience[future]
    if len(g_n) > 1:
      # g = random.choice(g_n[:-1])
      random.shuffle(g_n)
      for single_g in g_n:
        if len(single_g) > 30:
          # print('relabeld {} {}: {}'.format(goal_count, repeat, single_g))
          discount = gamma ** (future-current_t)
          if FLAGS.mutate:
            single_g = mutate_sentence(single_g, delete_color=FLAGS.variable_scene_content)
          buff.add(np.array(
              (obs, action, discount*env.reward_scale,
               obs_next, encode_fn(single_g))))
          goal_count += 1
          break
    repeat += 1


def relabeling_state(
    obs, obs_next, action, env, achieved_goals,
    buff, encode_fn, episode_length, current_t, future_k,
    episode_experience, gamma):
  del gamma  # Unused
  if achieved_goals:
    achieved_instruction = random.choice(achieved_goals)
    sarsng = [obs, action, 0, obs_next,
              env.convert_order_invariant_to_direct(achieved_instruction)]
    buff.add(np.array(sarsng))
  goal_count = 0
  while goal_count < future_k:
    future = np.random.randint(current_t, episode_length)
    _, _, _, _, _, g_n = episode_experience[future]
    for single_g in g_n:
      if len(obs_next.shape) > 1:
        direct_next_state = env.convert_order_invariant_to_direct(obs_next)
        single_g = env.convert_order_invariant_to_direct(single_g)
      else:
        direct_next_state = obs_next
        single_g = single_g
      buff.add(np.array(
          (obs,
           action,
           -np.linalg.norm(direct_next_state-single_g),
           obs_next,
           encode_fn(single_g))))
      goal_count += 1


def main(_):
  xm.setup_work_unit()

  HER = FLAGS.HER
  shaped_reward = FLAGS.shape_reward
  action_type = FLAGS.action_type
  max_episode_len = FLAGS.max_episode_length
  num_epochs = FLAGS.num_epoch
  num_cycles = FLAGS.num_cycle
  num_episodes = FLAGS.num_episode
  optimisation_steps = FLAGS.optimisation_steps
  collect_cycle = FLAGS.collect_cycle
  K = FLAGS.K
  buffer_size = FLAGS.buffer_size
  tau = FLAGS.tau
  gamma = FLAGS.gamma
  epsilon = FLAGS.initial_epsilon
  min_epsilon = FLAGS.min_epsilon
  batch_size = FLAGS.batchsize
  learning_rate = FLAGS.learning_rate
  obs_type = FLAGS.obs_type
  assert obs_type in ['image', 'order_invariant', 'direct']
  assert action_type in ['perfect', 'discrete', 'continuous']
  direct_obs = True if obs_type != 'image' else False

  res = 64
  input_dim = [res, res, 3] if not direct_obs else [10]
  ac_dim = [800] if action_type == 'discrete' else [40]

  vocab_path = os.path.join(__file__, '..', 'assets/vocab.txt')
  if FLAGS.variable_scene_content:
    vocab_path =os.path.join(__file__, '..', 'assets/variable_input_vocab.txt')


  vocab_list = open(vocab_path).read().split()
  if vocab_list[0] != 'eos':
    vocab_list = ['eos'] + vocab_list

  # order invariant specifics ==============================================
  max_num_shape = 1
  max_num_size = 0
  max_num_texture = 0
  max_num_color = 5
  single_obj_feature_length = (2 + max_num_shape + max_num_size
                               + max_num_texture + max_num_color)
  descriptor_length = 64
  inner_product_length = 32

  if obs_type == 'order_invariant':
    input_dim = [None, single_obj_feature_length]
    ac_dim = 8 if action_type != 'discrete' else ac_dim

  # layer settings =========================================================
  conv_layer_config = [(48, 8, 2), (128, 5, 2), (64, 3, 1)]

  dense_layer_config = [256, 512, 1024]
  if obs_type == 'image':
    dense_layer_config = [512, 512]
  encoder_n_unit = 32
  vocab_size = len(vocab_list)
  embedding_size = 8
  MAX_LENGTH = 21

  if FLAGS.use_pretrained_encoder or FLAGS.pretrained_embedding:
    encoder_n_unit = 64

  train = True
  save_model = FLAGS.save_model


  # index 0 will always be eos token
  vocab_dict = {word: i for i, word in enumerate(vocab_list)}
  encode_fn = lambda text: encode_text(text, vocab_dict)
  vocab_list = wv.load_vocab_list(vocab_path)
  _, i2v = wv.create_look_up_table(vocab_list)
  decode_fn = wv.decode_with_lookup_table(i2v)

  exp_config_str = ['dirct_obs_{}'.format(FLAGS.direct_obs),
                    'shprew_{}'.format(FLAGS.shape_reward),
                    FLAGS.action_type,
                    'bffr_{}'.format(FLAGS.buffer_size),
                    'mxstp_{}'.format(FLAGS.max_episode_length),
                    'nepoch_{}'.format(FLAGS.num_epoch),
                    'ncycle_{}'.format(FLAGS.num_cycle),
                    'nepisode_{}'.format(FLAGS.num_episode),
                    'opstep_{}'.format(FLAGS.optimisation_steps),
                    'cllct_{}'.format(FLAGS.collect_cycle),
                    'k_{}'.format(FLAGS.K),
                    'batch_{}'.format(FLAGS.batchsize),
                    'tau_{}'.format(FLAGS.tau),
                    'gamma_{}'.format(FLAGS.gamma),
                    'init_eps_{}'.format(FLAGS.initial_epsilon),
                    'min_eps_{}'.format(FLAGS.min_epsilon),
                    'lr_{}'.format(FLAGS.learning_rate),
                    'eps_decay_{}'.format(FLAGS.epsilon_decay),
                    '{}'.format(FLAGS.obs_type),
                    'K_immediate_{}'.format(FLAGS.K_immediate),
                    'record_atomic_{}'.format(FLAGS.record_atomic),
                    'masking_q_{}'.format(FLAGS.masking_q),
                    'onehot_goal_{}'.format(FLAGS.onehot_goal),
                    'state_goal_{}'.format(FLAGS.state_goal),
                    'pretrained_encoder_{}'.format(FLAGS.use_pretrained_encoder),
                    'mutate_{}'.format(FLAGS.mutate),
                    'generalization_{}'.format(FLAGS.use_subset),
                    'continual_learning_{}'.format(FLAGS.continual_learning),
                    'boltzman_{}'.format(FLAGS.pseudo_boltzman),
                    'img_param_{}'.format(FLAGS.image_action_parameterization),
                    'fs_{}'.format(FLAGS.frame_skip),
                    'ptr_emb_{}'.format(FLAGS.pretrained_embedding),
                    'polar_{}'.format(FLAGS.use_polar),
                    'suppress_{}'.format(FLAGS.suppress),
                    'vqa_{}'.format(FLAGS.use_vqa)
                    ]

  if FLAGS.save_model:
    try:
      tf.gfile.MkDir(FLAGS.save_dir)
    except tf.gfile.Error as e:
      print(e)
    model_dir = os.path.join(FLAGS.save_dir, '_'.join(exp_config_str))
    try:
      tf.gfile.MkDir(model_dir)
    except tf.gfile.Error as e:
      print(e)

  # building environment ======================================================
  env = ClevrEnv(num_object=5, num_objective=100, random_start=True,
                 fixed_objective=True, maximum_episode_steps=max_episode_len,
                 description_num=0, rew_shape=shaped_reward,
                 direct_obs=direct_obs, action_type=action_type,
                 reward_scale=1.0, shape_val=FLAGS.shape_val, resolution=res,
                 obs_type=obs_type, use_subset_instruction=FLAGS.use_subset,
                 continual_learning=FLAGS.continual_learning,
                 use_state_goal=FLAGS.state_goal, frame_skip=FLAGS.frame_skip,
                 use_polar=FLAGS.use_polar,
                 suppress_other_movement=FLAGS.suppress,
                 variable_scene_content=FLAGS.variable_scene_content,
                 decode_fn=decode_fn)
  # env.encode_objective(encode_fn)

  # onehot
  duplicate = 10
  goal_enumerate = {}
  enumerate_goal = {}
  count = 0
  for i in range(env.all_question_num):
    bins = [count+j for j in range(duplicate)]
    goal_enumerate[env.all_objective[i][0]] = bins
    for j in bins:
      enumerate_goal[j] = env.all_objective[i][0]
    count += duplicate
  if FLAGS.onehot_goal:
    encode_fn = lambda text: random.choice(goal_enumerate[text])
    decode_fn = lambda bin: enumerate_goal[int(bin)]
  # state as goal
  if FLAGS.state_goal:
    encode_fn = lambda goal: goal
    decode_fn = lambda goal: goal

  # Building up models=========================================================
  if obs_type == 'direct':
    modelNetwork = Model(input_dim=input_dim, ac_dim=ac_dim,
                         vocab_size=vocab_size, embedding_size=embedding_size,
                         conv_layer_config=conv_layer_config,
                         dense_layer_config=dense_layer_config,
                         encoder_n_unit=encoder_n_unit, seq_length=MAX_LENGTH,
                         name="model", direct_obs = direct_obs,
                         learning_rate=learning_rate,
                         variable_input=obs_type=='order_invariant')
    targetNetwork = Model(input_dim=input_dim, ac_dim=ac_dim,
                          vocab_size=vocab_size, embedding_size=embedding_size,
                          conv_layer_config=conv_layer_config,
                          dense_layer_config=dense_layer_config,
                          encoder_n_unit=encoder_n_unit,
                          seq_length = MAX_LENGTH,
                          name="target", direct_obs = direct_obs,
                          variable_input=obs_type=='order_invariant')

    if FLAGS.cpu_collect:
      with tf.device("/cpu:0"):
        modelNetwork_cpu = Model(input_dim=input_dim, ac_dim=ac_dim,
                                 vocab_size=vocab_size,
                                 embedding_size=embedding_size,
                                 conv_layer_config=conv_layer_config,
                                 dense_layer_config=dense_layer_config,
                                 encoder_n_unit=encoder_n_unit,
                                 seq_length = MAX_LENGTH,
                                 name="model", direct_obs = direct_obs,
                                 learning_rate=learning_rate, reuse=True,
                                 variable_input=obs_type=='order_invariant')
        targetNetwork_cpu = Model(input_dim=input_dim, ac_dim=ac_dim,
                                  vocab_size=vocab_size,
                                  embedding_size=embedding_size,
                                  conv_layer_config=conv_layer_config,
                                  dense_layer_config=dense_layer_config,
                                  encoder_n_unit=encoder_n_unit,
                                  seq_length = MAX_LENGTH, reuse=True,
                                  name="target", direct_obs=direct_obs,
                                  variable_input=obs_type=='order_invariant')
    else:
      modelNetwork_cpu = modelNetwork
      targetNetwork_cpu = targetNetwork
  elif obs_type == 'order_invariant':
    print('using order invariant input...')
    modelNetwork = VariableInputModel(
        name='model',
        input_dim=input_dim,
        per_input_ac_dim=ac_dim,
        vocab_size=vocab_size,
        embedding_size=embedding_size,
        des_len=descriptor_length,
        inner_len=inner_product_length,
        encoder_n_unit=encoder_n_unit,
        learning_rate=learning_rate,
        action_type=action_type,
        onehot_goal=FLAGS.onehot_goal,
        state_goal=FLAGS.state_goal,
        pretrained_embedding=FLAGS.pretrained_embedding,
        max_input_length=env.all_question_num*duplicate,
        trainable_encoder=True,
    )
    targetNetwork = VariableInputModel(
        name='target',
        input_dim=input_dim,
        per_input_ac_dim=ac_dim,
        vocab_size=vocab_size,
        embedding_size=embedding_size,
        des_len=descriptor_length,
        inner_len=inner_product_length,
        encoder_n_unit=encoder_n_unit,
        learning_rate=learning_rate,
        action_type=action_type,
        onehot_goal=FLAGS.onehot_goal,
        state_goal=FLAGS.state_goal,
        pretrained_embedding=FLAGS.pretrained_embedding,
        max_input_length=env.all_question_num*duplicate,
        trainable_encoder=True
    )
    modelNetwork_cpu = modelNetwork
    targetNetwork_cpu = targetNetwork
  elif obs_type == 'image':
    print('Using image model')
    modelNetwork = ImageModel(
        name='model',
        input_dim=input_dim,
        ac_dim=ac_dim,
        vocab_size=vocab_size,
        embedding_size=embedding_size,
        conv_layer_config=conv_layer_config,
        dense_layer_config=dense_layer_config,
        encoder_n_unit=encoder_n_unit,
        action_type=FLAGS.action_type,
        action_parameterization=FLAGS.image_action_parameterization,
        onehot_goal=FLAGS.onehot_goal,
        pretrained_embedding=FLAGS.pretrained_embedding,
        max_input_length=env.all_question_num*duplicate,
        use_vqa=FLAGS.use_vqa,
    )
    targetNetwork = ImageModel(
        name='target',
        input_dim=input_dim,
        ac_dim=ac_dim,
        vocab_size=vocab_size,
        embedding_size=embedding_size,
        conv_layer_config=conv_layer_config,
        dense_layer_config=dense_layer_config,
        encoder_n_unit=encoder_n_unit,
        action_type=FLAGS.action_type,
        action_parameterization=FLAGS.image_action_parameterization,
        onehot_goal=FLAGS.onehot_goal,
        pretrained_embedding=FLAGS.pretrained_embedding,
        max_input_length=env.all_question_num*duplicate,
        use_vqa=FLAGS.use_vqa,
    )
    modelNetwork_cpu = modelNetwork
    targetNetwork_cpu = targetNetwork


  # End of building models=====================================================

  # summary utils
  summary_writer = tf.summary.FileWriter(model_dir)
  # update ops
  trainables = tf.trainable_variables()
  updateOps = updateTargetGraph(trainables, tau)

  buff = Buffer(buffer_size)
  vqa_buff = Buffer(buffer_size//10)

  ##############################################################################
  # Roll out functions
  ##############################################################################
  def rollout(sess, episode_num, episode_n=5, custom_dir=None, timeout=10):
    all_frames = []
    black_scene = env.render(mode='rgb_array')*0.0
    for i in range(episode_n):
      s = env.reset(True)
      for _ in range(20):
        all_frames.append(black_scene)
      g, p = env.sample_goal()
      if env.all_goals_satisfied:
        env.reset()
        g, p = env.sample_goal()
      if FLAGS.mutate:
        g = mutate_sentence(g, delete_color=FLAGS.variable_scene_content)
      g_text = g
      print('current goal {}'.format(g))
      g = np.squeeze(encode_fn(g))
      current_goal_repeat = 0
      for t in range(100):
        action, q_val = sess.run(
            [targetNetwork_cpu.predict, targetNetwork_cpu.Q_],
            feed_dict={targetNetwork_cpu.inputs: [s],
                       targetNetwork_cpu.word_inputs: [g],
                       targetNetwork_cpu.is_training: False})
        action = action[0]
        s_next, reward, done, _ = env.step(action, True, p)
        # all_frames.append(add_text(env.render(mode='rgb_array'), g_text))
        s = s_next
        current_goal_repeat += 1
        q_val.sort()
        print('episode {}  step {}  action {}'.format(i, t, action))
        # print('q values: {}'.format(q_val[0][::-1][:50]))
        if reward > env.shape_val or current_goal_repeat >= timeout:
          if reward > env.shape_val:
            # for _ in range(3):
            #   all_frames.append(
            #       add_text(env.render(mode='rgb_array'), g_text, color='green'))
            print('episode {} success'.format(i))
          else:
            # all_frames.append(
            #     add_text(env.render(mode='rgb_array'), 'timeout', False))
            print('timed out :( {}'.format(g_text))
          g, p = env.sample_goal()
          if env.all_goals_satisfied:
            break
          g_text = g
          g = np.squeeze(encode_fn(g))
          current_goal_repeat = 0
    # if custom_dir:
    #   save_video(
    #       np.uint8(all_frames),
    #       os.path.join(custom_dir, 'rollout2.mp4'.format(episode_num)), fps=10.)
    # else:
    #   save_video(
    #       np.uint8(all_frames),
    #       os.path.join(model_dir, '{}_video.mp4'.format(episode_num)))

  def rollout_with_questions(sess, questions, episode_n=10, timeout=10, video=False):
    print(len(questions))
    print('rolling out with questions')

    if video:
      all_frames = []
      black_scene = env.render(mode='rgb_array')*0.0

    achieved_goal = 0
    for i in range(episode_n):
      print('================ep {}==================='.format(i))
      if video:
        for _ in range(20):
          all_frames.append(black_scene)
      s = env.reset(True)
      # questions = env.valid_questions
      g, p = random.choice(questions)
      while env.get_answer(p):
        g, p = random.choice(questions)
      g_text = g
      g = np.squeeze(encode_fn(g))
      current_goal_repeat = 0
      for j in range(30):
        # action = sess.run(
        #     targetNetwork_cpu.predict,
        #     feed_dict={targetNetwork_cpu.inputs: [s],
        #                targetNetwork_cpu.word_inputs: [g],
        #                targetNetwork_cpu.is_training: False})
        # action = action[0]
        action_q = sess.run(
            modelNetwork_cpu.Q_,
            feed_dict={modelNetwork_cpu.inputs: [s],
                       modelNetwork_cpu.word_inputs: [g],
                       modelNetwork_cpu.is_training: False})
        action_q = np.squeeze(action_q)
        action = np.squeeze(action_q).argsort()[-1]

        s_next, reward, _, _ = env.step(action, True, p)

        # if video:
        #   all_frames.append(add_text(env.render(mode='rgb_array'), g_text))

        current_goal_repeat += 1
        s = s_next
        if reward > env.shape_val or current_goal_repeat >= timeout:
          if reward > env.shape_val:
            print(g_text)
            print('achieved')
            achieved_goal += 1
            # if video:
            #   for _ in range(3):
            #     all_frames.append(
            #         add_text(env.render(mode='rgb_array'), g_text, color='green'))
          # else:
            # if video:
            #   all_frames.append(
            #       add_text(env.render(mode='rgb_array'), 'timeout', False))
          g, p = random.choice(questions)
          while env.get_answer(p):
            g, p = random.choice(questions)
          g_text = g
          g = np.squeeze(encode_fn(g))
          current_goal_repeat = 0
    print(achieved_goal/float(episode_n)*10./3.)
    return achieved_goal/float(episode_n)*10./3.

  def rollout_training(sess, episode_num, epsilon, custom_dir=None, timeout=10):
    all_frames = []
    black_scene = env.render(mode='rgb_array')*0.0
    for n in range(5):
      s = env.reset(True)
      for _ in range(20):
        all_frames.append(black_scene)
      g, p = env.sample_goal()
      if env.all_goals_satisfied:
        env.reset()
        g, p = env.sample_goal()
      g_text = g
      print('current goal {}'.format(g_text))
      g = np.squeeze(encode_fn(g))
      current_goal_repeat = 0
      for t in range(100):
        action_q = sess.run(
            modelNetwork_cpu.Q_,
            feed_dict={modelNetwork_cpu.inputs: [s],
                       modelNetwork_cpu.word_inputs: [g],
                       modelNetwork_cpu.is_training: False})
        action_q = np.squeeze(action_q)
        action = np.squeeze(action_q).argsort()[-1]
        a_type = 'max q value'

        s_next, reward, done, _ = env.step(
            action, True, goal=p, atomic_goal=FLAGS.record_atomic)
        # all_frames.append(add_text(env.render(mode='rgb_array'), g_text))
        print('train episode {}  step {}  action {}  action category {}'.format(n, t, action, a_type))
        s = s_next
        current_goal_repeat += 1

        if reward > env.shape_val or current_goal_repeat >= timeout:
          if reward > env.shape_val:
            # for _ in range(3):
              # all_frames.append(
              #     add_text(env.render(mode='rgb_array'), g_text, color='green'))
            print('episode {} success'.format(i))
          else:
            # all_frames.append(
            #     add_text(env.render(mode='rgb_array'), 'timeout', False))
            print('timed out :( {}'.format(g_text))
          g, p = env.sample_goal()
          if env.all_goals_satisfied:
            break
          g_text = g
          g = np.squeeze(encode_fn(g))
          current_goal_repeat = 0
          print('new goal: {}'.format(g_text))

    # if custom_dir:
    #   save_video(
    #       np.uint8(all_frames),
    #       os.path.join(custom_dir, 'train_rollout2.mp4'), fps=10.)
    # else:
    #   save_video(
    #       np.uint8(all_frames),
    #       os.path.join(model_dir, 'train_{}_video.mp4'.format(episode_num)),
    #       fps=10.
    #   )

  ##############################################################################
  # Training
  ##############################################################################
  print('action dimension: {}'.format(ac_dim))

  if train:
    with tf.Session() as sess:
      # load from checkpoint if there is one
      saver = tf.train.Saver()
      global_step = tf.train.get_or_create_global_step(graph=None)
      increment_global_step_op = tf.assign(global_step, global_step+1)
      sess.run([modelNetwork.init_op,
                targetNetwork.init_op,
                tf.variables_initializer([global_step])])

      if FLAGS.use_pretrained_encoder or FLAGS.pretrained_embedding:
        print('resotring pretrained language encoder.')
        modelNetwork.load_pretrained_encoder(sess, encoder_decoder_path)
        targetNetwork.load_pretrained_encoder(sess, encoder_decoder_path)

      if FLAGS.save_model:
        try:
          print('trying to load checkpoints from {}'.format(model_dir))
          saver.restore(
              sess,
              os.path.join(model_dir, 'model.ckpt'),
          )
        except ValueError:
          print('No existing checkpoints')


      tq = tqdm(range(num_epochs), total=num_epochs)
      for i in tq:
        average_episode_rew = []
        average_episode_achieved_n = []
        for j in range(num_cycles):
          # rollout(sess, 0, custom_dir=None, episode_n=3, timeout=10)
          gc.collect()
          # decay epsilon
          if epsilon > min_epsilon and (j > collect_cycle or i > 0):
            epsilon *= FLAGS.epsilon_decay
          start_time = time.time()  # new place for timing
          for n in range(num_episodes):
            sess.run(increment_global_step_op)
            # Deciding if to reset the scene content
            # TOOD(ydjiang): probability
            sample_new_scene = random.uniform(0, 1) < 0.1
            s = env.reset(sample_new_scene)
            episode_experience = []
            episode_rew = []
            episode_achieved_n = 0
            g, p = env.sample_goal()
            if env.all_goals_satisfied:
              s = env.reset(True)
              g, p = env.sample_goal()
            g_text = g
            g = np.squeeze(encode_fn(g))

            if FLAGS.use_vqa and sample_new_scene:
              vq_questions, vq_programs = zip(*env.valid_questions)
              vq_encoded_question = [np.squeeze(encode_fn(vq)) for vq in vq_questions]
              vqa_answers = [env.get_answer(vp) for vp in vq_programs]
              for vq, va in zip(vq_encoded_question, vqa_answers):
                vq_ = decode_fn(vq)
                for _ in range(5):
                  mvq_ = mutate_sentence(vq_, delete_color=True)
                  mvq = encode_fn(mvq_)
                  vqa_buff.add((s, mvq, va))

            for t in range(max_episode_len):
              if np.random.rand(1) < epsilon:
                if FLAGS.pseudo_boltzman and i > 20:
                  action_q = sess.run(
                      modelNetwork_cpu.Q_,
                      feed_dict={modelNetwork_cpu.inputs: [s],
                                 modelNetwork_cpu.word_inputs: [g],
                                 modelNetwork_cpu.is_training: False})
                  top_idx = np.squeeze(action_q).argsort()[-100:]
                  selected_q = action_q[top_idx]
                  selected_q -= np.max(selected_q)
                  top_q = np.exp(selected_q)
                  top_q /= top_q.sum()
                  action = np.random.choice(top_idx, p=top_q)
                else:
                  action = env.sample_action()
              else:
                if obs_type != 'order_invariant' or action_type != 'perfect':
                  action_q = sess.run(
                      modelNetwork_cpu.Q_,
                      feed_dict={modelNetwork_cpu.inputs: [s],
                                 modelNetwork_cpu.word_inputs: [g],
                                 modelNetwork_cpu.is_training: False})
                  action_q = np.squeeze(action_q)
                  # if FLAGS.pseudo_boltzman and np.random.rand(1) < epsilon:
                  #   top_idx = np.squeeze(action_q).argsort()[-20:]
                  #   top_q = np.exp(action_q[top_idx])
                  #   top_q /= top_q.sum()
                  #   action = np.random.choice(top_idx, p=top_q)
                  # else:
                  #   action = np.squeeze(action_q).argsort()[-1]
                  if FLAGS.pseudo_boltzman:
                    top_idx = np.squeeze(action_q).argsort()[-10:]
                    selected_q = action_q[top_idx]/(epsilon-min_epsilon+0.01)
                    selected_q -= np.max(selected_q)
                    top_q = np.exp(selected_q)
                    top_q /= top_q.sum()
                    action = np.random.choice(top_idx, p=top_q)
                  else:
                    action = np.squeeze(action_q).argsort()[-1]
                else:
                  action = sess.run(
                      modelNetwork_cpu.predict,
                      feed_dict={modelNetwork_cpu.inputs: [s],
                                 modelNetwork_cpu.word_inputs: [g],
                                 modelNetwork_cpu.is_training: False})
                  action = np.squeeze(action)

              if not FLAGS.state_goal:
                s_next, reward, done, _ = env.step(
                    action, True, goal=p, atomic_goal=FLAGS.record_atomic)
                ag = env.achieved_last_step
              else:
                # using state as direct reward
                s_next, _, done, _ = env.step(
                    action, False, goal=None, atomic_goal=False)
                reward = env.get_answer(g)
                ag = [s_next]

              episode_experience.append((s, action, reward, s_next, g, ag))
              episode_rew.append(reward)
              s = s_next

              if FLAGS.state_goal:
                # use state goal, threshold set to 1.0
                if reward > -0.8:
                  episode_achieved_n += 1
                if reward > -0.5:
                  g, p = env.sample_goal()
                  g_text = g
                  g = np.squeeze(encode_fn(g))
              elif reward > env.shape_val:
                episode_achieved_n += 1
                g, p = env.sample_goal()
                if env.all_goals_satisfied:
                  break
                g_text = g
                g = np.squeeze(encode_fn(g))

              if done: break

            average_episode_rew.append(np.sum(episode_rew))
            if len(average_episode_rew) > 100:
              average_episode_rew = average_episode_rew[-100:]
            average_episode_achieved_n.append(episode_achieved_n)
            if len(average_episode_achieved_n) > 100:
              average_episode_achieved_n = average_episode_achieved_n[-100:]

            # processing trajectory
            episode_length = len(episode_experience)
            for t in range(episode_length):
              s, a, r, s_n, g, ag = episode_experience[t]
              # =================================================
              g = decode_fn(g)
              if FLAGS.mutate:
                g = mutate_sentence(g, delete_color=FLAGS.variable_scene_content)
              g = encode_fn(g)
              # =================================================
              buff.add(np.array((s, a, r, s_n, g)))
              if HER:
                if FLAGS.state_goal:
                  relabeling_state(
                      s, s_n, a, env, ag, buff, encode_fn, episode_length,
                      t, K, episode_experience, gamma)
                else:
                  relabeling(
                      s, s_n, a, env, ag, buff, encode_fn, episode_length,
                      t, K, episode_experience, gamma)

            mean_loss = []
            mean_vqa_loss = []
            # start_time = time.time()
            # training
            if j > collect_cycle or i > 0:
              for k in range(optimisation_steps):
                experience = buff.sample(batch_size)
                s, a, r, s_next, g = [
                    np.squeeze(elem, axis=1) for elem in np.split(experience, 5, 1)
                ]
                s = np.stack(s)
                s_next = np.stack(s_next)
                g = np.array([gg for gg in g])

                if not FLAGS.onehot_goal and not FLAGS.state_goal:
                  g = pad_to_max_length(g)

                # compute q values and relevant stuff
                Q2 = sess.run(
                    targetNetwork.Q_,
                    feed_dict={targetNetwork.inputs: s_next,
                               targetNetwork.word_inputs: g,
                               targetNetwork.is_training: True})

                if obs_type != 'order_invariant' or action_type != 'perfect':
                  Q1 = sess.run(
                      modelNetwork.Q_,
                      feed_dict={modelNetwork.inputs: s_next,
                                 modelNetwork.word_inputs: g,
                                 modelNetwork.is_training: True})
                  doubleQ = Q2[:, np.argmax(Q1, axis=-1)]
                else:
                  action = sess.run(
                      modelNetwork.predict,
                      feed_dict={modelNetwork.inputs: s_next,
                                 modelNetwork.word_inputs: g,
                                 modelNetwork.is_training: True})
                  indices = np.arange(action.shape[0])
                  doubleQ = Q2[indices, action[:, 0], action[:, 1]]
                  a = np.stack(a)

                if FLAGS.masking_q:
                  # masking at the end of the episode
                  doubleQ *= (1. - (r >= env.reward_scale).astype(np.float32))

                Q_target = np.clip(r + gamma * doubleQ, -1. / (1 - gamma), 10.0)

                _, loss = sess.run(
                    [modelNetwork.train_op, modelNetwork.loss],
                    feed_dict={
                        modelNetwork.inputs: s,
                        modelNetwork.Q_next: Q_target,
                        modelNetwork.action: a,
                        modelNetwork.word_inputs: g,
                        modelNetwork.is_training: True
                    })
                mean_loss.append(loss)

                if FLAGS.use_vqa and len(vqa_buff.buffer) > 30000 and i < 1:
                  vqa_data = vqa_buff.sample(batch_size)
                  obs, question, answer = [
                      np.squeeze(elem, axis=1) for elem in np.split(vqa_data, 3, 1)
                  ]
                  obs = np.stack(obs)
                  question = pad_to_max_length(question)
                  answer = np.stack(answer)
                  _, vqa_loss = sess.run(
                      [modelNetwork.train_vqa_op, modelNetwork.vqa_accuracy],
                      feed_dict={
                          modelNetwork.inputs: obs,
                          modelNetwork.answer_ph: np.int32(answer),
                          modelNetwork.word_inputs: question,
                          modelNetwork.is_training: True
                      })
                  mean_vqa_loss.append(vqa_loss)

            time_elapse = time.time()-start_time

            mean_100_eps_loss = np.mean(mean_loss) if mean_loss else 0
            mean_100_eps_achieved = np.mean(average_episode_achieved_n)
            mean_100_eps_reward = np.mean(average_episode_rew)

            if n % FLAGS.log_iter == 0:
              tq.set_description(
                  "cycle {}, epi {}, num achieve: {}, reward: {}, loss: {}, step/sec: {}".format(
                      j, n,
                      mean_100_eps_achieved,
                      mean_100_eps_reward,
                      mean_100_eps_loss,
                      n/(time_elapse+1e-6)
                  )
              )
              tq.refresh()
              if save_model:
                summary_list = [
                    tf.Summary.Value(tag='mean 100 eps loss',
                                     simple_value=mean_100_eps_loss),
                    tf.Summary.Value(tag='mean 100 achieved n',
                                     simple_value=mean_100_eps_achieved),
                    tf.Summary.Value(tag='mean 100 reward',
                                     simple_value=mean_100_eps_reward),
                    tf.Summary.Value(tag='exploration probability',
                                     simple_value=epsilon)
                ]
                if FLAGS.use_vqa:
                  vqa_value = np.mean(mean_vqa_loss) if mean_vqa_loss else 0
                  summary_list.append(
                      tf.Summary.Value(tag='vqa accuracy',
                                       simple_value=vqa_value))

                if FLAGS.use_subset:
                  achieved_on_training = rollout_with_questions(
                      sess, env.all_objective)
                  achieved_on_test = rollout_with_questions(
                      sess, env.held_out_objective)
                  summary_list.append(
                      tf.Summary.Value(tag='training goal achieved',
                                       simple_value=achieved_on_training))
                  summary_list.append(
                      tf.Summary.Value(tag='test goal achieved',
                                       simple_value=achieved_on_test))

                summary = tf.Summary(value=summary_list)
                summary_writer.add_summary(
                    summary, sess.run(global_step))

          # save videos
          if FLAGS.save_video and j % 20 == 0 and j > 0:
            print('~~~~~~~~~~making videos~~~~~~~~~~')
            # rollout(sess, sess.run(global_step), episode_n=10)
            rollout_training(sess, sess.run(global_step), epsilon)

          # update target network
          updateTarget(updateOps, sess)
          if save_model:
            saver.save(
                sess, os.path.join(model_dir, 'model.ckpt'),
                global_step=global_step
            )


if __name__ == '__main__':
  app.run(main)
