# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Model Networks for low level policy"""
from __future__ import print_function

import numpy as np
import tensorflow as tf

from hal.models.film import combine_variable_input
from hal.models.film import film_params
from hal.models.film import film_pi_network
from hal.models.film import mlp_pi_network
from hal.models.language_modules import encoder


class Model(object):

  def __init__(self, input_dim, ac_dim, name, vocab_size, embedding_size,
      conv_layer_config, dense_layer_config, encoder_n_unit,
      seq_length, direct_obs, learning_rate=1e-3, reuse=False,
      variable_input=False, trainable_encoder=True):
    self.name = name
    with tf.variable_scope(name, reuse=reuse):
      self.input_dim = input_dim
      print('input dimension: {}'.format(self.input_dim))
      self.ac_dim = ac_dim[0]
      self.embedding = tf.get_variable(
          name='word_embedding',
          shape=(vocab_size, embedding_size),
          dtype=tf.float32, trainable=True)
      self.inputs = tf.placeholder(
          shape=[None]+input_dim, dtype=tf.float32, name='input_ph')
      print('input placeholder: {}'.format(self.inputs))
      self.word_inputs = tf.placeholder(
          shape=[None, None], dtype=tf.int32, name='text_ph')
      self.is_training = tf.placeholder(
          shape=(), dtype=tf.bool, name='training_indicator_ph')
      _, goal_embedding = encoder(
          self.word_inputs, self.embedding, encoder_n_unit,
          trainable=trainable_encoder)
      if variable_input:
        self.inputs_ = combine_variable_input(
            self.inputs, goal_embedding, 64, 128)
      else:
        self.inputs_ = self.inputs

      if not direct_obs:
        self.Q_ = film_pi_network(
            self.inputs_, goal_embedding, self.ac_dim, conv_layer_config,
            dense_layer_config, self.is_training)
      else:
        self.Q_ = mlp_pi_network(self.inputs_, goal_embedding, self.ac_dim,
                                 dense_layer_config)
      self.predict = tf.argmax(self.Q_, axis=-1)
      self.action = tf.placeholder(shape=None, dtype=tf.int32)
      self.action_onehot = tf.one_hot(
          self.action, self.ac_dim, dtype=tf.float32)
      self.Q = tf.reduce_sum(tf.multiply(self.Q_, self.action_onehot), axis=1)
      self.Q_next = tf.placeholder(shape=None, dtype=tf.float32)
      self.loss = tf.losses.huber_loss(self.Q_next, self.Q,
                                       reduction=tf.losses.Reduction.MEAN)
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      train_op = self.optimizer.minimize(self.loss)
      update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.train_op = tf.group([train_op, update_op])
      self.init_op = tf.global_variables_initializer()

  @property
  def variables(self):
    all_vars = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
    return sorted(all_vars, key=lambda x: x.name)


class VariableInputModel(object):

  def __init__(self, input_dim, name, vocab_size, embedding_size,
      des_len, inner_len, encoder_n_unit, per_input_ac_dim,
      learning_rate=1e-4, reuse=False, trainable_encoder=True,
      action_type='perfect', onehot_goal=False, state_goal=False,
      pretrained_embedding=False, max_input_length=-1):
    self.name = name
    with tf.variable_scope(name, reuse=reuse):
      if onehot_goal:
        self.word_inputs = tf.placeholder(
            shape=(None), dtype=tf.int32, name='goal_ph')
        print('goal input for one-hot max len {}'.format(max_input_length))
        one_hot_goal = tf.one_hot(self.word_inputs, max_input_length)
        one_hot_goal.set_shape([None, max_input_length])
        layer_cfg = [max_input_length//8, encoder_n_unit]
        goal_embedding = stack_dense_layer(one_hot_goal, layer_cfg)
      elif state_goal:
        self.word_inputs = tf.placeholder(
            shape=(None, 10), dtype=tf.float32, name='goal_ph')
        print('using state representation for goals')
        layer_cfg = [100, encoder_n_unit]
        goal_embedding = stack_dense_layer(self.word_inputs, layer_cfg)
      else:
        trainable_encoder = trainable_encoder and not pretrained_embedding
        print('the encoder is trainable = {}'.format(trainable_encoder))
        self.embedding = tf.get_variable(
            name='word_embedding',
            shape=(vocab_size, embedding_size),
            dtype=tf.float32, trainable=trainable_encoder)
        self.word_inputs = tf.placeholder(
            shape=[None, None], dtype=tf.int32, name='text_ph')
        _, goal_embedding = encoder(
            self.word_inputs, self.embedding, encoder_n_unit,
            trainable=trainable_encoder)

      # variable number of inputs ([B, ?, di])
      self.inputs = tf.placeholder(
          shape=[None]+input_dim, dtype=tf.float32, name='input_ph')
      print('input placeholder: {}'.format(self.inputs))
      self.is_training = tf.placeholder(
          shape=(), dtype=tf.bool, name='training_indicator_ph')

      if action_type == 'perfect':
        print('using perfect action Q function...')
        self.Q_, self.predict_input, self.predict_action = self.build_Q(
            des_len, inner_len, goal_embedding, per_input_ac_dim)
        self.predict = tf.stack(
            [self.predict_input, self.predict_action], axis=1)

        self.action = tf.placeholder(shape=(None, 2), dtype=tf.int32)
        stacked_indices = tf.concat(
            [tf.expand_dims(tf.range(0, tf.shape(self.action)[0]), axis=1),
             self.action],
            axis=1
        )
        self.Q = tf.gather_nd(self.Q_, stacked_indices)
      elif action_type == 'discrete':
        print('using discrete action Q function...')
        self.ac_dim = per_input_ac_dim[0]
        self.Q_ = self.build_Q_discrete(
            des_len, inner_len, goal_embedding, self.ac_dim)
        self.predict = tf.argmax(self.Q_, axis=-1)
        self.action = tf.placeholder(shape=None, dtype=tf.int32)
        self.action_onehot = tf.one_hot(
            self.action, self.ac_dim, dtype=tf.float32)
        self.Q = tf.reduce_sum(tf.multiply(self.Q_, self.action_onehot), axis=1)

      self.Q_next = tf.placeholder(shape=None, dtype=tf.float32)
      self.loss = tf.losses.huber_loss(
          self.Q_next, self.Q, reduction=tf.losses.Reduction.MEAN)
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      self.train_op = self.optimizer.minimize(self.loss)
      self.init_op = tf.global_variables_initializer()

  @property
  def variables(self):
    all_vars = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
    return sorted(all_vars, key=lambda x: x.name)

  def load_pretrained_encoder(self, sess, path):
    print('=====restoring trained encoder weights from {}...====='.format(path))
    var_list = {'word_embedding': self.embedding}
    for var in self.variables:
      name = var.op.name
      if 'encoder' in name and 'Adam' not in name:
        var_list[name.split(self.name+'/')[1]] = var
    print(var_list)
    print(len(var_list))
    for var in tf.global_variables():
      if 'encoder' in var.name: print(var)
    saver = tf.train.Saver(var_list)
    saver.restore(sess, path)

  def build_Q(self, des_len, inner_len, goal_embedding, per_input_ac_dim):
    num_object = tf.shape(self.inputs)[1]
    tp_concat = vector_tensor_product(self.inputs, self.inputs)
    conv_layer_cfg = [[des_len*8, 1, 1], [des_len*4, 1, 1], [des_len, 1, 1]]
    # [B, ?, ?, des_len]
    tp_concat = stack_conv_layer(tp_concat, conv_layer_cfg)

    # similarity with goal
    goal_key = stack_dense_layer(
        goal_embedding, [inner_len*2, inner_len])  # [B, d_inner]
    goal_key = tf.expand_dims(goal_key, 1)  # [B, 1, d_inner]
    # [B, ?, ?, d_inner]
    obs_query = tf.layers.conv2d(tp_concat, inner_len, 1, padding='same')
    # [B, ?*?, d_inner]
    obs_query = tf.reshape(obs_query, [-1, num_object**2, inner_len])
    obs_query_t = tf.transpose(obs_query, perm=(0, 2, 1))  # [B, d_inner, ?*?]
    inner = tf.matmul(goal_key, obs_query_t)  # [B, 1, ?*?]
    weight = tf.nn.softmax(inner, axis=-1)  # [B, 1, ?*?]
    prod = tf.matmul(
        weight,
        tf.reshape(tp_concat, [-1, num_object**2, des_len]))  # [B, 1, des_len]

    goal_embedding_ = tf.expand_dims(goal_embedding, 1)  # [B, 1, dg]
    # [B, ?, dg]
    goal_embedding_ = tf.tile(goal_embedding_, multiples=[1, num_object, 1])
    # [B, ?, des_len]
    pair_wise_summary = tf.tile(prod, multiples=[1, num_object, 1])
    # [B, ?, des_len+di+dg]
    augemented_inputs = tf.concat(
        [self.inputs, pair_wise_summary, goal_embedding_], axis=-1)
    # [B, ?, 1, des_len+di+dg]
    augemented_inputs = tf.expand_dims(augemented_inputs, axis=2)
    conv_layer_cfg = [
        [per_input_ac_dim*64, 1, 1],
        [per_input_ac_dim*64, 1, 1],
        [per_input_ac_dim, 1, 1]
    ]
    # [B, ?, per_input_ac_dim]
    Q_out = tf.squeeze(
        stack_conv_layer(augemented_inputs, conv_layer_cfg), axis=2)
    input_max_q = tf.reduce_max(Q_out, axis=2)
    input_select = tf.argmax(input_max_q, axis=1)
    action_max_q = tf.reduce_max(Q_out, axis=1)
    action_select = tf.argmax(action_max_q, axis=1)

    return Q_out, input_select, action_select

  def build_Q_discrete(
      self, des_len, inner_len, goal_embedding, ac_dim):
    num_object = tf.shape(self.inputs)[1]
    tp_concat = vector_tensor_product(self.inputs, self.inputs)
    conv_layer_cfg = [[des_len*8, 1, 1], [des_len*4, 1, 1], [des_len, 1, 1]]
    # [B, ?, ?, des_len]
    tp_concat = stack_conv_layer(tp_concat, conv_layer_cfg)

    # similarity with goal
    goal_key = stack_dense_layer(
        goal_embedding, [inner_len*2, inner_len])  # [B, d_inner]
    goal_key = tf.expand_dims(goal_key, 1)  # [B, 1, d_inner]
    # [B, ?, ?, d_inner]
    obs_query = tf.layers.conv2d(tp_concat, inner_len, 1, padding='same')
    # [B, ?*?, d_inner]
    obs_query = tf.reshape(obs_query, [-1, num_object**2, inner_len])
    obs_query_t = tf.transpose(obs_query, perm=(0, 2, 1))  # [B, d_inner, ?*?]
    inner = tf.matmul(goal_key, obs_query_t)  # [B, 1, ?*?]
    weight = tf.nn.softmax(inner, axis=-1)  # [B, 1, ?*?]
    prod = tf.matmul(
        weight,
        tf.reshape(tp_concat, [-1, num_object**2, des_len]))  # [B, 1, des_len]
    goal_embedding_ = tf.expand_dims(goal_embedding, 1)  # [B, 1, dg]
    # [B, ?, dg]
    goal_embedding_ = tf.tile(goal_embedding_, multiples=[1, num_object, 1])
    # [B, ?, des_len]
    pair_wise_summary = tf.tile(prod, multiples=[1, num_object, 1])
    # [B, ?, des_len+di+dg]
    augemented_inputs = tf.concat(
        [self.inputs, pair_wise_summary, goal_embedding_], axis=-1)
    # [B, ?, 1, des_len+di+dg]
    augemented_inputs = tf.expand_dims(augemented_inputs, axis=2)
    cfg = [[ac_dim//8, 1, 1], [ac_dim//8, 1, 1]]
    heads = []
    for _ in range(8):
      # [B, ?, 1, ac_dim//8]
      head_out = stack_conv_layer(augemented_inputs, cfg)
      weights = tf.layers.conv2d(head_out, 1, 1)  # [B, ?, 1, 1]
      softmax_weights = tf.nn.softmax(weights, axis=1)  # [B, ?, 1, 1]
      heads.append(
          tf.reduce_sum(softmax_weights*head_out, axis=(1,2))
      )
    # heads = 8 X [B, ac_dim//8]
    out = tf.concat(heads, axis=1)  # [B, ac_dim]
    return tf.layers.dense(out, ac_dim)


class ImageModel(object):

  def __init__(self, input_dim, ac_dim, name, vocab_size, embedding_size,
      conv_layer_config, dense_layer_config, encoder_n_unit,
      action_type='discrete', action_parameterization='regular',
      learning_rate=1e-3, reuse=False, trainable_encoder=True,
      temperature=1.0, onehot_goal=False, state_goal=False,
      max_input_length=-1, pretrained_embedding=False,
      use_vqa=False):

    self.name = name
    self.use_vqa = use_vqa
    with tf.variable_scope(name, reuse=reuse):
      if onehot_goal:
        self.word_inputs = tf.placeholder(
            shape=(None), dtype=tf.int32, name='goal_ph')
        print('goal input for one-hot max len {}'.format(max_input_length))
        one_hot_goal = tf.one_hot(self.word_inputs, max_input_length)
        one_hot_goal.set_shape([None, max_input_length])
        layer_cfg = [max_input_length//8, encoder_n_unit]
        goal_embedding = stack_dense_layer(one_hot_goal, layer_cfg)
      elif state_goal:
        self.word_inputs = tf.placeholder(
            shape=(None, 10), dtype=tf.float32, name='goal_ph')
        print('using state representation for goals')
        layer_cfg = [100, encoder_n_unit]
        goal_embedding = stack_dense_layer(self.word_inputs, layer_cfg)
      else:
        trainable_encoder = trainable_encoder and not pretrained_embedding
        print('the encoder is trainable = {}'.format(trainable_encoder))
        self.embedding = tf.get_variable(
            name='word_embedding',
            shape=(vocab_size, embedding_size),
            dtype=tf.float32, trainable=trainable_encoder)
        self.word_inputs = tf.placeholder(
            shape=[None, None], dtype=tf.int32, name='text_ph')
        _, goal_embedding = encoder(
            self.word_inputs, self.embedding, encoder_n_unit,
            trainable=trainable_encoder)

      self.is_training = tf.placeholder(
          shape=(), dtype=tf.bool, name='training_indicator_ph')
      self.input_dim = input_dim
      print('input dimension: {}'.format(self.input_dim))
      self.ac_dim = ac_dim[0]
      self.inputs = tf.placeholder(
          shape=[None]+input_dim, dtype=tf.float32, name='input_ph')
      print('input placeholder: {}'.format(self.inputs))
      self.allowed_action_parameterization = [
          'regular', 'spatial', 'low_rank', 'autoregressive', 'factor', 'spatial_softmax']
      self.allowed_action_type = ['perfect', 'discrete']

      # self.Q_ = film_pi_network(
      #     self.inputs_, goal_embedding, self.ac_dim, conv_layer_config,
      #     dense_layer_config, self.is_training)
      self.temp = temperature
      self.action_type = action_type
      self.action_parameterization = action_parameterization
      self.Q_ = self._build_q(
          self.inputs, goal_embedding, conv_layer_config, dense_layer_config,
          self.ac_dim, action_type=self.action_type,
          action_parameterization=self.action_parameterization)
      print('q value: {}'.format(self.Q_))

      # setting up predicted action
      if self.action_type != 'discrete':
        self.predict = tf.argmax(self.Q_, axis=-1)
      else:
        q_minus_max = self.Q_ - tf.reduce_max(self.Q_, axis=-1, keepdims=True)
        self.predict = tf.multinomial(tf.exp(q_minus_max/self.temp), 1)[:, 0]

      self.action = tf.placeholder(shape=None, dtype=tf.int32)
      self.action_onehot = tf.one_hot(
          self.action, self.ac_dim, dtype=tf.float32)
      self.Q = tf.reduce_sum(tf.multiply(self.Q_, self.action_onehot), axis=1)
      self.Q_next = tf.placeholder(shape=None, dtype=tf.float32)
      self.loss = tf.losses.huber_loss(self.Q_next, self.Q,
                                       reduction=tf.losses.Reduction.MEAN)
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      train_op = self.optimizer.minimize(self.loss)
      update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.train_op = tf.group([train_op, update_op])

      if use_vqa:
        self.answer_ph = tf.placeholder(shape=(None), dtype=tf.int32)
        onehot_answer = tf.one_hot(self.answer_ph, 2, dtype=tf.float32)
        vqa_loss = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=onehot_answer, logits=self.vqa_out, axis=-1)
        self.vqa_loss = tf.reduce_mean(vqa_loss)
        self.train_vqa_op = self.optimizer.minimize(0.1 * self.vqa_loss)
        self.vqa_accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(
                    self.answer_ph,
                    tf.cast(tf.argmax(self.vqa_out, axis=-1), dtype=tf.int32)
                ),
                tf.float32
            )
        )

      self.init_op = tf.global_variables_initializer()

  def load_pretrained_encoder(self, sess, path):
    print('=====restoring trained encoder weights from {}...====='.format(path))
    var_list = {'word_embedding': self.embedding}
    for var in self.variables:
      name = var.op.name
      if 'encoder' in name and 'Adam' not in name:
        print(name.split(self.name+'/')[1])
        var_list[name.split(self.name+'/')[1]] = var
    # for v in var_list:
    #   print(v, '              ', var_list[v])
    print('number of var', len(var_list))
    for var in tf.global_variables():
      if 'encoder' in var.name: print(var)
    saver = tf.train.Saver(var_list)
    saver.restore(sess, path)
    print('=====restored trained encoder======')

  def _build_q(self, obs, goal_embedding, conv_layer_config, dense_layer_config,
      ac_dim, spatial_dim1=10, spatial_dim2=10, action_type='direct',
      factors=None, action_parameterization='regular'):
    assert action_parameterization in self.allowed_action_parameterization
    assert action_type in self.allowed_action_type

    # n_layer_channel = [layer_config[0] for layer_config in conv_layer_config]
    n_layer_channel = []
    for layer_config in conv_layer_config:
      if layer_config[0] > 0:
        n_layer_channel.append(layer_config[0])

    layer_film_params = film_params(goal_embedding, n_layer_channel)
    out = obs
    # building convnet
    activations = []
    for cfg, param in zip(conv_layer_config, layer_film_params):
      if cfg[0] < 0:
        out = tf.layers.conv2d(out, -cfg[0], cfg[1], cfg[2], padding='SAME')
        out = tf.nn.relu(out)
      else:
        out = tf.layers.conv2d(out, cfg[0], cfg[1], cfg[2], padding='SAME')
        out = tf.layers.batch_normalization(
            out, center=False, scale=False, training=self.is_training)
        gamma, beta = tf.split(param, 2, axis=1)
        out *= tf.expand_dims(tf.expand_dims(gamma, 1), 1)
        out += tf.expand_dims(tf.expand_dims(beta, 1), 1)
        out = tf.nn.relu(out)
        activations.append(out)

    if self.use_vqa:
      with tf.variable_scope('vqa'):
        vqa_out = out
        vqa_out = tf.layers.conv2d(vqa_out, 256, 1, 1)
        vqa_out = tf.reduce_max(vqa_out, axis=(1,2))
        vqa_out = tf.nn.relu(tf.layers.dense(vqa_out, 512))
        vqa_out = tf.nn.relu(tf.layers.dense(vqa_out, 512))
        self.vqa_out = tf.layers.dense(vqa_out, 2)

    out_shape = out.get_shape()

    if action_parameterization == 'regular':
      out = tf.reshape(out, (-1, np.prod(out_shape[1:])))
      for cfg in dense_layer_config:
        out = tf.nn.relu(tf.layers.dense(out, cfg))
      return tf.layers.dense(out, ac_dim)

    if action_parameterization == 'factor' and action_type == 'discrete':
      if not factors: factors = [8, 10, 10]
      # [B, s1*s2, s3]
      out = tf.reshape(out, (-1, np.prod(out_shape[1:-1]), out_shape[-1]))
      projection_mat = tf.get_variable(
          name='projection_matrix',
          shape=(sum(factors), np.prod(out_shape[1:-1])),
          dtype=tf.float32, trainable=True)
      projection_mat = tf.expand_dims(projection_mat, axis=0)
      projection_mat = tf.tile(projection_mat, [tf.shape(out)[0], 1, 1])
      out = tf.matmul(projection_mat, out)  # [B, sum(fac), s3]
      fac1, fac2, fac3 = tf.split(out, factors, axis=1)
      out = tensor_concat(fac1, fac2, fac3)  # [B, f1, f2, f3, s3]
      # [B, 800, s3*3]
      out = tf.reshape(out, [-1, np.prod(factors), out_shape[-1]*3])
      print('tensor concat: {}'.format(out))
      goal_tile = tf.expand_dims(
          tf.layers.dense(goal_embedding, out_shape[-1]), 1)  # [B, 1, s3]
      print('goal: {}'.format(goal_tile))
      goal_tile = tf.tile(
          goal_tile, multiples=[1, np.prod(factors), 1])
      out = tf.concat([out, goal_tile], axis=-1)
      out = tf.expand_dims(out, axis=1)
      out = tf.nn.relu(tf.layers.conv2d(out, 100, 1, 1))
      out = tf.nn.relu(tf.layers.conv2d(out, 32, 1, 1))
      out = tf.layers.conv2d(out, 1, 1, 1)
      return tf.squeeze(out, axis=[1, 3])

    if action_parameterization == 'spatial_softmax':
      out = tf.contrib.layers.spatial_softmax(out, trainable=True)
      for cfg in dense_layer_config:
        out = tf.nn.relu(tf.layers.dense(out, cfg))
      return tf.layers.dense(out, ac_dim)

    if action_parameterization == 'spatial':
      out = tf.reshape(out, (-1, out_shape[1]*out_shape[2], out_shape[3]))
      projection_mat = tf.get_variable(
          name='projection_matrix',
          shape=(spatial_dim1*spatial_dim2, out_shape[1]*out_shape[2]),
          dtype=tf.float32, trainable=True)
      projection_mat = tf.expand_dims(projection_mat, axis=0)
      projection_mat = tf.tile(projection_mat, [tf.shape(out)[0], 1, 1])
      out = tf.matmul(projection_mat, out)
      out = tf.nn.relu(out)
      out = tf.reshape(out, (-1, spatial_dim1, spatial_dim2, out_shape[3]))
      goal_proj = tf.layers.dense(goal_embedding, out_shape[3])
      goal_tile = tf.expand_dims(tf.expand_dims(goal_proj, 1), 1)
      goal_tile = tf.tile(
          goal_tile, multiples=[1, spatial_dim1, spatial_dim2, 1])
      out = tf.concat([out, goal_tile], axis=-1)
      out = tf.layers.conv2d(
          out, ac_dim//(spatial_dim1*spatial_dim2)*10, 1, 1, padding='SAME')
      out = tf.nn.relu(out)
      out = tf.layers.conv2d(
          out, ac_dim//(spatial_dim1*spatial_dim2), 1, 1, padding='SAME')
      # transpose to follow the same convention as the action set
      out = tf.transpose(out, [0, 3, 1, 2])
      return tf.reshape(out, [-1, ac_dim])

    if action_parameterization == 'sequential':
      if not factors: factors = [10, 10, 8]
      # [B, s1*s2, s3]
      out = tf.reshape(out, (-1, np.prod(out_shape[1:-1]), out_shape[-1]))
      # print_op = tf.print("first out:", out, output_stream=sys.stdout)
      projection_mat = tf.get_variable(
          name='projection_matrix',
          shape=(np.prod(out_shape[1:-1])//2, np.prod(out_shape[1:-1])),
          dtype=tf.float32, trainable=True)
      projection_mat = tf.expand_dims(projection_mat, axis=0)
      projection_mat = tf.tile(projection_mat, [tf.shape(out)[0], 1, 1])
      out = tf.matmul(projection_mat, out)  # [B, s1*s2//2, s3]
      out = tf.reshape(out, [-1, out_shape[1:-1]//2*out_shape[-1]])
      goal_proj = tf.layers.dense(goal_embedding, 16)
      out = tf.concat([out, goal_proj], axis=-1)
      # first axis
      q1 = stack_dense_layer(out, [100, 50, factors[0]])
      a1 = tf.one_hot(tf.argmax(q1, axis=-1), factors[0])
      out = tf.concat([out, a1], axis=-1)
      # second axis
      q2 = stack_dense_layer(out, [100, 50, factors[1]])
      a2 = tf.one_hot(tf.argmax(q2, axis=-1), factors[1])
      out = tf.concat([out, a2], axis=-1)
      # third axis
      q3 = stack_dense_layer(out, [100, 50, factors[2]])
      a3 = tf.one_hot(tf.argmax(q3, axis=-1), factors[2])
      return q1, q2, q3

    raise ValueError('Unrecognized action parameterization')

  @property
  def variables(self):
    all_vars = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
    return sorted(all_vars, key=lambda x: x.name)


#####################################################################
# Tensor operation utilities
#####################################################################
def vector_tensor_product(a, b):
  # a shape: [B, ?, d], b shape: [B, ?, d]
  variable_length = tf.shape(a)[1]  # variable_len = ?
  a = tf.expand_dims(a, axis=2)  # a shape: [B, ?, 1, d]
  b = tf.expand_dims(b, axis=2)  # b shape: [B, ?, 1, d]
  a = tf.tile(a, multiples=[1, 1, variable_length, 1])  # a shape: [B, ?, ?, d]
  b = tf.tile(b, multiples=[1, 1, variable_length, 1])  # b shape: [B, ?, ?, d]
  b = tf.transpose(b, perm=[0, 2, 1, 3])  # b shape: [B, ?, ?, d]
  return tf.concat([a, b], axis=-1)  # shape: [B, ?, ?, 2*d]


def stack_dense_layer(inputs, layer_cfg):
  for cfg in layer_cfg[:-1]:
    inputs = tf.layers.dense(inputs, cfg, activation=tf.nn.relu)
  return  tf.layers.dense(inputs, layer_cfg[-1])


def stack_conv_layer(inputs, layer_cfg, padding='same'):
  for cfg in layer_cfg[:-1]:
    inputs = tf.layers.conv2d(
        inputs, cfg[0], cfg[1], cfg[2], activation=tf.nn.relu, padding=padding)
  final_cfg = layer_cfg[-1]
  return  tf.layers.conv2d(
      inputs, final_cfg[0], final_cfg[1], final_cfg[2], padding=padding)


def image_self_attention(images, inner_prod_dim, feature_dim):
  shape = images.get_shape()
  sp_prod = shape[1]*shape[2]
  key = tf.layers.conv2d(images, inner_prod_dim, 1, 1)  # [B, d1, d2, id]
  query = tf.layers.conv2d(images, inner_prod_dim, 1, 1)  # [B, d1, d2, id]
  feature = tf.layers.conv2d(images, feature_dim, 1, 1)  # [B, d1, d2, fd]
  key = tf.reshape(key, [-1, sp_prod, inner_prod_dim])  # [B, d1*d2, id]
  query = tf.reshape(query, [-1, sp_prod, inner_prod_dim])  # [B, d1*d2, id]
  feature = tf.reshape(feature, [-1, sp_prod, feature_dim])  # [B, d1*d2, fd]
  kq = tf.matmul(key, tf.transpose(query, perm=[0, 2, 1]))  # [B, d1*d2, d1*d2]
  kq_softmax = tf.nn.softmax(kq)
  prod = tf.matmul(kq_softmax, feature)
  return tf.reshape(prod, shape=[-1, shape[1], shape[2], feature_dim])


def tensor_concat(a, b, c):
  # a shape = [B, dc, de], b shape = [B, db, de], c shape = [B, dc, de]
  dim_a, dim_b, dim_c = tf.shape(a)[1], tf.shape(b)[1], tf.shape(c)[1]
  a = tf.expand_dims(a, axis=2)  # [B, da, 1, de]
  b = tf.expand_dims(b, axis=2)  # [B, db, 1, de]
  c = tf.expand_dims(c, axis=2)  # [B, dc, 1, de]
  c = tf.expand_dims(c, axis=3)  # [B, dc, 1, 1, de]
  a = tf.tile(a, multiples=[1, 1, dim_b, 1])  # [B, da, db, de]
  b = tf.tile(b, multiples=[1, 1, dim_a, 1])  # [B, db, da, de]
  c = tf.tile(c, multiples=[1, 1, dim_a, dim_b, 1])  # [B, dc, da, db, de]
  b = tf.transpose(b, perm=[0, 2, 1, 3])  # [B, da, db, de]
  ab = tf.concat([a, b], axis=-1)  # [B, da, db, de*2]
  ab = tf.expand_dims(ab, axis=3)  # [B, da, db, 1, de*2]
  ab = tf.tile(ab, multiples=[1, 1, 1, dim_c, 1])  # [B, da, db, dc, 2*de]
  c = tf.transpose(c, perm=[0, 2, 3, 1, 4])  # [B, da, db, dc, de]
  abc = tf.concat([ab, c], axis=-1)  # [B, da, db, dc, 3*de]
  return tf.identity(abc)

# ============================================================================
# Other baselines
# ============================================================================

class OnehotInputModel(object):

  def __init__(self, input_dim, name, max_input_length,
      des_len, inner_len, encoder_n_unit, per_input_ac_dim,
      learning_rate=1e-4, reuse=False, trainable_encoder=True):
    self.name = name
    with tf.variable_scope(name, reuse=reuse):
      self.word_inputs = tf.placeholder(
          shape=(None), dtype=tf.int32, name='goal_ph')
      one_hot_goal = tf.one_hot(self.word_inputs, max_input_length)
      layer_cfg = [max_input_length//8, encoder_n_unit]
      goal_embedding = stack_dense_layer(one_hot_goal, layer_cfg)

      # variable number of inputs ([B, ?, di])
      self.inputs = tf.placeholder(
          shape=[None]+input_dim, dtype=tf.float32, name='input_ph')
      print('input placeholder: {}'.format(self.inputs))
      self.Q_, self.predict_input, self.predict_action = self.build_Q(
          des_len, inner_len, goal_embedding, per_input_ac_dim)
      self.predict = tf.stack(
          [self.predict_input, self.predict_action], axis=1)

      self.action = tf.placeholder(shape=(None, 2), dtype=tf.int32)
      stacked_indices = tf.concat(
          [tf.expand_dims(tf.range(0, tf.shape(self.action)[0]), axis=1),
           self.action],
          axis=1
      )
      self.Q = tf.gather_nd(self.Q_, stacked_indices)
      self.Q_next = tf.placeholder(shape=None, dtype=tf.float32)
      self.loss = tf.losses.huber_loss(
          self.Q_next, self.Q, reduction=tf.losses.Reduction.MEAN)
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      self.train_op = self.optimizer.minimize(self.loss)
      self.init_op = tf.global_variables_initializer()

  @property
  def variables(self):
    all_vars = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
    return sorted(all_vars, key=lambda x: x.name)

  def build_Q(self, des_len, inner_len, goal_embedding, per_input_ac_dim):
    num_object = tf.shape(self.inputs)[1]
    tp_concat = vector_tensor_product(self.inputs, self.inputs)
    conv_layer_cfg = [[des_len*8, 1, 1], [des_len*4, 1, 1], [des_len, 1, 1]]
    # [B, ?, ?, des_len]
    tp_concat = stack_conv_layer(tp_concat, conv_layer_cfg)

    # similarity with goal
    goal_key = stack_dense_layer(
        goal_embedding, [inner_len*2, inner_len])  # [B, d_inner]
    goal_key = tf.expand_dims(goal_key, 1)  # [B, 1, d_inner]
    # [B, ?, ?, d_inner]
    obs_query = tf.layers.conv2d(tp_concat, inner_len, 1, padding='same')
    # [B, ?*?, d_inner]
    obs_query = tf.reshape(obs_query, [-1, num_object**2, inner_len])
    obs_query_t = tf.transpose(obs_query, perm=(0, 2, 1))  # [B, d_inner, ?*?]
    inner = tf.matmul(goal_key, obs_query_t)  # [B, 1, ?*?]
    weight = tf.nn.softmax(inner, axis=-1)  # [B, 1, ?*?]
    prod = tf.matmul(
        weight,
        tf.reshape(tp_concat, [-1, num_object**2, des_len]))  # [B, 1, des_len]

    goal_embedding_ = tf.expand_dims(goal_embedding, 1)  # [B, 1, dg]
    # [B, ?, dg]
    goal_embedding_ = tf.tile(goal_embedding_, multiples=[1, num_object, 1])
    # [B, ?, des_len]
    pair_wise_summary = tf.tile(prod, multiples=[1, num_object, 1])
    # [B, ?, des_len+di+dg]
    augemented_inputs = tf.concat(
        [self.inputs, pair_wise_summary, goal_embedding_], axis=-1)
    # [B, ?, 1, des_len+di+dg]
    augemented_inputs = tf.expand_dims(augemented_inputs, axis=2)
    conv_layer_cfg = [
        [per_input_ac_dim*64, 1, 1],
        [per_input_ac_dim*64, 1, 1],
        [per_input_ac_dim, 1, 1]
    ]
    # [B, ?, per_input_ac_dim]
    Q_out = tf.squeeze(
        stack_conv_layer(augemented_inputs, conv_layer_cfg), axis=2)
    input_max_q = tf.reduce_max(Q_out, axis=2)
    input_select = tf.argmax(input_max_q, axis=1)
    action_max_q = tf.reduce_max(Q_out, axis=1)
    action_select = tf.argmax(action_max_q, axis=1)

    return Q_out, input_select, action_select


class StateGoalModel(OnehotInputModel):
  def __init__(self, input_dim, name, max_input_length,
      des_len, inner_len, encoder_n_unit, per_input_ac_dim,
      learning_rate=1e-4, reuse=False, trainable_encoder=True):
    self.name = name
    with tf.variable_scope(name, reuse=reuse):
      self.word_inputs = tf.placeholder(
          shape=(None, 10), dtype=tf.float32, name='goal_ph')

      layer_cfg = [100, encoder_n_unit]
      goal_embedding = stack_dense_layer(self.word_inputs, layer_cfg)

      # variable number of inputs ([B, ?, di])
      self.inputs = tf.placeholder(
          shape=[None]+input_dim, dtype=tf.float32, name='input_ph')
      print('input placeholder: {}'.format(self.inputs))
      self.Q_, self.predict_input, self.predict_action = self.build_Q(
          des_len, inner_len, goal_embedding, per_input_ac_dim)
      self.predict = tf.stack(
          [self.predict_input, self.predict_action], axis=1)

      self.action = tf.placeholder(shape=(None, 2), dtype=tf.int32)
      stacked_indices = tf.concat(
          [tf.expand_dims(tf.range(0, tf.shape(self.action)[0]), axis=1),
           self.action],
          axis=1
      )
      self.Q = tf.gather_nd(self.Q_, stacked_indices)
      self.Q_next = tf.placeholder(shape=None, dtype=tf.float32)
      self.loss = tf.losses.huber_loss(
          self.Q_next, self.Q, reduction=tf.losses.Reduction.MEAN)
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      self.train_op = self.optimizer.minimize(self.loss)
      self.init_op = tf.global_variables_initializer()
