import numpy as np
import tensorflow as tf
import random
from experiments.utils.ou_noise import OUNoise
from experiments.algorithms.ddpg.ddpg_algo import DDPGAlgo
from experiments.utils.common import Dataset

RENDER = False

LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_F = 1e-5
GAMMA = 0.999
TAU = 0.01

OU_NOISE_THETA = 0.15
OU_NOISE_SIGMA = 0.5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX = 1.0
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN = 1e-5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX = 0.2
GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE = 60000
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
F_GRADIENT_NORM_CLIP = 50.0

F_OPTIM_EPOCHS = 50
F_OPTIM_BATCH_SIZE = 1024

"""
    DDPG with optimization of parameterized reward shaping (OPRS) v1
    which directly relates shaping weight function f_phi with policy
    
    "fop" stands for that we update f using on-policy strategy
"""


class DDPGOprsV1FopAlgo(DDPGAlgo):
    def __init__(self, sess, graph, state_dim, action_dim, algo_name="ddpg_oprs_v1_fop", **kwargs):
        self.optimize_policy = True
        super(DDPGOprsV1FopAlgo, self).__init__(sess, graph, state_dim, action_dim,
                                                algo_name, **kwargs)

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.tau = kwargs.get("tau", TAU)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.lr_f = kwargs.get("lr_f", LR_F)
        self.explo_method = kwargs.get("explo_method", "OU")
        if self.explo_method == "OU":
            ou_noise_theta = kwargs.get("ou_noise_theta", OU_NOISE_THETA)
            ou_noise_sigma = kwargs.get("ou_noise_sigma", OU_NOISE_SIGMA)
            self.ou_noise = OUNoise(self.action_dim, mu=0, theta=ou_noise_theta, sigma=ou_noise_sigma)
        elif self.explo_method == "GAUSSIAN_STATIC":
            self.gaussian_explo_ratio = kwargs.get("gaussian_explo_sigma_ratio_fix",
                                                   GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX)
        else:
            self.gaussian_explo_sigma_ratio_max = kwargs.get("gaussian_explo_sigma_ratio_max",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX)
            self.gaussian_explo_sigma_ratio_min = kwargs.get("gaussian_explo_sigma_ratio_min",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN)
            gaussian_explo_sigma_ratio_decay_ep = kwargs.get("gaussian_explo_sigma_ratio_decay_ep",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE)
            self.gaussian_explo_ratio = self.gaussian_explo_sigma_ratio_max
            self.gaussian_explo_decay_factor = pow(self.gaussian_explo_sigma_ratio_min /
                                                   self.gaussian_explo_sigma_ratio_max,
                                                   1.0 / gaussian_explo_sigma_ratio_decay_ep)

        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.f_grad_clip = kwargs.get("f_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.f_grad_norm_clilp = kwargs.get("f_gradient_norm_clip", F_GRADIENT_NORM_CLIP)

        """
            network layer cell numbers
        """
        self.actor_net_layers = kwargs.get("actor_net_layers", [4, 4])
        self.critic_net_layers = kwargs.get("critic_net_layers", [32, 32])
        self.critic_act_in_ly_index = int(kwargs.get("critic_action_input_layer_index", 1))
        self.f_net_layers = kwargs.get("f_net_layers", [16, 8])
        self.f_optim_epochs = kwargs.get("f_optim_epochs", F_OPTIM_EPOCHS)
        self.f_optim_batch_size = kwargs.get("f_optim_batch_size", F_OPTIM_BATCH_SIZE)

        self.net_add_one = kwargs.get("net_add_one", False)
        self.f_hidden_layer_act_func = kwargs.get("f_hidden_layer_act_func", tf.nn.relu)

        self.f_output_layer_act_func = kwargs.get("f_output_layer_act_func", None)
        self.f_phi_min = kwargs.get("f_phi_min", None)
        self.f_phi_max = kwargs.get("f_phi_max", None)


    def init_networks(self):
        """
            firstly define place holders
        """
        self.define_place_holders()

        """
            then build networks, including
            actor network, shaping weight function, shaped critic, and true critic
        """
        self.build_networks()

        """
            the next is to define trainers,
            including building target networks, loss, and trainers
        """
        self.define_trainers()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def define_place_holders(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state')
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state_prime')

                """
                    the original reward R(s,a)
                """
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name='reward')

                """
                    the additional reward, namely F(s,a)
                """
                self.add_reward_phd = tf.placeholder(tf.float32, [None, 1], name='additional_reward')

                self.done_phd = tf.placeholder(tf.float32, [None, 1], name='done')
                self.temp = tf.Variable(1.0, name='temperature')

                """
                    shaping weight value of the next state s'
                """
                self.f_phi_sp_phd = tf.placeholder(tf.float32, [None, 1], name='f_phi_sp')
                # self.f_phi_phd = tf.placeholder(tf.float32, [None, 1], name='f_phi')

    def build_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    build the shaping weight function f_phi(s)
                """
                self.f_phi = self._build_weight_func(self.state_phd, )

                """
                    build actor network, the input is state and shaping weight f_phi(s)
                """
                self.actor_output = self._build_extended_actor(self.state_phd, self.f_phi)

                """
                    build shaped critic network, which is for optimizing policy
                    the input is s, mu(s), f_phi(s)
                """
                self.shaped_critic = self._build_shaped_critic(self.state_phd, self.actor_output, self.f_phi)

                """
                    build true critic network, which is for optimization weight function f_phi
                    the input is state and action
                """
                self.true_critic = self._build_true_critic(self.state_phd, self.actor_output)

    def define_trainers(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    first build target network of actor, shaped critic, and true critic
                """
                a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Actor')
                shaped_c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Shaped_Critic')
                true_c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='True_Critic')
                f_phi_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func')
                ema = tf.train.ExponentialMovingAverage(decay=1 - self.tau)

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update_for_policy = [ema.apply(a_params), ema.apply(shaped_c_params)]
                target_update_for_f = [ema.apply(true_c_params)]

                """
                    target actor network, the input is next state s' and f_phi(s')
                """
                target_actor_output = self._build_extended_actor(self.state_prime_phd, self.f_phi_sp_phd,
                                                                 reuse=True, custom_getter=ema_getter)

                """
                    target shaped critic network, the input is s', mu'(s'), and f_phi(s')
                """
                target_shaped_critic = self._build_shaped_critic(self.state_prime_phd, target_actor_output,
                                                                 self.f_phi_sp_phd, reuse=True,
                                                                 custom_getter=ema_getter)

                """
                    build target true critic network, the input is s' and mu'(s')
                """
                target_true_critic = self._build_true_critic(self.state_prime_phd, target_actor_output,
                                                             reuse=True, custom_getter=ema_getter)

                """
                    define optimization of policy according to shaped rewards
                """
                a_loss = tf.reduce_mean(self.shaped_critic)
                if self.actor_grad_clip:
                    self.actor_param_gradients = tf.gradients(-a_loss, a_params)
                    self.actor_clipped_grad_op, _ = tf.clip_by_global_norm(self.actor_param_gradients,
                                                                           self.actor_grad_norm_clip)
                    self.actor_opt = tf.train.AdamOptimizer(self.lr_actor)
                    self.trainer_actor = self.actor_opt.apply_gradients(zip(self.actor_clipped_grad_op, a_params))
                else:
                    self.trainer_actor = tf.train.AdamOptimizer(self.lr_actor).minimize(-a_loss, var_list=a_params)

                with tf.name_scope('Actor_Loss'):
                    tf.summary.scalar('actor_exp_Q', a_loss)

                """
                    define optimization of shaped critic
                """
                with tf.control_dependencies(target_update_for_policy):
                    shaped_q_target = self.reward_phd + self.f_phi * self.add_reward_phd + \
                                      self.gamma * (1 - self.done_phd) * target_shaped_critic
                    shaped_td_error = tf.losses.mean_squared_error(labels=tf.stop_gradient(shaped_q_target),
                                                                   predictions=self.shaped_critic)
                    clipped_shaped_td_error = tf.minimum(shaped_td_error, 100.0)
                    if self.critic_grad_clip:
                        self.shp_c_param_gradients = tf.gradients(clipped_shaped_td_error, shaped_c_params)
                        self.shp_c_clipped_grad_op, _ = tf.clip_by_global_norm(self.shp_c_param_gradients,
                                                                               self.critic_grad_norm_clip)
                        self.shp_c_opt = tf.train.AdamOptimizer(self.lr_critic)
                        self.trainer_shaped_critic = self.shp_c_opt.apply_gradients(zip(self.shp_c_clipped_grad_op,
                                                                                        shaped_c_params))
                    else:
                        self.trainer_shaped_critic = tf.train.AdamOptimizer(self.lr_critic).minimize(
                            clipped_shaped_td_error, var_list=shaped_c_params)

                """
                    define optimization of f_phi
                """
                f_loss = tf.reduce_mean(self.true_critic)
                if self.f_grad_clip:
                    self.f_param_gradients = tf.gradients(-f_loss, f_phi_params)
                    self.f_clipped_grad_op, _ = tf.clip_by_global_norm(self.f_param_gradients,
                                                                       self.f_grad_norm_clilp)
                    self.f_opt = tf.train.AdamOptimizer(self.lr_f)
                    self.trainer_f = self.f_opt.apply_gradients(zip(self.f_clipped_grad_op, f_phi_params))
                else:
                    self.trainer_f = tf.train.AdamOptimizer(self.lr_f).minimize(-f_loss, var_list=f_phi_params)

                with tf.name_scope('Weight_Func_Loss'):
                    tf.summary.scalar('f_phi_exp_q', f_loss)

                """
                    define optimization of true critic
                """
                with tf.control_dependencies(target_update_for_f):
                    true_q_target = self.reward_phd + self.gamma * (1 - self.done_phd) * target_true_critic
                    true_td_error = tf.losses.mean_squared_error(labels=tf.stop_gradient(true_q_target),
                                                                 predictions=self.true_critic)
                    clipped_true_td_error = tf.minimum(true_td_error, 100.0)
                    if self.critic_grad_clip:
                        self.true_c_param_gradients = tf.gradients(clipped_true_td_error, true_c_params)
                        self.true_c_clipped_grad_op, _ = tf.clip_by_global_norm(self.true_c_param_gradients,
                                                                                self.critic_grad_norm_clip)
                        self.true_c_opt = tf.train.AdamOptimizer(self.lr_critic)
                        self.trainer_true_critic = self.true_c_opt.apply_gradients(zip(self.true_c_clipped_grad_op,
                                                                                       true_c_params))
                    else:
                        self.trainer_true_critic = tf.train.AdamOptimizer(self.lr_critic).minimize(
                            clipped_true_td_error, var_list=true_c_params)

                with tf.name_scope('Critic_loss'):
                    tf.summary.scalar('shaped_td_error', shaped_td_error)
                    tf.summary.scalar('true_td_error', true_td_error)

    def choose_action(self, s, is_test):
        """
            firstly compute the action
        """
        f_phi_s, action = self.sess.run([self.f_phi, self.actor_output], {self.state_phd: [s]})
        action = action[0]
        if len(f_phi_s.shape) == 2:
            f_phi_s = f_phi_s[0][0]

        if random.uniform(0, 1.0) < 1e-3:
            print("f_phi_s for state {} is {}".format(s, f_phi_s))

        if self.f_phi_min is not None and self.f_phi_max is not None:
            f_phi_s = np.minimum(self.f_phi_max, np.maximum(self.f_phi_min, f_phi_s))

        """
            if currently it is test or optimizing the shaping weight function
        """
        if is_test:
            return action, {"f_phi_s": f_phi_s}
        else:
            # print("Computed action is {}".format(action))
            if self.explo_method == "OU":
                action = action + self.ou_noise.noise()
            else:
                action = np.random.normal(action, action * self.gaussian_explo_ratio)

            return action, {"f_phi_s": f_phi_s}

    def learn(self, bs, ba, br, bs_, bdone, **kwargs):
        with self.sess.graph.as_default():
            if self.optimize_policy:
                """
                    get the f(s') batch and additional reward batch
                """
                # f_phi_s_batch = kwargs.get("f_phi_s")
                add_reward_batch = kwargs.get("F")

                """
                    optimize actor
                """
                self.sess.run(self.trainer_actor, {self.state_phd: bs})

                """
                    optimize shaped critic
                    first we should compute f_phi(s')
                    target shaped critic is Q(s', mu'(s',f(s')), f(s'))
                """
                f_phi_s_batch = self.sess.run(self.f_phi, feed_dict={self.state_phd: bs})
                f_phi_sp_batch = self.sess.run(self.f_phi, feed_dict={self.state_phd: bs_})

                self.sess.run(self.trainer_shaped_critic, feed_dict={self.state_phd: bs, self.actor_output: ba,
                                                                     self.reward_phd: br,
                                                                     self.add_reward_phd: add_reward_batch,
                                                                     self.f_phi: f_phi_s_batch,
                                                                     self.state_prime_phd: bs_,
                                                                     self.done_phd: bdone,
                                                                     self.f_phi_sp_phd: f_phi_sp_batch})
            else:
                """
                    optimize shaping weight function
                    note that f is updating on-policy
                    so construct data set immediately
                """
                d = Dataset(dict(bs=bs, ba=ba, br=br, bs_=bs_, bdone=bdone), deterministic=False)
                for _ in range(self.f_optim_epochs):
                    for sub_mini_batch in d.iterate_once(self.f_optim_batch_size):
                        self.sess.run(self.trainer_f, {self.state_phd: sub_mini_batch["bs"]})

                        """
                            optimize true critic, which is Q(s', mu'(s',f(s'))
                        """
                        f_phi_sp_batch = self.sess.run(self.f_phi, feed_dict={self.state_phd: sub_mini_batch["bs_"]})
                        self.sess.run(self.trainer_true_critic, feed_dict={self.state_phd: sub_mini_batch["bs"],
                                                                           self.actor_output: sub_mini_batch["ba"],
                                                                           self.reward_phd: sub_mini_batch["br"],
                                                                           self.state_prime_phd: sub_mini_batch["bs_"],
                                                                           self.done_phd: sub_mini_batch["bdone"],
                                                                           self.f_phi_sp_phd: f_phi_sp_batch})

    def _build_weight_func(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            """
                currently we limit the weight value in [-1, 1]
            """
            # net = tf.layers.dense(net, 1, activation=tf.nn.tanh, name='f_value', trainable=trainable)
            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_extended_actor(self, s, f_s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Actor', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([s, f_s], axis=1)
            for ly_index in range(len(self.actor_net_layers)):
                ly_cell_num = self.actor_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            action = tf.layers.dense(net, self.action_dim, activation=None, name='action', trainable=trainable)
            # action = tf.layers.dense(net, self.action_dim, activation=tf.nn.tanh, name='action', trainable=trainable)

            return action

    def _build_true_critic(self, s, action, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('True_Critic', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.critic_net_layers)):
                ly_cell_num = self.critic_net_layers[ly_index]
                if ly_index == self.critic_act_in_ly_index:
                    net = tf.concat([net, action], axis=1)

                net = tf.layers.dense(net, ly_cell_num,
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            return tf.layers.dense(net, 1, trainable=trainable)

    def _build_shaped_critic(self, s, action, f_phi, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Shaped_Critic', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([s, f_phi], axis=1)
            for ly_index in range(len(self.critic_net_layers)):
                ly_cell_num = self.critic_net_layers[ly_index]
                if ly_index == self.critic_act_in_ly_index:
                    net = tf.concat([net, action], axis=1)

                net = tf.layers.dense(net, ly_cell_num,
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            return tf.layers.dense(net, 1, trainable=trainable)

    def experience(self, one_exp):
        print("The is experience method of ddpg-oprs-v1, {}".format(one_exp))

    def switch_optimization(self):
        self.optimize_policy = not self.optimize_policy
