#!/bin/bash

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

. ./path.sh || exit 1;
. ./cmd.sh || exit 1;

# general configuration
backend=pytorch
stage=0        # start from 0 if you need to start from data preparation
stop_stage=100
ngpu=1         # number of gpus ("0" uses cpu, otherwise use gpu)
debugmode=1
N=0            # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
verbose=0     # verbose option
resume=        # Resume the training from snapshot
seed=1

# config files
preprocess_config=conf/no_preprocess.yaml  # use conf/specaug.yaml for data augmentation
train_config=conf/train_multispkr_tasnet_conditional.yaml
lm_config=conf/lm.yaml
decode_config=conf/decode.yaml

# multi-speaker asr related
num_spkrs=2         # number of speakers
use_spa=false       # speaker parallel attention

# rnnlm related
use_wordlm=true     # false means to train/use a character LM
lm_vocabsize=65000  # effective only for word LMs
lm_resume=          # specify a snapshot file to resume LM training
lmtag=              # tag for managing LMs

# decoding parameter
n_average=10
recog_model=model.loss.best # set a model to be used for decoding: 'model.acc.best' or 'model.loss.best'

# data
wsj0=/export/corpora5/LDC/LDC93S6B
wsj1=/export/corpora5/LDC/LDC94S13B
wsj_full_wav=$PWD/data/wsj0/wsj0_wav
wsj_2mix_wav=$PWD/data/wsj0_mix/2speakers 
wsj_3mix_wav=$PWD/data/wsj0_mix/3speakers
wsj_2mix_scripts=$PWD/data/wsj0_mix/scripts

# exp tag
tag="" # tag for managing experiments.

. utils/parse_options.sh || exit 1;

# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail

train_set="tr"
train_dev="cv"
recog_set="tt"

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    ### Task dependent. You have to make data the following preparation part by yourself.
    echo "stage 0: Data preparation"
    ### This part is for WSJ0 mix
    ### Download mixture scripts and create mixtures for 2 speakers
    local/wsj0_create_mixture.sh ${wsj_2mix_scripts} ${wsj0} ${wsj_full_wav} \
        ${wsj_2mix_wav} || exit 1;
    local/wsj0_2mix_data_prep.sh ${wsj_2mix_wav}/wav8k/max ${wsj_2mix_scripts} \
        ${wsj_full_wav} || exit 1;

    local/wsj0_create_3mixture.sh ${wsj_2mix_scripts} ${wsj0} ${wsj_full_wav} \
        ${wsj_3mix_wav} || exit 1;
    local/wsj0_3mix_data_prep.sh ${wsj_3mix_wav}/wav8k/max ${wsj_2mix_scripts} \
        ${wsj_full_wav} || exit 1;
fi

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    ### Task dependent. You have to design training and dev sets by yourself.
    ### But you can utilize Kaldi recipes in most cases
    echo "stage 1: Dump wav files into a HDF5 file started @ `date`"
    local/dump_pcm.sh --cmd "$train_cmd" --nj 32 --filetype "sound.hdf5" --num-spkrs 2 data/${train_set}
    cat data/${train_set}/data/feats_spk1.scp | \
        awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
        sort > data/${train_set}/feats_spk1.scp
    cat data/${train_set}/data/feats_spk2.scp | \
        awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
        sort > data/${train_set}/feats_spk2.scp

    local/dump_pcm.sh --cmd "$train_cmd" --nj 4 --filetype "sound.hdf5" --num-spkrs 2 data/${train_dev}
    cat data/${train_dev}/data/feats_spk1.scp | \
        awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
        sort > data/${train_dev}/feats_spk1.scp
    cat data/${train_dev}/data/feats_spk2.scp | \
        awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
        sort > data/${train_dev}/feats_spk2.scp

    for rtask in ${recog_set}; do
        local/dump_pcm.sh --cmd "$train_cmd" --nj 4 --filetype "sound.hdf5" --num-spkrs 2 data/${rtask}
        cat data/${rtask}/data/feats_spk1.scp | \
            awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
            sort > data/${rtask}/feats_spk1.scp
        cat data/${rtask}/data/feats_spk2.scp | \
            awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
            sort > data/${rtask}/feats_spk2.scp
    done
    echo "stage 1: Done @ `date`"
fi

dict=data/lang_1char/${train_set}_units.txt
dict_wblank=data/lang_1char/${train_set}_units_wblank.txt
nlsyms=data/lang_1char/non_lang_syms.txt
wsj_train_set=wsj/train_si284
wsj_train_dev=wsj/test_dev93
wsj_train_test=wsj/test_eval92

echo "dictionary: ${dict}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    ### Task dependent. You have to check non-linguistic symbols used in the corpus.
    echo "stage 2: Dictionary and Json Data Preparation started @ `date`"
    mkdir -p data/lang_1char/

    echo "make a non-linguistic symbol list"
    cut -f 2- data/${wsj_train_set}/text | tr " " "\n" | sort | uniq | grep "<" > ${nlsyms}
    cat ${nlsyms}

    echo "make a dictionary"
    echo "<unk> 1" > ${dict} # <unk> must be 1, 0 will be used for "blank" in CTC
    text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_set}/text | cut -f 2- -d" " | tr " " "\n" \
    | sort | uniq | grep -v -e '^\s*$' | awk '{print $0 " " NR+1}' >> ${dict}
    wc -l ${dict}
    # add blank for dict, only use to convert CTC alignment into training units index format
    sed '1 i <blank> 0' ${dict} > ${dict_wblank}

    echo "make json files"
    local/data2json.sh --cmd "${train_cmd}" --nj 32 --filetype sound.hdf5 \
        --feat data/${train_set}/feats.scp --nlsyms ${nlsyms} --num-spkrs 2 \
        data/${train_set} ${dict} > data/${train_set}/data.json
    local/data2json.sh --cmd "${train_cmd}" --nj 4 --filetype sound.hdf5 \
        --feat data/${train_dev}/feats.scp --nlsyms ${nlsyms} --num-spkrs 2 \
        data/${train_dev} ${dict} > data/${train_dev}/data.json
    for rtask in ${recog_set}; do
        local/data2json.sh --cmd "${train_cmd}" --nj 4 --filetype sound.hdf5 \
            --feat data/${rtask}/feats.scp --nlsyms ${nlsyms} --num-spkrs 2 \
            data/${rtask} ${dict} > data/${rtask}/data.json
    done
    echo "stage 2: Done @ `date`"
fi

if [ -z ${tag} ]; then
    expname=${train_set}_${backend}_$(basename ${train_config%.*})_$(basename ${preprocess_config%.*})
else
    expname=${train_set}_${backend}_${tag}
fi
expdir=exp/${expname}
mkdir -p ${expdir}

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: Network Training started @ `date`"

    ${cuda_cmd} --gpu ${ngpu} ${expdir}/train.log \
        separate_train.py \
        --config ${train_config} \
        --ngpu ${ngpu} \
        --backend ${backend} \
        --outdir ${expdir}/results \
        --tensorboard-dir ${expdir}/tensorboard \
        --debugmode ${debugmode} \
        --dict ${dict} \
        --debugdir ${expdir} \
        --minibatches ${N} \
        --verbose ${verbose} \
        --resume ${resume} \
        --seed ${seed} \
        --train-json data/${train_set}/data.json \
        --valid-json data/${train_dev}/data.json \
        --num-spkrs ${num_spkrs} \
        ${spa:+--spa}
    echo "stage 3: Done @ `date`"
fi

recog_set="tt"
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
    echo "stage 4: Decoding"
    recog_model=model.last${n_average}.avg.best
    average_checkpoints.py --backend ${backend} \
                            --snapshots ${expdir}/results/snapshot.ep.* \
                            --out ${expdir}/results/${recog_model} \
                            --num ${n_average}

    for rtask in ${recog_set}; do
        decode_dir=decode_${rtask}_$(basename ${decode_config%.*})
        mkdir -p ${expdir}/${decode_dir}
        ${cuda_cmd} --gpu ${ngpu} ${expdir}/${decode_dir}/decode.log \
            separate_recog.py \
            --ngpu 1 \
            --backend ${backend} \
            --recog-json data/${rtask}/data.json \
            --decode-dir ${expdir}/${decode_dir} \
            --model ${expdir}/results/${recog_model}
    done
    echo "Finished"
fi
