import random
import numpy as np
import tensorflow as tf
from ...utils.mlp_policy import MlpPolicy
import experiments.utils.tf_util as U
from experiments.utils.common import Dataset, zipsame
from gym import spaces
RENDER = False

"""
    default algorithm parameters
"""
LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
GAMMA = 0.999
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
ENTROPY_COEFF = 0.0
RATIO_CLIP_PARAM = 0.2
ADAM_EPSILON = 1e-5
OPTIM_EPOCHS = 50
OPTIM_BATCH_SIZE = 1024


class PPOAlgo(object):
    def __init__(self, sess, graph, state_space, action_space, algo_name="ppo",
                 **kwargs):
        # print(a_dim, s_dim)
        self.pointer = 0
        self.sess = sess
        self.graph = graph
        self.algo_name = algo_name
        self.action_space, self.state_space = action_space, state_space

        if isinstance(action_space, spaces.Discrete):
            self.is_discrete = True
            self.action_dim = self.action_space.n
            print("Discrete Action Space, action num is {}".format(self.action_dim))
        else:
            self.is_discrete = False
            self.action_dim = self.action_space.shape[0]
            print("Continuous Action Space, action num is {}".format(self.action_dim))

        """
            initialize algorithm parameters
        """
        self.set_algo_parameters(**kwargs)

        self.init_networks()

        self.update_cnt = 0
        self.train_writer = tf.summary.FileWriter("./data/" + self.algo_name + "/summary/", self.sess.graph)

        """
            try the initialization of tf_util
        """
        # U.initialize()
        with self.sess.as_default():
            with self.graph.as_default():
                tf.global_variables_initializer().run()

        """
            sync of mpi optimizer
        """
        # self.adam.sync()

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.entropy_coeff = kwargs.get("entropy_coeff", ENTROPY_COEFF)
        self.ratio_clip_param = kwargs.get("ratio_clip_param", RATIO_CLIP_PARAM)
        self.adam_epsilon = kwargs.get("adam_epsilon", ADAM_EPSILON)
        self.optim_epochs = kwargs.get("optim_epochs", OPTIM_EPOCHS)
        self.optim_batch_size = kwargs.get("optim_batch_size", OPTIM_BATCH_SIZE)
        self.policy_net_layers = kwargs.get("policy_net_layers", [8, 8])
        self.v_net_layers = kwargs.get("v_net_layers", [32, 32])
        self.gaussian_fixed_var = kwargs.get("gaussian_fixed_var", False)

    def policy_fn(self, name):
        return MlpPolicy(self.sess, self.graph, name=name, ob_space=self.state_space,
                         ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                         v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var)

    def init_networks(self):
        self.init_ppo_networks()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def init_ppo_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.pi = self.policy_fn("pi")  # Construct network for new policy
                self.pi_old = self.policy_fn("oldpi")  # Network for old policy

                self.atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
                self.ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
                self.lrmult = tf.placeholder(name='lrmult', dtype=tf.float32,
                                             shape=[])  # learning rate multiplier, updated with schedule

                self.state_phd = U.get_placeholder_cached(name="ob")  # ob
                self.action_phd = self.pi.pdtype.sample_placeholder([None])  # ac

                """
                    I thinks these belong to optimization part
                """
                self.kl_old_new = self.pi_old.pd.kl(self.pi.pd)
                self.ent = self.pi.pd.entropy()
                self.mean_kl = tf.reduce_mean(self.kl_old_new)
                self.mean_ent = tf.reduce_mean(self.ent)
                self.pol_ent_pen = (-self.entropy_coeff) * self.mean_ent

                # self.ratio = tf.exp(
                #     self.pi.pd.logp(self.action_phd) - self.pi_old.pd.logp(self.action_phd))  # pnew / pold
                self.ratio = tf.exp(tf.minimum(40.0, tf.maximum(-40.0,
                                                                self.pi.pd.logp(self.action_phd) -
                                                                self.pi_old.pd.logp(self.action_phd))))

                # self.surr1 = self.ratio * self.atarg  # surrogate from conservative policy iteration
                self.surr1 = tf.where(tf.logical_or(
                    tf.is_inf(self.ratio * self.atarg),
                    tf.is_nan(self.ratio * self.atarg)),
                    tf.zeros_like(self.ratio),
                    self.ratio * self.atarg)

                self.surr2 = tf.clip_by_value(self.ratio, 1.0 - self.ratio_clip_param,
                                              1.0 + self.ratio_clip_param) * self.atarg  #
                self.pol_surr = - tf.reduce_mean(
                    tf.minimum(self.surr1, self.surr2))  # PPO's pessimistic surrogate (L^CLIP)
                self.vf_loss = tf.reduce_mean(tf.square(self.pi.vpred - self.ret))
                self.total_loss = self.pol_surr + self.pol_ent_pen + self.vf_loss
                self.losses = [self.pol_surr, self.pol_ent_pen, self.vf_loss, self.mean_kl, self.mean_ent]
                self.loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

                # self.all_params = self.pi.get_trainable_variables()
                self.policy_params = self.pi.get_policy_variables()
                self.critic_params = self.pi.get_critic_variables()

                """
                    define parameter assignment operation
                """
                # print("Pi variables are {}".format(self.pi.get_variables()))
                self.assign_old_eq_new = U.function([], [],
                                                    updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                             zipsame(self.pi_old.get_variables(),
                                                                     self.pi.get_variables())],
                                                    sess=self.sess,
                                                    graph=self.graph)

                """
                    define policy optimizer and policy trainer
                """
                self.policy_opt = tf.train.AdamOptimizer(self.lr_actor)
                if self.actor_grad_clip:
                    self.policy_gradients = tf.gradients(self.total_loss, self.policy_params)
                    self.policy_clipped_grad_op, _ = tf.clip_by_global_norm(self.policy_gradients,
                                                                            self.actor_grad_norm_clip)
                    self.policy_trainer = self.policy_opt.apply_gradients(zip(self.policy_clipped_grad_op,
                                                                              self.policy_params))
                else:
                    self.policy_trainer = self.policy_opt.minimize(loss=self.total_loss, var_list=self.policy_params)

                """
                    define critic optimizer and critic trainer
                """
                self.critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                if self.critic_grad_clip:
                    self.critic_gradients = tf.gradients(self.vf_loss, self.critic_params)
                    self.critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.critic_gradients,
                                                                            self.critic_grad_norm_clip)
                    self.critic_trainer = self.critic_opt.apply_gradients(zip(self.critic_clipped_grad_op,
                                                                              self.critic_params))
                else:
                    self.critic_trainer = self.critic_opt.minimize(self.vf_loss, var_list=self.critic_params)

                self.compute_losses = U.function([self.state_phd, self.action_phd,
                                                  self.atarg, self.ret, self.lrmult],
                                                 self.losses,
                                                 sess=self.sess,
                                                 graph=self.graph)

    def choose_action(self, s, is_test):
        with self.graph.as_default():
            action, vpred = self.pi.act(stochastic=True, ob=s)
            return action, vpred

    def learn(self, **kwargs):
        with self.graph.as_default():
            """
                get:
                advantage values
                td-lambda returns
                state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv")
            batch_td_lam_ret = kwargs.get("td_lam_ret")
            # batch_v_pred = kwargs.get("v_pred")

            # print("batch state is {}".format(bs))
            # print("Batch action is {}".format(ba))

            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(ob=bs, ac=ba, atarg=batch_adv, vtarg=batch_td_lam_ret),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(bs)

            """
                set old parameter values to new parameter values
            """
            self.assign_old_eq_new()

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                losses = []  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(batch_size):

                    # print("The batch is {}".format(batch))

                    # *newlosses, g = self.loss_and_grad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], 1.0)
                    # self.adam.update(g, LR_ACTOR * 1.0)
                    # losses.append(newlosses)

                    _, _, policy_loss, v_loss, kl = self.sess.run([self.policy_trainer, self.critic_trainer,
                                                               self.pol_surr, self.vf_loss, self.mean_kl],
                                  feed_dict={
                                      self.state_phd: batch["ob"],
                                      self.action_phd: batch["ac"],
                                      self.atarg: batch["atarg"],
                                      self.ret: batch["vtarg"]
                                  }
                                  )

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "policy_loss", policy_loss)
                    self.write_summary_scalar(self.update_cnt, "v_loss", v_loss)
                    self.write_summary_scalar(self.update_cnt, "mean_kl", kl)
                    # print("p_loss, v_loss, mean_kl are {}, {}".format(policy_loss, v_loss))
                    # print("The advantage is {}".format(batch["atarg"]))
                    self.update_cnt += 1

    def write_summary_scalar(self, iteration, tag, value):
        self.train_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)
