# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmrazor.registry import HOOKS


# Stop the gradient and update of mask generation network
@HOOKS.register_module()
class StopMaskLearningEpochHook(Hook):
    """Stop mask learning at a certain time.

    Args:
        stop_epoch (int): Stop mask learning at this epoch.
    """

    priority = 'HIGH'

    def __init__(self, stop_epoch: int) -> None:
        self.stop_epoch = stop_epoch

    def before_train_epoch(self, runner) -> None:
        """Stop distillation."""

        if runner.epoch == self.stop_epoch:
            model = runner.model
            # TODO: refactor after mmengine using model wrapper
            if is_model_wrapper(model):
                model = model.module
            assert hasattr(model.distiller, 'mask_learning_stopped')

            runner.logger.info('Mask learning has been stopped!')
            model.distiller.mask_learning_stopped = True
