# coding=utf-8
import numpy as np
import tensorflow as tf
from .rcpo_ppo_algo import RcpoPpoAlgo

RENDER = False

GAMMA = 0.999
LAMBDA = 0.95
TRUNCATION_SIZE = 20000
CONSTRAINT = 0.25

"""
    save model per 1000 episodes
"""
MODEL_UPDATE_FREQ = 1000

"""
    Reward Constrained Policy Optimization based on the PPO algorithm
"""

class RcpoPpoTrainer(object):
    def __init__(self, state_space, action_space, algo_name="rcpo_ppo", **kwargs):
        self.state_space = state_space
        self.action_space = action_space
        self.algo_name = algo_name

        """
            set parameters
        """
        self.set_trainer_parameters(**kwargs)

        """
            initialize the algorithm
        """
        self.init_algo(**kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lmda = kwargs.get("lmda", LAMBDA)
        self.truncation_size = kwargs.get("truncation_size", TRUNCATION_SIZE)

        """
            the penalty constraint
        """
        self.constraint = kwargs.get("constraint", CONSTRAINT)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = RcpoPpoAlgo(self.session, self.graph, self.state_space,
                                     self.action_space, algo_name=self.algo_name,
                                     **kwargs)

        """
            for restoring samples
            one update will be performed after 2048 exps are collected
        """
        self.exp_mini_buffer = [None] * self.truncation_size
        self.add_exp_mini_buffer = []
        self.last_exp = None
        self.exp_cnt = 0
        self.update_cnt = 0

        """
            whether adding more experience
            which is true when self.exp_mini_buffer is full but 
            the last experience is not a terminal one
        """
        self.adding_more_exp = False

        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, v, lagrange_multi = self.algorithm.choose_action(state, test_model)
        return a, {"v_pred": v, "lagrange_multi": lagrange_multi}

    def experience(self, s, a, r, sp, terminal, **kwargs):
        # ppo has no memory
        # self.memory.add((s, a, r, sp, terminal))
        v_pred = kwargs.get("v_pred")
        lagrange_multi = kwargs.get("lagrange_multi")
        penalty = kwargs.get("penalty")

        if self.last_exp is None:
            self.last_exp = (s, a, r, sp, terminal, v_pred, lagrange_multi, penalty)
        else:
            if not self.adding_more_exp:
                i = self.exp_cnt % self.truncation_size
                self.exp_mini_buffer[i] = self.last_exp
                self.last_exp = (s, a, r, sp, terminal, v_pred, lagrange_multi, penalty)

                self.exp_cnt += 1
                if self.exp_cnt % self.truncation_size == 0:
                    """
                        For RCPO, truncation must be at terminal state!!
                        update the policy using the current experiences in buffer
                    """
                    if self.exp_mini_buffer[-1][4]:
                        print("Terminal at truncation point")
                        self.rcpo_ppo_update(next_v_pred=v_pred,
                                             next_lagrange_multi=lagrange_multi,
                                             next_penalty=penalty)
                    else:
                        print("Begin to add more experience until terminal")
                        self.adding_more_exp = True
                        self.add_exp_mini_buffer.clear()
            else:
                self.add_exp_mini_buffer.append(self.last_exp)
                self.last_exp = (s, a, r, sp, terminal, v_pred, lagrange_multi, penalty)

                if self.add_exp_mini_buffer[-1][4]:
                    print("Add a terminal state sample, we begin to update policy")
                    self.rcpo_ppo_update(next_v_pred=v_pred,
                                         next_lagrange_multi=lagrange_multi,
                                         next_penalty=penalty)

                    self.adding_more_exp = False
                    self.add_exp_mini_buffer.clear()

    def update(self, t):
        """
            directly return here
        """
        return

    def rcpo_ppo_update(self, **kwargs):
        last_state_v_pred = kwargs.get("next_v_pred")
        last_state_lagrange_multi = kwargs.get("lagrange_multi")
        last_state_penalty = kwargs.get("penalty")

        """
            conduct update of ppo
            first, we should transform experiences to samples
        """
        # print('update for', self.update_cnt)
        self.update_cnt += 1

        obs0 = self.exp_mini_buffer[0][0]
        act0 = self.exp_mini_buffer[0][1]
        mini_buffer = self.exp_mini_buffer
        if len(self.add_exp_mini_buffer) > 0:
            mini_buffer.extend(self.add_exp_mini_buffer)

        truncation_size = len(mini_buffer)
        seg = {"ob": np.array([obs0 for _ in range(truncation_size)]),
               "ac": np.array([act0 for _ in range(truncation_size)]),
               "prev_ac": np.array([act0 for _ in range(truncation_size)]),
               "rew": np.zeros(truncation_size, dtype=float),
               "v_pred": np.zeros(truncation_size, dtype=float),
               "done": np.zeros(truncation_size, dtype=int),
               "penalty": np.zeros(truncation_size, dtype=float)}

        pre_act = act0
        for t in range(truncation_size):
            s, a, r, sp, done, v_pred, lagrange_multi, penalty = mini_buffer[t]
            seg.get("ob")[t] = s
            seg.get("ac")[t] = a
            """
                it should be noted that
                For RCPO, the reward used for updating policy is the penalized reward
            """
            seg.get("rew")[t] = r - lagrange_multi * penalty
            seg.get("done")[t] = done
            seg.get("v_pred")[t] = v_pred
            seg.get("prev_ac")[t] = pre_act
            seg.get("penalty")[t] = penalty
            pre_act = a

        """
            add one more value to done and v_pred array
        """
        seg_done = seg["done"]
        vpred = np.append(seg["v_pred"], last_state_v_pred) # currently we add 0

        """
            compute the advantage and GAE values
            for t = T-1, T-2, ..., 3, 2, 1
        """
        gae_lam = np.empty(truncation_size, dtype=float)
        seg_rewards = seg["rew"]
        last_gae_lam = 0
        for t in reversed(range(truncation_size)):
            non_terminal = 1 - seg_done[t]
            delta = seg_rewards[t] + self.gamma * vpred[t + 1] * non_terminal - vpred[t]
            gae_lam[t] = delta + self.gamma * self.lmda * non_terminal * last_gae_lam
            last_gae_lam = gae_lam[t]

        seg["adv"] = gae_lam
        seg["td_lam_ret"] = seg["adv"] + seg["v_pred"]

        """
            compute the constrained return
            which is sum_{t=0}^N {gamma^t (c(s,a) - C) }
            where C is the constraint, c(s,a) is the penalty of (s,a)
        """
        seg_penalty = seg["penalty"]
        cons_returns = np.empty(truncation_size, dtype=float)
        last_cons_ret = 0
        for t in reversed(range(truncation_size)):
            non_terminal = 1 - seg_done[t]
            cons_returns[t] = self.gamma * last_cons_ret * non_terminal + seg_penalty[t] - self.constraint
            last_cons_ret = cons_returns[t]

        seg["cons_ret"] = cons_returns

        # print("The reward batch is {}".format(seg["rew"]))
        # print("The action batch is {}".format(seg["ac"]))

        self.algorithm.learn(ob=seg["ob"], ac=seg["ac"], adv=seg["adv"],
                             td_lam_ret=seg["td_lam_ret"],
                             cons_ret=seg["cons_ret"])

        # save param
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            self.save_params()

    def save_params(self):
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            print('model saved for update', self.update_cnt)
            save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
            self.algorithm.saver.save(self.algorithm.sess, save_path)

    def load_params(self, load_cnt):
        load_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(load_cnt)
        self.algorithm.saver.restore(self.algorithm.sess, load_path)
        print("load model for update %s " % load_cnt)

    def episode_done(self, test_model):
        return
        # print("The current state explored are in [{},{}]".format(self.min_pos, self.max_pos))

    def write_summary_scalar(self, iteration, tag, value, train_info):
        if train_info:
            self.algorithm.write_summary_scalar(iteration, tag, value)
        else:
            self.my_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)
