# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Run experiments on high level tasks.."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
import cPickle as pickle
import os
import numpy as np
import tensorflow as tf
import random

from gym import spaces
from hal.clevr_env import ClevrEnv
from baselines import logger
from baselines import deepq

import hal.high_level_policy.visual_feature_extractor as vfe
import hal.high_level_policy.high_level_env as envs
import hal.language_processing_utils.word_vectorization as wv

import warnings
warnings.filterwarnings("ignore")

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'batch size')
flags.DEFINE_integer('buffer_size', int(1e5), 'Size of replay buffer')
flags.DEFINE_integer('max_step', int(4e6), 'training step')
flags.DEFINE_float('exp_frac', 0.5,
                  'Fraction of the training time spent exploring')
flags.DEFINE_float('final_eps', 0.05, 'Final exploration epsilon')
flags.DEFINE_integer('train_freq', 1, 'How often the network is trained')
flags.DEFINE_integer('train_step_per_step', 1, 'train step per interaction')
flags.DEFINE_integer('learning_starts', 10000, 'when learning starts')
flags.DEFINE_integer('target_update_freq', 1000, 'how often target is updated')
flags.DEFINE_float('gamma', 0.99, 'exp moving average coefficient')
flags.DEFINE_float('prior_alpha', 0.7,
                   'how much good experiences are prioritized')
flags.DEFINE_bool('reward_shaping', False, 'whether to shape reward')
flags.DEFINE_string('save_dir', 'None', 'where the experiments are stored')
flags.DEFINE_integer('print_freq', 10, 'how often to print')
flags.DEFINE_integer('dueling', 0, 'use dueling network')
flags.DEFINE_bool('direct_obs', True,
                  'whether to use direct state observation')
flags.DEFINE_string('action_type', 'perfect', 'what type of action is allowed')
flags.DEFINE_integer('sample_episode_n', 20, 'number of episode to sample')
flags.DEFINE_bool('repeat_obj', False,
                  'if a goal can be satisfied more than once')
flags.DEFINE_bool('prioritized_replay', False,
                  'if prioritized replay buffer is used')
flags.DEFINE_bool('stationary', True, 'if the target is stationary')
flags.DEFINE_bool('soft', False, 'if soft q learning is used')
flags.DEFINE_bool('normal_dqn', True, 'if normal dqn')
flags.DEFINE_bool('augmented_action', False, 'use_augmented_action_space')
flags.DEFINE_string('task_type', 'statement', 'what type of task')
flags.DEFINE_string('obs_type', 'order_invariant', 'type of observation')
flags.DEFINE_bool('self_attention_high_level', False,
                  'use self attention in the high level policy')
flags.DEFINE_bool('spatial_softmax', True,
                  'use spatial softmax in the high level policy')
flags.DEFINE_bool('top_down', False, 'view from top down')
flags.DEFINE_string(
    'low_level_obs_type', 'image', 'what observation low level policy uses.')
flags.DEFINE_string(
    'low_level_action_type', 'discrete', 'action used by low level policy')
flags.DEFINE_integer('cp', 0, 'copy of experiment')


max_starting = -8
min_reward = 0

vocab_path = diverse_vocab_path = os.path.join(__file__, '..', 'assets/vocab.txt')


vocab_list = wv.load_vocab_list(vocab_path)
_, i2v = wv.create_look_up_table(vocab_list)
decode_fn = wv.decode_with_lookup_table(i2v)

divserse_instruction_path = None  # instruction sets fro diverse environment


def _cnn_to_mlp(convs, hiddens, dueling, inpt, num_actions, scope, reuse=False,
    layer_norm=False, self_attention=False, spatial_softmax=False):

  with tf.variable_scope(scope, reuse=reuse):
    out = inpt
    conv_layer_config = [(48, 8, 2), (128, 5, 2), (64, 3, 1)]
    layer_film_params = load_film_params()
    layer_out = []
    with tf.variable_scope("conv"):
      for cfg, param in zip(conv_layer_config, layer_film_params):
        out = tf.layers.conv2d(
            out, cfg[0], cfg[1], cfg[2], padding='SAME', trainable=False)
        out = tf.layers.batch_normalization(
            out, center=False, scale=False, training=False, trainable=False)
        layer_out.append(out)
        gamma, beta = tf.split(param, 2, axis=1)
        out *= tf.expand_dims(tf.expand_dims(gamma, 1), 1)
        out += tf.expand_dims(tf.expand_dims(beta, 1), 1)
        out = tf.nn.relu(out)
    # ===============================hack====================================
    color_filter = [[0.57735027, -0.57735027, -0.57735027],
                    [-0.57735027, -0.57735027, 0.57735027],
                    [-0.57735027, 0.57735027, -0.57735027],
                    [-0.98162999, 0.163605, 0.098163],
                    [0.06311944, -0.9467916, 0.3155972]]
    kernel = np.expand_dims(np.expand_dims(np.float32(color_filter).T, 0), 0)
    out = tf.nn.conv2d(inpt, tf.constant(kernel), 1, "SAME")
    out = tf.nn.relu(out)
    # TODO: this might be the problem
    # out = out / tf.reduce_max(out, axis=(1, 2), keepdims=True)
    # ===============================hack====================================
    if spatial_softmax:
      # conv_out = layer_out[0]/0.01
      # conv_out = tf.contrib.layers.spatial_softmax(conv_out, trainable=False)
      conv_out = tf.contrib.layers.spatial_softmax(out/0.01, trainable=False)
    else:
      out = tf.contrib.layers.instance_norm(out)
      conv_out = layers.flatten(out)

    with tf.variable_scope("action_value"):
      action_out = conv_out
      for hidden in hiddens:
        action_out = layers.fully_connected(
            action_out, num_outputs=hidden, activation_fn=None)
        if layer_norm:
          action_out = layers.layer_norm(action_out, center=True, scale=True)
        action_out = tf.nn.relu(action_out)
      action_scores = layers.fully_connected(
          action_out, num_outputs=num_actions, activation_fn=None)

    if dueling:
      with tf.variable_scope("state_value"):
        state_out = conv_out
        for hidden in hiddens:
          state_out = layers.fully_connected(
              state_out, num_outputs=hidden, activation_fn=None)
          if layer_norm:
            state_out = layers.layer_norm(state_out, center=True, scale=True)
          state_out = tf.nn.relu(state_out)
        state_score = layers.fully_connected(
            state_out, num_outputs=1, activation_fn=None)
      action_scores_mean = tf.reduce_mean(action_scores, 1)
      action_scores_centered = action_scores - \
                               tf.expand_dims(action_scores_mean, 1)
      q_out = state_score + action_scores_centered
    else:
      q_out = action_scores

    return q_out


def cnn_to_mlp(convs, hiddens, dueling=False, layer_norm=False, self_attention=False,
               spatial_softmax=True):
  """This model takes as input an observation and returns values of all actions.

  Parameters
  ----------
  convs: [(int, int int)]
      list of convolutional layers in form of
      (num_outputs, kernel_size, stride)
  hiddens: [int]
      list of sizes of hidden layers
  dueling: bool
      if true double the output MLP to compute a baseline
      for action scores

  Returns
  -------
  q_func: function
      q_function for DQN algorithm.
  """

  return lambda *args, **kwargs: _cnn_to_mlp(
      convs, hiddens, dueling, layer_norm=layer_norm, self_attention=self_attention,
      spatial_softmax=spatial_softmax, *args, **kwargs)


def main(_):

  obs_type = FLAGS.obs_type
  if FLAGS.normal_dqn:
    obs_type = 'direct'

  self_attn = obs_type == 'order_invariant'

  env = ClevrEnv(
      num_object=5, num_objective=10, random_start=True, description_num=30,
      action_type=FLAGS.low_level_action_type, maximum_episode_steps=100,
      rew_shape=False, direct_obs=obs_type!='image', repeat_objective=False,
      fixed_objective=True, obs_type=FLAGS.low_level_obs_type, resolution=64,
      top_down_view=FLAGS.top_down
  )
  sess = tf.Session()

  if FLAGS.normal_dqn and FLAGS.low_level_obs_type == 'direct':
    if FLAGS.task_type == 'statement':
      print('task is: statement recovering')
      env = envs.HighLevelWrapper(env, obs_type='direct')
    elif FLAGS.task_type == 'sort':
      print('task is: sort')
      env = envs.HighLevelSortWrapper(env, obs_type='direct')
    elif FLAGS.task_type == '2dsort':
      print('task is 2D sort')
      env = envs.HighLevel2DSortWrapper(env, obs_type='direct')
  elif FLAGS.normal_dqn and FLAGS.low_level_obs_type == 'image':
    if FLAGS.task_type == 'statement':
      print('task is: statement recovering')
      env = envs.HighLevelWrapper(env, obs_type='image', sess=sess)
    elif FLAGS.task_type == 'sort':
      print('task is: sort')
      env = envs.HighLevelSortWrapper(env, obs_type='image', sess=sess)
    elif FLAGS.task_type == '2dsort':
      print('task is 2D sort')
      env = envs.HighLevel2DSortWrapper(env, obs_type='image', sess=sess)
    elif FLAGS.task_type == 'colorsort':
      print('task is: colorsort')
      del env
      env = ClevrEnv(
          num_object=5, num_objective=10, random_start=True, description_num=30,
          action_type=FLAGS.low_level_action_type, maximum_episode_steps=100,
          rew_shape=False, direct_obs=obs_type!='image', repeat_objective=False,
          fixed_objective=True, obs_type=FLAGS.low_level_obs_type, resolution=64,
          top_down_view=FLAGS.top_down, variable_scene_content=True
      )
      env = envs.HighLevelColorSortWrapper(env, obs_type='image', sess=sess)
    elif FLAGS.task_type == 'shapesort':
      print('task is: shapesort')
      del env
      env = ClevrEnv(
          num_object=5, num_objective=10, random_start=True, description_num=30,
          action_type=FLAGS.low_level_action_type, maximum_episode_steps=100,
          rew_shape=False, direct_obs=obs_type!='image', repeat_objective=False,
          fixed_objective=True, obs_type=FLAGS.low_level_obs_type, resolution=64,
          top_down_view=FLAGS.top_down, variable_scene_content=True
      )
      env = envs.HighLevelShapeSortWrapper(env, obs_type='image', sess=sess)
    elif FLAGS.task_type == 'colorshapesort':
      print('task is: colorshapesort')
      del env
      env = ClevrEnv(
          num_object=5, num_objective=10, random_start=True, description_num=30,
          action_type=FLAGS.low_level_action_type, maximum_episode_steps=100,
          rew_shape=False, direct_obs=obs_type!='image', repeat_objective=False,
          fixed_objective=True, obs_type=FLAGS.low_level_obs_type, resolution=64,
          top_down_view=FLAGS.top_down, variable_scene_content=True
      )
      env = envs.HighLevelColorShapeSortWrapper(env, obs_type='image', sess=sess)
  elif self_attn and obs_type == 'order_invariant':
    sess = tf.Session()
    if FLAGS.task_type == 'statement':
      print('task is: statement recovering')
      env = envs.DiscreteHierarchicalEnv(
          env, sess,
          self_attention_low_level=True,
          augmented_actions=FLAGS.augmented_action,
          low_level_obs_type=FLAGS.low_level_obs_type,
          low_level_action_type=FLAGS.low_level_action_type)
    elif FLAGS.task_type == 'sort':
      print('task is: sort')
      env = envs.DiscreteHierarchicalSortEnv(
          env, sess,
          self_attention_low_level=True,
          augmented_actions=FLAGS.augmented_action,
          low_level_obs_type=FLAGS.low_level_obs_type,
          low_level_action_type=FLAGS.low_level_action_type)
    elif FLAGS.task_type == '2dsort':
      print('task is: 2dsort')
      env = envs.DiscreteHierarchical2DSortEnv(
          env, sess,
          self_attention_low_level=True,
          augmented_actions=FLAGS.augmented_action,
          low_level_obs_type=FLAGS.low_level_obs_type,
          low_level_action_type=FLAGS.low_level_action_type)
    elif FLAGS.task_type == 'colorsort':
      print('task is: colorsort')
      del env
      env = ClevrEnv(
          num_object=5, num_objective=10, random_start=True, description_num=30,
          action_type=FLAGS.low_level_action_type, maximum_episode_steps=100,
          rew_shape=False, direct_obs=obs_type!='image', repeat_objective=False,
          fixed_objective=True, obs_type=FLAGS.low_level_obs_type, resolution=64,
          top_down_view=FLAGS.top_down, variable_scene_content=True
      )
      env = envs.DiverseColorSort(
          env, sess,
          self_attention_low_level=True,
          augmented_actions=FLAGS.augmented_action,
          low_level_obs_type=FLAGS.low_level_obs_type,
          low_level_action_type=FLAGS.low_level_action_type)
    elif FLAGS.task_type == 'shapesort':
      print('task is: shapesort')
      del env
      env = ClevrEnv(
          num_object=5, num_objective=10, random_start=True, description_num=30,
          action_type=FLAGS.low_level_action_type, maximum_episode_steps=100,
          rew_shape=False, direct_obs=obs_type!='image', repeat_objective=False,
          fixed_objective=True, obs_type=FLAGS.low_level_obs_type, resolution=64,
          top_down_view=FLAGS.top_down, variable_scene_content=True
      )
      env = envs.DiverseShapeSort(
          env, sess,
          self_attention_low_level=True,
          augmented_actions=FLAGS.augmented_action,
          low_level_obs_type=FLAGS.low_level_obs_type,
          low_level_action_type=FLAGS.low_level_action_type)
    elif FLAGS.task_type == 'colorshapesort':
      print('task is: colorshapesort')
      del env
      env = ClevrEnv(
          num_object=5, num_objective=10, random_start=True, description_num=30,
          action_type=FLAGS.low_level_action_type, maximum_episode_steps=100,
          rew_shape=False, direct_obs=obs_type!='image', repeat_objective=False,
          fixed_objective=True, obs_type=FLAGS.low_level_obs_type, resolution=64,
          top_down_view=FLAGS.top_down, variable_scene_content=True
      )
      env = envs.DiverseColorShapeSort(
          env, sess,
          self_attention_low_level=True,
          augmented_actions=FLAGS.augmented_action,
          low_level_obs_type=FLAGS.low_level_obs_type,
          low_level_action_type=FLAGS.low_level_action_type)

  elif obs_type == 'image':
    print(('getting image based {} env'
          ' hierarchical = {}').format(FLAGS.task_type, not FLAGS.normal_dqn))
    env = envs.get_hl_env(
        FLAGS.task_type, FLAGS.normal_dqn, sess, low_level_duration=5,
        low_level_obs_type=FLAGS.low_level_obs_type,
        low_level_action_type=FLAGS.low_level_action_type,
        top_down_view=FLAGS.top_down)

  logger.configure()

  if FLAGS.direct_obs:
    model = deepq.models.mlp(hiddens=(512, 512))
  else:
    model = cnn_to_mlp(
        convs=[(16, 8, 1), (32, 5, 1), (32, 3, 1)],
        hiddens=[256, 512],
        dueling=bool(FLAGS.dueling),
        self_attention=bool(FLAGS.self_attention_high_level),
        spatial_softmax = bool(FLAGS.spatial_softmax)
    )
  exp_config_str = ['btchsz_{}'.format(FLAGS.batch_size),
                    'bffrsz_{}'.format(FLAGS.buffer_size),
                    'mxstp_{}'.format(FLAGS.max_step),
                    'xpfrc_{}'.format(FLAGS.exp_frac),
                    'finleps_{}'.format(FLAGS.final_eps),
                    'lrn_strts_{}'.format(FLAGS.learning_starts),
                    'trgt_updt_{}'.format(FLAGS.target_update_freq),
                    'gmma_{}'.format(FLAGS.gamma),
                    'prior_{}'.format(FLAGS.prior_alpha),
                    'rw_sh_{}'.format(FLAGS.reward_shaping),
                    'dueling_{}'.format(FLAGS.dueling),
                    'drt_obs_{}'.format(FLAGS.direct_obs),
                    'ac_type_{}'.format(FLAGS.action_type),
                    'rpt_obj_{}'.format(FLAGS.repeat_obj),
                    'pr_rply_{}'.format(FLAGS.prioritized_replay),
                    'stnry_trgt_{}'.format(FLAGS.stationary),
                    'soft_{}'.format(FLAGS.soft),
                    'not_hierarchical_{}'.format(FLAGS.normal_dqn),
                    'augmented_action_{}'.format(FLAGS.augmented_action),
                    'task_{}'.format(FLAGS.task_type),
                    'attention_{}'.format(FLAGS.self_attention_high_level),
                    'sp_softmax_{}'.format(FLAGS.spatial_softmax),
                    'llot_{}'.format(FLAGS.low_level_obs_type),
                    'llat_{}'.format(FLAGS.low_level_action_type),
                    'train_per_step{}'.format(FLAGS.train_step_per_step),
                    'top_down_{}'.format(FLAGS.top_down),
                    'cp_{}'.format(FLAGS.cp)]
  try:
    tf.gfile.MkDir(FLAGS.save_dir)
  except tf.gfile.Error as e:
    print(e)
  exp_directory = os.path.join(FLAGS.save_dir, '_'.join(exp_config_str))
  try:
    tf.gfile.MkDir(exp_directory)
  except tf.gfile.Error as e:
    print(e)

  print('action space: {}'.format(env.action_space))
  print('start building...')

  act = deepq.learn(
      env,
      q_func=model,
      lr=1e-4,
      max_timesteps=FLAGS.max_step,
      buffer_size=FLAGS.buffer_size,
      exploration_fraction=FLAGS.exp_frac,
      exploration_final_eps=FLAGS.final_eps,
      train_freq=FLAGS.train_freq,
      train_step_per_step=FLAGS.train_step_per_step,
      learning_starts=FLAGS.learning_starts,
      target_network_update_freq=FLAGS.target_update_freq,
      gamma=FLAGS.gamma,
      prioritized_replay=FLAGS.prioritized_replay,
      prioritized_replay_alpha=FLAGS.prior_alpha,
      checkpoint_freq=2000,
      checkpoint_path=exp_directory if FLAGS.save_dir != 'None' else None,
      print_freq=FLAGS.print_freq,
      param_noise=False,
      batch_size=FLAGS.batch_size,
  )


if __name__ == '__main__':
  app.run(main)
