# import argparse
# import utils
# import random
# import numpy as np
# import torch
# import resnet
# import os

# NETWORKS = {
#     "resnet18": resnet.resnet18,
#     "resnet34": resnet.resnet34,
#     "resnet50": resnet.resnet50,
#     "resnet101": resnet.resnet101,
#     "resnet152": resnet.resnet152,
# }

# DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02", "0.01"]
# AUGS = ["baseaug", "contrastaug", "randaug", "autoaug"]
# ITERS = [4, 9, 14]
# OUT_DIR = "output/cifar10_ckpts"
# # OUT_DIR = "to_send/"

# parser = argparse.ArgumentParser()
# parser = utils.add_args(parser)
# args = parser.parse_args()

# random.seed(args.seed)
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
# torch.cuda.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)

# # for data_size in DATA_SIZES:
# #     for iter in ITERS:
# #         for aug in AUGS:
# #             ckpt = os.path.join(OUT_DIR, f"sparse_{data_size}_{aug}", f"best_imp_{iter}.ckpt")
# #             ckpt = torch.load(ckpt, map_location=torch.device("cpu"))
# #             model = NETWORKS[args.net](n_cls=10, pre_conv="small")
# #             if iter:
# #                 model.load_state_dict(ckpt["init"])
# #                 curr_mask = utils.extract_mask(ckpt["model"])
# #                 utils.mask_prune(model, curr_mask)
# #             model.load_state_dict(ckpt["model"])

# #             sparsity = []
# #             for name, m in model.named_modules():
# #                 if isinstance(m, torch.nn.Conv2d):
# #                     actual = float(m.weight.nelement())
# #                     sparse = float(torch.sum(m.weight == 0))
# #                     sparsity.append(str(round((1 - sparse / actual) * 100, 2)))
# #             print(" ".join(sparsity))
# #         print("\n")
# #     print("\n")

# CKPTS = [
#     # best winning tickets
#     "sparse_advprop_0.01_baseaug/best_imp_15.ckpt",
#     "sparse_advprop_0.02_baseaug/best_imp_14.ckpt",

#     "sparse_advprop_0.01_randaug/best_imp_15.ckpt",
#     "sparse_advprop_0.02_autoaug/best_imp_14.ckpt",
#     # "sparse_1_autoaug/best_imp_2.ckpt",
#     # "sparse_0.5_autoaug/best_imp_3.ckpt",
#     # "sparse_0.2_autoaug/best_imp_2.ckpt",
#     # "sparse_0.1_autoaug/best_imp_4.ckpt",
#     # "sparse_0.02_autoaug/best_imp_15.ckpt",
#     # "sparse_0.01_randaug/best_imp_15.ckpt",
# ]

# for ckpt_f in CKPTS:
#     ckpt = os.path.join(OUT_DIR, ckpt_f)
#     ckpt = torch.load(ckpt, map_location=torch.device("cpu"))
#     model = NETWORKS[args.net](n_cls=10, pre_conv="small", pretrained=False)
#     utils.modify_bn(model)
#     setattr(
#         model,
#         "attacker",
#         utils.PGDAttacker(
#             args.attack_n_iter, args.attack_eps, args.attack_step_size, 0.2
#         ),
#     )
#     model = model
#     if int(os.path.basename(ckpt_f).split(".")[0].split("_")[-1]):
#         model.load_state_dict(ckpt["init"])
#         curr_mask = utils.extract_mask(ckpt["model"])
#         utils.mask_prune(model, curr_mask)
#     model.load_state_dict(ckpt["model"])

#     sparsity = []
#     for name, m in model.named_modules():
#         if isinstance(m, torch.nn.Conv2d):
#             actual = float(m.weight.nelement())
#             sparse = float(torch.sum(m.weight == 0))
#             sparsity.append(str(round((1 - sparse / actual) * 100, 2)))
#     print(" ".join(sparsity))
#     print("\n")

# import argparse
# import utils
# import random
# import numpy as np
# import torch
# import resnet
# import os
# import matplotlib.pyplot as plt

# NETWORKS = {
#     "resnet18": resnet.resnet18,
#     "resnet34": resnet.resnet34,
#     "resnet50": resnet.resnet50,
#     "resnet101": resnet.resnet101,
#     "resnet152": resnet.resnet152,
# }

# CKPTS = [
#     # best winning tickets
#     # "sparse_1_autoaug/best_imp_2.ckpt",
#     # "sparse_0.5_autoaug/best_imp_3.ckpt",
#     # "sparse_0.2_autoaug/best_imp_2.ckpt",
#     # "sparse_0.1_autoaug/best_imp_4.ckpt",


#     "sparse_0.02_baseaug/best_imp_14.ckpt",
#     "sparse_0.01_baseaug/best_imp_15.ckpt",
    
#     "sparse_0.02_autoaug/best_imp_14.ckpt",
#     "sparse_0.01_randaug/best_imp_15.ckpt",
# ]
# OUT_DIR = "output/cifar10_ckpts"

# parser = argparse.ArgumentParser()
# parser = utils.add_args(parser)
# args = parser.parse_args()

# random.seed(args.seed)
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
# torch.cuda.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)

# device, _ = utils.setup_device(False)
# n_cls = 10

# for ckpt_f in CKPTS:
#     ckpt = os.path.join(OUT_DIR, ckpt_f)
#     print(f"Evaluating: {ckpt}")
#     ckpt = torch.load(ckpt)
#     model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small", pretrained=False).to(device)
#     if int(os.path.basename(ckpt_f).split(".")[0].split("_")[-1]):
#         model.load_state_dict(ckpt["init"])
#         curr_mask = utils.extract_mask(ckpt["model"])
#         utils.mask_prune(model, curr_mask)
#         print("remaining weight = ", utils.check_sparsity(model))
#     model.load_state_dict(ckpt["model"])

#     sparsity = []
#     for name, m in model.named_modules():
#         if isinstance(m, torch.nn.Conv2d):
#             actual = float(m.weight.nelement())
#             sparse = float(torch.sum(m.weight == 0))
#             sparsity.append(str(round((1 - sparse / actual) * 100, 2)))
#     print(" ".join(sparsity))
#     print("\n")

#     # plt.figure()
#     # for i in range(5):
#     #     for j in range(4):
#     #         plt.subplot2grid((5, 4), (i, j))
#     # i = 0
#     # for name, mask in curr_mask.items():
#     #     m = mask.view(-1, mask.shape[2], mask.shape[3]).mean(dim=0)
#     #     if m.shape[1] != 1:
#     #         ax = plt.subplot2grid((5, 4), (i // 4, i % 4))
#     #         ax.imshow(m.detach().cpu().numpy(), cmap="gray")
#     #     i += 1
#     # plt.show()
#     # exit()


import argparse
import utils
import random
import numpy as np
import torch
import resnet
import os
import matplotlib.pyplot as plt

NETWORKS = {
    "resnet18": resnet.resnet18,
    "resnet34": resnet.resnet34,
    "resnet50": resnet.resnet50,
    "resnet101": resnet.resnet101,
    "resnet152": resnet.resnet152,
}

CKPTS1 = [
    # best augs
    "sparse_1_autoaug/best_imp_0.ckpt",
    "sparse_0.5_autoaug/best_imp_0.ckpt",
    "sparse_0.2_autoaug/best_imp_0.ckpt",
    "sparse_0.1_autoaug/best_imp_0.ckpt",
    "sparse_0.02_autoaug/best_imp_0.ckpt",
    "sparse_0.01_contrastaug/best_imp_0.ckpt",
]
CKPTS2 = [
    # best winning tickets
    "sparse_1_autoaug/best_imp_2.ckpt",
    "sparse_0.5_autoaug/best_imp_3.ckpt",
    "sparse_0.2_autoaug/best_imp_2.ckpt",
    "sparse_0.1_autoaug/best_imp_4.ckpt",
    "sparse_0.02_autoaug/best_imp_15.ckpt",
    "sparse_0.01_randaug/best_imp_15.ckpt",
]

OUT_DIR = "output/cifar10_ckpts"

parser = argparse.ArgumentParser()
parser = utils.add_args(parser)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

device, _ = utils.setup_device(False)
n_cls = 10

from numpy.linalg import norm

for ckpt_f1, ckpt_f2 in zip(CKPTS1, CKPTS2):
    ckpt1 = os.path.join(OUT_DIR, ckpt_f1)
    print(f"Evaluating: {ckpt1}")
    ckpt1 = torch.load(ckpt1)
    model1 = NETWORKS[args.net](n_cls=n_cls, pre_conv="small", pretrained=False).to(device)
    if int(os.path.basename(ckpt_f1).split(".")[0].split("_")[-1]):
        model1.load_state_dict(ckpt1["init"])
        curr_mask = utils.extract_mask(ckpt1["model"])
        utils.mask_prune(model1, curr_mask)
        print("remaining weight = ", utils.check_sparsity(model1))
    model1.load_state_dict(ckpt1["model"])

    ckpt2 = os.path.join(OUT_DIR, ckpt_f2)
    print(f"Evaluating: {ckpt2}")
    ckpt2 = torch.load(ckpt2)
    model2 = NETWORKS[args.net](n_cls=n_cls, pre_conv="small", pretrained=False).to(device)
    if int(os.path.basename(ckpt_f2).split(".")[0].split("_")[-1]):
        model2.load_state_dict(ckpt2["init"])
        curr_mask = utils.extract_mask(ckpt2["model"])
        utils.mask_prune(model2, curr_mask)
        print("remaining weight = ", utils.check_sparsity(model2))
    model2.load_state_dict(ckpt2["model"])

    norms = []
    for (name1, m1), (name2, m2) in zip(model1.named_modules(), model2.named_modules()):
        if isinstance(m1, torch.nn.Conv2d):
            mask = m2.weight != 0
            wt1 = m1.weight[mask]
            wt2 = m2.weight[mask]

            n = norm(wt2.detach().cpu().numpy()) / norm(wt1.detach().cpu().numpy())
            norms.append(str(n))
    print(" ".join(norms))
    print("\n")

    # plt.figure()
    # for i in range(5):
    #     for j in range(4):
    #         plt.subplot2grid((5, 4), (i, j))
    # i = 0
    # for name, mask in curr_mask.items():
    #     m = mask.view(-1, mask.shape[2], mask.shape[3]).mean(dim=0)
    #     if m.shape[1] != 1:
    #         ax = plt.subplot2grid((5, 4), (i // 4, i % 4))
    #         ax.imshow(m.detach().cpu().numpy(), cmap="gray")
    #     i += 1
    # plt.show()
    # exit()