# coding=utf-8
import numpy as np
import tensorflow as tf
import random
from experiments.utils.memory import ReplayMemory
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer
from ..ddpg_oprs_v3_fsaqin.ddpg_oprs_v3_fsaqin_algo import DDPGOprsV3FsaqinAlgo

POLICY_BATCH_SIZE = 40
WEIGHT_FUNC_BATCH_SIZE = 40
REPLAY_BUFFER_SIZE = 50000
UPDATE_FREQ = 4
COMPUTE_GRAD_THETA_WRT_PHI_SAMPLE_NUM = 100 #4

# how many steps we should conduct one updating process
# for oprs_v2, truncation_size is episode-wise
EPISODE_TRUNCATION_SIZE = 40

UPDATE_NUM_PER_SWITCH_A_TO_F = 5  # 400 #40

# after updating f how many times, we switch to optimize actor
UPDATE_NUM_PER_SWITCH_F_TO_A = 4

FIRST_UPDATE_POLICY_SAMPLE_NUM = 200
MODEL_UPDATE_FREQ = 10000
SHAPING_WEIGHT_FUNC_MODEL_UPDATE_FREQ = 200

"""
    approximate version of OPRS-V3
    use only few samples to compute the gradient of theta w.r.t phi
    but use many samples to compute policy gradient
"""


class DDPGOprsV3FsaqinTrainer(DDPGTrainer):
    def __init__(self, state_dim, action_dim, algo_name="ddpg_oprs_v3_fsaqin", **kwargs):
        super(DDPGOprsV3FsaqinTrainer, self).__init__(state_dim, action_dim, algo_name, **kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.policy_batch_size = kwargs.get("policy_batch_size", POLICY_BATCH_SIZE)
        self.weight_func_batch_size = kwargs.get("weight_func_batch_size", WEIGHT_FUNC_BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)
        self.episode_truncation_size = kwargs.get("episode_truncation_size", EPISODE_TRUNCATION_SIZE)
        self.nabla_theta_wrt_phi_sam_num = kwargs.get("nabla_theta_wrt_phi_sam_num",
                                                      COMPUTE_GRAD_THETA_WRT_PHI_SAMPLE_NUM)
        self.update_num_per_switch_atof = kwargs.get("update_num_per_switch_atof",
                                                     UPDATE_NUM_PER_SWITCH_A_TO_F)
        self.update_num_per_switch_ftoa = kwargs.get("update_num_per_switch_ftoa",
                                                     UPDATE_NUM_PER_SWITCH_F_TO_A)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DDPGOprsV3FsaqinAlgo(self.session, self.graph, self.state_dim, self.action_dim,
                                              algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0
        self.f_update_cnt = 0
        self.local_update_cnt = 0

        """
            replay buffer for optimizing policy and shaping weight function
        """
        self.policy_memory = ReplayMemory(self.replay_buffer_size)
        self.policy_memory_sample_num = 0
        self.weight_func_memory = ReplayMemory(self.episode_truncation_size)
        self.weight_func_memory_sample_num = 0

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, action_info = self.algorithm.choose_action(state, test_model)
        return a, action_info

    def experience(self, s, a, r, sp, terminal, **kwargs):
        """
            get the shaping weight f_phi(s) and additional reward F(s,a,s')
        """
        f_phi_s = kwargs.get("f_phi_s")
        F_value = kwargs.get("F")
        self.algorithm.experience((s, a, r, sp, terminal, f_phi_s, F_value))

    def update(self, t):
        """
            only update actor
        """
        if self.algorithm.optimize_policy:
            self.update_policy_parameters(t)

    def update_policy_parameters(self, t):
        if len(self.policy_memory.store) > FIRST_UPDATE_POLICY_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            if self.policy_memory_sample_num < 20000:
                return

            # get mini batch from replay buffer
            traj_samples = self.policy_memory.get_minibatch(self.policy_batch_size)

            batch_sample_num = 0
            for traj_index in range(len(traj_samples)):
                batch_sample_num += len(traj_samples[traj_index])

            while batch_sample_num < 1024:
                traj_samples_add = self.policy_memory.get_minibatch(self.policy_batch_size)
                traj_samples.extend(traj_samples_add)

                for tj_index in range(len(traj_samples_add)):
                    batch_sample_num += len(traj_samples_add[tj_index])

            s_batch, a_batch, r_batch, sp_batch, done_batch, f_phi_batch, F_batch = [], [], [], [], [], [], []
            for traj_index in range(len(traj_samples)):
                epi_traj = traj_samples[traj_index]

                for exp_index in range(len(epi_traj)):
                    sample = epi_traj[exp_index]
                    s_batch.append(sample[0])
                    a_batch.append(sample[1])
                    r_batch.append(sample[2])
                    sp_batch.append(sample[3])
                    done_batch.append(sample[4])
                    f_phi_batch.append(sample[5])
                    F_batch.append(sample[6])

            """
                pick up the states for computing gradient of theta w.r.t. phi
            """
            grad_sample_count = 0
            grad_compute_s_batch = []
            grad_sample_index_dict = {}
            # grad_sample_index_list = [None] * self.nabla_theta_wrt_phi_sam_num
            while grad_sample_count < self.nabla_theta_wrt_phi_sam_num:

                grad_sample_index = random.randint(0, len(s_batch) - 1)
                while grad_sample_index_dict.get(grad_sample_index) is not None:
                    grad_sample_index = random.randint(0, len(s_batch) - 1)

                grad_sample_index_dict.update({grad_sample_index: True})

                # grad_sample_index_list[grad_sample_count] = grad_sample_index
                grad_sample_count += 1

            for grad_sample_index in grad_sample_index_dict.keys():
                grad_compute_s_batch.append(s_batch[grad_sample_index])

            # for list_index in range(len(grad_sample_index_list)):
            #     grad_sample_index = grad_sample_index_list[list_index]
            #     grad_compute_s_batch.append(s_batch[grad_sample_index])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                                 F=np.array(F_batch).reshape([-1, 1]),
                                 grad_compute_s_batch=np.array(grad_compute_s_batch))

            # save param
            self.save_params()

            """
                switch optimization
            """
            self.update_cnt += 1
            self.local_update_cnt += 1
            if self.local_update_cnt >= self.update_num_per_switch_atof:
                self.local_update_cnt = 0
                self.algorithm.switch_optimization()
                print("begin to optimize shaping weight function")

                # for neg
                # self.policy_memory.store.clear()
                # self.policy_memory_sample_num = 0

    def update_weight_func_parameters(self):
        # get mini batch from replay buffer
        ep_trunc_size = max(len(self.weight_func_memory.store), self.episode_truncation_size)
        traj_samples = self.weight_func_memory.get_minibatch(ep_trunc_size)
        s_batch, a_batch, r_batch, sp_batch, done_batch, f_phi_batch, F_batch = [], [], [], [], [], [], []

        for traj_index in range(len(traj_samples)):
            epi_traj = traj_samples[traj_index]
            for exp_index in range(len(epi_traj)):
                experience = epi_traj[exp_index]
                s_batch.append(experience[0])
                a_batch.append(experience[1])
                r_batch.append(experience[2])
                sp_batch.append(experience[3])
                done_batch.append(experience[4])
                f_phi_batch.append(experience[5])
                F_batch.append(experience[6])

        self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                             np.array(r_batch).reshape([-1, 1]),
                             np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                             f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                             F=np.array(F_batch).reshape([-1, 1]))

        # save param
        self.save_params()

        """
            switch optimization
        """
        self.local_update_cnt += 1
        self.f_update_cnt += 1
        if self.local_update_cnt >= self.update_num_per_switch_ftoa:
            self.local_update_cnt = 0
            self.algorithm.switch_optimization()
            print("begin to optimize policy")

    def episode_done(self, test_model):
        """
            get the trajectory of this episode
            and add it to the memory
        """
        episode_traj = self.algorithm.episode_done(test_model)
        if not test_model:
            assert episode_traj is not None
            if self.algorithm.optimize_policy:
                self.policy_memory_sample_num += len(episode_traj)
                self.policy_memory.add(episode_traj)
            else:
                self.weight_func_memory_sample_num += len(episode_traj)
                self.weight_func_memory.add(episode_traj)
                if len(self.weight_func_memory.store) >= self.episode_truncation_size and \
                    self.weight_func_memory_sample_num >= 20000:
                    self.update_weight_func_parameters()
                    self.weight_func_memory.store.clear()
                    self.weight_func_memory_sample_num = 0

    def save_params(self):
        if self.algorithm.optimize_policy:
            if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
                print('model saved for update', self.update_cnt)
                save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
                self.algorithm.saver.save(self.algorithm.sess, save_path)
        else:
            if self.f_update_cnt % SHAPING_WEIGHT_FUNC_MODEL_UPDATE_FREQ == 0 \
                    and self.f_update_cnt > 0:
                print('model saved for update', self.update_cnt)
                save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
                self.algorithm.saver.save(self.algorithm.sess, save_path)


