import argparse
import os
import subprocess

parser = argparse.ArgumentParser(description=
  '''
  This file writes a script to run the various main.py scripts in the examples directories.
  It then submits the script to the cluster.

  Example usage:

  python write_script.py --num_ens 1 2 --d_model 256 512 --ffn_ratio 1 2 --nlayers 2 3 --nhead 2 4 --lr 0.5 1 --optimizer musgd --coord_check --coord_check_nsteps 3 --coord_check_nseeds 3

  This will write a script to run the following experiments:

  python main.py --num_ens 1 --d_model 256 --ffn_ratio 1 --nlayers 2 --nhead 2 --lr 0.5 --optimizer musgd --coord_check --cuda
  python main.py --num_ens 1 --d_model 512 --ffn_ratio 2 --nlayers 3 --nhead 4 --lr 1 --optimizer musgd --coord_check --cuda

  If the base shapes for that d_model have not been created yet, then the script will also create them.
  ''', formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument('-t', type=str, default="Transformer")
parser.add_argument('--data', type=str, default='wikitext-103',
                    help='location of the data corpus')
parser.add_argument('--max_data', type=int, default=argparse.SUPPRESS,
                            help='maximum number of tokens in the dataset')      
parser.add_argument('--bias', action='store_true', default=argparse.SUPPRESS,
                    help='use bias')
parser.add_argument('--save_base_shapes', type=str, default=argparse.SUPPRESS,
                    help='file location to save base shapes at')
parser.add_argument('--load_base_shapes', type=str, default=argparse.SUPPRESS,
                    help='file location to load base shapes from')
parser.add_argument('--d_model', type=int, nargs='+', default=[256],
                    help='width of the model')
parser.add_argument('--ffn_ratio', type=int, nargs='+', default=[1],
                    help='the ratio of d_ffn to d_model')
parser.add_argument('--nlayers', type=int, nargs='+', default=[2],
                    help='number of layers')
parser.add_argument('--nhead', type=int, nargs='+', default=[2],
                    help='the number of heads in the encoder/decoder of the transformer model')
parser.add_argument('--lr', type=float, nargs='+', default=[0.5],
                    help='initial learning rate')
parser.add_argument('--lr_rescale', action='store_true', default=argparse.SUPPRESS,
                        help='rescale lr with \sqrt N')   
parser.add_argument('--momentum', type=float, default=argparse.SUPPRESS,
                    help='momentum')
parser.add_argument('--output_mult', type=float, default=1,
                    help='output is multiplied by sqrt(output_mult/d_model)')
parser.add_argument('--input_mult', type=float, default=1,
                    help='input is multiplied by sqrt(input_mult*d_model)')
parser.add_argument('--attn_mult', type=float, default=1,
                    help='attn is multiplied by sqrt(attn_mult)/head_dim')
parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'musgd', 'adam', 'muadam'])
parser.add_argument('--init_var', type=float, default=1,
                    help='weights are initialized with variance init_var/ninp')
parser.add_argument('--clip', type=float, default=argparse.SUPPRESS,
                    help='gradient clipping')
parser.add_argument('--epochs', type=int, default=argparse.SUPPRESS,
                    help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=32, metavar='N',
                    help='batch size')
parser.add_argument('--bptt', type=int, default=35,
                    help='sequence length')
parser.add_argument('--dropout', type=float, default=0,
                    help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--tied', action='store_true', default=argparse.SUPPRESS,
                    help='tie the word embedding and softmax weights')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
parser.add_argument('--cuda', action='store_true', default=True,
                    help='use CUDA')
parser.add_argument('--precision', type=str, default='float',
                    help='float | double | half')
parser.add_argument('--log_interval', type=int, default=200, metavar='N',
                    help='report interval')
parser.add_argument('--save_dir', type=str, default="", # Put your directory here
                        help='path to save the final model')
parser.add_argument('--resume_dir', type=str, default="", # Put your directory here
                    help='path to resume training')
parser.add_argument('--log_dir', type=str, default="", # Put your directory here
                    help='path to save logs')
parser.add_argument('--coord_check', action='store_true', default=argparse.SUPPRESS,
                    help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.')
parser.add_argument('--coord_check_nsteps', type=int, default=argparse.SUPPRESS,
                    help='Do coord check with this many steps.')
parser.add_argument('--coord_check_nseeds', type=int, default=argparse.SUPPRESS,
                    help='number of seeds for testing correctness of μ parametrization')
parser.add_argument('--num_ens', type=int, nargs='+', default=[1],
                    help='number of random inits to ensemble over')
parser.add_argument('--eval_point', action='store_true', default=argparse.SUPPRESS,
                        help='Evaluate on a held out validation point')
parser.add_argument('--bsh', type=int, nargs='+', default=[256],
                        help='Base d_model setting')

args = parser.parse_args()

python_loc = # Enter whatever your python location is

print(args)

# Get the directory to which this python script is being run
dir_path = os.path.dirname(os.path.realpath(__file__)) + '/'

num_scripts = max(len(args.num_ens), len(args.d_model), len(args.ffn_ratio), len(args.nlayers), len(args.nhead), len(args.lr))

variable_args = "num_ens d_model ffn_ratio nlayers nhead lr bsh".split()

# Check that each of the lists are either length 1 or the same length as num_scripts
for arg in [args.num_ens, args.d_model, args.ffn_ratio, args.nlayers, args.nhead, args.lr]:
  if len(arg) != 1 and len(arg) != num_scripts:
    raise ValueError("Each of the lists of arguments must have either matching length or be length 1.")

for arg in variable_args:
  if len(getattr(args, arg)) == 1:
    setattr(args, arg, getattr(args, arg) * num_scripts)

opt = args.optimizer

print("Number of scripts to run: ", num_scripts)
# Print a list of dictionaries with the arguments for each script
for i in range(num_scripts):
  print({arg: getattr(args, arg)[i] for arg in variable_args})


for i, (num_ens, d_model, ffn_ratio, nlayers, nhead, lr, bsh) in enumerate(zip(args.num_ens, args.d_model, args.ffn_ratio, args.nlayers, args.nhead, args.lr, args.bsh)):
  save_str = f'{args.t}_{args.data}_E={num_ens}_d={d_model}_nheads={nhead}_nlayers={nlayers}_ffnrat={ffn_ratio}_lr={lr}_om={args.output_mult}_drp={args.dropout}_opt={args.optimizer}_bs={args.batch_size}_bptt={args.bptt}_bsh={bsh}'


  # Take the parser variables and convert them into a string of flags to pass to the bash script
  flags = ""
  flags += f"--data ./data/{args.data} "
  for arg in vars(args):
    if arg in ["bsh", "t", "data"]: continue 
    if arg in variable_args:
      flags += f"--{arg} {getattr(args, arg)[i]} "
    # for boolean flags, only add the flag if it is true
    elif arg in "tied coord_check lr_rescale cuda bias eval_point".split():
      if getattr(args, arg):
        flags += f"--{arg} "
    elif arg not in ["t", 'data']:
      flags += f"--{arg} {getattr(args, arg)} "


  bash_file_name = dir_path+f"bash_scripts/"+save_str+".sh"

  with open (bash_file_name, 'w') as rsh:
    rsh.write('''\
#!/bin/bash
#SBATCH -n 1                # Number of cores
#SBATCH -N 1                # Ensure that all cores are on one machine
#SBATCH -t 3-00:00          # Runtime in D-HH:MM, minimum of 10 minutes
#SBATCH -p kempner,seas_gpu,gpu,gpu_requeue	    # Partition to submit to
''')
    if d_model > 1024: 
      rsh.write(f"#SBATCH --gres=gpu:nvidia_a100-sxm4-80gb:1\n")   
      rsh.write(f"#SBATCH --constraint=a100\n")  
      rsh.write(f"#SBATCH --mem-per-gpu=80G\n")    
    else:
      rsh.write(f"#SBATCH --gres=gpu:nvidia_a100-sxm4-40gb:1\n")
      rsh.write(f"#SBATCH --constraint=a100\n")
      rsh.write(f"#SBATCH --mem-per-gpu=40G\n")
      
    rsh.write(f"#SBATCH -o {dir_path}/out_files/out_{save_str}.out\n") # File to which STDOUT will be written, %j inserts jobid
    rsh.write(f"#SBATCH -e {dir_path}/out_files/err_{save_str}.err\n") # File to which STDERR will be written, %j inserts jobid
    rsh.write(f"#SBATCH --job-name={save_str}\n")
    rsh.write('''\


nvidia-smi

which python

''') # Modify appropriately to import libraries
    rsh.write(f"export PYTHONPATH=\"${{PYTHONPATH}}:{dir_path}\"\n\n")
    if args.t =="Transformer":
      rsh.write(f"cd ./ \n") # Modify appropriately 
      base_shape_dir = f"{dir_path}/examples/{args.t}/base_shapes/"
      base_shape_file = f"width={bsh}_nhead={nhead}_nlayers={nlayers}_ffnrat={ffn_ratio}.bsh"
      if not os.path.exists(base_shape_dir+base_shape_file):
        newflags = flags.replace(f"--dmodel {d_model}", f"--dmodel {bsh}")
        rsh.write(f"{python_loc} main.py --save_base_shapes base_shapes/{base_shape_file} {newflags}\n")
      rsh.write(f"{python_loc} main.py --load_base_shapes base_shapes/{base_shape_file} {flags}\n")
    else:
      print("Not implemented")
      exit(0)
  # Run the bash script
  os.system(f"sbatch {bash_file_name}")

