import argparse
import os

def main():
    # 创建 ArgumentParser 对象
    parser = argparse.ArgumentParser(description="A simple program with argparse")

    # 添加参数
    parser.add_argument('--ts', type=str, help='timestep: 100k, 500k, 1m, 2m')
    parser.add_argument('--domain', type=str, default=None, help='domain: walker, quadruped, jaco')
    parser.add_argument('--device', type=str, help='cuda:x')
    parser.add_argument('--method', type=str, default=None, help='lbs, aps, apt and so on')
    parser.add_argument('--seed', type=str, default=None, help='seed')
    parser.add_argument('--task', type=str, default=None, help='task')
    parser.add_argument('--grayscale', type=int, default=1)
    parser.add_argument('--random_video', type=int, default=0)
    # python z_utils/finetune.py --seed 1 --domain jaco --ts 100k --device cuda:2 --method lbs

    # 解析命令行参数
    args = parser.parse_args()

    # 获取参数值并打印
    ts = args.ts
    args_domain = args.domain
    device = args.device
    args_method = args.method
    args_seed = args.seed
    args_task = args.task
    args_grayscale = args.grayscale
    args_random_video = True if args.random_video else False
    # print(args_grayscale)
    assert ts is not None
    # assert domain is not None
    assert device is not None

    project_dir = '/data/huangkc/works/se_explore/mastering-urlb-main/'
    # work_dir = '/data/huangkc/works/se_explore/mastering-urlb-main/assets/' + domain + '/pretrain_models/'
    seeds = ['1', '2', '3']
    all_task = {'walker':    ['stand', 'walk', 'run', 'flip'],
                'quadruped': ['stand', 'walk', 'run', 'jump'],
                'jaco':      ['reach_top_left', 'reach_top_right', 'reach_bottom_left', 'reach_bottom_right']}
    
    agents = {'icm': 'icm_dreamer', 'plan': 'plan2explore', 'rnd': 'rnd_dreamer',
              'lbs': 'lbs_dreamer', 'apt': 'apt_dreamer', 'diayn': 'diayn_dreamer',
              'aps': 'aps_dreamer', 'se': 'se_explore'}
    domains = ['walker', 'quadruped', 'jaco']


    for domain in domains:
        if args_domain is not None and domain != args_domain:
            continue
        
        work_dir = '/data/huangkc/works/se_explore/mastering-urlb-main/assets/' + domain + '/pretrain_models/'
        tasks = all_task.get(domain)

        for file_name in os.listdir(work_dir):
            if ts not in file_name:
                continue
            elif args_method is not None and args_method not in file_name:
                continue

            for task in tasks:
                if args_task is not None and task != args_task:
                    continue

                for seed in seeds:
                    if args_seed is not None and seed != args_seed:
                        continue

                    # print(file_name, task, seed)
                    method = file_name.split('_')[0]
                    agent = agents.get(method)
                    log_dir = '/data/huangkc/works/se_explore/mastering-urlb-main/assets/' + domain + '/' + task
                    assert os.path.exists(log_dir)

                    if args_random_video:
                        log_dir = f'{log_dir}/{method}/{ts}_{seed}_random'
                    elif args_grayscale:
                        log_dir = f'{log_dir}/{method}/{ts}_{seed}'
                    else:
                        log_dir = f'{log_dir}/{method}/{ts}_{seed}_color'
                    if not os.path.exists(log_dir):
                        os.makedirs(log_dir)
                    else:
                        continue

                    model_path = work_dir + file_name
                    # 创建dreamer_finetune.py和对应yaml文件的temp版本
                    py_path_origin = project_dir + 'dreamer_finetune.py'
                    py_path_new = project_dir + f'z_dreamer_finetune{method}{seed}{domain}{ts}{task}.py'
                    yaml_path_origin = project_dir + 'dreamer_finetune.yaml'
                    yaml_path_new = project_dir + f'z_dreamer_finetune{method}{seed}{domain}{ts}{task}.yaml'

                    with open(py_path_origin, 'r') as file1:
                        py_str1 = file1.read()
                    py_str2 = py_str1.replace("config_name='dreamer_finetune'", f"config_name='z_dreamer_finetune{method}{seed}{domain}{ts}{task}'")
                    with open(py_path_new, 'w') as file2:
                        file2.write(py_str2)

                    with open(yaml_path_origin, 'r') as file3:
                        yaml_str = file3.read()
                    yaml_str = yaml_str.replace("./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name}", log_dir)
                    yaml_str = yaml_str.replace("./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}", log_dir)
                    with open(yaml_path_new, 'w') as file4:
                        file4.write(yaml_str)

                    if agent == 'se_explore':
                        agent = 'plan2explore separate_wm=True'
                    sh_str = f"python {py_path_new} configs=dmc_pixels agent={agent} task={domain}_{task} seed={seed} device={device} snapshot_path={model_path} grayscale={args_grayscale} random_video={args_random_video}"
                    os.system(sh_str)
                    # print(sh_str)
                    try:
                        os.remove(py_path_new)
                        os.remove(yaml_path_new)
                    except:
                        pass

if __name__ == "__main__":
    main()
