"""
An implementation of a LIBRO like operation
THIS IS NOT FINISHED YET, EXPECT WEIRD BEHAVIOR

Takes several samples for unittest generation,
looks at their evaluation trace and picks the one that most closely
resembles the issue description (picked by LLM or so)
"""
import difflib
import re
from collections import defaultdict
from typing import List, Optional
import json
import pathlib

import fire
from unidiff import PatchSet

from datasets import load_from_disk, Dataset, DatasetDict
from measure_coverage_patch import extract_good_case_from_eval_output, log, extract_coverages_from_eval_output, \
    extract_patch_from_eval_output, load_eval_outputs

PROMPT = """
You are an automated expert software engineer working on a project. Below is a user issue in a repository.
{}

Another agent has generated a test case that tries to encapsulate the user issue.
The test suite of the repository was executed before and after adding the test case.
The difference between the execution traces is shown below:
```trace
{}
```

You are an automated expert software engineer working on a project. Above is a user issue in a repository.
Please look at the generated test case and the execution trace of running the test case on the current repository.
Please answer whether the test case accurately tests the issue described by the user.
Please answer with "yes" or "no".
"""


def extract_issue_from_text(text: List[str]) -> str:
    # extract the issue from the text
    for i, line in enumerate(text):
        if "<issue>" in line:
            break
    if i == len(text) - 1:
        return None
    for j, line in enumerate(text[i:]):
        if "</issue>" in line:
            break
    if j == len(text) - 1:
        return None
    return "\n".join(text[i:j])

def extract_execution_trace_from_eval_output(eval_output: str, pre_text: str) -> List[str]:
    # extract the execution trace from the evaluation output
    eval_output = eval_output.splitlines()
    start_line = 0
    for i, line in enumerate(eval_output[start_line:], start=start_line):
        if line.startswith(pre_text) and i < len(eval_output)-1 and eval_output[i+1].startswith("Test Script:"):
            break
    if i >= len(eval_output) - 1:
        return None
    for j, line in enumerate(eval_output[i:], start=i):
        if line.startswith("Coverage Script: "):
            break
    return eval_output[i+1:j-1]

def extract_exec_trace_after_applied_patch(eval_output: str) -> List[str]:
    return extract_execution_trace_from_eval_output(eval_output, ">>>>> Applied Patch")

def extract_exec_trace_after_init(eval_output: str) -> List[str]:
    return extract_execution_trace_from_eval_output(eval_output, ">>>>> Init Succeeded")

def main(
    eval_output_dir: str = "./evaluation_output/gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=3,temperature=07__test/mode_custom",
    seeds: str = "1,2,3,4,5",
    dataset: str = "./datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k",
    out_dataset: str = "./datasets/libro_gpt-4-1106-preview__swt_bench_lite_aug1__test",
    split: str = "test",
    log: callable = log,
):
    seeds = [int(s) for s in seeds.split(",")]
    dataset = load_from_disk(dataset)
    eval_output_dirs = [re.sub(r"00\d+__", f"00{seed}__", re.sub(r"seed=\d+", f"seed={seed}", eval_output_dir)) for seed in seeds]
    eval_output_by_instance = {seed: dict() for seed in seeds}
    new_examples = []
    for seed, eval_output_dir in zip(seeds, eval_output_dirs):
        eval_output_by_instance[seed] = load_eval_outputs(eval_output_dir)
    for example in dataset[split]:
        instance_id = example["instance_id"]
        eval_outputss = [eval_output_by_instance[seed].get(instance_id) for seed in seeds]
        if not any(eval_outputss):
            log({
                "instance_id": instance_id,
                "message": "no eval output found",
            })
            continue
        user_issue = extract_issue_from_text(example["text"].splitlines())
        for i, eval_outputs in enumerate(eval_outputss):
            if eval_outputs is None:
                continue
            _, _, fails_initially, error_initially, compilation_error = extract_good_case_from_eval_output(eval_outputs)
            if not fails_initially:
                continue
            execution_trace_after_init = extract_exec_trace_after_init(eval_outputs)
            execution_trace_after_patch = extract_exec_trace_after_applied_patch(eval_outputs)
            if execution_trace_after_init is None or execution_trace_after_patch is None:
                continue
            unittest_patch = extract_patch_from_eval_output(eval_outputs)

            diffstuff = list(difflib.unified_diff(execution_trace_after_init, execution_trace_after_patch))

            new_example = {
                **example,
                "text": PROMPT.format(user_issue, "".join(diffstuff)),
                "execution_trace_after_init": execution_trace_after_init,
                "execution_trace_after_patch": execution_trace_after_patch,
                "unittest_patch": unittest_patch,
                "instance_id": instance_id + "_seed=" + str(i),
            }
            new_examples.append(new_example)
    ds_l = Dataset.from_list(new_examples)
    ds = DatasetDict({split: ds_l})
    ds.save_to_disk(out_dataset)


        


if __name__ == "__main__":
    fire.Fire(main)
