"""Loads dataset ground truth clusters and similarities."""

import os
import pickle as pkl

import numpy as np
import pandas as pd
import scipy.sparse as sp
from sklearn.metrics.pairwise import cosine_similarity


UCI_DATASETS = [
    "glass",
    "zoo",
    "spambase",
    "letter",
    "iris",
    "segmentation",
]

def load_data(dataset, normalize=True):
    """Load dataset.

    @param dataset: dataset name
    @type dataset: str
    @return: feature vectors, labels, and pairwise similarities computed with cosine similarity
    @rtype: Tuple[np.array, np.array, np.array]
    """
    x, y = load_uci_data(dataset)
    if normalize:
    	x = x / np.linalg.norm(x, axis=1, keepdims=True)
    x0 = x[None, :, :]
    x1 = x[:, None, :]
    cos = (x0 * x1).sum(-1)
    similarities = 0.5 * (1 + cos)
    similarities = np.triu(similarities) + np.triu(similarities).T
    similarities[np.diag_indices_from(similarities)] = 1.0
    similarities[similarities > 1.0] = 1.0
    return x, y, similarities

def load_uci_data(dataset):
    """Loads data from UCI repository.

    @param dataset: UCI dataset name
    @type dataset: str
    @param normalize: whether to normalize features or not
    @type normalize: boolean
    @return: feature vectors, labels
    @rtype: Tuple[np.array, np.array]
    """
    x = []
    y = []
    data_path = os.path.join(os.environ["DATAPATH"], dataset, "{}.data".format(dataset))
    classes = {}
    class_counter = 0
    with open(data_path, 'r') as f:
        for line in f:
            split_line = line.split(",")
            x.append([float(x) for x in split_line[1:-1]])
            label = split_line[-1]
            if not label in classes:
                classes[label] = class_counter
                class_counter += 1
            y.append(classes[label])
    y = np.array(y, dtype=int)
    x = np.array(x, dtype=float)
    mean = x.mean(0)
    std = x.std(0)
    if dataset in ["covtype", "segmentation"]:
        idx = np.argwhere(std>0).flatten()
        x = x[:, idx]
        mean = mean[idx]
        std = std[idx]
    x = (x - mean) / std
    return x, y

