# coding=utf-8
from .dqn_trainer import DqnTrainer

GAMMA = 0.9     # reward discount
TAU = 0.01      # soft replacement
RENDER = False
BATCH_SIZE = 1024

REPLAY_BUFFER_SIZE = 1000000
UPDATE_FREQ = 100
FIRST_UPDATE_SAMPLE_NUM = 25600
MODEL_UPDATE_FREQ = 1000

# high frequency target soft update is better for DQN
TARGET_UPDATE_FREQ = 1


class DqnPbrsTrainer(DqnTrainer):
    def __init__(self, state_dim, action_num, algo_name="dqn_pbrs"):
        super(DqnPbrsTrainer, self).__init__(state_dim, action_num, algo_name)

    def experience(self, s, a_n, r_n, s_n, terminal, **kwargs):
        """
           get the potential of s and s_n
           and compute the shaping reward
        """
        phi_s = kwargs.get("phi_s")
        phi_sp = kwargs.get("phi_sp")
        f_ssp = GAMMA * phi_sp - phi_s
        self.memory.add((s, a_n, r_n + f_ssp, s_n, terminal))


