# coding=utf-8
import numpy as np
import tensorflow as tf
from experiments.utils.memory import ReplayMemory
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer
from experiments.algorithms.ddpg.ddpg_dpba.ddpg_dpba_algo import DDPGDpbaAlgo

RENDER = False

FIRST_UPDATE_SAMPLE_NUM = 25600
MODEL_UPDATE_FREQ = 10000


class DDPGDpbaTrainer(DDPGTrainer):
    def __init__(self, state_dim, action_dim, algo_name="ddpg_dpba", **kwargs):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.algo_name = algo_name

        """
            set trainer parameters
        """
        self.set_trainer_parameters(**kwargs)

        self.init_algo(**kwargs)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DDPGDpbaAlgo(self.session, self.graph, self.state_dim, self.action_dim,
                                      algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0

        # the replay buffer
        self.memory = ReplayMemory(self.replay_buffer_size)

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def experience(self, s, a, r, sp, terminal, **kwargs):
        """
            get the shaping reward of s
        """
        c = kwargs.get("c")
        self.memory.add((s, a, r, sp, terminal, c))

    def update(self, t):

        if len(self.memory.store) > FIRST_UPDATE_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            # print('update for', self.update_cnt)
            self.update_cnt += 1

            # get mini batch from replay buffer
            sample = self.memory.get_minibatch(self.batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch, c_batch = [], [], [], [], [], []

            for i in range(len(sample)):
                s_batch.append(sample[i][0])
                a_batch.append(sample[i][1])
                r_batch.append(sample[i][2])
                sp_batch.append(sample[i][3])
                done_batch.append(sample[i][4])
                c_batch.append(sample[i][5])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 c_batch=np.array(c_batch).reshape([-1, 1]))

            # save param
            self.save_params()

