# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, to_tree


def measure_ip_stability(dist_matrix, clustering):
    '''
    measures whether / to what extent a clustering is IP-stable and also returns closest cluster for every point

    INPUT:
    dist_matrix: 2-dim array of size n x n, or condensed distance vector of length n*(n-1)/2
    clustering: 1-dim array of size n

    OUTPUT:
    is_stable: either True or False indicating whether clustering is stable
    nr_unstable: number of points that are not stable
    violation: 1-dim array of size n; violation[i] measures multiplicative violation of stability constraint for point i;
               violation[i]>=0 and violation[i]<=1 if and only if point i is stable
    closest_cluster: closest_cluster[i] is the cluster that point i is closest to
    '''

    s = dist_matrix.shape
    if len(s) == 1:
        dist_matrix = squareform(dist_matrix)

    n = dist_matrix.shape[0]
    clustering_unique = np.unique(clustering)

    nr_unstable = 0
    violation = np.zeros(n)
    closest_cluster = np.zeros(n)

    nr_in_clusters = {}
    for cluster_number in clustering_unique:
        nr_in_clusters[cluster_number] = 0
    for x in clustering:
        nr_in_clusters[x] = nr_in_clusters[x] + 1

    row_sum_for_clusters={}

    for idx,x in enumerate(clustering):
        if x not in row_sum_for_clusters:
            row_sum_for_clusters[x] = dist_matrix[idx]
        else:
            row_sum_for_clusters[x]+= dist_matrix[idx]

    for ell in np.arange(n):
        cl_index = clustering[ell]
        # in_cluster = (clustering == cl_index)
        # nr_in_cluster = np.sum(in_cluster)
        nr_in_cluster = nr_in_clusters[cl_index]
        closest_cluster[ell] = cl_index
        # sum_dist = {}
        # for cluster_number in clustering_unique:
            # sum_dist[cluster_number] = 0
        # for j in np.arange(n):
            # sum_dist[clustering[j]] = sum_dist[clustering[j]] + dist_matrix[ell][j]

        if nr_in_cluster > 1:
            is_stable_point = True
            av_dist_own = row_sum_for_clusters[cl_index][ell] / (nr_in_clusters[cl_index] - 1)
            for jjj in clustering_unique:
                if jjj == cl_index:
                    continue
                av_dist = row_sum_for_clusters[jjj][ell] / nr_in_clusters[jjj]
                if av_dist < av_dist_own:
                    is_stable_point = False
                if av_dist > 0 and ( av_dist_own / av_dist ) > violation[ell]:
                    violation[ell] = (av_dist_own / av_dist)
                    closest_cluster[ell] = jjj
            if is_stable_point == False:
                nr_unstable += 1
            else:
                closest_cluster[ell] = cl_index


            # av_dist_own = np.sum(dist_matrix[ell, in_cluster]) / (nr_in_cluster - 1)
            # for jjj in clustering_unique:
                # if jjj == cl_index:
                    # continue
                # av_dist = np.mean(dist_matrix[ell, clustering == jjj])
                # if av_dist < av_dist_own:
                    # is_stable_point = False
                # if (av_dist_own / av_dist) > violation[ell]:
                    # violation[ell] = (av_dist_own / av_dist)
                    # closest_cluster[ell] = jjj
            # if is_stable_point == False:
                # nr_unstable += 1
            # else:
                # closest_cluster[ell] = cl_index

    is_stable = (nr_unstable == 0)
    return is_stable, nr_unstable, violation, closest_cluster


########################################################################################################################


def compute_clustering_cost(dist_matrix, clustering):
    '''
    computes the cost of a clustering as defined in Eq. (4)
    '''

    s = dist_matrix.shape
    if len(s) == 1:
        dist_matrix = squareform(dist_matrix)

    cost = 0
    kmeans_cost = 0
    for ell in np.unique(clustering):
        in_cluster = (clustering == ell)
        if np.sum(in_cluster) > 1:
            cost = cost + np.sum(dist_matrix[np.ix_(in_cluster, in_cluster)]) / (
                        np.sum(in_cluster) * (np.sum(in_cluster) - 1))
            kmeans_cost = kmeans_cost + np.sum(dist_matrix[np.ix_(in_cluster, in_cluster)]**2) / np.sum(in_cluster)

    return cost,kmeans_cost


########################################################################################################################


def DynamicProgramming_for_1Dimension_PInfinity(points, k, target_sizes):
    '''
    implementation of the dynamic programming approach of Section 5.1 / Appendix F for the case p=infinity

    INPUT:
    points: SORTED 1-dim array of n points in 1D
    k: number of clusters
    target_sizes: 1-dim array of size k comprising target cluster sizes

    OUTPUT:
    clustering: 1-dim array of size n
    cl_sizes: 1-dim array of size k comprising cluster sizes for clustering
    '''

    n = points.size
    dmat = squareform(pdist(points.reshape((n, 1))))

    # building table
    T = np.inf * np.ones((n, n, k))
    for ell in np.arange(n):
        T[ell, ell, 0] = np.abs(ell + 1 - target_sizes[0])
    for ell in np.arange(k):
        T[ell, 0, ell] = np.amax(np.abs(1 - target_sizes[0:(ell + 1)]))

    for l in np.arange(2, k + 1):
        for i in np.arange(l, n + 1):

            j = 1
            ineq1_lhs_UNNORM = 0
            ineq1_rhs = np.mean(dmat[(i - j) - 1, (i - j):i])
            ineq2_rhs_UNNORM = 0
            temp = np.inf

            s = 1
            temp = np.amin([temp, T[i - j - 1, s - 1, (l - 1) - 1]])
            for s in np.arange(2, i - j - (l - 2) + 1):
                ineq1_lhs_UNNORM = ineq1_lhs_UNNORM + dmat[(i - j) - 1, (i - j - s)]
                if (ineq1_lhs_UNNORM / (s - 1)) <= ineq1_rhs:
                    temp = np.amin([temp, T[i - j - 1, s - 1, (l - 1) - 1]])
            T[i - 1, j - 1, l - 1] = np.amax([np.abs(target_sizes[l - 1] - j), temp])

            for j in np.arange(2, i - l + 1 + 1):
                ineq1_lhs_UNNORM = 0
                ineq1_rhs = np.mean(dmat[(i - j) - 1, (i - j):i])
                ineq2_lhs = np.mean(dmat[(i - j), (i - j + 1):i])
                ineq2_rhs_UNNORM = 0
                temp = np.inf

                s = 1
                ineq2_rhs_UNNORM = ineq2_rhs_UNNORM + dmat[i - j, (i - j - s)]
                if ineq2_lhs <= (ineq2_rhs_UNNORM / s):
                    temp = np.amin([temp, T[i - j - 1, s - 1, (l - 1) - 1]])
                for s in np.arange(2, i - j - (l - 2) + 1):
                    ineq1_lhs_UNNORM = ineq1_lhs_UNNORM + dmat[(i - j) - 1, (i - j - s)]
                    ineq2_rhs_UNNORM = ineq2_rhs_UNNORM + dmat[i - j, (i - j - s)]
                    if ((ineq1_lhs_UNNORM / (s - 1)) <= ineq1_rhs) and (ineq2_lhs <= (ineq2_rhs_UNNORM / s)):
                        temp = np.amin([temp, T[i - j - 1, s - 1, (l - 1) - 1]])
                T[i - 1, j - 1, l - 1] = np.amax([np.abs(target_sizes[l - 1] - j), temp])

    # recovering clustering from table
    cl_sizes = np.zeros(k, dtype=int)
    cl_sizes[-1] = np.argmin(T[n - 1, :, k - 1]) + 1
    v_star = np.amin(T[n - 1, :, k - 1])
    sum_current_cl_sizes = cl_sizes[-1]

    for ell in np.arange(k - 1, 1, -1):

        ineq1_lhs_UNNORM = 0
        ineq1_rhs = np.mean(
            dmat[(n - sum_current_cl_sizes) - 1, (n - sum_current_cl_sizes):(n - sum_current_cl_sizes + cl_sizes[ell])])
        if cl_sizes[ell] > 1:
            ineq2_lhs = np.mean(dmat[(n - sum_current_cl_sizes),
                                (n - sum_current_cl_sizes + 1):(n - sum_current_cl_sizes + cl_sizes[ell])])
        ineq2_rhs_UNNORM = 0

        mmm = 1
        if cl_sizes[ell] == 1:
            if T[n - sum_current_cl_sizes - 1, mmm - 1, ell - 1] <= v_star:
                cl_sizes[ell - 1] = mmm
                sum_current_cl_sizes += mmm
                continue
        else:
            ineq2_rhs_UNNORM = ineq2_rhs_UNNORM + dmat[(n - sum_current_cl_sizes), (n - sum_current_cl_sizes - mmm)]
            if ineq2_lhs <= (ineq2_rhs_UNNORM / mmm):
                if T[n - sum_current_cl_sizes - 1, mmm - 1, ell - 1] <= v_star:
                    cl_sizes[ell - 1] = mmm
                    sum_current_cl_sizes += mmm
                    continue

        for mmm in np.arange(2, n - sum_current_cl_sizes):
            if cl_sizes[ell] == 1:
                ineq1_lhs_UNNORM = ineq1_lhs_UNNORM + dmat[
                    (n - sum_current_cl_sizes) - 1, (n - sum_current_cl_sizes - mmm)]
                if (ineq1_lhs_UNNORM / (mmm - 1)) <= ineq1_rhs:
                    if T[n - sum_current_cl_sizes - 1, mmm - 1, ell - 1] <= v_star:
                        cl_sizes[ell - 1] = mmm
                        sum_current_cl_sizes += mmm
                        break
            else:
                ineq1_lhs_UNNORM = ineq1_lhs_UNNORM + dmat[
                    (n - sum_current_cl_sizes) - 1, (n - sum_current_cl_sizes - mmm)]
                ineq2_rhs_UNNORM = ineq2_rhs_UNNORM + dmat[(n - sum_current_cl_sizes), (n - sum_current_cl_sizes - mmm)]
                if ((ineq1_lhs_UNNORM / (mmm - 1)) <= ineq1_rhs) and (ineq2_lhs <= (ineq2_rhs_UNNORM / mmm)):
                    if T[n - sum_current_cl_sizes - 1, mmm - 1, ell - 1] <= v_star:
                        cl_sizes[ell - 1] = mmm
                        sum_current_cl_sizes += mmm
                        break

    cl_sizes[0] = n - np.sum(cl_sizes)

    # sizes of the clusters define the clustering
    clustering = np.zeros(n)
    clustering[0:cl_sizes[0]] = 0
    for ell in np.arange(1, k):
        clustering[np.sum(cl_sizes[:ell]):np.sum(cl_sizes[:ell + 1])] = ell

    return clustering, cl_sizes


########################################################################################################################


def linkage_clustering_with_greedy_pruning(dist_matrix, k, method='single', criterion='number of points'):
    '''
    implementation of the heuristic approach of Section 6.1 / Appendix G.3: first runs linkage clustering to produce the
    whole linkage tree and then splits clusters (beginning from the root of the tree, which corresponds to one cluster
    comprising all data points) in a greedy fashion; when we have l clusters, it splits one of these such that the
    resulting (l+1)-clustering ...

        *) if criterion='number of points':
            ... has the smallest (over the choice of the cluster to split) number of points that are not stable
        *) if criterion='violation'
            ... has the smallest (over the choice of the cluster to split) maximum (over the data points) stability
                violation

    INPUT:
    dist_matrix: 2-dim array of size n x n, or condensed distance vector of length n*(n-1)/2
    k: number of clusters
    method: which linkage algorithm to use --- see

            https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html

            for available options
    criterion: if 'number of points', strategy aims at minimizing the number of points that are not stable; if
               'violation', strategy aims at minimizing the maximum violation

    OUTPUT:
    clustering: 1-dim array of size n
    '''

    s = dist_matrix.shape
    if len(s) == 2:
        n = s[0]
        dist_matrix = squareform(dist_matrix)
    else:
        n = int(np.ceil(np.sqrt(s[0] * 2)))

    Z = linkage(dist_matrix, method=method)

    root = to_tree(Z)
    clusters = [root.left, root.right]
    clustering = np.zeros(n)
    clustering[clusters[1].pre_order()] = 1
    kk = 2

    for ell in range(k - 2):

        violation_cand = np.inf
        nr_unstable_cand = np.inf

        for ttt in np.arange(kk):
            if np.sum(clustering == ttt) == 1:
                continue
            clustering_new = clustering.copy()
            clustering_new[clusters[ttt].right.pre_order()] = kk
            _, nr_unstable, violation, _ = measure_ip_stability(dist_matrix, clustering_new)
            if criterion == 'number of points':
                if nr_unstable < nr_unstable_cand:
                    nr_unstable_cand = nr_unstable
                    clustering_best = clustering_new
                    best_to_split_ind = ttt
            else:
                if np.amax(violation) < violation_cand:
                    violation_cand = np.amax(violation)
                    clustering_best = clustering_new
                    best_to_split_ind = ttt

        best_to_split = clusters[best_to_split_ind]
        clusters[best_to_split_ind] = best_to_split.left
        clusters.append(best_to_split.right)
        clustering = clustering_best
        kk += 1

    return clustering
