#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2018 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

from __future__ import print_function
from __future__ import unicode_literals

import argparse
import codecs
import json
import logging
import sys

from distutils.util import strtobool

from espnet.utils.cli_utils import get_commandline_args

is_python2 = sys.version_info[0] == 2


def get_parser():
    parser = argparse.ArgumentParser(
        description="add multiple json values to an input or output value",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("jsons", type=str, nargs="+", help="json files")
    parser.add_argument(
        "-i",
        "--is-input",
        default=True,
        type=strtobool,
        help="If true, add to input. If false, add to output",
    )
    parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
    return parser


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    # make intersection set for utterance keys
    js = []
    intersec_ks = []
    for x in args.jsons:
        with codecs.open(x, "r", encoding="utf-8") as f:
            j = json.load(f)
        ks = j["utts"].keys()
        logging.info(x + ": has " + str(len(ks)) + " utterances")
        if len(intersec_ks) > 0:
            intersec_ks = intersec_ks.intersection(set(ks))
            if len(intersec_ks) == 0:
                logging.warning("Empty intersection")
                break
        else:
            intersec_ks = set(ks)
        js.append(j)
    logging.info("new json has " + str(len(intersec_ks)) + " utterances")

    # updated original dict to keep intersection
    intersec_org_dic = dict()
    for k in intersec_ks:
        v = js[0]["utts"][k]
        intersec_org_dic[k] = v

    intersec_add_dic = dict()
    for k in intersec_ks:
        v = js[1]["utts"][k]
        for j in js[2:]:
            v.update(j["utts"][k])
        intersec_add_dic[k] = v

    new_dic = dict()
    for key_id in intersec_org_dic:
        orgdic = intersec_org_dic[key_id]
        adddic = intersec_add_dic[key_id]

        if "utt2spk" not in orgdic:
            orgdic["utt2spk"] = ""
        # NOTE: for machine translation

        # add as input
        if args.is_input:
            # original input
            input_list = orgdic["input"]
            # additional input
            in_add_dic = {}
            if "idim" in adddic and "ilen" in adddic:
                in_add_dic["shape"] = [int(adddic["ilen"]), int(adddic["idim"])]
            elif "idim" in adddic:
                in_add_dic["shape"] = [int(adddic["idim"])]
            # add all other key value
            for key, value in adddic.items():
                if key in ["idim", "ilen"]:
                    continue
                in_add_dic[key] = value
            # add name
            in_add_dic["name"] = "input%d" % (len(input_list) + 1)

            input_list.append(in_add_dic)
            new_dic[key_id] = {
                "input": input_list,
                "output": orgdic["output"],
                "utt2spk": orgdic["utt2spk"],
            }
        # add as output
        else:
            # original output
            output_list = orgdic["output"]
            # additional output
            out_add_dic = {}
            # add shape
            if "odim" in adddic and "olen" in adddic:
                out_add_dic["shape"] = [int(adddic["olen"]), int(adddic["odim"])]
            elif "odim" in adddic:
                out_add_dic["shape"] = [int(adddic["odim"])]
            # add all other key value
            for key, value in adddic.items():
                if key in ["odim", "olen"]:
                    continue
                out_add_dic[key] = value
            # add name
            out_add_dic["name"] = "target%d" % (len(output_list) + 1)

            output_list.append(out_add_dic)
            new_dic[key_id] = {
                "input": orgdic["input"],
                "output": output_list,
                "utt2spk": orgdic["utt2spk"],
            }
            if "lang" in orgdic.keys():
                new_dic[key_id]["lang"] = orgdic["lang"]

    # ensure "ensure_ascii=False", which is a bug
    jsonstring = json.dumps(
        {"utts": new_dic},
        indent=4,
        ensure_ascii=False,
        sort_keys=True,
        separators=(",", ": "),
    )
    sys.stdout = codecs.getwriter("utf-8")(
        sys.stdout if is_python2 else sys.stdout.buffer
    )
    print(jsonstring)
