from typing import Any

import numpy as np
import torch.nn.functional as F
import torch


def soft_masker(heat_maps, **kwargs):
    heat_maps = np.ones_like(heat_maps) - heat_maps
    return [np.expand_dims(heat_maps, axis=-1)]


def mask_heatmap_using_threshold(heat_maps, **kwargs):
    mask_threshold_values = []
    if kwargs["mask_threshold_value"] is not None:
        mask_threshold_values = [np.ones(
            (heat_maps.shape[0], 1, 1)) * kwargs["mask_threshold_value"]]
    else:
        mask_mean_value = np.nanmean(
            np.where(heat_maps > 0, heat_maps, np.nan), axis=(1, 2))[:, None, None]
        mask_std_value = np.nanstd(
            np.where(heat_maps > 0, heat_maps, np.nan), axis=(1, 2))[:, None, None]
        for std_coefficient in kwargs["std_coefficients"]:
            mask_threshold_values.append(
                mask_mean_value + std_coefficient * mask_std_value)
    masked_heatmaps = []
    for mask_threshold_value in mask_threshold_values:
        masked_heatmaps.append(np.expand_dims(
            np.where(heat_maps > mask_threshold_value, 0, 1), axis=-1))

    return masked_heatmaps


def mask_heatmap_using_sort(heat_maps, **kwargs):
    remove_k = kwargs["remove_k"]
    batch_size, w_, h_ = heat_maps.shape
    number_to_remove = int(w_*h_*remove_k)
    heat_map_f = heat_maps.reshape(batch_size, -1)

    top_k_indices = np.argpartition(
        heat_map_f, -number_to_remove, axis=1)[:, -number_to_remove:]
    heat_map_f[:] = 1
    for i in range(batch_size):
        heat_map_f[i, top_k_indices[i]] = 0
    heat_maps = heat_map_f.reshape(batch_size, w_, h_)
    return [np.expand_dims(heat_maps, axis=-1)]


def mask_the_max_pixel(heat_maps, **kwargs):
    new_heat_map = np.ones_like(heat_maps)
    new_heat_map[heat_maps == np.reshape(
        np.max(heat_maps, axis=(1, 2)), (heat_maps.shape[0], 1, 1))] = 0
    return [np.expand_dims(new_heat_map, axis=-1)]


class SparsityHeatmap:
    def __init__(self, model, target_layer) -> None:
        self.model = model
        self.target_layer = target_layer

    def __call__(self, images, *args: Any, **kwds: Any) -> Any:
        activations = {}

        def forward_hook(module, input, output):
            activations["value"] = output.cpu().detach()

        self.target_layer.register_forward_hook(forward_hook)
        _ = self.model(images)
        b, c, h, w = images.shape
        sparsity_values, sparsity_indices = torch.sort(1-(torch.sum(torch.where(activations['value'] > 0, 1, 0), dim=(
            2, 3))/(activations['value'].shape[2]*activations['value'].shape[3])), dim=1)

        top_activations = []
        for sample_id, activation_id in enumerate(sparsity_indices[:, -1]):
            top_activations.append(
                activations['value'][sample_id, activation_id])

        top_activations = torch.stack(top_activations, dim=0)

        saliency_map = F.interpolate(
            torch.unsqueeze(top_activations, 1), size=(h, w), mode="bilinear", align_corners=False
        )
        return torch.squeeze(saliency_map).detach().cpu().numpy()


def mask_with_bernoulli_distribution(heat_maps):
    masked_heatmaps = 1 - np.random.binomial(1, heat_maps)
    return [masked_heatmaps]


def get_heat_map(images, heat_map_generator):
    return heat_map_generator(images)


def soft_mask_heatmap_using_threshold(heat_maps, **kwargs):
    mask_threshold_values = []
    mask_mean_value = np.nanmean(
        np.where(heat_maps > 0, heat_maps, np.nan), axis=(1, 2))[:, None, None]
    mask_std_value = np.nanstd(
        np.where(heat_maps > 0, heat_maps, np.nan), axis=(1, 2))[:, None, None]
    for std_coefficient in kwargs["std_coefficients"]:
        mask_threshold_values.append(
            mask_mean_value + std_coefficient * mask_std_value)
    masked_heatmaps = []
    for mask_threshold_value in mask_threshold_values:
        masked_heatmaps.append(np.expand_dims(
            np.where(heat_maps > mask_threshold_value, 0.05, 0.95), axis=-1))

    return masked_heatmaps


def get_mask(
    images,
    heat_map_generator,
    masking,
    remove_k,
    global_threshold,
    std_coefficients,
):
    """
    returns masked heatmaps
    """
    mask_config = {}
    if masking == "threshold_mask":
        masker = mask_heatmap_using_threshold
        mask_config = {
            "mask_threshold_value": global_threshold,
            "std_coefficients": std_coefficients
        }
        power = 1
    elif masking == "soft_mask":
        masker = soft_masker
        power = 2

    elif masking == "sort_mask":
        masker = mask_heatmap_using_sort
        mask_config = {"remove_k": remove_k}
        power = 2
        print(f"Removing top {remove_k} pixels")

    elif masking == "max_pixel":
        masker = mask_the_max_pixel
        power = 1

    elif masking == "bernoulli":
        masker = mask_with_bernoulli_distribution
        power = 0.5

    elif masking == "soft_mask_using_threshold":
        masker = soft_mask_heatmap_using_threshold
        mask_config = {
            "std_coefficients": std_coefficients
        }
        power = 1

    heat_maps = heat_map_generator(images)
    heat_maps = np.power(heat_maps, power)
    masked_heat_maps = masker(heat_maps=heat_maps, **mask_config)

    return masked_heat_maps
