# coding=utf-8
import numpy as np
import tensorflow as tf
from .ppo_algo import PPOAlgo

RENDER = False

GAMMA = 0.999
LAMBDA = 0.95
TRUNCATION_SIZE = 20000

"""
    save model per 1000 episodes
"""
MODEL_UPDATE_FREQ = 1000


class PPOTrainer(object):
    def __init__(self, state_space, action_space, algo_name="ppo", **kwargs):
        self.state_space = state_space
        self.action_space = action_space
        self.algo_name = algo_name

        """
            set parameters
        """
        self.set_trainer_parameters(**kwargs)

        """
            initialize the algorithm
        """
        self.init_algo(**kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lmda = kwargs.get("lmda", LAMBDA)
        self.truncation_size = kwargs.get("truncation_size", TRUNCATION_SIZE)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = PPOAlgo(self.session, self.graph, self.state_space,
                                 self.action_space, algo_name=self.algo_name,
                                 **kwargs)

        """
            for restoring samples
            one update will be performed after 2048 exps are collected
        """
        self.exp_mini_buffer = [None] * self.truncation_size
        self.last_exp = None
        self.exp_cnt = 0
        self.update_cnt = 0

        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, v = self.algorithm.choose_action(state, test_model)
        return a, v

    def experience(self, s, a, r, sp, terminal, **kwargs):
        # ppo has no memory
        # self.memory.add((s, a, r, sp, terminal))
        v_pred = kwargs.get("v_pred")

        if self.last_exp is None:
            self.last_exp = (s, a, r, sp, terminal, v_pred)
        else:
            i = self.exp_cnt % self.truncation_size
            self.exp_mini_buffer[i] = self.last_exp
            self.last_exp = (s, a, r, sp, terminal, v_pred)

            self.exp_cnt += 1
            if self.exp_cnt % self.truncation_size == 0:
                """
                    update the policy using the current experiences in buffer
                """
                self.ppo_update(next_v_pred=v_pred)

    def update(self, t):
        """
            directly return here
        """
        return

    def ppo_update(self, **kwargs):
        last_state_v_pred = kwargs.get("next_v_pred")

        """
            conduct update of ppo
            first, we should transform experiences to samples
        """
        # print('update for', self.update_cnt)
        self.update_cnt += 1

        obs0 = self.exp_mini_buffer[0][0]
        act0 = self.exp_mini_buffer[0][1]
        # print("The initial state and action is {}, {}".format(obs0, act0))

        seg = {"ob": np.array([obs0 for _ in range(self.truncation_size)]),
               "ac": np.array([act0 for _ in range(self.truncation_size)]),
               "prev_ac": np.array([act0 for _ in range(self.truncation_size)]),
               "rew": np.zeros(self.truncation_size, dtype=float),
               "v_pred": np.zeros(self.truncation_size, dtype=float),
               "done": np.zeros(self.truncation_size, dtype=int)}

        pre_act = act0
        for t in range(self.truncation_size):
            s, a, r, sp, done, v_pred = self.exp_mini_buffer[t]
            seg.get("ob")[t] = s
            seg.get("ac")[t] = a
            seg.get("rew")[t] = r
            seg.get("done")[t] = done
            seg.get("v_pred")[t] = v_pred
            seg.get("prev_ac")[t] = pre_act
            pre_act = a

        """
            add one more value to done and v_pred array
        """
        seg_done = seg["done"]
        vpred = np.append(seg["v_pred"], last_state_v_pred) # currently we add 0

        """
            compute the advantage and GAE values
            for t = T-1, T-2, ..., 3, 2, 1
        """
        gae_lam = np.empty(self.truncation_size, dtype=float)
        seg_rewards = seg["rew"]
        last_gae_lam = 0
        for t in reversed(range(self.truncation_size)):
            non_terminal = 1 - seg_done[t]
            delta = seg_rewards[t] + self.gamma * vpred[t + 1] * non_terminal - vpred[t]
            gae_lam[t] = delta + self.gamma * self.lmda * non_terminal * last_gae_lam
            last_gae_lam = gae_lam[t]

        seg["adv"] = gae_lam
        seg["td_lam_ret"] = seg["adv"] + seg["v_pred"]

        # print("The reward batch is {}".format(seg["rew"]))
        # print("The action batch is {}".format(seg["ac"]))

        self.algorithm.learn(ob=seg["ob"], ac=seg["ac"], adv=seg["adv"],
                             td_lam_ret=seg["td_lam_ret"])

        # save param
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            self.save_params()

    def save_params(self):
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            print('model saved for update', self.update_cnt)
            save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
            self.algorithm.saver.save(self.algorithm.sess, save_path)

    def load_params(self, load_cnt):
        load_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(load_cnt)
        self.algorithm.saver.restore(self.algorithm.sess, load_path)
        print("load model for update %s " % load_cnt)

    def episode_done(self, test_model):
        return
        # print("The current state explored are in [{},{}]".format(self.min_pos, self.max_pos))

    def write_summary_scalar(self, iteration, tag, value, train_info):
        if train_info:
            self.algorithm.write_summary_scalar(iteration, tag, value)
        else:
            self.my_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)
