import os
import subprocess
import argparse
import time
import concurrent.futures
import random

def create_train_script_deepspeed(dset_name, num_gpu, num_node, gradient_accum_step, per_dev_batch, seed, lr, train_size, epoch, folder, ds_config,mem,gpu,partition,include):
    batch_size = num_gpu * per_dev_batch * gradient_accum_step
    k = int(train_size)//1000
    rand_int = random.randint(0, 9)
    train_string = f'''#!/bin/bash

#SBATCH --job-name=cl_{dset_name}_{lr}_b{batch_size}s{seed}
#SBATCH --output=cl_{dset_name}_{lr}_b{batch_size}s{seed}.out
#SBATCH --error=cl_{dset_name}_{lr}_b{batch_size}s{seed}.err

#SBATCH --partition={partition}
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2

#SBATCH --gres=gpu:{gpu}:{num_gpu}
#SBATCH --mem={mem}
#SBATCH --exclude=inst-0-35,babel-4-23,babel-5-23
#SBATCH --time=1-12:00:00
#SBATCH --mail-type=ALL
#SBATCH --mail-user=bo@andrew.cmu.edu


source ~/.bashrc
source /opt/rh/devtoolset-10/enable
conda activate dpsk

WANDB__SERVICE_WAIT=500 WANDB_PROJECT=agent WANDB_ENTITY=vl001 WANDB_NAME={dset_name}_b{batch_size}s{seed} deepspeed --hostfile hostfile --include {include} --master_port=998{rand_int} /home/bo/cais/deepseek_llamafactory/LLaMA-Factory/src/train_bash.py \
    --deepspeed {ds_config} \
    --stage sft \
    --model_name_or_path  /data/tir/projects/tir3/users/bo/ckpts/cllama/cllama/models--codellama--CodeLlama-7b-hf/snapshots/bc5283229e2fe411552f55c71657e97edf79066c \
    --do_train \
    --dataset {dset_name} \
    --train_size {train_size} \
    --shuffle False \
    --dataset_dir /home/bo/cais/agent/data/{folder}/ \
    --template llama2 \
    --finetuning_type full \
    --output_dir /data/tir/projects/tir3/users/bo/ckpts/output_{k}k_{dset_name}_{lr}_b{batch_size}s{seed}/ \
    --overwrite_output_dir True \
    --cache_path /scratch/bo/{dset_name}/ \
    --per_device_train_batch_size {per_dev_batch} \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps {gradient_accum_step} \
    --gradient_checkpointing True \
    --lr_scheduler_type cosine \
    --evaluation_strategy "steps" \
    --save_strategy "epoch" \
    --logging_steps "30" \
    --save_total_limit 1 \
    --preprocessing_num_workers 16 \
    --learning_rate {lr} \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --num_train_epochs {epoch} \
    --plot_loss \
    --bf16 True \
    --cutoff_len 4096 \
    --report_to 'wandb' \
    --flash_attn True \
    --save_only_model \
    --seed {seed}

echo "exit code: $?"

    '''
    
    with open(f'/home/bo/cais/agent/fine_tune/train_script/tmp_{dset_name}_{seed}.sh', 'w') as file:
        file.write(train_string)
        
    return

def create_train_script(dset_name, num_gpu,num_node, gradient_accum_step,per_dev_batch, seed, lr, train_size, epoch, folder):
    batch_size = num_gpu * per_dev_batch * gradient_accum_step
    k = str(train_size)[0]
    rand_int = random.randint(0, 9)
    train_string = f'''#!/bin/bash

#SBATCH --job-name=cl_{dset_name}_{lr}_b{batch_size}s{seed}
#SBATCH --output=cl_{dset_name}_{lr}_b{batch_size}s{seed}.out
#SBATCH --error=cl_{dset_name}_{lr}_b{batch_size}s{seed}.err

#SBATCH --partition=general
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=2
# SBATCH --gres=gpu:A100:8
#SBATCH --gres=gpu:A6000:{num_gpu}
# SBATCH --gres=gpu:A100_80GB:{num_gpu}
#SBATCH --mem=512G
#SBATCH --exclude=inst-0-35,babel-4-23,babel-5-23
#SBATCH --time=1-23:00:00
#SBATCH --mail-type=ALL
#SBATCH --mail-user=bo@andrew.cmu.edu


source ~/.bashrc
source /opt/rh/devtoolset-10/enable
conda activate dpsk

WANDB__SERVICE_WAIT=300 WANDB_PROJECT=agent WANDB_ENTITY=vl001 WANDB_NAME={dset_name}_b{batch_size}s{seed} accelerate launch --main_process_port 998{rand_int} --num_processes {num_gpu}  --num_machines {num_node} /home/bo/cais/deepseek_llamafactory/LLaMA-Factory/src/train_bash.py \
    --stage sft \
    --model_name_or_path /data/tir/projects/tir3/users/bo/ckpts/cllama/cllama/models--codellama--CodeLlama-7b-hf/snapshots/bc5283229e2fe411552f55c71657e97edf79066c \
    --do_train \
    --do_eval \
    --dataset {dset_name} \
    --train_size {train_size} \
    --shuffle False \
    --dataset_dir /home/bo/cais/agent/data/{folder}/ \
    --template llama2 \
    --finetuning_type full \
    --output_dir /data/tir/projects/tir3/users/bo/ckpts/output_{k}k_{dset_name}_{lr}_b{batch_size}s{seed}/ \
    --overwrite_output_dir True \
    --overwrite_cache \
    --cache_path /scratch/bo/{dset_name}/ \
    --per_device_train_batch_size {per_dev_batch} \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps {gradient_accum_step} \
    --gradient_checkpointing True \
    --lr_scheduler_type cosine \
    --evaluation_strategy "steps" \
    --save_strategy "epoch" \
    --logging_steps "30" \
    --save_total_limit 1 \
    --learning_rate {lr} \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --num_train_epochs {epoch} \
    --plot_loss \
    --bf16 True \
    --cutoff_len 4096 \
    --report_to 'wandb' \
    --save_only_model \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --flash_attn True \
    --tf32 True \
    --seed {seed}

    '''
    with open(f'/home/bo/cais/agent/fine_tune/train_script/tmp_{dset_name}_{seed}.sh', 'w') as file:
        file.write(train_string)
        
    return



dset = [
    ['m2w_scrp_8k_hist_1kstop_allgen_v2_numlen', 8430]
]

for seed in [13307]:
    for dset_name,train_size in dset:
        dset_name = dset_name
        num_gpu = 3
        num_node = 1
        partition = 'general'
        include = 'babel-5-31:0,1,2'
        gpu = 'A100_80GB'
        per_dev_batch = 6
        gradient_accum_step = 2
        lr = '1e-5'
        seed = seed
        train_size = train_size
        folder = 'code_scrp'
        epoch = 5
        mem = '1000G'
        ds_config = 'ds_config_3_xofl.json'
        
        # ds_config = 'ds_config_3.json'
        # ds_config = 'ds_z2_config.json'
        # create_train_script(dset_name, num_gpu, num_node, gradient_accum_step, per_dev_batch, seed, lr, train_size, epoch, folder)
        create_train_script_deepspeed(dset_name, num_gpu, num_node, gradient_accum_step, per_dev_batch, seed, lr, train_size, epoch, folder, ds_config,mem,gpu,partition,include)

        subprocess.run(['sbatch', f'/home/bo/cais/agent/fine_tune/train_script/tmp_{dset_name}_{seed}.sh'])


# 28k 8 A6000 batch 2 ds_config_3      (with offload) 6 epoch 50 hours
# 28k 8 A6000 batch 1 ds_config_3_xofl (no offload)   6 epoch 38 hours
# 28k 8 A6000 batch 1 ds_z2_config     (no offload)   6 epoch 36 hours

        


        