#!/usr/bin/env python
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Decode from trained T2T models.

This binary performs inference using the Estimator API.

Example usage to decode from dataset:

  t2t-decoder \
      --data_dir ~/data \
      --problems=algorithmic_identity_binary40 \
      --model=transformer
      --hparams_set=transformer_base

Set FLAGS.decode_interactive or FLAGS.decode_from_file for alternative decode
sources.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

from tensor2tensor.utils import decoding
from tensor2tensor.utils import trainer_utils
from tensor2tensor.utils import usr_dir

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("output_dir", "", "Training directory to load from.")
flags.DEFINE_string("decode_from_file", None, "Path to decode file")
flags.DEFINE_string("decode_to_file", None,
                    "Path prefix to inference output file")
flags.DEFINE_string("decode_targets_file", None,
                    "Path to decode ground turth")
flags.DEFINE_bool("decode_interactive", False,
                  "Interactive local inference mode.")
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
flags.DEFINE_string("t2t_usr_dir", "",
                    "Path to a Python module that will be imported. The "
                    "__init__.py file should include the necessary imports. "
                    "The imported files should contain registrations, "
                    "e.g. @registry.register_model calls, that will then be "
                    "available to the t2t-decoder.")
flags.DEFINE_string("master", "", "Address of TensorFlow master.")
flags.DEFINE_string("schedule", "train_and_evaluate",
                    "Must be train_and_evaluate for decoding.")
flags.DEFINE_bool("dedup", False, "Deduplicate the output.")

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  trainer_utils.log_registry()
  trainer_utils.validate_flags()
  assert FLAGS.schedule == "train_and_evaluate"
  data_dir = os.path.expanduser(FLAGS.data_dir)
  output_dir = os.path.expanduser(FLAGS.output_dir)

  hparams = trainer_utils.create_hparams(
      FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams)
  trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
  estimator, _ = trainer_utils.create_experiment_components(
      data_dir=data_dir,
      model_name=FLAGS.model,
      hparams=hparams,
      run_config=trainer_utils.create_run_config(output_dir))

  decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
  decode_hp.add_hparam("shards", FLAGS.decode_shards)
  decode_hp.add_hparam("shard_id", FLAGS.worker_id)
  if FLAGS.decode_interactive:
    decoding.decode_interactively(estimator, decode_hp)
  elif FLAGS.decode_from_file:
    decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp,
                              FLAGS.decode_to_file, FLAGS.decode_targets_file)
  else:
    decoding.decode_from_dataset(
        estimator,
        FLAGS.problems.split("-"),
        decode_hp,
        decode_to_file=FLAGS.decode_to_file,
        dataset_split="test" if FLAGS.eval_use_test_set else None)


if __name__ == "__main__":
  tf.app.run()
