import argparse
import sys
import yaml
import numpy as np
import pandas as pd
import json
import os
import asyncio
import time
import re
from tqdm import tqdm
import shutil
import multiprocessing
from multiprocessing import Queue
from multiprocessing import active_children
import sys
import openai
import signal
import random
import google.generativeai as palm

API_KEYS = []#chatgp api keys here
GPT4_KEYS = []#gpt4 api keys here 
BARD_KEYS = []#bard api keys here
QA_LOCATION = ""#location of a json file contining the clip data


def parse_args():
    """
    Parse the following arguments for a default parser
    """
    parser = argparse.ArgumentParser(
        description="Running experiments"
    )
    
    parser.add_argument(
        "--p",
        dest="prompt_file",
        help="path to prompt",
        default="",
        type=str,
    )
    parser.add_argument(
        "--e",
        dest="experiment",
        help="experiment name",
        default="",
        type=str,
    )
    parser.add_argument(
        "--d",
        dest="data",
        help="which data file to use",
        default="sample",
        type=str,
    )
    parser.add_argument(
        "--model",
        choices=['gpt-4', 'gpt-3.5-turbo', 'bard'],
        default="gpt-4",
        help="name of the model",
        type=str,
    )
    parser.add_argument(
        "-o",
        dest="optimize",
        action='store_true',
        help="whether it is a second run",
    )
    parser.add_argument(
        '--keep', 
        nargs="+", 
        type=str, 
        default=['q'],
        help="what to keep from the first run"
    )
    parser.add_argument(
        "--context",
        dest="context",
        help="context file name for bard",
        default="",
        type=str,
    )
    parser.add_argument(
        "--examples",
        dest="examples",
        help="examples file for bard",
        default="",
        type=str,
    )
    return parser.parse_args()

"""
Some minor narration formatting changes 
"""
def filter_narration(narration_text):
    result = narration_text
    
    result = result.strip()
    
    if len(result) == 0:
        return None
    
    if narration_text[0] == "#":
        result = result[2:]
        
    result = result.replace('#unsure', '')
    result = result.replace('#Unsure', '')
        
    result = result.strip()
    
    if result[-1] == ".":
        result = result[:-1]
        
    return result

"""
Given clip dict get the clips narration 
"""
def get_clip_narrations(clip):
    clip_id = clip["clip_id"]
    clip_narrations = clip["clip_narrations"]#[:300]
    curr_start_sec = clip["starting_sec"]
    
    current_word_count = prompt_word_count
    narration_text = "\nTimestamps and narrations:\n"
    for narration in clip_narrations:

        narration_time = narration['timestamp_sec'] - curr_start_sec
        unfiltered_narration_text = narration["narration_text"]
        filtered_narration_text = filter_narration(unfiltered_narration_text)

        if filtered_narration_text is None:
            continue
        
        narration_text += str(int(narration_time))
        narration_text += " - "
        narration_text += filtered_narration_text
        narration_text += "\n"
        current_word_count += 3
        current_word_count += len(filtered_narration_text.split(" "))

        # limit word count in order to prevent cases with huge amount of narrations
        if args.model == "gpt-4":
            if current_word_count > 2300:
                break
        else:
            if current_word_count > 1300:
                break
        
    return narration_text[:-1]

"""
Give instructions to the chatgpt/ gpt4 what output formar should be
"""
def generate_output_format(output_type):
    output_format = '\n\n'
    
    if "s" in output_type:
        output_format += 'When announcing the summary please label summary as "Summary: "\n'
        
    if "q" in output_type:
        output_format += 'When announcing the question please label each question as "Question 1,2,3: [full question]"\n'
    
    if "a" in output_type or "w" in output_type:
        output_format += 'Do not use letters for the answer choices\n'
        
    if "a" in output_type:
        output_format += 'Print each correct answer exactly as "Correct answer: [full answer]"\n'
        
    if "w" in output_type:
        output_format += 'Please print each wrong answer on a new line and print each wrong answer as "Wrong answer 1,2,3,4: [full answer]"\n'
        
    output_format += '\n'
    return output_format

"""
Based on the input specified in the prompt plug it in the prompt
"""
def form_input(curr_prompt, input_type, data):

    if "n" in input_type:
        curr_prompt += data["n"]
         
    if "s" in input_type:
        if "s" not in data:
            return None
        
        curr_prompt += "\nSummary:"
        curr_prompt += data["s"]
        
    if "q" in input_type:
        if "q" not in data:
            return None
            
    if "a" in input_type:
        if "a" not in data:
            return None
        
    if "q" in input_type:
        curr_prompt += "\n\nQuestions:" 
        for q_i in range(3):
            curr_prompt += "\n"
            curr_prompt += "\nQuestion: "
            curr_prompt += data["q"][q_i]
            if "a" in input_type:
                curr_prompt += "\nCorrect answer: "
                curr_prompt += data["a"][q_i]
        
    return curr_prompt
    
def form_input_bard(curr_prompt, input_type, data):
        
    if "s" in input_type:
        if "s" not in data:
            return None
        
        curr_prompt += "\nSummary:"
        curr_prompt += data["s"]
        
    if "q" in input_type:
        if "q" not in data:
            return None
            
    if "a" in input_type:
        if "a" not in data:
            return None
        
    if "q" in input_type:
        curr_prompt += "\n\nQuestions:" 
        for q_i in range(3):
            curr_prompt += "\n"
            curr_prompt += f"\nQuestion {q_i+1}:"
            curr_prompt += data["q"][q_i]
            if "a" in input_type:
                curr_prompt += "\nCorrect answer: "
                curr_prompt += data["a"][q_i]

    if "n" in input_type:
        curr_prompt += "\n\n"
        curr_prompt += data["n"]
        
    return curr_prompt

"""
Check if outout has all the required output
true - everything is good
false - not good
"""
def check_output(prompt_output, output_type, data):
    
    if "s" in output_type:
        sum_occurences = re.findall(r"summary[a-zA-Z1-9\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        if len(sum_occurences) != 1:
            return False
    
    if "q" in output_type:
        q_occurences = re.findall(r"question[a-rt-zA-RT-Z\s]*[1-9]*[a-rt-zA-RT-Z\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        if len(q_occurences) != 3:
            return False
        
    if "a" in output_type:
        a_occurences = re.findall(r"correct answer[a-zA-Z1-9\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        if len(a_occurences) != 3:
            return False
        
    if "w" in output_type:
        w1_occurences = re.findall(r"wrong answer [1-3.]*1:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w2_occurences = re.findall(r"wrong answer [1-3.]*2:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w3_occurences = re.findall(r"wrong answer [1-3.]*3:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w4_occurences = re.findall(r"wrong answer [1-3.]*4:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        
        if len(w1_occurences) != 3:
            return False
        if len(w2_occurences) != 3:
            return False
        if len(w3_occurences) != 3:
            return False
        if len(w4_occurences) != 3:
            return False
    return True

def check_output_bard(prompt_output, output_type, data):
    
    if "s" in output_type:
        sum_occurences = re.findall(r"summary[a-zA-Z1-9\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        if len(sum_occurences) != 1:
            return False
    
    if "q" in output_type:
        q_occurences = re.findall(r"question[a-rt-zA-RT-Z\s]*[1-9]*[a-rt-zA-RT-Z\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        if len(q_occurences) != 3:
            return False
        
    if "a" in output_type:
        a_occurences = re.findall(r"correct answer[a-zA-Z1-9\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        if len(a_occurences) < 3:
            return False
        
    if "w" in output_type:
        w1_occurences = re.findall(r"wrong answer [1-3.]*1:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w2_occurences = re.findall(r"wrong answer [1-3.]*2:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w3_occurences = re.findall(r"wrong answer [1-3.]*3:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w4_occurences = re.findall(r"wrong answer [1-3.]*4:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        
        if len(w1_occurences) < 3:
            return False
        if len(w2_occurences) < 3:
            return False
        if len(w3_occurences) < 3:
            return False
        if len(w4_occurences) < 3:
            return False
    return True

"""
Parses the chatgpt output into the data dict
"""
def parse_output(prompt_output, output_type, data):
    
    if "s" in output_type:
        sum_occurences = re.findall(r"summary[a-zA-Z1-9\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        data["s"] = sum_occurences[0]
    
    if "q" in output_type:
        q_occurences = re.findall(r"question[a-rt-zA-RT-Z\s]*[1-9]*[a-rt-zA-RT-Z\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        data["q"] = q_occurences
        
    if "a" in output_type:
        a_occurences = re.findall(r"correct answer[a-zA-Z1-9\s]*:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        data["a"] = a_occurences
        
    if "w" in output_type:
        w1_occurences = re.findall(r"wrong answer [1-3.]*1:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w2_occurences = re.findall(r"wrong answer [1-3.]*2:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w3_occurences = re.findall(r"wrong answer [1-3.]*3:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")
        w4_occurences = re.findall(r"wrong answer [1-3.]*4:[\n\s]+(.*?)[\n]",  prompt_output.strip().lower() + "\n")     
        data["w"] = [[w1_occurences[q_i], w2_occurences[q_i], w3_occurences[q_i], w4_occurences[q_i]] for q_i in range(3)]
        

"""
In case you want to continue a conversation of a previous result
"""
def form_redundant_output(output_type, data):
    output = ""
    if "s" in output_type:
        output += "\nSummary:"
        output += data["s"]
        
    if "q" in output_type or "a" in output_type:
        output += "\n" 
        for q_i in range(3):
            output += "\n"
            output += f"\nQuestion {q_i + 1}: "
            output += data["q"][q_i]
            if "a" in output_type:
                output += "\nCorrect answer: "
                output += data["a"][q_i]
    return output.strip()

def check_redundancy(output_type, data):
    no_need_to_run = True
    
    if output_type.strip() == "":
        no_need_to_run = False
        return no_need_to_run
    
    if "q" in output_type:
        if not "q" in data:
            no_need_to_run = False
    
    if "s" in output_type:
        if not "s" in data:
            no_need_to_run = False
            
    if "a" in output_type:
        if not "a" in data:
            no_need_to_run = False
            
    if "w" in output_type:
        if not "w" in data:
            no_need_to_run = False

    return no_need_to_run

"""
Simulates a chat with bard
"""
def simulate_chat_bard(process_id, data):
    response = None
    data["output"] = ""
    conv_end = True
    for prompt_i in range(len(prompt_list)):
        if not conv_end:
            return None
        prompt = prompt_list[prompt_i]
        
        # getting the prompt type
        prompt_type = prompt[prompt.find(":") + 1:prompt.find("\n")]
        prompt = prompt[prompt.find("\n") + 1:]

        # getting what the input should be
        input_type = prompt[prompt.find(":") + 1:prompt.find("\n")]
        prompt = prompt[prompt.find("\n") + 1:]

        # getting what the output should be
        output_type = prompt[prompt.find(":") + 1:prompt.find("\n")]
        prompt = prompt[prompt.find("\n") + 1:]
    
        # Prepare the prompt
        prompt = form_input_bard(prompt, input_type, data) + "\n\n"
        prompt += generate_output_format_bard_1(output_type, data)
        prompt = prompt.strip()
        
        no_need_to_run = check_redundancy(output_type, data)

        data["output"] += prompt
        data["output"] += ("\n" + "-" * 10 + "\n")

        # run the api
        conv_end = False
        for _ in range(3):
            try:
                if no_need_to_run:
                    output = form_redundant_output(output_type, data)
                else:
                    time.sleep(10)
                    if prompt_i == 0:
                        if context != "" and examples != "":
                            response = palm.chat(context = context, examples=2*examples, messages=prompt, candidate_count = 12)
                        elif context == "" and examples == "":
                            response = palm.chat(messages=prompt)
                        elif context != "" and examples == "":
                            response = palm.chat(context = context, messages=prompt)
                        elif context == "" and examples != "":
                            response = palm.chat(examples=4*examples, messages=prompt, candidate_count = 8)
                    else:
                        response = response.reply(prompt)
                    output = response.last
                    candidates = response.candidates
                if check_output_bard(output, output_type, data):
                    parse_output(output, output_type, data)
                    data["output"] += output 
                    data["output"] += ("\n" + "-" * 10 + "\n")
                    conv_end = True
                    break
                else:
                    print("------------------------------------------------------------------")
                    for cand in candidates:
                        if check_output_bard(cand['content'], output_type, data):
                            parse_output(cand['content'], output_type, data)
                            data["output"] += output 
                            data["output"] += ("\n" + "-" * 10 + "\n")
                            conv_end = True
                            print("used candidate")
                            break
                    if conv_end:
                        break
                    save_bad_format(data["clip_id"], data["output"] + "\n" + output)
                    print(f"Formar error at process {process_id}, clip {data['clip_id']}", flush=True)        
            except Exception as e:
                print(e)
                print(f"Something wrong with server at process {process_id}, clip {data['clip_id']}", flush=True)
                time.sleep(30)             
    if not conv_end:
        return None
    return data

"""
Simulates a chat with gpt
"""
def simulate_chat_gpt(process_id, data):
    curr_messages = [{"role": "system", "content": "You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible."}]
    data["output"] = ""
    conv_end = True
    
    for prompt_i in range(len(prompt_list)):
        if not conv_end:
            return None
        prompt = prompt_list[prompt_i]
        
        # getting the prompt type
        prompt_type = prompt[prompt.find(":") + 1:prompt.find("\n")]
        prompt = prompt[prompt.find("\n") + 1:]

        # getting what the input should be
        input_type = prompt[prompt.find(":") + 1:prompt.find("\n")]
        prompt = prompt[prompt.find("\n") + 1:]

        # getting what the output should be
        output_type = prompt[prompt.find(":") + 1:prompt.find("\n")]
        prompt = prompt[prompt.find("\n") + 1:]
    
        # Prepare the prompt
        prompt += generate_output_format(output_type)
        prompt = form_input(prompt, input_type, data)
        
        no_need_to_run = check_redundancy(output_type, data)
        
        # add it into the conversation
        curr_messages.append({"role": "user", "content": prompt}) 
        data["output"] += prompt
        data["output"] += ("\n" + "-" * 10 + "\n")
        
        # run the api
        conv_end = False
        for trial in range(3):
            try:
                if no_need_to_run:
                    output = form_redundant_output(output_type, data)
                else:
                    openai_output = openai.ChatCompletion.create(
                            model=model_name,
                            messages=curr_messages)
                    output = openai_output["choices"][0]["message"]["content"]
                
                if check_output(output, output_type, data):
                    parse_output(output, output_type, data)
                    curr_messages.append({"role": "assistant", "content": output})
                    data["output"] += output 
                    data["output"] += ("\n" + "-" * 10 + "\n")
                    conv_end = True
                    break
                else:
                    save_bad_format(data["clip_id"], data["output"] + "\n" + output)
                    print(f"Formar error at process {process_id}, clip {data['clip_id']}", flush=True)
            except Exception as e:
                print(e)
                time.sleep(1)
                print(f"Something wrong with server at process {process_id}, clip {data['clip_id']}", flush=True)
    if not conv_end:
        return None
    return data

"""
Format the gpt output in txt file
"""
def format_output(data):
    output_format = ''
    
    if "s" in data:
        
        output_format += "\nSummary:"
        output_format += data["s"]
        
    if "q" in data:
        output_format += ""
        for q_i in range(3):
            output_format += "\n"
            output_format += "\nQuestion: "
            output_format += data["q"][q_i]
            if "a" in data:
                output_format += "\nCorrect answer: "
                output_format += data["a"][q_i]
                
            if "w" in data:
                for w_i in range(4):
                    output_format += "\nWrong answer " + str(w_i + 1) + ": "
                    output_format += data["w"][q_i][w_i]
    return output_format

"""
If you get bad output save it to check if it is bad parsing
"""
def save_bad_format(clip_id, output):
    to_write = open(f"{experiment_folder}/bad_formatting/{clip_id}.txt", 'w')
    to_write.write(output)
    to_write.close()

"""
Saving good output in txt format
"""
def save_result(clip, data):
    
    intro_text = f"Here is the qa request output \nYou can find the link here: {data['clip_url']}"
    divider1 = "\n---------------------------------------------\n"
    output_formatted = format_output(data)
    divider2 = "\n---------------------------------------------\n"
    chat_txt = data["output"]
    
    text_to_save = intro_text + divider1 + output_formatted + divider2 + chat_txt
    to_write = open(f"{experiment_folder}/output_txt/{data['clip_id']}.txt", 'w')
    to_write.write(text_to_save)
    to_write.close()
    
def fill_previous_result(clip_id, data):
    if str(clip_id) in prev_results:
        for result_type in results_to_keep:
            if result_type in prev_results[str(clip_id)]:
                data[result_type] = prev_results[str(clip_id)][result_type]
        
    
"""
This will be once of the processes that run in parallel in mutiprocessing
"""
def run_one_process(process_id, api_key, clip_index_list, all_results):
    if model_name == "bard":
        palm.configure(api_key=api_key)
    else:
        openai.api_key = api_key
    
    process_results = {}
    process_results_path = f"{experiment_folder}/output_json/{process_id}.json"
    if os.path.isfile(process_results_path):
        process_results_f = open(process_results_path) 
        process_results = json.load(process_results_f)
    
    while True:
        if clip_index_list.empty():
            break
            
        clip_index = clip_index_list.get()
        clip = dataset[clip_index]
        data = {}
        
        clip_id = clip["clip_id"] 
        data["clip_id"] = clip_id
        
        if str(clip_id) in all_results:
            data = all_results[str(clip_id)]
            process_results[str(clip_id)] = data
            
            with open(process_results_path, 'w') as f:
                json.dump(process_results, f)
            continue
        
        video_uid = clip["video_uid"] 
        data["clip_url"] = f"https://gpt-qa.s3.amazonaws.com/clip_generation/all_videos/{video_uid}_{clip_id}.mp4"
        
        narrations = get_clip_narrations(clip)
        data["n"] = narrations
        
        fill_previous_result(clip_id, data)
        
        if model_name == "bard":
            res = simulate_chat_bard(process_id, data)
        else:
            res = simulate_chat_gpt(process_id, data)
        
        #saving the output
        if res is not None:
            process_results[str(clip_id)] = data
            save_result(clip, data)
            
            with open(process_results_path, 'w') as f:
                json.dump(process_results, f)

    print(f"Process {process_id}: done", flush=True)
    
"""
collect the result from all the proccesses into one json
"""
def collect_results():
    all_results = {}
    if os.path.isfile(f"{experiment_folder}/all_results.json"):
        all_results_f = open(f"{experiment_folder}/all_results.json") 
        all_results = json.load(all_results_f)
    all_processes_json = os.listdir(f"{experiment_folder}/output_json/")
    for process_json in all_processes_json:
        if ".ipynb_checkpoints" in process_json:
            continue
        process_results_f = open(f"{experiment_folder}/output_json/{process_json}") 
        process_results = json.load(process_results_f)
        all_results = {**all_results, **process_results}
        
    with open(f"{experiment_folder}/all_results.json", 'w') as f:
        json.dump(all_results, f)

"""
Randomly pick ten clips to manually check the output
"""
def pick_ten():
    random_ten_clips = random.sample(dataset, 10) if len(dataset) > 10 else dataset
    for clip in random_ten_clips:
        clip_id = clip["clip_id"]
        shutil.copyfile(f"{experiment_folder}/output_txt/{clip_id}.txt",
                        f"{experiment_folder}/manual_benchmark/{clip_id}.txt")
    
    
"""
one process that runs in parrallel and checks how much is done 
"""
def progress_tracker(initially_done, to_do, clip_index_list):
    global done_trigger
    print("Done: -1%. Time left: -1 seconds", flush=True, end='\r')
            
    start = time.time()
    while not done_trigger:
        time.sleep(1)
        sys.stdout.flush()
        done = to_do - clip_index_list.qsize()
            
        speed = done / (time.time() - start)
        rest = to_do - done
            
        if speed == 0:
            time_left = -1
        else:
            time_left = rest / speed
        
        print(f"Done: {100 * (initially_done + done) / len(dataset)}%. Time left: {int(time_left)} seconds", end='\r', flush = True)
    print("Process tracker done", flush=True)

def main():
    """
    Main function to spawn the train and test process.
    """
    all_results = {}
    # collect the result so far
    if os.path.isfile(f"{experiment_folder}/all_results.json"):
        all_results_f = open(f"{experiment_folder}/all_results.json") 
        all_results = json.load(all_results_f)

    all_processes_json = os.listdir(f"{experiment_folder}/output_json/")
    print(all_processes_json)
    for process_json in all_processes_json:
        if ".ipynb_checkpoints" in process_json:
            continue
        process_results_f = open(f"{experiment_folder}/output_json/{process_json}") 
        process_results = json.load(process_results_f)
        all_results = {**all_results, **process_results}
    
    if len(all_results) != 0:
        with open(f"{experiment_folder}/all_results.json", 'w') as f:
            json.dump(all_results, f)

    print(f"Initially done {len(all_results)}")
    initially_done = len(all_results)

    # set up the queue for the clips that were not processed
    clip_index_list = Queue()
    for clip_index in range(len(dataset)):
        clip_id = dataset[clip_index]["clip_id"]
        if str(clip_id) not in all_results:
            clip_index_list.put(clip_index)

    to_do = clip_index_list.qsize()
    print(f"To do {to_do}")

    # starting the processes
    progress = multiprocessing.Process(target=progress_tracker, args=(initially_done, to_do, clip_index_list))
    processes = []
    for process_id in range(num_processes):
        processes.append(multiprocessing.Process(target=run_one_process, 
                                                 args=(process_id,
                                                       key_list[process_id],
                                                       clip_index_list,
                                                       all_results)))

    # start the processes
    print("Processses set")
    for process_id in range(num_processes):
        processes[process_id].start()
    progress.start()
    print("Processses started")

    # Waiting for procceses to finish
    for process_id in range(num_processes):
        processes[process_id].join()
    done_trigger = True
    progress.terminate()
    print("Processses done")
    collect_results()
    pick_ten()      

def signal_handler(sig, frame):
    print("Stopping all the processes")
    active = active_children()
    for child in active:
        child.terminate()
    sys.exit(0)
                
        
if __name__ == "__main__":
    args = parse_args()
    qa_loc = QA_LOCATION##location of a json file contining the clip data
    dataset_path = f"{qa_loc}/{args.data}.json"
    
    dataset_file = open(dataset_path)
    dataset = json.load(dataset_file)

    context = ""
    context_file  = ""
    examples = "" 
    examples_file  = ""

    if args.optimize:
        prompt_name = args.prompt_file[args.prompt_file.find("/") + 1:-4]
        result_folder = f"results/{args.data}_results"
        prompt_folder = f"prompts/stage_2"

        if args.model == "bard":
            if args.context != "":
                context_file = open(f"{prompt_folder}/{args.context}")
                context = context_file.read()
                context_file = args.context[:args.context.rfind(".")]
    
            if args.examples != "":
                examples_file = open(f"{prompt_folder}/{args.examples}")
                examples = examples_file.read().split("---")
                examples_file = args.examples[:args.examples.rfind(".")]

        experiment_folder = f"{result_folder}/{args.experiment}/stage_2/{prompt_name}_{context_file}_{examples_file}"
        prev_results_f = open(f"{result_folder}/{args.experiment}/all_results.json")
        prev_results = json.load(prev_results_f)
        results_to_keep = args.keep
    else:
        result_folder = f"results/{args.data}_results"
        prompt_folder = f"prompts/stage_1"

        if args.model == "bard":
            if args.context != "":
                context_file = open(f"{prompt_folder}/{args.context}")
                context = context_file.read()
                context_file = args.context[:args.context.rfind(".")]
    
            if args.examples != "":
                examples_file = open(f"{prompt_folder}/{args.examples}")
                examples = examples_file.read().split("\n")
                examples_file = args.examples[:args.examples.rfind(".")]
            
        experiment_folder = f"{result_folder}/{args.experiment}"
        prev_results = {}
        results_to_keep = []

    if not os.path.exists(result_folder):
        os.mkdir(result_folder)
        
    if not os.path.exists(experiment_folder):
        os.mkdir(experiment_folder)
        os.mkdir(f"{experiment_folder}/output_json")
        os.mkdir(f"{experiment_folder}/output_txt")
        os.mkdir(f"{experiment_folder}/manual_benchmark")
        os.mkdir(f"{experiment_folder}/bad_formatting")
        os.mkdir(f"{experiment_folder}/annotations")
        if not args.optimize:
            os.mkdir(f"{experiment_folder}/stage_2")
        shutil.copyfile(f"{prompt_folder}/{args.prompt_file}", f"{experiment_folder}/prompt.txt")

        if args.context != "":
            shutil.copyfile(f"{prompt_folder}/{args.context}", f"{experiment_folder}/context.txt")

        if args.examples != "":
            shutil.copyfile(f"{prompt_folder}/{args.examples}", f"{experiment_folder}/examples.txt")

    prompt_file = open(f"{experiment_folder}/prompt.txt")
    prompt = prompt_file.read()
    prompt_list = prompt.split("\n" + "-" * 10 + "\n")
    prompt_word_count = sum([len(one_prompt.strip().split(" ")) for one_prompt in prompt_list])

    if args.model == "gpt-4":
        model_name = "gpt-4"
        key_list = GPT4_KEYS
    elif args.model == "gpt-3.5-turbo":
        model_name = "gpt-3.5-turbo"
        key_list = API_KEYS
    else:
        model_name = "bard"
        key_list = BARD_KEYS

    num_processes = len(key_list)
    done_trigger = False
    processes = []
    signal.signal(signal.SIGINT, signal_handler)
    main()