import random
import numpy as np
import tensorflow as tf
from ....utils.mlp_policy import MlpPolicyOprsV2

import experiments.utils.tf_util as U
from experiments.utils.common import Dataset, zipsame
from ..ppo_algo import PPOAlgo

LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_F = 0.00
GAMMA = 0.999

ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
F_GRADIENT_NORM_CLIP = 50.0

ENTROPY_COEFF = 0.0
RATIO_CLIP_PARAM = 0.2
ADAM_EPSILON = 1e-5
OPTIM_EPOCHS = 50
OPTIM_BATCH_SIZE = 1024

my_flatten = lambda x: [subitem for item in x for subitem in my_flatten(item)] if type(x) is list else [x]


"""
    For PPO algorithm, the updating of the shaping weight function f 
    should be conducted using on-policy mode, for short we call it "fop"
"""


class PPOOprsV2FopAlgo(PPOAlgo):
    def __init__(self, sess, graph, state_space, action_space, algo_name="ppo_oprs_v2_fop", **kwargs):
        self.optimize_policy = True
        super(PPOOprsV2FopAlgo, self).__init__(sess, graph, state_space, action_space, algo_name, **kwargs)

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.lr_f = kwargs.get("lr_f", LR_F)
        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.f_grad_clip = kwargs.get("f_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.f_grad_norm_clip = kwargs.get("f_gradient_norm_clip", F_GRADIENT_NORM_CLIP)
        self.entropy_coeff = kwargs.get("entropy_coeff", ENTROPY_COEFF)
        self.ratio_clip_param = kwargs.get("ratio_clip_param", RATIO_CLIP_PARAM)
        self.adam_epsilon = kwargs.get("adam_epsilon", ADAM_EPSILON)
        self.optim_epochs = kwargs.get("optim_epochs", OPTIM_EPOCHS)
        self.optim_batch_size = kwargs.get("optim_batch_size", OPTIM_BATCH_SIZE)
        self.policy_net_layers = kwargs.get("policy_net_layers", [8, 8])
        self.v_net_layers = kwargs.get("v_net_layers", [32, 32])
        self.f_net_layers = kwargs.get("f_net_layers", [16, 8])
        self.gaussian_fixed_var = kwargs.get("gaussian_fixed_var", False)
        self.joint_opt = kwargs.get("joint_opt", False)
        self.use_f_old = kwargs.get("use_f_old", False)
        self.net_add_one = kwargs.get("net_add_one", False)
        self.f_hidden_layer_act_func = kwargs.get("f_hidden_layer_act_func", tf.nn.relu)
        self.f_output_layer_act_func = kwargs.get("f_output_layer_act_func", None)

    def init_networks(self):
        """
            init shaping weight function f's network
        """
        self.init_f_network()

        self.init_ppo_networks()

        """
            define losses
        """
        self.define_losses()
        self.define_f_loss()

        """
            then define optimizer for each network
        """
        self.define_optimizers()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def policy_fn(self, name):
        return MlpPolicyOprsV2(self.sess, self.graph, name=name, ob_space=self.state_space,
                               ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                               v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var,
                               f_output=self.f_phi)

    def policy_fn_old(self, name):
        return MlpPolicyOprsV2(self.sess, self.graph, name=name, ob_space=self.state_space,
                          ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                          v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var,
                          f_output=self.f_old)

    def init_f_network(self):
        with self.sess.as_default():
            with self.graph.as_default():
                # self.state_phd_of_f = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='state_of_f')
                self.state_phd = U.get_placeholder_with_graph(name="ob", dtype=tf.float32,
                                          shape=[None] + list(self.state_space.shape),
                                          graph=self.graph)

                """
                    build the shaping weight function f_phi(s)
                """
                self.f_phi = self._build_weight_func(self.state_phd, )

                """
                    build the old shaping weight function f_old
                    which is the input of pi_old
                """
                if self.use_f_old:
                    self.f_old = self._build_weight_func_old(self.state_phd, )

                """
                   for computing \\nabla_{\phi} R_{tau}(s,a) for each (s,a) 
                   we have compute the gradients of the state-action pairs along the trajectories
                   so we need the corresponding place holders
                """
                self.traj_state_phd = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='traj_state')
                self.traj_step_phd = tf.placeholder(tf.float32, [None, 1], name='traj_step')
                self.traj_F_phd = tf.placeholder(tf.float32, [None, 1], name='traj_F')
                self.traj_f_phi = self._build_weight_func_reuse(self.traj_state_phd)

    def init_ppo_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    now the policy network pi is defined over the state-f-action space!!
                """
                self.pi = self.policy_fn("pi")  # Construct network for new policy
                if self.use_f_old:
                    self.pi_old = self.policy_fn_old("oldpi")  # Network for old policy
                else:
                    self.pi_old = self.policy_fn("oldpi")

                """
                    shaped advantage and return
                """
                self.atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
                self.ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

                """
                    true advantage and return
                """
                self.atarg_true = tf.placeholder(dtype=tf.float32, shape=[None], name="adv_true")
                self.ret_true = tf.placeholder(dtype=tf.float32, shape=[None], name="ret_true")

                self.lrmult = tf.placeholder(name='lrmult', dtype=tf.float32,
                                             shape=[])  # learning rate multiplier, updated with schedule

                self.state_phd = U.get_placeholder_cached(name="ob")  # ob
                self.action_phd = self.pi.pdtype.sample_placeholder([None])  # ac

                """
                    define parameter assignment operation
                """
                # print("Pi variables are {}".format(self.pi.get_variables()))
                self.assign_old_eq_new = U.function([], [],
                                                    updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                             zipsame(self.pi_old.get_variables(),
                                                                     self.pi.get_variables())],
                                                    sess=self.sess,
                                                    graph=self.graph)

    def define_losses(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.policy_params = self.pi.get_policy_variables()
                self.critic_params = self.pi.get_critic_variables()
                self.true_critic_params = self.pi.get_true_critic_variables()
                self.f_phi_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func')
                if self.use_f_old:
                    self.f_old_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func_Old')

                self.kl_old_new = self.pi_old.pd.kl(self.pi.pd)
                self.ent = self.pi.pd.entropy()
                self.mean_kl = tf.reduce_mean(self.kl_old_new)
                self.mean_ent = tf.reduce_mean(self.ent)
                self.pol_ent_pen = (-self.entropy_coeff) * self.mean_ent

                # self.ratio = tf.exp(
                #     self.pi.pd.logp(self.action_phd) - self.pi_old.pd.logp(self.action_phd))  # pnew / pold
                self.ratio = tf.exp(tf.minimum(40.0, tf.maximum(-40.0,
                                                                self.pi.pd.logp(self.action_phd) -
                                                                self.pi_old.pd.logp(self.action_phd))))

                # self.surr1 = self.ratio * self.atarg  # surrogate from conservative policy iteration
                self.surr1 = tf.where(tf.logical_or(
                    tf.is_inf(self.ratio * self.atarg),
                    tf.is_nan(self.ratio * self.atarg)),
                    self.atarg,  # tf.zeros_like(self.ratio),
                    self.ratio * self.atarg)

                self.surr2 = tf.clip_by_value(self.ratio, 1.0 - self.ratio_clip_param,
                                              1.0 + self.ratio_clip_param) * self.atarg  #
                self.pol_surr = - tf.reduce_mean(
                    tf.minimum(self.surr1, self.surr2))  # PPO's pessimistic surrogate (L^CLIP)

                """
                    loss of shaped value function
                """
                self.vf_loss = tf.reduce_mean(tf.square(self.pi.vpred - self.ret))
                self.total_loss = self.pol_surr + self.pol_ent_pen + self.vf_loss

                """
                    loss of true value function
                """
                self.vf_true_loss = tf.reduce_mean(tf.square(self.pi.vpred_true - self.ret_true))

                """
                    the true loss of policy
                    which is used for computing nabla_{phi} J
                """
                # self.surr1_f = self.ratio * self.atarg_true  # surrogate from conservative policy iteration
                self.surr1_f = tf.where(tf.logical_or(
                    tf.is_inf(self.ratio * self.atarg_true),
                    tf.is_nan(self.ratio * self.atarg_true)),
                    self.atarg_true,  # tf.zeros_like(self.ratio),
                    self.ratio * self.atarg_true)

                self.surr2_f = tf.clip_by_value(self.ratio, 1.0 - self.ratio_clip_param,
                                                1.0 + self.ratio_clip_param) * self.atarg_true  #
                self.pol_surr_true = - tf.reduce_mean(tf.minimum(self.surr1_f, self.surr2_f))

                self.losses = [self.pol_surr, self.pol_ent_pen, self.vf_loss, self.vf_true_loss,
                               self.mean_kl, self.mean_ent]
                self.loss_names = ["pol_surr", "pol_entpen", "vf_loss", "vf_true_loss", "kl", "ent"]

    def define_f_loss(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    the loss of shaping weight function f, which is quite complicated

                    define optimization of f_phi
                    nabla_{phi} J = { nabla_{theta} log pi_{theta}(s,a) } * 
                                    { nabla_{phi} theta } * 
                                    Q_{True}(s, a)
                                  
                                  = { nabla_{theta} log pi_{theta}(s,a) } * 
                                    alpha * { nabla_{theta'} log pi_{theta'}(s',a') } * { nabla_{phi} R_{tau} } * 
                                    Q_{True}(s, a)

                                  = { nabla_{theta} log pi_{theta}(s,a) } * 
                                    alpha * { nabla_{theta'} log pi_{theta'}(s',a') } * 
                                    { sum_{i=0}^{|tau|-1 } gamma^i F(s_i,a_i) nabla_{phi} f_{phi}(s_i) } *
                                    Q_{True}(s, a)
                """
                """
                    self.grad_return_wrt_phi = sum_{i=0}^{|tau|-1 } gamma^i F(s_i,a_i) nabla_{phi} f_{phi}(s_i)

                    it should be noted that although the shape of self.traj_f_phi is [None, 1] (one output for each input state)
                    but the gradient value is the sum over all gradients of the input states
                    that is the say, self.grad_return_wrt_phi is like [array(), array(), ..., array()],
                    where each array corresponds to the gradients of each layer's parameters
                    
                """
                self.grad_return_wrt_phi = tf.gradients(ys=self.traj_f_phi, xs=self.f_phi_params,
                                                        grad_ys=tf.multiply(tf.pow(self.gamma, self.traj_step_phd),
                                                                            self.traj_F_phd))
                self.f_phi_params_shapes = [None] * len(self.f_phi_params)
                self.theta_params_shapes = [None] * len(self.policy_params)
                for i in range(len(self.f_phi_params)):
                    self.f_phi_params_shapes[i] = self.f_phi_params[i].shape
                    print("The f_phi param {} shape is {}".format(i, self.f_phi_params_shapes[i]))

                for i in range(len(self.policy_params)):
                    self.theta_params_shapes[i] = self.policy_params[i].shape
                    print("The policy param {} shape is {}".format(i, self.theta_params_shapes[i]))


                """
                   self.grad_logpi_wrt_theta = { nabla_{theta'} log pi_{theta'}(s',a') } and 
                   { nabla_{theta} log pi_{theta}(s,a) }
                   
                   the gradient of log policy w.r.t to the parameter theta
                """
                self.grad_logpi_wrt_theta = tf.gradients(ys=self.pi.pd.logp(self.action_phd),
                                                         xs=self.policy_params)

                """
                   gradient of theta w.r.t. phi
                   which is aggragated every update of policy parameters
                   and will be averaged before the update of the shaping weight function
                """
                self.grad_theta_wrt_phi_aggr = None
                self.grad_aggr_num = 0

                """
                    the policy gradient is for optimizing weight function parameter phi
                    { nabla_{theta} log pi_{theta}(s,a) } * Q_{True}(s,a) * { nabla_{phi} theta }
                """
                self.grad_true_loss_wrt_theta = tf.gradients(ys=self.pol_surr_true, xs=self.policy_params)
                self.grad_true_loss_wrt_theta, _ = tf.clip_by_global_norm(self.grad_true_loss_wrt_theta,
                                                                          self.actor_grad_norm_clip)

                self.grad_J_wrt_phi = [None] * len(self.f_phi_params)

                for ly in range(len(self.f_phi_params_shapes)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    fake_gradient = np.full(shape=ly_shape, fill_value=0.0, dtype=np.float32)
                    self.grad_J_wrt_phi[ly] = fake_gradient


    def define_optimizers(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    define policy optimizer and policy trainer
                """
                self.policy_opt = tf.train.AdamOptimizer(self.lr_actor)
                if self.actor_grad_clip:
                    self.policy_gradients = tf.gradients(self.total_loss, self.policy_params)
                    self.policy_clipped_grad_op, _ = tf.clip_by_global_norm(self.policy_gradients,
                                                                            self.actor_grad_norm_clip)
                    self.policy_trainer = self.policy_opt.apply_gradients(zip(self.policy_clipped_grad_op,
                                                                              self.policy_params))
                else:
                    self.policy_trainer = self.policy_opt.minimize(loss=self.total_loss, var_list=self.policy_params)

                """
                    define critic optimizer and critic trainer
                """
                self.critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                if self.critic_grad_clip:
                    self.critic_gradients = tf.gradients(self.vf_loss, self.critic_params)
                    self.critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.critic_gradients,
                                                                            self.critic_grad_norm_clip)
                    self.critic_trainer = self.critic_opt.apply_gradients(zip(self.critic_clipped_grad_op,
                                                                              self.critic_params))
                else:
                    self.critic_trainer = self.critic_opt.minimize(self.vf_loss, var_list=self.critic_params)

                """
                    define true critic optimizer and true critic trainer
                """
                self.true_critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                if self.critic_grad_clip:
                    self.true_critic_gradients = tf.gradients(self.vf_true_loss, self.true_critic_params)
                    self.true_critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.true_critic_gradients,
                                                                                 self.critic_grad_norm_clip)
                    self.true_critic_trainer = self.true_critic_opt.apply_gradients(
                        zip(self.true_critic_clipped_grad_op,
                            self.true_critic_params))
                else:
                    self.true_critic_trainer = self.true_critic_opt.minimize(self.vf_true_loss,
                                                                             var_list=self.true_critic_params)

                """
                    another way for updating phi
                """
                self.optimizer_f = tf.train.AdamOptimizer(self.lr_f)
                self.f_phi_params_grad_phds = [None] * len(self.f_phi_params)
                for ly in range(len(self.f_phi_params)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    self.f_phi_params_grad_phds[ly] = tf.placeholder(tf.float32, shape=ly_shape,
                                                                     name="f_phi_phd_{}".format(ly))

                if self.f_grad_clip:
                    self.f_clipped_grad_op, _ = tf.clip_by_global_norm(self.f_phi_params_grad_phds,
                                                                       self.f_grad_norm_clip)
                    self.trainer_f = self.optimizer_f.apply_gradients(zip(self.f_clipped_grad_op, self.f_phi_params))
                else:
                    self.trainer_f = self.optimizer_f.apply_gradients(zip(self.f_phi_params_grad_phds,
                                                                          self.f_phi_params))


                self.compute_losses = U.function([self.state_phd, self.action_phd,
                                                  self.atarg, self.ret, self.lrmult,
                                                  self.atarg_true, self.ret_true],
                                                 self.losses,
                                                 sess=self.sess,
                                                 graph=self.graph)

    def choose_action(self, s, is_test):
        with self.graph.as_default():
            action, vpred, vpred_true = self.pi.act(stochastic=True, ob=s)

            """
                compute f_phi_s value
            """
            f_phi_s = self.sess.run(self.f_phi, {self.state_phd: [s]})

            # if random.uniform(0, 1.0) < 1e-3:
            #     print("f_phi_s is {}".format(f_phi_s))

            if len(f_phi_s.shape) == 2:
                f_phi_s = f_phi_s[0][0]

            return action, {"v_pred": vpred, "v_pred_true": vpred_true, "f_phi_s": f_phi_s}

    def learn(self, **kwargs):
        if self.optimize_policy:
            self.update_policy(**kwargs)
        else:
            self.update_shaping_weight_func(**kwargs)

    def update_policy(self, **kwargs):
        with self.graph.as_default():
            """
                oprs-v2 learns with trajectory batch, which is organized in the trainer
                use all samples to optimize actor
                
                get:
                advantage values
                td-lambda returns
                state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv")
            batch_td_lam_ret = kwargs.get("td_lam_ret")

            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(bs=bs, ba=ba, badv=batch_adv, bret=batch_td_lam_ret),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(bs)

            """
                set old parameter values to new parameter values
            """
            self.assign_old_eq_new()

            """
                set f parameters values to f_old parameters
            """
            if self.use_f_old:
                self.sess.run([old_param.assign(param) for old_param, param in
                               zip(self.f_old_params, self.f_phi_params)])

            # param_values = self.sess.run(self.policy_params)
            # print("Before update the parameters are {}".format(param_values))

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                losses = []  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(batch_size):
                    # print("The batch is {}".format(batch))

                    _, _, policy_loss, v_loss, kl = self.sess.run([self.policy_trainer, self.critic_trainer,
                                                                   self.pol_surr, self.vf_loss, self.mean_kl],
                                                                  feed_dict={
                                                                      self.state_phd: batch["bs"],
                                                                      self.action_phd: batch["ba"],
                                                                      self.atarg: batch["badv"],
                                                                      self.ret: batch["bret"]
                                                                  })

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "policy_loss", policy_loss)
                    self.write_summary_scalar(self.update_cnt, "v_loss", v_loss)
                    self.write_summary_scalar(self.update_cnt, "mean_kl", kl)
                    self.update_cnt += 1

            """
                for each state in the state batch
                compute { nabla_{theta} log pi_{theta}(s,a) } and { nabla_{phi} R_{tau}(s,a) }

                1. compute nabla_{phi} R_{tau}(s, a) using the minibatch of the state
                2. compute nabla_{theta} log pi_{theta}(s, a)
                3. conduct matrix multiplication of nabla_{phi} R_{tau}(s, a) and nabla_{theta} log pi_{theta}(s, a)
                4. sum over the multiplication results of all states

                only pick up few samples for computing gradient of theta w.r.t. phi
            """
            mini_batches = kwargs.get("mini_batches")
            grad_compute_s_batch = kwargs.get("grad_compute_s_batch")
            grad_compute_a_batch = kwargs.get("grad_compute_a_batch")
            assert len(mini_batches) == len(grad_compute_s_batch)
            grad_theta_wrt_phi = None
            for sample_index in range(len(grad_compute_s_batch)):
                s = grad_compute_s_batch[sample_index]
                a = grad_compute_a_batch[sample_index]

                """
                    [state, step, F, action]
                """
                mini_batch_s = mini_batches[sample_index]
                mini_feed_dict = {}
                mini_feed_dict.update({self.traj_state_phd: mini_batch_s[0]})
                mini_feed_dict.update({self.traj_step_phd: mini_batch_s[1]})
                mini_feed_dict.update({self.traj_F_phd: mini_batch_s[2]})

                grad_return_wrt_phi_sa = self.sess.run(self.grad_return_wrt_phi, feed_dict=mini_feed_dict)
                grad_logpi_wrt_theta_sa = self.sess.run(self.grad_logpi_wrt_theta, feed_dict={
                    self.state_phd: [s], self.action_phd: [a]})

                """
                    firstly transform the two list of nd-arrays into list of lists and record the corresponding shapes
                    flatten the two list of nd-arrays
                """
                for ly in range(len(grad_return_wrt_phi_sa)):
                    grad_return_wrt_phi_sa[ly] = grad_return_wrt_phi_sa[ly].tolist()

                for ly in range(len(grad_logpi_wrt_theta_sa)):
                    grad_logpi_wrt_theta_sa[ly] = grad_logpi_wrt_theta_sa[ly].tolist()

                flatten_grad_return_wrt_phi_sa = np.array(my_flatten(grad_return_wrt_phi_sa))
                flatten_grad_logpi_wrt_theta_sa = np.array(my_flatten(grad_logpi_wrt_theta_sa))

                grad_theta_wrt_phi_sa = np.matmul(flatten_grad_logpi_wrt_theta_sa.reshape([-1, 1]),
                                                 np.atleast_2d(flatten_grad_return_wrt_phi_sa))

                if grad_theta_wrt_phi is None:
                    grad_theta_wrt_phi = grad_theta_wrt_phi_sa
                else:
                    grad_theta_wrt_phi = grad_theta_wrt_phi + grad_theta_wrt_phi_sa

            grad_theta_wrt_phi = np.multiply(grad_theta_wrt_phi, self.lr_actor / len(grad_compute_s_batch))

            """
                then aggregate the gradient of this update
            """
            self.grad_aggr_num += 1
            if self.grad_theta_wrt_phi_aggr is None:
                self.grad_theta_wrt_phi_aggr = grad_theta_wrt_phi
            else:
                self.grad_theta_wrt_phi_aggr = self.grad_theta_wrt_phi_aggr + grad_theta_wrt_phi

    def update_shaping_weight_func(self, **kwargs):
        with self.graph.as_default():
            assert self.grad_aggr_num != 0
            assert self.grad_theta_wrt_phi_aggr is not None

            """
                get:
                true advantage values
                true td-lambda returns
                true state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv_true")
            batch_ret = kwargs.get("ret_true")

            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(bs=bs, ba=ba, badv=batch_adv, bret=batch_ret),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            """
                set old parameter values to new parameter values
            """
            # self.assign_old_eq_new()

            """
                set f parameters values to f_old parameters
            """
            if self.use_f_old:
                self.sess.run([old_param.assign(param) for old_param, param in
                               zip(self.f_old_params, self.f_phi_params)])

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                for batch in d.iterate_once(batch_size):
                    """
                        compute E_{s \sim rho} [nabla_{theta} log pi_{theta}(s,a) * Q_{True}(s,a)]
                        actually, the objective is the clipped policy loss with true advantage values
                    """
                    grad_true_loss_wrt_theta = self.sess.run(self.grad_true_loss_wrt_theta,
                                                             {self.state_phd: batch["bs"],
                                                              self.action_phd: batch["ba"],
                                                              self.atarg_true: batch["badv"]})

                    for ly in range(len(grad_true_loss_wrt_theta)):
                        grad_true_loss_wrt_theta[ly] = grad_true_loss_wrt_theta[ly].tolist()

                    flatten_grad_true_loss_wrt_theta = np.array(my_flatten(grad_true_loss_wrt_theta))

                    """
                        then, conduct matrix multiplication and finally get \\nabla_{\phi} J_{True}
                        we also should reshape the value according to the shape of parameters \phi
                    """
                    grad_J_wrt_phi = np.matmul(np.atleast_2d(flatten_grad_true_loss_wrt_theta),
                                               self.grad_theta_wrt_phi_aggr)
                    if len(grad_J_wrt_phi.shape) == 2 and grad_J_wrt_phi.shape[0] == 1:
                        grad_J_wrt_phi = grad_J_wrt_phi[0]

                    self.grad_J_wrt_phi = [None] * len(self.f_phi_params_shapes)
                    element_start = 0
                    element_end = 0  # exclusive
                    for ly in range(len(self.f_phi_params_shapes)):
                        ly_shape = self.f_phi_params_shapes[ly]
                        element_num = 1
                        for j in range(len(ly_shape)):
                            element_num *= int(ly_shape[j])

                        element_end = element_start + element_num
                        param_grad_seg = grad_J_wrt_phi[element_start:element_end]
                        param_grad_seg = np.array(param_grad_seg, dtype=np.float32).reshape(ly_shape)

                        # self.grad_J_wrt_phi.append(param_grad_seg)
                        self.grad_J_wrt_phi[ly] = param_grad_seg

                        element_start = element_end

                    # print("The recover gradient of J w.r.t. phi is {}".format(self.grad_J_wrt_phi))

                    f_phi_dict = {}
                    for ly in range(len(self.f_phi_params)):
                        f_phi_dict.update({self.f_phi_params_grad_phds[ly]: self.grad_J_wrt_phi[ly]})

                    """
                        finally, we use nabla_{phi} J_{True} to optimize f_phi
                    """
                    self.sess.run(self.trainer_f, feed_dict=f_phi_dict)

                    _, v_true_loss = self.sess.run([self.true_critic_trainer, self.vf_true_loss],
                                                   feed_dict={
                                                       self.state_phd: batch["bs"],
                                                       self.action_phd: batch["ba"],
                                                       self.atarg_true: batch["badv"],
                                                       self.ret_true: batch["bret"]
                                                   })

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "v_true_loss", v_true_loss)
                    self.update_cnt += 1

    def _build_weight_func(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                # net = tf.nn.relu(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_weight_func_reuse(self, s):
        with tf.variable_scope('Weight_Func', reuse=True):
            net = s
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=True)
                net = tf.contrib.layers.layer_norm(net)
                # net = tf.nn.relu(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=True)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_weight_func_old(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func_Old', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                # net = tf.nn.relu(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def switch_optimization(self):
        self.optimize_policy = not self.optimize_policy
        """
            clean the gradient of theta w.r.t. phi 
            if now it is to optimize the policy again
        """
        if self.optimize_policy:
            self.grad_theta_wrt_phi_aggr = None
            self.grad_aggr_num = 0
        else:
            self.grad_theta_wrt_phi_aggr = np.divide(self.grad_theta_wrt_phi_aggr, self.grad_aggr_num)
