# coding=utf-8
import numpy as np
import tensorflow as tf
from experiments.utils.memory import ReplayMemory
from experiments.algorithms.ddpg.ddpg_oprs_v1_freeze.ddpg_oprs_v1_freeze_algo import DDPGOprsV1FreezeAlgo
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer

BATCH_SIZE = 1024
REPLAY_BUFFER_SIZE = 1000000
UPDATE_FREQ = 4
UPDATE_NUM_PER_SWITCH = 400

FIRST_UPDATE_POLICY_SAMPLE_NUM = 25600
FIRST_UPDATE_WEIGHT_FUNC_SAMPLE_NUM = 25600
MODEL_UPDATE_FREQ = 10000

"""
    DDPG with optimization of parameterized reward shaping (OPRS) v1
    which directly relate shaping weight function f_phi with policy
"""


class DDPGOprsV1FreezeTrainer(DDPGTrainer):
    def __init__(self, state_dim, action_dim, algo_name="ddpg_oprs_v1_freeze", **kwargs):
        super(DDPGOprsV1FreezeTrainer, self).__init__(state_dim, action_dim, algo_name, **kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.batch_size = kwargs.get("batch_size", BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)
        self.update_num_per_switch = kwargs.get("update_num_per_switch", UPDATE_NUM_PER_SWITCH)

        self.update_f_pi_sim = kwargs.get("update_pi_f_simutaneously", False)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DDPGOprsV1FreezeAlgo(self.session, self.graph, self.state_dim, self.action_dim,
                                              algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0
        self.local_update_cnt = 0

        """
            replay buffer for optimizing policy and shaping weight function
        """
        self.policy_memory = ReplayMemory(self.replay_buffer_size)
        self.weight_func_memory = ReplayMemory(self.replay_buffer_size)

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, action_info = self.algorithm.choose_action(state, test_model)
        return a, action_info

    def experience(self, s, a, r, sp, terminal, **kwargs):
        """
            get the shaping weight f_phi(s) and additional reward F(s,a,s')
        """
        f_phi_s = kwargs.get("f_phi_s")
        F_value = kwargs.get("F")

        if not self.update_f_pi_sim:
            if self.algorithm.optimize_policy:
                self.policy_memory.add((s, a, r, sp, terminal, f_phi_s, F_value))
            else:
                self.weight_func_memory.add((s, a, r, sp, terminal, f_phi_s))
        else:
            self.policy_memory.add((s, a, r, sp, terminal, f_phi_s, F_value))
            self.weight_func_memory.add((s, a, r, sp, terminal, f_phi_s))

    def update(self, t):
        if not self.update_f_pi_sim:
            if self.algorithm.optimize_policy:
                self.update_policy_parameters(t)
            else:
                self.update_weight_func_parameters(t)
        else:
            self.update_policy_parameters(t)
            self.update_weight_func_parameters(t)

    def update_policy_parameters(self, t):
        if len(self.policy_memory.store) > FIRST_UPDATE_POLICY_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            self.update_cnt += 1
            self.local_update_cnt += 1

            # get mini batch from replay buffer
            sample = self.policy_memory.get_minibatch(self.batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch, \
                f_phi_batch, F_batch = [], [], [], [], [], [], []

            for i in range(len(sample)):
                s_batch.append(sample[i][0])
                a_batch.append(sample[i][1])
                r_batch.append(sample[i][2])
                sp_batch.append(sample[i][3])
                done_batch.append(sample[i][4])
                f_phi_batch.append(sample[i][5])
                F_batch.append(sample[i][6])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                                 F=np.array(F_batch).reshape([-1, 1]))

            # save param
            self.save_params()

            """
                switch optimization
            """
            if not self.update_f_pi_sim:
                if self.local_update_cnt >= self.update_num_per_switch:
                    self.local_update_cnt = 0
                    self.algorithm.switch_optimization()
                    print("Switch to optimize weight function")
            else:
                self.algorithm.switch_optimization()

    def update_weight_func_parameters(self, t):
        if len(self.weight_func_memory.store) > FIRST_UPDATE_WEIGHT_FUNC_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            self.update_cnt += 1
            self.local_update_cnt += 1

            # get mini batch from replay buffer
            sample = self.weight_func_memory.get_minibatch(self.batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch, \
                f_phi_batch = [], [], [], [], [], []

            for i in range(len(sample)):
                s_batch.append(sample[i][0])
                a_batch.append(sample[i][1])
                r_batch.append(sample[i][2])
                sp_batch.append(sample[i][3])
                done_batch.append(sample[i][4])
                f_phi_batch.append(sample[i][5])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]))

            # save param
            self.save_params()

            """
                switch optimization
            """
            if not self.update_f_pi_sim:
                if self.local_update_cnt >= self.update_num_per_switch:
                    self.local_update_cnt = 0
                    self.algorithm.switch_optimization()
                    print("Switch to optimize policy")
            else:
                self.algorithm.switch_optimization()

