import numpy as np
import tensorflow as tf
from ...utils.memory import ReplayMemory

BATCH_SIZE = 1024
REPLAY_BUFFER_SIZE = 1000000
UPDATE_FREQ = 4

FIRST_UPDATE_SAMPLE_NUM = 25600
MODEL_UPDATE_FREQ = 10000

class MADDPGTrainer(object):
    def __init__(self, state_dim, action_dim, agent_num=2, algo_name="maddpg", **kwargs):

        """
            one is the true agent
            the other one is the shaping weight agent
        """
        self.agent_num = agent_num

        self.state_dim = state_dim

        """
            for discrete action problem, action_dim is action number
            for continuous action problem, it is action dim
        """
        self.action_dim = action_dim
        self.algo_name = algo_name

        """
            set trainer parameters
        """
        self.set_trainer_parameters(**kwargs)

        self.init_algo(**kwargs)


    def set_trainer_parameters(self, **kwargs):
        self.batch_size = kwargs.get("batch_size", BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)


    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = MADDPGAlgo(self.session, self.graph, self.state_dim, self.action_dim,
                                  agent_num=self.agent_num, algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0

        # the replay buffer
        self.memory = ReplayMemory(self.replay_buffer_size)

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer




    def action(self, state, temperature, test_model):#, is_learn, test_index, index):
        state = np.asarray(state)
        # self.lock.acquire()
        # if is_learn:
        if test_model:
            temperature = 1e-8

        a_all = self.mymaddpg.choose_action(state, temperature)
        return a_all

    def experience(self, s, a, r, s_n, done, terminal, test_model):
        if not test_model:
            self.memory.add((s, a, r, s_n, done, terminal))

    # def fetch(self, interface):
    #     # for x in range(len(interface.cache)):
    #     # sample = interface.cache.popleft()
    #     sample = interface.cache.pop(0)
    #     self.tmp.append([np.asarray(x) if type(x) == list else x for x in sample])  # 先将 s 转为 np.array吧

    def update(self, t):

        if len(self.memory.store) > 25*1024:  #INIT_SAMPLE_NUM[self.para]:
            # print('UPDATE!!')
            if not t % 100 == 0:
                return

            self.update_cnt += 1
            sample = self.memory.get_minibatch(1024)
            s_all, a_all, r_all, s__all, done_all = [], [], [], [], []
            a_all = []
            for i in range(len(sample)):  # [[],[],[]]
                s_all.append(sample[i][0])
                a_all.append(sample[i][1])
                r_all.append(sample[i][2])
                s__all.append(sample[i][3])
                done_all.append(sample[i][4])

            s_all = np.array(s_all)
            # a_all = [np.array(a1_all), np.array(a2_all), np.array(a3_all)] # org
            a_all = np.array(a_all)  #.reshape([-1, 1])
            r_all = np.array(r_all).reshape([-1, 1])
            s__all = np.array(s__all)
            done_all = np.array(done_all).reshape([-1, 1])

            # self.lock.acquire()
            self.mymaddpg.learn(s_all, a_all, r_all, s__all, done_all)

            # for i in self.var:
            self.var *= 0.999  # explore decay

            # summary = self.sess.run(self.merged, feed_dict={self.mymaddpg.S: s_all, self.mymaddpg.S_: s__all,
            #                                                 self.mymaddpg.a_all: a_all, self.mymaddpg.R: r_all,
            #                                                 self.mymaddpg.Done: done_all})
            #
            # self.train_writer.add_summary(summary, self.update_cnt)
            # self.train_writer.add_summary(c_summary, self.update_cnt)
            # self.lock.release()

            self.save_params()


    def save_params(self):
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            print('model saved for update', self.update_cnt)
            save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
            self.algorithm.saver.save(self.algorithm.sess, save_path)

    def load_params(self, load_cnt):
        load_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(load_cnt)
        self.algorithm.saver.restore(self.algorithm.sess, load_path)
        print("load model for update %s " % load_cnt)

    def episode_done(self, test_model):
        self.algorithm.episode_done(test_model)

    def write_summary_scalar(self, iteration, tag, value, train_info):
        if train_info:
            self.algorithm.write_summary_scalar(iteration, tag, value)
        else:
            self.my_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)