# coding=utf-8
# Copyright 2019 The Hal Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for extracting visual features."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf


color_filter = [[0.57735027, -0.57735027, -0.57735027],
                [-0.57735027, -0.57735027, 0.57735027],
                [-0.57735027, 0.57735027, -0.57735027],
                [-0.98162999, 0.163605, 0.098163],
                [0.06311944, -0.9467916, 0.3155972]]


two_statement_path = None # film params layer 1 for standard setting

diverse_layer1_statement_path =  None  # film params layer 1 for diverse setting

diverse_layer2_statement_path = None  # film params layer 2 for diverse setting

fixed_object_ckpt_path = None  # checkpoint for standard setting

diverse_object_ckpt_path = None  # checkpoint for diverse setting




def one_by_one_conv(img, kernel):
  img_flat = img.reshape([-1, 3])
  inner = np.array(kernel).dot(img_flat.T)
  return inner.reshape(img.shape[:-1])


def hard_coded_conv(img):
  all_img = []
  for f in color_filter:
    test_img = one_by_one_conv(img, f)
    test_img[test_img < 0] = 0.
    test_img /= 0.01
    all_img.append(test_img)
  return np.stack(all_img, axis=-1)


class HardCodedExtractor(object):

  def __init__(self, sess):
    self._sess = sess
    self.img_ph = tf.placeholder(tf.float32, [1, 64, 64, 5])
    self.spsm = tf.contrib.layers.spatial_softmax(self.img_ph, trainable=False)
    self._initialize()

  def _initialize(self):
    temperatures = []
    for var in tf.global_variables():
      if 'temperature' in var.name: temperatures.append(var)
    print('temperature variables: {}'.format(temperatures))
    self._sess.run(tf.initialize_variables(temperatures))

  def get_feature(self, img):
    filtered_img = hard_coded_conv(img)
    return self._sess.run(self.spsm, {self.img_ph: [filtered_img]})[0]


class FilmFeatureExtractor(object):

  def __init__(self, sess, num_statements=2, res=64, channel=3, diverse=False):
    self._sess = sess
    self._num_statements = num_statements
    self._film_param_path = two_statement_path
    if diverse:
      self._film_param_path = diverse_layer1_statement_path
    with tf.gfile.GFile(self._film_param_path, mode='r') as f:
      self._all_film_params = np.load(f)
    self._res = res
    self._channel = channel
    self._input_image_ph = tf.placeholder(tf.float32, [1, res, res, channel])
    self._conv_layer_config = [(48, 8, 2), (128, 5, 2), (64, 3, 1)]
    self.obs_shape = self._num_statements * 48
    if diverse:
      self.ckpt_path = diverse_object_ckpt_path
    else:
      self.ckpt_path = fixed_object_ckpt_path
    self._spsm = self._build()
    self._initialize()

  def get_feature(self, img):
    return self._sess.run(self._spsm, {self._input_image_ph: [img]})[0]

  def _build(self):
    with tf.variable_scope('film_extractor', reuse=tf.AUTO_REUSE):
      all_features = []
      for i in range(self._num_statements):
        film_param = tf.constant(self._all_film_params[i][None, :])
        feature = self._build_single_conv_feature(film_param)
        all_features.append(feature)
      all_features = tf.concat(all_features, axis=-1)/0.001
      return tf.contrib.layers.spatial_softmax(all_features, trainable=False)

  def _build_single_conv_feature(self, film_params):
    with tf.variable_scope('conv_feature', reuse=tf.AUTO_REUSE):
      cfg = self._conv_layer_config[0]
      out = tf.layers.conv2d(
          self._input_image_ph, cfg[0], cfg[1], cfg[2], padding='SAME')
      out = tf.layers.batch_normalization(
          out, center=False, scale=False, training=False)
      gamma, beta = tf.split(film_params, 2, axis=1)
      out *= tf.expand_dims(tf.expand_dims(gamma, 1), 1)
      out += tf.expand_dims(tf.expand_dims(beta, 1), 1)
      out = tf.nn.relu(out)
      return out

  def _initialize(self):
    variables = tf.global_variables('film_extractor')
    self._sess.run(tf.initialize_variables(variables))
    var_name = [
        u'model/batch_normalization/moving_mean',
        u'model/batch_normalization/moving_variance',
        u'model/conv2d/bias',
        u'model/conv2d/kernel',
    ]
    variable_list = tf.global_variables('film_extractor/conv_feature')
    variable_list = sorted(variable_list, key=lambda x: x.name)
    var_map = {k: c for k, c in zip(var_name, variable_list)}
    saver_1 = tf.train.Saver(var_list=var_map)
    saver_1.restore(
        self._sess, self.ckpt_path
    )


class DiverseFilmFeatureExtractor(object):

  def __init__(self, sess, num_statements=2, res=64, channel=3, state_fn=None):
    self._sess = sess
    self._num_statements = num_statements
    with tf.gfile.GFile(diverse_layer1_statement_path, mode='r') as f:
      self._all_film_params_layer1 = np.load(f)
    with tf.gfile.GFile(diverse_layer2_statement_path, mode='r') as f:
      self._all_film_params_layer2 = np.load(f)
    self._res = res
    self._channel = channel
    self._input_image_ph = tf.placeholder(tf.float32, [1, res, res, channel])
    self._conv_layer_config = [(48, 8, 2), (128, 5, 2), (64, 3, 1)]
    self.obs_shape = self._num_statements * 48 * 2
    self._spsm = self._build()
    self._initialize()

  def get_feature(self, img):
    return self._sess.run(self._spsm, {self._input_image_ph: [img]})[0]

  def _build(self):
    with tf.variable_scope('film_extractor', reuse=tf.AUTO_REUSE):
      all_features = []
      for i in range(self._num_statements):
        film_param = tf.constant(self._all_film_params_layer1[i][None, :])
        feature = self._build_single_conv_feature(film_param)
        all_features.append(feature)
      all_features = tf.concat(all_features, axis=-1)/0.01
      return tf.contrib.layers.spatial_softmax(all_features, trainable=False)

  def _build_single_conv_feature(self, film_params):
    with tf.variable_scope('conv_feature', reuse=tf.AUTO_REUSE):
      cfg = self._conv_layer_config[0]
      out = tf.layers.conv2d(
          self._input_image_ph, cfg[0], cfg[1], cfg[2], padding='SAME')
      out = tf.layers.batch_normalization(
          out, center=False, scale=False, training=False)
      gamma, beta = tf.split(film_params, 2, axis=1)
      out *= tf.expand_dims(tf.expand_dims(gamma, 1), 1)
      out += tf.expand_dims(tf.expand_dims(beta, 1), 1)
      out = tf.nn.relu(out)
      return out

  def _initialize(self):
    variables = tf.global_variables('film_extractor')
    self._sess.run(tf.initialize_variables(variables))
    var_name = [
        u'model/batch_normalization/moving_mean',
        u'model/batch_normalization/moving_variance',
        u'model/conv2d/bias',
        u'model/conv2d/kernel',
    ]
    variable_list = tf.global_variables('film_extractor/conv_feature')
    variable_list = sorted(variable_list, key=lambda x: x.name)
    var_map = {k: c for k, c in zip(var_name, variable_list)}
    saver_1 = tf.train.Saver(var_list=var_map)
    saver_1.restore(
        self._sess, diverse_object_ckpt_path
    )


class DiverseGTFeatureExtractor(object):
  """
  Multistatement variant.
  """
  def __init__(self, sess, num_statements=2, res=64, channel=3, state_fn=None):
    def state_fn_sorted():
      state = state_fn()
      state = np.array(sorted(state, key=lambda x: x[0]))
      return state.flatten()
    self.state_fn = state_fn_sorted
    self.obs_shape = self.state_fn().shape[0]

  def get_feature(self, img):
    del img
    feature = self.state_fn()
    return feature


class MultiPatchDiverseFilmFeatureExtractor(object):

  def __init__(self, sess, num_statements=2, res=64, channel=3, state_fn=None):
    self._sess = sess
    self._num_statements = num_statements
    with tf.gfile.GFile(diverse_layer1_statement_path, mode='r') as f:
      self._all_film_params_layer1 = np.load(f)
    with tf.gfile.GFile(diverse_layer2_statement_path, mode='r') as f:
      self._all_film_params_layer2 = np.load(f)
    self._res = res
    self._channel = channel
    self._input_image_ph = tf.placeholder(tf.float32, [1, res, res, channel])
    self._conv_layer_config = [(48, 8, 2), (128, 5, 2), (64, 3, 1)]
    self.obs_shape = self._num_statements * 48 * 2
    self._spsm = self._build()
    self._initialize()

  def get_feature(self, img):
    return self._sess.run(self._spsm, {self._input_image_ph: [img]})[0]

  def _build(self):
    with tf.variable_scope('film_extractor', reuse=tf.AUTO_REUSE):
      all_features = []
      for i in range(self._num_statements):
        film_param = tf.constant(self._all_film_params_layer1[i][None, :])
        feature = self._build_single_conv_feature(film_param)
        all_features.append(feature)
      all_features = tf.concat(all_features, axis=-1)/0.001
      return tf.contrib.layers.spatial_softmax(all_features, trainable=False)

  def _build_single_conv_feature(self, film_params):
    with tf.variable_scope('conv_feature', reuse=tf.AUTO_REUSE):
      cfg = self._conv_layer_config[0]
      out = tf.layers.conv2d(
          self._input_image_ph, cfg[0], cfg[1], cfg[2], padding='SAME')
      out = tf.layers.batch_normalization(
          out, center=False, scale=False, training=False)
      gamma, beta = tf.split(film_params, 2, axis=1)
      out *= tf.expand_dims(tf.expand_dims(gamma, 1), 1)
      out += tf.expand_dims(tf.expand_dims(beta, 1), 1)
      out = tf.nn.relu(out)
      return out

  def _initialize(self):
    variables = tf.global_variables('film_extractor')
    self._sess.run(tf.initialize_variables(variables))
    var_name = [
        u'model/batch_normalization/moving_mean',
        u'model/batch_normalization/moving_variance',
        u'model/conv2d/bias',
        u'model/conv2d/kernel',
    ]
    variable_list = tf.global_variables('film_extractor/conv_feature')
    variable_list = sorted(variable_list, key=lambda x: x.name)
    var_map = {k: c for k, c in zip(var_name, variable_list)}
    saver_1 = tf.train.Saver(var_list=var_map)
    saver_1.restore(
        self._sess, diverse_object_ckpt_path
    )


class DiverseFilmFeatureExtractorV2(object):

  def __init__(self, sess, num_statements=2, res=64, channel=3, state_fn=None):
    self._sess = sess
    self._num_statements = num_statements
    with tf.gfile.GFile(diverse_layer1_statement_path, mode='r') as f:
      self._all_film_params_layer1 = np.load(f)
    with tf.gfile.GFile(diverse_layer2_statement_path, mode='r') as f:
      self._all_film_params_layer2 = np.load(f)
    self.null_idx = [1, 4, 7, 10, 12, 13, 14, 17, 18, 20, 21, 27, 28, 29, 31, 32, 37, 39, 40, 45, 46]
    self._res = res
    self._channel = channel
    self._input_image_ph = tf.placeholder(tf.float32, [1, res, res, channel])
    self._conv_layer_config = [(48, 8, 2), (128, 5, 2), (64, 3, 1)]
    self._n_argmax = 2
    self._spsm = self._build()
    self._initialize()
    self.obs_shape = self._num_statements * (48 - len(self.null_idx)) * (4 * 4 + 2) * self._n_argmax
    print('observation shape', self.obs_shape)

  def get_feature(self, img):
    features =  self._sess.run(self._spsm, {self._input_image_ph: [img]})[0]
    gray = np.dot(img[Ellipsis,:3], [0.2989, 0.5870, 0.1140])
    loc = []
    n_features = features.shape[-1]
    for i in range(n_features):
      if i in self.null_idx:
        continue
      f = features[:, :, i]
      for _ in range(self._n_argmax):
        idx = np.unravel_index(f.argmax(), f.shape)
        loc.append(np.array(idx)/32.)
        start_0 = min(max(0, idx[0]*2-2), 60)
        start_1 = min(max(idx[1]*2-2, 0), 60)
        loc.append(gray[start_0:start_0+4, start_1:start_1+4].flatten())
        f[idx[0]-3:idx[0]+3, idx[1]-3:idx[1]+3] = 0
    loc = np.concatenate(loc)
    return loc.flatten()

  def _build(self):
    with tf.variable_scope('film_extractor', reuse=tf.AUTO_REUSE):
      all_features = []
      for i in range(self._num_statements):
        film_param = tf.constant(self._all_film_params_layer1[i][None, :])
        feature = self._build_single_conv_feature(film_param)
        all_features.append(feature)
      all_features = tf.concat(all_features, axis=-1)
      return all_features

  def _build_single_conv_feature(self, film_params):
    with tf.variable_scope('conv_feature', reuse=tf.AUTO_REUSE):
      cfg = self._conv_layer_config[0]
      out = tf.layers.conv2d(
          self._input_image_ph, cfg[0], cfg[1], cfg[2], padding='SAME')
      out = tf.layers.batch_normalization(
          out, center=False, scale=False, training=False)
      gamma, beta = tf.split(film_params, 2, axis=1)
      out *= tf.expand_dims(tf.expand_dims(gamma, 1), 1)
      out += tf.expand_dims(tf.expand_dims(beta, 1), 1)
      out = tf.nn.relu(out)
      return out

  def _initialize(self):
    variables = tf.global_variables('film_extractor')
    self._sess.run(tf.initialize_variables(variables))
    var_name = [
        u'model/batch_normalization/moving_mean',
        u'model/batch_normalization/moving_variance',
        u'model/conv2d/bias',
        u'model/conv2d/kernel',
    ]
    variable_list = tf.global_variables('film_extractor/conv_feature')
    variable_list = sorted(variable_list, key=lambda x: x.name)
    var_map = {k: c for k, c in zip(var_name, variable_list)}
    saver_1 = tf.train.Saver(var_list=var_map)
    saver_1.restore(
        self._sess, diverse_object_ckpt_path
    )


class DiverseFilmFeatureExtractorV3(object):

  def __init__(self, sess, num_statements=2, res=64, channel=3, state_fn=None):
    self._sess = sess
    self._num_statements = num_statements
    with tf.gfile.GFile(diverse_layer1_statement_path, mode='r') as f:
      self._all_film_params_layer1 = np.load(f)
    with tf.gfile.GFile(diverse_layer2_statement_path, mode='r') as f:
      self._all_film_params_layer2 = np.load(f)
    self._res = res
    self._channel = channel
    self._input_image_ph = tf.placeholder(tf.float32, [1, res, res, channel])
    self._conv_layer_config = [(48, 8, 2), (128, 5, 2), (64, 3, 1)]
    self._n_argmax = 3
    self._spsm = self._build()
    self._initialize()
    self.obs_shape = self._num_statements * 48 * 3 * self._n_argmax
    print('observation shape', self.obs_shape)

  def get_feature(self, img):
    features =  self._sess.run(self._spsm, {self._input_image_ph: [img]})[0]
    loc = []
    n_features = features.shape[-1]
    for i in range(n_features):
      f = features[:, :, i]
      for _ in range(self._n_argmax):
        idx = np.unravel_index(f.argmax(), f.shape)
        loc.append(np.array(idx)/32.)
        loc.append([f[idx]])
        f[idx[0]-3:idx[0]+3, idx[1]-3:idx[1]+3] = 0
    loc = np.concatenate(loc)
    return loc.flatten()

  def _build(self):
    with tf.variable_scope('film_extractor', reuse=tf.AUTO_REUSE):
      all_features = []
      for i in range(self._num_statements):
        film_param = tf.constant(self._all_film_params_layer1[i][None, :])
        feature = self._build_single_conv_feature(film_param)
        all_features.append(feature)
      all_features = tf.concat(all_features, axis=-1)
      return all_features

  def _build_single_conv_feature(self, film_params):
    with tf.variable_scope('conv_feature', reuse=tf.AUTO_REUSE):
      cfg = self._conv_layer_config[0]
      out = tf.layers.conv2d(
          self._input_image_ph, cfg[0], cfg[1], cfg[2], padding='SAME')
      out = tf.layers.batch_normalization(
          out, center=False, scale=False, training=False)
      gamma, beta = tf.split(film_params, 2, axis=1)
      out *= tf.expand_dims(tf.expand_dims(gamma, 1), 1)
      out += tf.expand_dims(tf.expand_dims(beta, 1), 1)
      out = tf.nn.relu(out)
      return out

  def _initialize(self):
    variables = tf.global_variables('film_extractor')
    self._sess.run(tf.initialize_variables(variables))
    var_name = [
        u'model/batch_normalization/moving_mean',
        u'model/batch_normalization/moving_variance',
        u'model/conv2d/bias',
        u'model/conv2d/kernel',
    ]
    variable_list = tf.global_variables('film_extractor/conv_feature')
    variable_list = sorted(variable_list, key=lambda x: x.name)
    var_map = {k: c for k, c in zip(var_name, variable_list)}
    saver_1 = tf.train.Saver(var_list=var_map)
    saver_1.restore(
        self._sess, diverse_object_ckpt_path
    )
