from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import math

FLAGS = flags.FLAGS

flags.DEFINE_string("aug_tfhub_module", None, "TF Hub module handle")

from .reader import generic
from .reader.matrix import apply_fn_matrices

def setup():
    if not FLAGS.aug_tfhub_module:
        raise app.UsageError("--aug_tfhub_module has to be specified!")

def apply(features, dim, samples, labels):
    setup()

    module = hub.load(FLAGS.aug_tfhub_module)

    def fn(images):
        # TODO dynamically check input signature
        images = [tf.image.resize(tf.cast(x, tf.float32), [FLAGS.resize_height, FLAGS.resize_width]) for x in images]
        images = tf.stack(images)
        size=[32, 32]
        compute = lambda z: module.signatures["from_decoded_images"](images=z, image_size=tf.constant(size), augmentation=tf.constant(True))["default"]
        return compute(images).numpy()

    return apply_fn_matrices(features, dim, samples, labels, fn)
