import numpy as np
import tensorflow as tf
from experiments.algorithms.ddpg.ddpg_algo import DDPGAlgo
from experiments.utils.ou_noise import OUNoise

RENDER = False

"""
    default hyper parameters of ddpg-dpba
"""
LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_PHI = 1e-4
GAMMA = 0.999
TAU = 0.01
OU_NOISE_THETA = 0.15
OU_NOISE_SIGMA = 0.5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX = 1.0
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN = 1e-5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX = 0.2
GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE = 60000
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
PHI_GRADIENT_NORM_CLIP = 50.0


class DDPGDpbaAlgo(DDPGAlgo):
    def __init__(self, sess, graph, state_dim, action_dim, algo_name="ddpg_dpba", **kwargs):
        super(DDPGDpbaAlgo, self).__init__(sess, graph, state_dim, action_dim, algo_name, **kwargs)

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.tau = kwargs.get("tau", TAU)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.lr_phi = kwargs.get("lr_phi", LR_PHI)
        self.explo_method = kwargs.get("explo_method", "OU")
        if self.explo_method == "OU":
            ou_noise_theta = kwargs.get("ou_noise_theta", OU_NOISE_THETA)
            ou_noise_sigma = kwargs.get("ou_noise_sigma", OU_NOISE_SIGMA)
            self.ou_noise = OUNoise(self.action_dim, mu=0, theta=ou_noise_theta, sigma=ou_noise_sigma)
        elif self.explo_method == "GAUSSIAN_STATIC":
            self.gaussian_explo_ratio = kwargs.get("gaussian_explo_sigma_ratio_fix",
                                                   GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX)
        else:
            self.gaussian_explo_sigma_ratio_max = kwargs.get("gaussian_explo_sigma_ratio_max",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX)
            self.gaussian_explo_sigma_ratio_min = kwargs.get("gaussian_explo_sigma_ratio_min",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN)
            gaussian_explo_sigma_ratio_decay_ep = kwargs.get("gaussian_explo_sigma_ratio_decay_ep",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE)
            self.gaussian_explo_ratio = self.gaussian_explo_sigma_ratio_max
            self.gaussian_explo_decay_factor = pow(self.gaussian_explo_sigma_ratio_min /
                                                   self.gaussian_explo_sigma_ratio_max,
                                                   1.0 / gaussian_explo_sigma_ratio_decay_ep)

        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.phi_grad_clip = kwargs.get("phi_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.phi_grad_norm_clilp = kwargs.get("phi_gradient_norm_clip", PHI_GRADIENT_NORM_CLIP)

        """
            network layer cell numbers
        """
        self.actor_net_layers = kwargs.get("actor_net_layers", [4, 4])
        self.critic_net_layers = kwargs.get("critic_net_layers", [32, 32])
        self.critic_act_in_ly_index = int(kwargs.get("critic_action_input_layer_index", 1))
        self.phi_net_layers = kwargs.get("phi_net_layers", [16, 8])

        self.phi_hidden_layer_act_func = kwargs.get("phi_hidden_layer_act_func", tf.nn.relu)

    def init_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state')  # the general states
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state_prime')
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name='reward')
                self.shaping_reward_phd = tf.placeholder(tf.float32, [None, 1], name="shaping_reward")
                self.done_phd = tf.placeholder(tf.float32, [None, 1], name='done')
                self.temp = tf.Variable(1.0, name='temperature')

                """
                    build actor network
                """
                self.actor_output = self._build_actor(self.state_phd, )
                self.action_phd = tf.placeholder(tf.float32, [None, self.action_dim], name='action')
                self.action_prime_phd = tf.placeholder(tf.float32, [None, self.action_dim], name='action_prime')

                """
                    build critic network
                """
                q = self._build_critic(self.state_phd, self.actor_output)

                """
                    build potential network
                """
                self.phi_sa = self.build_phi_net(self.state_phd, self.action_phd)

                # get the params and apply to target net
                a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Actor')
                c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Critic')
                phi_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PhiNet")
                ema = tf.train.ExponentialMovingAverage(decay=1 - self.tau)  # soft replacement

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update = [ema.apply(a_params), ema.apply(c_params)]
                phi_target_update = [ema.apply(phi_params)]

                """
                    target actor network
                """
                target_actor_output = self._build_actor(self.state_prime_phd, reuse=True, custom_getter=ema_getter)

                """
                    target critic network
                """
                q_ = self._build_critic(self.state_prime_phd, target_actor_output, reuse=True, custom_getter=ema_getter)

                """
                    target potential network
                """
                self.phi_spap = self.build_phi_net(self.state_prime_phd, self.action_prime_phd,
                                                   reuse=True, custom_getter=ema_getter)

                # decentralized actor train
                pg_loss = tf.reduce_mean(q)  # maximize the q
                a_loss = pg_loss  # + p_reg * 1e-3
                if self.actor_grad_clip:
                    self.actor_param_gradients = tf.gradients(-a_loss, a_params)
                    self.actor_clipped_grad_op, _ = tf.clip_by_global_norm(self.actor_param_gradients,
                                                                           self.actor_grad_norm_clip)
                    self.actor_opt = tf.train.AdamOptimizer(self.lr_actor)
                    self.atrain = self.actor_opt.apply_gradients(zip(self.actor_clipped_grad_op, a_params))
                else:
                    self.atrain = tf.train.AdamOptimizer(self.lr_actor).minimize(-a_loss, var_list=a_params)

                with tf.name_scope('Actor_exp_Q'):
                    tf.summary.scalar('a_exp_Q', a_loss)

                # centralized critic train
                with tf.control_dependencies(target_update):
                    Fsaspap = self.gamma * self.phi_spap - self.phi_sa
                    q_target = self.reward_phd + Fsaspap + self.gamma * (1 - self.done_phd) * q_
                    td_error = tf.losses.mean_squared_error(labels=tf.stop_gradient(q_target), predictions=q)
                    clipped_td_error = tf.minimum(td_error, 100.0)
                    if self.critic_grad_clip:
                        # self.critic_param_gradients = tf.gradients(td_error, c_params)
                        self.critic_param_gradients = tf.gradients(clipped_td_error, c_params)
                        self.critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.critic_param_gradients,
                                                                                self.critic_grad_norm_clip)
                        self.critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                        self.ctrain = self.critic_opt.apply_gradients(zip(self.critic_clipped_grad_op, c_params))
                    else:
                        # self.ctrain = tf.train.AdamOptimizer(LR_CRITIC).minimize(td_error, var_list=c_params)
                        self.ctrain = tf.train.AdamOptimizer(self.lr_critic).minimize(clipped_td_error, var_list=c_params)

                with tf.control_dependencies(phi_target_update):
                    """
                        define the optimization of Phi-network
                    """
                    phi_target = -self.shaping_reward_phd + self.gamma * (1 - self.done_phd) * self.phi_spap
                    self.phi_target_op = tf.stop_gradient(phi_target)
                    self.phi_td_error_op = self.phi_target_op - self.phi_sa
                    self.phi_squared_error_op = tf.square(self.phi_td_error_op)
                    self.phi_loss = tf.reduce_mean(self.phi_squared_error_op, name="phi_loss")
                    if self.phi_grad_clip:
                        self.phi_param_gradients = tf.gradients(self.phi_loss, phi_params)
                        self.phi_clipped_grad_op, _ = tf.clip_by_global_norm(self.phi_param_gradients,
                                                                             self.phi_grad_norm_clilp)
                        self.phi_opt = tf.train.AdamOptimizer(self.lr_phi)
                        self.phi_optimizer = self.phi_opt.apply_gradients(zip(self.phi_clipped_grad_op, phi_params))
                    else:
                        self.phi_optimizer = tf.train.AdamOptimizer(self.lr_phi).minimize(self.phi_loss,
                                                                                          name="phi_adam_optimizer")

                with tf.name_scope('Centralized_Critic_loss'):
                    tf.summary.scalar('td_error', td_error)
                    tf.summary.scalar("phi_loss", self.phi_loss)

                self.saver = tf.train.Saver(max_to_keep=100)

    def choose_action(self, s, is_test):

        action = self.sess.run(self.actor_output, {self.state_phd: [s]})
        action = action[0]
        if not is_test:
            if self.explo_method == "OU":
                action = action + self.ou_noise.noise()
            else:
                action = np.random.normal(action, action * self.gaussian_explo_ratio)

        return action, None

    def learn(self, bs, ba, br, bs_, bdone, **kwargs):
        """
            get the shaping reward batches of the current state,
            which is used for learning the potential function Phi(s,a)
        """
        c_batch = kwargs.get("c_batch")

        """
            get the greedy action of the next state according to the actor network (not target)
        """
        with self.sess.graph.as_default():
            action_sp_batch = self.sess.run(self.actor_output, feed_dict={self.state_phd: bs_})

            self.sess.run(self.atrain, {self.state_phd: bs})
            # self.sess.run(self.atrain, {self.state_phd: bs,
            #                             self.actor_output: ba})

            self.sess.run([self.ctrain, self.phi_optimizer],
                          feed_dict={self.state_phd: bs, self.actor_output: ba,
                                     self.reward_phd: br, self.state_prime_phd: bs_,
                                     self.action_phd: ba, self.action_prime_phd: action_sp_batch,
                                     self.done_phd: bdone, self.shaping_reward_phd: c_batch
                                     })

    def build_phi_net(self, state_phd, action_phd, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('PhiNet', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([state_phd, action_phd], axis=1)
            for ly_index in range(len(self.phi_net_layers)):
                ly_cell_num = self.phi_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name="l"+str(ly_index),
                                      trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.phi_hidden_layer_act_func(net)

            # the output layer
            phisa = tf.layers.dense(net, 1, activation=None,
                                    kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                    bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                    name='phi_sa',
                                    trainable=trainable
                                    )
            return phisa

    def experience(self, one_exp):
        print("The is experience method of ddpg-dpba, {}".format(one_exp))

    def write_summary_scalar(self, iteration, tag, value):
        self.train_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)