"""
Utility classes and methods.
"""

import collections
import json
import os
import pickle
import platform
import re
from types import TracebackType
from typing import (Any, Callable, Generator, List, NoReturn, Optional,
                    OrderedDict, Sequence, Tuple, Type, TypeVar)

import numpy as np
import pandas as pd
import timm
import torch
import torch.nn.functional as F
import torchvision
import webdataset as wds
from lucent.modelzoo import inceptionv1
from lucent.optvis.hooks import ModelHook, ModuleHook
from open_clip import create_model_and_transforms
from PIL import Image, ImageFile
from torch import nn
from torchvision import datasets, transforms

try:
    import stimuli_generation.resnet as resnet
    from stimuli_generation.clip_utils import (imagenet_classnames,
                                               openai_imagenet_template,
                                               zero_shot_classifier)
except:
    import resnet
    from clip_utils import (imagenet_classnames, openai_imagenet_template,
                            zero_shot_classifier)


# getting around PIL errors on Thomas' machine only - only relevant for debugging stuff
if platform.system() == "Darwin":
    ImageFile.LOAD_TRUNCATED_IMAGES = True

ckpt_dict = {
    "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
    "resnet50-linf": "https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_linf_eps4.0.ckpt",  # noqa: E501
    "resnet50-l2": "https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt",  # noqa: E501
    "wide_resnet50": "https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
}

# Storing the advertised accuracies of the models, so we can check whether they
# work correctly
accuracies = {
    "resnet50": 0.7613,
    "resnet50-linf": 0.6386,
    "resnet50-l2": 0.6238,
    "googlenet": 0.6915,
    "clip-resnet50": 0.5983,
    "wide_resnet50": 0.8160,  # Using 232 instead of 256 resize before central crop
    "densenet_201": 0.7689,
    "convnext_b": 0.838,
    "clip-vit_b32": 0.666,
    "in1k-vit_b32": 0.74904,
}

# Holds the transforms applied for each model
model_transforms = {model_name: None for model_name in accuracies}

# Holds the necessary transforms for each model, so we don't have to construct them every time.
# For CLIP-models, getting the transforms requires loading the model, so it's a lot faster to just store them.
# This dict will be filled by the first call to get_transforms and used in consecutive calls.
model_transforms = {model_name: None for model_name in accuracies}


def get_default_device(gpu_id: Optional[int] = None):
    """Get the default device for torch; GPU if supported."""
    if gpu_id is not None:
        device = torch.device(f"cuda:{gpu_id}")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device {device}.")
    return device


_T_co = TypeVar("_T_co")
_T_contra = TypeVar("_T_contra")
_V_co = TypeVar("_V_co")


class KnownLengthGenerator(Generator[_T_co, _T_contra, _V_co]):
    """Adds a known length to a generator."""

    def __init__(self, generator: Generator[_T_co, _T_contra, _V_co], length: int):
        self.generator = generator
        self.length = length

    def __iter__(self) -> Generator[_T_co, _T_contra, _V_co]:
        return self.generator

    def __len__(self) -> int:
        return self.length

    def send(self, __value: _T_contra) -> _T_co:
        return self.generator.send(__value)

    def throw(
        self,
        typ: Type[BaseException],
        val: Optional[BaseException] = ...,
        tb: Optional[TracebackType] = ...,
    ) -> NoReturn:
        return self.generator.throw(typ, val, tb)


def split_unit(layer_unit):
    """
    Splits a fully described unit into layer and index.

    :param layer_unit: the full unit, e.g. "layer1_0_conv1__0"
    """
    parts = layer_unit.split("__")
    return parts[0], int(parts[1])


def chunks(lst, n):
    """
    Yields successive n-sized chunks from lst.

    :param lst: the list to be chunked
    :param n: the size of each chunk
    """
    for i in range(0, len(lst), n):
        yield lst[i : i + n]


def store_units(units, filename):
    """
    Stores a json-file of unit names in the given file.

    :param units: list of units
    :param filename: name of the file to store list in (without extension)
    """
    data = {"units": units}
    with open(filename + ".json", "w", encoding="utf-8") as file:
        json.dump(data, file)


def test_layer_relevance(layer, model_name):
    """
    Returns true if a layer-name is relevant, for a given model_name.

    :param layer: the layer
    :param model_name: the model
    """

    if model_name == "googlenet":
        return layer.startswith("mixed") and "pre_relu_conv" in layer

    if model_name == "densenet_201":
        dense = layer.startswith("features_denseblock") and layer.endswith(
            ("conv1", "conv2", "norm1", "norm2")
        )
        transition = layer.startswith("features_transition") and layer.endswith(
            ("conv", "norm")
        )
        return dense or transition

    if model_name == "clip-vit_b32":
        return layer.startswith("transformer_resblocks") and layer.endswith(
            ("ln_1", "ls_1", "ln_2", "mlp_c_fc", "mlp_c_proj", "ls_2")
        )

    if model_name == "in1k-vit_b32":
        return layer.startswith("blocks_") and layer.endswith(
            ("norm1", "ls1", "norm2", "mlp_fc1", "mlp_fc2", "ls2")
        )

    if model_name == "convnext_b":
        regex = "stages_\\d+_blocks_\\d+"
        block = (
            re.fullmatch(regex, layer) is not None
        )  # corresponds to shortcut connection
        other = re.match(regex, layer) is not None and layer.endswith(
            ("conv_dw", "norm", "mlp_fc1", "mlp_fc2")
        )

        return block or other

    return (
        layer.startswith("layer") or layer.startswith("visual_layer")
    ) and layer.endswith(("conv1", "conv2", "conv3", "bn1", "bn2", "bn3", "shortcut"))


def test_permute(model_name, layer_name):
    """
    Tests if the activations at a layer should be permuted from (batch, h, w, c) to (batch, c, h, w).
    This is the case for 1x1-Conv- and Norm-layers in ConvNext,
    because they use LayerNorm and implement 1x1 convolutions as LinearLayers.

    :param model_name: the name of the model
    :param layer_name: the name of the layer
    :returns: True if the layer needs to be permuted
    """

    return model_name == "convnext_b" and (
        "mlp_fc" in layer_name or "norm" in layer_name
    )


def get_relevant_layers(
    model, model_name, strict_mode: bool = True, get_modules: bool = False
):
    """
    Returns the layers of interest to us as a list.

    :param model: the torch model
    :param model_name: str, model name
    :param strict_mode: bool, if True, use old layer selection criteria
    :param get_modules: bool, if True, returns a dict mapping layer name to layer
    :returns: list of str, layers
    """

    def _test_layer_relevance(layer: nn.Module, name: str) -> bool:
        if model_name.startswith("timm://") or not strict_mode:
            return isinstance(
                layer,
                (
                    nn.Conv1d,
                    nn.Conv2d,
                    nn.Linear,
                    nn.BatchNorm1d,
                    nn.BatchNorm2d,
                    nn.LayerNorm,
                    nn.GroupNorm,
                ),
            )
        else:
            return test_layer_relevance(name, model_name)

    if get_modules:
        return {
            name: layer
            for name, layer in get_model_layers(model, True).items()
            if _test_layer_relevance(layer, name)
        }
    else:
        return [
            name
            for name, layer in get_model_layers(model, True).items()
            if _test_layer_relevance(layer, name)
        ]


def get_model_layers(model, get_modules=False):
    """
    Custom version of Lucent's modelzoo.util.get_model_layers that returns dict mapping
    layer name to layer.
    If get_modules is True, return a OrderedDict of layer names, layer representation
    string pair.
    """
    layers = OrderedDict() if get_modules else []

    # recursive function to get layers
    def get_layers(net, prefix=[]):
        if hasattr(net, "_modules"):
            for name, layer in net._modules.items():
                if layer is None:
                    # e.g. GoogLeNet's aux1 and aux2 layers
                    continue
                if get_modules:
                    layers["_".join(prefix + [name])] = layer
                else:
                    layers.append("_".join(prefix + [name]))
                get_layers(layer, prefix=prefix + [name])

    if isinstance(model, torch.nn.DataParallel):
        model = model.module

    get_layers(model)
    return layers


def read_units_file(unitfile):
    """
    Reads the json file which contains the list of units.
    (Overkill to have a function here but maybe this gets more complicated)

    :param unitfile: the json-file with units
    """
    with open(unitfile, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data["units"]


def get_layers_from_units_list(units):
    """
    Takes a list of units and returns a list of all of their layers.

    :param units: list[str] of units
    :returns: list[str] of layers
    """
    layers = []
    for unit in units:
        layer, _ = split_unit(unit)
        layers.append(layer)

    return sorted(list(set(layers)))


def transform_and_copy_img(src_path, dest_path):
    """
    apply a transform to image at src_path and save it in dest_path
    """

    resize_and_crop = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
        ]
    )

    img = Image.open(src_path)
    img = resize_and_crop(img)

    # there were some issues with images being in CMYK mode
    if img.mode != "RGB":
        print(
            f"image {src_path} is not in RGB mode! Converting before "
            f"saving to {dest_path}"
        )
        img = img.convert("RGB")
    img.save(dest_path)


class ImageFolderWithPaths(datasets.ImageFolder):
    """
    Custom dataset that includes image file paths.
    Extends torchvision.datasets.ImageFolder.
    """

    def __init__(self, *args, return_indices: bool = False, **kwargs):
        super().__init__(*args, **kwargs)
        self.return_indices = return_indices

    # override the __getitem__ method. this is the method dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        if self.return_indices:
            additional_tuple = (path, index)
        else:
            additional_tuple = (path,)
        tuple_with_path = original_tuple + additional_tuple
        return tuple_with_path


def get_transforms(model_name):
    """
    Obtains a Composition of the model-specific transforms.
    :param model_name: the identifier of the model
    :returns: a torch transform
    """

    # check if we have transforms, use them if so, otherwise build them
    if model_name in model_transforms and model_transforms[model_name] is not None:
        return model_transforms[model_name]

    if model_name == "clip-resnet50":
        # openclip can conveniently just give us the transforms
        _, _, transformations = create_model_and_transforms("RN50", pretrained="openai")
    elif model_name == "clip-vit_b32":
        _, _, transformations = create_model_and_transforms(
            "ViT-B-32", pretrained="laion2b_s34b_b79k"
        )
    elif model_name == "in1k-vit_b32":
        model = timm.create_model("vit_base_patch32_224.augreg_in1k", pretrained=True)
        data_config = timm.data.resolve_data_config(args={}, model=model)
        transformations = timm.data.create_transform(**data_config, is_training=False)
    elif model_name.startswith("timm://"):
        model = timm.create_model(model_name[len("timm://") :], pretrained=True)
        data_config = timm.data.resolve_data_config(args={}, model=model)
        transformations = timm.data.create_transform(**data_config, is_training=False)
    else:
        if model_name == "googlenet":
            normalize = lambda x: x * 255 - 117
        else:
            normalize = transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )

        img_size = 229 if model_name == "googlenet" else 224

        transformations = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(img_size),
                transforms.ToTensor(),
                normalize,
            ]
        )

    model_transforms["model_name"] = transformations
    return transformations


def get_dataloader(
    datadir: str,
    model_name: Optional[str] = None,
    batch_size: Optional[int] = None,
    transform=None,
    return_indices: bool = False,
    use_webdataset: bool = False,
):
    """
    Provides a dataloader for imagenet validation set.

    :param datadir: path to dataset, e.g. /path/to/imagenet/val
    """

    if transform is None:
        if model_name is None:
            raise ValueError("Either model_name or transform must be provided")
        transform = get_transforms(model_name)

    if use_webdataset:
        if batch_size is None:
            raise ValueError("batch_size must be provided when using webdataset")
        if batch_size < 1:
            raise ValueError("batch_size must be greater than 0")
        dataset = (
            wds.WebDataset(datadir)
            .shuffle(True)
            .decode("pil")
            .to_tuple("jpeg.jpg;png.png jpeg.cls __key__")
            .map_tuple(transform, lambda x: x, lambda x: x)
            .batched(batch_size, partial=False)
        )
        dataloader = wds.WebLoader(
            dataset,
            batch_size=None,
            shuffle=False,
            num_workers=8,
        )
    else:
        dataset = ImageFolderWithPaths(
            datadir, transform, return_indices=return_indices
        )

        if batch_size is None:
            if model_name is None:
                raise ValueError("Either model_name or batch_size must be provided")
            batch_size = 30 if model_name == "clip-resnet50" else 32
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
        )

    return dataloader


def load_model(model_name, checkpoint: Optional[str] = None, device=None):
    """Loads a model to device and puts it in eval mode."""

    if model_name == "googlenet":
        model = inceptionv1(pretrained=True)
    elif model_name == "clip-resnet50":
        model, _, _ = create_model_and_transforms("RN50", pretrained="openai")
    elif model_name.startswith("resnet50"):
        # we use our own resnet50 to get pre-relu activations
        model = resnet.resnet50()

        ckpt_url = ckpt_dict[model_name]
        checkpoint = torch.hub.load_state_dict_from_url(
            ckpt_url, map_location=device if device != "mps" else "cpu"
        )

        if model_name in ["resnet50-linf", "resnet50-l2"]:
            state_dict = {
                k[len("module.model.") :]: v
                for k, v in checkpoint["model"].items()
                if k[: len("module.model.")] == "module.model."
            }  # Consider only the model and not normalizers or attacker
            model.load_state_dict(state_dict)
        else:
            model.load_state_dict(checkpoint)
    elif model_name == "wide_resnet50":
        model = resnet.wide_resnet50_2()
        ckpt_url = ckpt_dict[model_name]
        checkpoint = torch.hub.load_state_dict_from_url(
            ckpt_url, map_location=device if device != "mps" else "cpu"
        )
        model.load_state_dict(checkpoint)
    elif model_name == "densenet_201":
        model = torchvision.models.densenet201(
            weights=torchvision.models.DenseNet201_Weights.IMAGENET1K_V1
        )
        # Replace ReLU with non-inplace ReLUs
        replace_relu(model)
    elif model_name == "in1k-vit_b32":
        model = timm.create_model("vit_base_patch32_224.augreg_in1k", pretrained=True)
    elif model_name == "clip-vit_b32":
        model, _, _ = create_model_and_transforms(
            "ViT-B-32", pretrained="laion2b_s34b_b79k"
        )
    elif model_name == "convnext_b":
        model = timm.create_model("hf-hub:timm/convnext_base.fb_in1k", pretrained=True)
        # No need to replace ReLU since this model uses non-inplace GeLU units.
    elif model_name.startswith("timm://"):
        model = timm.create_model(
            model_name[len("timm://") :],
            pretrained=checkpoint is None,
            checkpoint_path=checkpoint if checkpoint is not None else "",
        )
    else:
        raise KeyError(f"Model {model_name} not known!")

    # in any case, push model to device and set to eval mode
    model.to(device).eval()

    return model


def aggregate_activations(activations, layer_name, batch_size):
    """
    Prepares the activations obtained at the layer for storage,
    to get tensor of shape batchsize x num_units.

    :param activations: the activations tensor
    :param model_name: the name of the model
    :param batch_size: the batch size
    :returns: tensor of activations of shape batchsize x num_units
    """

    # Test whether activations need to be transposed: (b, h, w, c) to (b, c, h, w)
    if activations.ndim == 4:
        if (
            activations.shape[1] == activations.shape[2]
            and not activations.shape[2] == activations.shape[3]
        ):
            activations = torch.permute(activations, (0, 3, 1, 2))

    # TODO(zimmerrol): Verify correctness of this change.
    """
    # calculate the reduced mean across each channel
    if test_permute(model_name, layer_name):
        # convnext 1x1 convolutions implemented as linear layers 
        activations = torch.permute(activations, (0, 3, 1, 2))
    """
    if activations.ndim == 3:
        if activations.shape[0] != batch_size:
            # acts has shape seq_length x batch x embedding_dim
            activations = torch.mean(activations, dim=0, keepdim=False)
        else:
            # acts has shape batch x seq_length x embedding_dim
            activations = torch.mean(activations, dim=1, keepdim=False)
    elif len(activations.shape) == 4:
        # activations has shape batchsize x out_channels x w x h
        activations = torch.mean(activations, dim=(-1, -2))

    return activations


def get_activations(
    model, model_name, layer, channel, images, device: Optional[Any] = None
):
    """
    Gets the activations that a list of images (224x224 feature visualisations) achieves at a unit.

    :param model: the model
    :param layer: the layer of the target unit
    :param channel: the channel of the target unit
    :param images: the list of images
    :returns: a list of activations (same order as images)
    :param device: the device to use
    """

    # Lucent applies InceptionTransform to googlenet.
    # Apart from that, irrespective of model-specific transforms,
    # FVs were generated using torchvision standard transforms

    transform_list = [transforms.ToTensor()]
    if model_name == "googlenet":
        normalize = lambda x: x * 255 - 117
        transform_list.append(normalize)
    elif model_name in ("clip-resnet50", "in21k_in1k-vit_b32", "clip-vit_b32"):
        transform_list.append(
            transforms.Normalize(
                mean=(0.48145466, 0.4578275, 0.40821073),
                std=(0.26862954, 0.26130258, 0.27577711),
            )
        )
    else:
        # For RN50, WRN, densenet and convnext.
        transform_list.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        )

    transform = transforms.Compose(transform_list)

    # apply transforms to images and push them to device
    imgs = [transform(img).to(device).unsqueeze(0) for img in images]
    imgs = torch.cat(imgs, dim=0)

    # feed images through model, obtain activations
    with torch.no_grad() and ModelHook(model, layer_names=[layer]) as hook:
        _logits = model(imgs)
        acts = hook(layer)

    acts = aggregate_activations(acts, model_name, layer)

    # select channel
    acts = acts[:, channel].cpu()

    # return the activations for the images
    return [acts[i].item() for i in range(len(images))]


def read_pickled_activations_file(filepath, units=None):
    """
    Reads and parses a .pkl file of activations and returns it as a pandas DataFrame,
    with columns 0, ..., n_units, path.

    :param filepath: full path to .pkl file
    :param units: list of ints, the units to be loaded into the df

    :returns: pandas DataFrame
    """

    with open(filepath, "rb") as fhandle:
        # res_dict maps 'activations' to np array and 'paths' to list of paths
        res_dict = pickle.load(fhandle)

        # select only those units that are of interest
        activations = res_dict["activations"]

        if units:
            activations = activations[:, units]
        else:
            units = list(range(activations.shape[1]))

        # create dataframe from relevant units
        dataframe = pd.DataFrame(activations, columns=[str(u) for u in units])

        # add the column of paths
        if isinstance(res_dict["paths"], np.ndarray) and res_dict["paths"].ndim == 2:
            paths_df = pd.DataFrame(
                np.array([res_dict["paths"][:, unit] for unit in units]).T,
                columns=[f"path_{unit}" for unit in units],
            )
            dataframe = pd.concat([dataframe, paths_df], axis=1)
        else:
            dataframe["path"] = res_dict["paths"]

    return dataframe


def get_min_max_exemplar_activations(
    model_name, unit_name, activations_root, use_csv=False
):
    """
    Returns the activations achieved by minimally and maximally activating image.

    :param model_name: str, the name of the model
    :param unit_name: str, the unit for which to find exemplars
    :param activations_root: str, the root directory of the activations
    :param use_csv: bool, whether the old CSV files should be used

    :returns: min, max activation
    """

    layer, unit = split_unit(unit_name)

    # construct path to csv-files
    activations_dir = os.path.join(activations_root, model_name)
    assert os.path.exists(activations_dir), "Could not find directory with activations!"

    if use_csv:
        # get path to CSV of this layer and get df
        csv_path = os.path.join(activations_dir, layer + ".csv")
        assert os.path.exists(csv_path), f"Could not find path to csv: {csv_path}"
        unit_df = pd.read_csv(csv_path, usecols=[str(unit)])
    else:
        pkl_path = os.path.join(activations_dir, layer + ".pkl")
        assert os.path.exists(
            pkl_path
        ), f"Could not find path to pickle-file: {pkl_path}"
        unit_df = read_pickled_activations_file(pkl_path, [unit])

    min_act = unit_df.min().get(str(unit))
    max_act = unit_df.max().get(str(unit))

    return min_act, max_act


def get_label_translator():
    """
    Returns a function that transforms a batch of labels to the new label values for InceptionV1.

    old labels:
        https://raw.githubusercontent.com/rgeirhos/lucent/dev/lucent/modelzoo/misc/old_imagenet_labels.txt
    new labels:
        https://raw.githubusercontent.com/conan7882/GoogLeNet-Inception/master/data/imageNetLabel.txt
    """

    with open("old_imagenet_labels.txt", "r", encoding="utf-8") as fhandle:
        old_imagenet_labels_data = fhandle.read()
    with open("imagenet_labels.txt", "r", encoding="utf-8") as fhandle:
        new_imagenet_labels_data = fhandle.read()

    # maps a class index to wordnet-id in old convention
    old_imagenet_labels_map = {}
    for cid, l in enumerate(old_imagenet_labels_data.strip().split("\n")):
        wid = l.split(" ")[0].strip()
        old_imagenet_labels_map[wid] = cid

    # maps a class index to wordnet-id in new convention
    new_imagenet_labels_map = {}
    for cid, l in enumerate(new_imagenet_labels_data.strip().split("\n")):
        wid = l.split(" ")[0].strip()
        new_imagenet_labels_map[cid] = wid

    def remap_torch_to_tf_labels(y):
        """Map PyTorch-style ImageNet labels to old convention used by GoogLeNet/InceptionV1."""
        res = []
        for yi in y.cpu().numpy():
            zi = None
            wid = new_imagenet_labels_map[yi]
            if wid in old_imagenet_labels_map:
                zi = old_imagenet_labels_map[wid]
                res.append(zi)
            else:
                raise ValueError(f"Unknown class {yi}/{wid}.")

        return torch.tensor(res).to(y.device) + 1

    return remap_torch_to_tf_labels


def replace_module(
    model: nn.Module,
    check_fn: Callable[[nn.Module], bool],
    get_replacement_fn: Callable[[nn.Module], nn.Module],
):
    """Recursively replaces modules in model with new modules.

    Args:
        model: The model to replace modules in.
        check_fn: A function that takes a module and returns True if it should
            be replaced.
        get_replacement_fn: A function that takes a module and returns a new
            module to replace it with.
    """
    children = list(model.named_children())
    for name, value in children:
        if check_fn(value):
            new_value = get_replacement_fn(value)
            setattr(model, name, new_value)
        replace_module(value, check_fn, get_replacement_fn)


def replace_relu(model: nn.Module):
    """Replaces all ReLU modules in model with non-inplace ones.
    Args:
        model: The model to replace modules in.
    """
    replace_module(model, lambda x: isinstance(x, nn.ReLU), lambda _: nn.ReLU())


def get_clip_zero_shot_classifier(model, model_name, device: Optional[Any] = None):
    """
    Turns a CLIP-model into an ImageNet1k zero shot classifier.

    :param model: the pytorch model
    :param model_name: the model name
    :param device: the device to use
    """

    class ClipArgs:
        """Clip expects some object that holds configuration."""

        def __init__(self, model_name, device: Optional[Any] = None):
            if model_name == "clip-resnet50":
                self.model = "RN50"
                self.pretrained = "openai"
            elif model_name == "clip-vit_b32":
                self.model = "ViT-B-32"
                self.pretrained = "laion2b_s34b_b79k"

            self.distributed = False
            self.horovod = False
            self.precision = "fp32"
            self.batch_size = 32
            self.device = device

    args = ClipArgs(model_name, device)

    return zero_shot_classifier(
        model, imagenet_classnames, openai_imagenet_template, args
    )


def get_clip_logits(model, classifier, images):
    """
    Obtains the class-level logits for clip-trained models.

    :param model: the pytorch model
    :param classifier: the zero-shot classifier
    :param images: the batch of images
    """

    image_features = model.encode_image(images)
    image_features = F.normalize(image_features, dim=-1)
    logits = 100.0 * image_features @ classifier

    return logits


def read_activations_file(activations_dir, layer, force_csv: bool = False):
    """
    Reads the file (.csv or .pkl) of activations and returns the relevant units as pandas DataFrame.

    :param activations_dir: path to directory where activations files lie
    :param layer: the layer we are recording
    :param force_csv: whether to force reading the csv file instead of the pickled file

    :returns: a pandas DataFrame that for all units of a layer and all images of the dataset stores their activation
    """
    pkl_path = os.path.join(activations_dir, layer + ".pkl")
    csv_path = os.path.join(activations_dir, layer + ".csv")

    if force_csv or not os.path.exists(pkl_path):
        assert os.path.exists(csv_path), f"Could not find path to csv: {csv_path}"
        # only loading columns for the chosen units and the filepath
        dataframe = pd.read_csv(csv_path)
    else:
        assert os.path.exists(
            pkl_path
        ), f"Could not find path to pickled file: {pkl_path}"

        dataframe = read_pickled_activations_file(pkl_path)

    return dataframe


def extract_stimuli_for_layer_units_from_dataframe(
    dataframe: pd.DataFrame,
    num_batches: int,
    start_idx_min: int,
    stop_idx_min: int,
    start_idx_max: int,
    stop_idx_max: int,
) -> Generator[tuple[str, list[str]], Any, None]:
    """Extracts the stimuli for a given set of units from a dataframe.

    :param num_batches: the number of batches to extract
    :param start_idx_min: the index from where to begin sampling the min query images
    :param stop_idx_min: the index where to stop sampling the min query images
    :param start_idx_max: the index from where to begin sampling the max query images
    :param stop_idx_max: the index where to stop sampling the max query images
    """
    num_images_total = len(dataframe)  # how many images there are in total

    # make sure that there is no overlap between query and reference images
    if (start_idx_min >= 0 and num_batches * 9 > start_idx_min) or (
        start_idx_min < 0 and num_batches * 9 > num_images_total + start_idx_min
    ):
        raise ValueError(
            "Illegal combination of arguments! Queries and " "References would overlap!"
        )

    if (stop_idx_max >= 0 and num_images_total - num_batches * 9 < stop_idx_max) or (
        start_idx_max < 0 and num_batches * 9 <= stop_idx_max
    ):
        raise ValueError(
            "Illegal combination of arguments! Queries and " "References would overlap!"
        )

    units = [c for c in dataframe.columns if not c.startswith("path")]

    use_per_unit_paths = "path" not in dataframe.columns

    def inner() -> Generator[tuple[str, list[str]], Any, None]:
        for unit in units:
            if use_per_unit_paths:
                dataframe_view = dataframe[[unit, f"path_{unit}"]].copy()
                dataframe_view.rename(columns={f"path_{unit}": "path"}, inplace=True)
            else:
                dataframe_view = dataframe[[unit, "path"]].copy()

            # extract query images from given range
            # extract 99 reference images from top and bottom (for 11 batches with 9 refs)
            (min_queries, max_queries), (min_refs, max_refs) = extract_stimuli_range(
                dataframe_view,
                unit,
                [
                    (start_idx_min, stop_idx_min, start_idx_max, stop_idx_max),
                    (
                        0,
                        num_batches * 9,
                        num_images_total - (num_batches * 9),
                        num_images_total,
                    ),
                ],
            )

            # Combine the lists - both lists go from least to most, so min list starts with
            # queries (the first / last ten images land in the batch from which queries are
            # sourced, so this is fine).
            min_exemplars = min_queries + min_refs
            max_exemplars = max_refs + max_queries

            min_lists = make_fair_batches(min_exemplars, num_batches, reverse=True)
            max_lists = make_fair_batches(max_exemplars, num_batches)

            # for each unit, we create ten folders...
            yield unit, [maxs + mins for mins, maxs in zip(min_lists, max_lists)]

    return KnownLengthGenerator(inner(), len(units))


def extract_stimuli_for_layer_units(
    model_name,
    layer,
    num_batches,
    start_idx_min,
    stop_idx_min,
    start_idx_max,
    stop_idx_max,
    force_csv,
    activations_root: str,
) -> Generator[tuple[str, list[str]], None, None]:
    """
    Extracts the stimuli for a given layer and its units.

    :param model_name: the model name
    :param layer: the layer name
    :param num_batches: the number of batches to extract
    :param start_idx_min: the index from where to begin sampling the min query images
    :param stop_idx_min: the index where to stop sampling the min query images
    :param start_idx_max: the index from where to begin sampling the max query images
    :param stop_idx_max: the index where to stop sampling the max query images
    :param force_csv: whether to force reading the csv file instead of the pickled file
    :param activations_root: the root directory where the activations lie
    """

    activations_dir = os.path.join(activations_root, model_name)
    assert os.path.exists(
        activations_dir
    ), f"Could not find directory with activations: {activations_dir}"

    # get path to CSV of this layer and get df
    try:
        dataframe = read_activations_file(activations_dir, layer, force_csv)
    except AssertionError:
        print(f"Skipping layer {layer} because no activations file was found.")
        return []

    num_images_total = len(dataframe)  # how many images there are in total
    print(f"Found {num_images_total} images in total for layer {layer}.")

    return extract_stimuli_for_layer_units_from_dataframe(
        dataframe, num_batches, start_idx_min, stop_idx_min, start_idx_max, stop_idx_max
    )


def extract_stimuli_range(df, unit, ranges) -> List[Tuple[List[str], List[str]]]:
    """
    Extracts stimuli for the given unit and the given range, returns them as sorted
    list from least to most activating.

    :param df: the pandas dataframe of activation values
    :param unit: the index of the unit
    :param ranges: list of tuples (start_min, stop_min, start_max, stop_max)
    """

    # select only this unit and sort in ascending order
    unit_df = df[["path", str(unit)]]
    unit_df = unit_df.sort_values(str(unit), ascending=True)

    results = []
    for start_min, stop_min, start_max, stop_max in ranges:
        test_values = np.arange(unit_df.shape[0])
        min_test_values = test_values[start_min:stop_min]
        max_test_values = test_values[start_max:stop_max]
        # Check if min_test_values are already sorted
        if not np.all(np.diff(min_test_values) > 0):
            raise ValueError(
                "Indices are not reasonable. Min indices are in " "the wrong order."
            )
        # Check if max_test_values are already sorted
        if not np.all(np.diff(max_test_values) > 0):
            raise ValueError(
                "Indices are not reasonable. Max indices are in " "the wrong order."
            )
        # Ensure there is no overlap between min and max test values
        if np.logical_and(
            np.any(np.isin(min_test_values, max_test_values)),
            np.any(np.isin(max_test_values, min_test_values)),
        ) or np.max(min_test_values) > np.min(max_test_values):
            raise ValueError("Indices are not reasonable. Min and Max indices overlap.")

        assert (
            len(unit_df) >= stop_max
        ), f"Not enough activations for unit {unit} and index {stop_max}!"

        # Select the first few exemplars = minima, then make fair lists but reverse them
        # first so that min_9 is the strongest negatively activating image
        min_exemplars = unit_df.iloc[start_min:stop_min]["path"].tolist()
        max_exemplars = unit_df.iloc[start_max:stop_max]["path"].tolist()
        results.append((min_exemplars, max_exemplars))

    return results


def make_fair_batches(paths, num_lists: int, reverse=False):
    """
    Makes batches of natural stimuli from a list of paths, sorted ascending by
    activation (i.e. so that the most activating image is the last in paths).

    :param paths: list of paths, sorted ascending by absolute value of activations
    :param num_lists: how many fair lists to generate
    :param reverse: whether to reverse paths first, this is done for minima
    :returns: num_lists lists of batches
    """

    if reverse:
        paths.reverse()

    elems_per_list = int(len(paths) / num_lists)  # usually 10, i.e. 9 ref + 1 query

    # Create elems_per_list bins
    # (we need num_lists elements per bin, because each list will get one element
    # from each bin).
    bins = [paths[i * num_lists : (i + 1) * num_lists] for i in range(elems_per_list)]

    # shuffle every bin
    for data_bin in bins:
        np.random.shuffle(data_bin)

    # construct fair lists by taking the i-th value from every bin
    # note that the last image, which will be {min/max}_9, is taken from the best bin
    fair_vals = [[bins[j][i] for j in range(elems_per_list)] for i in range(num_lists)]

    return fair_vals


class ModuleHookWithAggregation(ModuleHook):
    def __init__(
        self, module: nn.Module, aggregation_fn: Callable[[torch.Tensor], torch.Tensor]
    ):
        def hook_fn(m: nn.Module, args: Any, output: torch.Tensor):
            def add_to_features(t: torch.Tensor, idx: Optional[int] = None):
                device = t.device
                t = aggregation_fn(t)
                if idx is None:
                    self._features[str(device)] = t
                else:
                    self._features[f"{idx}_{str(device)}"] = t

            if torch.is_tensor(output):
                add_to_features(output)
            elif isinstance(output, (tuple, list)):
                for idx, out in enumerate(output):
                    if torch.is_tensor(out):
                        add_to_features(out, idx)

        self.hook = module.register_forward_hook(hook_fn)
        self._features: OrderedDict[str, torch.Tensor] = collections.OrderedDict()

    def clear_features(self):
        self._features.clear()


class ModelHookWithAggregation(ModelHook):
    def __init__(
        self,
        model: nn.Module,
        get_aggregation_fn: Callable[[str], Callable[[torch.Tensor], torch.Tensor]],
        image_f: Optional[Callable[[], torch.Tensor]] = None,
        layer_names: Optional[Sequence[str]] = None,
    ):
        super().__init__(model, image_f, layer_names)
        self.get_aggregation_fn = get_aggregation_fn

    def clear_features(self):
        for k in self.features:
            self.features[k].clear_features()

    def __enter__(self):
        hook_all_layers = self.layer_names is not None and "all" in self.layer_names

        # recursive hooking function
        def hook_layers(net, prefix=[]):
            if hasattr(net, "_modules"):
                layers = list(net._modules.items())
                for i, (name, layer) in enumerate(layers):
                    effective_name = "_".join(prefix + [name])
                    if layer is None:
                        # e.g. GoogLeNet's aux1 and aux2 layers
                        continue

                    if self.layer_names is not None and i < len(layers) - 1:
                        # only save activations for chosen layers
                        if (
                            effective_name not in self.layer_names
                            and not hook_all_layers
                        ):
                            # Don't save activations for this layer but check if it
                            # has any layers we want to save.
                            hook_layers(layer, prefix=prefix + [name])
                            continue

                    self.features[effective_name] = ModuleHookWithAggregation(
                        layer, self.get_aggregation_fn(effective_name)
                    )
                    hook_layers(layer, prefix=prefix + [name])

        if isinstance(self.model, torch.nn.DataParallel):
            hook_layers(self.model.module)
        else:
            hook_layers(self.model)

        def hook(layer):
            if layer == "input":
                out = self.image_f()
            elif layer == "labels":
                out = list(self.features.values())[-1].features
            else:
                assert layer in self.features, (
                    f"Invalid layer {layer}. Retrieve the list of layers with "
                    "`lucent.modelzoo.util.get_model_layers(model)`."
                )
                out = self.features[layer].features
            if out is None:
                raise RuntimeError(
                    "No activations were recorded for this layer. "
                    "Make sure to put the model in eval mode, like so: "
                    "`model.to(device).eval()`. See README for example."
                )
            return out

        return hook
