import argparse
import sys
import time
import numpy as np
import random
import tensorflow as tf
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

import gym
from gym import logger
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer
from experiments.algorithms.ddpg.ddpg_dpba.ddpg_dpba_trainer import DDPGDpbaTrainer
from experiments.algorithms.ddpg.ddpg_pbrs.ddpg_pbrs_trainer import DDPGPbrsTrainer
from experiments.algorithms.ddpg.ddpg_oprs_v1.ddpg_oprs_v1_trainer import DDPGOprsV1Trainer
from experiments.algorithms.ddpg.ddpg_oprs_v1_freeze.ddpg_oprs_v1_freeze_trainer import DDPGOprsV1FreezeTrainer
from experiments.algorithms.ddpg.ddpg_oprs_v1_fop.ddpg_oprs_v1_fop_trainer import DDPGOprsV1FopTrainer
from experiments.algorithms.ddpg.ddpg_oprs_v2.ddpg_oprs_v2_trainer import DDPGOprsV2Trainer
from experiments.algorithms.ddpg.ddpg_oprs_v2_approx.ddpg_oprs_v2_approx_trainer import DDPGOprsV2ApproxTrainer
from experiments.algorithms.ddpg.ddpg_oprs_v2_fop.ddpg_oprs_v2_fop_trainer import DDPGOprsV2FopTrainer
from experiments.algorithms.ddpg.ddpg_oprs_v2_fsa.ddpg_oprs_v2_fsa_trainer import DDPGOprsV2FsaTrainer
from experiments.algorithms.ddpg.ddpg_oprs_v2_fsaqin.ddpg_oprs_v2_fsaqin_trainer import DDPGOprsV2FsaqinTrainer
from experiments.algorithms.ddpg.ddpg_oprs_v3_fsaqin.ddpg_oprs_v3_fsaqin_trainer import DDPGOprsV3FsaqinTrainer
from experiments.algorithms.ddpg.ddpg_algo_parameters import ddpg_hyper_params, ddpg_dpba_hyper_params,\
    ddpg_oprs_v1_hyper_params, ddpg_oprs_v1_fop_hyper_params, ddpg_oprs_v2_hyper_params,\
    ddpg_oprs_v2_approx_hyper_params, ddpg_oprs_v2_fop_hyper_params, ddpg_oprs_v2_fsa_hyper_params,\
    ddpg_oprs_v2_fsaqin_hyper_params, ddpg_oprs_v3_fsaqin_hyper_params

from experiments.algorithms.ppo.ppo_trainer import PPOTrainer
from experiments.algorithms.ppo.ppo_pbrs.ppo_pbrs_trainer import PPOPbrsTrainer
from experiments.algorithms.ppo.ppo_dpba.ppo_dpba_trainer import PPODpbaTrainer
from experiments.algorithms.ppo.ppo_oprs_v1_fop.ppo_oprs_v1_fop_trainer import PPOOprsV1FopTrainer
from experiments.algorithms.ppo.ppo_oprs_v2_fop.ppo_oprs_v2_fop_trainer import PPOOprsV2FopTrainer
from experiments.algorithms.ppo.ppo_oprs_v2_fsa.ppo_oprs_v2_fsa_trainer import PPOOprsV2FsaTrainer
from experiments.algorithms.ppo.ppo_oprs_v3_fsart.ppo_oprs_v3_fsart_trainer import PPOOprsV3FsartTrainer
from experiments.algorithms.ppo.ppo_algo_parameter_mujoco import ppo_hyper_params,\
    ppo_dpba_hyper_params, ppo_oprs_v1_hyper_params, ppo_oprs_v2_hyper_params,\
    ppo_oprs_v3_hyper_params

from experiments.algorithms.rcpo.rcpo_ppo.rcpo_ppo_trainer import RcpoPpoTrainer
from experiments.algorithms.rcpo.rcpo_algo_parameter_mujoco import rcpo_ppo_hyper_params

def parse_args():
    parser = argparse.ArgumentParser("Reinforcement Learning experiments for multiagent environments")
    # Environment
    parser.add_argument("--max-episode-len", type=int, default=1000, help="maximum episode length")
    parser.add_argument("--num-episodes", type=int, default=5000, help="number of episodes")
    parser.add_argument("--num-adversaries", type=int, default=3, help="number of adversaries")
    parser.add_argument("--good-policy", type=str, default="ddpg", help="policy for good agents")
    parser.add_argument("--adv-policy", type=str, default="maddpg", help="policy of adversaries")
    # Core training parameters
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate for Adam optimizer")
    parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
    parser.add_argument("--batch-size", type=int, default=1024, help="number of episodes to optimize at the same time")
    parser.add_argument("--num-units", type=int, default=64, help="number of units in the mlp")
    # Checkpointing
    parser.add_argument("--exp-name", type=str, default='maddpg', help="name of the experiment") # org:None
    parser.add_argument("--save-dir", type=str, default="./tmp/policy/", help="directory in which training state and model should be saved")
    parser.add_argument("--save-rate", type=int, default=1000, help="save model once every time this many episodes are completed")
    # parser.add_argument("--test-period", type=int, default=20, help="after how many episode we test the model")

    parser.add_argument("--step-to-test", type=int, default=4000, help="after how many steps we test the model")
    parser.add_argument("--step-to-finish", type=int, default=3200000, help="after how many steps we finish training")

    parser.add_argument("--test-count", type=int, default=20, help="after how many episode we test the model")

    parser.add_argument("--load-dir", type=str, default="", help="directory in which training state and model are loaded")
    # Evaluation
    parser.add_argument("--restore", action="store_true", default=False)
    parser.add_argument("--display", action="store_true", default=False)
    parser.add_argument("--benchmark", action="store_true", default=False)
    parser.add_argument("--benchmark-iters", type=int, default=100000, help="number of iterations run for benchmarking")
    parser.add_argument("--benchmark-dir", type=str, default="./benchmark_files/", help="directory where benchmark data is saved")
    parser.add_argument("--plots-dir", type=str, default="./learning_curves/", help="directory where plot data is saved")
    parser.add_argument("--algo", type=str, default="ppo", help="the algorithm name")
    parser.add_argument('--env_id', type=str, nargs='?', default='Swimmer-v2', help='Select the environment to run')
    """
        The shaping method
        values are:
        "none", which means no reward shaping
        "ns", which means naive shaping
        "dps", which means double policy shaping
        "dpba", which means dynamic potential based advice, 
        "myp_pbrs", which means myopic potential based reward shaping
        "pbrs", which means potential based reward shaping
    """
    parser.add_argument("--shaping-method", type=str, default="none", help="the reward shaping method")
    return parser.parse_args()


def get_trainers(algo, shaping_method, obs_space, action_space, obs_shape, action_dim,):
    if algo == "ddpg":
        if shaping_method == "ns":
            trainer = DDPGTrainer(state_dim=obs_shape, action_dim=action_dim,
                                  algo_name="ddpg_ns", **ddpg_hyper_params)
        elif shaping_method == "dpba":
            trainer = DDPGDpbaTrainer(state_dim=obs_shape, action_dim=action_dim,
                                      algo_name="ddpg_dpba", **ddpg_dpba_hyper_params)
        elif shaping_method == "pbrs":
            trainer = DDPGPbrsTrainer(state_dim=obs_shape, action_dim=action_dim,
                                      algo_name="ddpg_pbrs", **ddpg_hyper_params)
        elif shaping_method == "oprs_v1":
            trainer = DDPGOprsV1Trainer(state_dim=obs_shape, action_dim=action_dim,
                                        algo_name="ddpg_oprs_v1", **ddpg_oprs_v1_hyper_params)
        elif shaping_method == "oprs_v1_freeze":
            trainer = DDPGOprsV1FreezeTrainer(state_dim=obs_shape, action_dim=action_dim,
                                              algo_name="ddpg_oprs_v1_freeze", **ddpg_oprs_v1_hyper_params)
        elif shaping_method == "oprs_v1_fop":
            trainer = DDPGOprsV1FopTrainer(state_dim=obs_shape, action_dim=action_dim,
                                           algo_name="ddpg_oprs_v1_fop", **ddpg_oprs_v1_fop_hyper_params)
        elif shaping_method == "oprs_v2":
            trainer = DDPGOprsV2Trainer(state_dim=obs_shape, action_dim=action_dim,
                                        algo_name="ddpg_oprs_v2", **ddpg_oprs_v2_hyper_params)
        elif shaping_method == "oprs_v2_approx":
            trainer = DDPGOprsV2ApproxTrainer(state_dim=obs_shape, action_dim=action_dim,
                                              algo_name="ddpg_oprs_v2_approx", **ddpg_oprs_v2_approx_hyper_params)
        elif shaping_method == "oprs_v2_fop":
            trainer = DDPGOprsV2FopTrainer(state_dim=obs_shape, action_dim=action_dim,
                                           algo_name="ddpg_oprs_v2_fop", **ddpg_oprs_v2_fop_hyper_params)
        elif shaping_method == "oprs_v2_fsa":
            trainer = DDPGOprsV2FsaTrainer(state_dim=obs_shape, action_dim=action_dim,
                                           algo_name="ddpg_oprs_v2_fsa", **ddpg_oprs_v2_fsa_hyper_params)
        elif shaping_method == "oprs_v2_fsaqin":
            trainer = DDPGOprsV2FsaqinTrainer(state_dim=obs_shape, action_dim=action_dim,
                                              algo_name="ddpg_oprs_v2_fsaqin", **ddpg_oprs_v2_fsaqin_hyper_params)
        elif shaping_method == "oprs_v3_fsaqin":
            trainer = DDPGOprsV3FsaqinTrainer(state_dim=obs_shape, action_dim=action_dim,
                                              algo_name="ddpg_oprs_v3_fsaqin", **ddpg_oprs_v3_fsaqin_hyper_params)
        else:
            trainer = DDPGTrainer(state_dim=obs_shape, action_dim=action_dim,
                                  algo_name="ddpg", **ddpg_hyper_params)
    elif algo == "ppo":
        if shaping_method == "ns":
            trainer = PPOTrainer(state_space=obs_space, action_space=action_space, algo_name="ppo_ns",
                                 **ppo_hyper_params)
        elif shaping_method == "dpba":
            trainer = PPODpbaTrainer(state_space=obs_space, action_space=action_space,
                                     algo_name="ppo_dpba", **ppo_dpba_hyper_params)
        elif shaping_method == "pbrs":
            trainer = PPOPbrsTrainer(state_space=obs_space, action_space=action_space,
                                     algo_name="ppo_pbrs", **ppo_hyper_params)
        elif shaping_method == "oprs_v1_fop":
            trainer = PPOOprsV1FopTrainer(state_space=obs_space, action_space=action_space,
                                          algo_name="ppo_oprs_v1_fop", **ppo_oprs_v1_hyper_params)
        elif shaping_method == "oprs_v2_fop":
            trainer = PPOOprsV2FopTrainer(state_space=obs_space, action_space=action_space,
                                          algo_name="ppo_oprs_v2_fop", **ppo_oprs_v2_hyper_params)
        elif shaping_method == "oprs_v2_fsa":
            trainer = PPOOprsV2FsaTrainer(state_space=obs_space, action_space=action_space,
                                          algo_name="ppo_oprs_v2_fsa", **ppo_oprs_v2_hyper_params)
        elif shaping_method == "oprs_v3_fsart":
            trainer = PPOOprsV3FsartTrainer(state_space=obs_space, action_space=action_space,
                                            algo_name="ppo_oprs_v3_fsart", **ppo_oprs_v3_hyper_params)
        else:
            trainer = PPOTrainer(state_space=obs_space, action_space=action_space,
                                 algo_name="ppo", **ppo_hyper_params)
    elif algo == "rcpo_ppo":
        trainer = RcpoPpoTrainer(state_space=obs_space, action_space=action_space,
                                 algo_name="rcpo_ppo", **rcpo_ppo_hyper_params)
    else:
        trainer = None

    return trainer


def train_all(arglist):
    seeds = [10011301, 104590487, 106312663, 107281571, 110900381,
             110900689, 110914091, 110942057, 99992777, 99999509,
             122420729, 163227661, 217636919, 290182597, 386910137,
             515880193, 687840301, 917120411, 1222827239, 1610612741]
    shaping_methods = ["oprs_v1_fop", "oprs_v2_fsa", "oprs_v3_fsart"]
    data_dict = {}

    test_run = 10
    for r in range(test_run):
        for m in range(len(shaping_methods)):
            np.random.seed(seeds[r])
            tf.set_random_seed(seeds[r])
            random.seed(seeds[r])

            shaping_method = shaping_methods[m]
            returned_rewards, returned_steps, return_torques = train(arglist, shaping_method, r)
            test_data = data_dict.get(shaping_method, None)
            if test_data is None:
                test_data = [[], [], []]

            test_data[0].append(returned_rewards)
            test_data[1].append(returned_steps)
            test_data[2].append(return_torques)
            data_dict.update({shaping_method: test_data})

            """
                at each run we write the data to files
            """
            rewards_np = np.array(test_data[0])
            steps_np = np.array(test_data[1])
            torques_np = np.array(test_data[2])
            rewards_mean = np.mean(rewards_np, axis=0)
            steps_mean = np.mean(steps_np, axis=0)
            torques_mean = np.mean(torques_np, axis=0)
            print("current step mean is {}".format(steps_mean))
            print("current reward mean is {}".format(rewards_mean))
            print("current torque mean is {}".format(torques_mean))

            with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_steps", 'w') as f:
                for index in range(len(steps_mean)):
                    f.write(str(steps_mean[index]) + "\n")

                f.close()

            with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_rewards", 'w') as f:
                for index in range(len(rewards_mean)):
                    f.write(str(rewards_mean[index]) + "\n")

                f.close()

            with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_torques", 'w') as f:
                for index in range(len(torques_mean)):
                    f.write(str(torques_mean[index]) + "\n")

                f.close()

            if r > 0:
                rewards_var = np.var(rewards_np, axis=0)
                steps_var = np.var(steps_np, axis=0)
                torques_var = np.var(torques_np, axis=0)
                with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_steps_var", 'w') as f:
                    for index in range(len(steps_var)):
                        f.write(str(steps_var[index]) + "\n")

                    f.close()

                with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_rewards_var",
                          'w') as f:
                    for index in range(len(rewards_var)):
                        f.write(str(rewards_var[index]) + "\n")

                    f.close()

                with open("./" + arglist.env_id + "_" + arglist.algo + "_" + shaping_method + "_test_torques_var",
                          'w') as f:
                    for index in range(len(torques_var)):
                        f.write(str(torques_var[index]) + "\n")

                    f.close()


def train(arglist, shaping_method, test_run):


    """
        You can set the level to logger.DEBUG or logger.WARN if you
        want to change the amount of output.
    """
    logger.set_level(logger.INFO)

    """
        Create environment
    """
    env = gym.make(arglist.env_id)


    """
        The initial state
    """
    observation = env.reset()
    print("The observation is {}".format(observation))

    """
        Create the trainer and agent
        and load parameters if we are testing
    """
    print("Observation space is {}".format(env.observation_space))
    print("Action space is {}".format(env.action_space))
    print("Observation space shape is {}".format(env.observation_space.shape))
    print("Action space dim is {}".format(env.action_space.shape[0]))

    """
        create the trainer according to the algorithm and shaping method
    """
    trainer = get_trainers(arglist.algo, shaping_method, obs_space=env.observation_space,
                           action_space=env.action_space, obs_shape=observation.shape[0],
                           action_dim=env.action_space.shape[0])

    if arglist.display or arglist.restore or arglist.benchmark:
        print('Loading previous state...')
        trainer.load_params(24000)

    """
        some training variables
    """
    episode_step = 0
    train_step = 0
    agent_one_episode_rewards = 0.
    test_model = False
    test_episode_count = 0

    test_episode_rewards = [0.0]
    global_test_ep_rewards = []

    test_episode_steps = [0.0]
    global_test_ep_steps = []

    test_episode_torques = [0.0]
    global_test_ep_torques = []

    # write the first point
    trainer.write_summary_scalar(1, "Test_Episode_Reward", 0, True)
    # print('Starting iterations...')
    logger.info('Starting iterations...')
    logger.warn('Starting iterations...')

    while True:
        # print("state is {}".format(observation))
        """
            The agent takes an action
        """
        action, action_info = trainer.action(observation, test_model=test_model)
        # action = action[0]
        # print("action is ", action)

        """
            Environment takes a step
        """
        # env.render()
        next_observation, reward, done, info = env.step(action)
        episode_step += 1

        """
            Now collecting samples
        """
        s, a, r, sp, done, terminal = observation, action, reward, next_observation, done, (episode_step >= arglist.max_episode_len)
        episode_terminal = done or terminal

        """
            record the experience
        """
        if not test_model:
            experience(trainer, arglist.algo, shaping_method, s, a, r, sp,
                       episode_terminal, action_info, info)

        """
            Transition to the next state
        """
        observation = next_observation

        """
            Record the reward
        """
        if not test_model:
            agent_one_episode_rewards += r
            train_step += 1
        else:
            test_episode_rewards[-1] += r
            test_episode_steps[-1] += 1
            """
                compute the average torque amount of the action
            """
            avg_torque_abs = np.mean(np.abs(a))
            # print("Action is {}".format(a))
            # print("Avg Torque Abs of this action is {}".format(avg_torque_abs))
            test_episode_torques[-1] += avg_torque_abs

        """
            for displaying learned policies
        """
        if arglist.display:
            time.sleep(0.1)
            env.render()
            continue

        """
            Update all trainers, if not in display or benchmark mode
        """
        if not test_model:
            trainer.update(train_step)

        """
            If this episode is done or terminal
            start a new one and make a summary of this episode
        """
        if episode_terminal:
            trainer.episode_done(test_model=test_model)
            observation = env.reset()
            episode_step = 0

        """
            if in training and we have run enough steps to conduct one test
        """
        if not test_model:
          if train_step % arglist.step_to_test == 0 and train_step > 0:
              test_model = True
              test_episode_count = 0
              test_episode_rewards = [0.0]
              test_episode_steps = [0.0]
              test_episode_torques = [0.0]
              print("Round {}, train step {}, begin to test algorithm".format(len(global_test_ep_steps),
                                                                              train_step))
        elif episode_terminal:
            """
                if in test and the episode is terminated
            """
            test_episode_count += 1
            test_episode_torques[-1] /= test_episode_steps[-1]

            """
                Make a summary of the last 1000 episodes
            """
            if test_episode_count == arglist.test_count:
                avg_test_ep_rewards = sum(test_episode_rewards) / (len(test_episode_rewards) )
                avg_test_ep_steps = sum(test_episode_steps) / (len(test_episode_steps))
                avg_test_ep_torque_abs = sum(test_episode_torques) / (len(test_episode_torques))

                global_test_ep_rewards.append(avg_test_ep_rewards)
                global_test_ep_steps.append(avg_test_ep_steps)
                global_test_ep_torques.append(avg_test_ep_torque_abs)

                print("{}-{}, Test run {}, "
                      "Average test step, rewards, torque abs are {}, {}, {}".format(arglist.algo, shaping_method, test_run,
                                                                                     avg_test_ep_steps, avg_test_ep_rewards,
                                                                                     avg_test_ep_torque_abs))

                test_model = False

                """
                    Finish the loop if predefined episodes are run
                """
                if train_step >= arglist.step_to_finish:
                    print('...Finished total of {} training steps.'.format(train_step))
                    break
            else:
                test_episode_rewards.append(0)
                test_episode_steps.append(0)

    # print the return value
    print("The return test rewards are")
    print(global_test_ep_rewards)
    print("The return test steps are")
    print(global_test_ep_steps)
    print("The return test avg torque abs are")
    print(global_test_ep_torques)

    """
        Close the env and write monitor result info to disk
    """
    env.close()
    return np.array(global_test_ep_rewards), np.array(global_test_ep_steps), np.array(global_test_ep_torques)


def experience(trainer, algo, shaping_method, s, a, r, sp, done, action_info, info):
    if algo == "ddpg":
        if shaping_method == "ns":
            c = info.get("c")
            trainer.experience(s, a, r + c, sp, done)
        elif shaping_method == "dpba":
            c = info.get("c")
            trainer.experience(s, a, r, sp, done, c=c)
        elif shaping_method == "pbrs":
            phi_s = info.get("phi_s")
            phi_sp = info.get("phi_sp")
            trainer.experience(s, a, r, sp, done, phi_s=phi_s, phi_sp=phi_sp)
        elif shaping_method == "oprs_v1" or \
                shaping_method == "oprs_v1_freeze" or \
                shaping_method == "oprs_v1_fop" or \
                shaping_method == "oprs_v2" or \
                shaping_method == "oprs_v2_approx" or \
                shaping_method == "oprs_v2_fop" or \
                shaping_method == "oprs_v2_fsa" or \
                shaping_method == "oprs_v2_fsaqin" or \
                shaping_method == "oprs_v3_fsaqin":
            f_phi_s = action_info.get("f_phi_s")
            F_value = info.get("c")
            trainer.experience(s, a, r, sp, done, f_phi_s=f_phi_s, F=F_value)
        else:
            trainer.experience(s, a, r, sp, done)
    elif algo == "ppo":
        if shaping_method == "ns":
            c = info.get("c")
            trainer.experience(s, a, r + c, sp, done, v_pred=action_info)
        elif shaping_method == "pbrs":
            phi_s = info.get("phi_s")
            phi_sp = info.get("phi_sp")
            trainer.experience(s, a, r, sp, done, v_pred=action_info,
                               phi_s=phi_s, phi_sp=phi_sp)
        elif shaping_method == "dpba":
            c = info.get("c")
            v_pred = action_info.get("v_pred")
            phi_sa = action_info.get("phi_sa")
            trainer.experience(s, a, r, sp, done, v_pred=v_pred, c=c, phi_sa=phi_sa)
        elif shaping_method == "oprs_v1_fop" or shaping_method == "oprs_v2_fop" or \
            shaping_method == "oprs_v2_fsa" or shaping_method == "oprs_v3_fsart":
            v_pred = action_info.get("v_pred")
            v_pred_true = action_info.get("v_pred_true")
            f_phi_s = action_info.get("f_phi_s")
            F_value = info.get("c")
            trainer.experience(s, a, r, sp, done, v_pred=v_pred, v_pred_true=v_pred_true,
                               f_phi_s=f_phi_s, F=F_value)
        else:
            trainer.experience(s, a, r, sp, done, v_pred=action_info)
    elif algo == "rcpo_ppo":
        v_pred = action_info.get("v_pred")
        lagrange_multi = action_info.get("lagrange_multi")
        penalty = info.get("penalty")
        trainer.experience(s, a, r, sp, done, v_pred=v_pred,
                           lagrange_multi=lagrange_multi, penalty=penalty)


if __name__ == '__main__':
    arglist = parse_args()
    train_all(arglist)


