import json
import re
import pandas as pd
import yaml
import hotpot_evaluate

with open('config.yml', 'r') as f:
    config = yaml.safe_load(f)

df = pd.DataFrame(columns=['worker_id', 'model1', 'model2', 'eval_model', 'fluency', 'helpfulness', 'ease', 'helpfulness_freetext', 'no_of_turns', 'accuracy', 'f1', 'recall'])

conv_file_name = '../results/ambig_conversation_{model1}_{model2}{persona}_prompt-1.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
with open(conv_file_name, 'r') as f:
    conv_data = json.load(f)
    conv_data = {c['id']:c for c in conv_data}

#file_name = '../results/predictions_{model1}_{model2}_{eval_model}.json'.format(
file_name = '../results/ambig_predictions_{model1}_{model2}_{eval_model}{persona}_prompt-1.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        eval_model=config['eval_model'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
with open(file_name, 'r') as f:
    predictions = json.load(f)

def parse_line(line):
    if ':' in line:
        match = re.search(' \((\d)\)', line)
        if match is not None:
            number = int(match.group(1))
        else:
            number = int(line.split(':')[-1].strip())
    else:
        number = int(line.strip()[-2])
    return number

for idx, pred in enumerate(predictions):
    pred_text = pred['prediction'].strip()
    pred_text = re.sub('\n\n', '\n', pred_text)
    pred_text = pred_text.split('\n')
    try:
        fluency, helpful, ease = pred_text[:3]
        print(fluency)
        print(helpful)
        print(ease)
        fluency = parse_line(fluency)
        helpful = parse_line(helpful)
        ease = parse_line(ease)
    except Exception as e:
        print(pred_text)
        if '' == ''.join(pred_text).strip():
            fluency, helpful, ease = 0, 0, 0
        else:
            fluency = int(input('fluency: '))
            helpful = int(input('helpful: '))
            ease = int(input('ease: '))
    try:
        helpful_text = '\n'.join(pred_text[3:])
        print(helpful_text)
    except Exception as e:
        helpful_text = ''
    if 'helpfulness:' == helpful_text.lower()[:12]:
        helpful_text = helpful_text[12:].strip()
    elif 'helpfulness (free-form): ' == helpful_text.lower()[:25]:
        helpful_text = helpful_text[25:].strip()
    else:
        helpful_text = helpful_text.strip()

    # number of turns
    worker_id = pred['worker_id']
    line = conv_data[worker_id]
    no_of_turns = len(line['lm_responses'])
    print('# of turns:', no_of_turns)

    # accuracy
    golden_answers = []
    for pairs in line['annotations']:
        if 'qaPairs' in pairs:
            for pair in pairs['qaPairs']:
                golden_answers.extend(pair['answer'])
        elif 'answer' in pairs:
            golden_answers.extend(pairs['answer'])
    golden_answers = list(set(golden_answers))
    golden_answers = [s.strip() for s in golden_answers]
    user_answer = line['user_answer']
    flag = False
    max_f1 = 0
    max_recall = 0
    for golden in golden_answers:
        if golden in user_answer:
            flag = True
        f1, _, recall = hotpot_evaluate.f1_score(user_answer, golden)
        max_f1 = max(max_f1, f1)
        max_recall = max(max_recall, recall)
#    if not flag:
#        print(', '.join(golden_answers))
#        print(user_answer)
#        acc = int(input('0 or 1: '))
    acc = int(flag)
    print('Accuracy:', acc)
    print('F1:', max_f1)
    print('Recall:', max_recall)

    row = [idx+1, config['model1'], config['model2'], config['eval_model'], fluency, helpful, ease, helpful_text, no_of_turns, acc, max_f1, max_recall]
    df.loc[len(df.index)] = row

#file_name = '../results/predictions_{model1}_{model2}_{eval_model}.csv'.format(
file_name = '../results/ambig_predictions_{model1}_{model2}_{eval_model}{persona}_prompt-1.csv'.format(
        model1=config['model1'],
        model2=config['model2'],
        eval_model=config['eval_model'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
df.to_csv(file_name, index=False)

