import argparse
import sys
import time
import numpy as np
import os
import random
import tensorflow as tf
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

import gym
from gym import logger
from experiments.algorithms.dqn.dqn_trainer import DqnTrainer
from experiments.algorithms.dqn.dqn_pbrs_trainer import DqnPbrsTrainer
from experiments.algorithms.dqn.dqn_myp_pbrs_trainer import DqnMypPbrsTrainer
from experiments.algorithms.dqn.dqn_dps_trainer import DqnDpsTrainer
from experiments.algorithms.dqn.dqn_dpba_trainer import DqnDpbaTrainer
from experiments.algorithms.ppo.ppo_trainer import PPOTrainer
from experiments.algorithms.ppo.ppo_pbrs.ppo_pbrs_trainer import PPOPbrsTrainer
from experiments.algorithms.ppo.ppo_dpba.ppo_dpba_trainer import PPODpbaTrainer
from experiments.algorithms.ppo.ppo_oprs_v1_fop.ppo_oprs_v1_fop_trainer import PPOOprsV1FopTrainer
from experiments.algorithms.ppo.ppo_oprs_v2_fop.ppo_oprs_v2_fop_trainer import PPOOprsV2FopTrainer
from experiments.algorithms.ppo.ppo_oprs_v2_fsa.ppo_oprs_v2_fsa_trainer import PPOOprsV2FsaTrainer
from experiments.algorithms.ppo.ppo_oprs_v3_fsart.ppo_oprs_v3_fsart_trainer import PPOOprsV3FsartTrainer
from experiments.algorithms.ppo.ppo_algo_parameters import ppo_hyper_params,\
    ppo_dpba_hyper_params, ppo_oprs_v1_hyper_params, ppo_oprs_v2_hyper_params,\
    ppo_oprs_v3_hyper_params

def parse_args():
    parser = argparse.ArgumentParser("Reinforcement Learning experiments for multiagent environments")
    # Environment
    parser.add_argument("--scenario", type=str, default="simple_tag_graph", help="name of the scenario script")
    parser.add_argument("--max-episode-len", type=int, default=200, help="maximum episode length")
    parser.add_argument("--num-episodes", type=int, default=5000, help="number of episodes")
    parser.add_argument("--num-adversaries", type=int, default=3, help="number of adversaries")
    parser.add_argument("--good-policy", type=str, default="ddpg", help="policy for good agents")
    parser.add_argument("--adv-policy", type=str, default="maddpg", help="policy of adversaries")
    # Core training parameters
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate for Adam optimizer")
    parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
    parser.add_argument("--batch-size", type=int, default=1024, help="number of episodes to optimize at the same time")
    parser.add_argument("--num-units", type=int, default=64, help="number of units in the mlp")
    # Checkpointing
    parser.add_argument("--exp-name", type=str, default='maddpg', help="name of the experiment") # org:None
    parser.add_argument("--save-dir", type=str, default="./tmp/policy/", help="directory in which training state and model should be saved")
    parser.add_argument("--save-rate", type=int, default=1000, help="save model once every time this many episodes are completed")
    # parser.add_argument("--test-period", type=int, default=20, help="after how many episode we test the model")

    parser.add_argument("--step-to-test", type=int, default=4000, help="after how many steps we test the model")
    parser.add_argument("--step-to-finish", type=int, default=1600000, help="after how many steps we finish training")

    parser.add_argument("--test-count", type=int, default=20, help="after how many episode we test the model")

    parser.add_argument("--load-dir", type=str, default="", help="directory in which training state and model are loaded")
    # Evaluation
    parser.add_argument("--restore", action="store_true", default=False)
    parser.add_argument("--display", action="store_true", default=False)
    parser.add_argument("--benchmark", action="store_true", default=False)
    parser.add_argument("--benchmark-iters", type=int, default=100000, help="number of iterations run for benchmarking")
    parser.add_argument("--benchmark-dir", type=str, default="./benchmark_files/", help="directory where benchmark data is saved")
    parser.add_argument("--plots-dir", type=str, default="./learning_curves/", help="directory where plot data is saved")
    parser.add_argument("--algo", type=str, default="ppo", help="the algorithm name")
    parser.add_argument('env_id', nargs='?', default='CartPole-v0', help='Select the environment to run')


    """
        The shaping method
        values are:
        "none", which means no reward shaping
        "ns", which means naive shaping
        "dps", which means double policy shaping
        "dpba", which means dynamic potential based advice, 
        "myp_pbrs", which means myopic potential based reward shaping
        "pbrs", which means potential based reward shaping
    """
    parser.add_argument("--shaping-method", type=str, default="oprs_v3_fsart", help="the reward shaping method")
    return parser.parse_args()


def get_trainers(algo, shaping_method, obs_space, action_space, obs_shape, action_dim):
    if algo == "ppo":
        if shaping_method == "ns":
            trainer = PPOTrainer(state_space=obs_space, action_space=action_space, algo_name="ppo_ns",
                                 **ppo_hyper_params)
        elif shaping_method == "dpba":
            trainer = PPODpbaTrainer(state_space=obs_space, action_space=action_space,
                                     algo_name="ppo_dpba", **ppo_dpba_hyper_params)
        elif shaping_method == "pbrs":
            trainer = PPOPbrsTrainer(state_space=obs_space, action_space=action_space,
                                     algo_name="ppo_pbrs", **ppo_hyper_params)
        elif shaping_method == "oprs_v1_fop":
            trainer = PPOOprsV1FopTrainer(state_space=obs_space, action_space=action_space,
                                          algo_name="ppo_oprs_v1_fop", **ppo_oprs_v1_hyper_params)
        elif shaping_method == "oprs_v2_fop":
            trainer = PPOOprsV2FopTrainer(state_space=obs_space, action_space=action_space,
                                          algo_name="ppo_oprs_v2_fop", **ppo_oprs_v2_hyper_params)
        elif shaping_method == "oprs_v2_fsa":
            trainer = PPOOprsV2FsaTrainer(state_space=obs_space, action_space=action_space,
                                          algo_name="ppo_oprs_v2_fsa", **ppo_oprs_v2_hyper_params)
        elif shaping_method == "oprs_v3_fsart":
            trainer = PPOOprsV3FsartTrainer(state_space=obs_space, action_space=action_space,
                                            algo_name="ppo_oprs_v3_fsart", **ppo_oprs_v3_hyper_params)
        else:
            trainer = PPOTrainer(state_space=obs_space, action_space=action_space,
                                 algo_name="ppo", **ppo_hyper_params)
    elif algo == "dqn":
        if shaping_method == "pbrs":
            trainer = DqnPbrsTrainer(state_dim=obs_shape, action_num=action_dim, algo_name="dqn_pbrs")
        elif shaping_method == "myp_pbrs":
            trainer = DqnMypPbrsTrainer(state_dim=obs_shape, action_num=action_dim, algo_name="dqn_myp_pbrs")
        elif shaping_method == "ns":
            trainer = DqnTrainer(state_dim=obs_shape, action_num=action_dim, algo_name="dqn_ns")
        elif shaping_method == "dps":
            trainer = DqnDpsTrainer(state_dim=obs_shape, action_num=action_dim, algo_name="dqn_dps")
        elif shaping_method == "dpba":
            trainer = DqnDpbaTrainer(state_dim=obs_shape, action_num=action_dim, algo_name="dqn_dpba")
        else:
            trainer = DqnTrainer(state_dim=obs_shape, action_num=action_dim, algo_name="dqn")
    else:
        trainer = None

    return trainer

def train_all(arglist):
    seeds = [10011301, 104590487, 106312663, 107281571, 110900381,
             110900689, 110914091, 110942057, 99992777, 99999509,
             122420729, 163227661, 217636919, 290182597, 386910137,
             515880193, 687840301, 917120411, 1222827239, 1610612741]

    if arglist.shaping_method is not None:
        shaping_methods = [arglist.shaping_method]
        print("The arg shaping method is {}".format(arglist.shaping_method))
    else:
        shaping_methods = ["none"]

    data_dict = {}
    test_run = 20
    for r in range(test_run):
        np.random.seed(seeds[r])
        tf.set_random_seed(seeds[r])
        random.seed(seeds[r])
        for m in range(len(shaping_methods)):
            shaping_method = shaping_methods[m]
            returned_rewards, returned_steps = train(arglist, shaping_method, r)
            test_data = data_dict.get(shaping_method, None)
            if test_data is None:
                test_data = [[], []]

            test_data[0].append(returned_rewards)
            test_data[1].append(returned_steps)
            data_dict.update({shaping_method: test_data})

            """
                at each run we write the data to files
            """
            rewards_np = np.array(test_data[0])
            steps_np = np.array(test_data[1])
            rewards_mean = np.mean(rewards_np, axis=0)
            steps_mean = np.mean(steps_np, axis=0)
            # print("current reward arrays are {}".format(rewards_np))
            # print("current step arrays are {}".format(steps_np))
            print("current step mean is {}".format(steps_mean))
            print("current reward mean is {}".format(rewards_mean))

            with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_steps", 'w') as f:
                for index in range(len(steps_mean)):
                    f.write(str(steps_mean[index]) + "\n")

                f.close()

            with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_rewards", 'w') as f:
                for index in range(len(rewards_mean)):
                    f.write(str(rewards_mean[index]) + "\n")

                f.close()

            if r > 0:
                rewards_var = np.var(rewards_np, axis=0)
                steps_var = np.var(steps_np, axis=0)
                with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_steps_var", 'w') as f:
                    for index in range(len(steps_var)):
                        f.write(str(steps_var[index]) + "\n")

                    f.close()

                with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_rewards_var",
                          'w') as f:
                    for index in range(len(rewards_var)):
                        f.write(str(rewards_var[index]) + "\n")

                    f.close()


def train(arglist, shaping_method, test_run):

    """
        You can set the level to logger.DEBUG or logger.WARN if you
        want to change the amount of output.
    """
    logger.set_level(logger.INFO)

    """
        Create environment
    """
    env = gym.make(arglist.env_id)


    """
        The initial state
    """
    observation = env.reset()
    print("The observation is {}".format(observation))


    """
        Create the trainer and agent
        and load parameters if we are testing
    """
    print("Observation space is {}".format(env.observation_space))
    print("Action space is {}".format(env.action_space))
    print("Observation space shape is {}".format(env.observation_space.shape))
    print("Action space num is {}".format(env.action_space.n))
    # trainer = get_trainers(1, [observation.shape[0]], [env.action_space.n], shaping_method)

    trainer = get_trainers(arglist.algo, shaping_method, obs_space=env.observation_space,
                           action_space=env.action_space, obs_shape=observation.shape[0],
                           action_dim=env.action_space.n)

    if arglist.display or arglist.restore or arglist.benchmark:
        print('Loading previous state...')
        trainer.load_params(24000)

    """
        some training variables
    """
    episode_step = 0
    train_step = 0
    agent_one_episode_rewards = 0.
    test_model = False
    test_episode_count = 0
    test_episode_rewards = [0.0]
    global_test_ep_rewards = []
    test_episode_steps = [0.0]
    global_test_ep_steps = []

    # write the first point
    trainer.write_summary_scalar(1, "Test_Episode_Reward", 0, True)
    print('Starting iterations...')

    while True:
        """
            The agent takes an action
        """
        action, action_info = trainer.action(observation, test_model=test_model)
        # action = action[0]
        # print("action is ", action)

        """
            Environment takes a step
        """
        # env.render()
        next_observation, reward, done, info = env.step(action)
        episode_step += 1

        # print("Current episode count is {}".format(train_episode_count))

        """
            Now collecting samples
        """
        s, a, r, sp, done, terminal = observation, action, reward, next_observation, done, (
                    episode_step >= arglist.max_episode_len)
        episode_terminal = done or terminal

        """
            record the experience
        """
        if not test_model:
            experience(trainer, arglist.algo, shaping_method, s, a, r, sp,
                       episode_terminal, action_info, info)

        """
            Transition to the next state
        """
        observation = next_observation

        """
            Record the reward
        """
        if not test_model:
            agent_one_episode_rewards += r
            train_step += 1
        else:
            test_episode_rewards[-1] += r
            test_episode_steps[-1] += 1

        """
            for displaying learned policies
        """
        if arglist.display:
            time.sleep(0.1)
            env.render()
            continue

        """
            Update all trainers, if not in display or benchmark mode
        """
        if not test_model:
            trainer.update(train_step)

        """
            If this episode is done or terminal
            start a new one and make a summary of this episode
        """
        if episode_terminal:
            trainer.episode_done(test_model=test_model)
            observation = env.reset()
            episode_step = 0

        """
            if in training and we have run enough steps to conduct one test
        """
        if not test_model:
            if train_step % arglist.step_to_test == 0 and train_step > 0:
                test_model = True
                test_episode_count = 0
                test_episode_rewards = [0.0]
                test_episode_steps = [0.0]
                print("Round {}, train step {}, begin to test algorithm".format(len(global_test_ep_steps),
                                                                                train_step))
        elif episode_terminal:
            """
                if in test and the episode is terminated
            """
            test_episode_count += 1

            """
                Make a summary of the last 1000 episodes
            """
            if test_episode_count == arglist.test_count:
                avg_test_ep_rewards = sum(test_episode_rewards) / (len(test_episode_rewards))
                avg_test_ep_steps = sum(test_episode_steps) / (len(test_episode_steps))
                global_test_ep_rewards.append(avg_test_ep_rewards)
                global_test_ep_steps.append(avg_test_ep_steps)

                print("Shaping Method {}, Test run {}, "
                      "Average test step and rewards are {}, {}".format(shaping_method, test_run,
                                                                        avg_test_ep_steps, avg_test_ep_rewards))

                test_model = False

                """
                    Finish the loop if predefined episodes are run
                """
                if train_step >= arglist.step_to_finish:
                    print('...Finished total of {} training steps.'.format(train_step))
                    break
            else:
                test_episode_rewards.append(0)
                test_episode_steps.append(0)

    # print the return value
    print("The return test rewards are")
    print(global_test_ep_rewards)
    print("The return test steps are")
    print(global_test_ep_steps)

    """
        Close the env and write monitor result info to disk
    """
    env.close()
    return np.array(global_test_ep_rewards), np.array(global_test_ep_steps)


def experience(trainer, algo, shaping_method, s, a, r, sp, done, action_info, info):
    if algo == "dqn":
        if shaping_method == "ns":
            c = info.get("c")
            trainer.experience(s, a, r + c, sp, done)
        elif shaping_method == "dps" or shaping_method == "dpba":
            c = info.get("c")
            trainer.experience(s, a, r, sp, done, c=c)
        elif shaping_method == "myp_pbrs":
            c = info.get("c")
            c_sp = info.get("c_sp")  # the shaping reward of all actions in the next state
            trainer.experience(s, a, r, sp, done, c=c, c_sp=c_sp)
        elif shaping_method == "pbrs":
            phi_s = info.get("phi_s")
            phi_sp = info.get("phi_sp")
            trainer.experience(s, a, r, sp, done, phi_s=phi_s, phi_sp=phi_sp)
        else:
            trainer.experience(s, a, r, sp, done)
    elif algo == "ppo":
        if shaping_method == "ns":
            c = info.get("c")
            trainer.experience(s, a, r + c, sp, done, v_pred=action_info)
        elif shaping_method == "pbrs":
            phi_s = info.get("phi_s")
            phi_sp = info.get("phi_sp")
            trainer.experience(s, a, r, sp, done, v_pred=action_info,
                               phi_s=phi_s, phi_sp=phi_sp)
        elif shaping_method == "dpba":
            c = info.get("c")
            v_pred = action_info.get("v_pred")
            phi_sa = action_info.get("phi_sa")
            trainer.experience(s, a, r, sp, done, v_pred=v_pred, c=c, phi_sa=phi_sa)
        elif shaping_method == "oprs_v1_fop" or shaping_method == "oprs_v2_fop" or \
                shaping_method == "oprs_v2_fsa" or shaping_method == "oprs_v3_fsart":
            v_pred = action_info.get("v_pred")
            v_pred_true = action_info.get("v_pred_true")
            f_phi_s = action_info.get("f_phi_s")
            F_value = info.get("c")
            trainer.experience(s, a, r, sp, done, v_pred=v_pred, v_pred_true=v_pred_true,
                               f_phi_s=f_phi_s, F=F_value)
        else:
            trainer.experience(s, a, r, sp, done, v_pred=action_info)

if __name__ == '__main__':
    arglist = parse_args()
    train_all(arglist)


