import os
import json
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--task', default='LaMP_2', type=str, help='Task')
parser.add_argument('--mode', default='dev', type=str, help='Mode')
args = parser.parse_args()

# user_path = "data/LaMP/{}/{}/{}_user_dict.json".format(args.task, args.mode, args.mode)
user_path = "data/LaMP-42/{}/42_test_history.jsonl".format(args.task, args.mode)
user_list = []
with open(user_path, 'r') as f:
    for line in f:
        user = json.loads(line)
        if user['id'] not in user_list:
            user_list.append(user['id'])
# print(user_list)
print(len(user_list))
# input()
file_path = "data/LaMP/{}/{}/{}_outputs.json".format(args.task, args.mode, args.mode)
with open(file_path, 'r') as f:
    for line in f:
        solution_list = json.loads(line)
solution_list = solution_list['golds']
solution_dict = {}
for sol in solution_list:
    solution_dict[sol['id']] = sol['output']

file_path = "data/LaMP/{}/{}/{}_questions.json".format(args.task, args.mode, args.mode)
with open(file_path, 'r') as f:
    for line in f:
        data_list = json.loads(line)
extended_data_list = []
extended_sol_list = []
extended_query_list = []
extended_querysol_list = []
for data in data_list:
    # print(data.keys())
    # input()
    if data['id'] in user_list:
        extended_query_list.append(data)
        extended_querysol_list.append({'id': data['id'], 'output': solution_dict[data['id']]})
        for idx in range(len(data['profile'])):
            profile = data['profile'][idx]
            others = data['profile'][:idx] + data['profile'][idx+1:]
            if args.task == 'LaMP_2':
                profile_data = {'id': profile['id'], 'input': profile['text'], 'profile': others}
                sol = {'id': profile['id'], 'output': profile['category']}
            elif args.task == 'LaMP_2_movie':
                profile_data = {'id': profile['id'], 'input': profile['description'], 'profile': others}
                sol = {'id': profile['id'], 'output': profile['tag']}
            elif args.task == 'LaMP_3':
                profile_data = {'id': profile['id'], 'input': profile['text'], 'profile': others}
                sol = {'id': profile['id'], 'output': profile['score']}
            elif args.task == 'LaMP_4':
                profile_data = {'id': profile['id'], 'input': profile['text'], 'profile': others}
                sol = {'id': profile['id'], 'output': profile['title']}
            elif args.task == 'LaMP_5':
                profile_data = {'id': profile['id'], 'input': profile['abstract'], 'profile': others}
                sol = {'id': profile['id'], 'output': profile['title']}
            extended_data_list.append(profile_data)
            extended_sol_list.append(sol)
            # print(extended_data_list[-1]['input'])
            # input()
# print(extended_data_list) 
output_path = "data/LaMP/{}/{}/{}_history_questions_extended.json".format(args.task, args.mode, args.mode)
with open(output_path, 'w') as f:
    json.dump(extended_data_list, f)

output_path = "data/LaMP/{}/{}/{}_history_outputs_extended.json".format(args.task, args.mode, args.mode)
extended_sol_list = {'golds': extended_sol_list}
with open(output_path, 'w') as f:
    json.dump(extended_sol_list, f)

output_path = "data/LaMP/{}/{}/{}_query_questions_extended.json".format(args.task, args.mode, args.mode)
with open(output_path, 'w') as f:
    json.dump(extended_query_list, f)

output_path = "data/LaMP/{}/{}/{}_query_outputs_extended.json".format(args.task, args.mode, args.mode)
extended_querysol_list = {'golds': extended_querysol_list}
with open(output_path, 'w') as f:
    json.dump(extended_querysol_list, f)
