# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Transformer speech recognition model (pytorch)."""

from argparse import Namespace
from distutils.util import strtobool
import logging
import math
import random

import chainer
from chainer import reporter
import numpy
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F

from espnet.nets.asr_interface import ASRInterface
from espnet.nets.ctc_prefix_score import CTCPrefixScore
from espnet.nets.e2e_asr_common import end_detect
from espnet.nets.e2e_asr_common import ErrorCalculator
from espnet.nets.pytorch_backend.ctc import CTC
from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD
from espnet.nets.pytorch_backend.nets_utils import get_subsample
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.decoder import Decoder
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
from espnet.nets.pytorch_backend.transformer.initializer import initialize
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.mask import target_mask
from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport
from espnet.nets.scorers.ctc import CTCPrefixScorer

import espnet.nets.pytorch_backend.separation.criterion as sep_criterion
from espnet.nets.pytorch_backend.separation.encoder import Encoder as SepEncoder
from espnet.nets.pytorch_backend.separation.temporal_convnet import TemporalConvNet
from espnet.nets.pytorch_backend.separation.decoder import Decoder as SepDecoder


params = {
    "dither": 1.0,
    "sample_frequency": 8000,
    "frame_length": 25,
    "low_freq": 20,
    "num_mel_bins": 80
}


class Reporter(chainer.Chain):
    """A chainer reporter wrapper."""

    def report(self, loss_ctc, loss_att, acc, cer_ctc, cer, wer, 
               loss_tas, mtl_loss):
        """Report at every step."""
        reporter.report({"loss_ctc": loss_ctc}, self)
        reporter.report({"loss_att": loss_att}, self)
        reporter.report({"acc": acc}, self)
        reporter.report({"cer_ctc": cer_ctc}, self)
        reporter.report({"cer": cer}, self)
        reporter.report({"wer": wer}, self)
        reporter.report({"loss_tas": loss_tas}, self)
        logging.info("mtl loss:" + str(mtl_loss))
        reporter.report({"loss": mtl_loss}, self)


class E2E(ASRInterface, torch.nn.Module):
    """E2E module.

    :param int idim: dimension of inputs
    :param int odim: dimension of outputs
    :param Namespace args: argument Namespace containing options

    """
    @staticmethod
    def add_arguments(parser):
        """Add arguments."""
        E2E.add_asr_arguments(parser)
        E2E.add_sep_arguments(parser)
        return parser

    @staticmethod
    def add_asr_arguments(parser):
        """Add asr arguments."""
        group = parser.add_argument_group("transformer asr model setting")

        group.add_argument(
            "--transformer-init",
            type=str,
            default="pytorch",
            choices=[
                "pytorch",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
            ],
            help="how to initialize transformer parameters",
        )
        group.add_argument(
            "--transformer-input-layer",
            type=str,
            default="conv2d",
            choices=["conv2d", "linear", "embed"],
            help="transformer input layer type",
        )
        group.add_argument(
            "--transformer-attn-dropout-rate",
            default=None,
            type=float,
            help="dropout in transformer attention. use --dropout-rate if None is set",
        )
        group.add_argument(
            "--transformer-lr",
            default=10.0,
            type=float,
            help="Initial value of learning rate",
        )
        group.add_argument(
            "--transformer-warmup-steps",
            default=25000,
            type=int,
            help="optimizer warmup steps",
        )
        group.add_argument(
            "--transformer-length-normalized-loss",
            default=True,
            type=strtobool,
            help="normalize loss by length",
        )

        group.add_argument(
            "--dropout-rate",
            default=0.0,
            type=float,
            help="Dropout rate for the encoder",
        )
        # Encoder
        group.add_argument(
            "--elayers",
            default=4,
            type=int,
            help="Number of encoder layers (for shared recognition part "
            "in multi-speaker asr mode)",
        )
        group.add_argument(
            "--eunits",
            "-u",
            default=300,
            type=int,
            help="Number of encoder hidden units",
        )
        # Attention
        group.add_argument(
            "--adim",
            default=320,
            type=int,
            help="Number of attention transformation dimensions",
        )
        group.add_argument(
            "--aheads",
            default=4,
            type=int,
            help="Number of heads for multi head attention",
        )
        # Decoder
        group.add_argument(
            "--dlayers", default=1, type=int, help="Number of decoder layers"
        )
        group.add_argument(
            "--dunits", default=320, type=int, help="Number of decoder hidden units"
        )
        group.add_argument(
            "--sampling-probability",
            default=0.0,
            type=float,
            help="Ratio of using predicted wav to asr",
        )
        group.add_argument(
            "--condition-sampling-probability",
            default=0.0,
            type=float,
            help="Ratio of predicted ctc alignment as condition",
        )
        return parser

    @staticmethod
    def add_sep_arguments(parser):
        """Add arguments."""
        group = parser.add_argument_group("Tasnet model paremeters.")
        group.add_argument(
            "--N",
            default=256,
            type=int,
            help="Number of filters in encoder layers"
        )
        group.add_argument(
            "--L",
            default=20,
            type=int,
            help="Length of the filters (in samples)"
        )
        group.add_argument(
            "--B",
            default=256,
            type=int,
            help="Number of channels in bottleneck 1 * 1-conv block"
        )
        group.add_argument(
            "--H",
            default=512,
            type=int,
            help="Number of channels in convolutional blocks"
        )
        group.add_argument(
            "--P",
            default=3,
            type=int,
            help="Kernel size in convolutional blocks"
        )
        group.add_argument(
            "--X",
            default=8,
            type=int,
            help="Number of convolutional blocks in each repeat"
        )
        group.add_argument(
            "--R",
            default=4,
            type=int,
            help="Number of repeats"
        )
        group.add_argument(
            "--C",
            default=2,
            type=int,
            help="Number of speakers"
        )
        group.add_argument(
            "--norm-type",
            default="gLN",
            type=str,
            choices=[
                "BN",
                "gLN",
                "cLN"
            ],
            help="BN, gLN, cLN"
        )
        group.add_argument(
            "--causal",
            default=0,
            type=int,
            help="causal or non-causal"
        )
        group.add_argument(
            "--mask-nonlinear",
            default='relu',
            type=str,
            help="use which non-linear function to generate mask"
        )
        group.add_argument(
            "--end-separation-mode",
            default=0,
            type=int,
            help="end-separation-mode"
        )
        group.add_argument(
            '--greedy-tf',
            default=0,
            type=int,
            help='greedy-tf'
        )        
        group.add_argument(
            '--add-last-silence',
            default=0,
            type=int,
            help='Add last silence'
        )
        group.add_argument(
            '--pit-without-tf',
            default=0,
            type=int,
            help='PIT without teacher force.'
        )

        return parser        

    @property
    def attention_plot_class(self):
        """Return PlotAttentionReport."""
        return PlotAttentionReport

    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.embed = nn.Embedding(odim, args.B)
        self.dropout_emb = nn.Dropout(p=args.dropout_rate)
        self.sos = odim - 1
        self.eos = odim - 1
        self.blank = args.sym_blank
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="transformer")
        self.reporter = Reporter()

        self.criterion = LabelSmoothingLoss(
            self.odim,
            self.ignore_id,
            args.lsm_weight,
            args.transformer_length_normalized_loss,
        )
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        self.mtlbelta = args.mtlbelta
        self.global_mean = torch.tensor(args.global_mean)
        self.global_std = torch.tensor(args.global_std)
        if args.mtlalpha > 0.0:
            self.ctc = CTC(
                odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
            )
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None
        self.blank = args.sym_blank
        self.sampling_probability = args.sampling_probability
        self.condition_sampling_probability = args.condition_sampling_probability

        # separtaion related
        self.add_last_silence = args.add_last_silence
        self.pit_without_tf = args.pit_without_tf
        self.greedy_tf = args.greedy_tf

        # tasnet
        self.sep_encoder = SepEncoder(args.L, args.N)
        self.sep_separator = TemporalConvNet(args.N, args.B, args.H, args.P, args.X, args.R,
                                             args.C, args.norm_type, args.causal, args.mask_nonlinear, args.end_separation_mode)
        self.sep_mask_conv1x1 = nn.Conv1d(args.B, args.N, 1, bias=False)
        self.sep_decoder = SepDecoder(args.N, args.L)
        self.sep_spk_lstm = nn.LSTMCell(args.B*2+args.N, args.B)  # LSTM over the speakers' step.
        self.sep_criterion = sep_criterion

    def reset_parameters(self, args):
        """Initialize parameters."""
        # initialize parameters
        initialize(self, args.transformer_init)

    def forward(self, xs_pad, ilens, ys_pad, ys_wav_pad, ys_ctc_align_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_wav_pad: batch of padded target wave for separation (B, num_spkers, Tmax)
        :param torch.Tensor ys_ctc_align_pad: batch of padded forced alignment sequences (B, num_spkrs, Tmax')
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # forward separation encoder
        batch_size, num_spkrs = ys_wav_pad.size(0), ys_wav_pad.size(1)
        ys_dict = []
        ys_wav_dict = []
        ys_ctc_align_dict = []
        for i in range(batch_size):
            dict1 = {idx: cand for idx, cand in enumerate(ys_pad[i])}
            dict2 = {idx: cand for idx, cand in enumerate(ys_wav_pad[i])}
            dict3 = {idx: cand for idx, cand in enumerate(ys_ctc_align_pad[i])}
            ys_dict.append(dict1)
            ys_wav_dict.append(dict2)
            ys_ctc_align_dict.append(dict3)

        # xs_pad: (BS, T)
        enc_output = self.sep_encoder(xs_pad)  # (BS, N, K)
        # sep_output: (BS, B, K), where K = (T - L) / (L / 2) + 1 = 2 T / L - 1
        sep_output = self.sep_separator(enc_output)

        # init condition to zero at the beginning
        wav_condition_pre = torch.zeros_like(enc_output)
        ctc_condition_pre = torch.zeros_like(ys_ctc_align_pad[:, 0])
        N, B = self.sep_encoder.N, self.sep_separator.B
        BS, K = enc_output.size(0), enc_output.size(-1) # new length

        # preds_wav = []
        # ys_resorted = []
        # ys_wav_resorted = []
        # ys_ctc_align_resorted = []
        # spks_list = []
        hs_pad = []
        loss_att = []
        loss_ctc = []
        loss_tas = []
        acc = []
        cer = []
        cer_ctc = []
        wer = []
        repeat_time = num_spkrs + 1 if self.add_last_silence else num_spkrs

        for step_idx in range(repeat_time):
            ctc_condition_pre = self.create_ctc_condition(ctc_condition_pre, enc_output.size(-1))
            # (BS, 2B+N, K)
            cat_condition_cur = torch.cat((sep_output, wav_condition_pre, ctc_condition_pre), 1)
            if step_idx == 0:
                h_0, c_0 = torch.zeros(BS*K, B).to(xs_pad.device), torch.zeros(BS*K, B).to(xs_pad.device)
                # (BS, 2B+N, K) --> (BS, K, 2B+N) --> (BS*K, 2B+N)
                lstm_h, lstm_c = self.sep_spk_lstm(cat_condition_cur.transpose(1, 2).contiguous().view(-1, 2*B+N), (h_0, c_0))
                del h_0, c_0
            else:
                # (BS, B+N, K) --> (BS, K, B+N) --> (BS*K, 2*B+N)
                lstm_h, lstm_c = self.sep_spk_lstm(cat_condition_cur.transpose(1, 2).contiguous().view(-1, 2*B+N), (lstm_h, lstm_c))

            # (BS*K, B) --> (BS, K, B) --> (BS, B, K) --> (BS, N, K)
            pred_wav_cur = self.sep_mask_conv1x1(lstm_h.view(-1, K, B).transpose(1, 2))
            pred_wav_cur = F.relu(pred_wav_cur).unsqueeze(1)    # (BS, 1, N, K)
            pred_wav_cur = self.sep_decoder(enc_output, pred_wav_cur).squeeze(1)    # (BS, T)
            T_origin = xs_pad.size(-1)
            T_conv = pred_wav_cur.size(-1)
            pred_wav_cur = F.pad(pred_wav_cur, (0, T_origin-T_conv))
            if self.add_last_silence and step_idx == repeat_time - 1:  # 如果是最后一个，后面就不用
                continue

            # update the condition
            y_cur, y_wav_cur, y_ctc_align_cur, spk_list_cur, \
                ys_dict, ys_wav_dict, ys_ctc_align_dict = self.choose_candidate(
                pred_wav_cur, ilens, ys_dict, ys_wav_dict, ys_ctc_align_dict, BS)
            logging.info("Step: {}, spk list: {}".format(step_idx, spk_list_cur))

            y_cur = torch.cat(y_cur, 0)  # (BS, tgt_len)
            y_ctc_align_cur = torch.cat(y_ctc_align_cur)  # (BS, num_frame)
            y_wav_cur = torch.cat(y_wav_cur, 0)  # (BS, T)
            loss_tas_cur = self.sep_criterion.cal_loss_with_sdr_order(
                pred_wav_cur.unsqueeze(0),
                y_wav_cur.unsqueeze(0), ilens)[0]

            # preds_wav.append(pred_wav_cur)
            # ys_resorted.append(y_cur)
            # ys_wav_resorted.append(y_wav_cur)
            # ys_ctc_align_resorted.append(y_ctc_align_cur)
            loss_tas.append(loss_tas_cur)

            # ASR encode part
            # 1. forward asr encoder
            if random.random() < self.sampling_probability:
                logging.info("Input predicted wav to ASR.")
                feats = self.create_feats(pred_wav_cur, self.global_mean, self.global_std)
            else:
                feats = self.create_feats(y_wav_cur, self.global_mean, self.global_std)  # (batch, num_frame, fbank_dim)
            ilens_freq = 1 + (ilens - 200) // 80
            src_mask = make_non_pad_mask(ilens_freq.tolist()).to(xs_pad.device).unsqueeze(-2)
            hs_pad, hs_mask = self.encoder(feats, src_mask)

            # 2. forward asr decoder
            y_cur_in, y_cur_out = add_sos_eos(y_cur, self.sos, self.eos, self.ignore_id)
            y_cur_mask = target_mask(y_cur_in, self.ignore_id)
            pred_pad, pred_mask = self.decoder(y_cur_in, y_cur_mask, hs_pad, hs_mask)  

            # 3. compute asr attention loss
            loss_att_cur = self.criterion(pred_pad, y_cur_out)
            acc_cur = th_accuracy(
                pred_pad.view(-1, self.odim), y_cur_out, ignore_label=self.ignore_id)
            loss_att.append(loss_att_cur)
            acc.append(acc_cur)

            # 4. compute ctc loss
            if self.mtlalpha == 0.0:
                loss_ctc_cur = None
                cer_ctc_cur = None
            else:
                hs_len = hs_mask.view(feats.size(0), -1).sum(1)
                loss_ctc_cur = self.ctc(hs_pad, hs_len, y_cur)
                if self.error_calculator is not None:
                    y_cur_hat = self.ctc.argmax(hs_pad).data
                    cer_ctc_cur = self.error_calculator(
                        y_cur_hat.cpu(), y_cur.cpu(), is_ctc=True)

            if loss_ctc_cur > CTC_LOSS_THRESHOLD or math.isnan(loss_ctc_cur):
                logging.warning("hs_pad size: {}, y_cur size: {}".format(
                    hs_pad.size(), y_cur.size()))
            loss_ctc.append(loss_ctc_cur)
            cer_ctc.append(cer_ctc_cur)

            # 5. compute attention cer/wer
            if self.training or self.error_calculator is None:
                cer_cur, wer_cur = None, None
            else:
                y_cur_hat = pred_pad.argmax(dim=-1)
                cer_cur, wer_cur = self.error_calculator(
                    y_cur_hat.cpu(), y_cur.cpu())
            cer.append(cer_cur)
            wer.append(wer_cur)

            # get conditions of next step
            # for wav, use a conv1d to subsample the original wav to (BS, N, K)
            # for ctc alignment, use scheduled sampling
            wav_condition_pre = self.sep_encoder(
                y_wav_cur + 0.5*torch.randn_like(y_wav_cur))
            if random.random() < self.condition_sampling_probability:
                logging.info('Input generated ctc alignment as condition.')
                ctc_condition_pre = self.ctc.argmax(hs_pad).data
            else:
                ctc_condition_pre = y_ctc_align_cur

        alpha = self.mtlalpha
        belta = self.mtlbelta
        loss_tas = torch.stack(loss_tas, dim=0).mean()
        loss_tas_data = float(loss_tas)
        if alpha == 0:
            loss_att = torch.stack(loss_att, dim=0).mean()       
            self.loss = belta * loss_tas + (1 - belta) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
            acc = torch.stack(acc, dim=1).mean()
            cer = torch.stack(cer, dim=1).mean()
            wer = torch.stack(wer, dim=1).mean()
            cer_ctc = None
        elif alpha == 1:
            loss_ctc = torch.stack(loss_ctc, dim=0).mean()
            # self.loss = belta * loss_tas + (1 - belta) * loss_ctc
            self.loss = loss_tas + belta * loss_ctc
            loss_ctc_data = float(loss_ctc)
            loss_att_data = None
            acc, cer, wer = None, None, None
            # cer_ctc = torch.stack(cer_ctc, dim=1).mean()
            cer_ctc = sum(cer_ctc)/len(cer_ctc)
        else:
            loss_att = torch.stack(loss_att, dim=0).mean()  
            loss_ctc = torch.stack(loss_ctc, dim=0).mean()
            loss_asr = alpha * loss_ctc + (1 - alpha) * loss_att
            self.loss = belta * loss_tas + (1 - belta) * loss_asr
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)
            acc = torch.stack(acc, dim=1).mean()
            cer = torch.stack(cer, dim=1).mean()
            wer = torch.stack(wer, dim=1).mean()
            cer_ctc = torch.stack(cer_ctc, dim=1).mean()

        loss_data = float(self.loss)

        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, 
                loss_tas_data, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss

    def choose_candidate(self, pred_wav, ilens, ys_dict, ys_wav_dict, ys_ctc_align_dict, BS):
        spk_list = []
        cand_list = []
        cand_wavs_list = []
        cand_ctc_align_list = []

        for idx in range(BS):
            est_wav = pred_wav[idx]  # T
            candidates_wav_dict = ys_wav_dict[idx]  # dict topk,T
            key_max = None
            snr_max = None  # original key and dist
            for key, cand_wav in candidates_wav_dict.items():
                snr = self.sep_criterion.cal_sdr_with_order(
                    cand_wav.view(1, 1, -1), est_wav.view(1, 1, -1),
                    ilens[idx].view(1))
                if snr_max is None:
                    snr_max = snr
                    key_max = key
                else:
                    if snr > snr_max:
                        snr_max = snr
                        key_max = key
            spk_list.append(key_max)
            cand_list.append(ys_dict[idx][key_max].unsqueeze(0))
            cand_wavs_list.append(ys_wav_dict[idx][key_max].unsqueeze(0)) # list of 1,T
            cand_ctc_align_list.append(ys_ctc_align_dict[idx][key_max].unsqueeze(0))
            ys_dict[idx].pop(key_max)
            ys_wav_dict[idx].pop(key_max)  # remove this element
            ys_ctc_align_dict[idx].pop(key_max)

        return cand_list, cand_wavs_list, cand_ctc_align_list, spk_list, ys_dict, ys_wav_dict, ys_ctc_align_dict

    def create_feats(self, audios_pad, global_mean, global_std):
        """Compute fbank feature for a give audio waveform."""
        feats_list = []
        global_mean = global_mean.to(audios_pad.device)
        global_std = global_std.to(audios_pad.device)
        for i in range(audios_pad.size(0)):  # num samples
            waveform = audios_pad[i].unsqueeze(0)
            feat = torchaudio.compliance.kaldi.fbank(waveform, **params)
            feats_list.append(feat)
        feats = torch.stack(feats_list, dim=0)
        feats_norm = feats - global_mean
        feats_norm = feats_norm / global_std

        return feats

    def create_ctc_condition(self, ys_pad, expended_len):
        ys = [y[y != self.ignore_id] for y in ys_pad]  # parse padded ys
        # padding for ys 
        ys_in_pad = pad_list(ys, self.eos)

        # pre-computation of embedding
        ys_pad_emb = self.dropout_emb(self.embed(ys_in_pad))
        ys_pad_upsampled = F.interpolate(ys_pad_emb.transpose(1,2), size=[expended_len])

        return ys_pad_upsampled

    def scorers(self):
        """Scorers."""
        return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))

    def encode(self, x):
        """Encode acoustic features.

        :param ndarray x: source acoustic feature (T, D)
        :return: encoder outputs
        :rtype: torch.Tensor
        """
        self.eval()
        x = torch.as_tensor(x).unsqueeze(0)
        enc_output, _ = self.encoder(x, None)
        return enc_output.squeeze(0)

    def recognize(self, x, ilen, recog_args, char_list=None, rnnlm=None, use_jit=False, num_spkrs=2):
        """Recognize input speech.

        :param ndnarray x: input acoustic feature (B, T, D) or (T, D)
        :param Namespace recog_args: argment Namespace contraining options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        enc_output = self.sep_encoder(x)  # (BS, N, K)
        # sep_output: (BS, B, K), where K = (T - L) / (L / 2) + 1 = 2 T / L - 1
        sep_output = self.sep_separator(enc_output)

        # First step to use all ZEROs
        wav_condition_pre = torch.zeros_like(enc_output)
        ctc_condition_pre = torch.zeros_like(enc_output)
        N, B = self.sep_encoder.N, self.sep_separator.B
        BS, K = enc_output.size(0), enc_output.size(-1) # new length

        repeat_time = num_spkrs + 1 if self.add_last_silence else num_spkrs

        nbest_hyps_list = []
        preds_wav = []
        for step_idx in range(repeat_time):
            if step_idx == 0:
                # (BS, B+N, K) --> (BS, K, B+N) --> (BS*K, B+N) --> (BS*K, B)
                cat_condition_cur = torch.cat((sep_output, wav_condition_pre, ctc_condition_pre), 1)
                h_0, c_0 = torch.zeros(BS*K, B).to(x.device), torch.zeros(BS*K, B).to(x.device)
                lstm_h, lstm_c = self.sep_spk_lstm(cat_condition_cur.transpose(1, 2).contiguous().view(-1, 2*B+N), (h_0, c_0))
                del h_0, c_0
            else:
                ctc_condition_pre = self.create_ctc_condition(ctc_condition_pre, enc_output.size(-1))
                # (BS, 2B+N, K)
                cat_condition_cur = torch.cat((sep_output, wav_condition_pre, ctc_condition_pre), 1)
                # (BS, B+N, K) --> (BS, K, B+N) --> (BS*K, B+N) --> (BS*K, B)
                lstm_h, lstm_c = self.sep_spk_lstm(cat_condition_cur.transpose(1, 2).contiguous().view(-1, 2*B+N), (lstm_h, lstm_c))

            # (BS*K, B) --> (BS, K, B) --> (BS, B, K) --> (BS, N, K)
            pred_wav_cur = self.sep_mask_conv1x1(lstm_h.view(-1, K, B).transpose(1, 2))
            pred_wav_cur = F.relu(pred_wav_cur).unsqueeze(1)    # (BS, 1, N, K)
            pred_wav_cur = self.sep_decoder(enc_output, pred_wav_cur).squeeze(1)    # (BS, T)
            T_origin = x.size(-1)
            T_conv = pred_wav_cur.size(-1)
            pred_wav_cur = F.pad(pred_wav_cur, (0, T_origin-T_conv))
            preds_wav.append(pred_wav_cur)

            # ASR encoder part
            feat = self.create_feats(pred_wav_cur, self.global_mean, self.global_std)  # (batch, num_frame, fbank_dim)
            hs_pad, _ = self.encoder(feat, None)

            if recog_args.ctc_weight > 0.0:
                lpz = self.ctc.log_softmax(hs_pad)
                lpz = lpz.squeeze(0)
            else:
                lpz = None

            h = hs_pad.squeeze(0)

            logging.info("input lengths: " + str(h.size(0)))
            # search parms
            beam = recog_args.beam_size
            penalty = recog_args.penalty
            ctc_weight = recog_args.ctc_weight

            # preprare sos
            y = self.sos
            vy = h.new_zeros(1).long()

            if recog_args.maxlenratio == 0:
                maxlen = h.shape[0]
            else:
                # maxlen >= 1
                maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
            minlen = int(recog_args.minlenratio * h.size(0))
            logging.info("max output length: " + str(maxlen))
            logging.info("min output length: " + str(minlen))

            # initialize hypothesis
            if rnnlm:
                hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None}
            else:
                hyp = {"score": 0.0, "yseq": [y]}
            if lpz is not None:
                ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy)
                hyp["ctc_state_prev"] = ctc_prefix_score.initial_state()
                hyp["ctc_score_prev"] = 0.0
                if ctc_weight != 1.0:
                    # pre-pruning based on attention scores
                    ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
                else:
                    ctc_beam = lpz.shape[-1]
            hyps = [hyp]
            ended_hyps = []

            import six
            traced_decoder = None
            for i in six.moves.range(maxlen):
                logging.debug("position " + str(i))

                hyps_best_kept = []
                for hyp in hyps:
                    vy[0] = hyp["yseq"][i]

                    # get nbest local scores and their ids
                    ys_mask = subsequent_mask(i + 1).unsqueeze(0)
                    ys = torch.tensor(hyp["yseq"]).unsqueeze(0)
                    n_batch, _ = ys.size()
                    local_att_scores = torch.zeros(n_batch, self.odim, device=ys.device)

                    if rnnlm:
                        rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy)
                        local_scores = (
                            local_att_scores + recog_args.lm_weight * local_lm_scores
                        )
                    else:
                        local_scores = local_att_scores

                    if lpz is not None:
                        local_best_scores, local_best_ids = torch.topk(
                            local_att_scores, ctc_beam, dim=1
                        )
                        ctc_scores, ctc_states = ctc_prefix_score(
                            hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"]
                        )
                        local_scores = (1.0 - ctc_weight) * local_att_scores[
                            :, local_best_ids[0]
                        ] + ctc_weight * torch.from_numpy(
                            ctc_scores - hyp["ctc_score_prev"]
                        )
                        if rnnlm:
                            local_scores += (
                                recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
                            )
                        local_best_scores, joint_best_ids = torch.topk(
                            local_scores, beam, dim=1
                        )
                        local_best_ids = local_best_ids[:, joint_best_ids[0]]
                    else:
                        local_best_scores, local_best_ids = torch.topk(
                            local_scores, beam, dim=1
                        )

                    for j in six.moves.range(beam):
                        new_hyp = {}
                        new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j])
                        new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
                        new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
                        new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j])
                        if rnnlm:
                            new_hyp["rnnlm_prev"] = rnnlm_state
                        if lpz is not None:
                            new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]]
                            new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]]
                        # will be (2 x beam) hyps at most
                        hyps_best_kept.append(new_hyp)

                    hyps_best_kept = sorted(
                        hyps_best_kept, key=lambda x: x["score"], reverse=True
                    )[:beam]

                # sort and get nbest
                hyps = hyps_best_kept
                logging.debug("number of pruned hypothes: " + str(len(hyps)))
                if char_list is not None:
                    logging.debug(
                        "best hypo: "
                        + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])
                    )

                # add eos in the final loop to avoid that there are no ended hyps
                if i == maxlen - 1:
                    logging.info("adding <eos> in the last postion in the loop")
                    for hyp in hyps:
                        hyp["yseq"].append(self.eos)

                # add ended hypothes to a final list, and removed them from current hypothes
                # (this will be a probmlem, number of hyps < beam)
                remained_hyps = []
                for hyp in hyps:
                    if hyp["yseq"][-1] == self.eos:
                        # only store the sequence that has more than minlen outputs
                        # also add penalty
                        if len(hyp["yseq"]) > minlen:
                            hyp["score"] += (i + 1) * penalty
                            if rnnlm:  # Word LM needs to add final <eos> score
                                hyp["score"] += recog_args.lm_weight * rnnlm.final(
                                    hyp["rnnlm_prev"]
                                )
                            ended_hyps.append(hyp)
                    else:
                        remained_hyps.append(hyp)

                # end detection

                if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
                    logging.info("end detected at %d", i)
                    break

                hyps = remained_hyps
                if len(hyps) > 0:
                    logging.debug("remeined hypothes: " + str(len(hyps)))
                else:
                    logging.info("no hypothesis. Finish decoding.")
                    break

                if char_list is not None:
                    for hyp in hyps:
                        logging.debug(
                            "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])
                        )

                logging.debug("number of ended hypothes: " + str(len(ended_hyps)))

            nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[
                : min(len(ended_hyps), recog_args.nbest)
            ]

            # check number of hypotheis
            if len(nbest_hyps) == 0:
                logging.warning(
                    "there is no N-best results, perform recognition "
                    "again with smaller minlenratio."
                )
                # should copy becasuse Namespace will be overwritten globally
                recog_args = Namespace(**vars(recog_args))
                recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
                return self.recognize(x, recog_args, char_list, rnnlm)

            logging.info("total log probability: " + str(nbest_hyps[0]["score"]))
            logging.info(
                "normalized log probability: "
                + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))
            )

            nbest_hyps_list.append(nbest_hyps)

            # update the conditon by estimated wav
            wav_condition_pre = self.sep_encoder(pred_wav_cur)
            ctc_condition_pre = self.predict_alignment(lpz, nbest_hyps[0]["yseq"], char_list)
    
        return nbest_hyps_list, torch.stack(preds_wav, dim=1)

    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        """E2E attention calculation.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
        :rtype: float ndarray
        """
        with torch.no_grad():
            self.forward(xs_pad, ilens, ys_pad)
        ret = dict()
        for name, m in self.named_modules():
            if isinstance(m, MultiHeadedAttention):
                ret[name] = m.attn.cpu().numpy()
        return ret

    def predict_alignment(self, ctc_log_prob, best_seq, char_list=None):
        # print('ctc_log_prob:', ctc_log_prob.size())
        # print('best_seq', best_seq)
        # ret = torch.argmax(ctc_log_prob, dim=1)
        ret = torch.tensor(self.forward_process(ctc_log_prob, best_seq, char_list)).unsqueeze(0)
        return ret

    def forward_process(self, lpz, y, char_list):
        """Forward process of getting alignments of CTC

        :param torch.Tensor lpz: log probabilities of CTC (T, odim)
        :param torch.Tensor y: id sequence tensor (L)
        :param list char_list: list of characters
        :return: best alignment results
        :rtype: list
        """
        def interpolate_blank(l, blank_id=0):
            l = numpy.expand_dims(l, 1)
            zero = numpy.zeros((l.shape[0], 1), dtype=numpy.int64)
            l = numpy.concatenate([zero, l], axis=1)
            l = l.reshape(-1)
            l = numpy.append(l, l[0])
            return l

        blank_id = 0
        if char_list is not None:
            blank_id = char_list.index(self.blank)
        y_interp = interpolate_blank(y, blank_id)

        logdelta = numpy.zeros((lpz.size(0), len(y_interp))) - 100000000000.0  # log of zero
        state_path = numpy.zeros((lpz.size(0), len(y_interp)), dtype=numpy.int16) - 1  # state path

        logdelta[0, 0] = lpz[0][y_interp[0]]
        logdelta[0, 1] = lpz[0][y_interp[1]]

        for t in range(1, lpz.size(0)):
            for s in range(len(y_interp)):
                if (y_interp[s] == blank_id or s<2 or y_interp[s] == y_interp[s-2]):
                    candidates = numpy.array([logdelta[t-1, s], logdelta[t-1, s-1]])
                    prev_state = [s, s-1]
                else:
                    candidates = numpy.array([logdelta[t-1, s], logdelta[t-1, s-1], logdelta[t-1, s-2]])
                    prev_state = [s, s-1, s-2]
                logdelta[t, s] = numpy.max(candidates) + lpz[t][y_interp[s]]
                state_path[t, s] = prev_state[numpy.argmax(candidates)]
        
        state_seq = -1 * numpy.ones((lpz.size(0), 1), dtype=numpy.int16)

        candidates = numpy.array([logdelta[-1, len(y_interp)-1], logdelta[-1, len(y_interp)-2]])
        prev_state = [len(y_interp)-1, len(y_interp)-2]
        state_seq[-1] = prev_state[numpy.argmax(candidates)]
        for t in range(lpz.size(0)-2, -1, -1):
            state_seq[t] = state_path[t+1, state_seq[t+1, 0]]
        
        output_state_seq = []
        for t in range(0, lpz.size(0)):
            output_state_seq.append(y_interp[state_seq[t, 0]])

        # orig_seq = []
        # for t in range(0, len(y)):
        #     orig_seq.append(char_list[y[t]])

        return output_state_seq

    def get_ctc_alignments(self, x, y, char_list):
        """E2E get alignments of CTC

        :param torch.Tensor x: input acoustic feature (T, D)
        :param torch.Tensor y: id sequence tensor (L)
        :param list char_list: list of characters
        :return: best alignment results
        :rtype: list
        """
        def interpolate_blank(l, blank_id=0):
            l = np.expand_dims(l, 1)
            zero = np.zeros((l.shape[0], 1), dtype=np.int64)
            l = np.concatenate([zero, l], axis=1)
            l = l.reshape(-1)
            l = np.append(l, l[0])
            return l

        enc_output = self.encode(x).unsqueeze(0)
        lpz = self.ctc.log_softmax(enc_output)
        lpz = lpz.squeeze(0)

        blank_id = char_list.index(self.blank)
        y_int = interpolate_blank(y, blank_id)

        logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0  # log of zero
        state_path = np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1  # state path

        logdelta[0, 0] = lpz[0][y_int[0]]
        logdelta[0, 1] = lpz[0][y_int[1]]

        for t in range(1, lpz.size(0)):
            for s in range(len(y_int)):
                if (y_int[s] == blank_id or s<2 or y_int[s] == y_int[s-2]):
                    candidates = np.array([logdelta[t-1, s], logdelta[t-1, s-1]])
                    prev_state = [s, s-1]
                else:
                    candidates = np.array([logdelta[t-1, s], logdelta[t-1, s-1], logdelta[t-1, s-2]])
                    prev_state = [s, s-1, s-2]
                logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]]
                state_path[t, s] = prev_state[np.argmax(candidates)]
        
        state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16)

        candidates = np.array([logdelta[-1, len(y_int)-1], logdelta[-1, len(y_int)-2]])
        prev_state = [len(y_int)-1, len(y_int)-2]
        state_seq[-1] = prev_state[np.argmax(candidates)]
        for t in range(lpz.size(0)-2, -1, -1):
            state_seq[t] = state_path[t+1, state_seq[t+1, 0]]
        
        output_state_seq = []
        for t in range(0, lpz.size(0)):
            output_state_seq.append(y_int[state_seq[t, 0]])

        # orig_seq = []
        # for t in range(0, len(y)):
        #     orig_seq.append(char_list[y[t]])

        return output_state_seq