# coding=utf-8
import numpy as np
import tensorflow as tf
from experiments.utils.memory import ReplayMemory
from experiments.algorithms.ddpg.ddpg_oprs_v1_fop.ddpg_oprs_v1_fop_algo import DDPGOprsV1FopAlgo
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer

REPLAY_BUFFER_SIZE = 1000000
BATCH_SIZE = 1024
UPDATE_FREQ = 4
UPDATE_NUM_PER_SWITCH_A_TO_F = 1000

# how many steps we should conduct one updating process of shaping weight function f
TRUNCATION_SIZE = 1000

# after updating f how many times, we switch to optimize actor
UPDATE_NUM_PER_SWITCH_F_TO_A = 4

FIRST_UPDATE_POLICY_SAMPLE_NUM = 25600
MODEL_UPDATE_FREQ = 10000
SHAPING_WEIGHT_FUNC_MODEL_UPDATE_FREQ = 200


"""
    DDPG with optimization of parameterized reward shaping (OPRS) v1
    which directly relates shaping weight function f_phi with policy
    
    "fop" stands for that we update f using on-policy strategy
"""
class DDPGOprsV1FopTrainer(DDPGTrainer):
    def __init__(self, state_dim, action_dim, algo_name="ddpg_oprs_v1_fop", **kwargs):
        self.first_switch_ocurred = False
        super(DDPGOprsV1FopTrainer, self).__init__(state_dim, action_dim, algo_name, **kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.batch_size = kwargs.get("batch_size", BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)
        self.truncation_size = kwargs.get("truncation_size", TRUNCATION_SIZE)
        self.update_num_per_switch_atof = kwargs.get("update_num_per_switch_atof", UPDATE_NUM_PER_SWITCH_A_TO_F)
        self.update_num_per_switch_ftoa = kwargs.get("update_num_per_switch_ftoa", UPDATE_NUM_PER_SWITCH_F_TO_A)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DDPGOprsV1FopAlgo(self.session, self.graph, self.state_dim, self.action_dim,
                                           algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0
        self.f_update_cnt = 0
        self.local_update_cnt = 0

        """
            replay buffer for optimizing policy and shaping weight function
        """
        self.policy_memory = ReplayMemory(self.replay_buffer_size)
        self.weight_func_memory = ReplayMemory(self.truncation_size)

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, action_info = self.algorithm.choose_action(state, test_model)
        return a, action_info

    def experience(self, s, a, r, sp, terminal, **kwargs):
        """
            get the shaping weight f_phi(s) and additional reward F(s,a,s')
        """
        f_phi_s = kwargs.get("f_phi_s")
        F_value = kwargs.get("F")

        if self.algorithm.optimize_policy:
            self.policy_memory.add((s, a, r, sp, terminal, f_phi_s, F_value))
        else:
            self.weight_func_memory.add((s, a, r, sp, terminal, f_phi_s))
            if len(self.weight_func_memory.store) >= self.truncation_size:
                self.update_weight_func_parameters()
                self.weight_func_memory.store.clear()

    def update(self, t):
        """
            only update actor
        """
        if self.algorithm.optimize_policy:
            self.update_policy_parameters(t)

    def update_policy_parameters(self, t):
        if len(self.policy_memory.store) > FIRST_UPDATE_POLICY_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            self.update_cnt += 1
            self.local_update_cnt += 1

            # get mini batch from replay buffer
            sample = self.policy_memory.get_minibatch(self.batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch,\
                f_phi_batch, F_batch = [], [], [], [], [], [], []

            for i in range(len(sample)):
                s_batch.append(sample[i][0])
                a_batch.append(sample[i][1])
                r_batch.append(sample[i][2])
                sp_batch.append(sample[i][3])
                done_batch.append(sample[i][4])
                f_phi_batch.append(sample[i][5])
                F_batch.append(sample[i][6])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                                 F=np.array(F_batch).reshape([-1, 1]))

            # save param
            self.save_params()

            """
                switch optimization
            """
            if self.local_update_cnt >= self.update_num_per_switch_atof:
                self.local_update_cnt = 0
                self.algorithm.switch_optimization()

    def update_weight_func_parameters(self):
        """
            get all samples from replay memory
        """
        sample = self.weight_func_memory.get_minibatch(self.truncation_size)
        s_batch, a_batch, r_batch, sp_batch, done_batch, \
            f_phi_batch = [], [], [], [], [], []

        for i in range(len(sample)):
            s_batch.append(sample[i][0])
            a_batch.append(sample[i][1])
            r_batch.append(sample[i][2])
            sp_batch.append(sample[i][3])
            done_batch.append(sample[i][4])
            f_phi_batch.append(sample[i][5])

        self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                             np.array(r_batch).reshape([-1, 1]),
                             np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                             f_phi_s=np.array(f_phi_batch).reshape([-1, 1]))

        # save param
        self.save_params()

        """
            switch optimization
        """
        self.local_update_cnt += 1
        self.f_update_cnt += 1
        if self.local_update_cnt >= self.update_num_per_switch_ftoa:
            self.local_update_cnt = 0
            self.algorithm.switch_optimization()
            # print("Switch to optimize policy")

    def save_params(self):
        if self.algorithm.optimize_policy:
            if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
                print('model saved for update', self.update_cnt)
                save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
                self.algorithm.saver.save(self.algorithm.sess, save_path)
        else:
            if self.f_update_cnt % SHAPING_WEIGHT_FUNC_MODEL_UPDATE_FREQ == 0 \
                    and self.f_update_cnt > 0:
                print('model saved for update', self.update_cnt)
                save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
                self.algorithm.saver.save(self.algorithm.sess, save_path)
