import numpy as np
import tensorflow as tf

from mayo.util import object_from_params, Percent
from mayo.override import util
from mayo.override.base import OverriderBase, Parameter


class IncrementalQuantizer(OverriderBase):
    """
    https://arxiv.org/pdf/1702.03044.pdf
    """
    mask = Parameter('mask', None, None, 'bool')

    def __init__(
            self, session, quantizer, interval,
            stop_gradient=False, count_zero=True,
            should_update=True, enable=True):
        super().__init__(session, should_update, enable)
        cls, params = object_from_params(quantizer)
        self.quantizer = cls(session, **params)
        self.stop_gradient = stop_gradient
        self.count_zero = count_zero
        self.interval = interval

    def _quantize(self, value, mean_quantizer=False):
        quantizer = self.quantizer
        scope = '{}/{}'.format(self._scope, self.__class__.__name__)
        return quantizer.apply(self.node, scope, self._original_getter, value)

    def _apply(self, value):
        self._parameter_config = {
            'mask': {
                'initial': tf.zeros_initializer(tf.bool),
                'shape': value.shape,
            }
        }
        quantized_value = self._quantize(value)
        if self.stop_gradient:
            quantized_value = tf.stop_gradient(quantized_value)
        off_mask = util.cast(util.logical_not(self.mask), float)
        mask = util.cast(self.mask, float)
        # on mask indicates the quantized values
        return value * off_mask + quantized_value * mask

    def _policy(self, value, quantized, previous_mask, interval):
        if interval >= 1.0:
            # ensure all values are quantized
            return np.ones(previous_mask.shape, dtype=np.bool)
        previous_pruned = util.sum(previous_mask)
        if self.count_zero:
            th_arg = util.cast(util.count(value) * interval, int)
        else:
            tmp = util.count(value[value != 0])
            flat_value_arg = util.where(value.flatten() != 0)
            th_arg = util.cast(tmp * interval, int)
        if th_arg < 0:
            raise ValueError(
                'mask has {} elements, interval is {}'.format(
                    previous_pruned, interval))
        off_mask = util.cast(
            util.logical_not(util.cast(previous_mask, bool)), float)
        metric = util.abs(value - quantized)
        flat_value = (metric * off_mask).flatten()
        if interval >= 1.0:
            th = flat_value.max() + 1.0
        elif self.count_zero:
            th = util.top_k(flat_value, th_arg)
        else:
            th = util.top_k(flat_value[flat_value_arg], th_arg)
        th = util.cast(th, float)
        new_mask = util.logical_not(util.greater_equal(metric, th))
        return util.logical_or(new_mask, previous_mask)

    # override assign_parameters to assign quantizer as well
    def assign_parameters(self):
        super().assign_parameters()
        self.quantizer.assign_parameters()

    def _update(self):
        # reset index
        self.quantizer.update()
        # if chosen quantized, change it to zeros
        value, quantized, mask = self.session.run(
            [self.before, self.quantizer.after, self.mask])
        new_mask = self._policy(value, quantized, mask, self.interval)
        self.session.assign(self.mask, new_mask)

    def dump(self):
        return self.quantizer.dump()

    def _info(self):
        info = self.quantizer.info()._asdict()
        info.pop('name')
        mask = self.session.run(self.mask)
        interval = mask.sum() / mask.size
        return self._info_tuple(**info, interval=Percent(interval))

    def finalize_info(self, table):
        return super().finalize_info(table)
