# coding=utf-8
import random
import numpy as np
import tensorflow as tf
from .dqn_algo import DqnAlgo

ALPHA = 1e-4 #0.001
BETA_MAX = 1.0 #0.7
BETA_MIN = 1e-4
GAMMA = 0.9     # reward discount
TAU = 0.01      # soft replacement
RENDER = False
# BATCH_SIZE = 1024
# EPSILON = 0.1
EPSILON = 0.0001
EPSILON_MAX = 1.0
# EPSILON_MIN = 0.05
# EPSILON_MIN = 0.1
EPSILON_MIN = 0.1 #0.02
# EPSILON_DECAY_EPISODE = 40000
EPSILON_DECAY_EPISODE = 6000
BETA_DECAY_EPISODE = 10000 #20000 #60000 #20000 #60000

# best setting for 1-agent DQN
# epsilon decay: exponential, 1.0 to 0.02 (0.2 for many agents such as 6,7,8), 60000 episodes
# soft target update using ema, tau = 0.01
# learning rate alpha = 1e-3, adam optimizer

class DqnMypPbrsAlgo(DqnAlgo):
    def __init__(self, sess, graph, state_dim, action_dim, epsilon_decay=None,
                 is_test=False, algo_name="dqn_myp_pbrs"):
        super(DqnMypPbrsAlgo, self).__init__(sess, graph, state_dim, action_dim,
                                             epsilon_decay, is_test, algo_name)

    def init_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name="s_")
                self.action_phd = tf.placeholder(tf.int32, [None, ], name="a_")
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim],
                                                      name="s_prime_")
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name="reward_")
                self.done_phd = tf.placeholder(tf.float32, [None, 1], name="done_")

                # build Q-network and the target network
                self.netQ = self.build_q_net(self.state_phd)

                # define algorithm related operations
                # Q(s,a) op, note that the action should be one-hot
                self.q_sa_op = tf.reduce_sum(self.netQ * tf.one_hot(self.action_phd, self.action_dim),
                                             axis=1, keepdims=True,
                                             name="q_sa_")
                self.argmax_qsa = tf.argmax(self.netQ, axis=1, name="argmax_qsa_")

                # naive target update
                # self.netQ_target = self.build_q_net_target(self.state_prime_phd)

                # another version of target update
                q_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Qnet_")
                ema = tf.train.ExponentialMovingAverage(decay=1 - TAU)
                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update = [ema.apply(q_params)]  # soft update operation

                with tf.control_dependencies(target_update):
                    self.netQ_target = self.build_q_net(self.state_prime_phd, reuse=True, custom_getter=ema_getter)
                    self.q_max_sp_op = tf.reduce_max(self.netQ_target,
                                                     name="q_max_sp_",
                                                     axis=1,
                                                     keepdims=True)
                    q_target = self.reward_phd + GAMMA * (1 - self.done_phd) * self.q_max_sp_op
                    self.target_op = tf.stop_gradient(q_target)
                    self.td_error_op = self.target_op - self.q_sa_op

                    self.squared_error_op = tf.square(self.td_error_op)
                    self.loss = tf.reduce_mean(self.squared_error_op, name="loss_")

                    # define the optimizer
                    self.optimizer = tf.train.AdamOptimizer(ALPHA).minimize(self.loss,
                                                                            name="adam_optimizer_")

                with tf.name_scope("q_loss"):
                    tf.summary.scalar("q_loss_", self.loss)

                self.merged = tf.summary.merge_all()
                self.saver = tf.train.Saver(max_to_keep=100)

    def learn(self, state_batch, action_batch, reward_batch,
              state_prime_batch, done_batch, **kwargs):
        if self.is_test:
            return

        """
            get the shaping reward batches
            of the current state and the next state
            note that c_batch stores the shaping reward of the current (s,a) pair
            but c_sp_batch store the shaping rewards of all actions in the next state s'
            we use the shaping rewards as potential functions
        """
        c_batch = kwargs.get("c_batch")
        c_sp_batch = kwargs.get("c_sp_batch")

        """
            get the max action of the next state according to the Q-network (not target)
        """
        with self.sess.graph.as_default():
            max_act_sp_batch = self.sess.run(self.argmax_qsa, feed_dict={self.state_phd: state_prime_batch})

        # print("c_sp_batch before trans is {}".format(c_sp_batch))
        c_sp_trans_batch = np.array([0] * len(c_batch))
        for x in range(len(c_sp_batch)):
            c_sp_trans_batch[x] = c_sp_batch[x][max_act_sp_batch[x]]
        c_sp_trans_batch = c_sp_trans_batch.reshape([-1, 1])
        # print("c_sp_batch after trans is {}".format(c_sp_trans_batch))

        reward_batch = reward_batch + GAMMA * c_sp_trans_batch - c_batch
        with self.sess.graph.as_default():
            _, summary = self.sess.run([self.optimizer, self.merged],
                                       feed_dict={self.state_phd: state_batch,
                                                  self.action_phd: action_batch,
                                                  self.state_prime_phd: state_prime_batch,
                                                  self.reward_phd: reward_batch,
                                                  self.done_phd: done_batch
                                                  })

        self.train_writer.add_summary(summary, self.update_cnt)
        self.update_cnt += 1



