import pickle
from torch.utils.data import DataLoader
import torch
from omegaconf import OmegaConf
import os
from os.path import join
from ntldm.networks import VAE  # Import your model definition here

def load_LDS_dataset_and_create_dataloaders(
    file_path, return_indices=False, batch_size=100
):
    """ load the LDS dataset and create dataloaders """
    with open(file_path, "rb") as f:
        data = pickle.load(f)

    dataset = data["dataset"]
    train_indices, valid_indices, test_indices = (
        data["train_indices"],
        data["valid_indices"],
        data["test_indices"],
    )

    lds_dataloader = DataLoader(
        dataset.samples[train_indices],
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )
    lds_dataloader_valid = DataLoader(
        dataset.samples[valid_indices],
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
    )
    lds_dataloader_test = DataLoader(
        dataset.samples[test_indices],
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
    )

    if return_indices:
        return (
            lds_dataloader,
            lds_dataloader_valid,
            lds_dataloader_test,
            train_indices,
            valid_indices,
            test_indices,
        )
    else:
        return lds_dataloader, lds_dataloader_valid, lds_dataloader_test


def load_LDS_dataset(file_path, indices=False):
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    if indices:
        return data["dataset"], data["train_indices"], data["valid_indices"], data["test_indices"]
    else:
        return data["dataset"]



def load_vae_data(pickle_file_path):
    """
    Load VAE-rembeddings (true samples, latents, reconstructions, mu, logvar) from a pickle file.
    
    :param pickle_file_path: Path to the pickle file containing the saved tensors.
    :return: A dictionary containing the loaded PyTorch tensors.
    """
    with open(pickle_file_path, "rb") as f:
        loaded_data = pickle.load(f)

    # Optionally, you could directly return the loaded_data dictionary
    return {
        "true_samples": loaded_data["true_samples"],
        "true_latents": loaded_data["true_latents"],
        "reconstructions": loaded_data["reconstructions"],
        "mu": loaded_data["mu"],
        "logvar": loaded_data["logvar"],
    }



def load_model_deprecated(run_dir, device, ModelClass):
    # Load model (make sure to define your model structure in MyModel)
    model = ModelClass().to(device)
    model.load_state_dict(torch.load(run_dir, map_location=device))
    model.eval()
    return model
    
    
def load_config(config_path):
    # Load the YAML config file
    cfg = OmegaConf.load(config_path)
    return cfg

def load_cfg_model_data(run_dir, device='cpu', model_name="model_end.pth", cfg_name="cfg.yaml"):
    # load the config file, the model and the dataloaders from the run_dir
    
    # load the config file cfg.yaml 
    cfg = load_config(join(run_dir, cfg_name))
    # load the model
    model = load_model(cfg, run_dir, device, model_name)
    # load dataset
    dataset = load_dataset(cfg)
    # load dataloaders
    dataloader, dataloader_valid, dataloader_test = load_dataloaders(cfg)

    return_dict = {"cfg": cfg, "model": model, "dataset": dataset, "dataloader": dataloader, "dataloader_valid": dataloader_valid, "dataloader_test": dataloader_test}
    return return_dict


def load_model(cfg, run_dir, device='cpu', model_name="model_end.pth"):
    # Load model (make sure to define your model structure in MyModel)
    model = init_model(cfg)
    model.load_state_dict(torch.load(join(run_dir, model_name),  map_location=device))
    return model


def load_train_test_datasets(cfg, indices=False):
    filepath = os.path.join(cfg.dataset.filepath, cfg.dataset.filename)
    with open(filepath, "rb") as f:
        data = pickle.load(f)
    if indices:
        return data["dataset"], data["train_indices"], data["valid_indices"], data["test_indices"]
    else:
        return data["dataset"]


def load_dataset(cfg,indices=True):
    """ load the dataset"""
    if cfg.dataset.name == "lds":
        filepath = os.path.join(cfg.dataset.filepath, cfg.dataset.filename)
        dataset = load_LDS_dataset(filepath)
    else:
        dataset = None
    return dataset

def load_dataloaders(cfg):
    """ load the dataloaders"""
    dataloader, dataloader_valid, dataloader_test = init_dataloaders(cfg)
    return dataloader, dataloader_valid, dataloader_test
    
    
def init_dataloaders(cfg):
    # Construct the filepath from the config
    filepath = os.path.join(cfg.dataset.filepath, cfg.dataset.filename)
    if cfg.dataset.name == "lds":
        # Assuming load_LDS_dataset_and_create_dataloaders is already defined and does what its name suggests
        lds_dataloader, lds_dataloader_valid, lds_dataloader_test = load_LDS_dataset_and_create_dataloaders(
            filepath, return_indices=False
        )

        # This function now returns both the dataloaders and the indices if needed
        return lds_dataloader, lds_dataloader_valid, lds_dataloader_test

    else:
        raise NotImplementedError(f"Dataset {cfg.dataset.name} not implemented yet!")


# TODO move to model
def init_model(cfg):
    """ Initialize the model based on the config"""
    if cfg.dataset.name == "lds":
        filepath = os.path.join(cfg.dataset.filepath, cfg.dataset.filename)
        dataset = load_LDS_dataset(filepath)

        d_dimension, latent_dim = (dataset.samples.shape[-1], dataset.latents.shape[-1])

        assert d_dimension == cfg.network.input_dim, "Input dimension mismatch"

        if cfg.network.name == "VAE":
            model = VAE(
                input_channels=cfg.network.input_dim,
                hidden_sizes=cfg.network.hidden_sizes,  # Assuming these are specified in your config
                bottleneck_size=cfg.network.latent_dim,
                output_channels=cfg.network.input_dim,
                C=dataset.C.T,
                b=dataset.b,
                frozen=True,
                bias=True,
            )

        else:
            raise NotImplementedError(f"Model {cfg.network.name} not implemented yet!")
    else:
        raise NotImplementedError(f"Dataset {cfg.dataset.name} not implemented yet!")

    return model
