# coding=utf-8
import random
import numpy as np
import tensorflow as tf

ALPHA = 1e-3 #0.001
BETA_MAX = 1.0 #0.7
BETA_MIN = 1e-4
GAMMA = 0.999     # reward discount
TAU = 0.01      # soft replacement
RENDER = False
# BATCH_SIZE = 1024
# EPSILON = 0.1
EPSILON = 0.0001
EPSILON_MAX = 1.0
# EPSILON_MIN = 0.05
# EPSILON_MIN = 0.1
EPSILON_MIN = 0.1 #0.02
# EPSILON_DECAY_EPISODE = 40000
EPSILON_DECAY_EPISODE = 10000 #60000
BETA_DECAY_EPISODE = 10000 #20000 #60000 #20000 #60000

# best setting for 1-agent DQN
# epsilon decay: exponential, 1.0 to 0.02 (0.2 for many agents such as 6,7,8), 60000 episodes
# soft target update using ema, tau = 0.01
# learning rate alpha = 1e-3, adam optimizer

class DqnAlgo(object):
    def __init__(self, sess, graph, state_dim,
                 action_dim, epsilon_decay=None, is_test=False,
                 algo_name="dqn", grad_clip=True):
        # print(a_dim, s_dim)
        self.pointer = 0
        self.sess = sess
        self.graph = graph
        self.algo_name = algo_name
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.epsilon_decay = epsilon_decay
        self.is_test = is_test
        self.grad_clip = grad_clip

        """
            set exploration factor epsilon
        """
        self.init_epsilon()


        """
            initialize neural networks
        """
        self.init_networks()

        self.update_cnt = 0
        self.episode_cnt = 0
        self.episode_reward = []

        self.train_writer = tf.summary.FileWriter("./data/" + self.algo_name + "/summary", self.sess.graph)

        # initialization
        with self.sess.as_default():
            with self.graph.as_default():
                tf.global_variables_initializer().run()

    def init_epsilon(self):
        if self.epsilon_decay == "exponential":
            print("Exponential decay")
            self.epsilon = EPSILON_MAX
            self.epsilon_decay_factor = pow(EPSILON_MIN / EPSILON_MAX, 1.0 / EPSILON_DECAY_EPISODE)
        elif self.epsilon_decay == "linear":
            print("Linear decay")
            self.epsilon = EPSILON_MAX
            self.epsilon_inc = (EPSILON_MAX - EPSILON_MIN) / EPSILON_DECAY_EPISODE
        else:
            self.epsilon = EPSILON

    def init_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name="s_")
                self.action_phd = tf.placeholder(tf.int32, [None, ], name="a_")
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim],
                                                      name="s_prime_" )
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name="reward_")
                self.done_phd = tf.placeholder(tf.float32, [None, 1], name="done_")

                # build Q-network and the target network
                self.netQ = self.build_q_net(self.state_phd)

                # define algorithm related operations
                # Q(s,a) op, note that the action should be one-hot
                self.q_sa_op = tf.reduce_sum(self.netQ * tf.one_hot(self.action_phd, self.action_dim),
                                             axis=1, keepdims=True, name="q_sa_")

                # naive target update
                # self.netQ_target = self.build_q_net_target(self.state_prime_phd)

                # another version of target update
                q_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Qnet_")
                ema = tf.train.ExponentialMovingAverage(decay=1 - TAU)
                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update = [ema.apply(q_params)]  # soft update operation

                with tf.control_dependencies(target_update):
                    self.netQ_target = self.build_q_net(self.state_prime_phd, reuse=True, custom_getter=ema_getter)
                    self.q_max_sp_op = tf.reduce_max(self.netQ_target,
                                                     name="q_max_sp_",
                                                     axis=1,
                                                     keepdims=True)
                    q_target = self.reward_phd + GAMMA * (1 - self.done_phd) * self.q_max_sp_op
                    self.target_op = tf.stop_gradient(q_target)
                    self.td_error_op = self.target_op - self.q_sa_op
                    self.squared_error_op = tf.square(self.td_error_op)
                    self.loss = tf.reduce_mean(self.squared_error_op, name="loss_")

                    # define the optimizer
                    if self.grad_clip:
                        self.parameter_gradients = tf.gradients(self.loss, q_params)
                        self.param_clipped_grad_op, _ = tf.clip_by_global_norm(self.parameter_gradients,
                                                                                1.0)
                        self.opt = tf.train.AdamOptimizer(ALPHA)
                        self.optimizer = self.opt.apply_gradients(zip(self.param_clipped_grad_op, q_params))
                    else:
                        self.optimizer = tf.train.AdamOptimizer(ALPHA).minimize(self.loss,
                                                                                name="adam_optimizer_")

                with tf.name_scope("q_loss"):
                    tf.summary.scalar("q_loss_", self.loss)

                self.merged = tf.summary.merge_all()
                self.saver = tf.train.Saver(max_to_keep=100)

    def update(self):
        x = 0

    def update_target_hard(self):
        q_vars = self.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Qnet_")
        q_target_vars = self.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Qnet_Target_")
        assert len(q_vars) == len(q_target_vars)
        self.sess.run([v_t.assign(v) for v_t, v in zip(q_target_vars, q_vars)])

    def update_target_soft(self, tau=0.05):
        x = 1

    def choose_action(self, s, test_model):
        if not test_model and random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1), None
        else:
            qsa_vec = self.sess.run(self.netQ, feed_dict={self.state_phd: [s]})
            max_act = np.argmax(qsa_vec)
            return max_act, None

    def learn(self, state_batch, action_batch, reward_batch,
              state_prime_batch, done_batch, **kwargs):
        if self.is_test:
            return

        with self.sess.graph.as_default():
            _, summary = self.sess.run([self.optimizer, self.merged],
                                       feed_dict={self.state_phd: state_batch,
                                                  self.action_phd: action_batch,
                                                  self.state_prime_phd: state_prime_batch,
                                                  self.reward_phd: reward_batch,
                                                  self.done_phd: done_batch
                                                  })

        self.train_writer.add_summary(summary, self.update_cnt)

        self.update_cnt += 1

    def episode_done(self, test_model):
        if not test_model:
            if self.epsilon_decay == "exponential":
                self.epsilon = max(EPSILON_MIN, self.epsilon * self.epsilon_decay_factor)
            elif self.epsilon_decay == "linear":
                self.epsilon = max(EPSILON_MIN, self.epsilon - self.epsilon_inc)

            self.episode_cnt += 1
            self.episode_reward = []

    def experience(self, s, a, r, sp):
        self.episode_reward.append(r)

    def build_q_net(self, state_phd, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Qnet_', reuse=reuse, custom_getter=custom_getter):

            x = 1.0

            # the first hidden layer
            net = tf.layers.dense(state_phd, 32,
                                  # kernel_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  # bias_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  name="l1",
                                  trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            # the second hidden layer
            net = tf.layers.dense(net, 16,
                                  # kernel_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  # bias_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  name="l2",
                                  trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            # the output layer
            # note that the output is a vector which contains
            # Q-values of all actions in one state
            qsa = tf.layers.dense(net, self.action_dim, activation=None,
                                  # kernel_initializer=tf.random_uniform_initializer(-x / 1000.0, x / 1000.0),
                                  # bias_initializer=tf.random_uniform_initializer(-x / 1000.0, x / 1000.0),
                                  name='qs',
                                  trainable=trainable)
            return qsa

    def build_q_net_target(self, state_phd, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Qnet_Target_', reuse=reuse, custom_getter=custom_getter):
            # the first hidden layer
            net = tf.layers.dense(state_phd, 32,
                                  # kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                  # bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                  name="l1",
                                  trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            # the second hidden layer
            net = tf.layers.dense(net, 16,
                                  # kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                  # bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                  name="l2",
                                  trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            # the output layer
            # note that the output is a vector which contains
            # Q-values of all actions in one state
            qsa = tf.layers.dense(net, self.action_dim, activation=None,
                                  # kernel_initializer=tf.random_uniform_initializer(-3e-3, 3e-3),
                                  # bias_initializer=tf.random_uniform_initializer(-3e-3, 3e-3),
                                  name='qs',
                                  trainable=trainable)
            return qsa

    def write_summary_scalar(self, iteration, tag, value):
        self.train_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)

