import numpy as np
import tensorflow as tf
import random
from experiments.utils.ou_noise import OUNoise
from experiments.algorithms.ddpg.ddpg_algo import DDPGAlgo
from experiments.utils.common import Dataset

LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_F = 1e-4
GAMMA = 0.999
TAU = 0.01
OU_NOISE_THETA = 0.15
OU_NOISE_SIGMA = 0.5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX = 1.0
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN = 1e-5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX = 0.2
GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE = 60000
SHAPING_WEIGHT_SCALE_FACTOR = 10.0
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
F_GRADIENT_NORM_CLIP = 50
F_OPTIM_EPOCHS = 50
F_OPTIM_BATCH_SIZE = 1024

my_flatten = lambda x: [subitem for item in x for subitem in my_flatten(item)] if type(x) is list else [x]

"""
    approximate version of OPRS-V2
    use only few samples to compute the gradient of theta w.r.t phi
    but use many samples to compute policy gradient

"""


class DDPGOprsV2FsaqinAlgo(DDPGAlgo):

    def __init__(self, sess, graph, state_dim, action_dim, algo_name="ddpg_oprs_v2_fsaqin", **kwargs):
        self.optimize_policy = True
        self.episode_traj = []
        super(DDPGOprsV2FsaqinAlgo, self).__init__(sess, graph, state_dim, action_dim, algo_name, **kwargs)

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.tau = kwargs.get("tau", TAU)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.lr_f = kwargs.get("lr_f", LR_F)
        self.explo_method = kwargs.get("explo_method", "OU")
        if self.explo_method == "OU":
            ou_noise_theta = kwargs.get("ou_noise_theta", OU_NOISE_THETA)
            ou_noise_sigma = kwargs.get("ou_noise_sigma", OU_NOISE_SIGMA)
            self.ou_noise = OUNoise(self.action_dim, mu=0, theta=ou_noise_theta, sigma=ou_noise_sigma)
        elif self.explo_method == "GAUSSIAN_STATIC":
            self.gaussian_explo_ratio = kwargs.get("gaussian_explo_sigma_ratio_fix",
                                                   GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX)
        else:
            self.gaussian_explo_sigma_ratio_max = kwargs.get("gaussian_explo_sigma_ratio_max",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX)
            self.gaussian_explo_sigma_ratio_min = kwargs.get("gaussian_explo_sigma_ratio_min",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN)
            gaussian_explo_sigma_ratio_decay_ep = kwargs.get("gaussian_explo_sigma_ratio_decay_ep",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE)
            self.gaussian_explo_ratio = self.gaussian_explo_sigma_ratio_max
            self.gaussian_explo_decay_factor = pow(self.gaussian_explo_sigma_ratio_min /
                                                   self.gaussian_explo_sigma_ratio_max,
                                                   1.0 / gaussian_explo_sigma_ratio_decay_ep)

        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.f_grad_clip = kwargs.get("f_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.f_grad_norm_clilp = kwargs.get("f_gradient_norm_clip", F_GRADIENT_NORM_CLIP)

        """
            network layer cell numbers
        """
        self.actor_net_layers = kwargs.get("actor_net_layers", [4, 4])
        self.critic_net_layers = kwargs.get("critic_net_layers", [32, 32])
        self.critic_act_in_ly_index = int(kwargs.get("critic_action_input_layer_index", 1))
        self.f_net_layers = kwargs.get("f_net_layers", [16, 8])
        self.f_optim_epochs = kwargs.get("f_optim_epochs", F_OPTIM_EPOCHS)
        self.f_optim_batch_size = kwargs.get("f_optim_batch_size", F_OPTIM_BATCH_SIZE)

        """
            use nabla_{a} Q_{shaped}(s,a,f) or just Q_{shaped}(s,a,f) to
            compute the gradient of critic w.r.t. phi
        """
        self.enable_grad_q_shaped_wrt_a = kwargs.get("enable_grad_q_shaped_wrt_a", True)

        self.net_add_one = kwargs.get("net_add_one", False)
        self.f_hidden_layer_act_func = kwargs.get("f_hidden_layer_act_func", tf.nn.relu)

        self.f_output_layer_act_func = kwargs.get("f_output_layer_act_func", None)
        self.f_phi_min = kwargs.get("f_phi_min", None)
        self.f_phi_max = kwargs.get("f_phi_max", None)

    def init_networks(self):
        """
            firstly define place holders
        """
        self.define_place_holders()

        """
            then build networks, including
            actor network, shaping weight function, shaped critic
        """
        self.build_networks()

        """
            the next is to define trainers,
            including building target networks, loss, and trainers
        """
        self.define_trainers()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def define_place_holders(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state')
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state_prime')
                self.action_phd = tf.placeholder(tf.float32, [None, self.action_dim], name='action')

                """
                    the original reward R(s,a)
                """
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name='reward')

                """
                    the additional reward, namely F(s,a)
                """
                self.add_reward_phd = tf.placeholder(tf.float32, [None, 1], name='additional_reward')

                self.done_phd = tf.placeholder(tf.float32, [None, 1], name='done')
                self.temp = tf.Variable(1.0, name='temperature')

                """
                    shaping weight value of the next state s'
                """
                self.f_phi_sp_phd = tf.placeholder(tf.float32, [None, 1], name='f_phi_sp')

    def build_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    build the shaping weight function f_phi(s)
                    also build the copy of f_phi along a trajectory
                """
                self.f_phi = self._build_weight_func(self.state_phd, self.action_phd)

                """
                    build actor network, the input is state
                """
                self.actor_output = self._build_actor(self.state_phd, )

                """
                    build shaped critic network, which is for optimizing policy
                    the input is s, mu(s), f_phi(s)
                """
                self.shaped_critic = self._build_shaped_critic(self.state_phd, self.actor_output, self.f_phi)

                """
                    build true critic network, which is for optimization weight function f_phi
                    the input is state and action
                """
                self.true_critic = self._build_true_critic(self.state_phd, self.actor_output)

    def define_trainers(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    first build target network of actor, shaped critic, and true critic
                """
                a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Actor')
                shaped_c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Shaped_Critic')
                true_c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='True_Critic')
                self.f_phi_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func')
                ema = tf.train.ExponentialMovingAverage(decay=1 - self.tau)

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update_for_policy = [ema.apply(a_params), ema.apply(shaped_c_params)]
                target_update_for_f = [ema.apply(true_c_params)]

                """
                    target actor network, the input is next state s'
                """
                target_actor_output = self._build_actor(self.state_prime_phd, reuse=True, custom_getter=ema_getter)
                self.target_actor_output = target_actor_output

                """
                    target shaped critic network, the input is s', mu'(s'), and f_phi(s')
                """
                target_shaped_critic = self._build_shaped_critic(self.state_prime_phd, target_actor_output,
                                                                 self.f_phi_sp_phd, reuse=True,
                                                                 custom_getter=ema_getter)

                """
                    build target true critic network, the input is s' and mu'(s')
                """
                target_true_critic = self._build_true_critic(self.state_prime_phd, target_actor_output,
                                                             reuse=True, custom_getter=ema_getter)

                """
                    define optimization of policy according to shaped rewards
                    we must record the gradient of actor w.r.t the parameter theta
                """
                a_loss = tf.reduce_mean(-self.shaped_critic)
                if self.actor_grad_clip:
                    self.actor_param_gradients = tf.gradients(a_loss, a_params)
                    self.actor_clipped_grad_op, _ = tf.clip_by_global_norm(self.actor_param_gradients,
                                                                           self.actor_grad_norm_clip)
                    self.actor_opt = tf.train.AdamOptimizer(self.lr_actor)
                    self.trainer_actor = self.actor_opt.apply_gradients(zip(self.actor_clipped_grad_op, a_params))
                else:
                    self.trainer_actor = tf.train.AdamOptimizer(self.lr_actor).minimize(a_loss, var_list=a_params)

                with tf.name_scope('Actor_Loss'):
                    tf.summary.scalar('actor_exp_Q', a_loss)

                """
                    define optimization of shaped critic
                """
                with tf.control_dependencies(target_update_for_policy):
                    shaped_q_target = self.reward_phd + self.f_phi * self.add_reward_phd + \
                                      self.gamma * (1 - self.done_phd) * target_shaped_critic
                    shaped_td_error = tf.losses.mean_squared_error(labels=tf.stop_gradient(shaped_q_target),
                                                                   predictions=self.shaped_critic)
                    clipped_shaped_td_error = tf.minimum(shaped_td_error, 100.0)
                    if self.critic_grad_clip:
                        self.shp_c_param_gradients = tf.gradients(clipped_shaped_td_error, shaped_c_params)

                        self.shp_c_clipped_grad_op, _ = tf.clip_by_global_norm(self.shp_c_param_gradients,
                                                                               self.critic_grad_norm_clip)
                        self.shp_c_opt = tf.train.AdamOptimizer(self.lr_critic)
                        self.trainer_shaped_critic = self.shp_c_opt.apply_gradients(zip(self.shp_c_clipped_grad_op,
                                                                                        shaped_c_params))
                    else:
                        self.trainer_shaped_critic = tf.train.AdamOptimizer(self.lr_critic).minimize(
                            clipped_shaped_td_error, var_list=shaped_c_params)
                """
                    define optimization of f_phi
                    nabla_{phi} J = nabla_{theta} mu_{theta}(s) * 
                                        nabla_{phi} theta * 
                                        Q_{True}(s, mu(s))

                                     =  nabla_{theta} mu_{theta}(s) * 
                                        alpha * nabla_{theta'} mu_{theta'} * nabla_{phi} Q_{shaped}(s,mu)
                                        * Q_{True}(s, mu(s))
                """

                """
                    self.grad_return_wrt_phi = nabla_{phi} Q_{shaped}(s,mu)

                    it should be noted that although the shape of self.traj_f_phi is [None, 1] (one output for each input state)
                    but the gradient value is the sum over all gradients of the input states
                    that is the say, self.grad_return_wrt_phi is like [array(), array(), ..., array()],
                    where each array corresponds to the gradients of each layer's parameters
                    
                    please note that here we must make ys -self.shaped_critic
                """
                # self.grad_shaped_loss_wrt_a = tf.gradients(ys=-self.shaped_critic, xs=self.actor_output)
                self.grad_shaped_loss_wrt_a = tf.gradients(ys=-self.shaped_critic, xs=self.actor_output)
                if self.enable_grad_q_shaped_wrt_a:
                    self.grad_shaped_loss_wrt_phi = tf.gradients(ys=self.grad_shaped_loss_wrt_a, xs=self.f_phi_params)
                else:
                    # self.grad_shaped_loss_wrt_phi = tf.gradients(ys=-self.shaped_critic, xs=self.f_phi_params)
                    self.grad_shaped_loss_wrt_phi = tf.gradients(ys=-self.shaped_critic, xs=self.f_phi_params)


                self.f_phi_params_shapes = [None] * len(self.f_phi_params)
                self.theta_params_shapes = [None] * len(a_params)
                for i in range(len(self.f_phi_params)):
                    self.f_phi_params_shapes[i] = self.f_phi_params[i].shape
                    print("The f_phi param {} shape is {}".format(i, self.f_phi_params_shapes[i]))

                for i in range(len(a_params)):
                    self.theta_params_shapes[i] = a_params[i].shape
                    print("The actor param {} shape is {}".format(i, self.theta_params_shapes[i]))

                """
                    self.grad_mu_wrt_theta = nabla_{theta'} mu_{theta'} and nabla_{theta} mu_{theta}
                    the gradient of policy mu w.r.t to the parameter theta
                """
                self.grad_mu_wrt_theta = tf.gradients(ys=self.actor_output, xs=a_params)

                """
                    gradient of theta w.r.t. phi
                    which is aggragated every update of policy parameters
                    and will be averaged before the update of the shaping weight function
                """
                self.grad_theta_wrt_phi_aggr = None
                self.grad_aggr_num = 0

                """
                    the policy gradient is for optimizing weight function parameter \phi
                    nabla_{theta} mu(s) * Q_{True}(s,mu(s)) * (nabla_{phi} theta)
                """
                true_a_loss = tf.reduce_mean(self.true_critic)
                self.grad_true_loss_wrt_theta = tf.gradients(ys=-true_a_loss, xs=a_params)
                self.grad_true_loss_wrt_theta, _ = tf.clip_by_global_norm(self.grad_true_loss_wrt_theta,
                                                                          self.actor_grad_norm_clip)

                self.grad_J_wrt_phi = [None] * len(self.f_phi_params)

                for ly in range(len(self.f_phi_params_shapes)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    fake_gradient = np.full(shape=ly_shape, fill_value=0.0, dtype=np.float32)
                    self.grad_J_wrt_phi[ly] = fake_gradient

                """
                    another way for updating phi
                """
                self.optimizer_f = tf.train.AdamOptimizer(self.lr_f)
                self.f_phi_params_grad_phds = [None] * len(self.f_phi_params)
                for ly in range(len(self.f_phi_params)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    self.f_phi_params_grad_phds[ly] = tf.placeholder(tf.float32, shape=ly_shape,
                                                                     name="f_phi_phd_{}".format(ly))

                if self.f_grad_clip:
                    self.f_clipped_grad_op, _ = tf.clip_by_global_norm(self.f_phi_params_grad_phds,
                                                                       self.f_grad_norm_clilp)
                    self.trainer_f = self.optimizer_f.apply_gradients(zip(self.f_clipped_grad_op, self.f_phi_params))
                else:
                    self.trainer_f = self.optimizer_f.apply_gradients(zip(self.f_phi_params_grad_phds,
                                                                          self.f_phi_params))

                """
                    define optimization of true critic
                """
                with tf.control_dependencies(target_update_for_f):
                    true_q_target = self.reward_phd + self.gamma * (1 - self.done_phd) * target_true_critic
                    true_td_error = tf.losses.mean_squared_error(labels=tf.stop_gradient(true_q_target),
                                                                 predictions=self.true_critic)
                    clipped_true_td_error = tf.minimum(true_td_error, 100.0)
                    if self.critic_grad_clip:
                        self.true_c_param_gradients = tf.gradients(clipped_true_td_error, true_c_params)

                        self.true_c_clipped_grad_op, _ = tf.clip_by_global_norm(self.true_c_param_gradients,
                                                                                self.critic_grad_norm_clip)
                        self.true_c_opt = tf.train.AdamOptimizer(self.lr_critic)
                        self.trainer_true_critic = self.true_c_opt.apply_gradients(zip(self.true_c_clipped_grad_op,
                                                                                       true_c_params))
                    else:
                        self.trainer_true_critic = tf.train.AdamOptimizer(self.lr_critic).minimize(
                            clipped_true_td_error, var_list=true_c_params)

                with tf.name_scope('Critic_loss'):
                    tf.summary.scalar('shaped_td_error', shaped_td_error)
                    tf.summary.scalar('true_td_error', true_td_error)

    def choose_action(self, s, is_test):
        """
            firstly compute the action
        """
        action = self.sess.run(self.actor_output, {self.state_phd: [s]})
        action = action[0]

        """
            if currently it is test or optimizing the shaping weight function
        """
        if is_test or not self.optimize_policy:
            exe_action = action
        else:
            if self.explo_method == "OU":
                exe_action = action + self.ou_noise.noise()
            else:
                exe_action = np.random.normal(action, action * self.gaussian_explo_ratio)

        # print("exe_action is {}".format(exe_action))

        f_phi_sa = self.sess.run(self.f_phi, {
            self.state_phd: [s],
            self.action_phd: [exe_action]
        })
        if len(f_phi_sa.shape) == 2:
            f_phi_sa = f_phi_sa[0][0]

        if random.uniform(0, 1.0) < 1e-3:
            print("f_phi_s is {}".format(f_phi_sa))

        if self.f_phi_min is not None and self.f_phi_max is not None:
            f_phi_sa = np.minimum(self.f_phi_max, np.maximum(self.f_phi_min, f_phi_sa))

        return exe_action, {"f_phi_s": f_phi_sa}

    def learn(self, bs, ba, br, bs_, bdone, **kwargs):
        with self.sess.graph.as_default():
            if self.optimize_policy:
                """
                    oprs-v2 learns with trajectory batch, which is organized in the trainer
                """
                """
                    use all samples to optimize actor
                """
                bmu = self.sess.run(self.actor_output, {self.state_phd: bs})
                self.sess.run(self.trainer_actor, {self.state_phd: bs,
                                                   self.action_phd: bmu})

                """
                    for each state in the state batch
                    compute nabla_{theta} mu_{theta}(s) and nabla_{phi} Q_{shaped}(s, mu)

                    1. compute nabla_{phi} Q_{shaped}(s, mu) using the minibatch of the state
                    2. compute nabla_{theta} mu_{theta}(s)
                    3. conduct matrix multiplication of nabla_{phi} Q_{shaped}(s, mu) and nabla_{theta} mu_{theta}(s)
                    4. sum over the multiplication results of all states

                    only pick up few samples for computing gradient of theta w.r.t. phi
                """
                grad_compute_s_batch = kwargs.get("grad_compute_s_batch")
                # grad_compute_a_batch = kwargs.get("grad_compute_a_batch")
                grad_compute_a_batch = self.sess.run(self.actor_output, feed_dict={
                    self.state_phd: grad_compute_s_batch
                })

                grad_theta_wrt_phi = None
                for sample_index in range(len(grad_compute_s_batch)):
                    s = grad_compute_s_batch[sample_index]
                    a = grad_compute_a_batch[sample_index]
                    grad_shaped_loss_wrt_phi_s = self.sess.run(self.grad_shaped_loss_wrt_phi,
                                                                 feed_dict={self.state_phd: [s],
                                                                            self.action_phd: [a]})
                    grad_mu_wrt_theta_s = self.sess.run(self.grad_mu_wrt_theta, feed_dict={self.state_phd: [s]})

                    """
                        firstly transform the two list of nd-arrays into list of lists and record the corresponding shapes
                        flatten the two list of nd-arrays
                    """
                    for ly in range(len(grad_shaped_loss_wrt_phi_s)):
                        grad_shaped_loss_wrt_phi_s[ly] = grad_shaped_loss_wrt_phi_s[ly].tolist()

                    for ly in range(len(grad_mu_wrt_theta_s)):
                        grad_mu_wrt_theta_s[ly] = grad_mu_wrt_theta_s[ly].tolist()

                    flatten_grad_shaped_loss_wrt_phi_s = np.array(my_flatten(grad_shaped_loss_wrt_phi_s))
                    flatten_grad_mu_wrt_theta_s = np.array(my_flatten(grad_mu_wrt_theta_s))

                    grad_theta_wrt_phi_s = np.matmul(flatten_grad_mu_wrt_theta_s.reshape([-1, 1]),
                                                     np.atleast_2d(flatten_grad_shaped_loss_wrt_phi_s))

                    if grad_theta_wrt_phi is None:
                        grad_theta_wrt_phi = grad_theta_wrt_phi_s
                    else:
                        grad_theta_wrt_phi = grad_theta_wrt_phi + grad_theta_wrt_phi_s

                grad_theta_wrt_phi = np.multiply(grad_theta_wrt_phi, self.lr_actor / len(grad_compute_s_batch))

                """
                    then aggregate the gradient of this update
                """
                self.grad_aggr_num += 1
                if self.grad_theta_wrt_phi_aggr is None:
                    self.grad_theta_wrt_phi_aggr = grad_theta_wrt_phi
                else:
                    self.grad_theta_wrt_phi_aggr = self.grad_theta_wrt_phi_aggr + grad_theta_wrt_phi

                """
                    optimize shaped critic
                    first we should compute f_phi(s')
                    target shaped critic is Q(s', mu'(s',f(s')), f(s'))
                """
                # f_phi_s_batch = kwargs.get("f_phi_s")
                add_reward_batch = kwargs.get("F")
                f_phi_s_batch = self.sess.run(self.f_phi, feed_dict={self.state_phd: bs,
                                                                     self.action_phd: ba})

                # actor_output or target_actor_output ??
                ba_ = self.sess.run(self.target_actor_output, feed_dict={self.state_prime_phd: bs_})
                f_phi_sp_batch = self.sess.run(self.f_phi, feed_dict={self.state_phd: bs_,
                                                                      self.action_phd: ba_})

                self.sess.run(self.trainer_shaped_critic, feed_dict={self.state_phd: bs,
                                                                     self.actor_output: ba,
                                                                     self.reward_phd: br,
                                                                     self.add_reward_phd: add_reward_batch,
                                                                     self.f_phi: f_phi_s_batch,
                                                                     self.state_prime_phd: bs_,
                                                                     self.done_phd: bdone,
                                                                     self.f_phi_sp_phd: f_phi_sp_batch})
            else:
                print("UPdate shaping weight function")

                assert self.grad_aggr_num != 0
                assert self.grad_theta_wrt_phi_aggr is not None

                """
                    optimize shaping weight function
                    note that f is updating on-policy
                    so construct data set immediately
                """
                d = Dataset(dict(bs=bs, ba=ba, br=br, bs_=bs_, bdone=bdone), deterministic=False)
                for _ in range(self.f_optim_epochs):
                    for sub_mini_batch in d.iterate_once(self.f_optim_batch_size):
                        """
                            firstly, get the value of nabla_{phi} theta
                            this following value is 
                            nabla_{phi} theta = alpha * nabla_{theta'} mu_{theta'} * nabla_{phi} Q_{shaped}(s, mu)

                            the gradient of the objective w.r.t. phi is:
                            nabla_{theta} mu(s) * nabla_{phi} theta * Q_{True}(s, mu(s))
                            
                            this has been done in function switch_optimization

                            secondly, compute E_{s \sim rho} [nabla_{theta} mu(s) * Q_{True}(s,mu(s))]
                        """
                        grad_true_loss_wrt_theta = self.sess.run(self.grad_true_loss_wrt_theta,
                                                                 {self.state_phd: sub_mini_batch["bs"],
                                                                  self.actor_output: sub_mini_batch["ba"]})

                        for ly in range(len(grad_true_loss_wrt_theta)):
                            grad_true_loss_wrt_theta[ly] = grad_true_loss_wrt_theta[ly].tolist()

                        flatten_grad_true_loss_wrt_theta = np.array(my_flatten(grad_true_loss_wrt_theta))

                        """
                            then, conduct matrix multiplication and finally get nabla_{phi} J_{True}
                            we also should reshape the value according to the shape of parameters phi
                        """
                        grad_J_wrt_phi = np.matmul(np.atleast_2d(flatten_grad_true_loss_wrt_theta),
                                                   self.grad_theta_wrt_phi_aggr)
                        if len(grad_J_wrt_phi.shape) == 2 and grad_J_wrt_phi.shape[0] == 1:
                            grad_J_wrt_phi = grad_J_wrt_phi[0]

                        self.grad_J_wrt_phi = [None] * len(self.f_phi_params_shapes)
                        element_start = 0
                        element_end = 0  # exclusive
                        for ly in range(len(self.f_phi_params_shapes)):
                            ly_shape = self.f_phi_params_shapes[ly]
                            element_num = 1
                            for j in range(len(ly_shape)):
                                element_num *= int(ly_shape[j])

                            element_end = element_start + element_num
                            param_grad_seg = grad_J_wrt_phi[element_start:element_end]
                            param_grad_seg = np.array(param_grad_seg, dtype=np.float32).reshape(ly_shape)

                            # self.grad_J_wrt_phi.append(param_grad_seg)
                            self.grad_J_wrt_phi[ly] = param_grad_seg

                            element_start = element_end

                        # print("The recover gradient of J w.r.t. phi is {}".format(self.grad_J_wrt_phi))

                        f_phi_dict = {}
                        for ly in range(len(self.f_phi_params)):
                            f_phi_dict.update({self.f_phi_params_grad_phds[ly]: self.grad_J_wrt_phi[ly]})

                        """
                            finally, we use nabla_{phi} J_{True} to optimize f_phi
                        """
                        self.sess.run(self.trainer_f, feed_dict=f_phi_dict)

                        """
                            optimize true critic, which is Q(s', mu'(s',f(s'))
                        """
                        self.sess.run(self.trainer_true_critic,
                                      feed_dict={self.state_phd: sub_mini_batch["bs"],
                                                 self.actor_output: sub_mini_batch["ba"],
                                                 self.reward_phd: sub_mini_batch["br"],
                                                 self.state_prime_phd: sub_mini_batch["bs_"],
                                                 self.done_phd: sub_mini_batch["bdone"]
                                                 })

    def _build_weight_func(self, s, a, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([s, a], axis=1)
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_weight_func_reuse(self, s, a):
        with tf.variable_scope('Weight_Func', reuse=True):
            net = tf.concat([s, a], axis=1)
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=True)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=True)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_true_critic(self, s, action, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('True_Critic', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.critic_net_layers)):
                ly_cell_num = self.critic_net_layers[ly_index]
                if ly_index == self.critic_act_in_ly_index:
                    net = tf.concat([net, action], axis=1)

                net = tf.layers.dense(net, ly_cell_num,
                                      name='l' + str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            return tf.layers.dense(net, 1, trainable=trainable)

    def _build_shaped_critic(self, s, action, f_phi, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Shaped_Critic', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([s, f_phi], axis=1)
            for ly_index in range(len(self.critic_net_layers)):
                ly_cell_num = self.critic_net_layers[ly_index]
                if ly_index == self.critic_act_in_ly_index:
                    net = tf.concat([net, action], axis=1)

                net = tf.layers.dense(net, ly_cell_num,
                                      name='l' + str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            return tf.layers.dense(net, 1, trainable=trainable)

    def experience(self, one_exp):
        assert one_exp is not None
        self.episode_traj.append(one_exp)

    def episode_done(self, is_test):
        if self.optimize_policy and not is_test:
            if self.explo_method == "GAUSSIAN_DYNAMIC":
                self.gaussian_explo_ratio = max(self.gaussian_explo_sigma_ratio_min,
                                                self.gaussian_explo_ratio * self.gaussian_explo_decay_factor)

        """
            reset the record of episode trajectory and return the last one
        """
        if not is_test:
            traj = np.array(self.episode_traj)
            self.episode_traj = []
            return traj
        else:
            return None

    def switch_optimization(self):
        self.optimize_policy = not self.optimize_policy
        """
            clean the gradient of theta w.r.t. phi 
            if now it is to optimize the policy again
        """
        if self.optimize_policy:
            self.grad_theta_wrt_phi_aggr = None
            self.grad_aggr_num = 0
        else:
            self.grad_theta_wrt_phi_aggr = np.divide(self.grad_theta_wrt_phi_aggr, self.grad_aggr_num)
