# coding=utf-8
import random
import datetime
import os
import numpy as np
import tensorflow as tf
from ...utils.ou_noise import OUNoise

RENDER = False

"""
    default hyper parameters of ddpg
"""
LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
GAMMA = 0.999
TAU = 0.01
OU_NOISE_THETA = 0.15
OU_NOISE_SIGMA = 0.5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX = 1.0
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN = 1e-5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX = 0.2
GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE = 60000
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0

###############################  MYMADDPG  ####################################

class MADDPGAlgo(object):
    def __init__(self, sess, graph, state_dim, action_dim, agent_num=2, algo_name="maddpg", **kwargs):
        self.graph = graph
        self.sess = sess
        self.agent_num = agent_num
        self.action_dim, self.state_dim = action_dim, state_dim
        self.algo_name = algo_name

        """
            initialize algorithm parameters
        """
        self.set_algo_parameters(**kwargs)

        self.init_networks()

        self.train_writer = tf.summary.FileWriter("./data/" + self.algo_name + "/summary/", self.sess.graph)

        with self.sess.as_default():
            with self.graph.as_default():
                tf.global_variables_initializer().run()


    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.tau = kwargs.get("tau", TAU)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.explo_method = kwargs.get("explo_method", "OU")
        if self.explo_method == "OU":
            ou_noise_theta = kwargs.get("ou_noise_theta", OU_NOISE_THETA)
            ou_noise_sigma = kwargs.get("ou_noise_sigma", OU_NOISE_SIGMA)
            self.ou_noise = OUNoise(self.action_dim, mu=0, theta=ou_noise_theta, sigma=ou_noise_sigma)
        elif self.explo_method == "GAUSSIAN_STATIC":
            self.gaussian_explo_ratio = kwargs.get("gaussian_explo_sigma_ratio_fix",
                                                   GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX)
        else:
            self.gaussian_explo_sigma_ratio_max = kwargs.get("gaussian_explo_sigma_ratio_max",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX)
            self.gaussian_explo_sigma_ratio_min = kwargs.get("gaussian_explo_sigma_ratio_min",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN)
            gaussian_explo_sigma_ratio_decay_ep = kwargs.get("gaussian_explo_sigma_ratio_decay_ep",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE)
            self.gaussian_explo_ratio = self.gaussian_explo_sigma_ratio_max
            self.gaussian_explo_decay_factor = pow(self.gaussian_explo_sigma_ratio_min /
                                                   self.gaussian_explo_sigma_ratio_max,
                                                   1.0 / gaussian_explo_sigma_ratio_decay_ep)

        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)

        """
            network layer cell numbers
        """
        self.actor_net_layers = kwargs.get("actor_net_layers", [4, 4])
        self.critic_net_layers = kwargs.get("critic_net_layers", [32, 32])
        self.critic_act_in_ly_index = int(kwargs.get("critic_action_input_layer_index", 1))

    def init_networks(self):
        """
            firstly define place holders
        """
        self.define_place_holders()

        """
            then build networks, including
            actor network, shaping weight function, shaped critic, and true critic
        """
        self.build_networks()

        """
            the next is to define trainers,
            including building target networks, loss, and trainers
        """
        self.define_trainers()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def init_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():


                # decentralized actor train
                pg_loss = tf.reduce_mean(q)  # maximize the q
                a_loss = pg_loss  # + p_reg * 1e-3
                self.atrain = tf.train.AdamOptimizer(LR_A).minimize(-a_loss, var_list=a_params)

                with tf.name_scope('Actor_exp_Q'):
                    tf.summary.scalar('a_exp_Q', a_loss)

                # centralized critic train
                with tf.control_dependencies(target_update):  # soft replacement happened at here
                    q_target = self.R + GAMMA * (1 - self.Done) * q_
                    td_error = tf.losses.mean_squared_error(labels=q_target, predictions=q)
                    self.ctrain = tf.train.AdamOptimizer(LR_C).minimize(td_error, var_list=c_params)

                with tf.name_scope('Centralized_Critic_loss'):
                    tf.summary.scalar('td_error', td_error)

                self.merged = tf.summary.merge_all()
                self.saver = tf.train.Saver(max_to_keep=100)

    def define_place_holders(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state')  # the general states
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state_prime')
                if continuous_action:
                    self.action_phd = tf.placeholder(tf.float32, [None, self.action_dim], name='action')
                else:
                    self.action_phd = tf.placeholder(tf.int32, [None, ], name="action")

                """
                    true reward
                """
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name='reward')

                """
                    the additional reward, namely F(s,a)
                """
                self.add_reward_phd = tf.placeholder(tf.float32, [None, 1], name='additional_reward')

                self.done_phd = tf.placeholder(tf.float32, [None, 1], name='done')

                self.temp = tf.Variable(1.0, name='temperature')


    def build_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    actor of the agent
                """
                # build decentralized  evaluate actor network
                self.a_all = self._build_a(self.S, )
                # build centralized evaluate critic network
                q = self._build_c(self.S, self.a_all)

                # get the params and apply to target net
                a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Actor')
                c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Critic')
                ema = tf.train.ExponentialMovingAverage(decay=1 - TAU)  # soft replacement

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update = [ema.apply(a_params), ema.apply(c_params)]  # soft update operation
                a_all = self._build_a(self.S_, reuse=True, custom_getter=ema_getter)  # replaced target parameters
                q_ = self._build_c(self.S_, a_all, reuse=True, custom_getter=ema_getter)

    def choose_action(self, s, temp): #, index):
        # if index == 0:
        # a1, a2, a3 = self.sess.run(self.a1, self.a2, self.a3, {self.S: s[np.newaxis, :]})[0]
        a_all = self.sess.run(self.a_all, {self.S: s[np.newaxis, :], self.temp: temp})
        return a_all # a_all[0], a_all[1], a_all[2] # org
        # elif index == 1:
        #     return self.sess.run(self.a2, {self.S: s[np.newaxis, :]})[0]
        # elif index == 2:
        #     return self.sess.run(self.a3, {self.S: s[np.newaxis, :]})[0]

    def learn(self, bs, ba, br, bs_, bdone):
        self.sess.run(self.atrain, {self.S: bs})
        self.sess.run(self.ctrain, {self.S: bs, self.a_all: ba, self.R: br, self.S_: bs_, self.Done: bdone})

    def _build_a(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Actor', reuse=reuse, custom_getter=custom_getter):
            net = tf.layers.dense(s, 64, name='l1', trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            net = tf.layers.dense(net, 64, name='l2', trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            jnt_action = []
            for i in range(self.agent_num):
                a_i = tf.layers.dense(net, self.action_dims[i], activation=None, name='a'+str(i), trainable=trainable)
                a_i = gumbel_softmax(a_i, self.temp, hard=True)
                jnt_action.append(a_i)

            action_all = tf.concat(jnt_action, axis=1)
            return action_all

    # def _build_c(self, S, a1, a2, a3, reuse=None, custom_getter=None): # org
    def _build_c(self, S, a_all, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Critic', reuse=reuse, custom_getter=custom_getter):

            # n_l1 = 64
            # w1_s = tf.get_variable('w1_s', [self.S_dim, n_l1], trainable=trainable)
            # w1_a1 = tf.get_variable('w1_a1', [self.a_bound, n_l1], trainable=trainable)
            # w1_a2 = tf.get_variable('w1_a2', [self.a_bound, n_l1], trainable=trainable)
            # w1_a3 = tf.get_variable('w1_a3', [self.a_bound, n_l1], trainable=trainable)
            # b1 = tf.get_variable('b1', [1, n_l1], trainable=trainable)
            #
            # net = tf.matmul(S, w1_s) + tf.matmul(a1, w1_a1) + tf.matmul(a2, w1_a2) + tf.matmul(a3, w1_a3) + b1
            # net = tf.contrib.layers.layer_norm(net)
            # net = tf.nn.relu(net)

            # c_input = tf.concat([S, a_all[0], a_all[1], a_all[2]], 1)  # org
            c_input = tf.concat([S, a_all], 1)
            net = tf.layers.dense(c_input, 64, name='l1', trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            net = tf.layers.dense(net, 64, name='l2', trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            return tf.layers.dense(net, 1, trainable=trainable)  # Q(s,a1,a2,a3)