# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of the language modules of the model"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


def encoder(inputs, embeddings, n_hidden_unit, trainable=True,
    reuse=False, name='encoder', time_major=False,
    cell_collection=None):
  """One layer GRU unit encoder."""
  with tf.variable_scope(name, reuse=reuse):
    input_embedding = tf.nn.embedding_lookup(embeddings, inputs)
    encoder_cell = tf.contrib.rnn.GRUCell(n_hidden_unit, trainable=trainable)
    encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
        encoder_cell, input_embedding, dtype=tf.float32, time_major=time_major,
    )
    return encoder_outputs, encoder_final_state


def decoder(inputs, embeddings, encoder_final_states, n_hidden_unit,
    trainable=True, reuse=False, name='decoder', time_major=False,
    cell_collection=None):
  """One layer GRU unit decoder."""
  with tf.variable_scope(name, reuse=reuse):
    input_embedding = tf.nn.embedding_lookup(embeddings, inputs)
    print(input_embedding)
    decoder_cell = tf.contrib.rnn.GRUCell(n_hidden_unit, trainable=trainable)
    if cell_collection is not None:
      cell_collection.append(decoder_cell)
    decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
        decoder_cell, input_embedding, initial_state = encoder_final_states,
        dtype=tf.float32, time_major=time_major,
    )
    return decoder_outputs, decoder_final_state


def projection(inputs, vocab_n, name='projection', reuse=False):
  with tf.variable_scope(name, reuse=reuse):
    return tf.layers.dense(inputs, vocab_n)
