import os
import argparse
import json
import numpy as np

def cal_success(path, save_dirs):
    successes = []
    for save_dir in save_dirs:
        file_name = os.listdir(os.path.join(path, save_dir))[0]
        with open(os.path.join(path, save_dir, file_name)) as f:
            result = json.load(f)
            success = result[0]["success"]
            successes.append(success)
    print("Average Success Rate:", round(np.mean(successes)*100,2), "STD Success Rate:", round(np.std(successes)*100,2))
    return np.mean(successes)*100

parser = argparse.ArgumentParser()
parser.add_argument("--save-path", type=str, default="storage/experiment_output/1681191019")
parser.add_argument("--tag", type=str, default="objectnav_ithor_rgb_prompt_attm_clip_vit32gru_ddppo_autotest")
parser.add_argument('--split-num', type=int, nargs='+', default=[7, 9, 10, 10, 9]) # [7, 9, 10, 10, 9]
args = parser.parse_args()

print(args.save_path)
print(args.tag)
print(args.split_num)
save_dirs = os.listdir(os.path.join(args.save_path, "metrics", args.tag))
save_dirs = sorted(save_dirs)

if len(args.split_num)>1:
    start = 0
    for i in args.split_num:
        end = start + i
        files = save_dirs[start:end]
        cal_success(os.path.join(args.save_path, "metrics", args.tag), files)
        start += i
else:
    cal_success(os.path.join(args.save_path, "metrics", args.tag), save_dirs)