from collections import OrderedDict
import copy
import io
import logging
import sys

import yaml

from espnet.utils.dynamic_import import dynamic_import


PY2 = sys.version_info[0] == 2

if PY2:
    from collections import Sequence
    from funcsigs import signature
else:
    # The ABCs from 'collections' will stop working in 3.8
    from collections.abc import Sequence
    from inspect import signature


# TODO(karita): inherit TransformInterface
# TODO(karita): register cmd arguments in asr_train.py
import_alias = dict(
    identity="espnet.transform.transform_interface:Identity",
    time_warp="espnet.transform.spec_augment:TimeWarp",
    time_mask="espnet.transform.spec_augment:TimeMask",
    freq_mask="espnet.transform.spec_augment:FreqMask",
    spec_augment="espnet.transform.spec_augment:SpecAugment",
    speed_perturbation="espnet.transform.perturb:SpeedPerturbation",
    volume_perturbation="espnet.transform.perturb:VolumePerturbation",
    noise_injection="espnet.transform.perturb:NoiseInjection",
    bandpass_perturbation="espnet.transform.perturb:BandpassPerturbation",
    rir_convolve="espnet.transform.perturb:RIRConvolve",
    delta="espnet.transform.add_deltas:AddDeltas",
    cmvn="espnet.transform.cmvn:CMVN",
    utterance_cmvn="espnet.transform.cmvn:UtteranceCMVN",
    fbank="espnet.transform.spectrogram:LogMelSpectrogram",
    spectrogram="espnet.transform.spectrogram:Spectrogram",
    stft="espnet.transform.spectrogram:Stft",
    istft="espnet.transform.spectrogram:IStft",
    stft2fbank="espnet.transform.spectrogram:Stft2LogMelSpectrogram",
    wpe="espnet.transform.wpe:WPE",
    channel_selector="espnet.transform.channel_selector:ChannelSelector",
)


class Transformation(object):
    """Apply some functions to the mini-batch

    Examples:
        >>> kwargs = {"process": [{"type": "fbank",
        ...                        "n_mels": 80,
        ...                        "fs": 16000},
        ...                       {"type": "cmvn",
        ...                        "stats": "data/train/cmvn.ark",
        ...                        "norm_vars": True},
        ...                       {"type": "delta", "window": 2, "order": 2}]}
        >>> transform = Transformation(kwargs)
        >>> bs = 10
        >>> xs = [np.random.randn(100, 80).astype(np.float32)
        ...       for _ in range(bs)]
        >>> xs = transform(xs)
    """

    def __init__(self, conffile=None):
        if conffile is not None:
            if isinstance(conffile, dict):
                self.conf = copy.deepcopy(conffile)
            else:
                with io.open(conffile, encoding="utf-8") as f:
                    self.conf = yaml.safe_load(f)
                    assert isinstance(self.conf, dict), type(self.conf)
        else:
            self.conf = {"mode": "sequential", "process": []}

        self.functions = OrderedDict()
        if self.conf.get("mode", "sequential") == "sequential":
            for idx, process in enumerate(self.conf["process"]):
                assert isinstance(process, dict), type(process)
                opts = dict(process)
                process_type = opts.pop("type")
                class_obj = dynamic_import(process_type, import_alias)
                # TODO(karita): assert issubclass(class_obj, TransformInterface)
                try:
                    self.functions[idx] = class_obj(**opts)
                except TypeError:
                    try:
                        signa = signature(class_obj)
                    except ValueError:
                        # Some function, e.g. built-in function, are failed
                        pass
                    else:
                        logging.error(
                            "Expected signature: {}({})".format(
                                class_obj.__name__, signa
                            )
                        )
                    raise
        else:
            raise NotImplementedError(
                "Not supporting mode={}".format(self.conf["mode"])
            )

    def __repr__(self):
        rep = "\n" + "\n".join(
            "    {}: {}".format(k, v) for k, v in self.functions.items()
        )
        return "{}({})".format(self.__class__.__name__, rep)

    def __call__(self, xs, uttid_list=None, **kwargs):
        """Return new mini-batch

        :param Union[Sequence[np.ndarray], np.ndarray] xs:
        :param Union[Sequence[str], str] uttid_list:
        :return: batch:
        :rtype: List[np.ndarray]
        """
        if not isinstance(xs, Sequence):
            is_batch = False
            xs = [xs]
        else:
            is_batch = True

        if isinstance(uttid_list, str):
            uttid_list = [uttid_list for _ in range(len(xs))]

        if self.conf.get("mode", "sequential") == "sequential":
            for idx in range(len(self.conf["process"])):
                func = self.functions[idx]
                # TODO(karita): use TrainingTrans and UttTrans to check __call__ args
                # Derive only the args which the func has
                try:
                    param = signature(func).parameters
                except ValueError:
                    # Some function, e.g. built-in function, are failed
                    param = {}
                _kwargs = {k: v for k, v in kwargs.items() if k in param}
                try:
                    if uttid_list is not None and "uttid" in param:
                        xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)]
                    else:
                        xs = [func(x, **_kwargs) for x in xs]
                except Exception:
                    logging.fatal(
                        "Catch a exception from {}th func: {}".format(idx, func)
                    )
                    raise
        else:
            raise NotImplementedError(
                "Not supporting mode={}".format(self.conf["mode"])
            )

        if is_batch:
            return xs
        else:
            return xs[0]
