from allenact.utils.experiment_utils import TrainingPipeline
from projects.plugins.ithor_plugin.ithor_sensors import RGBSensorThor
from projects.plugins.robothor_plugin.robothor_sensors import GPSCompassSensorRoboThor
from projects.point_navigation.baseline_configs.mixins import (
    PointNavUnfrozenResNetWithGRUActorCriticMixin,
    PointNavPPOMixin
)
from projects.point_navigation.baseline_configs.ithor.pointnav_ithor_base import (
    PointNaviThorBaseConfig,
)


class PointNaviThorRGBPPOExperimentConfig(PointNaviThorBaseConfig):
    """An Point Navigation experiment configuration in iThor with RGB input."""

    SENSORS = [
        RGBSensorThor(
            height=PointNaviThorBaseConfig.SCREEN_SIZE,
            width=PointNaviThorBaseConfig.SCREEN_SIZE,
            use_resnet_normalization=True,
            uuid="rgb_lowres",
        ),
        GPSCompassSensorRoboThor(),
    ]

    def __init__(self):
        super().__init__()

        #### Domain defined by domain factors ####
        self.STEP_SIZE =             [0.25]
        self.ROTATION_DEGREES =      [30.0]
        self.VISIBILITY_DISTANCE =   [1.5]
        self.LIGHTING_VALUE =        [None]
        self.HORIZONTAL_FIELD_OF_VIEW = [79]
        self.LOOK_DEGREES = [30]

        self.DATA_GEN = False

        self.model_creation_handler = PointNavUnfrozenResNetWithGRUActorCriticMixin(
            backbone="simple_cnn",
            sensors=self.SENSORS,
            auxiliary_uuids=[],
            add_prev_actions=True,
            multiple_beliefs=False,
            belief_fusion=None,
        )

    def training_pipeline(self, **kwargs) -> TrainingPipeline:
        return PointNavPPOMixin.training_pipeline(
            auxiliary_uuids=[],
            multiple_beliefs=False,
            normalize_advantage=True,
            advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD,
        )

    def create_model(self, **kwargs):
        return self.model_creation_handler.create_model(**kwargs)

    def tag(self):
        return "PointNav-iTHOR-RGB-SimpleConv-DDPPO"
