from typing import Sequence, Union

import torch.nn as nn

from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.utils.experiment_utils import Builder, TrainingPipeline
from allenact_plugins.clip_plugin.clip_preprocessors import (
    ClipViTPreprocessor,
    NaivePreprocessor
)
from projects.plugins.ithor_plugin.ithor_sensors import (
    GoalObjectTypeThorSensor,
    RGBSensorThor,
)
from projects.object_navigation.baseline_configs.navigation_base import (
    VAEPreprocessGRUActorCriticMixin
)
from projects.object_navigation.baseline_configs.ithor.objectnav_ithor_base import (
    ObjectNaviThorMultiMDPsBaseConfig
)
from projects.object_navigation.baseline_configs.navigation_base import ObjectNavPPOMixin


class ObjectNaviThorLUSRRGBPPOExperimentConfig(ObjectNaviThorMultiMDPsBaseConfig):
    """A CLIP Object Navigation experiment configuration in RoboThor
    with RGB input."""
    CLIP_MODEL_TYPE = "ViT-B/32"
    SOURCE_MODEL = '/home/chois/MMRL/LUSR/checkpoints/domain_factor/model_200_ithor_cnn.pt'

    SENSORS = [
        RGBSensorThor(
            height=ObjectNaviThorMultiMDPsBaseConfig.SCREEN_SIZE,
            width=ObjectNaviThorMultiMDPsBaseConfig.SCREEN_SIZE,
            use_resnet_normalization=True,
            mean=ClipViTPreprocessor.CLIP_RGB_MEANS,
            stdev=ClipViTPreprocessor.CLIP_RGB_STDS,
            uuid="rgb_lowres",
        ),
        GoalObjectTypeThorSensor(object_types=ObjectNaviThorMultiMDPsBaseConfig.TARGET_TYPES,),
    ]
    
    PROMPT = False

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        #### Domain defined by domain factors ####
        self.STEP_SIZE =             [0.1, 0.15, 0.25, 0.35]
        self.ROTATION_DEGREES =      [90.0, 60.0, 30.0, 10.0]
        self.VISIBILITY_DISTANCE =   [1.0 , 1.0, 1.0, 1.0]
        self.LIGHTING_VALUE =        [(0.6, 0.2, 1.5, -0.4), (1.1, 1.0, 0.5, -0.1), None, (2.0, 3.5, 2, 0.4)]
        self.HORIZONTAL_FIELD_OF_VIEW = [59, 69, 79, 99]
        self.LOOK_DEGREES = [40, 10, 30, 20]
        ##########################################
        self.DATA_GEN = False

        self.preprocessing_and_model = VAEPreprocessGRUActorCriticMixin(
            sensors=self.SENSORS,
            clip_model_type=self.CLIP_MODEL_TYPE,
            screen_size=self.SCREEN_SIZE,
            goal_sensor_type=GoalObjectTypeThorSensor,
            target_types=self.TARGET_TYPES,
            pool=False,
            pooling_type='',
            source_model=self.SOURCE_MODEL
        )

    def training_pipeline(self, **kwargs) -> TrainingPipeline:
        return ObjectNavPPOMixin.training_pipeline(
            auxiliary_uuids=[],
            multiple_beliefs=False,
            advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD,
        )

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        return self.preprocessing_and_model.preprocessors()

    def create_model(self, **kwargs) -> nn.Module:
        return self.preprocessing_and_model.create_model(
            num_actions=self.ACTION_SPACE.n, **kwargs
        )

    def tag(cls):
        return "ObjectNav-iTHOR-RGB-LUSR"
# objectnav_ithor_rgb_lusrgru_ddppo