#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2019 Johns Hopkins University (Xuankai Chang)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import json
import logging
import re
import six
import sys

import numpy as np


def permutationDFS(source, start, res):
    # get permutations with DFS
    # return order in [[1, 2], [2, 1]] or
    # [[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 2, 1], [3, 1, 2]]
    if start == len(source) - 1:  # reach final state
        res.append(source.tolist())
    for i in range(start, len(source)):
        # swap values at position start and i
        source[start], source[i] = source[i], source[start]
        permutationDFS(source, start + 1, res)
        # reverse the swap
        source[start], source[i] = source[i], source[start]


# pre-set the permutation scheme (ref_idx, hyp_idx)
def permutation_schemes(num_spkrs):
    src = [x for x in range(1, num_spkrs + 1)]
    perms = []

    # get all permutations of [1, ..., num_spkrs]
    # [[r1h1, r2h2], [r1h2, r2h1]]
    # [[r1h1, r2h2, r3h3], [r1h1, r2h3, r3h2], [r1h2, r2h1, r3h3],
    #  [r1h2, r2h3, r3h2], [r1h3, r2h2, r3h1], [r1h3, r2h1, r3h2]]]
    # ...
    permutationDFS(np.array(src), 0, perms)

    keys = []
    for perm in perms:
        keys.append(["r%dh%d" % (i, j) for i, j in enumerate(perm, 1)])

    return sum(keys, []), keys


def convert_score(keys, dic):
    ret = {}
    pat = re.compile(r"\d+")
    for k in keys:
        score = dic[k]["Scores"]
        score = np.array(list(map(int, pat.findall(score))))  # [c,s,d,i]
        assert len(score) == 4
        ret[k] = score
    return ret


def get_utt_permutation(result, num_spkrs=2):
    perm, keys = permutation_schemes(num_spkrs)
    new_dic = {}

    for id in result.keys():  # utts
        # compute error rate for each utt
        in_dic = result[id]
        score = convert_score(perm, in_dic)
        perm_score = []
        for ks in keys:
            tmp_score = np.array([0, 0, 0, 0])
            for k in ks:
                tmp_score = tmp_score + score[k]
            perm_score.append(tmp_score)
        perm_score = np.stack(perm_score)

        error_rate = np.sum(perm_score[:, 1:4], axis=1) / np.sum(perm_score[:, 0:3], axis=1)  # (s+d+i) / (c+s+d)

        min_idx, min_v = np.argmin(error_rate), np.min(error_rate)
        dic = {}
        for k in keys[min_idx]:
            dic[k] = in_dic[k]
        dic["Scores"] = "(#C #S #D #I) " + " ".join(map(str, perm_score[min_idx]))
        new_dic[id] = dic

    return new_dic


def get_results(result_file, result_key):
    re_id = r"^id: "
    re_strings = {
        "Speaker": r"^Speaker sentences",
        "Scores": r"^Scores: ",
        "REF": r"^REF: ",
        "HYP": r"^HYP: ",
    }
    re_id = re.compile(re_id)
    re_patterns = {}
    for p in re_strings.keys():
        re_patterns[p] = re.compile(re_strings[p])

    results = {}
    tmp_id = None
    tmp_ret = {}

    with open(result_file, "r") as f:
        line = f.readline()
        while line:
            x = line.rstrip()
            x_split = x.split()

            if re_id.match(x):
                if tmp_id:
                    results[tmp_id] = {result_key: tmp_ret}
                    tmp_ret = {}
                tmp_id = x_split[1]
            for p in re_patterns.keys():
                if re_patterns[p].match(x):
                    tmp_ret[p] = " ".join(x_split[1:])
            line = f.readline()

    if tmp_ret != {}:
        results[tmp_id] = {result_key: tmp_ret}

    return {"utts": results}


def merge_results(results):
    rslt_lst = []

    # make intersection set for utterance keys
    intersec_keys = []
    for x in results.keys():
        j = results[x]

        ks = j["utts"].keys()
        logging.info(x + ": has " + str(len(ks)) + " utterances")

        if len(intersec_keys) > 0:
            intersec_keys = intersec_keys.intersection(set(ks))
        else:
            intersec_keys = set(ks)
        rslt_lst.append(j)

    logging.info(
        "After merge, the result has " + str(len(intersec_keys)) + " utterances"
    )

    # merging results
    dic = dict()
    for k in intersec_keys:
        v = rslt_lst[0]["utts"][k]
        for j in rslt_lst[1:]:
            v.update(j["utts"][k])
        dic[k] = v

    return dic


def get_parser():
    parser = argparse.ArgumentParser(description="evaluate permutation-free error")
    parser.add_argument(
        "--num-spkrs", type=int, default=2, help="number of mixed speakers."
    )
    parser.add_argument(
        "results",
        type=str,
        nargs="+",
        help="the scores between references and hypotheses, "
        "in ascending order of references (1st) and hypotheses (2nd), "
        "e.g. [r1h1, r1h2, r2h1, r2h2] in 2-speaker-mix case.",
    )
    return parser


def main(args):
    # Read results from files
    results = {}
    for r_idx in six.moves.range(1, args.num_spkrs + 1):
        for h_idx in six.moves.range(1, args.num_spkrs + 1):
            key = f"r{r_idx}h{h_idx}"
            result_name = list(filter(lambda r: re.match(f'\S*{key}\S*', r), args.results))
            assert len(result_name) == 1

            results[key] = get_results(result_name[0], key)

    # Merge the results of every permutation
    results = merge_results(results)

    # Get the final results with best permutation
    new_results = get_utt_permutation(results, args.num_spkrs)

    new_results_per_hyp = {}
    for utt_id, res in new_results.items():
        tmp_result = []
        for h_idx in six.moves.range(1, args.num_spkrs + 1):
            tmp_key = list(filter(lambda r: re.match(f'r\d+h{h_idx}', r), res.keys()))
            assert len(tmp_key) == 1
            tmp_result.append(
                [int(x) for x in res[tmp_key[0]]['Scores'].split()[-4:]]
            )
        new_results_per_hyp[utt_id] = np.stack(tmp_result)  # (nspkr, 4)

    score_per_hyp = []
    for h_idx in six.moves.range(args.num_spkrs):
        tmp_score_per_hyp = []
        for utt_id, res in new_results.items():
            tmp_score_per_hyp.append(new_results_per_hyp[utt_id][h_idx])
        score_per_hyp.append(np.stack(tmp_score_per_hyp))

    new_results_per_ref = {}
    for utt_id, res in new_results.items():
        tmp_result = []
        for r_idx in six.moves.range(1, args.num_spkrs + 1):
            tmp_key = list(filter(lambda r: re.match(f'r{r_idx}h\d+', r), res.keys()))
            assert len(tmp_key) == 1
            tmp_result.append(
                [int(x) for x in res[tmp_key[0]]['Scores'].split()[-4:]]
            )
        new_results_per_ref[utt_id] = np.stack(tmp_result)  # (nspkr, 4)

    score_per_ref = []
    for h_idx in six.moves.range(args.num_spkrs):
        tmp_score_per_ref = []
        for utt_id, res in new_results.items():
            tmp_score_per_ref.append(new_results_per_ref[utt_id][h_idx])
        score_per_ref.append(np.stack(tmp_score_per_ref))

    # Get WER/CER
    pat = re.compile(r"\d+")
    score = np.zeros((len(new_results.keys()), 4))
    for idx, key in enumerate(new_results.keys()):
        # [c, s, d, i]
        tmp_score = list(map(int, pat.findall(new_results[key]["Scores"])))
        score[idx] = tmp_score
    return score, new_results, score_per_hyp, score_per_ref


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    if len(args.results) != args.num_spkrs ** 2:
        parser.print_help()
        sys.exit(1)

    sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)

    scores, new_results, score_per_hyp, score_per_ref = main(args)
    score_sum = np.sum(scores, axis=0, dtype=int)

    # Print results
    print(sys.argv)
    print("Total Scores: (#C #S #D #I) " + " ".join(map(str, list(score_sum))))
    print(
        "Error Rate:   {:0.2f}".format(
            100 * sum(score_sum[1:4]) / float(sum(score_sum[0:3]))
        )
    )
    print("Total Utts: ", str(scores.shape[0]))

    for h_idx in range(len(score_per_hyp)):
        score_sum = np.sum(score_per_hyp[h_idx], axis=0, dtype=int)
        print(f"Total Scores for hypothesis {h_idx+1}: (#C #S #D #I) " + " ".join(map(str, list(score_sum))))
        print(
            "Error Rate:   {:0.2f}".format(
                100 * sum(score_sum[1:4]) / float(sum(score_sum[0:3]))
            )
        )
    score_per_hyp = np.stack(score_per_hyp)   # (n_spkrs, n_utts, 4), 4 is from #C, #S, #D, #I
    # ref_length = np.sum(score_per_hyp[:, :, 0:3], axis=2)
    ref_length = np.sum(score_per_hyp[:, :], axis=2) - score_per_hyp[:, :, 2]
    ref_length_order = np.argsort(-ref_length, axis=0)
    # print(ref_length.shape, ref_length_order.shape)
    tmp_out = np.concatenate((ref_length, ref_length_order), axis=0)
    print(tmp_out[:, 1])
    print('horizontal is rank, vertical axis is generation (hypothesis) order')
    conf_matrix = np.zeros((args.num_spkrs, args.num_spkrs))
    for i in range(args.num_spkrs):  # hypothesis order
        for j in range(args.num_spkrs):  # length rank
            conf_matrix[i, j] = np.sum(ref_length_order[i] == j)
    print('confusion matrix:\n', conf_matrix)
    ## Following is for two speakers only
    if args.num_spkrs == 2:
        conf_matrix = np.zeros((2, 2))
        loose_constraint = 5
        for i in range(ref_length_order.shape[1]):  # n_utts
            if (ref_length[0, i] + loose_constraint >= ref_length[1, i]):
                conf_matrix[0, 0] += 1
                conf_matrix[1, 1] += 1
            else:
                conf_matrix[0, 1] += 1
                conf_matrix[1, 0] += 1
        print(f'confusion matrix under loosen constraints {loose_constraint}\n', conf_matrix)
    longest = 0
    loose_constraint = 5
    for i in range(ref_length_order.shape[1]):  # n_utts
        if (ref_length[0, i] + loose_constraint >= max(ref_length[:, i])):
            longest += 1
    print('longest is the first:', longest)
        
    for i in six.moves.range(ref_length.shape[0]):
        print(f'longest {i}th:', np.sum((ref_length[i]+4) >= ref_length_max))
    for r_idx in range(len(score_per_ref)):
        score_sum = np.sum(score_per_ref[r_idx], axis=0, dtype=int)
        print(f"Total Scores for reference {r_idx+1}: (#C #S #D #I) " + " ".join(map(str, list(score_sum))))
        print(
            "Error Rate:   {:0.2f}".format(
                100 * sum(score_sum[1:4]) / float(sum(score_sum[0:3]))
            )
        )

    # hyp_length = np.sum(score_per_hyp, axis=2) - score_per_hyp[:, :, 2]  # C+S+D+I-D


    exit()

    print(
        json.dumps(
            {"utts": new_results},
            indent=4,
            ensure_ascii=False,
            sort_keys=True,
            separators=(",", ": "),
        )
    )
