import argparse
import os
import time
from typing import List, Union

import gym
import joblib
import numpy as np
import yaml

from action_masking.sb3_contrib.common.maskable.utils import generator_center_to_array
from action_masking.provably_safe_env.envs.seeker_circle_env import SeekerCircleEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import (
    NormalActionNoise,
    OrnsteinUhlenbeckActionNoise,
)
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.utils import configure_logger
from stable_baselines3.common.vec_env import DummyVecEnv

from action_masking.experiments.benchmark import Benchmark
from action_masking.callbacks.seeker_callback import SeekerCallback
from action_masking.sb3_contrib.common.wrappers.action_masking import (
    ActionMaskingWrapper,
)
from action_masking.sb3_contrib.common.wrappers.informer import InformerWrapper
from action_masking.util.sets import Zonotope
from action_masking.util.util import (
    ActionSpace,
    Algorithm,
    Approach,
    ContMaskingMode,
    Stage,
    TransitionTuple,
    load_configs_from_dir,
    hyperparam_optimization,
    gen_experiment,
)


class BenchmarkSeeker(Benchmark):
    def __init__(self, config: dict = {}, adapt_gradient: bool = False):
        self.config = config
        self.env_name = "SeekerCircleEnv-v0"
        self.adapt_gradient = adapt_gradient
        super().__init__(self.config, use_zonotope=True)

    def safe_control_fn(self, env, safe_region):
        print("Fail Safe Action")
        agent_pos = (
            env.get_attr("_agent_position")
            if isinstance(env, DummyVecEnv)
            else env.unwrapped._agent_position
        )
        obstacle_pos = (
            env.get_attr("_obstacle_position")
            if isinstance(env, DummyVecEnv)
            else env.unwrapped._obstacle_position
        )
        obstacle_radius = (
            env.get_attr("_obstacle_radius")
            if isinstance(env, DummyVecEnv)
            else env.unwrapped._obstacle_radius
        )
        env_size = (
            env.get_attr("size") if isinstance(env, DummyVecEnv) else env.unwrapped.size
        )

        action = agent_pos - obstacle_pos
        action = action / (np.linalg.norm(action) * 5)

        # Check if action would lead out of bounds
        if ((agent_pos + action) >= env_size).any():
            action = -action

        # Check if action now would collide with the obstacle
        if np.linalg.norm(agent_pos + action) <= obstacle_radius:
            # 90 degrees
            action = np.array([action[1], -action[0]])

            # Check if action would lead out of bounds
            if ((agent_pos + action) >= env_size).any():
                action = -action

        return action

    def continuous_safe_space_fn(self, env, safe_region) -> np.ndarray:
        return super().continuous_safe_space_fn(env, safe_region)

    def sampling_fn(self):
        return super().sampling_fn()

    def continuous_safe_space_fn_masking_zonotope(self, env, safe_region) -> Zonotope:
        # template_set = Zonotope.from_unit_box(2)
        # template_set = Zonotope.from_random(2, 3)
        # template_set = Zonotope(G=np.array(config.get("G")))
        safe_set = env.calc_safe_input_set_zono(self.template_input_set, debug=False)
        return safe_set

    def create_env(
            self,
            env_name: str,
            space: ActionSpace,
            approach: Approach,
            transition_tuple: TransitionTuple,
            sampling_seed: int = None,
            continuous_action_masking_mode: ContMaskingMode = None,
            render_mode: str = None,
            render_safe_input_set: bool = False,
            randomize: bool = True,
    ) -> DummyVecEnv:

        env = gym.make(
            env_name,
            seed=sampling_seed,
            render_mode=render_mode,
            render_safe_input_set=render_safe_input_set,
            template_input_set=self.template_input_set,
            randomize=randomize,
        )
        env = gym.wrappers.TimeLimit(env, max_episode_steps=self.max_episode_steps)

        alter_action_space = gym.spaces.Box(
            low=-1, high=1, shape=(2,), dtype=np.float32
        )

        punishment_fn = None
        generate_wrapper_tuple = None
        if (
                transition_tuple is TransitionTuple.AdaptionPenalty
                or transition_tuple is TransitionTuple.Both
        ):
            def punishment_fn(env, action, reward, safe_action):
                return reward + self.punishment

        if (
                transition_tuple is TransitionTuple.SafeAction
                or transition_tuple is TransitionTuple.Both
        ):
            generate_wrapper_tuple = True

        def transform_action_space_fn(action):
            """Convert action from [-1, 1] to [u_min, u_max]."""
            # [-1,1] -> [u_min, u_max]
            return np.clip(
                ((action + 1) / 2) * (self.u_high - self.u_low) + self.u_low,
                self.u_low,
                self.u_high,
            )

        def inv_transform_action_space_fn(u):
            """Convert action from [u_min, u_max] to [-1, 1]."""
            # [a_min, a_max] -> [-1,1]
            return np.clip(
                ((u - self.u_low) / (self.u_high - self.u_low)) * 2 - 1, -1, 1
            )

        def transform_action_space_zonotope_fn(
                action: Union[np.ndarray, List[float], float], safe_space: Zonotope
        ) -> np.ndarray:
            """
            Convert action in generator dimension (number of zonotope alphas) to action in action space and convert to be in [u_min, u_max].
            Args:
                action: action in generator dimension
                safe_space: safe space
            Returns:
                action in action space in [u_min, u_max]
            """
            action = safe_space.G @ action + np.squeeze(safe_space.c)
            return np.clip(action, self.u_low, self.u_high)

        def inv_transform_action_space_zonotope_fn(
                u: np.ndarray, safe_space: Zonotope
        ) -> np.ndarray:
            """
            Convert action in action space to action in generator dimension (number of zonotope alphas) and convert to be in [-1, 1].
            Args:
                u: action in action space
                safe_space: safe space
            Returns:
                action in generator dimension in [-1, 1]
            """
            action = np.linalg.pinv(safe_space.G) @ (u - np.squeeze(safe_space.c))
            return np.clip(
                action, -1, 1
            )

        if approach is Approach.Baseline:
            env = InformerWrapper(
                env=env,
                alter_action_space=alter_action_space,
                transform_action_space_fn=transform_action_space_fn,
            )
        elif approach is Approach.Masking:
            if (
                    env.action_space is gym.spaces.Box
                    and continuous_action_masking_mode is None
            ):
                raise ValueError("action masking mode not set")
            if continuous_action_masking_mode and (
                    continuous_action_masking_mode == ContMaskingMode.Generator
            ):
                generator_dim = self.template_input_set.G.shape[1]
            else:
                generator_dim = None

            env = ActionMaskingWrapper(
                env,
                safe_region=self.safe_region,
                dynamics_fn=self.dynamics_fn,
                safe_control_fn=self.safe_control_fn,
                punishment_fn=punishment_fn,
                continuous_safe_space_fn=self.continuous_safe_space_fn_masking_zonotope,
                continuous_action_space_fn_polytope=(
                    self.continuous_safe_space_fn_masking_inner_interval
                    if self.log_polytope_space
                    else None
                ),
                safe_region_polytope=(
                    self.safe_region_polytope if self.log_polytope_space else None
                ),
                alter_action_space=alter_action_space,
                transform_action_space_fn=(
                    transform_action_space_zonotope_fn
                    if continuous_action_masking_mode == ContMaskingMode.Generator
                    else transform_action_space_fn
                ),
                generate_wrapper_tuple=generate_wrapper_tuple,
                inv_transform_action_space_fn=(
                    transform_action_space_zonotope_fn
                    if continuous_action_masking_mode == ContMaskingMode.Generator
                    else transform_action_space_fn
                ),
                continuous_action_masking_mode=continuous_action_masking_mode,
                generator_dim=generator_dim,
                safe_center_obs=self.safe_center_obs,
            )

        return DummyVecEnv([lambda: Monitor(env)])

    def run_experiment(
            self,
            alg: Algorithm,
            policy: BasePolicy,
            space: ActionSpace,
            approach: Approach,
            transition_tuple: TransitionTuple,
            path: str,
            continuous_action_masking_mode=None,
    ):

        policy_kwargs = dict()
        hyperparams = self.config.get("algorithms").get(alg.name, {})

        if alg is Algorithm.PPO and (
                transition_tuple is TransitionTuple.SafeAction
                or transition_tuple is TransitionTuple.Both
        ):
            hyperparams["normalize_advantage"] = False
            if transition_tuple is TransitionTuple.SafeAction:
                hyperparams["learning_rate"] *= 1e-2
        if "activation_fn" in hyperparams:
            from torch import nn

            policy_kwargs["activation_fn"] = {"tanh": nn.Tanh, "relu": nn.ReLU}[
                hyperparams["activation_fn"]
            ]
            del hyperparams["activation_fn"]
        if "log_std_init" in hyperparams:
            policy_kwargs["log_std_init"] = hyperparams["log_std_init"]
            del hyperparams["log_std_init"]
        if "network_size" in hyperparams:
            net_size = hyperparams["network_size"]
            policy_kwargs["net_arch"] = [net_size, net_size]
            del hyperparams["network_size"]
        else:
            policy_kwargs["net_arch"] = [32, 32]
        if "log_std_init" in hyperparams:
            policy_kwargs["log_std_init"] = hyperparams["log_std_init"]
            del hyperparams["log_std_init"]
        if "noise_type" in hyperparams and "noise_std" in hyperparams:
            action_dim = (
                self.template_input_set.G.shape[1]
                if hasattr(self, "template_input_set")
                else 2
            )
            if continuous_action_masking_mode is ContMaskingMode.Interval:
                action_dim = 2
            if hyperparams["noise_type"] == "normal":
                hyperparams["action_noise"] = NormalActionNoise(
                    mean=np.zeros(action_dim),
                    sigma=hyperparams["noise_std"] * np.ones(action_dim),
                )
            elif hyperparams["noise_type"] == "ornstein-uhlenbeck":
                hyperparams["action_noise"] = OrnsteinUhlenbeckActionNoise(
                    mean=np.zeros(action_dim),
                    sigma=hyperparams["noise_std"] * np.ones(action_dim),
                )
            del hyperparams["noise_type"]
            del hyperparams["noise_std"]
        elif "noise_type" in hyperparams and "noise_std" not in hyperparams:
            del hyperparams["noise_type"]
        elif "noise_std" in hyperparams and "noise_type" not in hyperparams:
            del hyperparams["noise_std"]

        if continuous_action_masking_mode == ContMaskingMode.ConstrainedNormal:
            policy_kwargs["use_zono_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        if continuous_action_masking_mode == ContMaskingMode.Generator and self.adapt_gradient:
            policy_kwargs["use_generator_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction

        # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
        use_wrapper_tuple = (
                replace_policy_tuple or transition_tuple is TransitionTuple.Both
        )

        # Note: This is redundant!
        use_discrete_masking = (
                approach is Approach.Masking and space is ActionSpace.Discrete
        )
        use_continuous_masking = (
                approach is Approach.Masking and space is ActionSpace.Continuous
        )

        cur_time = time.perf_counter()
        for stage in Stage:
            print(f"Running experiment {path} in stage {stage.name}")
            tb_log_dir = os.getcwd() + f"/tensorboard/{stage.name}/{path}"
            model_dir = os.getcwd() + f"/models/{path}/"

            if stage is Stage.Train:
                callback = SeekerCallback(
                    safe_region=self.safe_region,
                    action_space=space,
                    action_space_area=self.action_space_area_eq,
                    verbose=2,
                )

                seeds = range(1, self.train_iters + 1)

                for i in range(1, self.train_iters + 1):
                    env = self.create_env(
                        env_name=self.env_name,
                        space=space,
                        approach=approach,
                        transition_tuple=transition_tuple,
                        sampling_seed=seeds[i - 1],
                        continuous_action_masking_mode=continuous_action_masking_mode,
                        randomize=self.config["randomize"]
                    )

                    start_time = time.perf_counter()

                    model = alg.value(
                        seed=i,
                        env=env,
                        policy=policy,
                        tensorboard_log=tb_log_dir,
                        policy_kwargs=policy_kwargs,
                        device="cpu",
                        **hyperparams,
                    )
                    try:
                        model.learn(
                            tb_log_name="",
                            callback=callback,
                            log_interval=None,
                            total_timesteps=self.steps,
                            use_wrapper_tuple=use_wrapper_tuple,
                            replace_policy_tuple=replace_policy_tuple,
                            use_discrete_masking=use_discrete_masking,
                            use_continuous_masking=use_continuous_masking,
                        )
                    except ValueError as e:
                        print("Error in training: ", e)

                    os.makedirs(model_dir, exist_ok=True)
                    n_existing_models = len(os.listdir(model_dir))

                    model.save(model_dir + str(i + n_existing_models))
                    print(
                        "Iteration time {} s".format((time.perf_counter() - start_time))
                    )

            elif stage is Stage.Deploy:
                callback = SeekerCallback(
                    safe_region=self.safe_region,
                    action_space=space,
                    action_space_area=self.action_space_area_eq,
                    verbose=2,
                    train=False,
                )

                for i in range(1, self.train_iters + 1):

                    # Create env
                    env: SeekerCircleEnv = self.create_env(
                        env_name=self.env_name,
                        space=space,
                        approach=approach,
                        transition_tuple=transition_tuple,
                        sampling_seed=self.train_iters + i,
                        continuous_action_masking_mode=continuous_action_masking_mode,
                        randomize=self.config["randomize"]
                    )

                    # Load model
                    n_existing_models = len(os.listdir(model_dir))
                    model_path = model_dir + str(n_existing_models - (i - 1))
                    model = alg.value.load(model_path)
                    model.set_env(env)

                    # Setup callback
                    logger = configure_logger(
                        tb_log_name="", tensorboard_log=tb_log_dir
                    )
                    model.set_logger(logger)
                    callback.init_callback(model=model)

                    for j in range(self.n_eval_ep):
                        done = False
                        action_mask = None
                        obs = env.reset()
                        # Give access to local variables
                        callback.update_locals(locals())
                        callback.on_rollout_start()
                        while not done:
                            safe_set_zono = env.envs[0].get_safe_space()
                            safe_set = None
                            if continuous_action_masking_mode == ContMaskingMode.ConstrainedNormal:
                                safe_set = generator_center_to_array(safe_set_zono.G, safe_set_zono.c)

                            action, _ = model.predict(
                                observation=obs,
                                action_masks=safe_set,
                                deterministic=True,
                            )

                            obs, reward, done, info = env.step(action)

                            # Give access to local variables
                            callback.update_locals(locals())
                            if callback.on_step() is False:
                                return

                    env.close()

        print("Experiment time {} min".format((time.perf_counter() - cur_time) / 60))

    def optimize_hyperparams(self,
                             alg: Algorithm,
                             policy: BasePolicy,
                             space: ActionSpace,
                             approach: Approach,
                             transition_tuple: TransitionTuple,
                             path: str,
                             cont_action_masking_mode: ContMaskingMode):
        """
        Optimize the training hyperparameters using Optuna and SB3zoo.
        """

        # Replace policy tuple if only (s, a_phi, s′, r) is used
        replace_policy_tuple = transition_tuple is TransitionTuple.SafeAction
        # Policy should use (s, a_phi, s′, r) for TransitionTuple.SafeAction or TransitionTuple.Both
        use_wrapper_tuple = replace_policy_tuple or transition_tuple is TransitionTuple.Both
        tb_log_dir = os.getcwd() + f'/tensorboard/Optimize/{path}'
        study_dir = os.getcwd() + f'/optuna/studies/{path}'

        # Check if study_dir exists
        if os.path.exists(study_dir):
            raise ValueError(f"Study directory {study_dir} already exists.")

        use_discrete_masking = (
                approach is Approach.Masking and space is ActionSpace.Discrete
        )
        use_continuous_masking = (
                approach is Approach.Masking and space is ActionSpace.Continuous
        )

        env_args = {
            'env_name': self.env_name,
            'space': space,
            'approach': approach,
            'transition_tuple': transition_tuple,
            'sampling_seed': 0,
            'continuous_action_masking_mode': cont_action_masking_mode,
            'randomize': self.config["randomize"]
        }
        learn_args = {
            "use_wrapper_tuple": use_wrapper_tuple,
            "replace_policy_tuple": replace_policy_tuple,
            "use_discrete_masking": use_discrete_masking,
            "use_continuous_masking": use_continuous_masking
        }
        hyperparams = {
            'seed': 0,
            'policy': policy,
            'tensorboard_log': tb_log_dir,
            'device': "cpu"
        }

        if cont_action_masking_mode == ContMaskingMode.ConstrainedNormal:
            hyperparams["use_zono_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        if cont_action_masking_mode == ContMaskingMode.Generator and self.adapt_gradient:
            hyperparams["use_generator_gaussian_dist"] = True
            hyperparams["template_generator_shape"] = self.template_input_set.G.shape

        study = hyperparam_optimization(
            algo=alg,
            model_fn=alg.value,
            env_fn=self.create_env,
            env_args=env_args,
            learn_args=learn_args,
            n_trials=50,
            n_timesteps=self.steps,
            hyperparams=hyperparams,
            n_jobs=1,
            sampler_method='tpe',
            pruner_method='median',
            seed=0,
            verbose=1,
            study_dir=study_dir,
        )
        os.makedirs(study_dir, exist_ok=True)
        joblib.dump(study, study_dir + "/study.pkl")

    def plot_importance_hyperparams(
            self,
            path: str
    ):
        """
        Plot the importance of the hyperparameters using Optuna and SB3zoo.
        """
        import optuna
        study = joblib.load(os.getcwd() + f'/optuna/studies/{path}/study.pkl')
        fig = optuna.visualization.plot_param_importances(
            study, target=lambda t: t.value, target_name="value"
        )
        fig.show()
        return fig


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-optimize",
        "--optimize-hyperparameters",
        action="store_true",
        default=False,
        help="Run hyperparameters search",
    )
    parser.add_argument(
        "-approach",
        "--approach",
        default="masking",
        help="'baseline' or 'masking",
    )
    parser.add_argument(
        "-mm",
        "--masking-mode",
        default="generator",
        help="'generator', 'ray', or 'distribution'",
    )
    args, _ = parser.parse_known_args()
    args = vars(args)

    space = ActionSpace.Continuous
    transition = TransitionTuple.Naive

    adapt_gradient_generator = False
    if args["approach"] == "baseline":
        approach = Approach.Baseline
        mode = None
        hp_file = "hyperparams/hyperparams_seeker.yml"
    else:
        approach = Approach.Masking
        if args["masking_mode"] == "generator":
            mode = ContMaskingMode.Generator
            adapt_gradient_generator = True
            hp_file = "hyperparams/hyperparams_seeker_gen.yml"
        elif args["masking_mode"] == "ray":
            mode = ContMaskingMode.Ray
            hp_file = "hyperparams/hyperparams_seeker_ray.yml"
        elif args["masking_mode"] == "distribution":
            mode = ContMaskingMode.ConstrainedNormal
            hp_file = "hyperparams/hyperparams_seeker_dist.yml"
        else:
            raise ValueError("Masking mode not set")

    with open(hp_file, 'r') as file:
        config = yaml.safe_load(file)

    env_name = "SeekerCircleEnv"  # "2dQuadrotorCoupledDynamics"
    benchmark = BenchmarkSeeker(config=config[env_name], adapt_gradient=adapt_gradient_generator)

    start_time = time.perf_counter()

    # Use custom configuration
    alg = Algorithm.PPO
    from action_masking.util.util import get_policy
    policy = get_policy(alg)

    path = f"{env_name}/{approach}/{transition}/{space}/{mode}/{alg}"

    if args["optimize_hyperparameters"]:
        benchmark.optimize_hyperparams(
            alg=alg,
            policy=policy,
            space=space,
            approach=approach,
            transition_tuple=transition,
            path=path,
            cont_action_masking_mode=mode,
        )
    else:
        benchmark.run_experiment(
            alg=alg,
            policy=policy,
            space=space,
            approach=approach,
            transition_tuple=transition,
            path=path,
            continuous_action_masking_mode=mode,
        )

    print(
        "\033[1m"
        + f"Total elapsed time {(time.perf_counter() - start_time) / 60:.2f} min"
    )
