import os
import string
from collections import OrderedDict
import shutil
import argparse
import pickle
import subprocess


srun_template = string.Template("""\
sbatch --job-name $job_name \\
       --mem 100G \\
       --partition dev,priority \\
       --gres gpu:1 \\
       -C volta32gb \\
       --nodes 1 \\
       --ntasks-per-node 1 \\
       --cpus-per-task 10 \\
       --time 1440 \\
       --output $sweep_folder/$job_name/stdout.log \\
       --error $sweep_folder/$job_name/stderr.log \\
       --comment nips_ddl \\
       $train_script
""")


train_template = string.Template("""\
#!/bin/bash
python train.py \\
  --train_dataset /private/home/hengyuan/rts-replays/dataset3_ppcmd5/train.json \\
  --val_dataset /private/home/hengyuan/rts-replays/dataset3_ppcmd5/val.json \\
  --inst_dict_path /private/home/hengyuan/rts-replays/dataset3_ppcmd5/dict.pt \\
  --emb_field_dim $emb_dim \\
  --prev_cmd_dim $prev_cmd_dim \\
  --num_conv_layers $num_conv_layers \\
  --num_post_layers $num_post_layers \\
  --conv_hid_dim $conv_hid_dim \\
  --army_out_dim $army_out_dim \\
  --other_out_dim $other_out_dim \\
  --money_hid_dim $money_hid_dim \\
  --money_hid_layer $money_hid_layer \\
  --conv_dropout $conv_dropout \\
  --rnn_word_emb_dim $rnn_word_emb_dim \\
  --word_emb_dropout $word_emb_dropout \\
  --inst_hid_dim $inst_hid_dim \\
  --inst_encoder_type $inst_encoder_type \\
  --model_folder $model_folder \\
  --batch_size 128 \\
  --gpu 0 \\
  --grad_clip $grad_clip \\
  --lr $lr \\
  --optim $optim \\
  --epochs 50 \\
  --use_hist_inst $use_hist_inst \\
  --pos_dim $pos_dim \\
  --prev_cmd_rnn $pcmd_rnn \\
  --seed $seed \\
""")


root = os.path.dirname(os.path.abspath(__file__))

default_args = {
    'emb_dim': 32,
    'prev_cmd_dim': 64,
    'num_conv_layers': 3,
    'num_post_layers': 2,
    'conv_hid_dim': 128,
    'army_out_dim': 128,
    'other_out_dim': 128,
    'money_hid_dim': 128,
    'money_hid_layer': 1,
    'conv_dropout': 0.0,
    'rnn_word_emb_dim': 128,
    'word_emb_dropout': 0,
    'inst_hid_dim': 128,
    'inst_encoder_type': 'bow',
    'ctype_hid_dim': 512,
    'chead_hid_dim': 256,
    'grad_clip': 0.5,
    'lr': 2e-3,
    'optim': 'adamax',
    'use_hist_inst': 0,
    'pos_dim': 32,
    'pcmd_rnn': 0,
    'seed': 1
}


# sweep_folder = 'model_rnn15_dset3_p5'
# variables = OrderedDict([
#     ('inst_encoder_type', ['lstm']),
#     ('prev_cmd_dim', [64, 128]),
#     ('num_conv_layers', [3]),
#     ('num_post_layers', [1]),
#     ('conv_hid_dim', [256]),
#     ('army_out_dim', [256]),
#     ('other_out_dim', [128]),
#     ('conv_dropout', [0.1]),
#     ('rnn_word_emb_dim', [64]),
#     ('word_emb_dropout', [0.25]),
#     ('inst_hid_dim', [128]),
#     ('use_hist_inst', [1]),
#     ('pos_dim', [16, 32]),
#     ('pcmd_rnn', [1]),
#     ('seed', [1, 3, 9]),
# ])


sweep_folder = 'model_bow14_dset3_p5'
variables = OrderedDict([
    ('inst_encoder_type', ['bow']),
    ('prev_cmd_dim', [64]),
    ('num_conv_layers', [3]),
    ('num_post_layers', [1]),
    ('conv_hid_dim', [256]),
    ('army_out_dim', [256]),
    ('other_out_dim', [128]),
    ('conv_dropout', [0.1]),
    # ('rnn_word_emb_dim', [64]),
    ('word_emb_dropout', [0, 0.25, 0.5]),
    ('inst_hid_dim', [128, 256]),
    ('use_hist_inst', [1]),
    ('pos_dim', [16, 32]),
    ('pcmd_rnn', [1]),
    ('seed', [1, 3, 9]),
])


sweep_folder = os.path.join(root, 'sweep_executor', sweep_folder)
print('sweep_folder', sweep_folder)
if not os.path.exists(sweep_folder):
    os.makedirs(sweep_folder)
# copy the the sweep script to the folder
shutil.copy2(os.path.realpath(__file__), sweep_folder)


def get_acronym(name):
    words = name.split('_')
    acro = [w[0] for w in words]
    acro = ''.join(acro)
    return acro


# create list of args for sweeping
job_names = ['']
sweep_args = [default_args]
for var in variables:
    new_job_names = []
    new_sweep_args = []
    for val in variables[var]:
        for job_name, args in zip(job_names, sweep_args):
            new_job_name = job_name
            if len(variables[var]) > 1:
                if len(new_job_name) > 0:
                    new_job_name += '_'
                new_job_name += '%s%s' % (get_acronym(var), val)
            new_job_names.append(new_job_name)

            new_args = args.copy()
            assert var in new_args, var
            new_args[var] = val
            new_sweep_args.append(new_args)
    job_names = new_job_names
    sweep_args = new_sweep_args

assert len(job_names) == len(sweep_args)


# generate sweep files (srun, train, eval) for each job
srun_files = []
for i, arg in enumerate(sweep_args):
    if arg['army_out_dim'] > arg['conv_hid_dim']:
        continue
    job_name = job_names[i]#'config%d' % i
    print(job_name)
    print(arg)

    model_folder = os.path.join(sweep_folder, job_name)
    if not os.path.exists(model_folder):
        os.makedirs(model_folder)

    # train script
    arg['model_folder'] = model_folder
    train = train_template.substitute(arg)
    train_file = os.path.join(model_folder, 'train.sh')
    with open(train_file, 'w') as f:
        f.write(train)

    # srun script
    srun_arg = {
        'sweep_folder': sweep_folder,
        'job_name': job_name,
        'train_script': train_file,
    }
    srun = srun_template.substitute(srun_arg)
    srun_file = os.path.join(model_folder, 'srun.sh')
    with open(srun_file, 'w') as f:
        f.write(srun)
        srun_files.append(srun_file)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='analysis')
    parser.add_argument('--dry', action='store_true')
    parser.add_argument('--max', type=int, default=0)
    args = parser.parse_args()

    if not args.dry:
        for i, srun_file in enumerate(srun_files):
            if args.max <= 0 or i < args.max:
                p = subprocess.Popen(["sh", srun_file], cwd=root)
            else:
                print('File not run: ', srun_file)
