from typing import Sequence, Union, Type, List, Tuple

import attr
import gym
import numpy as np
import torch.nn as nn

from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.base_abstractions.sensor import Sensor
from allenact.embodiedai.sensors.vision_sensors import RGBSensor, DepthSensor
from allenact.utils.experiment_utils import Builder
from allenact_plugins.clip_plugin.clip_preprocessors import (
    NaivePreprocessor,
    ClipResNetPreprocessor,
    ClipViTPreprocessor,
    ClipTextPreprocessor,
    PromptClipViTPreprocessor,
    SNPromptClipViTPreprocessor,
    PromptClipATTMViTPreprocessor,
    SNPromptClipATTMViTPreprocessor
)
from projects.plugins.ithor_plugin.ithor_sensors import GoalObjectTypeThorSensor
from allenact_plugins.clip_plugin.clip_zeroshot_objectnav_models import (
    ObjectNavActorCritic,
    COMObjectNavActorCritic,
    ENSObjectNavActorCritic,
    CLIPObjectNavActorCritic,
    CONPEObjectNavActorCritic
)


@attr.s(kw_only=True)
class CLIPViTGRUActorCriticMixin:
    sensors: Sequence[Sensor] = attr.ib()
    clip_model_type: str = attr.ib()
    screen_size: int = attr.ib()
    goal_sensor_type: Type[Sensor] = attr.ib()
    pool: bool = attr.ib(default=False)
    pooling_type: str = attr.ib()
    target_types: List[str] = attr.ib()
    prompt: Tuple = attr.ib()
    multi_p_mode: str = attr.ib()
    meta_mode: bool = attr.ib()
    noise_std: float = attr.ib(default=0.0)
    source_model: str = attr.ib(default=None)

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        rgb_sensor = next((s for s in self.sensors if isinstance(s, RGBSensor)), None)
        goal_sensor = next((s for s in self.sensors if isinstance(s, GoalObjectTypeThorSensor)), None)
        self.goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, self.goal_sensor_type)),
            None,
        )
        preprocessor_model = NaivePreprocessor

        assert rgb_sensor is not None and goal_sensor is not None

        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_means)
                - np.array(preprocessor_model.CLIP_RGB_MEANS)
            )
            < 1e-5
        )
        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_sds)
                - np.array(preprocessor_model.CLIP_RGB_STDS)
            )
            < 1e-5
        )

        preprocessor = preprocessor_model(
                rgb_input_uuid=rgb_sensor.uuid,
                goal_sensor_uuid=self.goal_sensor_uuid,
                clip_model_type=self.clip_model_type,
                pool=self.pool,
                pooling_type=self.pooling_type,
                class_emb_only = False,
                output_uuid="rgb_clip_vit",
                task="imagenav",
            )
        
        self.preprocessor_output_shape = preprocessor.output_shape

        preprocessors = [
            preprocessor
        ]

        return preprocessors

    def create_model(self, num_actions: int, **kwargs) -> nn.Module:
        if self.multi_p_mode[0] in ["COMPOSE", "ATTEMPT"]:
            ActorCritic = COMObjectNavActorCritic
        elif self.multi_p_mode[0] in ["ENSEMBLE", "SESoM"]:
            ActorCritic = ENSObjectNavActorCritic
        elif self.multi_p_mode[0] is None:
            ActorCritic = ObjectNavActorCritic
        else:
            raise NotImplementedError
        return ActorCritic(
            action_space=gym.spaces.Discrete(num_actions),
            observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces,
            goal_sensor_uuid=self.goal_sensor_uuid,
            hidden_size=512+512,
            clip_rgb_preprocessor_uuid='rgb_clip_vit',
            clip_embedding_dim=512,
            # embedding params
            clip_model_type=self.clip_model_type,
            prompt = self.prompt,
            multi_p_mode = self.multi_p_mode,
            meta_mode = self.meta_mode,
            noise_std = self.noise_std,
            source_model = self.source_model,
        )


@attr.s(kw_only=True)
class CONPEGRUActorCriticMixin:
    sensors: Sequence[Sensor] = attr.ib()
    clip_model_type: str = attr.ib()
    screen_size: int = attr.ib()
    goal_sensor_type: Type[Sensor] = attr.ib()
    pool: bool = attr.ib(default=False)
    pooling_type: str = attr.ib()
    target_types: List[str] = attr.ib()
    prompt: Tuple = attr.ib()
    multi_p_mode: str = attr.ib()
    meta_mode: bool = attr.ib()
    noise_std: float = attr.ib(default=0.0)
    sm_noise: tuple = attr.ib(default=0.0)
    source_model: str = attr.ib(default=None)

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        rgb_sensor = next((s for s in self.sensors if isinstance(s, RGBSensor)), None)
        goal_sensor = next((s for s in self.sensors if isinstance(s, GoalObjectTypeThorSensor)), None)
        self.goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, self.goal_sensor_type)),
            None,
        )
        preprocessor_model = NaivePreprocessor

        assert rgb_sensor is not None and goal_sensor is not None

        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_means)
                - np.array(preprocessor_model.CLIP_RGB_MEANS)
            )
            < 1e-5
        )
        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_sds)
                - np.array(preprocessor_model.CLIP_RGB_STDS)
            )
            < 1e-5
        )

        preprocessor = preprocessor_model(
                rgb_input_uuid=rgb_sensor.uuid,
                goal_sensor_uuid=self.goal_sensor_uuid,
                clip_model_type=self.clip_model_type,
                pool=self.pool,
                pooling_type=self.pooling_type,
                class_emb_only = False,
                output_uuid="rgb_clip_vit",
                task="imagenav",
            )
        
        self.preprocessor_output_shape = preprocessor.output_shape

        preprocessors = [
            preprocessor
        ]

        return preprocessors

    def create_model(self, num_actions: int, **kwargs) -> nn.Module:
        
        ActorCritic = CONPEObjectNavActorCritic
        
        return ActorCritic(
            action_space=gym.spaces.Discrete(num_actions),
            observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces,
            goal_sensor_uuid=self.goal_sensor_uuid,
            hidden_size=512+512,
            clip_rgb_preprocessor_uuid='rgb_clip_vit',
            clip_embedding_dim=512,
            # embedding params
            clip_model_type=self.clip_model_type,
            prompt = self.prompt,
            multi_p_mode = self.multi_p_mode,
            meta_mode = self.meta_mode,
            noise_std = self.noise_std,
            sm_noise = self.sm_noise,
            source_model = self.source_model,
        )


@attr.s(kw_only=True)
class ClipViTPreprocessGRUActorCriticMixin:
    sensors: Sequence[Sensor] = attr.ib()
    clip_model_type: str = attr.ib()
    screen_size: int = attr.ib()
    goal_sensor_type: Type[Sensor] = attr.ib()
    pool: bool = attr.ib(default=False)
    pooling_type: str = attr.ib()
    target_types: List[str] = attr.ib()
    prompt: Tuple = attr.ib()
    noise_std: float = attr.ib(default=0.0)

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        rgb_sensor = next((s for s in self.sensors if isinstance(s, RGBSensor)), None)
        goal_sensor = next((s for s in self.sensors if isinstance(s, GoalObjectTypeThorSensor)), None)

        if self.prompt:
            if isinstance(self.prompt[0], tuple):
                preprocessor_model = SNPromptClipViTPreprocessor
            else:
                preprocessor_model = PromptClipViTPreprocessor
        else:
            preprocessor_model = ClipViTPreprocessor
        

        assert rgb_sensor is not None and goal_sensor is not None

        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_means)
                - np.array(preprocessor_model.CLIP_RGB_MEANS)
            )
            < 1e-5
        )
        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_sds)
                - np.array(preprocessor_model.CLIP_RGB_STDS)
            )
            < 1e-5
        )

        preprocessor = preprocessor_model(
                rgb_input_uuid=rgb_sensor.uuid,
                clip_model_type=self.clip_model_type,
                pool=self.pool,
                pooling_type=self.pooling_type,
                class_emb_only = False,
                output_uuid="rgb_clip_vit",
                prompt = self.prompt,
                noise_std = self.noise_std,
            )
        
        self.preprocessor_output_shape = preprocessor.output_shape

        preprocessors = [
            preprocessor
        ]

        return preprocessors

    def create_model(self, num_actions: int, **kwargs) -> nn.Module:
        goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, self.goal_sensor_type)),
            None,
        )
        return CLIPObjectNavActorCritic(
            action_space=gym.spaces.Discrete(num_actions),
            observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces,
            goal_sensor_uuid=goal_sensor_uuid,
            hidden_size=self.preprocessor_output_shape[0],
            clip_rgb_preprocessor_uuid='rgb_clip_vit',
            clip_embedding_dim=self.preprocessor_output_shape[0]
        )