# coding=utf-8
import numpy as np
import tensorflow as tf
import random
from .rcpo_ppo_lagfunc_algo import RcpoPpoLagFuncAlgo

RENDER = False

GAMMA = 0.999
LAMBDA = 0.95
TRUNCATION_SIZE = 20000
CONSTRAINT = 0.25
LAG_FUNC_UPDAT_SAMPLE_NUM = 1000

"""
    save model per 1000 episodes
"""
MODEL_UPDATE_FREQ = 1000

"""
    Reward Constrained Policy Optimization based on the PPO algorithm
    But the Lagrange multiplier now is a function
"""

class RcpoPpoLagFuncTrainer(object):
    def __init__(self, state_space, action_space, algo_name="rcpo_ppo_lag_func", **kwargs):
        self.state_space = state_space
        self.action_space = action_space
        self.algo_name = algo_name

        """
            set parameters
        """
        self.set_trainer_parameters(**kwargs)

        """
            initialize the algorithm
        """
        self.init_algo(**kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lmda = kwargs.get("lmda", LAMBDA)
        self.truncation_size = kwargs.get("truncation_size", TRUNCATION_SIZE)

        """
            the penalty constraint
        """
        self.constraint = kwargs.get("constraint", CONSTRAINT)

        """
            how many samples are used for updating Lagrange multiplier function
        """
        self.lag_func_update_sample_num = kwargs.get("lag_func_updat_sample_num",
                                                     LAG_FUNC_UPDAT_SAMPLE_NUM)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = RcpoPpoLagFuncAlgo(self.session, self.graph, self.state_space,
                                            self.action_space, algo_name=self.algo_name,
                                            **kwargs)

        """
            for restoring samples
            one update will be performed after 2048 exps are collected
        """
        self.exp_mini_buffer = [None] * self.truncation_size
        self.add_exp_mini_buffer = []
        self.last_exp = None
        self.exp_cnt = 0
        self.update_cnt = 0

        """
            whether adding more experience
            which is true when self.exp_mini_buffer is full but 
            the last experience is not a terminal one
        """
        self.adding_more_exp = False

        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, v, lagrange_multi = self.algorithm.choose_action(state, test_model)
        return a, {"v_pred": v, "lagrange_multi": lagrange_multi}

    def experience(self, s, a, r, sp, terminal, **kwargs):
        # ppo has no memory
        # self.memory.add((s, a, r, sp, terminal))
        v_pred = kwargs.get("v_pred")
        lagrange_multi = kwargs.get("lagrange_multi")
        penalty = kwargs.get("penalty")

        if self.last_exp is None:
            self.last_exp = (s, a, r, sp, terminal, v_pred, lagrange_multi, penalty)
        else:
            if not self.adding_more_exp:
                i = self.exp_cnt % self.truncation_size
                self.exp_mini_buffer[i] = self.last_exp
                self.last_exp = (s, a, r, sp, terminal, v_pred, lagrange_multi, penalty)

                self.exp_cnt += 1
                if self.exp_cnt % self.truncation_size == 0:
                    """
                        For RCPO, truncation must be at terminal state!!
                        update the policy using the current experiences in buffer
                    """
                    if self.exp_mini_buffer[-1][4]:
                        print("Terminal at truncation point")
                        self.rcpo_ppo_update(next_v_pred=v_pred,
                                             next_lagrange_multi=lagrange_multi,
                                             next_penalty=penalty)
                    else:
                        print("Begin to add more experience until terminal")
                        self.adding_more_exp = True
                        self.add_exp_mini_buffer.clear()
            else:
                self.add_exp_mini_buffer.append(self.last_exp)
                self.last_exp = (s, a, r, sp, terminal, v_pred, lagrange_multi, penalty)

                if self.add_exp_mini_buffer[-1][4]:
                    print("Add a terminal state sample, we begin to update policy")
                    self.rcpo_ppo_update(next_v_pred=v_pred,
                                         next_lagrange_multi=lagrange_multi,
                                         next_penalty=penalty)

                    self.adding_more_exp = False
                    self.add_exp_mini_buffer.clear()

    def update(self, t):
        """
            directly return here
        """
        return

    def rcpo_ppo_update(self, **kwargs):
        last_state_v_pred = kwargs.get("next_v_pred")
        last_state_lagrange_multi = kwargs.get("lagrange_multi")
        last_state_penalty = kwargs.get("penalty")

        """
            conduct update of ppo
            first, we should transform experiences to samples
        """
        # print('update for', self.update_cnt)
        self.update_cnt += 1

        obs0 = self.exp_mini_buffer[0][0]
        act0 = self.exp_mini_buffer[0][1]
        mini_buffer = self.exp_mini_buffer
        if len(self.add_exp_mini_buffer) > 0:
            mini_buffer.extend(self.add_exp_mini_buffer)

        truncation_size = len(mini_buffer)
        seg = {"ob": np.array([obs0 for _ in range(truncation_size)]),
               "ac": np.array([act0 for _ in range(truncation_size)]),
               "prev_ac": np.array([act0 for _ in range(truncation_size)]),
               "rew": np.zeros(truncation_size, dtype=float),
               "v_pred": np.zeros(truncation_size, dtype=float),
               "done": np.zeros(truncation_size, dtype=int),
               "penalty": np.zeros(truncation_size, dtype=float),
               "lagrange_multi": np.zeros(truncation_size, dtype=float)}

        pre_act = act0
        for t in range(truncation_size):
            s, a, r, sp, done, v_pred, lagrange_multi, penalty = mini_buffer[t]
            seg.get("ob")[t] = s
            seg.get("ac")[t] = a
            """
                it should be noted that
                For RCPO, the reward used for updating policy is the penalized reward
            """
            seg.get("rew")[t] = r - lagrange_multi * penalty
            seg.get("done")[t] = done
            seg.get("v_pred")[t] = v_pred
            seg.get("prev_ac")[t] = pre_act
            seg.get("penalty")[t] = penalty
            seg.get("lagrange_multi")[t] = lagrange_multi
            pre_act = a

        """
            add one more value to done and v_pred array
        """
        seg_done = seg["done"]
        vpred = np.append(seg["v_pred"], last_state_v_pred) # currently we add 0

        """
            compute the advantage and GAE values
            for t = T-1, T-2, ..., 3, 2, 1
        """
        gae_lam = np.empty(truncation_size, dtype=float)
        seg_rewards = seg["rew"]
        last_gae_lam = 0
        for t in reversed(range(truncation_size)):
            non_terminal = 1 - seg_done[t]
            delta = seg_rewards[t] + self.gamma * vpred[t + 1] * non_terminal - vpred[t]
            gae_lam[t] = delta + self.gamma * self.lmda * non_terminal * last_gae_lam
            last_gae_lam = gae_lam[t]

        seg["adv"] = gae_lam
        seg["td_lam_ret"] = seg["adv"] + seg["v_pred"]

        """
            organize the mini batch for updating Lagrange multiplier function
        """
        lag_func_update_sample_count = 0
        lag_func_update_sample_index_dict = {}
        while lag_func_update_sample_count < self.lag_func_update_sample_num:

            sample_index = random.randint(0, truncation_size - 1)
            while lag_func_update_sample_index_dict.get(sample_index) is not None:
                sample_index = random.randint(0, truncation_size - 1)

            lag_func_update_sample_index_dict.update({sample_index: True})
            lag_func_update_sample_count += 1

        mini_batches = []
        for sample_index in lag_func_update_sample_index_dict.keys():
            step = 0
            sample_next_index = sample_index
            mnb_s, mnb_a, mnb_t, mnb_lag, mnb_cons = [], [], [], [], []
            while sample_next_index < truncation_size:
                mnb_s.append(seg["ob"][sample_next_index])
                mnb_a.append(seg["ac"][sample_next_index])
                mnb_t.append(step)
                mnb_lag.append(seg["lagrange_multi"][sample_next_index])
                mnb_cons.append(seg["penalty"][sample_next_index] - self.constraint)

                if seg["done"][sample_next_index]:
                    break

                step += 1
                sample_next_index += 1

            mini_batches.append([np.array(mnb_s),
                                 np.array(mnb_a),
                                 np.array(mnb_t).reshape([-1, 1]),
                                 np.array(mnb_lag).reshape([-1, 1]),
                                 np.array(mnb_cons).reshape([-1, 1])])

        self.algorithm.learn(ob=seg["ob"], ac=seg["ac"], adv=seg["adv"],
                             td_lam_ret=seg["td_lam_ret"],
                             mini_batches=mini_batches)

        # save param
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            self.save_params()

    def save_params(self):
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            print('model saved for update', self.update_cnt)
            save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
            self.algorithm.saver.save(self.algorithm.sess, save_path)

    def load_params(self, load_cnt):
        load_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(load_cnt)
        self.algorithm.saver.restore(self.algorithm.sess, load_path)
        print("load model for update %s " % load_cnt)

    def episode_done(self, test_model):
        return
        # print("The current state explored are in [{},{}]".format(self.min_pos, self.max_pos))

    def write_summary_scalar(self, iteration, tag, value, train_info):
        if train_info:
            self.algorithm.write_summary_scalar(iteration, tag, value)
        else:
            self.my_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)
