# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for converting words to integers and vice versa."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re

import numpy as np


def pad_to_max_length(data, max_l=None, eos_token=0):
  EOS = eos_token
  if not max_l:
    max_l = -1
    for p in data:
      max_l = max(max_l, len(p))
  data_ = []
  for p in data:
    if len(p) == max_l:
      data_.append(list(p))
    else:
      p = list(p) + [EOS] * (max_l - len(p))
      data_.append(p)
  return np.array(data_)


def load_vocab_list(vocab_path, add_eos=True):
  vocab_list = open(vocab_path).read().split()
  if vocab_list[0] != 'eos' and add_eos:
    vocab_list = ['eos'] + vocab_list
  return vocab_list


def create_look_up_table(vocab_list):
  vocab2int = {word: i for i, word in enumerate(vocab_list)}
  int2vocab = vocab_list
  return vocab2int, int2vocab


def encode_text(text, lookup_table):
  sentence = re.findall(r"[\w']+|[.,!?;]",  text + ' eos')
  encoded_sentence = []
  for w in sentence:
    encoded_sentence.append(lookup_table[w.lower()])
  return encoded_sentence


def decode_int(int_array, lookup_table, delimeter=' '):
  decoded_sentence = []
  for i in int_array:
    decoded_sentence.append(lookup_table[i])
  return delimeter.join(decoded_sentence)


def encode_text_with_lookup_table(look_up_table):
  return lambda text: encode_text(text, look_up_table)


def decode_with_lookup_table(look_up_table):
  return lambda int_array: decode_int(int_array, look_up_table)
