import random
import numpy as np
import math
import tensorflow as tf
from ....utils.mlp_policy import MlpPolicyOprsV2
import experiments.utils.tf_util as U
from experiments.utils.common import Dataset, zipsame
from ..ppo_algo import PPOAlgo

LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_F = 1e-4
LR_H = 1e-2
GAMMA = 0.999

ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
F_GRADIENT_NORM_CLIP = 50.0

ENTROPY_COEFF = 0.0
RATIO_CLIP_PARAM = 0.2
ADAM_EPSILON = 1e-5
OPTIM_EPOCHS = 50
OPTIM_BATCH_SIZE = 1024

my_flatten = lambda x: [subitem for item in x for subitem in my_flatten(item)] if type(x) is list else [x]

"""
    For PPO algorithm, the updating of the shaping weight function f 
    should be conducted using on-policy mode, for short we call it "fop"
"""


class PPOOprsV3FsartAlgo(PPOAlgo):
    def __init__(self, sess, graph, state_space, action_space, algo_name="ppo_oprs_v3_fsart", **kwargs):
        self.optimize_policy = True
        super(PPOOprsV3FsartAlgo, self).__init__(sess, graph, state_space, action_space, algo_name, **kwargs)

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.lr_f = kwargs.get("lr_f", LR_F)
        self.lr_h = kwargs.get("lr_h", LR_H)
        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.f_grad_clip = kwargs.get("f_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.f_grad_norm_clip = kwargs.get("f_gradient_norm_clip", F_GRADIENT_NORM_CLIP)
        self.entropy_coeff = kwargs.get("entropy_coeff", ENTROPY_COEFF)
        self.ratio_clip_param = kwargs.get("ratio_clip_param", RATIO_CLIP_PARAM)
        self.adam_epsilon = kwargs.get("adam_epsilon", ADAM_EPSILON)
        self.optim_epochs = kwargs.get("optim_epochs", OPTIM_EPOCHS)
        self.optim_batch_size = kwargs.get("optim_batch_size", OPTIM_BATCH_SIZE)
        self.policy_net_layers = kwargs.get("policy_net_layers", [8, 8])
        self.v_net_layers = kwargs.get("v_net_layers", [32, 32])
        self.f_net_layers = kwargs.get("f_net_layers", [16, 8])
        self.gaussian_fixed_var = kwargs.get("gaussian_fixed_var", False)
        self.joint_opt = kwargs.get("joint_opt", False)
        self.use_f_old = kwargs.get("use_f_old", False)

        """
            whether compute hessian matrix of mu_{theta} when computing h
        """
        self.enable_hessian_computing = kwargs.get("enable_hessian_computing", False)

        """
            using outer product gradients approximation of 
        """
        self.hessian_opg_approx = kwargs.get("hessian_opg_approx", False)

        self.net_add_one = kwargs.get("net_add_one", False)
        self.f_hidden_layer_act_func = kwargs.get("f_hidden_layer_act_func", tf.nn.relu)
        self.f_output_layer_act_func = kwargs.get("f_output_layer_act_func", None)

    def init_networks(self):
        """
            init shaping weight function f's network
        """
        self.init_f_network()

        self.init_ppo_networks()

        """
            define losses
        """
        self.define_losses()
        self.define_f_loss()

        """
            then define optimizer for each network
        """
        self.define_optimizers()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def policy_fn(self, name):
        return MlpPolicyOprsV2(self.sess, self.graph, name=name, ob_space=self.state_space,
                               ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                               v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var,
                               f_output=self.f_phi_phd)  # None)

    def policy_fn_old(self, name):
        return MlpPolicyOprsV2(self.sess, self.graph, name=name, ob_space=self.state_space,
                               ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                               v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var,
                               f_output=self.f_phi_phd)  # None)

    def init_f_network(self):
        with self.sess.as_default():
            with self.graph.as_default():
                # self.state_phd_of_f = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='state_of_f')
                self.state_phd = U.get_placeholder_with_graph(name="ob", dtype=tf.float32,
                                                              shape=[None] + list(self.state_space.shape),
                                                              graph=self.graph)
                self.f_phi_phd = U.get_placeholder_with_graph(name="f_phi_phd", dtype=tf.float32,
                                                              shape=[None, 1], graph=self.graph)

                """
                    build the shaping weight function f_phi(s, a)
                    and we should deal with both discrete and continuous actions
                """
                if self.is_discrete:
                    self.action_phd_for_f = tf.placeholder(tf.int32, [None, ], name='action_for_f')
                    self.f_phi = self._build_weight_func(self.state_phd, tf.one_hot(self.action_phd_for_f,
                                                                                    self.action_dim))
                else:
                    self.action_phd_for_f = tf.placeholder(tf.float32, [None, self.action_space.shape[0]],
                                                           name='action_for_f')
                    self.f_phi = self._build_weight_func(self.state_phd, self.action_phd_for_f)

                """
                    build the old shaping weight function f_old
                    which is the input of pi_old
                """
                if self.use_f_old:
                    if self.is_discrete:
                        self.f_old = self._build_weight_func_old(self.state_phd, tf.one_hot(self.action_phd_for_f,
                                                                                            self.action_dim))
                    else:
                        self.f_old = self._build_weight_func_old(self.state_phd, self.action_phd_for_f)

                """
                   for computing nabla_{phi} R_{tau}(s,a) for each (s,a) 
                   we have compute the gradients of the state-action pairs along the trajectories
                   so we need the corresponding place holders
                """
                self.traj_state_phd = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='traj_state')
                self.traj_step_phd = tf.placeholder(tf.float32, [None, 1], name='traj_step')
                self.traj_F_phd = tf.placeholder(tf.float32, [None, 1], name='traj_F')
                if self.is_discrete:
                    self.traj_action_phd = tf.placeholder(tf.int32, [None, ], name='traj_action')
                    self.traj_f_phi = self._build_weight_func_reuse(self.traj_state_phd,
                                                                    tf.one_hot(self.traj_action_phd,
                                                                               self.action_dim))
                else:
                    self.traj_action_phd = tf.placeholder(tf.float32, [None, self.action_space.shape[0]],
                                                          name='traj_action')
                    self.traj_f_phi = self._build_weight_func_reuse(self.traj_state_phd, self.traj_action_phd)

    def init_ppo_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    now the policy network pi is defined over the state-f-action space!!
                """
                self.pi = self.policy_fn("pi")  # Construct network for new policy
                if self.use_f_old:
                    self.pi_old = self.policy_fn_old("oldpi")  # Network for old policy
                else:
                    self.pi_old = self.policy_fn("oldpi")

                """
                    shaped advantage and return
                """
                self.atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
                self.ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

                """
                    true advantage and return
                """
                self.atarg_true = tf.placeholder(dtype=tf.float32, shape=[None], name="adv_true")
                self.ret_true = tf.placeholder(dtype=tf.float32, shape=[None], name="ret_true")

                self.lrmult = tf.placeholder(name='lrmult', dtype=tf.float32,
                                             shape=[])  # learning rate multiplier, updated with schedule

                self.state_phd = U.get_placeholder_cached(name="ob")  # ob
                self.action_phd = self.pi.pdtype.sample_placeholder([None])  # ac

                """
                    define parameter assignment operation
                """
                # print("Pi variables are {}".format(self.pi.get_variables()))
                self.assign_old_eq_new = U.function([], [],
                                                    updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                             zipsame(self.pi_old.get_variables(),
                                                                     self.pi.get_variables())],
                                                    sess=self.sess,
                                                    graph=self.graph)

    def define_losses(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.policy_params = self.pi.get_policy_variables()
                self.policy_params_num = U.flatten_tensors(self.policy_params).get_shape().as_list()[0]
                self.critic_params = self.pi.get_critic_variables()
                self.true_critic_params = self.pi.get_true_critic_variables()
                self.f_phi_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func')
                if self.use_f_old:
                    self.f_old_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func_Old')

                self.kl_old_new = self.pi_old.pd.kl(self.pi.pd)
                self.ent = self.pi.pd.entropy()
                self.mean_kl = tf.reduce_mean(self.kl_old_new)
                self.mean_ent = tf.reduce_mean(self.ent)
                self.pol_ent_pen = (-self.entropy_coeff) * self.mean_ent

                # self.ratio = tf.exp(
                #     self.pi.pd.logp(self.action_phd) - self.pi_old.pd.logp(self.action_phd))  # pnew / pold
                self.ratio = tf.exp(tf.minimum(40.0, tf.maximum(-40.0,
                                                              self.pi.pd.logp(self.action_phd) -
                                                              self.pi_old.pd.logp(self.action_phd))))

                # self.surr1 = self.ratio * self.atarg  # surrogate from conservative policy iteration
                self.surr1 = tf.where(tf.logical_or(
                    tf.is_inf(self.ratio * self.atarg),
                    tf.is_nan(self.ratio * self.atarg)),
                    self.atarg, #tf.zeros_like(self.ratio),
                    self.ratio * self.atarg)

                self.surr2 = tf.clip_by_value(self.ratio, 1.0 - self.ratio_clip_param,
                                              1.0 + self.ratio_clip_param) * self.atarg  #
                self.pol_surr = - tf.reduce_mean(
                    tf.minimum(self.surr1, self.surr2))  # PPO's pessimistic surrogate (L^CLIP)

                """
                    loss of shaped value function
                """
                self.vf_loss = tf.reduce_mean(tf.square(self.pi.vpred - self.ret))
                self.total_loss = self.pol_surr + self.pol_ent_pen + self.vf_loss

                """
                    loss of true value function
                """
                self.vf_true_loss = tf.reduce_mean(tf.square(self.pi.vpred_true - self.ret_true))

                """
                    the true loss of policy
                    which is used for computing nabla_{phi} J
                """
                # self.surr1_f = self.ratio * self.atarg_true  # surrogate from conservative policy iteration
                self.surr1_f = tf.where(tf.logical_or(
                    tf.is_inf(self.ratio * self.atarg_true),
                    tf.is_nan(self.ratio * self.atarg_true)),
                    self.atarg_true, #tf.zeros_like(self.ratio),
                    self.ratio * self.atarg_true)

                self.surr2_f = tf.clip_by_value(self.ratio, 1.0 - self.ratio_clip_param,
                                                1.0 + self.ratio_clip_param) * self.atarg_true  #
                self.pol_surr_true = - tf.reduce_mean(tf.minimum(self.surr1_f, self.surr2_f))

                self.losses = [self.pol_surr, self.pol_ent_pen, self.vf_loss, self.vf_true_loss,
                               self.mean_kl, self.mean_ent]
                self.loss_names = ["pol_surr", "pol_entpen", "vf_loss", "vf_true_loss", "kl", "ent"]

    def define_f_loss(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    the loss of shaping weight function f, which is quite complicated

                    define optimization of f_phi
                    nabla_{phi} J = { nabla_{theta} log pi_{theta}(s,a) } * 
                                    { nabla_{phi} theta } * 
                                    Q_{True}(s, a)

                                  = { nabla_{theta} log pi_{theta}(s,a) } * 
                                    alpha * { nabla_{theta'} log pi_{theta'}(s',a') } * { nabla_{phi} R_{tau} } * 
                                    Q_{True}(s, a)

                                  = { nabla_{theta} log pi_{theta}(s,a) } * 
                                    alpha * { nabla_{theta'} log pi_{theta'}(s',a') } * 
                                    { sum_{i=0}^{|tau|-1 } gamma^i F(s_i,a_i) nabla_{phi} f_{phi}(s_i) } *
                                    Q_{True}(s, a)
                """
                """
                    self.grad_return_wrt_phi = sum_{i=0}^{|tau|-1 } gamma^i F(s_i,a_i) nabla_{phi} f_{phi}(s_i)

                    it should be noted that although the shape of self.traj_f_phi is [None, 1] (one output for each input state)
                    but the gradient value is the sum over all gradients of the input states
                    that is the say, self.grad_return_wrt_phi is like [array(), array(), ..., array()],
                    where each array corresponds to the gradients of each layer's parameters

                """
                self.grad_return_wrt_phi = tf.gradients(ys=self.traj_f_phi, xs=self.f_phi_params,
                                                        grad_ys=tf.multiply(tf.pow(self.gamma, self.traj_step_phd),
                                                                            self.traj_F_phd))

                self.f_phi_params_shapes = [None] * len(self.f_phi_params)
                self.theta_params_shapes = [None] * len(self.policy_params)
                for i in range(len(self.f_phi_params)):
                    self.f_phi_params_shapes[i] = self.f_phi_params[i].shape
                    print("The f_phi param {} shape is {}".format(i, self.f_phi_params_shapes[i]))

                for i in range(len(self.policy_params)):
                    self.theta_params_shapes[i] = self.policy_params[i].shape
                    print("The policy param {} shape is {}".format(i, self.theta_params_shapes[i]))

                """
                   define function logpi = log pi_{theta}(s,a)
                   
                   grad_logpi_wrt_theta = { nabla_{theta'} log pi_{theta'}(s',a') } and 
                   { nabla_{theta} log pi_{theta}(s,a) }

                   the gradient of log policy w.r.t to the parameter theta
                """
                self.logpi = self.pi.pd.logp(self.action_phd)
                self.grad_logpi_wrt_theta = tf.gradients(ys=self.pi.pd.logp(self.action_phd),
                                                         xs=self.policy_params)
                # self.grad_logpi_wrt_theta, _ = tf.clip_by_global_norm(self.grad_logpi_wrt_theta,
                #                                                       self.actor_grad_norm_clip)

                """
                    the gradient of theta w.r.t phi
                    which is accumulated step by step
                """
                self.h = None
                if self.enable_hessian_computing:
                    if self.hessian_opg_approx:
                        self.hessian_theta = self.get_logpi_opg_op(1)
                    else:
                        self.hessian_theta = self.get_logpi_hessian_op(self.policy_params_num)

                """
                   gradient of theta w.r.t. phi
                   which is aggragated every update of policy parameters
                   and will be averaged before the update of the shaping weight function
                """
                self.grad_theta_wrt_phi_aggr = None
                self.grad_aggr_num = 0

                """
                    the policy gradient is for optimizing weight function parameter phi
                    { nabla_{theta} log pi_{theta}(s,a) } * Q_{True}(s,a) * { nabla_{phi} theta }
                """
                self.grad_true_loss_wrt_theta = tf.gradients(ys=self.pol_surr_true, xs=self.policy_params)
                self.grad_true_loss_wrt_theta, _ = tf.clip_by_global_norm(self.grad_true_loss_wrt_theta,
                                                                          self.actor_grad_norm_clip)

                self.grad_J_wrt_phi = [None] * len(self.f_phi_params)

                for ly in range(len(self.f_phi_params_shapes)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    fake_gradient = np.full(shape=ly_shape, fill_value=0.0, dtype=np.float32)
                    self.grad_J_wrt_phi[ly] = fake_gradient

    def define_optimizers(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    define policy optimizer and policy trainer
                """
                self.policy_opt = tf.train.AdamOptimizer(self.lr_actor)
                if self.actor_grad_clip:
                    self.policy_gradients = tf.gradients(self.total_loss, self.policy_params)
                    self.policy_clipped_grad_op, _ = tf.clip_by_global_norm(self.policy_gradients,
                                                                            self.actor_grad_norm_clip)
                    self.policy_trainer = self.policy_opt.apply_gradients(zip(self.policy_clipped_grad_op,
                                                                              self.policy_params))
                else:
                    self.policy_trainer = self.policy_opt.minimize(loss=self.total_loss, var_list=self.policy_params)

                """
                    define critic optimizer and critic trainer
                """
                self.critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                if self.critic_grad_clip:
                    self.critic_gradients = tf.gradients(self.vf_loss, self.critic_params)
                    self.critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.critic_gradients,
                                                                            self.critic_grad_norm_clip)
                    self.critic_trainer = self.critic_opt.apply_gradients(zip(self.critic_clipped_grad_op,
                                                                              self.critic_params))
                else:
                    self.critic_trainer = self.critic_opt.minimize(self.vf_loss, var_list=self.critic_params)

                """
                    define true critic optimizer and true critic trainer
                """
                self.true_critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                if self.critic_grad_clip:
                    self.true_critic_gradients = tf.gradients(self.vf_true_loss, self.true_critic_params)
                    self.true_critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.true_critic_gradients,
                                                                                 self.critic_grad_norm_clip)
                    self.true_critic_trainer = self.true_critic_opt.apply_gradients(
                        zip(self.true_critic_clipped_grad_op,
                            self.true_critic_params))
                else:
                    self.true_critic_trainer = self.true_critic_opt.minimize(self.vf_true_loss,
                                                                             var_list=self.true_critic_params)

                """
                    another way for updating phi
                """
                self.optimizer_f = tf.train.AdamOptimizer(self.lr_f)
                self.f_phi_params_grad_phds = [None] * len(self.f_phi_params)
                for ly in range(len(self.f_phi_params)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    self.f_phi_params_grad_phds[ly] = tf.placeholder(tf.float32, shape=ly_shape,
                                                                     name="f_phi_phd_{}".format(ly))

                if self.f_grad_clip:
                    self.f_clipped_grad_op, _ = tf.clip_by_global_norm(self.f_phi_params_grad_phds,
                                                                       self.f_grad_norm_clip)
                    self.trainer_f = self.optimizer_f.apply_gradients(zip(self.f_clipped_grad_op, self.f_phi_params))
                else:
                    self.trainer_f = self.optimizer_f.apply_gradients(zip(self.f_phi_params_grad_phds,
                                                                          self.f_phi_params))

                    for ly in range(len(self.f_phi_params)):
                        print("f_phi param {} shape is {}".format(ly, self.f_phi_params[ly].shape))
                        print("f_phi_param_grad_phd {} shape is {}".format(ly, self.f_phi_params_grad_phds[ly].shape))

                self.compute_losses = U.function([self.state_phd, self.action_phd,
                                                  self.atarg, self.ret, self.lrmult,
                                                  self.atarg_true, self.ret_true],
                                                 self.losses,
                                                 sess=self.sess,
                                                 graph=self.graph)

    def choose_action(self, s, is_test):
        with self.graph.as_default():
            # action, vpred, vpred_true = self.pi.act(stochastic=True, ob=s)
            action = self.pi.act(stochastic=True, ob=s)
            action = action[0]

            f_phi_sa = self.sess.run(self.f_phi, {self.state_phd: [s],
                                                  self.action_phd_for_f: [action]})
            if len(f_phi_sa.shape) == 2:
                f_phi_sa = f_phi_sa[0][0]

            # if random.uniform(0, 1.0) < 1e-3:
            #     print("f_phi_s is {}".format(f_phi_sa))

            assert not np.isnan(f_phi_sa) and not math.isnan(f_phi_sa)

            vpred, vpred_true = self.pi.v_predict(ob=s, f_output=f_phi_sa)
            return action, {"v_pred": vpred, "v_pred_true": vpred_true, "f_phi_s": f_phi_sa}

    def learn(self, **kwargs):
        if self.optimize_policy:
            self.update_policy(**kwargs)
        else:
            self.update_shaping_weight_func(**kwargs)

    def update_policy(self, **kwargs):
        with self.graph.as_default():
            """
                oprs-v2 learns with trajectory batch, which is organized in the trainer
                use all samples to optimize actor

                get:
                advantage values
                td-lambda returns
                state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv")
            batch_td_lam_ret = kwargs.get("td_lam_ret")
            batch_f_phi = kwargs.get("batch_f_phi")

            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(bs=bs, ba=ba, badv=batch_adv, bret=batch_td_lam_ret, bf=batch_f_phi),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(bs)

            """
                set old parameter values to new parameter values
            """
            self.assign_old_eq_new()

            """
                set f parameters values to f_old parameters
            """
            if self.use_f_old:
                self.sess.run([old_param.assign(param) for old_param, param in
                               zip(self.f_old_params, self.f_phi_params)])

            # param_values = self.sess.run(self.policy_params)
            # print("Before update the parameters are {}".format(param_values))

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                losses = []  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(batch_size):
                    # print("The batch is {}".format(batch))

                    _, _, policy_loss, v_loss, kl = self.sess.run([self.policy_trainer, self.critic_trainer,
                                                                   self.pol_surr, self.vf_loss, self.mean_kl],
                                                                  feed_dict={
                                                                      self.state_phd: batch["bs"],
                                                                      self.action_phd: batch["ba"],
                                                                      self.atarg: batch["badv"],
                                                                      self.ret: batch["bret"],
                                                                      self.action_phd_for_f: batch["ba"],
                                                                      self.f_phi_phd: batch["bf"].reshape([-1, 1])
                                                                  })

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "policy_loss", policy_loss)
                    self.write_summary_scalar(self.update_cnt, "v_loss", v_loss)
                    self.write_summary_scalar(self.update_cnt, "mean_kl", kl)
                    self.update_cnt += 1

            """
                for each state in the state batch
                compute { nabla_{theta} log pi_{theta}(s,a) } and { nabla_{phi} R_{tau}(s,a) }

                1. compute nabla_{phi} R_{tau}(s, a) using the minibatch of the state
                2. compute nabla_{theta} log pi_{theta}(s, a)
                3. conduct matrix multiplication of nabla_{phi} R_{tau}(s, a) and nabla_{theta} log pi_{theta}(s, a)
                4. sum over the multiplication results of all states

                only pick up few samples for computing gradient of theta w.r.t. phi
            """
            mini_batches = kwargs.get("mini_batches")
            grad_compute_s_batch = kwargs.get("grad_compute_s_batch")
            grad_compute_a_batch = kwargs.get("grad_compute_a_batch")

            grad_compute_shaped_ret_batch = kwargs.get("grad_compute_shaped_ret_batch")

            assert len(mini_batches) == len(grad_compute_s_batch)

            """
                the first part of nabla_{phi} Delta(theta)
                which is { nabla_{theta} logpi_{theta}(s,a) nabla_{phi} R_{tau}(s,a) }
            """
            grad_delta_theta_wrt_phi = self.compute_grad_delta_theta_phi_first_half(grad_compute_s_batch,
                                                                                    grad_compute_a_batch,
                                                                                    mini_batches)

            """
                compute the second part of the gradient nabla_{phi} delta_{theta}
                if computing hessian matrix of theta is allowed
                which is hessian_{theta} logpi_{theta}(s,a) * h_t) * R_{tau}(s,a)
            """
            if self.enable_hessian_computing:
                grad_delta_theta_wrt_phi_2p = self.compute_grad_delta_theta_phi_second_half(grad_compute_s_batch,
                                                                                            grad_compute_a_batch,
                                                                                            grad_compute_shaped_ret_batch)
                if grad_delta_theta_wrt_phi_2p is not None:
                    grad_delta_theta_wrt_phi = grad_delta_theta_wrt_phi + grad_delta_theta_wrt_phi_2p


            """
                then aggregate the gradient of this update
            """
            self.grad_aggr_num += 1
            if self.grad_theta_wrt_phi_aggr is None:
                self.grad_theta_wrt_phi_aggr = grad_delta_theta_wrt_phi
            else:
                self.grad_theta_wrt_phi_aggr = self.grad_theta_wrt_phi_aggr + grad_delta_theta_wrt_phi

    def update_shaping_weight_func(self, **kwargs):
        with self.graph.as_default():
            assert self.h is not None

            """
                get:
                true advantage values
                true td-lambda returns
                true state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv_true")
            batch_ret = kwargs.get("ret_true")

            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(bs=bs, ba=ba, badv=batch_adv, bret=batch_ret),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            """
                set old parameter values to new parameter values
            """
            # self.assign_old_eq_new()

            """
                set f parameters values to f_old parameters
            """
            if self.use_f_old:
                self.sess.run([old_param.assign(param) for old_param, param in
                               zip(self.f_old_params, self.f_phi_params)])

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                for batch in d.iterate_once(batch_size):
                    """
                        compute E_{s \sim rho} [nabla_{theta} log pi_{theta}(s,a) * Q_{True}(s,a)]
                    """
                    grad_true_loss_wrt_theta, true_pol_loss = self.sess.run([self.grad_true_loss_wrt_theta,
                                                                             self.pol_surr_true],
                                                             {self.state_phd: batch["bs"],
                                                              self.action_phd: batch["ba"],
                                                              self.atarg_true: batch["badv"]})

                    for ly in range(len(grad_true_loss_wrt_theta)):
                        grad_true_loss_wrt_theta[ly] = grad_true_loss_wrt_theta[ly].tolist()

                    flatten_grad_true_loss_wrt_theta = np.array(my_flatten(grad_true_loss_wrt_theta))

                    """
                        then, conduct matrix multiplication and finally get nabla_{phi} J_{True}
                        we also should reshape the value according to the shape of parameters phi
                    """
                    grad_J_wrt_phi = np.matmul(np.atleast_2d(flatten_grad_true_loss_wrt_theta),
                                               self.h)

                    # if random.uniform(0, 1.0) < 1e-3:
                    #     print("flatten grad true loss wrt theta {}".format(flatten_grad_true_loss_wrt_theta))
                    #     print("self.h {}".format(self.h))
                    #     print("The grad_J_wrt_phi is {}".format(grad_J_wrt_phi))

                    if len(grad_J_wrt_phi.shape) == 2 and grad_J_wrt_phi.shape[0] == 1:
                        grad_J_wrt_phi = grad_J_wrt_phi[0]

                    self.grad_J_wrt_phi = [None] * len(self.f_phi_params_shapes)
                    element_start = 0
                    element_end = 0  # exclusive
                    for ly in range(len(self.f_phi_params_shapes)):
                        ly_shape = self.f_phi_params_shapes[ly]
                        element_num = 1
                        for j in range(len(ly_shape)):
                            element_num *= int(ly_shape[j])

                        element_end = element_start + element_num
                        param_grad_seg = grad_J_wrt_phi[element_start:element_end]
                        param_grad_seg = np.array(param_grad_seg, dtype=np.float32).reshape(ly_shape)

                        self.grad_J_wrt_phi[ly] = param_grad_seg

                        element_start = element_end

                    # print("The recover gradient of J w.r.t. phi is {}".format(self.grad_J_wrt_phi))

                    f_phi_dict = {}
                    for ly in range(len(self.f_phi_params)):
                        f_phi_dict.update({self.f_phi_params_grad_phds[ly]: self.grad_J_wrt_phi[ly]})

                    """
                        finally, we use nabla_{phi} J_{True} to optimize f_phi
                    """
                    self.sess.run(self.trainer_f, feed_dict=f_phi_dict)

                    _, v_true_loss = self.sess.run([self.true_critic_trainer, self.vf_true_loss],
                                                   feed_dict={
                                                       self.state_phd: batch["bs"],
                                                       self.action_phd: batch["ba"],
                                                       self.atarg_true: batch["badv"],
                                                       self.ret_true: batch["bret"]
                                                   })

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "v_true_loss", v_true_loss)
                    self.update_cnt += 1

    def _build_weight_func(self, s, a, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([s, a], axis=1)
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_weight_func_reuse(self, s, a):
        with tf.variable_scope('Weight_Func', reuse=True):
            net = tf.concat([s, a], axis=1)
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=True)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=True)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_weight_func_old(self, s, a, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func_Old', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([s, a], axis=1)
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def switch_optimization(self):
        self.optimize_policy = not self.optimize_policy
        """
            clean the gradient of theta w.r.t. phi 
            if now it is to optimize the policy again
        """
        if self.optimize_policy:
            self.grad_theta_wrt_phi_aggr = None
            self.grad_aggr_num = 0
        else:
            # self.grad_theta_wrt_phi_aggr = np.divide(self.grad_theta_wrt_phi_aggr, self.grad_aggr_num)

            """
                in oprs-v3, the gradient of theta w.r.t. phi at each step
                should be added to the accumulative gradient variable h
            """
            self.grad_theta_wrt_phi_aggr = np.divide(self.grad_theta_wrt_phi_aggr, self.grad_aggr_num)
            if self.h is None:
                self.h = np.multiply(self.grad_theta_wrt_phi_aggr, self.lr_h)
            else:
                self.h = np.multiply(self.h, 1 - self.lr_h) + np.multiply(self.grad_theta_wrt_phi_aggr, self.lr_h)

    def get_logpi_Hv_op(self, v):
        """
            For the parameters theta of the policy network pi
            Implements a Hessian vector product estimator Hv op defined as the
            matrix multiplication of the Hessian matrix H with the vector v.

            Args:
                v: Vector to multiply with Hessian (tensor)

            Returns:
                Hv_op: Hessian vector product op (tensor)
        """
        # remember to limit the norm
        cost_gradient = tf.gradients(self.logpi, self.policy_params)
        cost_gradient, _ = tf.clip_by_global_norm(cost_gradient, self.actor_grad_norm_clip)
        cost_gradient = U.flatten_tensors(cost_gradient)

        vprod = tf.math.multiply(cost_gradient, tf.stop_gradient(v))

        Hv_op = tf.gradients(vprod, self.policy_params)
        Hv_op, _ = tf.clip_by_global_norm(Hv_op, self.actor_grad_norm_clip)
        Hv_op = U.flatten_tensors(Hv_op)
        return Hv_op

    def get_logpi_hessian_op(self, var_num):
        """
            For the parameters theta of the policy network
            Implements a full Hessian estimator op by forming p Hessian vector
            products using HessianEstimator.get_Hv_op(v) for all v's in R^P

            Args:
                var_num: the number of all variables of the neural network

            Returns:
                H_op: Hessian matrix op (tensor)
        """
        hessian_op = tf.map_fn(self.get_logpi_Hv_op, tf.eye(var_num, var_num), dtype='float32')
        return hessian_op

    def get_logpi_opg_op(self, batch_size_G):
        """
            For the parameters theta of the actor network
            Implements a Hessian matrix OPG approximation op by a per-example
            cost Jacobian matrix product

            Args:
                None

            Returns:
                G_op: Hessian matrix OPG approximation op (tensor)
        """

        ex_net_output = tf.split(self.logpi, batch_size_G)
        ex_grads = tf.stack([U.flatten_tensors(tf.gradients(ex_net_output[ex], self.policy_params))
                             for ex in range(batch_size_G)])

        opg_op = tf.matmul(tf.transpose(ex_grads), ex_grads) / batch_size_G
        return opg_op

    def compute_grad_delta_theta_phi_first_half(self, s_batch, a_batch, mini_batches):
        """
            compute the first half of the gradient nabla_{phi} Delta{theta},
            which is { nabla_{theta} logpi_{theta}(s,a) nabla_{phi} R_{tau}(s,a) }
        """
        grad_theta_wrt_phi_1p = None
        for sample_index in range(len(s_batch)):
            s = s_batch[sample_index]
            a = a_batch[sample_index]
            mini_batch_s = mini_batches[sample_index]

            """
                [state, action, step, F]
            """
            mini_feed_dict = {}
            mini_feed_dict.update({self.traj_state_phd: mini_batch_s[0]})
            mini_feed_dict.update({self.traj_action_phd: mini_batch_s[1]})
            mini_feed_dict.update({self.traj_step_phd: mini_batch_s[2]})
            mini_feed_dict.update({self.traj_F_phd: mini_batch_s[3]})

            grad_return_wrt_phi_sa = self.sess.run(self.grad_return_wrt_phi, feed_dict=mini_feed_dict)
            grad_logpi_wrt_theta_sa = self.sess.run(self.grad_logpi_wrt_theta,
                                                    feed_dict={self.state_phd: [s],
                                                               self.action_phd: [a]})

            """
                firstly transform the two list of nd-arrays into list of lists and record the corresponding shapes
                flatten the two list of nd-arrays
            """
            for ly in range(len(grad_return_wrt_phi_sa)):
                grad_return_wrt_phi_sa[ly] = grad_return_wrt_phi_sa[ly].tolist()

            for ly in range(len(grad_logpi_wrt_theta_sa)):
                grad_logpi_wrt_theta_sa[ly] = grad_logpi_wrt_theta_sa[ly].tolist()

            flatten_grad_return_wrt_phi_sa = np.array(my_flatten(grad_return_wrt_phi_sa))
            flatten_grad_logpi_wrt_theta_sa = np.array(my_flatten(grad_logpi_wrt_theta_sa))

            grad_theta_wrt_phi_sa = np.matmul(flatten_grad_logpi_wrt_theta_sa.reshape([-1, 1]),
                                              np.atleast_2d(flatten_grad_return_wrt_phi_sa))

            if grad_theta_wrt_phi_1p is None:
                grad_theta_wrt_phi_1p = grad_theta_wrt_phi_sa
            else:
                grad_theta_wrt_phi_1p = grad_theta_wrt_phi_1p + grad_theta_wrt_phi_sa

        grad_theta_wrt_phi_1p = np.multiply(grad_theta_wrt_phi_1p, self.lr_actor / len(s_batch))
        return grad_theta_wrt_phi_1p

    def compute_grad_delta_theta_phi_second_half(self, s_batch, a_batch, shaped_ret_batch):
        if self.h is None:
            return None

        """
            the second part of nabla_{phi} Delta{theta},
            which is hessian_{theta} logpi_{theta} (s,a) * h_t * R_{tau}(s,a)
        """
        grad_delta_theta_wrt_phi_2p = None
        for sample_index in range(len(s_batch)):
            s = s_batch[sample_index]
            a = a_batch[sample_index]
            shaped_ret = shaped_ret_batch[sample_index]

            """
                the hessian matrix of the policy parameter theta for function logpi 
                or its OPG approximation
            """
            hessian_theta_sa = self.sess.run(self.hessian_theta, feed_dict={self.state_phd: [s],
                                                                            self.action_phd: [a]})
            """
                matrix multiplication of hessian matrix and h_t
                h_t shape [n, m]
            """
            # print("The shape of h_t is {}".format(self.h.shape))

            hessian_mul_ht_sa = np.matmul(hessian_theta_sa, self.h)

            """
                matrix multiplication of the above result and R_{tau}(s,a)
            """
            hessian_mul_ht_mult_grad = np.multiply(hessian_mul_ht_sa, shaped_ret)
            # ???
            # hessian_mul_ht_mult_grad = -hessian_mul_ht_mult_grad

            if grad_delta_theta_wrt_phi_2p is None:
                grad_delta_theta_wrt_phi_2p = hessian_mul_ht_mult_grad
            else:
                grad_delta_theta_wrt_phi_2p = grad_delta_theta_wrt_phi_2p + hessian_mul_ht_mult_grad

        grad_delta_theta_wrt_phi_2p = np.multiply(grad_delta_theta_wrt_phi_2p, self.lr_actor / len(s_batch))
        return grad_delta_theta_wrt_phi_2p
