import json
import pandas as pd
import yaml

with open('config.yml', 'r') as f:
    config = yaml.safe_load(f)

from tqdm import tqdm

import openai_api
import anthropic_api
from persona import persona_eval, persona_eval_sys

sys_message = 'You are a helpful and precise assistant for checking the quality of the AI assistant\'s responses in conversations.'

sys_prompt = (
'Please evaluate the above conversations between user and AI assistant by using the following metrics:\n'
'Fluency (5-point Likert): How clear (or fluent) were the responses from the AI Assistant?\n'
'Helpfulness (5-point Likert): Independent of its fluency, how helpful was the AI Assistant to the user?\n'
'Ease of interaction (5-point Likert): How easy was it to interact with the AI Assistant?\n'
'Helpfulness (free-form): Why did you find the AI Assistant helpful or unhelpful?\n'
'Please output each of the above metrics line-by-line.'
)
#if 'general' != config['persona']:
#    sys_prompt = '\n'.join([persona_eval[config['persona']], sys_prompt])

#data = pd.read_csv('../data/event_blocks.csv')
#workers = pd.read_csv('../results/accuracy_by_id.csv')
file_name = '../results/hotpot_conversation_{model1}_{model2}{persona}_prompt-2.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
with open(file_name, 'r') as f:
    data = json.load(f)

def extract_line(line):
    question_text = line['context'] + line['question']
    answer_golden = 'True Answer: {ans}'.format(ans=line['answer'])

    conversation = ['Conversation:']
    for query, response in zip(line['user_queries'], line['lm_responses']):
        response = response.split('\n')
        valid = []
        for idx, r in enumerate(response):
            if 'context:' in r.lower():
                break
            if '' == r.strip():
                continue
            valid.append(r)
        response = '\n'.join(valid)
        turn = 'User: {up}\nAI Assistant: {ar}'.format(up=query, ar=response)
        conversation.append(turn)
    conversation = '\n'.join(conversation)

    answer_user = 'User Answer: {ans}'.format(ans=line['user_answer'])

    message = '\n\n'.join([question_text, answer_golden, conversation, answer_user])
    line['model_message'] = message
    return line

def extract(line):
    messages = [{'role': 'system', 'content': sys_message}]

    global sys_prompt
    if 'general' != config['persona']:
#        worker_id = group['worker_id'].unique()[0]
#        rate = workers.loc[workers['worker_id'] == worker_id]['rate'].unique()[0]
#        if rate > 0.6:
#            persona_type = 'type5'
#        elif rate > 0.3:
#            persona_type = 'type6'
#        else:
#            persona_type = 'type4'
#        persona_type = 'type4' if rate < 0.7 else 'type5'
#        sys_prompt = '\n'.join([persona_eval[persona_type], sys_prompt])
        sys_prompt = persona_eval_sys[config['persona']]

    line = extract_line(line)
    prompt = '\n\n'.join([line['model_message'], sys_prompt])
    messages.append({'role': 'user', 'content': prompt})

    if 'claude' == config['eval_model'][:6]:
        prediction = anthropic_api.call(messages)
    elif 'gpt-' == config['eval_model'][:4]:
        prediction = openai_api.call_chat(messages, config['eval_model'])
    else:
        prompt = '\n\n'.join([messages[0]['content'], messages[1]['content']])
        prediction = openai_api.call_completion(prompt, config['eval_model'])
    line = {'worker_id': line['id'],
            'question': line['question'],
            'prediction': prediction,
            'no_of_turns': len(line['user_queries']),
            }
    return line

predictions = []
for d in tqdm(data):
    predictions.append(extract(d))

#with open('../results/predictions_claude.json', 'w') as f:
#    json.dump(predictions, f, indent=2)
#file_name = '../results/predictions_{eval_model}_{persona}_prompt-2_first-60.json'.format(
#        eval_model=config['eval_model'],
#        persona=config['persona']
#        )
file_name = '../results/hotpot_predictions_{model1}_{model2}_{eval_model}{persona}_prompt-2.json'.format(
        model1=config['model1'],
        model2=config['model2'],
        eval_model=config['eval_model'],
        persona='_'+config['persona'] if 'general'!=config['persona'] else ''
        )
with open(file_name, 'w') as f:
    json.dump(predictions, f, indent=2)

