# coding=utf-8
import numpy as np
import random
import tensorflow as tf
from .ppo_oprs_v2_fsa_algo import PPOOprsV2FsaAlgo
from ..ppo_trainer import PPOTrainer

GAMMA = 0.999
LAMBDA = 0.95
TRUNCATION_SIZE = 20000
UPDATE_NUM_PER_SWITCH = 10

COMPUTE_GRAD_THETA_WRT_PHI_SAMPLE_NUM = 100


"""
    save model per 1000 episodes
"""
MODEL_UPDATE_FREQ = 1000


"""
    For PPO algorithm, the updating of the shaping weight function f 
    should be conducted using on-policy mode, for short we call it "fop"
"""


class PPOOprsV2FsaTrainer(PPOTrainer):
    def __init__(self, state_space, action_space, algo_name="ppo_oprs_v2_fsa", **kwargs):
        super(PPOOprsV2FsaTrainer, self).__init__(state_space, action_space, algo_name, **kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lmda = kwargs.get("lmda", LAMBDA)
        self.truncation_size = kwargs.get("truncation_size", TRUNCATION_SIZE)
        self.update_num_per_switch = kwargs.get("update_num_per_switch",
                                                UPDATE_NUM_PER_SWITCH)
        self.nabla_theta_wrt_phi_sam_num = kwargs.get("nabla_theta_wrt_phi_sam_num",
                                                      COMPUTE_GRAD_THETA_WRT_PHI_SAMPLE_NUM)
        self.update_f_pi_sim = kwargs.get("update_pi_f_simutaneously", False)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = PPOOprsV2FsaAlgo(self.session, self.graph, self.state_space,
                                          self.action_space, algo_name=self.algo_name,
                                          **kwargs)

        """
            for restoring samples
            one update will be performed after 2048 exps are collected
        """
        self.exp_mini_buffer = [None] * self.truncation_size
        self.last_exp = None
        self.exp_cnt = 0
        self.update_cnt = 0
        self.local_update_cnt = 0
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, action_info = self.algorithm.choose_action(state, test_model)
        # action_info contains "v_pred", "v_pred_true", "f_phi_s"
        return a, action_info

    def experience(self, s, a, r, sp, terminal, **kwargs):
        # ppo has no memory
        v_pred = kwargs.get("v_pred")
        v_pred_true = kwargs.get("v_pred_true")
        f_phi_s = kwargs.get("f_phi_s")
        F_value = kwargs.get("F")

        if self.last_exp is None:
            self.last_exp = (s, a, r, sp, terminal, v_pred, v_pred_true, f_phi_s, F_value)
        else:
            i = self.exp_cnt % self.truncation_size
            self.exp_mini_buffer[i] = self.last_exp
            self.last_exp = (s, a, r, sp, terminal, v_pred, v_pred_true, f_phi_s, F_value)

            self.exp_cnt += 1
            if self.exp_cnt % self.truncation_size == 0:
                """
                    update the policy using the current experiences in buffer
                """
                self.ppo_update(next_v_pred=v_pred, next_ac=a,
                                next_v_pred_true=v_pred_true,
                                next_f_phi=f_phi_s, next_F=F_value)

    def update(self, t):
        """
            directly return here
        """
        return

    def ppo_update(self, **kwargs):
        if not self.update_f_pi_sim:
            if self.algorithm.optimize_policy:
                self.ppo_update_policy(**kwargs)
            else:
                self.ppo_update_shaping_weight_func(**kwargs)
        else:
            self.ppo_update_policy(**kwargs)
            self.ppo_update_shaping_weight_func(**kwargs)

    def ppo_update_policy(self, **kwargs):
        """
            conduct update of ppo
            first, we should transform experiences to samples
        """
        self.update_cnt += 1
        self.local_update_cnt += 1

        obs0 = self.exp_mini_buffer[0][0]
        act0 = self.exp_mini_buffer[0][1]

        seg = {"ob": np.array([obs0 for _ in range(self.truncation_size)]),
               "ac": np.array([act0 for _ in range(self.truncation_size)]),
               "rew": np.zeros(self.truncation_size, dtype=float),
               "v_pred": np.zeros(self.truncation_size, dtype=float),
               "done": np.zeros(self.truncation_size, dtype=int),
               "F": np.zeros(self.truncation_size, dtype=float),
               "f_phi_s": np.zeros(self.truncation_size, dtype=float)
               }

        for t in range(self.truncation_size):
            s, a, r, sp, done, v_pred, _, f_phi_s, F = self.exp_mini_buffer[t]
            seg["ob"][t] = s
            seg["ac"][t] = a
            seg["rew"][t] = r
            seg["done"][t] = done
            seg["v_pred"][t] = v_pred
            seg["f_phi_s"][t] = f_phi_s
            seg["F"][t] = F
            # if t > 0:
            #     seg["next_ac"][t-1] = a

        """
            add one more value to done and v_pred array
        """
        seg_done = seg["done"]
        vpred = np.append(seg["v_pred"], kwargs.get("next_v_pred"))

        """
            compute the advantage and GAE values
            for t = T-1, T-2, ..., 3, 2, 1
        """
        gae_lam = np.empty(self.truncation_size, dtype=float)
        seg_rewards = seg["rew"]
        seg_F = seg["F"]
        seg_f = seg["f_phi_s"]
        last_gae_lam = 0
        for t in reversed(range(self.truncation_size)):
            non_terminal = 1 - seg_done[t]
            delta = seg_rewards[t] + seg_f[t] * seg_F[t] + self.gamma * vpred[t + 1] * non_terminal - vpred[t]
            gae_lam[t] = delta + self.gamma * self.lmda * non_terminal * last_gae_lam
            last_gae_lam = gae_lam[t]

        seg["adv"] = gae_lam
        seg["td_lam_ret"] = seg["adv"] + seg["v_pred"]

        """
            the data for computing nabla_{phi} theta
        """
        nbala_theta_phi_sample_count = 0
        nabla_theta_phi_sample_index_dict = {}
        while nbala_theta_phi_sample_count < self.nabla_theta_wrt_phi_sam_num:

            sample_index = random.randint(0, self.truncation_size - 1)
            while nabla_theta_phi_sample_index_dict.get(sample_index) is not None:
                sample_index = random.randint(0, self.truncation_size - 1)

            nabla_theta_phi_sample_index_dict.update({sample_index: True})
            nbala_theta_phi_sample_count += 1

        grad_compute_s_batch, grad_compute_a_batch, mini_batches = [], [], []
        for sample_index in nabla_theta_phi_sample_index_dict.keys():
            sample_s = seg["ob"][sample_index]
            sample_a = seg["ac"][sample_index]
            grad_compute_s_batch.append(sample_s)
            grad_compute_a_batch.append(sample_a)

            step = 0
            sample_next_index = sample_index
            mini_batch_s, mini_batch_a, mini_batch_t, mini_batch_F = [], [], [], []
            while sample_next_index < self.truncation_size:
                mini_batch_s.append(seg["ob"][sample_next_index])
                mini_batch_a.append(seg["ac"][sample_next_index])
                mini_batch_t.append(step)
                mini_batch_F.append(seg["F"][sample_next_index])

                if seg["done"][sample_next_index]:
                    break

                step += 1
                sample_next_index += 1

            mini_batches.append([np.array(mini_batch_s),
                                 np.array(mini_batch_a),
                                 np.array(mini_batch_t).reshape([-1, 1]),
                                 np.array(mini_batch_F).reshape([-1, 1])])

        self.algorithm.learn(ob=seg["ob"], ac=seg["ac"],
                             adv=seg["adv"], td_lam_ret=seg["td_lam_ret"],
                             batch_f_phi=seg["f_phi_s"],
                             grad_compute_s_batch=np.array(grad_compute_s_batch),
                             grad_compute_a_batch=np.array(grad_compute_a_batch),
                             mini_batches=mini_batches
                             )

        # save param
        self.save_params()

        """
            switch optimization
        """
        if not self.update_f_pi_sim:
            if self.local_update_cnt >= self.update_num_per_switch:
                self.local_update_cnt = 0
                self.algorithm.switch_optimization()
                print("Begin to optimize shaping weight function")
        else:
            self.algorithm.switch_optimization()

    def ppo_update_shaping_weight_func(self, **kwargs):

        self.update_cnt += 1
        self.local_update_cnt += 1

        obs0 = self.exp_mini_buffer[0][0]
        act0 = self.exp_mini_buffer[0][1]
        # print("The initial state and action is {}, {}".format(obs0, act0))

        seg = {"ob": np.array([obs0 for _ in range(self.truncation_size)]),
               "ac": np.array([act0 for _ in range(self.truncation_size)]),
               "rew": np.zeros(self.truncation_size, dtype=float),
               "v_pred_true": np.zeros(self.truncation_size, dtype=float),
               "done": np.zeros(self.truncation_size, dtype=int),
               "F": np.zeros(self.truncation_size, dtype=float),
               "f_phi_s": np.zeros(self.truncation_size, dtype=float)
               }

        for t in range(self.truncation_size):
            s, a, r, sp, done, _, v_pred_true, _, _ = self.exp_mini_buffer[t]
            seg["ob"][t] = s
            seg["ac"][t] = a
            seg["rew"][t] = r
            seg["done"][t] = done
            seg["v_pred_true"][t] = v_pred_true
            # seg["f_phi_s"][t] = f_phi_s
            # seg["F"][t] = F
            # if t > 0:
            #     seg["next_ac"][t - 1] = a

        """
            add one more value to done and v_pred array
        """
        seg_done = seg["done"]
        seg_v_pred_true = np.append(seg["v_pred_true"], kwargs.get("next_v_pred_true"))

        """
            compute the advantage and GAE values
            for t = T-1, T-2, ..., 3, 2, 1
        """
        gae_lam = np.empty(self.truncation_size, dtype=float)
        seg_rewards = seg["rew"]
        last_gae_lam = 0
        for t in reversed(range(self.truncation_size)):
            non_terminal = 1 - seg_done[t]
            delta = seg_rewards[t] + self.gamma * seg_v_pred_true[t + 1] * non_terminal - seg_v_pred_true[t]
            gae_lam[t] = delta + self.gamma * self.lmda * non_terminal * last_gae_lam
            last_gae_lam = gae_lam[t]

        seg["adv_true"] = gae_lam
        seg["ret_true"] = seg["adv_true"] + seg["v_pred_true"]

        self.algorithm.learn(ob=seg["ob"], ac=seg["ac"], adv_true=seg["adv_true"],
                             ret_true=seg["ret_true"])

        # save param
        self.save_params()

        """
            switch optimization
        """
        if not self.update_f_pi_sim:
            if self.local_update_cnt >= self.update_num_per_switch:
                self.local_update_cnt = 0
                self.algorithm.switch_optimization()
                print("Begin to optimize policy")
        else:
            self.algorithm.switch_optimization()

    def save_params(self):
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            print('model saved for update', self.update_cnt)
            save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
            self.algorithm.saver.save(self.algorithm.sess, save_path)


