from dataclasses import dataclass
from typing import Optional

from omegaconf import OmegaConf, MISSING

from .modules.configs import (
    DataEncoderConfig,
    GridCoordSamplerConfig,
    RayCoordSamplerConfig,
    DecoderWithCrossAttentionConfig,
    MultiBandDecoderWithCrossAttentionConfig,
    TransformerConfig,
)


@dataclass
class GINRWithXAttnDecoderConfig:
    type: str = "ginr_with_xattn_decoder"
    ema: Optional[bool] = None

    data_encoder: DataEncoderConfig = DataEncoderConfig()
    transformer: TransformerConfig = TransformerConfig()

    num_data_tokens: int = MISSING
    num_latent_tokens: int = MISSING
    output_dim: int = MISSING
    latent_dim: int = MISSING
    use_latent_projection: bool = False

    coord_sampler: GridCoordSamplerConfig = GridCoordSamplerConfig()
    decoder: DecoderWithCrossAttentionConfig = DecoderWithCrossAttentionConfig()

    @classmethod
    def create(cls, config):
        # We need to specify the type of the default DataEncoderConfig.
        # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
        # hence merging with the config with other type would cause config error.
        default_dataenc_config = DataEncoderConfig(type=config.data_encoder.type)
        defaults = OmegaConf.structured(cls(ema=False, data_encoder=default_dataenc_config))
        config = OmegaConf.merge(defaults, config)  # type: GINRWithXAttnDecoderConfig
        config.transformer.block.embed_dim = config.transformer.embed_dim
        config.decoder.latent_dim = config.latent_dim

        return config


@dataclass
class GINRWithMultiBandDecoderConfig:
    type: str = "ginr_with_multi_band_decoder"
    ema: Optional[bool] = None

    data_encoder: DataEncoderConfig = DataEncoderConfig()
    transformer: TransformerConfig = TransformerConfig()

    num_data_tokens: int = MISSING
    num_latent_tokens: int = MISSING
    output_dim: int = MISSING
    latent_dim: int = MISSING
    use_latent_projection: bool = False

    coord_sampler: GridCoordSamplerConfig = GridCoordSamplerConfig()
    decoder: MultiBandDecoderWithCrossAttentionConfig = MultiBandDecoderWithCrossAttentionConfig()

    @classmethod
    def create(cls, config):
        # We need to specify the type of the default DataEncoderConfig.
        # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
        # hence merging with the config with other type would cause config error.
        default_dataenc_config = DataEncoderConfig(type=config.data_encoder.type)
        defaults = OmegaConf.structured(cls(ema=False, data_encoder=default_dataenc_config))
        config = OmegaConf.merge(defaults, config)  # type: GINRWithMultiBandDecoderConfig
        config.transformer.block.embed_dim = config.transformer.embed_dim
        config.decoder.latent_dim = config.latent_dim

        return config


@dataclass
class GNeRFWithXAttnDecoderConfig:
    type: str = "gnerf_with_xattn_decoder"
    ema: Optional[bool] = None

    data_encoder: DataEncoderConfig = DataEncoderConfig()
    transformer: TransformerConfig = TransformerConfig()

    num_data_tokens: int = MISSING
    num_latent_tokens: int = MISSING
    latent_dim: int = MISSING

    coord_sampler: RayCoordSamplerConfig = RayCoordSamplerConfig()
    decoder: DecoderWithCrossAttentionConfig = DecoderWithCrossAttentionConfig()

    @classmethod
    def create(cls, config):
        # We need to specify the type of the default DataEncoderConfig.
        # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
        # hence merging with the config with other type would cause config error.
        default_dataenc_config = DataEncoderConfig(type=config.data_encoder.type)
        defaults = OmegaConf.structured(cls(ema=False, data_encoder=default_dataenc_config))
        config = OmegaConf.merge(defaults, config)  # type: GNeRFWithXAttnDecoderConfig
        config.transformer.block.embed_dim = config.transformer.embed_dim
        config.decoder.latent_dim = config.latent_dim

        return config


@dataclass
class LightFieldWithXAttnDecoderConfig:
    type: str = "light_field_with_xattn_decoder"
    ema: Optional[bool] = None

    data_encoder: DataEncoderConfig = DataEncoderConfig()
    transformer: TransformerConfig = TransformerConfig()

    num_data_tokens: int = MISSING
    num_latent_tokens: int = MISSING
    latent_dim: int = MISSING

    # `coord_sampler` is removed here unlike in GNeRFWithXAttnDecoderConfig.
    # LightFieldWithXAttnDecoder converts rays into Plucker coordinate (as opposed to sampling points along the rays),
    # and this is directly dealt with the method `sample_coord_inputs`.

    decoder: DecoderWithCrossAttentionConfig = DecoderWithCrossAttentionConfig()

    @classmethod
    def create(cls, config):
        # We need to specify the type of the default DataEncoderConfig.
        # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
        # hence merging with the config with other type would cause config error.
        default_dataenc_config = DataEncoderConfig(type=config.data_encoder.type)
        defaults = OmegaConf.structured(cls(ema=False, data_encoder=default_dataenc_config))
        config = OmegaConf.merge(defaults, config)  # type: LightFieldWithXAttnDecoderConfig
        config.transformer.block.embed_dim = config.transformer.embed_dim
        config.decoder.latent_dim = config.latent_dim

        return config


@dataclass
class LightFieldWithMultiBandDecoderConfig:
    type: str = "light_field_with_multi_band_decoder"
    ema: Optional[bool] = None

    data_encoder: DataEncoderConfig = DataEncoderConfig()
    transformer: TransformerConfig = TransformerConfig()

    num_data_tokens: int = MISSING
    num_latent_tokens: int = MISSING
    latent_dim: int = MISSING

    # `coord_sampler` is removed here unlike in GNeRFWithXAttnDecoderConfig.
    # LightFieldWithXAttnDecoder converts rays into Plucker coordinate (as opposed to sampling points along the rays),
    # and this is directly dealt with the method `sample_coord_inputs`.

    decoder: MultiBandDecoderWithCrossAttentionConfig = MultiBandDecoderWithCrossAttentionConfig()

    @classmethod
    def create(cls, config):
        # We need to specify the type of the default DataEncoderConfig.
        # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
        # hence merging with the config with other type would cause config error.
        default_dataenc_config = DataEncoderConfig(type=config.data_encoder.type)
        defaults = OmegaConf.structured(cls(ema=False, data_encoder=default_dataenc_config))
        config = OmegaConf.merge(defaults, config)  # type: MultiBandDecoderWithCrossAttentionConfig
        config.transformer.block.embed_dim = config.transformer.embed_dim
        config.decoder.latent_dim = config.latent_dim

        return config
