from typing import Sequence, Union, Optional, Dict, Tuple, Type, List

import attr
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torchvision import models
from gym.spaces.dict import Dict as SpaceDict
from allenact.algorithms.onpolicy_sync.losses import PPO
from allenact.algorithms.onpolicy_sync.losses.abstract_loss import (
    AbstractActorCriticLoss,
)
from allenact.algorithms.onpolicy_sync.losses.imitation import Imitation
from allenact.algorithms.onpolicy_sync.losses.ppo import PPOConfig
from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.base_abstractions.sensor import Sensor
from allenact.embodiedai.aux_losses.losses import (
    InverseDynamicsLoss,
    TemporalDistanceLoss,
    CPCA1Loss,
    CPCA2Loss,
    CPCA4Loss,
    CPCA8Loss,
    CPCA16Loss,
    MultiAuxTaskNegEntropyLoss,
    CPCA1SoftMaxLoss,
    CPCA2SoftMaxLoss,
    CPCA4SoftMaxLoss,
    CPCA8SoftMaxLoss,
    CPCA16SoftMaxLoss,
)
from allenact.embodiedai.preprocessors.resnet import ResNetPreprocessor
from allenact.embodiedai.sensors.vision_sensors import RGBSensor, DepthSensor
from allenact.embodiedai.models.visual_nav_models import VisualNavActorCritic
from allenact.utils.experiment_utils import (
    Builder,
    TrainingPipeline,
    PipelineStage,
    LinearDecay,
)
from allenact.algorithms.onpolicy_sync.policy import (
    ObservationType,
    DistributionType
)
from projects.plugins.ithor_plugin.ithor_sensors import GoalObjectTypeThorSensor
from projects.object_navigation.navigation.baseline_models import (
    ResnetTensorNavActorCritic,
    ObjectNavActorCritic,
)
from projects.object_navigation.navigation.model import IthorDisentangledVAE
from projects.plugins.robothor_plugin.robothor_tasks import ObjectNavTask
from allenact_plugins.clip_plugin.clip_preprocessors import (
    NaivePreprocessor,
)

@attr.s(kw_only=True)
class ResNetPreprocessGRUActorCriticMixin:
    sensors: Sequence[Sensor] = attr.ib()
    resnet_type: str = attr.ib()
    screen_size: int = attr.ib()
    goal_sensor_type: Type[Sensor] = attr.ib()

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        preprocessors = []

        if self.resnet_type in ["RN18", "RN34"]:
            output_shape = (512, 7, 7)
        elif self.resnet_type in ["RN50", "RN101", "RN152"]:
            output_shape = (2048, 7, 7)
        else:
            raise NotImplementedError(
                f"`RESNET_TYPE` must be one 'RNx' with x equaling one of"
                f" 18, 34, 50, 101, or 152."
            )

        rgb_sensor = next((s for s in self.sensors if isinstance(s, RGBSensor)), None)
        if rgb_sensor is not None:
            preprocessors.append(
                ResNetPreprocessor(
                    input_height=self.screen_size,
                    input_width=self.screen_size,
                    output_width=output_shape[2],
                    output_height=output_shape[1],
                    output_dims=output_shape[0],
                    pool=False,
                    torchvision_resnet_model=getattr(
                        models, f"resnet{self.resnet_type.replace('RN', '')}"
                    ),
                    input_uuids=[rgb_sensor.uuid],
                    output_uuid="rgb_resnet_imagenet",
                )
            )

        depth_sensor = next(
            (s for s in self.sensors if isinstance(s, DepthSensor)), None
        )
        if depth_sensor is not None:
            preprocessors.append(
                ResNetPreprocessor(
                    input_height=self.screen_size,
                    input_width=self.screen_size,
                    output_width=output_shape[2],
                    output_height=output_shape[1],
                    output_dims=output_shape[0],
                    pool=False,
                    torchvision_resnet_model=getattr(
                        models, f"resnet{self.resnet_type.replace('RN', '')}"
                    ),
                    input_uuids=[depth_sensor.uuid],
                    output_uuid="depth_resnet_imagenet",
                )
            )

        return preprocessors

    def create_model(self, **kwargs) -> nn.Module:
        has_rgb = any(isinstance(s, RGBSensor) for s in self.sensors)
        has_depth = any(isinstance(s, DepthSensor) for s in self.sensors)
        goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, self.goal_sensor_type)),
            None,
        )

        return ResnetTensorNavActorCritic(
            action_space=gym.spaces.Discrete(len(ObjectNavTask.class_action_names())),
            observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces,
            goal_sensor_uuid=goal_sensor_uuid,
            rgb_resnet_preprocessor_uuid="rgb_resnet_imagenet" if has_rgb else None,
            depth_resnet_preprocessor_uuid="depth_resnet_imagenet"
            if has_depth
            else None,
            hidden_size=512,
            goal_dims=32,
        )


@attr.s(kw_only=True)
class ObjectNavUnfrozenResNetWithGRUActorCriticMixin:
    backbone: str = attr.ib()
    sensors: Sequence[Sensor] = attr.ib()
    auxiliary_uuids: Sequence[str] = attr.ib()
    add_prev_actions: bool = attr.ib()
    multiple_beliefs: bool = attr.ib()
    belief_fusion: Optional[str] = attr.ib()

    def create_model(self, **kwargs) -> nn.Module:
        rgb_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, RGBSensor)), None
        )
        depth_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, DepthSensor)), None
        )
        goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, GoalObjectTypeThorSensor))
        )

        return ObjectNavActorCritic(
            action_space=gym.spaces.Discrete(len(ObjectNavTask.class_action_names())),
            observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces,
            rgb_uuid=rgb_uuid,
            depth_uuid=depth_uuid,
            goal_sensor_uuid=goal_sensor_uuid,
            hidden_size=192
            if self.multiple_beliefs and len(self.auxiliary_uuids) > 1
            else 512,
            backbone=self.backbone,
            resnet_baseplanes=32,
            object_type_embedding_dim=32,
            num_rnn_layers=1,
            rnn_type="GRU",
            add_prev_actions=self.add_prev_actions,
            action_embed_size=6,
            auxiliary_uuids=self.auxiliary_uuids,
            multiple_beliefs=self.multiple_beliefs,
            beliefs_fusion=self.belief_fusion,
        )


class ObjectNavDAggerMixin:
    @staticmethod
    def training_pipeline(
        advance_scene_rollout_period: Optional[int] = None,
    ) -> TrainingPipeline:
        training_steps = int(300000000)
        tf_steps = int(5e6)
        anneal_steps = int(5e6)
        il_no_tf_steps = training_steps - tf_steps - anneal_steps
        assert il_no_tf_steps > 0

        lr = 3e-4
        num_mini_batch = 1
        update_repeats = 4
        num_steps = 128
        # save_interval = 5000000
        save_interval = 10000
        log_interval = 10000 if torch.cuda.is_available() else 1
        gamma = 0.99
        use_gae = True
        gae_lambda = 0.95
        max_grad_norm = 0.5
        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=log_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses={"imitation_loss": Imitation(),},
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=advance_scene_rollout_period,
            pipeline_stages=[
                PipelineStage(
                    loss_names=["imitation_loss"],
                    max_stage_steps=tf_steps,
                    teacher_forcing=LinearDecay(startp=1.0, endp=1.0, steps=tf_steps,),
                ),
                PipelineStage(
                    loss_names=["imitation_loss"],
                    max_stage_steps=anneal_steps + il_no_tf_steps,
                    teacher_forcing=LinearDecay(
                        startp=1.0, endp=0.0, steps=anneal_steps,
                    ),
                ),
            ],
            lr_scheduler_builder=Builder(
                LambdaLR, {"lr_lambda": LinearDecay(steps=training_steps)},
            ),
        )


def update_with_auxiliary_losses(
    named_losses: Dict[str, Tuple[AbstractActorCriticLoss, float]],
    auxiliary_uuids: Sequence[str],
    multiple_beliefs: bool,
) -> Dict[str, Tuple[AbstractActorCriticLoss, float]]:
    # auxliary losses
    aux_loss_total_weight = 2.0

    # Total losses
    total_aux_losses: Dict[str, Tuple[AbstractActorCriticLoss, float]] = {
        InverseDynamicsLoss.UUID: (
            InverseDynamicsLoss(
                subsample_rate=0.2, subsample_min_num=10,  # TODO: test its effects
            ),
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        TemporalDistanceLoss.UUID: (
            TemporalDistanceLoss(
                num_pairs=8, epsiode_len_min=5,  # TODO: test its effects
            ),
            0.2 * aux_loss_total_weight,  # should times 2
        ),
        CPCA1Loss.UUID: (
            CPCA1Loss(subsample_rate=0.2,),  # TODO: test its effects
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA2Loss.UUID: (
            CPCA2Loss(subsample_rate=0.2,),  # TODO: test its effects
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA4Loss.UUID: (
            CPCA4Loss(subsample_rate=0.2,),  # TODO: test its effects
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA8Loss.UUID: (
            CPCA8Loss(subsample_rate=0.2,),  # TODO: test its effects
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA16Loss.UUID: (
            CPCA16Loss(subsample_rate=0.2,),  # TODO: test its effects
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA1SoftMaxLoss.UUID: (
            CPCA1SoftMaxLoss(subsample_rate=1.0,),
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA2SoftMaxLoss.UUID: (
            CPCA2SoftMaxLoss(subsample_rate=1.0,),
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA4SoftMaxLoss.UUID: (
            CPCA4SoftMaxLoss(subsample_rate=1.0,),
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA8SoftMaxLoss.UUID: (
            CPCA8SoftMaxLoss(subsample_rate=1.0,),
            0.05 * aux_loss_total_weight,  # should times 2
        ),
        CPCA16SoftMaxLoss.UUID: (
            CPCA16SoftMaxLoss(subsample_rate=1.0,),
            0.05 * aux_loss_total_weight,  # should times 2
        ),
    }
    named_losses.update({uuid: total_aux_losses[uuid] for uuid in auxiliary_uuids})

    if multiple_beliefs:  # add weight entropy loss automatically
        named_losses[MultiAuxTaskNegEntropyLoss.UUID] = (
            MultiAuxTaskNegEntropyLoss(auxiliary_uuids),
            0.01,
        )

    return named_losses


class ObjectNavPPOMixin:
    @staticmethod
    def training_pipeline(
        auxiliary_uuids: Sequence[str],
        multiple_beliefs: bool,
        normalize_advantage: bool = True,
        advance_scene_rollout_period: Optional[int] = None,
        lr=3e-4,
        num_mini_batch=1,
        update_repeats=4,
        num_steps=100,
        save_interval=10000,
        log_interval=10000 if torch.cuda.is_available() else 1,
        gamma=0.99,
        use_gae=True,
        gae_lambda=0.95,
        max_grad_norm=0.5,
        anneal_lr: bool = True,
        extra_losses: Optional[Dict[str, Tuple[AbstractActorCriticLoss, float]]] = None,
    ) -> TrainingPipeline:
        ppo_steps = int(4000000)

        named_losses = {
            "ppo_loss": (
                PPO(**PPOConfig, normalize_advantage=normalize_advantage),
                1.0,
            ),
            **({} if extra_losses is None else extra_losses),
        }
        named_losses = update_with_auxiliary_losses(
            named_losses=named_losses,
            auxiliary_uuids=auxiliary_uuids,
            multiple_beliefs=multiple_beliefs,
        )

        return TrainingPipeline(
            save_interval=save_interval,
            metric_accumulate_interval=log_interval,
            optimizer_builder=Builder(optim.Adam, dict(lr=lr)),
            num_mini_batch=num_mini_batch,
            update_repeats=update_repeats,
            max_grad_norm=max_grad_norm,
            num_steps=num_steps,
            named_losses={key: val[0] for key, val in named_losses.items()},
            gamma=gamma,
            use_gae=use_gae,
            gae_lambda=gae_lambda,
            advance_scene_rollout_period=advance_scene_rollout_period,
            pipeline_stages=[
                PipelineStage(
                    loss_names=list(named_losses.keys()),
                    max_stage_steps=ppo_steps,
                    loss_weights=[val[1] for val in named_losses.values()],
                )
            ],
            lr_scheduler_builder=Builder(
                LambdaLR, {"lr_lambda": LinearDecay(steps=300000000)}
            )
            if anneal_lr
            else None,
        )


# policies
@attr.s(kw_only=True)
class VAEPreprocessGRUActorCriticMixin:
    sensors: Sequence[Sensor] = attr.ib()
    screen_size: int = attr.ib()
    goal_sensor_type: Type[Sensor] = attr.ib()
    target_types: List[str] = attr.ib()
    pool: bool = attr.ib(default=False)
    pooling_type: str = attr.ib()
    clip_model_type: str = attr.ib()
    source_model: str = attr.ib(default=None)

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        rgb_sensor = next((s for s in self.sensors if isinstance(s, RGBSensor)), None)
        goal_sensor = next((s for s in self.sensors if isinstance(s, GoalObjectTypeThorSensor)), None)
        self.goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, self.goal_sensor_type)),
            None,
        )
        preprocessor_model = NaivePreprocessor

        assert rgb_sensor is not None and goal_sensor is not None

        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_means)
                - np.array(preprocessor_model.CLIP_RGB_MEANS)
            )
            < 1e-5
        )
        assert (
            np.linalg.norm(
                np.array(rgb_sensor._norm_sds)
                - np.array(preprocessor_model.CLIP_RGB_STDS)
            )
            < 1e-5
        )

        preprocessor = preprocessor_model(
                rgb_input_uuid=rgb_sensor.uuid,
                goal_sensor_uuid=self.goal_sensor_uuid,
                clip_model_type=self.clip_model_type,
                pool=self.pool,
                pooling_type=self.pooling_type,
                class_emb_only = False,
                output_uuid="rgb_clip_vit",
            )
        
        self.preprocessor_output_shape = preprocessor.output_shape

        preprocessors = [
            preprocessor
        ]

        return preprocessors

        return preprocessors

    def create_model(self, num_actions: int, **kwargs) -> nn.Module:
        goal_sensor_uuid = next(
            (s.uuid for s in self.sensors if isinstance(s, self.goal_sensor_type)),
            None,
        )
        return VAEObjectNavActorCritic(
            action_space=gym.spaces.Discrete(num_actions),
            observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces,
            goal_sensor_uuid=goal_sensor_uuid,
            hidden_size=1024,
            rgb_preprocessor_uuid='rgb_clip_vit',
            embedding_dim = 32 + 1,
            source_model=self.source_model
        )


class VAEObjectNavActorCritic(VisualNavActorCritic):
    def __init__(
        # base params
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str,
        rgb_preprocessor_uuid: str,

        # RNN
        hidden_size=1024,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        # custom params
        embedding_dim: int = 512,
        source_model=None
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size
        )

        assert rgb_preprocessor_uuid is not None

        self.rgb_preprocessor_uuid = rgb_preprocessor_uuid

        self.create_state_encoders(
            obs_embed_size=embedding_dim,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=embedding_dim,
            action_embed_size=action_embed_size,
        )

        self.train()

        self.embedder = IthorDisentangledVAE(class_latent_size=16, content_latent_size = 32)
        self.embedder.load_state_dict(torch.load(source_model))
        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.embedder.named_parameters():
            param.requires_grad_(False)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    @property
    def is_blind(self) -> bool:
        return False

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:
        # observaion shaping
        obs = observations[self.rgb_preprocessor_uuid].to(self.device)
        x = obs[:, :, :3*224*224].detach().clone()
        B, env_n, _ = x.shape
        x = x.view(B*env_n, 3, 224, 224)
        goal = obs[:, :, 3*224*224:].detach().clone()
        x = self.embedder.encoder.get_feature(x)
        x = x.view(B, env_n, x.size(-1))
        x = torch.cat([x, goal], dim=-1)
        return x