import argparse
import json
import logging
import os
import shutil
from time import time
import uuid

import importlib
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from pydantic.dataclasses import dataclass

from pathlib import Path

import wandb
from pythae.data.preprocessors import DataProcessor
from pythae.models import AutoModel
from pythae.config import BaseConfig
from pythae.trainers import BaseTrainer, BaseTrainerConfig
from pythae.models.base.base_utils import ModelOutput


from sklearn.cluster import KMeans

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)

PATH = os.path.dirname(os.path.abspath(__file__))

ap = argparse.ArgumentParser()

ap.add_argument(
    "--models_path",
    help="The path to a model to generate from",
    required=True,
)
ap.add_argument(
    "--n_runs",
    type=int,
    default=20
)
ap.add_argument(
    "--use_wandb",
    help="whether to log the metrics in wandb",
    action="store_true",
)
ap.add_argument(
    "--wandb_project",
    help="wandb project name",
    default="latent_dim_sensi_clustering",
)
ap.add_argument(
    "--wandb_entity",
    help="wandb entity name",
    default="benchmark_team",
)

args = ap.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"


def main(args):

    model_signature = os.listdir(args.models_path)[0]

    model_path = os.path.join(args.models_path, model_signature, "final_model")

    # reload the model
    trained_model = AutoModel.load_from_folder(model_path).to(device)
    trained_model.eval()

    logger.info(f"Successfully reloaded {trained_model.model_name.upper()} model !\n")

    train_data = None
    eval_data = None

    if trained_model.model_config.input_dim == (1, 28, 28):
        dataset = 'mnist'

    elif trained_model.model_config.input_dim == (3, 32, 32):
        dataset = 'cifar10'

    try:
        logger.info(f"\nLoading {dataset} data...\n")
        train_data = (
                np.load(os.path.join(PATH, f"data/{dataset}", "train_data.npz"))[
                    "data"
                ]
                / 255.0
            )
        train_targets = (
            np.load(os.path.join(PATH, f"data/{dataset}", "train_labels.npz"))["targets"]
        )

        eval_data = (
            np.load(os.path.join(PATH, f"data/{dataset}", "eval_data.npz"))["data"]
            / 255.0
        )
        eval_targets = (
            np.load(os.path.join(PATH, f"data/{dataset}", "eval_labels.npz"))["targets"]
        )


        test_data = (
            np.load(os.path.join(PATH, f"data/{dataset}", "test_data.npz"))["data"]
            / 255.0
        )
        test_targets = (
            np.load(os.path.join(PATH, f"data/{dataset}", "test_labels.npz"))["targets"]
        )
        
        
    except Exception as e:
        raise FileNotFoundError(
            f"Unable to load the data from 'data/{dataset}' folder. Please check that both a "
            "'train_data.npz' and 'eval_data.npz' are present in the folder.\n Data must be "
            " under the key 'data', in the range [0-255] and shaped with channel in first "
            "position\n"
            f"Exception raised: {type(e)} with message: " + str(e)
        ) from e

    logger.info("Successfully loaded data !\n")
    logger.info("------------------------------------------------------------")
    logger.info("Dataset \t \t Shape \t \t \t Range")
    logger.info(
            f"{dataset.upper()} train data: \t {train_data.shape, train_targets.shape} \t [{train_data.min()}-{train_data.max()}] "
        )
    logger.info(
        f"{dataset.upper()} eval data: \t {eval_data.shape, eval_targets.shape} \t [{eval_data.min()}-{eval_data.max()}] "
    )
    logger.info(
        f"{dataset.upper()} test data: \t {test_data.shape, test_targets.shape} \t [{test_data.min()}-{test_data.max()}]"
    )
    logger.info("------------------------------------------------------------\n")

    dataset_type = (
        "DoubleBatchDataset"
        if trained_model.model_name == "FactorVAE"
        else "BaseDataset"
    )

    data_processor = DataProcessor()
    train_data = data_processor.process_data(train_data).to(device)
    train_dataset = data_processor.to_dataset(train_data, dataset_type=dataset_type)
    train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=False)

    z = []

    try:
        with torch.no_grad():
            for _, inputs in enumerate(train_loader):
                encoder_output = trained_model(inputs)
                z_ = encoder_output.z
                z.append(z_)

    except RuntimeError:
        for _, inputs in enumerate(train_loader):
            encoder_output = trained_model(inputs)
            z_ = encoder_output.z.detach()
            z.append(z_)

    train_data = torch.cat(z)
    train_dataset = data_processor.to_dataset(data=train_data, labels=torch.tensor(train_targets).type(torch.long))
    
    eval_dataset = None

    if eval_data is not None:

        assert (
            eval_data.max() >= 1 and eval_data.min() >= 0
        ), "Eval data must in the range [0-1]"

        eval_data = data_processor.process_data(eval_data).to(device)
        eval_dataset = data_processor.to_dataset(eval_data, dataset_type=dataset_type)
        eval_loader = DataLoader(
            dataset=eval_dataset, batch_size=100, shuffle=False
        )

        z = []
        try:
            with torch.no_grad():
                for _, inputs in enumerate(eval_loader):
                    encoder_output = trained_model(inputs)
                    z_ = encoder_output.z
                    z.append(z_)

        except RuntimeError:
            for _, inputs in enumerate(eval_loader):
                encoder_output = trained_model(inputs)
                z_ = encoder_output.z.detach()
                z.append(z_)

        eval_data = torch.cat(z)
        eval_dataset = data_processor.to_dataset(data=eval_data, labels=torch.tensor(eval_targets).type(torch.long))
        eval_loader = DataLoader(
            dataset=eval_dataset, batch_size=100, shuffle=False
        )

    if test_data is not None:

        assert (
            test_data.max() >= 1 and test_data.min() >= 0
        ), "Test data must in the range [0-1]"

        test_data = data_processor.process_data(test_data).to(device)
        test_dataset = data_processor.to_dataset(test_data, dataset_type=dataset_type)
        test_loader = DataLoader(
            dataset=test_dataset, batch_size=100, shuffle=False
        )

        z = []
        try:
            with torch.no_grad():
                for _, inputs in enumerate(test_loader):
                    encoder_output = trained_model(inputs)
                    z_ = encoder_output.z
                    z.append(z_)

        except RuntimeError:
            for _, inputs in enumerate(test_loader):
                encoder_output = trained_model(inputs)
                z_ = encoder_output.z.detach()
                z.append(z_)

        test_data = torch.cat(z)
        test_dataset = data_processor.to_dataset(data=test_data, labels=torch.tensor(test_targets).type(torch.long))
        test_loader = DataLoader(
            dataset=test_dataset, batch_size=100, shuffle=False
        )

    train_acc = []
    eval_acc = []
    test_acc = []

    for i in range(args.n_runs):
        
        print(f"fits kmeans {i}")

        # fit kmeans on train
        X = train_data.detach().cpu().numpy()
        kmeans = KMeans(n_clusters=10).fit(X)
        
        # init label info arrays
        labels_permutation = [] #which kmeans cluster corresponds to which label
        labels = train_targets#.detach().cpu().numpy().squeeze() #real labels
        predicted_labels = np.zeros(labels.shape) #final predicted labels
        
        # assign most common label to each cluster on train
        for label in range(10):
            pred_label_ids = (kmeans.labels_ == label)
            most_common_label = np.bincount(labels[pred_label_ids]).argmax()
            labels_permutation.append(most_common_label)
            predicted_labels[pred_label_ids] = most_common_label

        train_acc.append( (predicted_labels == labels).mean() )

        # do the same on eval
        X_eval = eval_data.detach().cpu().numpy()
        kmeans_labels = kmeans.predict(X_eval)

        labels = eval_targets#.detach().cpu().numpy().squeeze()
        predicted_labels = np.copy(kmeans_labels)
        for i,label in enumerate(labels_permutation):
            predicted_labels[predicted_labels == i] = label
        
        eval_acc.append( (predicted_labels == labels).mean() )
        
        # do the same on test
        X_test = test_data.detach().cpu().numpy()
        kmeans_labels = kmeans.predict(X_test)

        labels = test_targets#.detach().cpu().numpy().squeeze()
        predicted_labels = np.copy(kmeans_labels)
        for i,label in enumerate(labels_permutation):
            predicted_labels[predicted_labels == i] = label
        
        test_acc.append( (predicted_labels == labels).mean() )
        
    print("-------------------------------------")
    print(f"mean train accuracy : {np.mean(train_acc)}")
    print(f"std train accuracy : {np.std(train_acc)}")
    print("-------------------------------------")
    print(f"mean eval accuracy : {np.mean(eval_acc)}")
    print(f"std eval accuracy : {np.std(eval_acc)}")
    print("-------------------------------------")
    print(f"mean test accuracy : {np.mean(test_acc)}")
    print(f"std test accuracy : {np.std(test_acc)}")

    if args.use_wandb:
        
        if importlib.util.find_spec("wandb") is not None:
            
            wandb.init(project=args.wandb_project, entity=args.wandb_entity)
            wandb.config.update(
                {   
                    "n_runs": args.n_runs,
                    "model_path": model_path,
                    "model_config": trained_model.model_config.to_dict()
                }
            )

        else:
            raise ModuleNotFoundError(
                "`wandb` package must be installed. Run `pip install wandb`"
            )

        wandb.log(
            {
                "train/mean_accuracy": np.mean(train_acc),
                "train/std_accuracy": np.std(train_acc),
                "eval/mean_accuracy": np.mean(eval_acc),
                "eval/std_accuracy": np.std(eval_acc),
                "test/mean_accuracy": np.mean(test_acc),
                "test/std_accuracy": np.std(test_acc),
                })

if __name__ == "__main__":

    main(args)
