import os
import json
import argparse
# from data.datasets import GeneralSeq2SeqDataset, get_all_labels
from metrics.classification_metrics import create_metric_f1_accuracy_chatgpt, create_metric_mae_rmse_chatgpt
from metrics.generation_metrics import create_metric_bleu_rouge_meteor_chatgpt

def get_all_labels(task):
    if task == "LaMP-1":
        return ["[1]","[2]"]
    elif task == "LaMP-2":
        return ['women', 'religion', 'politics', 'style & beauty', 'entertainment', 'culture & arts', 'sports', 'science & technology', 'travel', 'business', 'crime', 'education', 'healthy living', 'parents', 'food & drink']
    elif task == "LaMP-2-movie":
        return ['sci-fi', 'based on a book', 'comedy', 'action', 'twist ending', 'dystopia', 'dark comedy', 'classic', 'psychology', 'fantasy', 'romance', 'thought-provoking', 'social commentary', 'violence', 'true story']
    elif task == "LaMP-3":
        return ["1", "2", "3", "4", "5"]
    elif task == "LaMP-4":
        return []
    elif task == "LaMP-5":
        return []
    elif task == "LaMP-6":
        return []
    elif task == "LaMP-7":
        return []

parser = argparse.ArgumentParser()
parser.add_argument('--task', default='LaMP_2', type=str, help='task')
parser.add_argument('--file_path', default='data/LaMP_2/generation/llama-3-8b-origin/0_dev_chat.jsonl', type=str, help='debug')
args = parser.parse_args()

# prediction_path = 'data/LaMP_2/generation/llama-3-8b-origin/0_dev_chat.jsonl'
with open(args.file_path, 'r') as f:
    predictions = [json.loads(line) for line in f]
labels = get_all_labels(args.task.replace('_','-'))
if args.task in ['LaMP_1', 'LaMP_2', 'LaMP_2_movie']:
    labels = get_all_labels(args.task.replace('_','-'))
    compute_metrics = create_metric_f1_accuracy_chatgpt(all_labels=labels)
elif args.task == 'LaMP_3':
    labels = get_all_labels(args.task.replace('_','-'))
    compute_metrics = create_metric_mae_rmse_chatgpt(all_labels=labels)
elif args.task in ['LaMP_4', 'LaMP_5', 'LaMP_6', 'LaMP_7']:
    compute_metrics = create_metric_bleu_rouge_meteor_chatgpt()
preds = []
answers = []
for i in range(len(predictions)):
    data = predictions[i]
    if 'A: ' in data['generation']:
        data['generation'] = data['generation'].replace('A: ', '').replace('\n','')
    preds.append(data['generation'])
    answers.append(data['target'])
result = compute_metrics(preds, answers)
print(result)