#!/bin/bash

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

. ./path.sh || exit 1;
. ./cmd.sh || exit 1;

# general configuration
backend=pytorch
stage=0        # start from 0 if you need to start from data preparation
ngpu=1         # number of gpus ("0" uses cpu, otherwise use gpu)
debugmode=1
dumpdir=dump   # directory to dump full features
N=0            # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
verbose=0      # verbose option
resume=        # Resume the training from snapshot
seed=1

# configuration path
preprocess_config=conf/preprocess.json
train_config=conf/train.yaml
lm_config=conf/lm.yaml
decode_config=conf/decode.yaml

# feature configuration
do_delta=true

# network architecture
num_spkrs=2
use_spa=false
# frontend related
use_beamformer=True
blayers=3
bunits=512
bprojs=512
bnmask=3

# rnnlm related
use_wordlm=true     # false means to train/use a character LM
lm_vocabsize=65000  # effective only for word LMs
lm_resume=          # specify a snapshot file to resume LM training
lmtag=              # tag for managing LMs

# decoding parameter
recog_model=model.acc.best # set a model to be used for decoding: 'model.acc.best' or 'model.loss.best'

# data
wsj0=/export/corpora5/LDC/LDC93S6B
wsj1=/export/corpora5/LDC/LDC94S13B
wsj_full_wav=$PWD/data/wsj0/wsj0_wav
wsj_2mix_wav=$PWD/data/wsj0_mix/2speakers
wsj_2mix_scripts=$PWD/data/wsj0_mix/scripts

# cmvn
stats_file=

# exp tag
tag="" # tag for managing experiments.

. utils/parse_options.sh || exit 1;

# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail

environ="anechoic"
train_set="tr_spatialized_${environ}"
train_dev="cv_spatialized_${environ}"
train_test="tt_spatialized_${environ}"
recog_set="cv_spatialized_${environ}" #tt_spatialized_${environ}

num_channels=2

if [ ${stage} -le 0 ]; then
    ### Task dependent. You have to make data the following preparation part by yourself.
    echo "stage 0: Data preparation"
    ### This part is for WSJ0 mix
    ### Download mixture scripts and create mixtures for 2 speakers
    local/wsj0_create_mixture.sh ${wsj_2mix_scripts} ${wsj0} ${wsj_full_wav} \
        ${wsj_2mix_wav} || exit 1;
    local/wsj0_2mix_data_prep.sh ${wsj_2mix_wav}/wav16k/max ${wsj_2mix_scripts} \
        ${wsj_full_wav} || exit 1;

    ### Also need wsj corpus to prepare language information
    ### This is from Kaldi WSJ recipe
    local/wsj_data_prep.sh ${wsj0}/??-{?,??}.? ${wsj1}/??-{?,??}.?
    local/wsj_format_data.sh
    mkdir -p data/wsj
    mv data/{dev_dt_*,local,test_dev*,test_eval*,train_si284} data/wsj

    ### Or this part is for WSJ mix, which is a larger two-speaker mixture corpus created from WSJ corpus. Used in
    ### Seki H, Hori T, Watanabe S, et al. End-to-End Multi-Lingual Multi-Speaker Speech Recognition[J]. 2018. and
    ### Chang X, Qian Y, Yu K, et al. End-to-End Monaural Multi-speaker ASR System without Pretraining[J]. 2019
    ### Before next step, suppose wsj_2mix_corpus has been generated (please refer to wsj0_mixture for more details).
    #local/wsj_2mix_data_prep.sh ${wsj_2mix_wav}/wav16k/max ${wsj_2mix_script} || exit 1;
fi

feat_tr_dir=${dumpdir}/${environ}/${train_set}_${num_channels}ch; mkdir -p ${feat_tr_dir}
feat_dt_dir=${dumpdir}/${environ}/${train_dev}_${num_channels}ch; mkdir -p ${feat_dt_dir}
if [ ${stage} -le 1 ]; then
    ### Task dependent. You have to design training and dev sets by yourself.
    ### But you can utilize Kaldi recipes in most cases
    echo "stage 1: Dump wav files into HDF5 format"

    for setname in ${train_set} ${train_dev} ${train_test}; do
        mkdir -p data/${environ}/${setname}_${num_channels}ch
        <data/${environ}/${setname}/utt2spk sed -r 's/^(.*?).CH[0-9](_?.*?) /\1\2 /g' | sort -u >data/${environ}/${setname}_${num_channels}ch/utt2spk
        <data/${environ}/${setname}/text_spk1 sed -r 's/^(.*?).CH[0-9](_?.*?) /\1\2 /g' | sort -u >data/${environ}/${setname}_${num_channels}ch/text_spk1
        <data/${environ}/${setname}/text_spk2 sed -r 's/^(.*?).CH[0-9](_?.*?) /\1\2 /g' | sort -u >data/${environ}/${setname}_${num_channels}ch/text_spk2
        <data/${environ}/${setname}_${num_channels}ch/utt2spk utils/utt2spk_to_spk2utt.pl >data/${environ}/${setname}_${num_channels}ch/spk2utt
 
        for ch in 1 2; do
            <data/${environ}/${setname}/wav.scp grep "CH${ch}" | sed -r 's/^(.*?).CH[0-9](_?.*?) /\1\2 /g' | sort -u >data/${environ}/${setname}_${num_channels}ch/wav_ch${ch}.scp
        done
        mix-mono-wav-scp.py data/${environ}/${setname}_${num_channels}ch/wav_ch*.scp >data/${environ}/${setname}_${num_channels}ch/wav.scp
        rm -f data/${environ}/${setname}_${num_channels}ch/wav_ch*.scp
    done

    dump_pcm.sh --nj 32 --cmd "$train_cmd" --filetype "sound.hdf5" --format flac data/${environ}/${train_set}_${num_channels}ch
    dump_pcm.sh --nj 4 --cmd "$train_cmd" --filetype "sound.hdf5" --format flac data/${environ}/${train_dev}_${num_channels}ch
    for rtask in ${recog_set}; do
        feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir}
        dump_pcm.sh --nj 4 --cmd "$train_cmd" --filetype "sound.hdf5" --format flac data/${environ}/${rtask}_${num_channels}ch
    done
fi

dict=data/lang_1char/${train_set}_units.txt
nlsyms=data/lang_1char/non_lang_syms.txt
wsj_train_set=train_si284
wsj_train_dev=test_dev93
wsj_train_test=test_eval92

echo "dictionary: ${dict}"
if [ ${stage} -le 2 ]; then
    ### Task dependent. You have to check non-linguistic symbols used in the corpus.
    echo "stage 2: Dictionary and Json Data Preparation"
    mkdir -p data/lang_1char/

    echo "make a non-linguistic symbol list"
    cut -f 2- data/${wsj_train_set}/text | tr " " "\n" | sort | uniq | grep "<" > ${nlsyms}
    cat ${nlsyms}

    echo "make a dictionary"
    echo "<unk> 1" > ${dict} # <unk> must be 1, 0 will be used for "blank" in CTC
    text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_set}/text | cut -f 2- -d" " | tr " " "\n" \
    | sort | uniq | grep -v -e '^\s*$' | awk '{print $0 " " NR+1}' >> ${dict}
    wc -l ${dict}

    echo "make json files"
    local/data2json.sh --cmd "${train_cmd}" --nj 30 --num_spkrs ${num_spkrs} \
        --category "${num_channels}channels" \
        --preprocess_conf ${preprocess_config} --filetype sound.hdf5 \
        --feat ${feat_tr_dir}/feats.scp --nlsyms ${nlsyms} \
        --out data/${environ}/${train_set}_${num_channels}ch/data.json data/${environ}/${train_set}_${num_channels}ch ${dict}
    local/data2json.sh --cmd "${train_cmd}" --nj 4 --num_spkrs ${num_spkrs} \
        --category "${num_channels}channels" \
        --preprocess_conf ${preprocess_config} --filetype sound.hdf5 \
        --feat ${feat_dt_dir}/feats.scp --nlsyms ${nlsyms} \
        --out data/${environ}/${train_dev}_${num_channels}ch/data.json data/${environ}/${train_dev}_${num_channels}ch ${dict}

    for rtask in ${recog_set}; do
        feat_recog_dir=${dumpdir}/${environ}/${rtask}_${num_channels}ch;
        local/data2json.sh --cmd "${train_cmd}" --nj 4 --num_spkrs ${num_spkrs} \
            --category "${num_channels}channels" \
            --preprocess_conf ${preprocess_config} --filetype sound.hdf5 \
            --feat ${feat_recog_dir}/feats.scp --nlsyms ${nlsyms} \
            --out data/${environ}/${rtask}_${num_channels}ch/data.json data/${environ}/${rtask}_${num_channels}ch ${dict}
    done

    mkdir -p data/tr_spatialized_${environ}_${num_channels}ch_singlespkr
    concatjson.py data/${environ}/tr_spatialized_${environ}_${num_channels}ch/data.json data/train_si284/data.json > data/tr_spatialized_anechoic_${num_channels}ch_singlespkr/data.json
fi


# It takes about one day. If you just want to do end-to-end ASR without LM,
# you can skip this and remove --rnnlm option in the recognition (stage 5)
if [ -z ${lmtag} ]; then
    lmtag=$(basename ${lm_config%.*})
    if [ ${use_wordlm} = true ]; then
        lmtag=${lmtag}_word${lm_vocabsize}
    fi
fi
lmexpname=train_rnnlm_${backend}_${lmtag}
lmexpdir=exp/${lmexpname}
mkdir -p ${lmexpdir}

if [ ${stage} -le 3 ]; then
    echo "stage 3: LM Preparation"

    if [ ${use_wordlm} = true ]; then
        lmdatadir=data/local/wordlm_train
        lmdict=${lmdatadir}/wordlist_${lm_vocabsize}.txt
        mkdir -p ${lmdatadir}
        cut -f 2- -d" " data/${wsj_train_set}/text > ${lmdatadir}/train_trans.txt
        zcat ${wsj1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \
                | grep -v "<" | tr "[:lower:]" "[:upper:]" > ${lmdatadir}/train_others.txt
        cut -f 2- -d" " data/${wsj_train_dev}/text > ${lmdatadir}/valid.txt
        cut -f 2- -d" " data/${wsj_train_test}/text > ${lmdatadir}/test.txt
        cat ${lmdatadir}/train_trans.txt ${lmdatadir}/train_others.txt > ${lmdatadir}/train.txt
        text2vocabulary.py -s ${lm_vocabsize} -o ${lmdict} ${lmdatadir}/train.txt
    else
        lmdatadir=data/local/lm_train
        lmdict=${dict}
        mkdir -p ${lmdatadir}
        text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_set}/text \
            | cut -f 2- -d" " > ${lmdatadir}/train_trans.txt
        zcat ${wsj1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \
            | grep -v "<" | tr "[:lower:]" "[:upper:]" \
            | text2token.py -n 1 | cut -f 2- -d" " > ${lmdatadir}/train_others.txt
        text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_dev}/text \
            | cut -f 2- -d" " > ${lmdatadir}/valid.txt
        text2token.py -s 1 -n 1 -l ${nlsyms} data/${wsj_train_test}/text \
                | cut -f 2- -d" " > ${lmdatadir}/test.txt
        cat ${lmdatadir}/train_trans.txt ${lmdatadir}/train_others.txt > ${lmdatadir}/train.txt
    fi

    ${cuda_cmd} --gpu ${ngpu} ${lmexpdir}/train.log \
        lm_train.py \
        --ngpu ${ngpu} \
        --backend ${backend} \
        --verbose 1 \
        --outdir ${lmexpdir} \
        --tensorboard-dir tensorboard/${lmexpname} \
        --train-label ${lmdatadir}/train.txt \
        --valid-label ${lmdatadir}/valid.txt \
        --test-label ${lmdatadir}/test.txt \
        --resume ${lm_resume} \
        --layer ${lm_layers} \
        --unit ${lm_units} \
        --opt ${lm_opt} \
        --batchsize ${lm_batchsize} \
        --epoch ${lm_epochs} \
        --patience ${lm_patience} \
        --maxlen ${lm_maxlen} \
        --dict ${lmdict}
fi


if [ -z ${tag} ]; then
    expname=${train_set}_${backend}_b${blayers}_unit${bunits}_proj${bprojs}_$(basename ${train_config%.*})_$(basename ${preprocess_config%.*})_spa${use_spa}
    if ${do_delta}; then
        expname=${expname}_delta
    fi
else
    expname=${train_set}_${backend}_${tag}
fi
${use_spa} && spa=true
expdir=exp/${expname}
mkdir -p ${expdir}

if [ ${stage} -le 4 ]; then
    echo "stage 4: Network Training"

    ${cuda_cmd} --gpu ${ngpu} ${expdir}/train.log \
        asr_train.py \
        --config ${train_config} \
        --ngpu ${ngpu} \
        --backend ${backend} \
        --outdir ${expdir}/results \
        --tensorboard-dir tensorboard/${expname} \
        --debugmode ${debugmode} \
        --dict ${dict} \
        --debugdir ${expdir} \
        --minibatches ${N} \
        --verbose ${verbose} \
        --resume ${resume} \
        --seed ${seed} \
        --train-json data/tr_spatialized_anechoic_2ch_singlespkr/data.json \
        --valid-json ${feat_dt_dir}/data.json \
        --preprocess-conf ${preprocess_config} \
        --num-spkrs ${num_spkrs} \
        --use-frontend True \
        --use-beamformer ${use_beamformer} \
        --blayers ${blayers} \
        --bunits ${bunits} \
        --bprojs ${bprojs} \
        --bnmask ${bnmask} \
        --stats-file ${stats_file} \
        ${spa:+--spa}
    exit 0;
fi

if [ ${stage} -le 5 ]; then
    echo "stage 5: Decoding"
    nj=32

    pids=() # initialize pids
    recog_set="tt_spatialized_${environ}" #cv_spatialized_${environ}
    for rtask in ${recog_set}; do
    (
        decode_dir=decode_${rtask}_$(basename ${decode_config%.*})_${lmtag}
        if [ ${use_wordlm} = true ]; then
            recog_opts="--word-rnnlm ${lmexpdir}/rnnlm.model.best"
        else
            recog_opts="--rnnlm ${lmexpdir}/rnnlm.model.best"
        fi
        feat_recog_dir=${dumpdir}/${environ}/${rtask}_${num_channels}ch

        # split data
        splitjson.py --parts ${nj} ${feat_recog_dir}/data.json

        #### use CPU for decoding
        ngpu=0

        ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \
            asr_recog.py \
            --num-spkrs ${num_spkrs} \
            --config ${decode_config} \
            --ngpu ${ngpu} \
            --backend ${backend} \
            --recog-json ${feat_recog_dir}/split${nj}utt/data.JOB.json \
            --result-label ${expdir}/${decode_dir}/data.JOB.json \
            --model ${expdir}/results/${recog_model} \
            --beam-size 30 \
            ${recog_opts}

        score_sclite.sh --wer true --nlsyms ${nlsyms} --num_spkrs 2 ${expdir}/${decode_dir} ${dict}

    ) &
    pids+=($!) # store background pids
    done
    i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
    [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false
    echo "Finished"
fi
