#!/bin/bash

# Parameters
#SBATCH --constraint=XXX
#SBATCH --cpus-per-task=XXX
#SBATCH --error=XXX
#SBATCH --gres=gpu:XXX
#SBATCH --job-name=XXX
#SBATCH --mem=XXX
#SBATCH --nodes=XXX
#SBATCH --ntasks-per-node=XXX
#SBATCH --open-mode=XXX
#SBATCH --output=XXX
#SBATCH --partition=XXX
#SBATCH --signal=XXX
#SBATCH --time=XXX
#SBATCH --mail-user=XXX
#SBATCH --mail-type=END,FAIL,REQUEUE,BEGIN

source activate lab_vid


export MASTER_ADDR=${SLURM_NODELIST:0:9}${SLURM_NODELIST:10:4}
export MASTER_PORT=19500

# debugging flags (optional)
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1

# set the network interface
export NCCL_SOCKET_IFNAME=^docker0,lo
echo $SLURMD_NODENAME $SLURM_JOB_ID $CUDA_VISIBLE_DEVICES

SAV_FOLDER="XXX"
mkdir -p ${SAV_FOLDER}
if [ -z "$8" ]
then
  DS='vggsound'
else
  DS=$8
fi
BATCH_SIZE=16
MUM_EPOCHS=201
LR=1e-2
MOMENTUM=0.9
WEIGHT_DECAY=1e-5
USE_SCHEDULER='True'
SCHEDULER_TYPE='multi_step'
LR_GAMMA=0.1
LR_WARM_EPOCHS=10
LR_MIL='251' # 120,160,200
MODEL='resnet18'
VID_BASE_ARCH='r2plus1d_18'
PRETRAINED='False'
if [ -z "$1" ]
then
  USE_MLP_HEAD='True'
else
  USE_MLP_HEAD=$1
fi
MLPTYPE=1
if [ -z "$2" ]
then
  NUM_CLUSTERS=309
else
  NUM_CLUSTERS=$2
fi
if [ -z "$3" ]
then
  HEADCOUNT=1
else
  HEADCOUNT=$3
fi
AUGTYPE=1
NUM_CLIPS=1
CLIP_LEN=8
CROP_SIZE=112
SAMPLE_RATE=1
NUM_WORKERS=10
WARM_BN='False'
SYNC_BN='True'

# SK-params
GPU_SK='True'
NOPTS=100

# AUDIO AUGS
AUG_AUDIO='False'
AUD_AUG_TYPE='heavy'
AUD_BASE_ARCH='resnet9'
AUD_SAMPLE_RATE=24000
AUD_SPEC_TYPE=2
AUD_VOLUME_JITTERING='True'
AUD_TEMPORAL_JITTERING='False'
AUD_NUM_SEC=1
AUD_Z_NORMALIZE='True'
if [ -z "$9" ]
then
  NUM_DATA_SAMPLES=170752
else
  NUM_DATA_SAMPLES=$9
fi
STOCHASTIC_BLOCK=0
if [ -z "$4" ]
then
  MATCH="True"
else
  MATCH=$4
fi

if [ -z "$5" ]
then
  DISTR='default'
else
  DISTR=$5
fi

if [ -z "$6" ]
then
  NGROUPS=1
else
  NGROUPS=$6
fi

if [ -z "$7" ]
then
  IND_GROUPS=1
else
  IND_GROUPS=$7
fi



# command
srun --label python3 sela.py \
--groups ${NGROUPS} \
--output-dir ${SAV_FOLDER} \
--distribution ${DISTR} \
--dataset ${DS} \
--batch-size ${BATCH_SIZE} \
--epochs ${MUM_EPOCHS} \
--lr ${LR} \
--momentum ${MOMENTUM} \
--weight-decay ${WEIGHT_DECAY} \
--use-scheduler ${USE_SCHEDULER} \
--scheduler-type ${SCHEDULER_TYPE} \
--lr-gamma ${LR_GAMMA} \
--lr-warmup-epochs ${LR_WARM_EPOCHS} \
--lr-milestones ${LR_MIL} \
--warmup-bn ${WARM_BN} \
--sync-bn ${SYNC_BN} \
--model ${MODEL} \
--vid-base-arch ${VID_BASE_ARCH} \
--aud-base-arch ${AUD_BASE_ARCH} \
--pretrained ${PRETRAINED} \
--mlptype ${MLPTYPE} \
--augtype ${AUGTYPE} \
--num-clusters ${NUM_CLUSTERS} \
--clip-len ${CLIP_LEN} \
--train-crop-size ${CROP_SIZE} \
--sample-rate ${SAMPLE_RATE} \
--clips-per-video ${NUM_CLIPS} \
--workers ${NUM_WORKERS} \
--use-mlp ${USE_MLP_HEAD} \
--aug-audio ${AUG_AUDIO} \
--stoch-sk-modality ${STOCHASTIC_BLOCK} \
--ind-groups ${IND_GROUPS} \
--headcount ${HEADCOUNT} \
--audio-augtype ${AUD_AUG_TYPE} \
--gpu-sk ${GPU_SK} \
--nopts ${NOPTS} \
--aud-sample-rate ${AUD_SAMPLE_RATE} \
--aud-spec-type ${AUD_SPEC_TYPE} \
--use-volume-jittering ${AUD_VOLUME_JITTERING} \
--use-temporal-jittering ${AUD_TEMPORAL_JITTERING} \
--num-sec ${AUD_NUM_SEC} \
--z-normalize ${AUD_Z_NORMALIZE} \
--num-data-samples ${NUM_DATA_SAMPLES} \
--match ${MATCH} \
