# coding=utf-8
import numpy as np
import tensorflow as tf
import random
from experiments.utils.memory import ReplayMemory
from experiments.algorithms.ddpg.ddpg_oprs_v2.ddpg_oprs_v2_algo import DDPGOprsV2Algo
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer

POLICY_BATCH_SIZE = 40
WEIGHT_FUNC_BATCH_SIZE = 40
REPLAY_BUFFER_SIZE = 50000
UPDATE_FREQ = 4
UPDATE_NUM_PER_SWITCH = 5

FIRST_UPDATE_POLICY_SAMPLE_NUM = 200
FIRST_UPDATE_WEIGHT_FUNC_SAMPLE_NUM = 200
MODEL_UPDATE_FREQ = 10000


class DDPGOprsV2Trainer(DDPGTrainer):
    """
        DDPG with optimization of parameterized reward shaping (OPRS) v2
        which optimizes shaping weight function parameters \phi by computing
        the gradient of policy parameters \theta w.r.t phi, with the assumption
        that \nabla_{\phi} theta is only related to \Delta \theta

    """

    def __init__(self, state_dim, action_dim, algo_name="ddpg_oprs_v2", **kwargs):
        super(DDPGOprsV2Trainer, self).__init__(state_dim, action_dim, algo_name, **kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.policy_batch_size = kwargs.get("policy_batch_size", POLICY_BATCH_SIZE)
        self.weight_func_batch_size = kwargs.get("weight_func_batch_size", WEIGHT_FUNC_BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)
        self.update_num_per_switch = kwargs.get("update_num_per_switch", UPDATE_NUM_PER_SWITCH)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DDPGOprsV2Algo(self.session, self.graph, self.state_dim, self.action_dim,
                                        algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0
        self.local_update_cnt = 0

        """
            replay buffer for optimizing policy and shaping weight function
        """
        self.policy_memory = ReplayMemory(self.replay_buffer_size)
        self.weight_func_memory = ReplayMemory(self.replay_buffer_size)

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, f_phi_s = self.algorithm.choose_action(state, test_model)
        return a, f_phi_s

    def experience(self, s, a, r, sp, terminal, **kwargs):
        """
            get the shaping weight f_phi(s) and additional reward F(s,a,s')
        """
        f_phi_s = kwargs.get("f_phi_s")
        F_value = kwargs.get("F")
        self.algorithm.experience((s, a, r, sp, terminal, f_phi_s, F_value))

    def update(self, t):
        if self.algorithm.optimize_policy:
            self.update_policy_parameters(t)
        else:
            self.update_weight_func_parameters(t)

    def update_policy_parameters(self, t):
        if len(self.policy_memory.store) > FIRST_UPDATE_POLICY_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            # print('update for', self.update_cnt)
            self.update_cnt += 1
            self.local_update_cnt += 1

            # get mini batch from replay buffer
            traj_samples = self.policy_memory.get_minibatch(self.policy_batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch, f_phi_batch, \
                    F_batch, mini_batches = [], [], [], [], [], [], [], []
            for traj_index in range(len(traj_samples)):
                epi_traj = traj_samples[traj_index]

                """
                    we only pick up one state from each episode
                """
                sample_index = random.randint(0, len(epi_traj)-1)
                sample = epi_traj[sample_index]
                s_batch.append(sample[0])
                a_batch.append(sample[1])
                r_batch.append(sample[2])
                sp_batch.append(sample[3])
                done_batch.append(sample[4])
                f_phi_batch.append(sample[5])
                F_batch.append(sample[6])

                """
                    the mini-batch of the state
                """
                mini_batch_s = []
                mini_batch_t = []
                mini_batch_F = []
                step = 0
                for j in range(sample_index, len(epi_traj)):
                    exp = epi_traj[j]
                    mini_batch_s.append(exp[0])
                    mini_batch_t.append(step)
                    mini_batch_F.append(exp[6])
                    step += 1

                mini_batches.append([np.array(mini_batch_s),
                                     np.array(mini_batch_t).reshape([-1, 1]),
                                     np.array(mini_batch_F).reshape([-1, 1])])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                                 F=np.array(F_batch).reshape([-1, 1]),
                                 mini_batches=mini_batches)

            # save param
            self.save_params()

    def update_weight_func_parameters(self, t):
        if len(self.weight_func_memory.store) > FIRST_UPDATE_WEIGHT_FUNC_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            # print('update for', self.update_cnt)
            self.update_cnt += 1
            self.local_update_cnt += 1

            # get mini batch from replay buffer
            traj_samples = self.weight_func_memory.get_minibatch(self.weight_func_batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch, f_phi_batch, F_batch = [], [], [], [], [], [], []

            for traj_index in range(len(traj_samples)):
                epi_traj = traj_samples[traj_index]
                for exp_index in range(len(epi_traj)):
                    experience = epi_traj[exp_index]
                    s_batch.append(experience[0])
                    a_batch.append(experience[1])
                    r_batch.append(experience[2])
                    sp_batch.append(experience[3])
                    done_batch.append(experience[4])
                    f_phi_batch.append(experience[5])
                    F_batch.append(experience[6])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                                 F=np.array(F_batch).reshape([-1, 1]))

            # save param
            self.save_params()

    def episode_done(self, test_model):
        """
            get the trajectory of this episode
            and add it to the memory
        """
        episode_traj = self.algorithm.episode_done(test_model)
        if not test_model:
            assert episode_traj is not None
            if self.algorithm.optimize_policy:
                self.policy_memory.add(episode_traj)
            else:
                # self.weight_func_memory.add((s, a, r, sp, terminal))
                self.weight_func_memory.add(episode_traj)

            """
                switch optimization here
            """
            if self.local_update_cnt >= self.update_num_per_switch:
                self.local_update_cnt = 0
                self.algorithm.switch_optimization()
