import sys
import os
import numpy as np

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

def get_top_k_indices_scores(sampled_score_matrix, sampled_node_indices, k):
    num_sampled_nodes = sampled_score_matrix.shape[0]
    top_k_indices = np.zeros((num_sampled_nodes, k), dtype=int)
    top_k_scores = np.zeros((num_sampled_nodes, k))

    for i in range(num_sampled_nodes):
        node_scores = sampled_score_matrix[i, :]
        # Exclude the node itself from its top-k list
        node_scores[sampled_node_indices[i]] = -np.inf
        top_nodes = np.argsort(-node_scores)[:k]  # Indices of top-k nodes
        top_scores = node_scores[top_nodes]       # Scores of top-k nodes

        top_k_indices[i, :] = top_nodes
        top_k_scores[i, :] = top_scores

    return top_k_indices, top_k_scores


def dcg(scores):
    """Calculate DCG"""
    return np.sum([(2**rel - 1) / np.log2(idx + 2) for idx, rel in enumerate(scores)])


def calculate_ndcg(top_k_indices, top_k_scores, top_k_indices_private):
    num_samples, _ = top_k_indices.shape
    ndcg_scores = []  # List to store NDCG values for each sample

    for i in range(num_samples):
        # Calculate DCG for each node
        predicted_scores = [top_k_scores[i, list(top_k_indices[i]).index(idx)] if idx in top_k_indices[i] else 0 for idx in top_k_indices_private[i]]
        dcg_val = dcg(predicted_scores)

        # Calculate IDCG
        ideal_scores = top_k_scores[i]
        idcg_val = dcg(sorted(ideal_scores, reverse=True))

        # Calculate NDCG
        ndcg_val = dcg_val / idcg_val if idcg_val > 0 else 0
        ndcg_scores.append(ndcg_val)  # Store the NDCG for this sample

    avg_ndcg = np.mean(ndcg_scores)  # Compute the average NDCG
    std_ndcg = np.std(ndcg_scores)  # Compute the standard deviation of NDCG

    return avg_ndcg, std_ndcg


def calculate_recall(top_k_indices, top_k_indices_private, r):
    num_samples, _ = top_k_indices.shape
    recall_scores = []  # List to store recall values for each sample

    for i in range(num_samples):
        # Get the top-r indices for both lists
        top_r = set(top_k_indices[i, :r])
        top_r_private = set(top_k_indices_private[i, :r])

        # Calculate the intersection and the recall
        intersection = top_r.intersection(top_r_private)
        recall = len(intersection) / r
        recall_scores.append(recall)

    avg_recall = np.mean(recall_scores)  # Compute the average recall
    std_recall = np.std(recall_scores)  # Compute the standard deviation of recall

    return avg_recall, std_recall