import argparse
import json
import os
from typing import Callable, Tuple

import numpy as np
import pandas as pd
import torchvision
from PIL import Image
from tqdm import tqdm

tqdm.pandas()

import torch
from torchvision import transforms


def safe_mean(df):
    """Compute a mean over a (grouped) dataset while gracefully treating non-numeric columns."""
    new_df = {}
    for col in df.columns:
        if df[col].dtype == object:
            try:
                unique_values = df[col].unique()

                if len(unique_values) == 1:
                    value = unique_values[0]
                else:
                    value = np.nan
            except:
                value = np.nan
        else:
            value = df[col].mean()
        new_df[col] = value

    return pd.Series(new_df)


def get_df(fn):
    df = pd.read_pickle(fn)
    df = df[~df["is_demo"] & ~df["catch_trial"]]
    df = df[
        [
            "model",
            "channel",
            "batch",
            "layer",
            "correct",
            "confidence",
            "rt",
            "is_demo",
            "catch_trial",
            "mode",
        ]
    ]
    df = (
        df.groupby(["model", "channel", "batch", "layer", "mode"])
        .apply(safe_mean)
        .drop(["model", "channel", "batch", "layer", "mode"], axis=1)
    )
    df = df.reset_index()

    return df


def process_row(
    row: pd.Series,
    imagenet_dir: str,
    preprocess_fn: Callable,
    get_similarities_fn: Callable[[torch.Tensor], torch.Tensor],
    natural_structured_filenames: dict,
    device: str,
) -> np.ndarray:
    if row.model not in natural_structured_filenames:
        return None

    preprocess_natural = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
        ]
    )
    preprocess_synthetic = lambda x: x

    def check_zip_fn(fn):
        fn = fn.replace("resnet50_hard95", "resnet50")
        return fn

    if row["mode"] == "natural":
        # Load images
        image_filenames = [
            os.path.join(imagenet_dir, fn)
            for _, fn in sorted(
                natural_structured_filenames[row.model][
                    f"{row.layer}/channel_{row.channel}"
                ][f"batch_{row.batch}"].items(),
                key=lambda x: x[0],
            )
        ]
        pils = [Image.open(fn) for fn in image_filenames]
        images = torch.cat(
            [preprocess_fn(preprocess_natural(pil)) for pil in pils], 0
        ).to(device)
    else:
        import io
        import zipfile

        fvs = zipfile.ZipFile(
            "$USERn/IMI/feature_visualizations.zip", "r"
        )
        pils = []
        for i in range(9):
            pils.append(
                preprocess_synthetic(
                    Image.open(
                        io.BytesIO(
                            fvs.read(
                                check_zip_fn(
                                    f"release/{row.model}/{row.layer}/channel_{row.channel}/optimized_images/max_{i}.png"
                                )
                            )
                        )
                    )
                )
            )
        pils.append(
            preprocess_natural(
                Image.open(
                    os.path.join(
                        imagenet_dir,
                        natural_structured_filenames[row.model][
                            f"{row.layer}/channel_{row.channel}"
                        ][f"batch_{row.batch}"]["max_9.png"],
                    )
                )
            )
        )
        for i in range(9):
            pils.append(
                preprocess_synthetic(
                    Image.open(
                        io.BytesIO(
                            fvs.read(
                                check_zip_fn(
                                    f"release/{row.model}/{row.layer}/channel_{row.channel}/optimized_images/min_{i}.png"
                                )
                            )
                        )
                    )
                )
            )
        pils.append(
            preprocess_natural(
                Image.open(
                    os.path.join(
                        imagenet_dir,
                        natural_structured_filenames[row.model][
                            f"{row.layer}/channel_{row.channel}"
                        ][f"batch_{row.batch}"]["min_9.png"],
                    )
                )
            )
        )
        images = torch.cat([preprocess_fn(pil) for pil in pils], 0).to(device)

    # Compute scores
    similarities = get_similarities_fn(images)
    return similarities


def setup_dreamsim_similarity_model(
    device: str = "cpu",
) -> Tuple[Callable[[torch.tensor], torch.tensor], Callable]:
    from dreamsim import dreamsim

    dreamsim_model, dreamsim_preprocess = dreamsim(
        pretrained=True,
        cache_dir=os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch") + "/dreamsim"),
    )
    dreamsim_model = dreamsim_model.to(device)

    def get_similarities(images: torch.Tensor) -> np.ndarray:
        features = dreamsim_model.embed(images)
        scores = (
            torch.nn.functional.cosine_similarity(
                features.unsqueeze(1), features.unsqueeze(0), dim=-1
            )
            .cpu()
            .numpy()
        )
        return scores

    return get_similarities, dreamsim_preprocess


def setup_lpips_similarity_model(
    device: str = "cpu",
) -> Tuple[Callable[[torch.tensor], torch.tensor], Callable]:
    import lpips

    lpips_model = lpips.LPIPS(net="alex")
    lpips_model = lpips_model.to(device)

    t = transforms.Compose(
        [
            # transforms.Resize(
            #    (256, 256),
            #    interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor()
        ]
    )

    def lpips_preprocess(image: Image) -> torch.Tensor:
        image = image.convert("RGB")
        image = t(image)
        image = image.unsqueeze(0)
        return image

    def get_similarities(images: torch.Tensor) -> np.ndarray:
        with torch.no_grad():
            scores = lpips_model(images, images).cpu().numpy()
        return scores

    return get_similarities, lpips_preprocess


def setup_dists_similarity_model(
    device: str = "cpu",
) -> Tuple[Callable[[torch.tensor], torch.tensor], Callable]:
    import DISTS_pytorch

    dists_model = DISTS_pytorch.DISTS()
    dists_model = dists_model.to(device)

    def dists_preprocess(image: Image) -> torch.Tensor:
        image = image.convert("RGB")
        image = DISTS_pytorch.DISTS_pt.prepare_image(image)
        return image

    def get_similarities(images: torch.Tensor) -> np.ndarray:
        with torch.no_grad():
            a = images[:, None]
            b = images[None]
            a, b = torch.broadcast_tensors(a, b)
            a = a.contiguous().view(-1, *a.shape[2:])
            b = b.contiguous().view(-1, *b.shape[2:])
            scores = dists_model(a, b).cpu().numpy()
            scores = scores.reshape(len(images), len(images))
        return scores

    return get_similarities, dists_preprocess


def main():
    available_similarity_functions = {
        "dreamsim": setup_dreamsim_similarity_model,
        "lpips": setup_lpips_similarity_model,
        "dists": setup_dists_similarity_model,
    }
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--imagenet-dir", type=str, default="/scratch_local/datasets/ImageNet2012/train"
    )
    parser.add_argument(
        "--main-df",
        type=str,
        default="/IMI/responses_main.pd.pkl",
    )
    parser.add_argument(
        "--lowerq-df",
        type=str,
        default="/IMI/responses_lower_quality.pd.pkl",
    )
    parser.add_argument("--main-df-output", type=str, required=True)
    parser.add_argument("--lowerq-df-output", type=str, required=True)
    parser.add_argument(
        "--similarity-function",
        type=str,
        required=True,
        choices=available_similarity_functions.keys(),
    )
    args = parser.parse_args()

    filename = "$USERn/IMI/name_mapping_all_experiments.json"
    with open(filename) as json_file:
        natural_structured_filenames = json.load(json_file)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    get_similarities_fn, preprocess_fn = available_similarity_functions[
        args.similarity_function
    ](device)

    main_df = get_df(args.main_df)
    lowerq_df = get_df(args.lowerq_df)

    print(
        "Found {} sets of images in main data and {} in lower quality data.".format(
            len(main_df), len(lowerq_df)
        )
    )

    def process_df(df):
        df[f"{args.similarity_function}_similarity"] = df.progress_apply(
            lambda row: process_row(
                row,
                args.imagenet_dir,
                preprocess_fn,
                get_similarities_fn,
                natural_structured_filenames,
                device,
            ),
            axis=1,
        )

    process_df(main_df)
    main_df.to_pickle(args.main_df_output)

    process_df(lowerq_df)
    lowerq_df.to_pickle(args.lowerq_df_output)


if __name__ == "__main__":
    main()
