import itertools
import logging

import numpy as np


def batchfy_by_seq(
    sorted_data,
    batch_size,
    max_length_in,
    max_length_out,
    drop_len,
    min_batch_size=1,
    shortest_first=False,
    ikey="input",
    iaxis=0,
    okey="output",
    oaxis=0,
):
    """Make batch set from json dictionary

    :param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json
    :param int batch_size: batch size
    :param int max_length_in: maximum length of input to decide adaptive batch size
    :param int max_length_out: maximum length of output to decide adaptive batch size
    :param int drop_len: drop utt whose length is larger than drop_len
    :param int min_batch_size: mininum batch size (for multi-gpu)
    :param bool shortest_first: Sort from batch with shortest samples
        to longest if true, otherwise reverse
    :param str ikey: key to access input
        (for ASR ikey="input", for TTS, MT ikey="output".)
    :param int iaxis: dimension to access input
        (for ASR, TTS iaxis=0, for MT iaxis="1".)
    :param str okey: key to access output
        (for ASR, MT okey="output". for TTS okey="input".)
    :param int oaxis: dimension to access output
        (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)
    :return: List[List[Tuple[str, dict]]] list of batches
    """
    if batch_size <= 0:
        raise ValueError(f"Invalid batch_size={batch_size}")

    # check #utts is more than min_batch_size
    if len(sorted_data) < min_batch_size:
        raise ValueError(
            f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size})."
        )

    # make list of minibatches
    minibatches = []
    start = 0
    while True:
        _, info = sorted_data[start]
        ilen = int(info[ikey][iaxis]["shape"][0])
        olen = (
            int(info[okey][oaxis]["shape"][0])
            if oaxis >= 0
            else max(map(lambda x: int(x["shape"][0]), info[okey]))
        )
        factor = max(int(ilen / max_length_in), int(olen / max_length_out))
        # change batchsize depending on the input and output length
        # if ilen = 1000 and max_length_in = 800
        # then b = batchsize / 2
        # and max(min_batches, .) avoids batchsize = 0
        bs = max(min_batch_size, int(batch_size / (1 + factor)))
        end = min(len(sorted_data), start + bs)
        minibatch = sorted_data[start:end]
        if shortest_first:
            minibatch.reverse()

        # check each batch is more than minimum batchsize
        if len(minibatch) < min_batch_size:
            mod = min_batch_size - len(minibatch) % min_batch_size
            additional_minibatch = [
                sorted_data[i] for i in np.random.randint(0, start, mod)
            ]
            if shortest_first:
                additional_minibatch.reverse()
            minibatch.extend(additional_minibatch)
        
        # pengcheng: drop long utts to support separation
        if ilen < drop_len:
            minibatches.append(minibatch)

        if end == len(sorted_data):
            break
        start = end

    # batch: List[List[Tuple[str, dict]]]
    return minibatches


def batchfy_by_bin(
    sorted_data,
    batch_bins,
    num_batches=0,
    min_batch_size=1,
    shortest_first=False,
    ikey="input",
    okey="output",
):
    """Make variably sized batch set, which maximizes

    the number of bins up to `batch_bins`.

    :param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json
    :param int batch_bins: Maximum frames of a batch
    :param int num_batches: # number of batches to use (for debug)
    :param int min_batch_size: minimum batch size (for multi-gpu)
    :param int test: Return only every `test` batches
    :param bool shortest_first: Sort from batch with shortest samples
        to longest if true, otherwise reverse

    :param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
    :param str okey: key to access output (for ASR okey="output". for TTS okey="input".)

    :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
    """
    if batch_bins <= 0:
        raise ValueError(f"invalid batch_bins={batch_bins}")
    length = len(sorted_data)
    idim = int(sorted_data[0][1][ikey][0]["shape"][1])
    odim = int(sorted_data[0][1][okey][0]["shape"][1])
    logging.info("# utts: " + str(len(sorted_data)))
    minibatches = []
    start = 0
    n = 0
    while True:
        # Dynamic batch size depending on size of samples
        b = 0
        next_size = 0
        max_olen = 0
        while next_size < batch_bins and (start + b) < length:
            ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) * idim
            olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) * odim
            if olen > max_olen:
                max_olen = olen
            next_size = (max_olen + ilen) * (b + 1)
            if next_size <= batch_bins:
                b += 1
            elif next_size == 0:
                raise ValueError(
                    f"Can't fit one sample in batch_bins ({batch_bins}): "
                    f"Please increase the value"
                )
        end = min(length, start + max(min_batch_size, b))
        batch = sorted_data[start:end]
        if shortest_first:
            batch.reverse()
        minibatches.append(batch)
        # Check for min_batch_size and fixes the batches if needed
        i = -1
        while len(minibatches[i]) < min_batch_size:
            missing = min_batch_size - len(minibatches[i])
            if -i == len(minibatches):
                minibatches[i + 1].extend(minibatches[i])
                minibatches = minibatches[1:]
                break
            else:
                minibatches[i].extend(minibatches[i - 1][:missing])
                minibatches[i - 1] = minibatches[i - 1][missing:]
                i -= 1
        if end == length:
            break
        start = end
        n += 1
    if num_batches > 0:
        minibatches = minibatches[:num_batches]
    lengths = [len(x) for x in minibatches]
    logging.info(
        str(len(minibatches))
        + " batches containing from "
        + str(min(lengths))
        + " to "
        + str(max(lengths))
        + " samples "
        + "(avg "
        + str(int(np.mean(lengths)))
        + " samples)."
    )
    return minibatches


def batchfy_by_frame(
    sorted_data,
    max_frames_in,
    max_frames_out,
    max_frames_inout,
    num_batches=0,
    min_batch_size=1,
    shortest_first=False,
    ikey="input",
    okey="output",
):
    """Make variable batch set, which maximizes the number of frames to max_batch_frame.

    :param Dict[str, Dict[str, Any]] sorteddata: dictionary loaded from data.json
    :param int max_frames_in: Maximum input frames of a batch
    :param int max_frames_out: Maximum output frames of a batch
    :param int max_frames_inout: Maximum input+output frames of a batch
    :param int num_batches: # number of batches to use (for debug)
    :param int min_batch_size: minimum batch size (for multi-gpu)
    :param int test: Return only every `test` batches
    :param bool shortest_first: Sort from batch with shortest samples
        to longest if true, otherwise reverse

    :param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
    :param str okey: key to access output (for ASR okey="output". for TTS okey="input".)

    :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
    """
    if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:
        raise ValueError(
            f"At least, one of `--batch-frames-in`, `--batch-frames-out` or "
            f"`--batch-frames-inout` should be > 0"
        )
    length = len(sorted_data)
    minibatches = []
    start = 0
    end = 0
    while end != length:
        # Dynamic batch size depending on size of samples
        b = 0
        max_olen = 0
        max_ilen = 0
        while (start + b) < length:
            ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0])
            if ilen > max_frames_in and max_frames_in != 0:
                raise ValueError(
                    f"Can't fit one sample in --batch-frames-in ({max_frames_in}): "
                    f"Please increase the value"
                )
            olen = int(sorted_data[start + b][1][okey][0]["shape"][0])
            if olen > max_frames_out and max_frames_out != 0:
                raise ValueError(
                    f"Can't fit one sample in --batch-frames-out ({max_frames_out}): "
                    f"Please increase the value"
                )
            if ilen + olen > max_frames_inout and max_frames_inout != 0:
                raise ValueError(
                    f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): "
                    f"Please increase the value"
                )
            max_olen = max(max_olen, olen)
            max_ilen = max(max_ilen, ilen)
            in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0
            out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0
            inout_ok = (max_ilen + max_olen) * (
                b + 1
            ) <= max_frames_inout or max_frames_inout == 0
            if in_ok and out_ok and inout_ok:
                # add more seq in the minibatch
                b += 1
            else:
                # no more seq in the minibatch
                break
        end = min(length, start + b)
        batch = sorted_data[start:end]
        if shortest_first:
            batch.reverse()
        minibatches.append(batch)
        # Check for min_batch_size and fixes the batches if needed
        i = -1
        while len(minibatches[i]) < min_batch_size:
            missing = min_batch_size - len(minibatches[i])
            if -i == len(minibatches):
                minibatches[i + 1].extend(minibatches[i])
                minibatches = minibatches[1:]
                break
            else:
                minibatches[i].extend(minibatches[i - 1][:missing])
                minibatches[i - 1] = minibatches[i - 1][missing:]
                i -= 1
        start = end
    if num_batches > 0:
        minibatches = minibatches[:num_batches]
    lengths = [len(x) for x in minibatches]
    logging.info(
        str(len(minibatches))
        + " batches containing from "
        + str(min(lengths))
        + " to "
        + str(max(lengths))
        + " samples"
        + "(avg "
        + str(int(np.mean(lengths)))
        + " samples)."
    )

    return minibatches


def batchfy_shuffle(data, batch_size, min_batch_size, num_batches, shortest_first):
    import random

    logging.info("use shuffled batch.")
    sorted_data = random.sample(data.items(), len(data.items()))
    logging.info("# utts: " + str(len(sorted_data)))
    # make list of minibatches
    minibatches = []
    start = 0
    while True:
        end = min(len(sorted_data), start + batch_size)
        # check each batch is more than minimum batchsize
        minibatch = sorted_data[start:end]
        if shortest_first:
            minibatch.reverse()
        if len(minibatch) < min_batch_size:
            mod = min_batch_size - len(minibatch) % min_batch_size
            additional_minibatch = [
                sorted_data[i] for i in np.random.randint(0, start, mod)
            ]
            if shortest_first:
                additional_minibatch.reverse()
            minibatch.extend(additional_minibatch)
        minibatches.append(minibatch)
        if end == len(sorted_data):
            break
        start = end

    # for debugging
    if num_batches > 0:
        minibatches = minibatches[:num_batches]
        logging.info("# minibatches: " + str(len(minibatches)))
    return minibatches


BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"]
BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"]


def make_batchset(
    data,
    batch_size=0,
    max_length_in=float("inf"),
    max_length_out=float("inf"),
    drop_len=float("inf"),
    num_batches=0,
    min_batch_size=1,
    shortest_first=False,
    batch_sort_key="input",
    swap_io=False,
    mt=False,
    count="auto",
    batch_bins=0,
    batch_frames_in=0,
    batch_frames_out=0,
    batch_frames_inout=0,
    iaxis=0,
    oaxis=0,
):
    """Make batch set from json dictionary

    if utts have "category" value,

        >>> data = {'utt1': {'category': 'A', 'input': ...},
        ...         'utt2': {'category': 'B', 'input': ...},
        ...         'utt3': {'category': 'B', 'input': ...},
        ...         'utt4': {'category': 'A', 'input': ...}}
        >>> make_batchset(data, batchsize=2, ...)
        [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]

    Note that if any utts doesn't have "category",
    perform as same as batchfy_by_{count}

    :param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json
    :param int batch_size: maximum number of sequences in a minibatch.
    :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
    :param int batch_frames_in:  maximum number of input frames in a minibatch.
    :param int batch_frames_out: maximum number of output frames in a minibatch.
    :param int batch_frames_out: maximum number of input+output frames in a minibatch.
    :param str count: strategy to count maximum size of batch.
        For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES

    :param int max_length_in: maximum length of input to decide adaptive batch size
    :param int max_length_out: maximum length of output to decide adaptive batch size
    :param int num_batches: # number of batches to use (for debug)
    :param int min_batch_size: minimum batch size (for multi-gpu)
    :param bool shortest_first: Sort from batch with shortest samples
        to longest if true, otherwise reverse
    :param str batch_sort_key: how to sort data before creating minibatches
        ["input", "output", "shuffle"]
    :param bool swap_io: if True, use "input" as output and "output"
        as input in `data` dict
    :param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output"
        as input in `data` dict
    :param int iaxis: dimension to access input
        (for ASR, TTS iaxis=0, for MT iaxis="1".)
    :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,
        reserved for future research, -1 means all axis.)
    :return: List[List[Tuple[str, dict]]] list of batches
    """

    # check args
    if count not in BATCH_COUNT_CHOICES:
        raise ValueError(
            f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}"
        )
    if batch_sort_key not in BATCH_SORT_KEY_CHOICES:
        raise ValueError(
            f"arg 'batch_sort_key' ({batch_sort_key}) should be "
            f"one of {BATCH_SORT_KEY_CHOICES}"
        )

    # TODO(karita): remove this by creating converter from ASR to TTS json format
    batch_sort_axis = 0
    if swap_io:
        # for TTS
        ikey = "output"
        okey = "input"
        if batch_sort_key == "input":
            batch_sort_key = "output"
        elif batch_sort_key == "output":
            batch_sort_key = "input"
    elif mt:
        # for MT
        ikey = "output"
        okey = "output"
        batch_sort_key = "output"
        batch_sort_axis = 1
        assert iaxis == 1
        assert oaxis == 0
        # NOTE: input is json['output'][1] and output is json['output'][0]
    else:
        ikey = "input"
        okey = "output"

    if count == "auto":
        if batch_size != 0:
            count = "seq"
        elif batch_bins != 0:
            count = "bin"
        elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:
            count = "frame"
        else:
            raise ValueError(
                f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}"
            )
        logging.info(f"count is auto detected as {count}")

    if count != "seq" and batch_sort_key == "shuffle":
        raise ValueError(f"batch_sort_key=shuffle is only available if batch_count=seq")

    category2data = {}  # Dict[str, dict]
    for k, v in data.items():
        category2data.setdefault(v.get("category"), {})[k] = v

    batches_list = []  # List[List[List[Tuple[str, dict]]]]
    for d in category2data.values():
        if batch_sort_key == "shuffle":
            batches = batchfy_shuffle(
                d, batch_size, min_batch_size, num_batches, shortest_first
            )
            batches_list.append(batches)
            continue

        # sort it by input lengths (long to short)
        sorted_data = sorted(
            d.items(),
            key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
            reverse=not shortest_first,
        )
        logging.info("# utts: " + str(len(sorted_data)))
        if count == "seq":
            batches = batchfy_by_seq(
                sorted_data,
                batch_size=batch_size,
                max_length_in=max_length_in,
                max_length_out=max_length_out,
                drop_len=drop_len,
                min_batch_size=min_batch_size,
                shortest_first=shortest_first,
                ikey=ikey,
                iaxis=iaxis,
                okey=okey,
                oaxis=oaxis,
            )
        if count == "bin":
            batches = batchfy_by_bin(
                sorted_data,
                batch_bins=batch_bins,
                min_batch_size=min_batch_size,
                shortest_first=shortest_first,
                ikey=ikey,
                okey=okey,
            )
        if count == "frame":
            batches = batchfy_by_frame(
                sorted_data,
                max_frames_in=batch_frames_in,
                max_frames_out=batch_frames_out,
                max_frames_inout=batch_frames_inout,
                min_batch_size=min_batch_size,
                shortest_first=shortest_first,
                ikey=ikey,
                okey=okey,
            )
        batches_list.append(batches)

    if len(batches_list) == 1:
        batches = batches_list[0]
    else:
        # Concat list. This way is faster than "sum(batch_list, [])"
        batches = list(itertools.chain(*batches_list))

    # for debugging
    if num_batches > 0:
        batches = batches[:num_batches]
    logging.info("# minibatches: " + str(len(batches)))

    # batch: List[List[Tuple[str, dict]]]
    return batches
