import random
import numpy as np
import tensorflow as tf
from ...utils.ou_noise import OUNoise

RENDER = False

"""
    default hyper parameters of ddpg
"""
LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
GAMMA = 0.999
TAU = 0.01
OU_NOISE_THETA = 0.15
OU_NOISE_SIGMA = 0.5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX = 1.0
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN = 1e-5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX = 0.2
GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE = 60000
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0

class DDPGAlgo(object):
    def __init__(self, sess, graph, state_dim, action_dim, algo_name="ddpg", **kwargs):
        # print(a_dim, s_dim)
        self.pointer = 0
        self.sess = sess
        self.graph = graph
        self.algo_name = algo_name
        self.action_dim, self.state_dim = action_dim, state_dim

        """
            initialize algorithm parameters
        """
        self.set_algo_parameters(**kwargs)

        self.init_networks()

        self.train_writer = tf.summary.FileWriter("./data/" + self.algo_name + "/summary/", self.sess.graph)

        with self.sess.as_default():
            with self.graph.as_default():
                tf.global_variables_initializer().run()

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.tau = kwargs.get("tau", TAU)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.explo_method = kwargs.get("explo_method", "OU")
        if self.explo_method == "OU":
            ou_noise_theta = kwargs.get("ou_noise_theta", OU_NOISE_THETA)
            ou_noise_sigma = kwargs.get("ou_noise_sigma", OU_NOISE_SIGMA)
            self.ou_noise = OUNoise(self.action_dim, mu=0, theta=ou_noise_theta, sigma=ou_noise_sigma)
        elif self.explo_method == "GAUSSIAN_STATIC":
            self.gaussian_explo_ratio = kwargs.get("gaussian_explo_sigma_ratio_fix",
                                                   GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX)
        else:
            self.gaussian_explo_sigma_ratio_max = kwargs.get("gaussian_explo_sigma_ratio_max",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX)
            self.gaussian_explo_sigma_ratio_min = kwargs.get("gaussian_explo_sigma_ratio_min",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN)
            gaussian_explo_sigma_ratio_decay_ep = kwargs.get("gaussian_explo_sigma_ratio_decay_ep",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE)
            self.gaussian_explo_ratio = self.gaussian_explo_sigma_ratio_max
            self.gaussian_explo_decay_factor = pow(self.gaussian_explo_sigma_ratio_min /
                                                   self.gaussian_explo_sigma_ratio_max,
                                                   1.0 / gaussian_explo_sigma_ratio_decay_ep)

        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)

        """
            network layer cell numbers
        """
        self.actor_net_layers = kwargs.get("actor_net_layers", [4, 4])
        self.critic_net_layers = kwargs.get("critic_net_layers", [32, 32])
        self.critic_act_in_ly_index = int(kwargs.get("critic_action_input_layer_index", 1))

    def init_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state')  # the general states
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state_prime')
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name='reward')
                self.done_phd = tf.placeholder(tf.float32, [None, 1], name='done')
                self.temp = tf.Variable(1.0, name='temperature')

                # build decentralized  evaluate actor network
                self.actor_output = self._build_actor(self.state_phd, )
                # build centralized evaluate critic network
                q = self._build_critic(self.state_phd, self.actor_output)

                # get the params and apply to target net
                a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Actor')
                c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Critic')
                ema = tf.train.ExponentialMovingAverage(decay=1 - self.tau)  # soft replacement

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update = [ema.apply(a_params), ema.apply(c_params)]  # soft update operation
                target_actor_output = self._build_actor(self.state_prime_phd, reuse=True, custom_getter=ema_getter)
                q_ = self._build_critic(self.state_prime_phd, target_actor_output, reuse=True, custom_getter=ema_getter)

                # decentralized actor train
                pg_loss = tf.reduce_mean(q)  # maximize the q
                a_loss = pg_loss  # + p_reg * 1e-3
                if self.actor_grad_clip:
                    self.actor_param_gradients = tf.gradients(-a_loss, a_params)
                    self.actor_clipped_grad_op, _ = tf.clip_by_global_norm(self.actor_param_gradients,
                                                                           self.actor_grad_norm_clip)
                    self.actor_opt = tf.train.AdamOptimizer(self.lr_actor)
                    self.atrain = self.actor_opt.apply_gradients(zip(self.actor_clipped_grad_op, a_params))
                else:
                    self.atrain = tf.train.AdamOptimizer(self.lr_actor).minimize(-a_loss, var_list=a_params)

                with tf.name_scope('Actor_exp_Q'):
                    tf.summary.scalar('a_exp_Q', a_loss)

                # centralized critic train
                with tf.control_dependencies(target_update):  # soft replacement happened at here
                    q_target = self.reward_phd + self.gamma * (1 - self.done_phd) * q_
                    td_error = tf.losses.mean_squared_error(labels=tf.stop_gradient(q_target), predictions=q)
                    clipped_td_error = tf.minimum(td_error, 100.0)
                    if self.critic_grad_clip:
                        # self.critic_param_gradients = tf.gradients(td_error, c_params)
                        self.critic_param_gradients = tf.gradients(clipped_td_error, c_params)
                        self.critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.critic_param_gradients,
                                                                                self.critic_grad_norm_clip)
                        self.critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                        self.ctrain = self.critic_opt.apply_gradients(zip(self.critic_clipped_grad_op, c_params))
                    else:
                        # self.ctrain = tf.train.AdamOptimizer(LR_CRITIC).minimize(td_error, var_list=c_params)
                        self.ctrain = tf.train.AdamOptimizer(self.lr_critic).minimize(clipped_td_error, var_list=c_params)

                with tf.name_scope('Centralized_Critic_loss'):
                    tf.summary.scalar('td_error', td_error)
                self.saver = tf.train.Saver(max_to_keep=100)

    def choose_action(self, s, is_test):

        action = self.sess.run(self.actor_output, {self.state_phd: [s]})
        action = action[0]
        # print("Computed action is {}".format(action))
        if not is_test:
            if self.explo_method == "OU":
                action = action + self.ou_noise.noise()
            else:
                action = np.random.normal(action, action * self.gaussian_explo_ratio)

        return action, None

    def learn(self, bs, ba, br, bs_, bdone, **kwargs):
        self.sess.run(self.atrain, {self.state_phd: bs}) # why not use the exploration action??

        self.sess.run(self.ctrain, {self.state_phd: bs, self.actor_output: ba,
                                    self.reward_phd: br, self.state_prime_phd: bs_,
                                    self.done_phd: bdone})

    def _build_actor(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Actor', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.actor_net_layers)):
                ly_cell_num = self.actor_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num, name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            action = tf.layers.dense(net, self.action_dim, activation=None, name='action', trainable=trainable)
            # action = tf.layers.dense(net, self.action_dim, activation=tf.nn.tanh, name='action', trainable=trainable)

            return action

    def _build_critic(self, s, action, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Critic', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.critic_net_layers)):
                ly_cell_num = self.critic_net_layers[ly_index]
                if ly_index == self.critic_act_in_ly_index:
                    net = tf.concat([net, action], axis=1)

                net = tf.layers.dense(net, ly_cell_num, name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            return tf.layers.dense(net, 1, trainable=trainable)

    def experience(self, one_exp):
        print("The is experience method of ddpg, {}".format(one_exp))

    def episode_done(self, is_test):
        if not is_test:
            if self.explo_method == "GAUSSIAN_DYNAMIC":
                self.gaussian_explo_ratio = max(self.gaussian_explo_sigma_ratio_min,
                                                self.gaussian_explo_ratio * self.gaussian_explo_decay_factor)

    def write_summary_scalar(self, iteration, tag, value):
        self.train_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)