#!/bin/bash

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

. ./path.sh || exit 1;
. ./cmd.sh || exit 1;

# general configuration
backend=pytorch
stage=0        # start from 0 if you need to start from data preparation
stop_stage=100
ngpu=1         # number of gpus ("0" uses cpu, otherwise use gpu)
debugmode=1
N=0            # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
verbose=0     # verbose option
resume=        # Resume the training from snapshot
seed=1

# config files
preprocess_config=conf/no_preprocess.yaml  # use conf/specaug.yaml for data augmentation
train_config=conf/train_multispkr_joint_conditional.yaml
lm_config=conf/lm.yaml
decode_config=conf/decode.yaml

# multi-speaker asr related
num_spkrs=2         # number of speakers
use_spa=false       # speaker parallel attention

# rnnlm related
use_wordlm=true     # false means to train/use a character LM
lm_vocabsize=65000  # effective only for word LMs
lm_resume=          # specify a snapshot file to resume LM training
lmtag=              # tag for managing LMs

# decoding parameter
n_average=10
recog_model=model.acc.best # set a model to be used for decoding: 'model.acc.best' or 'model.loss.best'

# data
wsj0=/export/corpora5/LDC/LDC93S6B
wsj1=/export/corpora5/LDC/LDC94S13B
wsj_full_wav=$PWD/data/wsj0/wsj0_wav
wsj_2mix_wav=$PWD/data/wsj0_mix/2speakers 
wsj_3mix_wav=$PWD/data/wsj0_mix/3speakers
wsj_2mix_scripts=$PWD/data/wsj0_mix/scripts

# exp tag
tag="" # tag for managing experiments.

. utils/parse_options.sh || exit 1;

# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail

train_set="tr"
train_dev="cv"
recog_set="tt"

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    ### Task dependent. You have to make data the following preparation part by yourself.
    echo "stage 0: Data preparation"
    ### This part is for WSJ0 mix
    ### Download mixture scripts and create mixtures for 2 speakers
    local/wsj0_create_mixture.sh ${wsj_2mix_scripts} ${wsj0} ${wsj_full_wav} \
        ${wsj_2mix_wav} || exit 1;
    local/wsj0_2mix_data_prep.sh ${wsj_2mix_wav}/wav8k/max ${wsj_2mix_scripts} \
        ${wsj_full_wav} || exit 1;

    local/wsj0_create_3mixture.sh ${wsj_2mix_scripts} ${wsj0} ${wsj_full_wav} \
        ${wsj_3mix_wav} || exit 1;
    local/wsj0_3mix_data_prep.sh ${wsj_3mix_wav}/wav8k/max ${wsj_2mix_scripts} \
        ${wsj_full_wav} || exit 1;

    ### Also need wsj corpus to prepare language information
    ### This is from Kaldi WSJ recipe
    local/wsj_data_prep.sh ${wsj0}/??-{?,??}.? ${wsj1}/??-{?,??}.?
    local/wsj_format_data.sh
    mkdir -p data/wsj
    mv data/{dev_dt_*,local,test_dev*,test_eval*,train_si284} data/wsj

    ### Or this part is for WSJ mix, which is a larger two-speaker mixture corpus created from WSJ corpus. Used in
    ### Seki H, Hori T, Watanabe S, et al. End-to-End Multi-Lingual Multi-Speaker Speech Recognition[J]. 2018. and
    ### Chang X, Qian Y, Yu K, et al. End-to-End Monaural Multi-speaker ASR System without Pretraining[J]. 2019
    ### Before next step, suppose wsj_2mix_corpus has been generated (please refer to wsj0_mixture for more details).
    # local/wsj_2mix_data_prep.sh ${wsj_2mix_wav}/wav16k/max ${wsj_2mix_script} || exit 1;
fi

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    ### Task dependent. You have to design training and dev sets by yourself.
    ### But you can utilize Kaldi recipes in most cases
    echo "stage 1: Dump wav files into a HDF5 file started @ `date`"
    local/dump_pcm.sh --cmd "$train_cmd" --nj 32 --filetype "sound.hdf5" --num-spkrs 2 data/${train_set}
    cat data/${train_set}/data/feats_spk1.scp | \
        awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
        sort > data/${train_set}/feats_spk1.scp
    cat data/${train_set}/data/feats_spk2.scp | \
        awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
        sort > data/${train_set}/feats_spk2.scp

    local/dump_pcm.sh --cmd "$train_cmd" --nj 4 --filetype "sound.hdf5" --num-spkrs 2 data/${train_dev}
    cat data/${train_dev}/data/feats_spk1.scp | \
        awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
        sort > data/${train_dev}/feats_spk1.scp
    cat data/${train_dev}/data/feats_spk2.scp | \
        awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
        sort > data/${train_dev}/feats_spk2.scp

    for rtask in ${recog_set}; do
        local/dump_pcm.sh --cmd "$train_cmd" --nj 4 --filetype "sound.hdf5" --num-spkrs 2 data/${rtask}
        cat data/${rtask}/data/feats_spk1.scp | \
            awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
            sort > data/${rtask}/feats_spk1.scp
        cat data/${rtask}/data/feats_spk2.scp | \
            awk '{split($1, lst, "_"); spk=substr(lst[1],1,3)"_"substr(lst[3],1,3); print(spk"_"$0)}' | \
            sort > data/${rtask}/feats_spk2.scp
    done
    echo "stage 1: Done @ `date`"
fi

dict=data/lang_1char/${train_set}_units.txt
dict_wblank=data/lang_1char/${train_set}_units_wblank.txt
nlsyms=data/lang_1char/non_lang_syms.txt
wsj_train_set=wsj/train_si284
wsj_train_dev=wsj/test_dev93
wsj_train_test=wsj/test_eval92

echo "dictionary: ${dict}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    ### Task dependent. You have to check non-linguistic symbols used in the corpus.
    # echo "stage 2: Dictionary and Json Data Preparation started @ `date`"
    # mkdir -p data/lang_1char/

    # echo "make a non-linguistic symbol list"
    # cut -f 2- data/${wsj_train_set}/text | tr " " "\n" | sort | uniq | grep "<" > ${nlsyms}
    # cat ${nlsyms}

    # echo "make a dictionary"
    # echo "<unk> 1" > ${dict} # <unk> must be 1, 0 will be used for "blank" in CTC
    # text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_set}/text | cut -f 2- -d" " | tr " " "\n" \
    # | sort | uniq | grep -v -e '^\s*$' | awk '{print $0 " " NR+1}' >> ${dict}
    # wc -l ${dict}
    # # add blank for dict, only use to convert CTC alignment into training units index format
    # sed '1 i <blank> 0' ${dict} > ${dict_wblank}

    # echo "make json files"
    # # assert CTC alignment file already be generated by multi-speaker scripts
    # required_file="ctc_alignment_spk1 ctc_alignment_spk2"
    # for sdata in ${train_set} ${train_dev}; do
    #     for f in $required_file; do
    #         if [ ! -f ../wsj0_mix_asr/data_wsj0_2mix/${sdata}/${f} ]; then
    #             echo "Can not find ../wsj0_mix_asr/data_wsj0_2mix/${sdata}/${f}"
    #             echo "Assert file ${f} is generated by multi-speake scripts"
    #             echo "The detail of generating ${f} can be seen in ../wsj0_mix_asr/run.sh"
    #             exit 1;
    #         fi
    #         cp ../wsj0_mix_asr/data_wsj0_2mix/${sdata}/${f} data/${sdata}/${f}
    #     done
    # done

    local/data2json.sh --cmd "${train_cmd}" --nj 32 --filetype sound.hdf5 \
        --feat data/${train_set}/feats.scp --nlsyms ${nlsyms} --num-spkrs 2 \
        data/${train_set} ${dict_wblank} > data/${train_set}/data.json
    local/data2json.sh --cmd "${train_cmd}" --nj 4 --filetype sound.hdf5 \
        --feat data/${train_dev}/feats.scp --nlsyms ${nlsyms} --num-spkrs 2 \
        data/${train_dev} ${dict_wblank} > data/${train_dev}/data.json
    for rtask in ${recog_set}; do
        local/data2json.sh --cmd "${train_cmd}" --nj 4 --filetype sound.hdf5 \
            --feat data/${rtask}/feats.scp --nlsyms ${nlsyms} --num-spkrs 2 \
            data/${rtask} ${dict_wblank} > data/${rtask}/data.json
    done
    echo "stage 2: Done @ `date`"
fi

# It takes about one day. If you just want to do end-to-end ASR without LM,
# you can skip this and remove --rnnlm option in the recognition (stage 5)
if [ -z ${lmtag} ]; then
    lmtag=$(basename ${lm_config%.*})
    if [ ${use_wordlm} = true ]; then
        lmtag=${lmtag}_word${lm_vocabsize}
    fi
fi
lmexpname=train_rnnlm_${backend}_${lmtag}
lmexpdir=exp/${lmexpname}
mkdir -p ${lmexpdir}

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: LM Preparation started @ `date`"

    if [ ${use_wordlm} = true ]; then
        lmdatadir=data/local/wordlm_train
        lmdict=${lmdatadir}/wordlist_${lm_vocabsize}.txt
        mkdir -p ${lmdatadir}
        cut -f 2- -d" " data/${wsj_train_set}/text > ${lmdatadir}/train_trans.txt
        zcat ${wsj1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \
                | grep -v "<" | tr "[:lower:]" "[:upper:]" > ${lmdatadir}/train_others.txt
        cut -f 2- -d" " data/${wsj_train_dev}/text > ${lmdatadir}/valid.txt
        cut -f 2- -d" " data/${wsj_train_test}/text > ${lmdatadir}/test.txt
        cat ${lmdatadir}/train_trans.txt ${lmdatadir}/train_others.txt > ${lmdatadir}/train.txt
        text2vocabulary.py -s ${lm_vocabsize} -o ${lmdict} ${lmdatadir}/train.txt
    else
        lmdatadir=data/local/lm_train
        lmdict=${dict}
        mkdir -p ${lmdatadir}
        text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_set}/text \
            | cut -f 2- -d" " > ${lmdatadir}/train_trans.txt
        zcat ${wsj1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \
            | grep -v "<" | tr "[:lower:]" "[:upper:]" \
            | text2token.py -n 1 | cut -f 2- -d" " > ${lmdatadir}/train_others.txt
        text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_dev}/text \
            | cut -f 2- -d" " > ${lmdatadir}/valid.txt
        text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_test}/text \
                | cut -f 2- -d" " > ${lmdatadir}/test.txt
        cat ${lmdatadir}/train_trans.txt ${lmdatadir}/train_others.txt > ${lmdatadir}/train.txt
    fi

    ${cuda_cmd} --gpu ${ngpu} ${lmexpdir}/train.log \
        lm_train.py \
        --config ${lm_config} \
        --ngpu ${ngpu} \
        --backend ${backend} \
        --verbose 1 \
        --outdir ${lmexpdir} \
        --tensorboard-dir tensorboard/${lmexpname} \
        --train-label ${lmdatadir}/train.txt \
        --valid-label ${lmdatadir}/valid.txt \
        --test-label ${lmdatadir}/test.txt \
        --resume ${lm_resume} \
        --dict ${lmdict}
    echo "stage 3: Done @ `date`"
fi

if [ -z ${tag} ]; then
    expname=${train_set}_${backend}_$(basename ${train_config%.*})_$(basename ${preprocess_config%.*})
else
    expname=${train_set}_${backend}_${tag}
fi
${use_spa} && spa=true
expdir=exp/${expname}
mkdir -p ${expdir}

if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
    echo "stage 4: Network Training started @ `date`"

    ${cuda_cmd} --gpu ${ngpu} ${expdir}/train.log \
        asr_separate_train.py \
        --config ${train_config} \
        --ngpu ${ngpu} \
        --backend ${backend} \
        --outdir ${expdir}/results \
        --tensorboard-dir ${expdir}/tensorboard \
        --debugmode ${debugmode} \
        --dict ${dict} \
        --debugdir ${expdir} \
        --minibatches ${N} \
        --verbose ${verbose} \
        --resume ${resume} \
        --seed ${seed} \
        --train-json data/${train_set}/data.json \
        --valid-json data/${train_dev}/data.json \
        --num-spkrs ${num_spkrs} \
        ${spa:+--spa}
    echo "stage 4: Done @ `date`"
fi

recog_set="tt"
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
    echo "stage 5: Decoding started @ `date`"
    nj=32
    recog_model=snapshot.ep.17
    if [[ $(get_yaml.py ${train_config} model-module) = *transformer* ]]; then
        recog_model=model.last${n_average}.avg.best
        average_checkpoints.py --backend ${backend} \
                               --snapshots ${expdir}/results/snapshot.ep.* \
                               --out ${expdir}/results/${recog_model} \
                               --num ${n_average}
    fi

    pids=() # initialize pids
    for rtask in ${recog_set}; do
    (
        decode_dir=decode_${rtask}_$(basename ${decode_config%.*})_lm
        if [ ${use_wordlm} = true ]; then
            recog_opts="--word-rnnlm ${lmexpdir}/rnnlm.model.best"
        else
            recog_opts="--rnnlm ${lmexpdir}/rnnlm.model.best"
        fi

        # split data
        splitjson.py --parts ${nj} data/${rtask}/data.json

        #### use CPU for decoding
        ngpu=0

        ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \
            asr_separate_recog.py \
            --config ${decode_config} \
            --ngpu ${ngpu} \
            --num-spkr ${num_spkrs} \
            --backend ${backend} \
            --recog-json data/${rtask}/split${nj}utt/data.JOB.json \
            --result-label ${expdir}/${decode_dir}/data.JOB.json \
            --separate-dir ${expdir}/${decode_dir}/wav \
            --model ${expdir}/results/${recog_model}  \
            ${recog_opts}

        score_sclite.sh --wer true --nlsyms ${nlsyms} --num_spkrs 2 ${expdir}/${decode_dir} ${dict}

    ) &
    pids+=($!) # store background pids
    done
    i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
    [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
    echo "stage 5: Done @ `date`"
fi
