import json
import os
from transformers import T5Tokenizer, AutoTokenizer

from transformers import PegasusTokenizer


# common sense, data-to-text , summary
def tokenize_t5(ds_name, prefix="summarize: ", mname="t5-small"):
    tokenizer = AutoTokenizer.from_pretrained(mname, local_files_only=True)
    base_dir = f"jsonl_files/{ds_name}"
    tgt_dir = f"tokenized_files/{ds_name}"
    if not os.path.exists(tgt_dir):
        os.makedirs(tgt_dir)
    files = ["test.jsonl", "val.jsonl", "train.jsonl"]
    files_tgt = ["test.t5.jsonl", "val.t5.jsonl", "train.t5.jsonl"]
    insts_list = []
    for file in files:
        insts = []
        with open(os.path.join(base_dir, file)) as f:
            for line in f:
                insts.append(json.loads(line))
        insts_list.append(insts)
    for i, insts in enumerate(insts_list):
        for inst in insts:
            source = prefix + inst["source"]
            target = inst["target"]
            src_id = tokenizer.encode(source)
            tgt_id = tokenizer.encode(target)
            inst["src_id"] = src_id
            inst["tgt_id"] = tgt_id
            # insts.append(json.dumps(inst, ensure_ascii=False))
        with open(os.path.join(tgt_dir, files_tgt[i]), "w") as f:
            for inst in insts:
                print(json.dumps(inst, ensure_ascii=False), file=f)


# 摘要的实验
def tokenize_pegasus(ds_name):
    tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
    print("bos", tokenizer.bos_token_id)
    print("eos", tokenizer.eos_token_id)
    print("pad", tokenizer.pad_token_id)
    base_dir = f"jsonl_files/{ds_name}"
    tgt_dir = f"tokenized_files/{ds_name}"
    if not os.path.exists(tgt_dir):
        os.makedirs(tgt_dir)
    files = ["test.jsonl", "val.jsonl", "train.jsonl"]
    files_tgt = ["test.pegasus.jsonl", "val.pegasus.jsonl", "train.pegasus.jsonl"]
    insts_list = []
    for file in files:
        insts = []
        with open(os.path.join(base_dir, file)) as f:
            for line in f:
                insts.append(json.loads(line))
        insts_list.append(insts)
    for i, insts in enumerate(insts_list):
        for inst in insts:
            source = inst["source"]
            target = inst["target"]
            src_id = tokenizer.encode(source)
            tgt_id = tokenizer.encode(target)
            inst["src_id"] = src_id
            inst["tgt_id"] = tgt_id
            # insts.append(json.dumps(inst, ensure_ascii=False))
        with open(os.path.join(tgt_dir, files_tgt[i]), "w") as f:
            for inst in insts:
                print(json.dumps(inst, ensure_ascii=False), file=f)


# tokenize_pegasus("xsum")
# tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
# print("bos", tokenizer.bos_token_id)
# print("eos", tokenizer.eos_token_id)
# print("pad", tokenizer.pad_token_id)
# print(tokenizer.encode("my endless love."))

# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# print(tokenizer.decode(2))
# print("bos", tokenizer.bos_token_id)
# print("eos", tokenizer.eos_token_id)
# print("pad", tokenizer.pad_token_id)
# print(tokenizer.encode("my endless love."))
# tokenize_pegasus("multi_news")
# tokenize_pegasus("cnndm")

# tokenize_t5("xsum")
# tokenize_t5("multi_news")
# tokenize_t5("cnndm")
# tokenize_t5("concode", '<java> ',"Salesforce/codet5-base")
# tokenize_t5("common_gen", "generate a sentence with: ")
# tokenize_t5("totto_meta", "convert the table to text: ")
tokenize_t5("ruby", '<ruby> ', "Salesforce/codet5-base")
tokenize_t5("go", '<go> ', "Salesforce/codet5-base")
tokenize_t5("javascript", '<javascript> ', "Salesforce/codet5-base")
tokenize_t5("php", '<php> ', "Salesforce/codet5-base")
