"""Behavior Cloning Environment."""

from pathlib import Path
from typing import Any, Optional
from PIL import Image

import carla
import gym.spaces
import gym
import numpy as np
import cv2
import os
import time


from carla_env.utils.roaming_agent import RoamingAgent
from carla_env.base import BaseCarlaEnvironment
from carla_env.utils.config import ExperimentConfigs
from carla_env.utils.vector import to_array
from carla_env.dataset import Dataset, load_datasets
from carla_env.simulator.actor import Actor
from carla_env.simulator.simulator import Simulator
from carla_env.utils.carla_sync_mode import CarlaSyncMode
from carla_env.utils.config import ExperimentConfigs
from carla_env.utils.roaming_agent import RoamingAgent
from carla_env.agents import Agent, AgentState
from carla_env.agents.navigation.local_planner import LocalPlanner, RoadOption
from carla_env.agents.tools.misc import compute_magnitude_angle, is_within_distance_ahead


class HighLevelBehaviorCloningCarlaEnvironment(BaseCarlaEnvironment):
    """Behavior Cloning Environment."""

    def __init__(self, config: ExperimentConfigs, port: int = None, sampling_resolution: float = 0.25, expert_mode: bool = False):
        super().__init__(config, port, sampling_resolution)
        self.action_space = gym.spaces.Discrete(7)
        observation_space = {
                "obs": gym.spaces.Box(
                    shape=(24, ), low=-1, high=1
                ),
                'task': gym.spaces.Box(
                    low=np.zeros(self.sim.route_manager.route_selector.route_len),
                    high=np.ones(self.sim.route_manager.route_selector.route_len)
                ),
                "image": gym.spaces.Box(
                    shape=(160, 160, 3), low=0, high=255, dtype=np.uint8
                ),
            }


        self.agent = None
        self.num_task = self.sim.route_manager.route_selector.route_len

        self.observation_space = gym.spaces.Dict(observation_space)
        self.count = 0
        self.res = 0
        self.fixed = False
        self.velocity_reset_count = 0
        self.able_actions = None
        self.expert_mode = expert_mode
        self.expert_nodes = None

        self.last_expert_action = None
        self.is_random = config.random_route
        self.action_traj = []

    def step(
        self,
        action=None,
        traffic_light_color: Optional[str] = "",
    ):
        if not self.expert_mode:
            next_obs, reward, done, info = self._simulator_step(action, traffic_light_color)
        else:
            next_obs, reward, done, info = self._simulator_step()

        return next_obs, reward, done, info

    def expert_high_compute_action(self, node):
        self.agent._state = AgentState.NAVIGATING
        path, start_node_idx = self.agent._local_planner.expert_high_level_run_step(node)

        return path, start_node_idx

    def high_compute_action(self, action):
        self.agent._state = AgentState.NAVIGATING

        # controls, info = self.agent._local_planner.high_level_run_step(action)
        # return controls, info

        path, start_node_idx = self.agent._local_planner.high_level_run_step(action)
        return path, start_node_idx

    def expert_low_compute_action(self, edge, start_node_idx):
        controls = []
        target = edge['exit_waypoint']
        route = edge['path']
        done = False

        if edge['type'].value != RoadOption.CHANGELANELEFT.value and edge['type'].value != RoadOption.CHANGELANERIGHT.value:
            if edge['type'].value == RoadOption.LANEFOLLOW.value:
                min_dist = 2.0
            else:
                min_dist = 5.0
        else:
            min_dist = 2.0

        if edge['type'].value != RoadOption.CHANGELANELEFT.value and edge['type'].value != RoadOption.CHANGELANERIGHT.value:
            for i in range(len(route) - 1):
                target_waypoint = route[i + 1]
                control = self.agent._local_planner.one_run_step(target_waypoint)
                self.sim.world.carla.debug.draw_string(target_waypoint.transform.location, 'o', draw_shadow=False, color=carla.Color(r=255, g=0, b=0))
                throttle, steer, brake = control.throttle, control.steer, control.brake

                controls.append([throttle, steer, brake])


                vehicle_control = carla.VehicleControl(
                    throttle=throttle,  # [0,1]
                    steer=steer,  # [-1,1]
                    brake=brake,  # [0,1]
                    hand_brake=False,
                    reverse=False,
                    manual_gear_shift=False,
                )

                self.sim.ego_vehicle.apply_control(vehicle_control)

                _ = self.sync_mode.tick(timeout=10.0)
                dist = self.sim.ego_vehicle.location.distance(target.transform.location)
                if dist < min_dist:
                    done = True
                    break

        count = 0
        append_control = []
        while self.sim.ego_vehicle.location.distance(target.transform.location) > min_dist:
            control = self.agent._local_planner.one_run_step(target)
            self.sim.world.carla.debug.draw_string(target.transform.location, 'o', draw_shadow=False,
                                             color=carla.Color(r=255, g=0, b=0))

            throttle, steer, brake = control.throttle, control.steer, control.brake
            controls.append([throttle, steer, brake])

            vehicle_control = carla.VehicleControl(
                throttle=throttle,  # [0,1]
                steer=steer,  # [-1,1]
                brake=brake,  # [0,1]
                hand_brake=False,
                reverse=False,
                manual_gear_shift=False,
            )

            self.sim.ego_vehicle.apply_control(vehicle_control)

            _ = self.sync_mode.tick(timeout=10.0)

            count += 1
            if count >= 1000:
                append_control = []
                break

        controls.extend(append_control)

        # img = self.sim.ego_vehicle.camera.image
        # cv2.imwrite(f'./dataset/carla_image/route_{self.start_node}_{self.end_node}_order_{self.count}_node_{start_node_idx}.png', img)

        return controls

    def _simulator_step(
        self,
        action: Optional[np.ndarray] = None,
        traffic_light_color: Optional[str] = None,
    ):
        if not self.expert_mode and action in self.able_actions:
            # controls, able_actions = self.high_compute_action(action)
            # self.able_actions = [able_action.value for able_action in able_actions]


            path_edge, start_node_idx = self.high_compute_action(action)
            self.expert_low_compute_action(path_edge, start_node_idx)

            self.able_actions = [able_action.value for able_action in self.agent._local_planner.get_able_actions()]
            print("Able actions", self.able_actions)
            # cv2.imwrite(f'./carla_image/task_{self.sim.route_manager.route_selector.get_route_idx()}_node_{start_node_idx}.png', img)

            if self.sim.ego_vehicle.location.distance(self.sim.target_location) < 5.0:
                done = True
            else:
                done = False


            return (
                None,
                None,
                done,
                None,
            )
        elif self.expert_mode:
            if len(self.expert_nodes) == 0:
                done = True
            else:
                done = False
                node = self.expert_nodes.pop(0)
                path_edge, start_node_idx = self.expert_high_compute_action(node)
                if path_edge != None:
                    controls = self.expert_low_compute_action(path_edge, node)
                    self.action_traj.extend(controls)
                    self.action_length_traj.append([len(controls), path_edge['type'].value])

                    self.count += 1

                    if self.sim.ego_vehicle.location.distance(self.sim.target_location) < 5.0:
                        done = True

                else:
                    done = True
            # if done:
            #     np.save((f'./dataset/carla_action_traj/route_{self.spawn_node}_{self.end_node}.npy'), np.array(self.action_traj))
            #     np.save((f'./dataset/carla_action_len/route_{self.spawn_node}_{self.end_node}.npy'), np.array(self.action_length_traj))

            return None, None, done, None
        else:
            print("Not valid action")
            print("Able actions", self.able_actions)

        return None, None, False, None

    def reset(self, get_info: bool = False):
        self.reset_simulator()
        self.sim.route_manager.planner._reasign_type()
        self.action_traj = []
        self.action_length_traj = []

        self.agent = RoamingAgent(
            self.sim.ego_vehicle.carla,
            follow_traffic_lights=self.config.lights,
        )
        self.agent._local_planner.set_global_plan(
            self.sim.route_manager.waypoints
        )

        self.agent._local_planner.set_global_planner(
            self.sim.route_manager.planner,
            self.sim.route_manager.initial_transform,
            self.sim.route_manager.target_transform,
        )
        print("Cur location / Initial location", self.sim.ego_vehicle.location, self.sim.route_manager.initial_transform)

        vehicle_control = carla.VehicleControl(
            throttle=0,  # [0,1]
            steer=0,  # [-1,1]
            brake=0,  # [0,1]
            hand_brake=False,
            reverse=False,
            manual_gear_shift=False,
        )

        self.sim.ego_vehicle.apply_control(vehicle_control)

        _ = self.sync_mode.tick(timeout=10.0)

        self.count = 0
        self.node_list = []

        able_actions = self.agent._local_planner.get_able_actions()
        if not self.expert_mode:
            print("Able actions", able_actions)
        else:
            self.expert_nodes = self.sim.route_manager.path_search(self.sim.ego_vehicle.location)
            print('Origin expert', self.expert_nodes)
            self.end_node = self.expert_nodes[-1]

            self.spawn_node = self.agent._local_planner._find_closest_in_node(self.sim.ego_vehicle.transform)[0]
            if self.spawn_node in self.expert_nodes:
                self.start_node = self.spawn_node
                while True:
                    if self.expert_nodes.pop(0) == self.start_node:
                        break
            elif (self.spawn_node, self.expert_nodes[0]) in self.agent._local_planner.global_planner._graph.edges.keys():
                self.start_node = self.spawn_node
            else:
                self.start_node = self.expert_nodes.pop(0)
            last_node = self.start_node

            if (self.start_node, self.expert_nodes[0]) in self.agent._local_planner.global_planner._graph.edges.keys():
                option = self.agent._local_planner.global_planner._graph.edges[(self.start_node, self.expert_nodes[0])]['type'].value
                if option == RoadOption.CHANGELANELEFT.value or option == RoadOption.CHANGELANERIGHT.value:
                    last_node = self.expert_nodes.pop(0)
            self.agent._local_planner.set_pre_node(last_node)
            print('Node', self.spawn_node, self.start_node, self.end_node, self.expert_nodes)
            # time.sleep(10)
            img = self.sim.ego_vehicle.camera.image

            # if not os.path.exists('./dataset/carla_image'):
            #     os.mkdir('./dataset/carla_image')
            # if not os.path.exists('./dataset/carla_action_traj'):
            #     os.mkdir('./dataset/carla_action_traj')
            # if not os.path.exists('./dataset/carla_action_len'):
            #     os.mkdir('./dataset/carla_action_len')
            #
            # cv2.imwrite(f'./dataset/carla_image/route_{self.spawn_node}_{self.end_node}_order_{self.count}_node_{self.spawn_node}.png', img)
            self.count += 1

        self.able_actions = [able_action.value for able_action in able_actions]

        return None

    def select_route(self, idx: int = 0):
        self.sim.route_manager.select_route_by_idx(idx)

    def set_expert(self):
        self.expert_mode = True
