import random
import numpy as np
import math
import tensorflow as tf
from ....utils.mlp_policy import MlpPolicyOprsV1

import experiments.utils.tf_util as U
from experiments.utils.common import Dataset, zipsame
from ..ppo_algo import PPOAlgo

LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_F = 0.00
GAMMA = 0.999
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
F_GRADIENT_NORM_CLIP = 50.0
ENTROPY_COEFF = 0.0
RATIO_CLIP_PARAM = 0.2
ADAM_EPSILON = 1e-5
OPTIM_EPOCHS = 50
OPTIM_BATCH_SIZE = 1024

"""
    For PPO algorithm, the updating of the shaping weight function f 
    should be conducted using on-policy mode, for short we call it "fop"
"""


class PPOOprsV1FopAlgo(PPOAlgo):
    def __init__(self, sess, graph, state_space, action_space, algo_name="ppo_oprs_v1_fop", **kwargs):
        self.optimize_policy = True
        super(PPOOprsV1FopAlgo, self).__init__(sess, graph, state_space, action_space, algo_name, **kwargs)

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.lr_f = kwargs.get("lr_f", LR_F)
        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.f_grad_clip = kwargs.get("f_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.f_grad_norm_clip = kwargs.get("f_gradient_norm_clip", F_GRADIENT_NORM_CLIP)
        self.entropy_coeff = kwargs.get("entropy_coeff", ENTROPY_COEFF)
        self.ratio_clip_param = kwargs.get("ratio_clip_param", RATIO_CLIP_PARAM)
        self.adam_epsilon = kwargs.get("adam_epsilon", ADAM_EPSILON)
        self.optim_epochs = kwargs.get("optim_epochs", OPTIM_EPOCHS)
        self.optim_batch_size = kwargs.get("optim_batch_size", OPTIM_BATCH_SIZE)
        self.policy_net_layers = kwargs.get("policy_net_layers", [8, 8])
        self.v_net_layers = kwargs.get("v_net_layers", [32, 32])
        self.f_net_layers = kwargs.get("f_net_layers", [16, 8])
        self.gaussian_fixed_var = kwargs.get("gaussian_fixed_var", False)
        self.joint_opt = kwargs.get("joint_opt", False)
        self.use_f_old = kwargs.get("use_f_old", False)
        self.net_add_one = kwargs.get("net_add_one", False)
        self.f_hidden_layer_act_func = kwargs.get("f_hidden_layer_act_func", tf.nn.relu)
        self.f_output_layer_act_func = kwargs.get("f_output_layer_act_func", None)

        self.f_phi_min = kwargs.get("f_phi_min", None)
        self.f_phi_max = kwargs.get("f_phi_max", None)

    def init_networks(self):
        """
            init shaping weight function f's network
        """
        self.init_f_network()

        self.init_ppo_networks()

        """
            define losses
        """
        self.define_losses()

        """
            then define optimizer for each network
        """
        self.define_optimizers()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def policy_fn(self, name):
        return MlpPolicyOprsV1(self.sess, self.graph, name=name, ob_space=self.state_space,
                               ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                               v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var,
                               f_output=self.f_phi)

    def policy_fn_old(self, name):
        if self.use_f_old:
            return MlpPolicyOprsV1(self.sess, self.graph, name=name, ob_space=self.state_space,
                                   ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                                   v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var,
                                   f_output=self.f_old)
        else:
            return MlpPolicyOprsV1(self.sess, self.graph, name=name, ob_space=self.state_space,
                                   ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                                   v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var,
                                   f_output=self.f_phi_phd)

    def init_f_network(self):
        with self.sess.as_default():
            with self.graph.as_default():
                # self.state_phd_of_f = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='state_of_f')
                self.state_phd = U.get_placeholder_with_graph(name="ob", dtype=tf.float32,
                                          shape=[None] + list(self.state_space.shape),
                                          graph=self.graph)
                # self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='state_prime')
                #
                # """
                #     the original reward R(s,a)
                # """
                # self.reward_phd = tf.placeholder(tf.float32, [None, 1], name='reward')
                #
                # """
                #     the additional reward, namely F(s,a)
                # """
                # self.add_reward_phd = tf.placeholder(tf.float32, [None, 1], name='additional_reward')
                #
                # self.done_phd = tf.placeholder(tf.float32, [None, 1], name='done')
                #
                # """
                #     shaping weight value of the next state s'
                # """
                # self.f_phi_sp_phd = tf.placeholder(tf.float32, [None, 1], name='f_phi_sp')

                """
                    build the shaping weight function f_phi(s)
                """
                self.f_phi = self._build_weight_func(self.state_phd, )
                self.f_phi_phd = U.get_placeholder_with_graph(name="f_phi_phd", dtype=tf.float32,
                                                              shape=[None, 1], graph=self.graph)

                """
                    build the old shaping weight function f_old
                    which is the input of pi_old
                """
                if self.use_f_old:
                    self.f_old = self._build_weight_func_old(self.state_phd, )

                # """
                #     build true critic network, which is for optimization weight function f_phi
                #     the input is state and action
                # """
                # self.action_of_true_q_phd = tf.placeholder(tf.float32, [None, self.action_space.shape[0]],
                #                                            name='action_of_true_q')
                #
                # """
                #     true critic is not build here, we use true advantage function
                #     which is computed using true returns
                # """

    def init_ppo_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    now the policy network pi is defined over the state-f-action space!!
                """
                self.pi = self.policy_fn("pi")  # Construct network for new policy
                self.pi_old = self.policy_fn_old("oldpi")

                """
                    shaped advantage and return
                """
                self.atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
                self.ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return

                """
                    true advantage and return
                """
                self.atarg_true = tf.placeholder(dtype=tf.float32, shape=[None], name="adv_true")
                self.ret_true = tf.placeholder(dtype=tf.float32, shape=[None], name="ret_true")

                self.lrmult = tf.placeholder(name='lrmult', dtype=tf.float32,
                                             shape=[])  # learning rate multiplier, updated with schedule

                self.state_phd = U.get_placeholder_cached(name="ob")  # ob
                self.action_phd = self.pi.pdtype.sample_placeholder([None])  # ac

                """
                    define parameter assignment operation
                """
                # print("Pi variables are {}".format(self.pi.get_variables()))
                self.assign_old_eq_new = U.function([], [],
                                                    updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                             zipsame(self.pi_old.get_variables(),
                                                                     self.pi.get_variables())],
                                                    sess=self.sess,
                                                    graph=self.graph)

    def define_losses(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.kl_old_new = self.pi_old.pd.kl(self.pi.pd)
                self.ent = self.pi.pd.entropy()
                self.mean_kl = tf.reduce_mean(self.kl_old_new)
                self.mean_ent = tf.reduce_mean(self.ent)
                self.pol_ent_pen = (-self.entropy_coeff) * self.mean_ent

                # self.ratio = tf.exp(
                #     self.pi.pd.logp(self.action_phd) - self.pi_old.pd.logp(self.action_phd))  # pnew / pold
                self.ratio = tf.exp(tf.minimum(40.0,
                                               tf.maximum(-40.0,
                                                                self.pi.pd.logp(self.action_phd) -
                                                                self.pi_old.pd.logp(self.action_phd))))

                # self.surr1 = self.ratio * self.atarg  # surrogate from conservative policy iteration
                self.surr1 = tf.where(tf.logical_or(
                    tf.is_inf(self.ratio * self.atarg),
                    tf.is_nan(self.ratio * self.atarg)),
                    tf.zeros_like(self.ratio),
                    self.ratio * self.atarg)

                self.surr2 = tf.clip_by_value(self.ratio, 1.0 - self.ratio_clip_param,
                                              1.0 + self.ratio_clip_param) * self.atarg  #
                self.pol_surr = - tf.reduce_mean(
                    tf.minimum(self.surr1, self.surr2))  # PPO's pessimistic surrogate (L^CLIP)

                """
                    loss of shaped value function
                """
                self.vf_loss = tf.reduce_mean(tf.square(self.pi.vpred - self.ret))
                self.total_loss = self.pol_surr + self.pol_ent_pen + self.vf_loss

                """
                    loss of true value function
                """
                self.vf_true_loss = tf.reduce_mean(tf.square(self.pi.vpred_true - self.ret_true))

                """
                    the loss of shaping weight function f
                """
                # self.surr1_f = self.ratio * self.atarg_true  # surrogate from conservative policy iteration

                self.surr1_f = tf.where(tf.logical_or(
                    tf.is_inf(self.ratio * self.atarg_true),
                    tf.is_nan(self.ratio * self.atarg_true)),
                    tf.zeros_like(self.ratio),
                    self.ratio * self.atarg_true)

                ratio_clip_param_f = 0.2 #0.3
                self.surr2_f = tf.clip_by_value(self.ratio, 1.0 - ratio_clip_param_f,
                                                1.0 + ratio_clip_param_f) * self.atarg_true

                self.f_loss = - tf.reduce_mean(tf.minimum(self.surr1_f, self.surr2_f))
                # self.f_loss = - tf.reduce_mean(self.surr1_f)

                self.losses = [self.pol_surr, self.pol_ent_pen, self.vf_loss, self.vf_true_loss,
                               self.mean_kl, self.mean_ent]
                self.loss_names = ["pol_surr", "pol_entpen", "vf_loss", "vf_true_loss", "kl", "ent"]

    def define_optimizers(self):
        with self.sess.as_default():
            with self.graph.as_default():
                # self.all_params = self.pi.get_trainable_variables()
                self.policy_params = self.pi.get_policy_variables()
                self.critic_params = self.pi.get_critic_variables()
                self.true_critic_params = self.pi.get_true_critic_variables()
                self.f_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func')
                if self.use_f_old:
                    self.f_old_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func_Old')

                """
                    define policy optimizer and policy trainer
                """
                self.policy_opt = tf.train.AdamOptimizer(self.lr_actor)
                if self.actor_grad_clip:
                    self.policy_gradients = tf.gradients(self.total_loss, self.policy_params)
                    self.policy_clipped_grad_op, _ = tf.clip_by_global_norm(self.policy_gradients,
                                                                            self.actor_grad_norm_clip)
                    self.policy_trainer = self.policy_opt.apply_gradients(zip(self.policy_clipped_grad_op,
                                                                              self.policy_params))
                else:
                    self.policy_trainer = self.policy_opt.minimize(loss=self.total_loss, var_list=self.policy_params)

                """
                    define critic optimizer and critic trainer
                """
                self.critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                if self.critic_grad_clip:
                    self.critic_gradients = tf.gradients(self.vf_loss, self.critic_params)
                    self.critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.critic_gradients,
                                                                            self.critic_grad_norm_clip)
                    self.critic_trainer = self.critic_opt.apply_gradients(zip(self.critic_clipped_grad_op,
                                                                              self.critic_params))
                else:
                    self.critic_trainer = self.critic_opt.minimize(self.vf_loss, var_list=self.critic_params)

                """
                    define true critic optimizer and true critic trainer
                """
                self.true_critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                if self.critic_grad_clip:
                    self.true_critic_gradients = tf.gradients(self.vf_true_loss, self.true_critic_params)
                    self.true_critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.true_critic_gradients,
                                                                                 self.critic_grad_norm_clip)
                    self.true_critic_trainer = self.true_critic_opt.apply_gradients(
                        zip(self.true_critic_clipped_grad_op,
                            self.true_critic_params))
                else:
                    self.true_critic_trainer = self.true_critic_opt.minimize(self.vf_true_loss,
                                                                             var_list=self.true_critic_params)

                """
                    define optimizer for shaping weight function f
                """
                self.f_opt = tf.train.AdamOptimizer(self.lr_f)
                if self.joint_opt:
                    self.extended_f_param = self.f_params
                    self.extended_f_param.extend(self.policy_params)
                    if self.f_grad_clip:
                        self.f_gradients = tf.gradients(self.f_loss, self.extended_f_param)
                        self.f_clipped_grad_op, _ = tf.clip_by_global_norm(self.f_gradients,
                                                                           self.f_grad_norm_clip)
                        self.f_trainer = self.f_opt.apply_gradients(zip(self.f_clipped_grad_op,
                                                                        self.extended_f_param))
                    else:
                        self.f_trainer = self.f_opt.minimize(self.f_loss, var_list=self.extended_f_param)
                else:
                    if self.f_grad_clip:
                        self.f_gradients = tf.gradients(self.f_loss, self.f_params)
                        self.f_clipped_grad_op, _ = tf.clip_by_global_norm(self.f_gradients,
                                                                           self.f_grad_norm_clip)
                        self.f_trainer = self.f_opt.apply_gradients(zip(self.f_clipped_grad_op,
                                                                        self.f_params))
                    else:
                        self.f_gradients = tf.gradients(self.f_loss, self.f_params)
                        self.f_trainer = self.f_opt.minimize(self.f_loss, var_list=self.f_params)

                self.compute_losses = U.function([self.state_phd, self.action_phd,
                                                  self.atarg, self.ret, self.lrmult,
                                                  self.atarg_true, self.ret_true],
                                                 self.losses,
                                                 sess=self.sess,
                                                 graph=self.graph)

    def choose_action(self, s, is_test):
        with self.graph.as_default():
            action, vpred, vpred_true = self.pi.act(stochastic=True, ob=s)

            """
                compute f_phi_s value
            """
            f_phi_s = self.sess.run(self.f_phi, {self.state_phd: [s]})

            # if random.uniform(0, 1.0) < 1e-3:
            #     print("f_phi_s is {}".format(f_phi_s))

            if len(f_phi_s.shape) == 2:
                f_phi_s = f_phi_s[0][0]

            if self.f_phi_min is not None and self.f_phi_max is not None:
                f_phi_s = np.minimum(self.f_phi_max, np.maximum(self.f_phi_min, f_phi_s))

            return action, {"v_pred": vpred, "v_pred_true": vpred_true, "f_phi_s": f_phi_s}

    def learn(self, **kwargs):
        if self.optimize_policy:
            self.update_policy(**kwargs)
        else:
            self.update_shaping_weight_func(**kwargs)

    def update_policy(self, **kwargs):
        with self.graph.as_default():
            """
                get:
                advantage values
                td-lambda returns
                state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv")
            batch_td_lam_ret = kwargs.get("td_lam_ret")
            batch_f_phi_s = kwargs.get("f_phi_s")

            # batch_v_pred = kwargs.get("v_pred")
            # print("batch state is {}".format(bs))
            # print("Batch action is {}".format(ba))

            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(bs=bs, ba=ba, badv=batch_adv, bret=batch_td_lam_ret,
                             bf=batch_f_phi_s),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(bs)

            """
                set old parameter values to new parameter values
            """
            self.assign_old_eq_new()

            """
                set f parameters values to f_old parameters
            """
            if self.use_f_old:
                self.sess.run([old_param.assign(param) for old_param, param in
                               zip(self.f_old_params, self.f_params)])

            # param_values = self.sess.run(self.policy_params)
            # print("Before update the parameters are {}".format(param_values))

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                losses = []  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(batch_size):
                    # print("The batch is {}".format(batch))

                    _, _, policy_loss, v_loss, kl = self.sess.run([self.policy_trainer, self.critic_trainer,
                                                                   self.pol_surr, self.vf_loss, self.mean_kl],
                                                                  feed_dict={
                                                                      self.state_phd: batch["bs"],
                                                                      self.action_phd: batch["ba"],
                                                                      self.atarg: batch["badv"],
                                                                      self.ret: batch["bret"],
                                                                      self.f_phi_phd: batch["bf"].reshape([-1, 1]),
                                                                      # self.f_phi: batch["bf"].reshape([-1, 1]),
                                                                      # self.f_old: batch["bf"].reshape([-1, 1])
                                                                  })

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "policy_loss", policy_loss)
                    self.write_summary_scalar(self.update_cnt, "v_loss", v_loss)
                    self.write_summary_scalar(self.update_cnt, "mean_kl", kl)
                    self.update_cnt += 1

    def update_shaping_weight_func(self, **kwargs):
        with self.graph.as_default():
            """
                get:
                true advantage values
                true td-lambda returns
                true state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv_true")
            batch_ret = kwargs.get("ret_true")
            batch_f_phi_s = kwargs.get("f_phi_s")

            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(bs=bs, ba=ba, badv=batch_adv, bret=batch_ret,
                             bf=batch_f_phi_s),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            """
                set old parameter values to new parameter values
                ???
            """
            self.assign_old_eq_new()

            """
                set f parameters values to f_old parameters
            """
            if self.use_f_old:
                self.sess.run([old_param.assign(param) for old_param, param in
                               zip(self.f_old_params, self.f_params)])

            # f_param_values = self.sess.run(self.f_params)
            # print("Before updating, The f_param values are {}".format(f_param_values))

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                for batch in d.iterate_once(batch_size):
                    _, _, f_loss, v_true_loss = self.sess.run([self.f_trainer, self.true_critic_trainer,
                                                               self.f_loss, self.vf_true_loss],
                                                              feed_dict={
                                                                  self.state_phd: batch["bs"],
                                                                  self.action_phd: batch["ba"],
                                                                  self.atarg_true: batch["badv"],
                                                                  self.ret_true: batch["bret"],
                                                                  self.f_phi_phd: batch["bf"].reshape([-1, 1])
                                                              })

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "f_loss", f_loss)
                    self.write_summary_scalar(self.update_cnt, "v_true_loss", v_true_loss)
                    self.update_cnt += 1

            # f_param_values = self.sess.run(self.f_params)
            # print("After updating, The f_param values are {}".format(f_param_values))

    def _build_weight_func(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            """
                currently we limit the weight value in [-1, 1]
            """
            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_weight_func_old(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func_Old', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            """
                currently we limit the weight value in [-1, 1]
            """
            net = tf.layers.dense(net, 1, activation=self.f_output_layer_act_func,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def switch_optimization(self):
        self.optimize_policy = not self.optimize_policy
