# coding=utf-8
import numpy as np
import tensorflow as tf
import random
from experiments.utils.memory import ReplayMemory
from experiments.algorithms.ddpg.ddpg_oprs_v2_approx.ddpg_oprs_v2_approx_algo import DDPGOprsV2ApproxAlgo
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer


POLICY_BATCH_SIZE = 40
WEIGHT_FUNC_BATCH_SIZE = 40
REPLAY_BUFFER_SIZE = 50000
UPDATE_FREQ = 4
UPDATE_NUM_PER_SWITCH = 5
COMPUTE_GRAD_THETA_WRT_PHI_SAMPLE_NUM = 4

FIRST_UPDATE_POLICY_SAMPLE_NUM = 200
FIRST_UPDATE_WEIGHT_FUNC_SAMPLE_NUM = 200
MODEL_UPDATE_FREQ = 10000

"""
    approximate version of OPRS-V2
    use only few samples to compute the gradient of theta w.r.t phi
    but use many samples to compute policy gradient
"""
class DDPGOprsV2ApproxTrainer(DDPGTrainer):
    def __init__(self, state_dim, action_dim, algo_name="ddpg_oprs_v2_approx", **kwargs):
        super(DDPGOprsV2ApproxTrainer, self).__init__(state_dim, action_dim, algo_name, **kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.policy_batch_size = kwargs.get("policy_batch_size", POLICY_BATCH_SIZE)
        self.weight_func_batch_size = kwargs.get("weight_func_batch_size", WEIGHT_FUNC_BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)
        self.update_num_per_switch = kwargs.get("update_num_per_switch", UPDATE_NUM_PER_SWITCH)
        self.nabla_theta_wrt_phi_sam_num = kwargs.get("nabla_theta_wrt_phi_sam_num",
                                                      COMPUTE_GRAD_THETA_WRT_PHI_SAMPLE_NUM)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DDPGOprsV2ApproxAlgo(self.session, self.graph, self.state_dim, self.action_dim,
                                              algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0
        self.local_update_cnt = 0

        """
            replay buffer for optimizing policy and shaping weight function
        """
        self.policy_memory = ReplayMemory(self.replay_buffer_size)
        self.weight_func_memory = ReplayMemory(self.replay_buffer_size)

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, f_phi_s = self.algorithm.choose_action(state, test_model)
        return a, f_phi_s

    def experience(self, s, a, r, sp, terminal, **kwargs):
        """
            get the shaping weight f_phi(s) and additional reward F(s,a,s')
        """
        f_phi_s = kwargs.get("f_phi_s")
        F_value = kwargs.get("F")
        self.algorithm.experience((s, a, r, sp, terminal, f_phi_s, F_value))

    def update(self, t):
        if self.algorithm.optimize_policy:
            self.update_policy_parameters(t)
        else:
            self.update_weight_func_parameters(t)

    def update_policy_parameters(self, t):
        if len(self.policy_memory.store) > FIRST_UPDATE_POLICY_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            # print('update for', self.update_cnt)
            self.update_cnt += 1
            self.local_update_cnt += 1

            # get mini batch from replay buffer
            traj_samples = self.policy_memory.get_minibatch(self.policy_batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch, f_phi_batch, F_batch = [], [], [], [], [], [], []
            for traj_index in range(len(traj_samples)):
                epi_traj = traj_samples[traj_index]

                for exp_index in range(len(epi_traj)):
                    sample = epi_traj[exp_index]
                    s_batch.append(sample[0])
                    a_batch.append(sample[1])
                    r_batch.append(sample[2])
                    sp_batch.append(sample[3])
                    done_batch.append(sample[4])
                    f_phi_batch.append(sample[5])
                    F_batch.append(sample[6])

            """
                pick up the states for computing gradient of theta w.r.t. phi
            """
            sample_traj_count = 0
            sample_traj_index_dict = {}
            while sample_traj_count < self.nabla_theta_wrt_phi_sam_num:

                sample_traj_index = random.randint(0, len(traj_samples) - 1)
                while sample_traj_index_dict.get(sample_traj_index) is not None:
                    sample_traj_index = random.randint(0, len(traj_samples) - 1)

                sample_traj_index_dict.update({sample_traj_index: True})
                sample_traj_count += 1

            grad_compute_s_batch, mini_batches = [], []
            for sample_traj_index in sample_traj_index_dict.keys():
                sample_traj = traj_samples[sample_traj_index]
                exp_index = random.randint(0, len(sample_traj)-1)

                experience = sample_traj[exp_index]
                grad_compute_s_batch.append(experience[0])
                mini_batch_s = []
                mini_batch_t = []
                mini_batch_F = []
                step = 0
                for j in range(exp_index, len(sample_traj)):
                    exp = sample_traj[j]
                    mini_batch_s.append(exp[0])
                    mini_batch_t.append(step)
                    mini_batch_F.append(exp[6])
                    step += 1

                mini_batches.append([np.array(mini_batch_s),
                                     np.array(mini_batch_t).reshape([-1, 1]),
                                     np.array(mini_batch_F).reshape([-1, 1])])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                                 F=np.array(F_batch).reshape([-1, 1]),
                                 grad_compute_s_batch=np.array(grad_compute_s_batch),
                                 mini_batches=mini_batches)

            # save param
            self.save_params()

    def update_weight_func_parameters(self, t):
        if len(self.weight_func_memory.store) > FIRST_UPDATE_WEIGHT_FUNC_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            self.update_cnt += 1
            self.local_update_cnt += 1

            # get mini batch from replay buffer
            traj_samples = self.weight_func_memory.get_minibatch(self.weight_func_batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch, f_phi_batch, F_batch = [], [], [], [], [], [], []

            for traj_index in range(len(traj_samples)):
                epi_traj = traj_samples[traj_index]
                for exp_index in range(len(epi_traj)):
                    experience = epi_traj[exp_index]
                    s_batch.append(experience[0])
                    a_batch.append(experience[1])
                    r_batch.append(experience[2])
                    sp_batch.append(experience[3])
                    done_batch.append(experience[4])
                    f_phi_batch.append(experience[5])
                    F_batch.append(experience[6])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                                 F=np.array(F_batch).reshape([-1, 1]))

            # save param
            self.save_params()

    def episode_done(self, test_model):
        """
            get the trajectory of this episode
            and add it to the memory
        """
        episode_traj = self.algorithm.episode_done(test_model)
        if not test_model:
            assert episode_traj is not None
            if self.algorithm.optimize_policy:
                self.policy_memory.add(episode_traj)
            else:
                # self.weight_func_memory.add((s, a, r, sp, terminal))
                self.weight_func_memory.add(episode_traj)

            """
                switch optimization here
            """
            if self.local_update_cnt >= self.update_num_per_switch:
                self.local_update_cnt = 0
                self.algorithm.switch_optimization()
