#!/usr/bin/env python3
# -*- coding: utf-8 -*-


from __future__ import print_function
from __future__ import unicode_literals

import argparse
import codecs
import json
import logging
import os
import sys

from espnet.utils.cli_utils import get_commandline_args

is_python2 = sys.version_info[0] == 2

def get_parser():
    parser = argparse.ArgumentParser(
        description='prepare mixture alignments, e.g. ./local/wsj_mix_alignments_prep.py --input-json /export/c09/xkc09/tools/espnet/egs/wsj/asr1/exp/train_si284_pytorch_train_no_preprocess/align_train_si284_decode_lm_word65000/data.json --utt2spk data_wsj/tr/utt2spk -O data_wsj/tr/ctc_alignment_spk1 data_wsj/tr/ctc_alignment_spk2 --verbose 1',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input-json', type=str, default=None,
                        help='Json file for the input')
    parser.add_argument('--utt2spk', type=str, default='', required=True,
                        help='utt2spk files to identify the mixture schemes')
    parser.add_argument('--verbose', '-V', default=0, type=int, help='Verbose option')
    parser.add_argument('--blank-sym', default='<blank>', type=str, help='Blank symbol to be appended')
    parser.add_argument('--num-spkrs', default=2, type=int, help='Number of speakers')
    parser.add_argument('-O', dest='output', type=str, nargs='+', default=[],
                        help='Output files')
    return parser


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()

    assert args.num_spkrs == len(args.output), f"Number of speakers {args.num_spkrs} is not consistant with number of outputs {len(args.output)}"

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    # Read alignment json
    if os.path.isfile(args.input_json):
        with codecs.open(args.input_json, encoding="utf-8") as f:
            j = json.load(f)
        ks = list(j['utts'].keys())
        j = j['utts']
        logging.info(args.input_json + ': has ' + str(len(ks)) + ' utterances')

    # Read utt2spk
    output_dic = dict()
    if os.path.isfile(args.utt2spk):
        with codecs.open(args.utt2spk, encoding="utf-8") as f:
            nutts = 0
            while True:
                nutts += 1
                input_line = f.readline()

                if not input_line:
                    break

                utt_name = input_line.split()[0]
                # FIXME(xkc09): ad-hoc for wsj_2mix
                _, _, name1, _, name2, _ = utt_name.split('_')
                assert name1 in ks
                assert name2 in ks
                # if args.num_spkrs > 1:
                #     lst = utt_name.split('_')
                #     spkrs = lst[0:args.num_spkrs]
                #     names = lst[args.num_spkrs::2]
                # else:
                #     names = [utt_name]
                # for name in names:
                #     assert name in ks

                alignments = [j[name1]['ctc_alignment'], j[name2]['ctc_alignment']]
                # alignments = [j[name]['ctc_alignment'] for name in names]
                lens = []
                for i, a in enumerate(alignments):
                    alignments[i] = a.split()
                    lens.append(len(alignments[i]))

                max_len = max(lens)
                for i, a in enumerate(alignments):
                    alignments[i] = alignments[i] + [args.blank_sym] * (max_len - lens[i])      # Because in the wsj0_mixture generation process, 0s are padded to the end of the short utterances.
                    assert len(alignments[i]) == max_len
                    alignments[i] = ' '.join(alignments[i])

                output_dic[utt_name] = {}
                for i, a in enumerate(alignments):
                    output_dic[utt_name][f'ctc_alignment{i+1}'] = a

    fs = [open(o, 'w', encoding='utf-8') for o in args.output]
    for key, val in output_dic.items():
        for n in range(len(args.output)):
            fs[n].write(key + ' ' + val[f'ctc_alignment{n+1}'] + '\n')

    for f in fs:
        f.close()
