# coding=utf-8
import numpy as np
import tensorflow as tf
import random
from experiments.utils.memory import ReplayMemory
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer
from .ddpg_oprs_v2_fsa_algo import DDPGOprsV2FsaAlgo

POLICY_BATCH_SIZE = 40
WEIGHT_FUNC_BATCH_SIZE = 40
REPLAY_BUFFER_SIZE = 50000
UPDATE_FREQ = 4
COMPUTE_GRAD_THETA_WRT_PHI_SAMPLE_NUM = 4

# how many steps we should conduct one updating process
# for oprs_v2, truncation_size is episode-wise
EPISODE_TRUNCATION_SIZE = 40

UPDATE_NUM_PER_SWITCH_A_TO_F = 5  # 400 #40

# after updating f how many times, we switch to optimize actor
UPDATE_NUM_PER_SWITCH_F_TO_A = 4

FIRST_UPDATE_POLICY_SAMPLE_NUM = 200
MODEL_UPDATE_FREQ = 10000
SHAPING_WEIGHT_FUNC_MODEL_UPDATE_FREQ = 200

"""
    approximate version of OPRS-V2
    use only few samples to compute the gradient of theta w.r.t phi
    but use many samples to compute policy gradient

    "fop" stands for that we update f using on-policy strategy
"""


class DDPGOprsV2FsaTrainer(DDPGTrainer):
    def __init__(self, state_dim, action_dim, algo_name="ddpg_oprs_v2_fsa", **kwargs):
        super(DDPGOprsV2FsaTrainer, self).__init__(state_dim, action_dim, algo_name, **kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.policy_batch_size = kwargs.get("policy_batch_size", POLICY_BATCH_SIZE)
        self.weight_func_batch_size = kwargs.get("weight_func_batch_size", WEIGHT_FUNC_BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)
        self.episode_truncation_size = kwargs.get("episode_truncation_size", EPISODE_TRUNCATION_SIZE)
        self.nabla_theta_wrt_phi_sam_num = kwargs.get("nabla_theta_wrt_phi_sam_num",
                                                      COMPUTE_GRAD_THETA_WRT_PHI_SAMPLE_NUM)
        self.update_num_per_switch_atof = kwargs.get("update_num_per_switch_atof",
                                                     UPDATE_NUM_PER_SWITCH_A_TO_F)
        self.update_num_per_switch_ftoa = kwargs.get("update_num_per_switch_ftoa",
                                                     UPDATE_NUM_PER_SWITCH_F_TO_A)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DDPGOprsV2FsaAlgo(self.session, self.graph, self.state_dim, self.action_dim,
                                           algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0
        self.f_update_cnt = 0
        self.local_update_cnt = 0

        """
            replay buffer for optimizing policy and shaping weight function
        """
        self.policy_memory = ReplayMemory(self.replay_buffer_size)
        self.policy_memory_sample_num = 0
        self.weight_func_memory = ReplayMemory(self.episode_truncation_size)
        self.weight_func_memory_sample_num = 0

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, action_info = self.algorithm.choose_action(state, test_model)
        return a, action_info

    def experience(self, s, a, r, sp, terminal, **kwargs):
        """
            get the shaping weight f_phi(s) and additional reward F(s,a,s')
        """
        f_phi_s = kwargs.get("f_phi_s")
        F_value = kwargs.get("F")
        self.algorithm.experience((s, a, r, sp, terminal, f_phi_s, F_value))

    def update(self, t):
        """
            only update actor
        """
        if self.algorithm.optimize_policy:
            self.update_policy_parameters(t)

    def update_policy_parameters(self, t):
        if len(self.policy_memory.store) > FIRST_UPDATE_POLICY_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            if self.policy_memory_sample_num < 20000:
                return

            # get mini batch from replay buffer
            traj_samples = self.policy_memory.get_minibatch(self.policy_batch_size)

            batch_sample_num = 0
            for traj_index in range(len(traj_samples)):
                batch_sample_num += len(traj_samples[traj_index])

            while batch_sample_num < 1024:
                traj_samples_add = self.policy_memory.get_minibatch(self.policy_batch_size)
                traj_samples.extend(traj_samples_add)

                for tj_index in range(len(traj_samples_add)):
                    batch_sample_num += len(traj_samples_add[tj_index])

            s_batch, a_batch, r_batch, sp_batch, done_batch, f_phi_batch, F_batch = [], [], [], [], [], [], []
            for traj_index in range(len(traj_samples)):
                epi_traj = traj_samples[traj_index]

                for exp_index in range(len(epi_traj)):
                    sample = epi_traj[exp_index]
                    s_batch.append(sample[0])
                    a_batch.append(sample[1])
                    r_batch.append(sample[2])
                    sp_batch.append(sample[3])
                    done_batch.append(sample[4])
                    f_phi_batch.append(sample[5])
                    F_batch.append(sample[6])

            """
                pick up the states for computing gradient of theta w.r.t. phi
            """
            sample_traj_count = 0
            sample_traj_index_dict = {}
            sample_traj_index_list = [None] * self.nabla_theta_wrt_phi_sam_num
            while sample_traj_count < self.nabla_theta_wrt_phi_sam_num:

                sample_traj_index = random.randint(0, len(traj_samples) - 1)
                # while sample_traj_index_dict.get(sample_traj_index) is not None:
                #     sample_traj_index = random.randint(0, len(traj_samples) - 1)

                # sample_traj_index_dict.update({sample_traj_index: True})

                sample_traj_index_list[sample_traj_count] = sample_traj_index
                sample_traj_count += 1

            # print("Index array is {}".format(sample_traj_index_list))

            grad_compute_s_batch, mini_batches = [], []
            # for sample_traj_index in sample_traj_index_dict.keys():
            for list_index in range(len(sample_traj_index_list)):
                sample_traj_index = sample_traj_index_list[list_index]
                sample_traj = traj_samples[sample_traj_index]
                exp_index = random.randint(0, len(sample_traj) - 1)

                experience = sample_traj[exp_index]
                grad_compute_s_batch.append(experience[0])
                mini_batch_s = []
                mini_batch_a = []
                mini_batch_t = []
                mini_batch_F = []
                step = 0
                for j in range(exp_index, len(sample_traj)):
                    exp = sample_traj[j]
                    mini_batch_s.append(exp[0])
                    mini_batch_a.append(exp[1])
                    mini_batch_t.append(step)
                    mini_batch_F.append(exp[6])
                    step += 1

                mini_batches.append([np.array(mini_batch_s),
                                     np.array(mini_batch_a),
                                     np.array(mini_batch_t).reshape([-1, 1]),
                                     np.array(mini_batch_F).reshape([-1, 1])])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                                 F=np.array(F_batch).reshape([-1, 1]),
                                 grad_compute_s_batch=np.array(grad_compute_s_batch),
                                 mini_batches=mini_batches)

            # save param
            self.save_params()

            """
                switch optimization
            """
            self.update_cnt += 1
            self.local_update_cnt += 1
            if self.local_update_cnt >= self.update_num_per_switch_atof:
                self.local_update_cnt = 0
                self.algorithm.switch_optimization()
                print("begin to optimize shaping weight function")

                # for neg
                self.policy_memory.store.clear()
                self.policy_memory_sample_num = 0

    def update_weight_func_parameters(self):
        # get mini batch from replay buffer
        ep_trunc_size = max(len(self.weight_func_memory.store), self.episode_truncation_size)
        traj_samples = self.weight_func_memory.get_minibatch(ep_trunc_size)
        s_batch, a_batch, r_batch, sp_batch, done_batch, f_phi_batch, F_batch = [], [], [], [], [], [], []

        for traj_index in range(len(traj_samples)):
            epi_traj = traj_samples[traj_index]
            for exp_index in range(len(epi_traj)):
                experience = epi_traj[exp_index]
                s_batch.append(experience[0])
                a_batch.append(experience[1])
                r_batch.append(experience[2])
                sp_batch.append(experience[3])
                done_batch.append(experience[4])
                f_phi_batch.append(experience[5])
                F_batch.append(experience[6])

        self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                             np.array(r_batch).reshape([-1, 1]),
                             np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                             f_phi_s=np.array(f_phi_batch).reshape([-1, 1]),
                             F=np.array(F_batch).reshape([-1, 1]))

        # save param
        self.save_params()

        """
            switch optimization
        """
        self.local_update_cnt += 1
        self.f_update_cnt += 1
        if self.local_update_cnt >= self.update_num_per_switch_ftoa:
            self.local_update_cnt = 0
            self.algorithm.switch_optimization()
            print("begin to optimize policy")

    def episode_done(self, test_model):
        """
            get the trajectory of this episode
            and add it to the memory
        """
        episode_traj = self.algorithm.episode_done(test_model)
        if not test_model:
            assert episode_traj is not None

            if self.algorithm.optimize_policy:
                self.policy_memory_sample_num += len(episode_traj)
                self.policy_memory.add(episode_traj)
            else:
                self.weight_func_memory_sample_num += len(episode_traj)
                self.weight_func_memory.add(episode_traj)
                if len(self.weight_func_memory.store) >= self.episode_truncation_size and \
                    self.weight_func_memory_sample_num >= 20000:
                    self.update_weight_func_parameters()
                    self.weight_func_memory.store.clear()
                    self.weight_func_memory_sample_num = 0

    def save_params(self):
        if self.algorithm.optimize_policy:
            if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
                print('model saved for update', self.update_cnt)
                save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
                self.algorithm.saver.save(self.algorithm.sess, save_path)
        else:
            if self.f_update_cnt % SHAPING_WEIGHT_FUNC_MODEL_UPDATE_FREQ == 0 \
                    and self.f_update_cnt > 0:
                print('model saved for update', self.update_cnt)
                save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
                self.algorithm.saver.save(self.algorithm.sess, save_path)


