import os
import pandas as pd
import re
import numpy as np
from loguru import logger

def parse_answer(answer):
    """
    Parses the answer string of the GSM8k benchmark and extracts a floating-point number.

    Args:
        answer (str): The answer string to be parsed.

    Returns:
        float or None: The extracted floating-point number if found, otherwise None.
    """
    if '####' in answer:
        floating = answer.split('####')[1]
    else:
        return None
    # remove all non numbers and .
    floating_parsed = re.sub(r'[^\d.]', '', floating)
    if len(floating_parsed) == 0:
        return None
    return float(floating_parsed)

def parse_gsm8k_synthetic_correctness(data):
    """
        Parses the 'answer' field in the given data and checks if the parsed answer is an integer
        and does not involve rounding up or down.

        Args:
            data (pandas.DataFrame): The input data containing the 'answer' field.

        Returns:
            numpy.ndarray: A boolean array indicating whether each answer satisfies the conditions.
    """
    data['answer'] = data['doc'].apply(lambda x: x['answer'])
    data['parsed_answer'] = data['answer'].apply(lambda x: parse_answer(x))
    data['is_int'] = data['parsed_answer'].apply(lambda x: not np.isnan(x) and int(x) == x)
    # manual check of the data indicated that these words were only used when the model used rounding operations.
    data['rounding_up_or_down'] = data['answer'].apply(lambda x: "can't" in x or 'rounds' in x or "not possible" in x or "a fraction of" in x or "cannot" in x)
    return np.logical_and(data['is_int'], np.logical_not(data['rounding_up_or_down']))

def get_synthetic_omit(benchmark, base_path_data):
    """
    Get the 'omit' column from the synthetic.csv file for a given benchmark.

    Parameters:
    - benchmark (str): The name of the benchmark.
    - base_path_data (str): The base path where the data is stored.

    Returns:
    - pandas.Series: The 'omit' column from the synthetic.csv file.
    """
    data = pd.read_csv(f"{base_path_data}/{benchmark}/synthetic.csv")
    return data['omit']

def load_result(base_path_eval, model, benchmark_name, metric, base_path_data):
    """
    Load and process results from a specified benchmark for a given model.

    Args:
        base_path_eval (str): The base path for evaluation results.
        model (str): The name of the model.
        benchmark_name (str): The name of the benchmark.
        metric (str): The name of the metric to extract from the results.
        base_path_data (str): The base path for data.

    Returns:
        numpy.ndarray: The processed results for the specified benchmark and model.
    """
    filename = None
    path = None
    if os.path.exists(os.path.join(base_path_eval, model, benchmark_name)):
        for file in os.listdir(os.path.join(base_path_eval, model, benchmark_name)):
            if file.endswith('jsonl'):
                filename = file
                break
        if filename is not None:
            path = os.path.join(base_path_eval, model, benchmark_name, filename)
    if path is None or not os.path.isfile(path):
        logger.warning(f'Not able to read {benchmark_name} for {model} from {path}')
        return
    data = pd.read_json(path, lines=False)
    # remove duplicated rows
    data = data.drop_duplicates(subset=['doc_id'])
    if 'mathqa' in benchmark_name:
        data = data.iloc[:2000]
    if benchmark_name == 'gsm8k_synthetic':
        keep_indices = parse_gsm8k_synthetic_correctness(data)
    else:
        keep_indices = None
    if '_synthetic' in benchmark_name:
        omitted = get_synthetic_omit(benchmark_name.replace('_synthetic', ''), base_path_data)
        if keep_indices is None:
            keep_indices = np.logical_not(omitted)
        else:
            keep_indices = np.logical_or(keep_indices, np.logical_not(omitted))

    results = np.array(data[metric])
    if keep_indices is not None:
        results = results[keep_indices]
    return results

def load_results(benchmark_name, metric, 
                 base_path_eval, model_name, ref_models, base_path_data):
    """
    Load results for a given benchmark and metric.

    Args:
        benchmark_name (str): The name of the benchmark.
        metric (str): The metric to evaluate the results.
        base_path_eval (str): The base path where the evaluation results are stored.
        model_name (str): The name of the model to load results for.
        ref_models (list): A list of reference models to compare against.
        base_path_data (str): The base path where the data is stored.

    Returns:
        tuple: A tuple containing two dictionaries. The first dictionary contains the reference results
               for each reference model, and the second dictionary contains the results for the model_name.
    """
    reference_results = dict()
    min_length = None
    for model in ref_models:
        # find the jsonl file in the folder os.path.join(base_path_eval, model, benchmark_name)
        results = load_result(base_path_eval, model, benchmark_name, metric, base_path_data)
        if results is not None:
            reference_results[model] = results
            if min_length is None or len(results) < min_length:
                min_length = len(results)
    results = load_result(base_path_eval, model_name, benchmark_name, metric, base_path_data)
    if min_length is None or len(results) < min_length:
        min_length = len(results)
    if min_length is not None:
        for model in reference_results:
            reference_results[model] = reference_results[model][:min_length]
        results = results[:min_length]
    return reference_results, results


import numpy as np

def prepare_ref_results(scores_ref_models, scores_ref_models_ref_data):
    """
    Prepare reference results for contamination detection.

    This function takes in two dictionaries, `scores_ref_models` and `scores_ref_models_ref_data`,
    and extracts the corresponding scores for normal and not normal data from each dictionary.
    The extracted scores are then converted into numpy arrays and returned.

    Parameters:
    scores_ref_models (dict): A dictionary containing scores for normal data.
    scores_ref_models_ref_data (dict): A dictionary containing scores for not normal data.

    Returns:
    normal_here_ref (numpy.ndarray): An array of scores for normal data.
    not_normal_here_ref (numpy.ndarray): An array of scores for not normal data.
    """
    normal_here_ref = []
    not_normal_here_ref = []

    for model in scores_ref_models:
        if model in scores_ref_models_ref_data:
            normal_here_ref.append(scores_ref_models[model])
            not_normal_here_ref.append(scores_ref_models_ref_data[model])
    
    normal_here_ref = np.array(normal_here_ref)
    not_normal_here_ref = np.array(not_normal_here_ref)
    return normal_here_ref, not_normal_here_ref