import datetime
import os
from typing import Dict, Optional, Tuple

import dask_geopandas as dgpd
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from geopandas import sjoin
from matplotlib.colors import ListedColormap
from shapely.geometry import Point, Polygon
from shapely.ops import unary_union
from tqdm import tqdm, trange
from trajdata.maps.vec_map_elements import MapElementType


def create_agent_polygon(x, y, length, width, heading):
    half_length = length / 2
    half_width = width / 2
    coords = np.array(
        [
            [-half_length, -half_width],
            [half_length, -half_width],
            [half_length, half_width],
            [-half_length, half_width],
        ]
    )
    rotation_matrix = np.array([[np.cos(heading), -np.sin(heading)], [np.sin(heading), np.cos(heading)]])
    rotated_coords = np.dot(coords, rotation_matrix.T)
    translated_coords = rotated_coords + np.array([x, y])
    return Polygon(translated_coords)


def total_position_area_fix_size(df: pd.DataFrame, scenes, ax: Optional[plt.Axes] = None) -> float:
    df = df[["x", "y", "heading"]].reset_index(level=1).reset_index(drop=True)
    agent_size_dict = {}
    for scene in scenes:
        for agent in scene.agents:
            agent_id = agent.name
            length = agent.extent.length
            width = agent.extent.width
            agent_size_dict[agent_id] = {"length": length, "width": width}
    # st = datetime.datetime.now()
    polygons = df.apply(
        lambda row: create_agent_polygon(
            row["x"],
            row["y"],
            agent_size_dict[row["agent_id"]]["length"],
            agent_size_dict[row["agent_id"]]["width"],
            row["heading"],
        ),
        axis=1,
    )
    # print("Polygon created, time used: ", datetime.datetime.now() - st)
    # st = datetime.datetime.now()

    merged_geoms = []
    chunk_size = 1000
    for chunk_idx in trange(0, len(polygons), chunk_size, desc="Merging Geom Chunks"):
        merged_geoms.append(unary_union(polygons[chunk_idx : chunk_idx + chunk_size]))

    merged_geometry = unary_union(merged_geoms)

    # print("Polygon merged, time used: ", datetime.datetime.now() - st)
    # st = datetime.datetime.now()
    # print(f"Total area covered by agents (no overlap): {total_area_no_overlap}")

    merged_gdf = gpd.GeoDataFrame(geometry=[merged_geometry])
    merged_gdf.plot(ax=ax, color="lightblue", edgecolor="black")
    # print("Plot graphed, time used: ", datetime.datetime.now() - st)
    return merged_geometry.area


def create_agent_polygon_from_row(row):
    x, y, heading, length, width = (
        row["x"],
        row["y"],
        row["heading"],
        row["length"],
        row["width"],
    )
    return create_agent_polygon(x, y, length, width, heading)


def split_geoseries(geoseries, n_chunks):
    chunk_size = len(geoseries) // n_chunks
    return [geoseries[i : i + chunk_size] for i in range(0, len(geoseries), chunk_size)]


def total_position_area_var_size(df: pd.DataFrame, ax: Optional[plt.Axes] = None) -> float:
    df = df[["x", "y", "heading", "length", "width"]].reset_index(drop=True)
    df["half_width"] = df["width"] / 2
    df["half_length"] = df["length"] / 2
    df["cos_heading"] = np.cos(df["heading"])
    df["sin_heading"] = np.sin(df["heading"])

    st = datetime.datetime.now()
    polygons = df.apply(create_agent_polygon_from_row, axis=1)
    print("Polygon created, time used: ", datetime.datetime.now() - st)
    st = datetime.datetime.now()

    merged_geometry = unary_union(gpd.GeoSeries(polygons))
    print("Polygon merged, time used: ", datetime.datetime.now() - st)
    st = datetime.datetime.now()
    # merged_geometry = unary_union(gpd.GeoSeries(df['geometry']))
    # total_area_no_overlap = merged_geometry.area
    # print(f"Total area covered by agents (no overlap): {total_area_no_overlap}")

    merged_gdf = gpd.GeoDataFrame(geometry=[merged_geometry])
    merged_gdf.plot(ax=ax, color="lightblue", edgecolor="black")
    ax.set_title("Covered Area (No Overlap)")
    print("Plot graphed, time used: ", datetime.datetime.now() - st)

    return merged_geometry.area


def collision_rate_fix_size(df: pd.DataFrame, scenes) -> Tuple[Dict[str, float], pd.DataFrame]:
    df = df[["x", "y", "heading"]]
    agent_meta_dict = {}
    for scene in scenes:
        for agent in scene.agents:
            agent_meta_dict[agent.name] = {
                "length": agent.extent.length,
                "width": agent.extent.width,
                "type": str(agent.type).split(".")[1],
            }

    collision_rates = {}
    all_collided_agent_ids = set()

    # for each scene
    for scene_id in tqdm(df.index.get_level_values("scene_id").unique()):
        scene_df = df.loc[pd.IndexSlice[scene_id, :, :], :]

        scene_collided_agent_ids = set()
        # for each timestep
        for scene_ts in scene_df.index.get_level_values("scene_ts").unique():
            timestep_df = scene_df.loc[pd.IndexSlice[:, :, scene_ts], :]
            gdf = gpd.GeoDataFrame(
                timestep_df,
                geometry=timestep_df.apply(
                    lambda row: create_agent_polygon(
                        row["x"],
                        row["y"],
                        agent_meta_dict[row.name[1]]["length"],
                        agent_meta_dict[row.name[1]]["width"],
                        row["heading"],
                    ),
                    axis=1,
                ),
            )
            gdf = gdf.reset_index().set_index("agent_id")[["geometry"]]
            overlapping_agents = sjoin(gdf, gdf, predicate="intersects", how="inner")

            overlapping_agents = overlapping_agents[overlapping_agents.index != overlapping_agents.index_right]
            scene_collided_agent_ids = scene_collided_agent_ids.union(set(overlapping_agents.index.unique().values))

        collision_rate = len(scene_collided_agent_ids) / len(scene_df.index.get_level_values(1).unique())
        collision_rates[scene_id] = collision_rate
        all_collided_agent_ids = all_collided_agent_ids.union(scene_collided_agent_ids)

    agent_classes = {agent_meta["type"]: 0 for agent_id, agent_meta in agent_meta_dict.items()}
    collided_agent_classes = {agent_meta["type"]: 0 for agent_id, agent_meta in agent_meta_dict.items()}

    for agent_id, agent_meta in agent_meta_dict.items():
        agent_type = agent_meta["type"]
        agent_classes[agent_type] += 1
        if agent_id in all_collided_agent_ids:
            collided_agent_classes[agent_type] += 1

    class_collision_rates = {
        agent_type: collided_agent_classes[agent_type] / agent_classes[agent_type] for agent_type in agent_classes
    }

    class_collision_rates_df = pd.DataFrame(
        list(class_collision_rates.items()), columns=["agent_type", "collision_rate"]
    )

    return collision_rates, class_collision_rates_df


def collision_rate_var_size(df: pd.DataFrame, scenes):
    df = df[["x", "y", "heading", "length", "width"]]
    current_time = datetime.datetime.now()
    print("The current time is:", current_time.time())
    df["geometry"] = df.apply(create_agent_polygon_from_row, axis=1)

    agent_types = {}
    for scene in scenes:
        for agent in scene.agents:
            agent_types[agent.name] = str(agent.type).split(".")[1]

    current_time = datetime.datetime.now()
    print("The current time is:", current_time.time())

    collision_rates = {}
    all_collided_agent_ids = set()
    current_time = datetime.datetime.now()
    print("The current time is:", current_time.time())
    # for each scene
    for scene_id in tqdm(df.index.get_level_values("scene_id").unique()):
        scene_df = df.loc[pd.IndexSlice[scene_id, :, :], :]

        scene_collided_agent_ids = set()
        # for each timestep
        for scene_ts in scene_df.index.get_level_values("scene_ts").unique():
            timestep_df = scene_df.loc[pd.IndexSlice[:, :, scene_ts], :]
            gdf = gpd.GeoDataFrame(timestep_df["geometry"])
            gdf = gdf.reset_index().set_index("agent_id")
            overlapping_agents = sjoin(gdf, gdf, predicate="intersects", how="inner")

            overlapping_agents = overlapping_agents[overlapping_agents.index != overlapping_agents.index_right]
            scene_collided_agent_ids = scene_collided_agent_ids.union(set(overlapping_agents.index.unique().values))

        # print(len(collided_agent_ids), len(scene_df.index.get_level_values(1).unique()))
        collision_rate = len(scene_collided_agent_ids) / len(scene_df.index.get_level_values(1).unique())
        collision_rates[scene_id] = collision_rate
        all_collided_agent_ids = all_collided_agent_ids.union(scene_collided_agent_ids)

    print("Collision rates per scene: ", collision_rates)
    sns.histplot(collision_rates).set(xlabel="Collision Rate Per Scene", ylabel="Frequency")
    # plt.show()

    agent_classes = {agent_type: 0 for agent_id, agent_type in agent_types.items()}
    collided_agent_classes = {agent_type: 0 for agent_id, agent_type in agent_types.items()}

    for agent_id, agent_type in agent_types.items():
        agent_classes[agent_type] += 1
        if agent_id in all_collided_agent_ids:
            collided_agent_classes[agent_type] += 1

    class_collision_rates = {
        agent_type: collided_agent_classes[agent_type] / agent_classes[agent_type] for agent_type in agent_classes
    }
    print("Collision rates per agent class: ", class_collision_rates)

    class_collision_rates_df = pd.DataFrame(list(class_collision_rates.items()), columns=["Agent_Type", "Value"])
    sns.barplot(x="Agent_Type", y="Value", data=class_collision_rates_df).set(
        xlabel="Agent Type", ylabel="Collision Rate"
    )
    # plt.show()


def plot_position_heatmap(df, vec_map):
    df = df[["x", "y"]]
    df.reset_index()
    # Set up the figure and axes
    map_img, raster_from_world = vec_map.rasterize(
        resolution=1,
        return_tf_mat=True,
        incl_centerlines=False,
        area_color=(255, 255, 255),
        edge_color=(0, 0, 0),
        scene_ts=100,
    )
    road_areas = vec_map.elements[MapElementType.ROAD_AREA]
    series = []
    for id, road_area in road_areas.items():
        series.append(
            Polygon(
                road_area.exterior_polygon.points,
                holes=[poly.points for poly in road_area.interior_holes],
            )
        )
    # road_area_polygons = gpd.GeoSeries(series)

    walkways = vec_map.elements[MapElementType.PED_WALKWAY]

    for id, walkway in walkways.items():
        series.append(Polygon(walkway.polygon.points))
    # walkway_polygons = gpd.GeoSeries(walkway_series)

    cross_walks = vec_map.elements[MapElementType.PED_CROSSWALK]
    for id, cross_walk in cross_walks.items():
        series.append(Polygon(cross_walk.polygon.points))
    series = gpd.GeoSeries(series)

    gdf = gpd.GeoDataFrame(geometry=series)
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(map_img, alpha=0.5, aspect="auto", origin="lower")
    gdf.plot(ax=ax, color="lightblue", edgecolor="black", alpha=0.5)
    sns.scatterplot(x="x", y="y", data=df, s=5)
    fig, ax = plt.subplots()
    num_bins = 500
    hist, xedges, yedges = np.histogram2d(df["x"], df["y"], bins=num_bins)
    cmap = ListedColormap(sns.color_palette(n_colors=10))
    [min_x, min_y, min_z, max_x, max_y, max_z] = vec_map.extent
    gdf.plot(ax=ax, color="lightblue", edgecolor="black", alpha=0.5)
    ax.imshow(
        hist.T,
        cmap=cmap,
        alpha=1,
        origin="lower",
        aspect="auto",
        vmax=10,
        extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
    )
    plt.imshow(
        map_img,
        alpha=0.5,
        aspect="auto",
        origin="lower",
        extent=[min_x, max_x, min_y, max_y],
    )
    plt.show()


def off_road_rates(df: pd.DataFrame, scenes, vec_map) -> Tuple[float, float]:
    agent_types = {}
    for scene in scenes:
        for agent in scene.agents:
            agent_types[agent.name] = str(agent.type).split(".")[1]

    road_areas = vec_map.elements[MapElementType.ROAD_AREA]
    road_series = []
    for road_area in road_areas.values():
        road_series.append(
            Polygon(
                road_area.exterior_polygon.points,
                holes=[poly.points for poly in road_area.interior_holes],
            )
        )
    # road_area_polygons = gpd.GeoSeries(series)
    # walk_series = []

    # walkways = vec_map.elements[MapElementType.PED_WALKWAY]
    # for id, walkway in walkways.items():
    #     walk_series.append(Polygon(walkway.polygon.points))
    # walkway_polygons = gpd.GeoSeries(walkway_series)

    # cross_walks = vec_map.elements[MapElementType.PED_CROSSWALK]
    # for id, cross_walk in cross_walks.items():
    #     walk_series.append(Polygon(cross_walk.polygon.points))

    merged_geoms = []
    chunk_size = 100
    for chunk_idx in trange(0, len(road_series), chunk_size, desc="Merging Roads"):
        merged_geoms.append(unary_union(road_series[chunk_idx : chunk_idx + chunk_size]))

    road_poly = unary_union(merged_geoms)

    # walk_series = gpd.GeoSeries(walk_series)
    # walk_poly = unary_union(walk_series)
    # walk_poly = gpd.GeoDataFrame(geometry=[walk_poly])

    df = df[["x", "y"]].reset_index()
    gdf = gpd.GeoDataFrame(df, geometry=gpd.points_from_xy(df.x, df.y))
    gdf.drop(columns=["x", "y"], inplace=True)

    gdf = dgpd.from_geopandas(gdf, npartitions=os.cpu_count())

    gdf["agent_type"] = gdf["agent_id"].map(agent_types)

    vehicles = gdf[gdf["agent_type"] == "VEHICLE"]
    vehicle_ids = vehicles["agent_id"].unique()

    vehicles["onroad"] = vehicles.within(road_poly)
    # vehicles["onroad"] = False
    # chunk_size = 100
    # for chunk_idx in trange(0, len(vehicles), chunk_size, desc="Calculating Onroad"):
    #     vehicles.iloc[chunk_idx:chunk_idx+chunk_size, -1] = vehicles.iloc[chunk_idx:chunk_idx+chunk_size, 0].within(road_poly)

    vehicle_offroad = vehicles.groupby("agent_id")["onroad"].prod()
    vehicle_offroad_ids = vehicle_offroad[vehicle_offroad == 0].index.unique()

    st = datetime.datetime.now()
    offroad_vehicles = len(vehicle_offroad_ids) / len(vehicle_ids)
    offroad_vehicle_positions = len(vehicles[vehicles["onroad"] == False]) / len(vehicles)
    print("time to compute offroad info: ", datetime.datetime.now() - st)

    return offroad_vehicles, offroad_vehicle_positions

    # fig, ax = plt.subplots(figsize=(10, 10))
    # gpd.GeoSeries(road_poly).plot(ax=ax, color="white", edgecolor="black")
    # vehicle_mask = vehicles[vehicles["onroad"] == False]
    # vehicle_mask.plot(ax=ax, color="lightblue", edgecolor="black")
    # plt.show()
    # plt.figure().clear()

    # vehicle_offroad_per_scene = (
    #     vehicles["onroad"].groupby(level=["scene_id", "agent_id"]).prod()
    # )
    # vehicle_offroad_per_scene = vehicle_offroad_per_scene.groupby(
    #     level=["scene_id"]
    # ).mean()
    # sns.histplot(vehicle_offroad_per_scene)
    # plt.show()

    # pedestrians = gdf[gdf['agent_type'] == 'PEDESTRIAN']
    # pedestrian_mask = pedestrians['geometry'].apply(
    #     lambda x: walk_poly.contains(x)).fillna(False)
    #
    # # Merge the masks back into the GeoDataFrame
    # gdf['on_road'] = vehicle_mask
    # gdf['on_walkway'] = pedestrian_mask
    #
    # # Now we can calculate the off-road rate for each scene
    # # First, create a mask for off-road agents
    # offroad_mask = ~gdf['on_road'] & ~gdf['on_walkway']
    #
    # # Now group by scene_id and calculate the off-road rate
    # offroad_rates = offroad_mask.groupby(level='scene_id').mean()
    # print(offroad_rates)

    # fig, ax = plt.subplots(figsize=(10, 10))
    # walk_gdf.plot(ax=ax, color='lightblue', edgecolor='black')
