import os
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import ticker
from trajdata import MapAPI, UnifiedDataset, VectorMap

from trajdata_analysis.analysis.dataset_wise import (
    collision_rate_fix_size,
    collision_rate_var_size,
    off_road_rates,
    plot_position_heatmap,
    total_position_area_fix_size,
    total_position_area_var_size,
)
from trajdata_analysis.analysis.point_wise import (
    agent_ego_distance,
    pointwise_acceleration,
    pointwise_heading,
    pointwise_jerk,
    pointwise_speed,
)
from trajdata_analysis.analysis.scene_wise import max_sim_agents, sim_agents
from trajdata_analysis.analysis.trajectory_wise import (
    agent_type,
    max_heading_delta,
    rel_heading_delta,
    traj_length_time,
)
from trajdata_analysis.data_preprocessing.data_preprocessing import (
    parallel_extract_dataframes,
)


def main():
    # Setting plot defaults.
    sns.set_theme(style="ticks")

    cache_path = Path("~/.unified_data_cache").expanduser()
    map_api = MapAPI(cache_path)

    for env_name in ["nusc_trainval"]:
        print()
        print("#" * 40)
        print(f"Analyzing {env_name}...")
        print("#" * 40)

        plot_dir = Path(f"plots/{env_name}")
        plot_dir.mkdir(parents=True, exist_ok=True)

        dataset = UnifiedDataset(
            desired_data=[env_name],
            num_workers=os.cpu_count(),
            verbose=True,
            data_dirs={
                # NOTE: This assumes the data is already cached!
                env_name: ""
            },
        )

        print(f"# Data Samples: {len(dataset):,}")

        # Usage
        data_df = parallel_extract_dataframes(dataset, os.cpu_count())
        # print(data_df.columns)

        ### Loading random scene and initializing VectorMap.
        # vec_map: VectorMap = map_api.get_map(
        #     f"{env_name}:singapore-onenorth",
        #     incl_road_lanes=True,
        #     incl_road_areas=True,
        #     incl_ped_crosswalk=True,
        #     incl_ped_walkways=True,
        # )

        # Pointwise:
        fig, ax = plt.subplots()
        pointwise_speed(data_df, ax)
        fig.savefig(plot_dir / "pointwise_speed.pdf", bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots()
        pointwise_acceleration(data_df, ax)
        fig.savefig(plot_dir / "pointwise_acceleration.pdf", bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots()
        pointwise_jerk(data_df, DT=0.5, ax=ax)
        fig.savefig(plot_dir / "pointwise_jerk.pdf", bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots(subplot_kw=dict(projection="polar"))
        pointwise_heading(data_df, ax)
        fig.savefig(plot_dir / "pointwise_heading.pdf", bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots()
        agent_ego_distance(data_df, ax)
        fig.savefig(plot_dir / "agent_ego_distance.pdf", bbox_inches="tight")
        plt.close(fig)

        #
        # # Traj Wise:
        fig, ax = plt.subplots()
        traj_length_time(data_df, DT=0.5, ax=ax)
        fig.savefig(plot_dir / "traj_length.pdf", bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots()
        max_heading_delta(data_df, ax=ax)
        fig.savefig(plot_dir / "max_heading_delta.pdf", bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots()
        rel_heading_delta(data_df, ax=ax)
        fig.savefig(plot_dir / "rel_heading_delta.pdf", bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots()
        agent_type(dataset.scenes(), ax=ax)
        fig.savefig(plot_dir / "agent_type.pdf", bbox_inches="tight")
        plt.close(fig)

        # class_error(dataset.scenes())
        #
        # # Scene Wise
        fig, ax = plt.subplots()
        max_sim_agents(dataset.scenes(), ax=ax)
        fig.savefig(plot_dir / "max_sim_agents.pdf", bbox_inches="tight")
        plt.close(fig)

        fig, ax = plt.subplots()
        sim_agents(dataset.scenes(), ax=ax)
        fig.savefig(plot_dir / "sim_agents.pdf", bbox_inches="tight")
        plt.close(fig)

        # # Dataset Wise
        # fig, ax = plt.subplots(figsize=(10, 10))
        # if "lyft" in env_name or "waymo" in env_name:
        #     covered_area = total_position_area_var_size(data_df, ax=ax)
        # else:
        #     covered_area = total_position_area_fix_size(
        #         data_df, dataset.scenes(), ax=ax
        #     )
        # print(f"{dataset.envs[0].name} Covered Area: {covered_area:.2f} m^2")
        # fig.savefig(plot_dir / "total_position_area_fix_size.pdf", bbox_inches="tight")
        # plt.close(fig)

        # if "lyft" in env_name or "waymo" in env_name:
        #     collision_rates, class_collision_rates_df = collision_rate_var_size(
        #         data_df, dataset.scenes()
        #     )
        # else:
        #     collision_rates, class_collision_rates_df = collision_rate_fix_size(
        #         data_df, dataset.scenes()
        #     )

        # fig, ax = plt.subplots()
        # sns.histplot(collision_rates, ax=ax)
        # ax.set(xlabel="Collision Rate", ylabel="# Scenes")
        # ax.xaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
        # fig.savefig(
        #     plot_dir / "collision_rate_fix_size_per_scene.pdf", bbox_inches="tight"
        # )
        # plt.close(fig)

        # fig, ax = plt.subplots()
        # sns.barplot(
        #     x="agent_type", y="collision_rate", data=class_collision_rates_df, ax=ax
        # )
        # ax.set(xlabel="Agent Type", ylabel="Collision Rate")
        # ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
        # plt.savefig(
        #     plot_dir / "collision_rate_fix_size_per_agent_class.pdf",
        #     bbox_inches="tight",
        # )
        # plt.close(fig)

        # # plot_position_heatmap(data_df, vec_map)
        # offroad_vehicle_rate, offroad_vehicle_position_rate = off_road_rates(data_df, dataset.scenes(), vec_map)
        # print("Proportions of vehicles that have been off-road:", offroad_vehicle_rate)
        # print("Proportion of vehicle position points that are off-road:", offroad_vehicle_position_rate)
        # # plt.show()


if __name__ == "__main__":
    main()
