# coding=utf-8
import random
import numpy as np
import tensorflow as tf
from .dqn_algo import DqnAlgo

ALPHA = 1e-4 #0.001
GAMMA = 0.9     # reward discount
TAU = 0.01      # soft replacement
RENDER = False
# BATCH_SIZE = 1024
# EPSILON = 0.1
EPSILON = 0.0001
EPSILON_MAX = 1.0
EPSILON_MIN = 0.1 #0.02
EPSILON_DECAY_EPISODE = 6000

"""
    DQN algorithm with dynamic potential-based advice (DPBA)
"""
class DqnDpbaAlgo(DqnAlgo):
    def __init__(self, sess, graph, state_dim, action_dim, epsilon_decay=None,
                 is_test=False, algo_name="dqn_dpba"):
        super(DqnDpbaAlgo, self).__init__(sess, graph, state_dim, action_dim, epsilon_decay,
                                          is_test, algo_name)

    def init_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name="s_")
                self.action_phd = tf.placeholder(tf.int32, [None, ], name="a_")
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim],
                                                      name="s_prime_")
                self.action_prime_phd = tf.placeholder(tf.int32, [None, ], name="a_prime_")
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name="reward_")
                self.shaping_reward_phd = tf.placeholder(tf.float32, [None, 1],
                                                         name="shaping_reward_")
                self.done_phd = tf.placeholder(tf.float32, [None, 1], name="done_")

                """
                    build Q-network and Potential network
                """
                self.netQ = self.build_q_net(self.state_phd)
                self.phi_sa = self.build_phi_net(self.state_phd, self.action_phd)

                """
                    define Q(s,a), Phi(s,a), and argmax Q(s,a)
                """
                self.q_sa_op = tf.reduce_sum(self.netQ * tf.one_hot(self.action_phd, self.action_dim),
                                             axis=1, keepdims=True, name="q_sa_")
                self.argmax_qsa = tf.argmax(self.netQ, axis=1, name="argmax_qsa_")

                """
                    build the target Q-network and the shaping target Q-network
                """
                q_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Qnet_")
                phi_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PhiNet_")
                ema = tf.train.ExponentialMovingAverage(decay=1 - TAU)

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update = [ema.apply(q_params)]
                phi_target_update = [ema.apply(phi_params)]

                with tf.control_dependencies(target_update):
                    """
                        define the optimization of Q-network
                    """
                    self.netQ_target = self.build_q_net(self.state_prime_phd, reuse=True, custom_getter=ema_getter)
                    self.phi_spap = self.build_phi_net(self.state_prime_phd, self.action_prime_phd,
                                                       reuse=True, custom_getter=ema_getter)
                    self.q_max_sp_op = tf.reduce_max(self.netQ_target,
                                                     name="q_max_sp_",
                                                     axis=1,
                                                     keepdims=True)
                    Fsaspap = GAMMA * self.phi_spap - self.phi_sa
                    q_target = self.reward_phd + Fsaspap + GAMMA * (1 - self.done_phd) * self.q_max_sp_op
                    self.target_op = tf.stop_gradient(q_target)
                    self.td_error_op = self.target_op - self.q_sa_op
                    self.squared_error_op = tf.square(self.td_error_op)
                    self.loss = tf.reduce_mean(self.squared_error_op, name="loss_")
                    self.optimizer = tf.train.AdamOptimizer(ALPHA).minimize(self.loss,
                                                                            name="adam_optimizer_")

                with tf.control_dependencies(phi_target_update):
                    """
                        define the optimization of Phi-network
                    """
                    phi_target = -self.shaping_reward_phd + GAMMA * (1 - self.done_phd) * self.phi_spap
                    self.phi_target_op = tf.stop_gradient(phi_target)
                    self.phi_td_error_op = self.phi_target_op - self.phi_sa
                    self.phi_squared_error_op = tf.square(self.phi_td_error_op)
                    self.phi_loss = tf.reduce_mean(self.phi_squared_error_op,
                                                   name="phi_loss_")
                    self.phi_optimizer = tf.train.AdamOptimizer(ALPHA).minimize(self.phi_loss,
                                                                                name="phi_adam_optimizer_")

                with tf.name_scope("q_loss"):
                    tf.summary.scalar("q_loss_", self.loss)
                    tf.summary.scalar("phi_loss_", self.phi_loss)

                self.merged = tf.summary.merge_all()
                self.saver = tf.train.Saver(max_to_keep=100)

    def learn(self, state_batch, action_batch, reward_batch,
              state_prime_batch, done_batch, **kwargs):
        if self.is_test:
            return

        """
            get the shaping reward batches of the current state,
            which is used for learning the potential function Phi(s,a)
        """
        c_batch = kwargs.get("c_batch")

        """
            get the max action of the next state according to the Q-network (not target)
        """
        with self.sess.graph.as_default():
            max_act_sp_batch = self.sess.run(self.argmax_qsa, feed_dict={self.state_phd: state_prime_batch})
            # print("The max act sp batch is {}".format(max_act_sp_batch))

        with self.sess.graph.as_default():
            _, _, summary = self.sess.run([self.optimizer, self.phi_optimizer, self.merged],
                                          feed_dict={self.state_phd: state_batch,
                                                     self.action_phd: action_batch,
                                                     self.state_prime_phd: state_prime_batch,
                                                     self.reward_phd: reward_batch,
                                                     self.done_phd: done_batch,
                                                     self.shaping_reward_phd: c_batch,
                                                     self.action_prime_phd: max_act_sp_batch
                                                     })

        self.train_writer.add_summary(summary, self.update_cnt)
        self.update_cnt += 1

    def build_phi_net(self, state_phd, action_phd, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('PhiNet_', reuse=reuse, custom_getter=custom_getter):

            x = 1.0

            # the first hidden layer
            net = tf.layers.dense(tf.concat([state_phd, tf.one_hot(action_phd, self.action_dim)], axis=1), 64,
                                  kernel_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  bias_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  name="l1",
                                  trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            # the second hidden layer
            net = tf.layers.dense(tf.concat([net, tf.one_hot(action_phd, self.action_dim)], axis=1), 32,
                                  kernel_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  bias_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  name="l2",
                                  trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            # the output layer
            # note that the output is a vector which contains
            # Q-values of all actions in one state
            phisa = tf.layers.dense(net, 1, activation=None,
                                    kernel_initializer=tf.random_uniform_initializer(-x / 1000.0, x / 1000.0),
                                    bias_initializer=tf.random_uniform_initializer(-x / 1000.0, x / 1000.0),
                                    name='phi_sa',
                                    trainable=trainable
                                    )
            return phisa
