#################################################
# Evaluate across iterations
#################################################

# import argparse
# import utils
# import random
# import numpy as np
# import torch
# from torchvision import transforms, datasets
# from torch.utils.data import DataLoader, Dataset
# import resnet
# import os

# NETWORKS = {
#     "resnet18": resnet.resnet18,
#     "resnet34": resnet.resnet34,
#     "resnet50": resnet.resnet50,
#     "resnet101": resnet.resnet101,
#     "resnet152": resnet.resnet152,
# }

# DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02", "0.01"]
# AUGS = ["baseaug", "contrastaug", "randaug", "autoaug"]
# ITERS = [0, 4, 9, 14]
# OUT_DIR = "cifar10_ckpts"

# parser = argparse.ArgumentParser()
# parser = utils.add_args(parser)
# parser.add_argument("--rob_data_root", type=str, required=True, help="path to transformed data directory")
# args = parser.parse_args()

# random.seed(args.seed)
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
# torch.cuda.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)

# device, _ = utils.setup_device(False)
# criterion = torch.nn.CrossEntropyLoss()
# metric_meter = utils.AvgMeter()

# @torch.no_grad()
# def eval(loader, model, metric_meter):
#     metric_meter.reset()
#     model.eval()
#     for indx, (img, target) in enumerate(loader):
#         img, target = img.to(device), target.to(device)

#         pred = model(img)
#         loss = criterion(pred, target)

#         pred_cls = pred.argmax(dim=1)
#         acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

#         metrics = {"loss": loss.item(), "acc": acc}
#         metric_meter.add(metrics)
#         utils.pbar(indx / len(loader), msg=metric_meter.msg())
#     utils.pbar(1, msg=metric_meter.msg())


# class CIFARRobustness(Dataset):
#     TYPES = [
#         # noise
#         "gaussian_noise",
#         "shot_noise",
#         "impulse_noise",

#         # blur
#         "defocus_blur",
#         "glass_blur",
#         "motion_blur",
#         "zoom_blur",

#         # weather
#         "snow",
#         "frost",
#         "fog",
#         "brightness",

#         # digital
#         "contrast",
#         "elastic_transform",
#         "pixelate",
#         "jpeg_compression",

#         # extra
#         "gaussian_blur",
#         "saturate",
#         "spatter",
#         "speckle_noise",
#     ]
#     LEVELS = [1, 2, 3, 4, 5]

#     def __init__(self, root, type, level, transform):
#         assert type in self.TYPES
#         assert level in self.LEVELS
#         imgs = np.load(os.path.join(root, f"{type}.npy"))
#         labels = np.load(os.path.join(root, "labels.npy"))
#         self.imgs = imgs[(level - 1) * 10_000 : level * 10_000]
#         self.labels = labels[(level - 1) * 10_000 : level * 10_000]
#         self.transform = transform

#     def __getitem__(self, indx):
#         img = self.imgs[indx]
#         label = self.labels[indx]
#         img = self.transform(img)
#         return img, label

#     def __len__(self):
#         return len(self.imgs)

# f = open(f"{args.dset}_rob_s_results.txt", "w")
# for data_size in DATA_SIZES:
#     for aug in AUGS:
#         for iter in ITERS:
#             ckpt = os.path.join(OUT_DIR, f"sparse_{data_size}_{aug}", f"best_imp_{iter}.ckpt")
#             print(f"Evaluating: {ckpt}")
#             ckpt = torch.load(ckpt)
#             if args.dset == "cifar10":
#                 norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
#                 n_cls = 10
#             elif args.dset == "cifar100":
#                 norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
#                 n_cls = 100
#             else:
#                 raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
#             model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small").to(device)
#             if iter:
#                 model.load_state_dict(ckpt["init"])
#                 curr_mask = utils.extract_mask(ckpt["model"])
#                 utils.mask_prune(model, curr_mask)
#                 print("remaining weight = ", utils.check_sparsity(model))
#             model.load_state_dict(ckpt["model"])

#             # basic
#             transform = transforms.Compose(
#                 [
#                     transforms.ToTensor(),
#                     norm,
#                 ]
#             )
#             dset = datasets.CIFAR10(
#                 root=args.data_root,
#                 train=False,
#                 transform=transform,
#                 download=True,
#             )
#             loader = DataLoader(
#                 dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
#             )
#             eval(loader, model, metric_meter)
#             metrics = metric_meter.get()
#             print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")

#             temp = []
#             for type in CIFARRobustness.TYPES:
#                 for level in CIFARRobustness.LEVELS:
#                     dset = CIFARRobustness(
#                         root=args.rob_data_root, type=type, level=level, transform=transform
#                     )
#                     loader = DataLoader(
#                         dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
#                     )
#                     eval(loader, model, metric_meter)
#                     metrics = metric_meter.get()
#                     print(
#                         f"{args.dset} {type}_{level}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
#                     )
#                     temp.append(str(round(metrics['acc'], 4)*100))
#             f.write(" ".join(temp) + "\n")
#             f.flush()

#             print("finished evaluating on ckpt")
#             print("---------------------------")
# f.close()

#################################################
# Evaluate best aug and best ticket
#################################################

import argparse
import utils
import random
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import resnet
import os

NETWORKS = {
    "resnet18": resnet.resnet18,
    "resnet34": resnet.resnet34,
    "resnet50": resnet.resnet50,
    "resnet101": resnet.resnet101,
    "resnet152": resnet.resnet152,
}

CKPTS = [
    # # best augs
    # "sparse_1_autoaug/best_imp_0.ckpt",
    # "sparse_0.5_autoaug/best_imp_0.ckpt",
    # "sparse_0.2_autoaug/best_imp_0.ckpt",
    # "sparse_0.1_autoaug/best_imp_0.ckpt",
    # "sparse_0.02_autoaug/best_imp_0.ckpt",
    # "sparse_0.01_contrastaug/best_imp_0.ckpt",
    # # best winning tickets
    # "sparse_1_autoaug/best_imp_2.ckpt",
    # "sparse_0.5_autoaug/best_imp_3.ckpt",
    # "sparse_0.2_autoaug/best_imp_2.ckpt",
    # "sparse_0.1_autoaug/best_imp_4.ckpt",
    # "sparse_0.02_autoaug/best_imp_15.ckpt",
    # "sparse_0.01_randaug/best_imp_15.ckpt",

    # "sparse_1_baseaug/best_imp_2.ckpt",
    # "sparse_0.5_baseaug/best_imp_6.ckpt",
    # "sparse_0.2_baseaug/best_imp_6.ckpt",
    # "sparse_0.1_baseaug/best_imp_15.ckpt",
    # "sparse_0.02_baseaug/best_imp_14.ckpt",
    # "sparse_0.01_baseaug/best_imp_15.ckpt",

    # "sparse_1_contrastaug/best_imp_3.ckpt",
    # "sparse_0.5_contrastaug/best_imp_3.ckpt",
    # "sparse_0.2_contrastaug/best_imp_2.ckpt",
    # "sparse_0.1_contrastaug/best_imp_12.ckpt",
    # "sparse_0.02_contrastaug/best_imp_15.ckpt",
    # "sparse_0.01_contrastaug/best_imp_15.ckpt",

    # "sparse_1_randaug/best_imp_1.ckpt",
    # "sparse_0.5_randaug/best_imp_1.ckpt",
    # "sparse_0.2_randaug/best_imp_1.ckpt",
    # "sparse_0.1_randaug/best_imp_10.ckpt",
    # "sparse_0.02_randaug/best_imp_13.ckpt",
    # "sparse_0.01_randaug/best_imp_15.ckpt",

    # "sparse_1_autoaug/best_imp_2.ckpt",
    # "sparse_0.5_autoaug/best_imp_1.ckpt",
    # "sparse_0.2_autoaug/best_imp_2.ckpt",
    # "sparse_0.1_autoaug/best_imp_4.ckpt",
    # "sparse_0.02_autoaug/best_imp_15.ckpt",
    # "sparse_0.01_autoaug/best_imp_15.ckpt",
]

OUT_DIR = "output/cifar10_ckpts"
# for ckpt in CKPTS:
#     if not os.path.exists(os.path.join(OUT_DIR, ckpt)):
#         print(ckpt)
# exit()
parser = argparse.ArgumentParser()
parser = utils.add_args(parser)
parser.add_argument(
    "--rob_data_root", type=str, required=True, help="path to transformed data directory"
)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

device, _ = utils.setup_device(False)
criterion = torch.nn.CrossEntropyLoss()
metric_meter = utils.AvgMeter()


@torch.no_grad()
def eval(loader, model, metric_meter):
    metric_meter.reset()
    model.eval()
    for indx, (img, target) in enumerate(loader):
        img, target = img.to(device), target.to(device)

        pred = model(img)
        loss = criterion(pred, target)

        pred_cls = pred.argmax(dim=1)
        acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

        metrics = {"loss": loss.item(), "acc": acc}
        metric_meter.add(metrics)
        utils.pbar(indx / len(loader), msg=metric_meter.msg())
    utils.pbar(1, msg=metric_meter.msg())


class CIFARRobustness(Dataset):
    TYPES = [
        # noise
        "gaussian_noise",
        "shot_noise",
        "impulse_noise",
        # blur
        "defocus_blur",
        "glass_blur",
        "motion_blur",
        "zoom_blur",
        # weather
        "snow",
        "frost",
        "fog",
        "brightness",
        # digital
        "contrast",
        "elastic_transform",
        "pixelate",
        "jpeg_compression",
        # extra
        "gaussian_blur",
        "saturate",
        "spatter",
        "speckle_noise",
    ]
    LEVELS = [1, 2, 3, 4, 5]

    def __init__(self, root, type, level, transform):
        assert type in self.TYPES
        assert level in self.LEVELS
        imgs = np.load(os.path.join(root, f"{type}.npy"))
        labels = np.load(os.path.join(root, "labels.npy"))
        self.imgs = imgs[(level - 1) * 10_000 : level * 10_000]
        self.labels = labels[(level - 1) * 10_000 : level * 10_000]
        self.transform = transform

    def __getitem__(self, indx):
        img = self.imgs[indx]
        label = self.labels[indx]
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)


f = open(f"{args.dset}_rob_s_results.txt", "w")
for ckpt_f in CKPTS:
    ckpt = os.path.join(OUT_DIR, ckpt_f)
    print(f"Evaluating: {ckpt}")
    ckpt = torch.load(ckpt)
    if args.dset == "cifar10":
        norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        n_cls = 10
    elif args.dset == "cifar100":
        norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        n_cls = 100
    else:
        raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
    model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small", pretrained=False).to(device)
    if int(os.path.basename(ckpt_f).split(".")[0].split("_")[-1]):
        model.load_state_dict(ckpt["init"])
        curr_mask = utils.extract_mask(ckpt["model"])
        utils.mask_prune(model, curr_mask)
        print("remaining weight = ", utils.check_sparsity(model))
    model.load_state_dict(ckpt["model"])

    # basic
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            norm,
        ]
    )
    dset = datasets.CIFAR10(
        root=args.data_root,
        train=False,
        transform=transform,
        download=True,
    )
    loader = DataLoader(
        dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
    )
    eval(loader, model, metric_meter)
    metrics = metric_meter.get()
    print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")

    temp = []
    for type in CIFARRobustness.TYPES:
        for level in CIFARRobustness.LEVELS:
            dset = CIFARRobustness(
                root=args.rob_data_root, type=type, level=level, transform=transform
            )
            loader = DataLoader(
                dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
            )
            eval(loader, model, metric_meter)
            metrics = metric_meter.get()
            print(
                f"{args.dset} {type}_{level}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
            )
            temp.append(str(round(metrics["acc"], 4) * 100))
    f.write(" ".join(temp) + "\n")
    f.flush()

    print("finished evaluating on ckpt")
    print("---------------------------")
f.close()
