import os
import string
from collections import OrderedDict
import shutil
import argparse
import pickle
import subprocess


srun_template = string.Template("""\
sbatch --job-name $job_name \\
       --mem 50G \\
       --partition dev \\
       --gres gpu:1 \\
       --nodes 1 \\
       --ntasks-per-node 1 \\
       --cpus-per-task 40 \\
       --time 2880 \\
       --output $sweep_folder/$job_name/stdout.log \\
       --error $sweep_folder/$job_name/stderr.log \\
       $train_script
""")


train_template = string.Template("""\
#!/bin/bash
python -u $root/scripts/rl/train.py \\
--num_thread $num_thread \\
--batchsize $batchsize \\
--gpu 0 \\
--save_dir $save_dir \\
--lr $lr \\
--fow $fow \\
--ppo $ppo \\
--adversarial $adv \\
--win_rate_decay $win_decay \\
""")


root = os.path.abspath(__file__)
root = os.path.dirname(os.path.dirname(os.path.dirname(root)))
print('root:', root)

default_args = {
    'root': root,
    'num_thread': 1024,
    'batchsize': 128,
    'lr': 6.25e-5,
    'fow': 1,
    'ppo': 0,
    'adv': 0,
    'win_decay': 0.99,
}

# config experiments
sweep_folder = 'sweep_rule/a2c_new_rule_sample'
variables = OrderedDict([
    # ('fow', [1]),
    ('num_thread', [512]),
    ('batchsize', [64]),
    ('lr', [1e-3]),
    ('adv', [1]),
    ('win_decay', [0.9, 0.99, 0.999]),
])

sweep_folder = os.path.join(root, 'scripts/rl', sweep_folder)
if not os.path.exists(sweep_folder):
    print('make dir:', sweep_folder)
    os.makedirs(sweep_folder)
# copy the the sweep script to the folder
shutil.copy2(os.path.realpath(__file__), sweep_folder)


# create list of args for sweeping
job_names = ['']
sweep_args = [default_args]
for var in variables:
    new_job_names = []
    new_sweep_args = []
    for val in variables[var]:
        for job_name, args in zip(job_names, sweep_args):
            new_job_name = job_name
            if len(new_job_name) > 0:
                new_job_name += '_'
            new_job_name += '%s%s' % (var, val)
            new_job_names.append(new_job_name)

            new_args = args.copy()
            assert var in new_args, var
            new_args[var] = val
            new_sweep_args.append(new_args)
    job_names = new_job_names
    sweep_args = new_sweep_args

assert len(job_names) == len(sweep_args)


# generate sweep files (srun, train, eval) for each job
srun_files = []
for job_name, arg in zip(job_names, sweep_args):
    exp_folder = os.path.join(sweep_folder, job_name)
    if not os.path.exists(exp_folder):
        os.makedirs(exp_folder)
    print(job_name)
    print(arg)
    # train script
    arg['save_dir'] = exp_folder
    train = train_template.substitute(arg)
    train_file = os.path.join(sweep_folder, job_name, 'train.sh')
    with open(train_file, 'w') as f:
        f.write(train)
    # eval script
    arg['save_dir'] = os.path.join(exp_folder, 'replay_$model_id')
    # eval_ = eval_template.safe_substitute(arg) # $model_file is not replaced
    # eval_file = os.path.join(sweep_folder, job_name, 'eval.sh')
    # with open(eval_file, 'w') as f:
    #     f.write(eval_)
    # srun script
    srun_arg = {
        'sweep_folder': sweep_folder,
        'job_name': job_name,
        'train_script': train_file,
    }
    srun = srun_template.substitute(srun_arg)
    srun_file = os.path.join(sweep_folder, job_name, 'srun.sh')
    with open(srun_file, 'w') as f:
        f.write(srun)
        srun_files.append(srun_file)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='analysis')
    parser.add_argument('--dry', action='store_true')
    args = parser.parse_args()

    if not args.dry:
        for srun_file in srun_files:
            p = subprocess.Popen(
                ["sh", srun_file],
                cwd=os.path.join(root, 'scripts'))
