#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

from __future__ import print_function
from __future__ import unicode_literals

import argparse
import codecs
import json
import logging
import os
import sys

from espnet.utils.cli_utils import get_commandline_args

is_python2 = sys.version_info[0] == 2


def get_parser():
    parser = argparse.ArgumentParser(
        description="merge json files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--input-jsons",
        type=str,
        nargs="+",
        action="append",
        default=[],
        help="Json files for the inputs",
    )
    parser.add_argument(
        "--output-jsons",
        type=str,
        nargs="+",
        action="append",
        default=[],
        help="Json files for the outputs",
    )
    parser.add_argument(
        "--jsons",
        type=str,
        nargs="+",
        action="append",
        default=[],
        help="The json files except for the input and outputs",
    )
    parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
    parser.add_argument("-O", dest="output", type=str, help="Output json file")
    return parser


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    js_dict = {}  # Dict[str, List[List[Dict[str, Dict[str, dict]]]]]
    # make intersection set for utterance keys
    intersec_ks = None  # Set[str]
    for jtype, jsons_list in [
        ("input", args.input_jsons),
        ("output", args.output_jsons),
        ("other", args.jsons),
    ]:
        js_dict[jtype] = []
        for jsons in jsons_list:
            js = []
            for x in jsons:
                if os.path.isfile(x):
                    with codecs.open(x, encoding="utf-8") as f:
                        j = json.load(f)
                    ks = list(j["utts"].keys())
                    logging.info(x + ": has " + str(len(ks)) + " utterances")
                    if intersec_ks is not None:
                        intersec_ks = intersec_ks.intersection(set(ks))
                        if len(intersec_ks) == 0:
                            logging.warning("No intersection")
                            break
                    else:
                        intersec_ks = set(ks)
                    js.append(j)
            js_dict[jtype].append(js)
    logging.info("new json has " + str(len(intersec_ks)) + " utterances")

    new_dic = {}
    for k in intersec_ks:
        new_dic[k] = {"input": [], "output": []}
        for jtype in ["input", "output", "other"]:
            for idx, js in enumerate(js_dict[jtype], 1):
                # Merge dicts from jsons into a dict
                dic = {k2: v for j in js for k2, v in j["utts"][k].items()}

                if jtype == "other":
                    new_dic[k].update(dic)
                else:
                    _dic = {}

                    # FIXME(kamo): ad-hoc way to change str to List[int]
                    if jtype == "input":
                        _dic["name"] = "input{}".format(idx)
                        if "ilen" in dic and "idim" in dic:
                            _dic["shape"] = (int(dic["ilen"]), int(dic["idim"]))
                        elif "ilen" in dic:
                            _dic["shape"] = (int(dic["ilen"]),)
                        elif "idim" in dic:
                            _dic["shape"] = (int(dic["idim"]),)

                    elif jtype == "output":
                        _dic["name"] = "target{}".format(idx)
                        if "olen" in dic and "odim" in dic:
                            _dic["shape"] = (int(dic["olen"]), int(dic["odim"]))
                        elif "ilen" in dic:
                            _dic["shape"] = (int(dic["olen"]),)
                        elif "idim" in dic:
                            _dic["shape"] = (int(dic["odim"]),)
                    if "shape" in dic:
                        # shape: "80,1000" -> [80, 1000]
                        _dic["shape"] = list(map(int, dic["shape"].split(",")))

                    for k2, v in dic.items():
                        if k2 not in ["ilen", "idim", "olen", "odim", "shape"]:
                            _dic[k2] = v
                    new_dic[k][jtype].append(_dic)

    # ensure "ensure_ascii=False", which is a bug
    if args.output is not None:
        sys.stdout = codecs.open(args.output, "w", encoding="utf-8")
    else:
        sys.stdout = codecs.getwriter("utf-8")(
            sys.stdout if is_python2 else sys.stdout.buffer
        )
    print(
        json.dumps(
            {"utts": new_dic},
            indent=4,
            ensure_ascii=False,
            sort_keys=True,
            separators=(",", ": "),
        )
    )
