import time

import torch
import torch.nn.functional as F
from pykeops.torch import LazyTensor
from torch_cluster import fps

from rel2abs.openfaiss import FaissIndex

# from sklearn.cluster import KMeans


try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

import random
from enum import auto
from typing import List, Optional, Sequence

import numpy as np
from pytorch_lightning import seed_everything

# @torch.no_grad()
# def anchor_pruning(anchors: torch.Tensor, threshold: float, min_anchors: int = 1):
#     anchor_similarities = (anchors @ anchors.T) - torch.eye(anchors.size(0), device=anchors.device)

#     current_anchor_index: int = 0
#     while current_anchor_index < anchor_similarities.size(0):
#         current_anchor_similarities = anchor_similarities[current_anchor_index, :]
#         too_similar = (current_anchor_similarities >= threshold).nonzero()
#         anchor_similarities[too_similar, :] = 0
#         anchor_similarities[:, too_similar] = 0
#         current_anchor_index += 1

#     keep_indices = anchor_similarities.sum(dim=0) != 0
#     if keep_indices.sum() < min_anchors:
#         raise RuntimeError(f"Not enough anchors remaining with threshold {threshold}")


#     return keep_indices


@torch.no_grad()
def anchor_pruning(anchors: torch.Tensor, stop_distance: float, random_seed: Optional[int] = None):
    if len(anchors.shape) == 2:
        anchors = anchors.unsqueeze(0)

    if stop_distance == 0:
        return [torch.arange(subspace_anchors.size(0)) for subspace_anchors in anchors]

    keep_indices = []
    for subspace_anchors in anchors:
        assert stop_distance >= 0
        subspace_anchors = F.normalize(subspace_anchors, p=2, dim=-1)
        anchor_similarities = subspace_anchors @ subspace_anchors.T
        anchor_distances = 1 - anchor_similarities.abs()

        keep_indices_subspace = torch.zeros(anchor_distances.size(0), dtype=torch.bool, device=subspace_anchors.device)
        if random_seed is None:
            max_arg = anchor_distances.flatten().argmax()
            max_arg_i = max_arg // anchor_distances.size(0)
            max_arg_j = max_arg % anchor_distances.size(0)

            keep_indices_subspace[max_arg_i] = True
            keep_indices_subspace[max_arg_j] = True
        else:
            seed_everything(seed=random_seed)
            keep_indices_subspace[random.randint(0, keep_indices_subspace.size(0) - 1)] = True

        while True:
            dist_slice = anchor_distances[keep_indices_subspace, :]
            dist_slice = dist_slice * ~keep_indices_subspace

            min_dist = dist_slice.min(dim=0).values

            if min_dist.max() <= stop_distance:
                # print("stop dist reached")
                break

            keep_indices_subspace[min_dist.argmax()] = True

        keep_indices.append(keep_indices_subspace.nonzero()[:, 0])

    assert all(
        len(subspace_indices.tolist()) == len(set(subspace_indices.tolist())) for subspace_indices in keep_indices
    )

    return keep_indices


class AnchorChoice(StrEnum):
    UNIFORM = auto()
    FPS = auto()
    KMEANS = auto()
    TOP_K = auto()


def KMeans(x, K, Niter=300, verbose=True):
    """Implements Lloyd's algorithm for the Euclidean metric."""

    start = time.time()
    N, D = x.shape  # Number of samples, dimension of the ambient space

    c = x[:K, :].clone()  # Simplistic initialization for the centroids

    x_i = LazyTensor(x.view(N, 1, D))  # (N, 1, D) samples
    c_j = LazyTensor(c.view(1, K, D))  # (1, K, D) centroids

    # K-means loop:
    # - x  is the (N, D) point cloud,
    # - cl is the (N,) vector of class labels
    # - c  is the (K, D) cloud of cluster centroids
    for i in range(Niter):

        # E step: assign points to the closest cluster -------------------------
        D_ij = ((x_i - c_j) ** 2).sum(-1)  # (N, K) symbolic squared distances
        cl = D_ij.argmin(dim=1).long().view(-1)  # Points -> Nearest cluster

        # M step: update the centroids to the normalized cluster average: ------
        # Compute the sum of points per cluster:
        c.zero_()
        c.scatter_add_(0, cl[:, None].repeat(1, D), x)

        # Divide by the number of points per cluster:
        Ncl = torch.bincount(cl, minlength=K).type_as(c).view(K, 1)
        c /= Ncl  # in-place division to compute the average

    if verbose:  # Fancy display -----------------------------------------------
        if x.device.type == "cuda":
            torch.cuda.synchronize()
        end = time.time()
        print(f"K-means for the Euclidean metric with {N:,} points in dimension {D:,}, K = {K:,}:")
        print(
            "Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n".format(
                Niter, end - start, Niter, (end - start) / Niter
            )
        )

    return cl, c


# https://www.kernel-operations.io/keops/_auto_tutorials/kmeans/plot_kmeans_torch.html?highlight=kmeans
def KMeansCosine(x, K, Niter=300, verbose=True):
    """Implements Lloyd's algorithm for the Cosine similarity metric."""

    start = time.time()
    N, D = x.shape  # Number of samples, dimension of the ambient space

    c = x[:K, :].clone()  # Simplistic initialization for the centroids
    # Normalize the centroids for the cosine similarity:
    c = torch.nn.functional.normalize(c, dim=1, p=2)

    x_i = LazyTensor(x.view(N, 1, D))  # (N, 1, D) samples
    c_j = LazyTensor(c.view(1, K, D))  # (1, K, D) centroids

    # K-means loop:
    # - x  is the (N, D) point cloud,
    # - cl is the (N,) vector of class labels
    # - c  is the (K, D) cloud of cluster centroids
    for i in range(Niter):

        # E step: assign points to the closest cluster -------------------------
        S_ij = x_i | c_j  # (N, K) symbolic Gram matrix of dot products
        cl = S_ij.argmax(dim=1).long().view(-1)  # Points -> Nearest cluster

        # M step: update the centroids to the normalized cluster average: ------
        # Compute the sum of points per cluster:
        c.zero_()
        c.scatter_add_(0, cl[:, None].repeat(1, D), x)

        # Normalize the centroids, in place:
        c[:] = torch.nn.functional.normalize(c, dim=1, p=2)

    if verbose:  # Fancy display -----------------------------------------------
        if x.device.type == "cuda":
            torch.cuda.synchronize()
        end = time.time()
        print(f"K-means for the cosine similarity with {N:,} points in dimension {D:,}, K = {K:,}:")
        print(
            "Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n".format(
                Niter, end - start, Niter, (end - start) / Niter
            )
        )

    return cl, c


def get_anchors(
    x: torch.Tensor, anchor_choice: AnchorChoice, num_anchors: int, seed: int, limit: Optional[int] = None
) -> Sequence[int]:
    seed_everything(seed)

    N, D = x.shape
    ids: List[int] = list(range(N))

    if anchor_choice == AnchorChoice.UNIFORM:
        anchor_ids: Sequence[int] = random.sample(ids, num_anchors)
    elif anchor_choice == AnchorChoice.FPS:
        anchor_fps: torch.Tensor = x
        anchor_fps = F.normalize(anchor_fps, p=2, dim=-1)
        anchor_fps = fps(anchor_fps, random_start=True, ratio=num_anchors / N)
        anchor_ids: Sequence[int] = anchor_fps.cpu().tolist()
    elif anchor_choice == AnchorChoice.KMEANS:
        vectors = F.normalize(x, p=2, dim=-1)
        clustered, centroids = KMeansCosine(x=vectors, K=num_anchors)
        # clustered = KMeans(n_clusters=num_anchors, random_state=seed).fit_predict(x)
        all_targets = sorted(set(clustered.cpu().tolist()))
        cluster2embeddings = {target: vectors[clustered == target] for target in all_targets}
        cluster2centroid = {
            cluster: centroid.mean(dim=0).cpu().numpy() for cluster, centroid in cluster2embeddings.items()
        }
        centroids = np.array(list(cluster2centroid.values()), dtype="float32")

        index: FaissIndex = FaissIndex(d=vectors.shape[1])
        index.add_vectors(list(zip(list(map(str, ids)), vectors.cpu().numpy())), normalize=False)
        centroids = index.search_by_vectors(query_vectors=centroids, k_most_similar=1, normalize=True)

        anchor_ids = [int(list(sample2score.keys())[0]) for sample2score in centroids]
    elif anchor_choice == AnchorChoice.TOP_K:
        assert limit is not None
        anchor_ids = random.sample(ids[:limit], num_anchors)
    else:
        assert NotImplementedError

    return sorted(anchor_ids)
