import json
import os


def load_jsonl(path):
    inst_list = []
    with open(path) as f:
        for line in f:
            inst_list.append(json.loads(line))
    return inst_list


def write_jsonl(insts, path):
    with open(path, "w") as f:
        for inst in insts:
            print(json.dumps(inst), file=f)


def prep_common_gen(path_list):
    for path in path_list:
        insts = load_jsonl(path)
        for inst in insts:
            inst["source"] = ", ".join(inst["concepts"])
            inst.pop("concepts")
            assert isinstance(inst["target"], str)
        write_jsonl(insts, path)


def prep_e2e_nlg(path_list):
    for path in path_list:
        insts = load_jsonl(path)
        for inst in insts:
            inst["source"] = inst["meaning_representation"]
            inst["target"] = inst["human_reference"]
            inst.pop("meaning_representation")
            inst.pop("human_reference")
            assert isinstance(inst["target"], str)
        write_jsonl(insts, path)


def prep_wiki_bio(path_list):
    for path in path_list:
        insts = load_jsonl(path)
        for inst in insts:
            inst["source"] = json.dumps(inst["input_text"])
            inst["target"] = inst["target_text"]
            inst.pop("input_text")
            inst.pop("target_text")
            assert isinstance(inst["target"], str)
        write_jsonl(insts, path)


def prep_totto(path_list):
    for path in path_list:
        insts = load_jsonl(path)
        new_insts = []
        for inst in insts:
            new_inst = {}
            new_inst["source"] = json.dumps(inst["table"])
            new_inst["target"] = " ".join(inst["sentence_annotations"]["final_sentence"])
            new_insts.append(new_inst)
            assert isinstance(new_inst["target"], str)
        write_jsonl(new_insts, path)


def reformat_multi_news():
    base_dir = "datasets_all/multi_news"
    train_file = "train.src.processed"
    test_file = "test.src.processed"
    valid_file = "val.src.processed"
    train_tgt = "train.tgt"
    test_tgt = "test.tgt"
    valid_tgt = "val.tgt"
    train_src_lines = open(os.path.join(base_dir, train_file)).readlines()
    train_tgt_lines = open(os.path.join(base_dir, train_tgt)).readlines()
    val_src_lines = open(os.path.join(base_dir, valid_file)).readlines()
    val_tgt_lines = open(os.path.join(base_dir, valid_tgt)).readlines()
    test_src_lines = open(os.path.join(base_dir, test_file)).readlines()
    test_tgt_lines = open(os.path.join(base_dir, test_tgt)).readlines()
    new_insts = []
    for sample in zip(train_src_lines, train_tgt_lines):
        new_insts.append({"source": sample[0].replace("\n", ''), "target": sample[1].replace("\n", '')})
    write_jsonl(new_insts, "train.jsonl")
    new_insts = []
    for sample in zip(val_src_lines, val_tgt_lines):
        new_insts.append({"source": sample[0].replace("\n", ''), "target": sample[1].replace("\n", '')})
    write_jsonl(new_insts, "val.jsonl")
    new_insts = []
    for sample in zip(test_src_lines, test_tgt_lines):
        new_insts.append({"source": sample[0].replace("\n", ''), "target": sample[1].replace("\n", '')})
    write_jsonl(new_insts, "test.jsonl")


def reformat_cnndm():
    test_file = "data_cased/test.id.jsonl"
    train_file = "data_cased/train.id.jsonl"
    val_file = "data_cased/val.id.jsonl"
    tgt_dir = "jsonl_files/cnndm"
    test_insts = load_jsonl(test_file)
    val_insts = load_jsonl(val_file)
    train_insts = load_jsonl(train_file)
    new_insts = []
    for inst in test_insts:
        new_insts.append({"source": inst["text"], "target": inst["summary"]})
    write_jsonl(new_insts, os.path.join(tgt_dir, "test.jsonl"))

    new_insts = []
    for inst in val_insts:
        new_insts.append({"source": inst["text"], "target": inst["summary"]})
    write_jsonl(new_insts, os.path.join(tgt_dir, "val.jsonl"))

    new_insts = []
    for inst in train_insts:
        new_insts.append({"source": inst["text"], "target": inst["summary"]})
    write_jsonl(new_insts, os.path.join(tgt_dir, "train.jsonl"))


def reformat_totto(path_list):
    for path in path_list:
        insts = load_jsonl(path)
        new_insts = []
        for inst in insts:
            new_inst = {}
            new_inst["example_id"] = inst["example_id"]
            new_inst["source"] = inst["subtable_metadata_str"]
            new_inst["target"] = " ".join(inst["sentence_annotations"]["final_sentence"])
            exit(0)
            new_insts.append(new_inst)
            assert isinstance(new_inst["target"], str)
        write_jsonl(new_insts, path.replace("totto", "totto_meta"))


def reformat_wikibio():
    base_dir = "jsonl_files/wiki_bio"
    train_file = "train.box"
    test_file = "test.box"
    valid_file = "val.box"
    train_tgt = "train.summary"
    test_tgt = "test.summary"
    valid_tgt = "val.summary"
    train_src_lines = open(os.path.join(base_dir, train_file)).readlines()
    train_tgt_lines = open(os.path.join(base_dir, train_tgt)).readlines()
    val_src_lines = open(os.path.join(base_dir, valid_file)).readlines()
    val_tgt_lines = open(os.path.join(base_dir, valid_tgt)).readlines()
    test_src_lines = open(os.path.join(base_dir, test_file)).readlines()
    test_tgt_lines = open(os.path.join(base_dir, test_tgt)).readlines()
    new_insts = _prepare_wiki_bio(test_src_lines, test_tgt_lines)
    write_jsonl(new_insts, os.path.join(base_dir, "test.jsonl"))
    new_insts = _prepare_wiki_bio(train_src_lines, train_tgt_lines)
    write_jsonl(new_insts, os.path.join(base_dir, "train.jsonl"))
    new_insts = _prepare_wiki_bio(val_src_lines, val_tgt_lines)
    write_jsonl(new_insts, os.path.join(base_dir, "val.jsonl"))


def _prepare_wiki_bio(mrs, refs):
    e2e_like_list = []  # to run for T5 model

    for mr, ref in zip(mrs, refs):

        temp_dict = {}

        mr = mr.strip()
        ref = ref.strip()
        e2e_mr = {}
        for item in mr.split("\t"):
            sn, sv = item.split(":")[0], item.split(":")[1]
            sn = sn.split("_")
            if "<none>" in sv or "image" in sn or "px" in sv:
                continue
            if len(sn) == 1:
                e2e_sn = sn[0]
            else:
                e2e_sn = " ".join(sn[:-1])

            if e2e_sn not in e2e_mr:
                e2e_mr[e2e_sn] = []

            e2e_mr[e2e_sn].append(sv)
        e2e_mr = {k: " ".join(v) for k, v in e2e_mr.items()}
        e2e_mr = ", ".join([f"{k}[{v}]" for k, v in e2e_mr.items()])
        # print(e2e_mr)
        # print(ref)
        # print("----------")

        temp_dict["source"] = e2e_mr
        temp_dict["target"] = ref
        e2e_like_list.append(temp_dict)

    return e2e_like_list


def reformat_concode(file_paths):
    for fp in file_paths:
        insts = load_jsonl(fp)
        for inst in insts:
            inst["source"] = inst["nl"]
            inst["target"] = inst["code"]
            inst.pop("nl")
            inst.pop("code")
        write_jsonl(insts, fp)


def reformat_code_sum(file_paths):
    for fp in file_paths:
        new_insts = []
        insts = load_jsonl(fp)
        for inst in insts:
            new_inst = {}
            new_inst["source"] = " ".join(inst["code_tokens"])
            new_inst["target"] = " ".join(inst["docstring_tokens"])
            new_insts.append(new_inst)
        write_jsonl(new_insts, fp)


if __name__ == '__main__':
    # files = ["val.jsonl", "test.jsonl", "train.jsonl"]
    # base_path = "jsonl_files/python"
    # files = [os.path.join(base_path, file) for file in files]
    # reformat_code_sum(files)
    datasets = ["ruby", "go", "php", "javascript"]
    files = ["valid.jsonl", "test.jsonl", "train.jsonl"]

    for dataset in datasets:
        base_path = f"jsonl_files/{dataset}"
        files = [os.path.join(base_path, file) for file in files]
        reformat_code_sum(files)
    # prep_common_gen([files[2]])
    # reformat_totto(files)
    # reformat_cnndm()
