import os
import subprocess
import itertools
import argparse
import torch

def create_script(args):
    # create sbatch job script, must take all args
    script = '''#!/bin/bash

#SBATCH --job-name={name}
#SBATCH --nodes=1
#SBATCH --cpus-per-task={cpu}
#SBATCH --time={time}:00:00
#SBATCH --mem={mem}GB
#SBATCH --gres=gpu:k80:{gpu}

#PRINCE PRINCE_GPU_MPS=YES

module purge
    '''.format(
        mem=args.mem,
        time=args.time,
        gpu=args.gpu,
        cpu=args.cpu,
        name=args.save_dir
        # name=args.name
    )

    cmd = '\npython main.py '
    cmd += '--seed {} '.format(args.seed)
    cmd += '--dataset {} '.format(args.dataset)
    cmd += '--model {} '.format(args.model)
    cmd += '--lr {} '.format(args.lr)
    cmd += '--epochs {} '.format(args.epochs)
    cmd += '--save_dir {} '.format(args.save_dir)
    cmd += ' & \n'

    script += cmd
    script += "\nwait\n" # wait is necessary, otherwise it doesn't work

    # save the script to a file
    file_path = '{}.sbatch'.format(args.save_dir)
    with open(file_path, 'w') as f:
        f.write(script)

    return file_path

def copy_py(dst_folder):
    # run in the current folder 
    # and copy all .py's into dst_folder
    if not os.path.exists(dst_folder):
        print("Folder doesn't exist!")
        return 
    import shutil
    for f in os.listdir():
        if f.endswith('.py'):
            shutil.copy2(f, dst_folder)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='skinnyprime', type=str, choices=['skinnyprime', 'alexnetprime'])
    parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'cifar100'])
    parser.add_argument('--seed', default=2, type=int)
    parser.add_argument('--gpu', default=2, type=int)
    parser.add_argument('--mem', default=10, type=int, help='in GB')
    parser.add_argument('--time', default=6, type=int, help='in hours')
    parser.add_argument('--cpu', default=2, type=int, help='might get more for multiple jobs per GPU')
    args = parser.parse_args()
    
    results_dir = 'results_{}_{}'.format(args.dataset, args.model) 
    copy_py(results_dir) # folder must exist here!!!
    os.chdir(results_dir)

    args.lr = 0.1
    args.epochs = 150
    args.save_dir = '{}'.format(model) 
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
    file_path = create_script(args)
    process = subprocess.Popen(['sbatch', file_path], stdout=subprocess.PIPE)
    print(process.communicate()[0].decode('utf-8'))
