import pandas as pd
import torch
import numpy as np
from pathlib import Path
import torch.nn.functional as F

import utilis
import wandb
import argparse
import random
import torch.nn as nn
from config import get_config
from networks import get_network
from datasets import load_data, Dataset
import os

# os.environ["WANDB_MODE"]="disabled"
wandb.init(project='Bound', reinit=True)
parser = argparse.ArgumentParser()
parser.add_argument('--net', type=str, default='ResNet', choices=['SNN', 'IM', 'BNN', 'ResNet'], help='model used')
parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['FashionMNIST', 'CIFAR10'], help='dataset used')
parser.add_argument('--smooth', type=bool, default=True, help='Set only to true for the experiments on the r-values')
parser.add_argument('--adv_load', type=bool, default=False, help='If true, adversarials are loaded- you have to run this with false once')
parser.add_argument('--device', default=0, type=int, help='If you have more than one gpu, select the one on which the code is run')
parser.add_argument('--n_samples', default=50, type=int, help='Amount of samples used during inference')
parser.add_argument('--attack_iteration_number', default=100, type=int, help='Only applicable for PGD: number of iterations')
parser.add_argument('--droprate', type=float, default=0.6, help='Only applicable for ResNet, specifies the dropout probability')
parser.add_argument('--stoch_varianz', default=0.05, type=float, help='Only applicable for SNN models - variance of noise added to the input')
parser.add_argument('--max_eps', type=float, default=.2, help='Maximal attack strength, we take 10 equidistant steps to reach the maximal strength')
parser.add_argument('--attack', type=str, default='FGM',
                    choices=['FGM', 'PGD', 'CW', 'margin', 'FGSM', 'strong'], help='Different attacks')
parser.add_argument('--attack_samples', type=int, default=10, help='Number of samples used during the attack')
args = parser.parse_args()
args = get_config(args)
wandb.config.update(args)
torch.cuda.set_device(args.device)
wandb.run.name = f'''{args.net}_{args.attack_samples}_{wandb.run.id}'''


def main(args):
    # get data
    _, _, red_test_loader = load_data(args.dataset, args.batch_size, args.root_dir)

    # get model
    model = get_network(args)

    # load model
    if args.smooth:
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_smooth_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.droprate}_{args.smooth_level}.bin'''), map_location='cpu')
    elif args.net == 'SNN':
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.layer}_{args.stoch_varianz}.bin'''), map_location='cpu')
    else:
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.droprate}.bin'''), map_location='cpu')
    model.load_state_dict(parameter)
    if args.net == 'ResNet':
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    # different seed than during training
    args.randseed += 1000
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    X, advs, epsilon, targets = create_adv_sample(model, red_test_loader, args)
    # different seed than during adversarial creation
    args.randseed += 1000
    # evaluate adversarial examples
    for idx in range(len(X)):
        eval_adv_samples(model, advs[idx], targets[idx], epsilon[idx], args)

def attack(model, x, y, eps, args):
    if args.attack == 'PGD':
        adv = utilis.projected_gradient_descent(model, x, eps, args.attack_iteration_number, args.attack_samples, y)
    elif args.smooth:
        adv = utilis.fast_gradient_method_smooth(model, x, y, eps, args)
    else:
        adv = utilis.fast_gradient_method(model, x, y, eps, args.attack_samples, args.attack, clip_min=0, clip_max=1)
    return adv


def create_adv_sample(model, data_loader, args):
    all_advs = []
    all_data = []
    all_eps = []
    all_targets = []
    model.cuda()
    max_eps = args.max_eps
    step = np.round(max_eps / 10, 3)
    for eps in np.arange(0, max_eps, step):
        eps = np.round(eps, 4)
        epsilon = []
        targets = []
        advs = []
        deltas = []
        benign = []
        for idx, (X_mb, t_mb) in enumerate(data_loader):
            if idx>0 and idx % 100 == 0:
                print(idx)
                break
            X_mb, t_mb = X_mb.cuda(), t_mb.long().cuda()
            if not args.adv_load:
                x_advs = attack(model, X_mb, t_mb, eps, args)
                advs.append(x_advs.detach().cpu().numpy())
                delta_attack = x_advs - X_mb
                deltas.append(delta_attack.detach().cpu().numpy())
            epsilon.append(np.repeat(eps, X_mb.shape[0]))
            targets.append(t_mb.cpu().numpy())
            benign.append(X_mb.cpu().numpy())

        file_name = utilis.get_string(args, eps)
        if args.adv_load:
            advs = np.load(Path(args.root_dir,
                                f'''data/adversarial/advs_{file_name}.npy'''))
            advs = torch.from_numpy(advs)
        else:
            advs = torch.from_numpy(np.stack(advs, axis=0))
            deltas = torch.from_numpy(np.stack(deltas, axis=0))
            benign_data = torch.from_numpy(np.stack(benign, axis=0))
            benign_target = torch.from_numpy(np.concatenate(targets, axis=0))
            epsilons_all = torch.from_numpy(np.concatenate(epsilon, axis=0))
            if args.smooth:
                np.save(Path(args.root_dir,
                             f'''data/adversarial/advs_smooth_{file_name}_{args.smooth_level}.npy'''),
                        advs)
                np.save(Path(args.root_dir,
                             f'''data/adversarial/deltas_smooth_{file_name}_{args.smooth_level}.npy'''),
                        deltas)
                np.save(Path(args.root_dir,
                             f'''data/adversarial/benign_data_smooth_{file_name}_{args.smooth_level}.npy'''),
                        benign_data)
                np.save(Path(args.root_dir,
                             f'''data/adversarial/benign_target_smooth_{file_name}_{args.smooth_level}.npy'''),
                        benign_target)
                np.save(Path(args.root_dir,
                             f'''data/adversarial/epsilons_all_smooth_{file_name}_{args.smooth_level}.npy'''),
                        epsilons_all)
            else:
                np.save(Path(args.root_dir, f'''data/adversarial/advs_{file_name}.npy'''),
                        advs)
                if args.net == 'ResNet' and eps == 0.3:
                    np.save(Path(args.root_dir,
                                 f'''data/adversarial/deltas_{file_name}.npy'''),
                            deltas)
                    np.save(Path(args.root_dir,
                                 f'''data/adversarial/benign_data_{file_name}.npy'''),
                            benign_data)
                    np.save(Path(args.root_dir,
                                 f'''data/adversarial/benign_target_{file_name}.npy'''),
                            benign_target)
                    np.save(Path(args.root_dir,
                                 f'''data/adversarial/epsilons_all_{file_name}.npy'''),
                            epsilons_all)
                elif args.dataset == 'FashionMNIST' and eps == 1.5:
                    np.save(Path(args.root_dir,
                                 f'''data/adversarial/deltas_{file_name}.npy'''),
                            deltas)
                    np.save(Path(args.root_dir,
                                 f'''data/adversarial/benign_data_{file_name}.npy'''),
                            benign_data)
                    np.save(Path(args.root_dir,
                                 f'''data/adversarial/benign_target_{file_name}.npy'''),
                            benign_target)
                    np.save(Path(args.root_dir,
                                 f'''data/adversarial/epsilons_all_{file_name}.npy'''),
                            epsilons_all)

        epsilon = np.concatenate(epsilon, axis=0)
        targets = np.concatenate(targets, axis=0)
        all_advs.append(advs)
        all_data.append(benign)
        all_targets.append(targets)
        all_eps.append(epsilon)
    return all_data, all_advs, all_eps, all_targets


def predict_model(model, test_data, pred_number):
    preds = []
    for _ in range(pred_number):
        pred = model.forward(test_data)
        if args.net != 'SNN':
            pred = F.softmax(pred, dim=1)
        preds.append(pred.cpu().data.numpy())
    return np.array(preds)


def eval_adv_samples(model, advs, targets, eps, args):
    preds = []
    labels = []
    adv_data = Dataset(advs, targets)
    adv_loader = torch.utils.data.DataLoader(adv_data, batch_size=100, shuffle=False)
    model.cuda()
    for _, (X_mb, t_mb) in enumerate(adv_loader):
        X_mb, t_mb = X_mb.cuda().squeeze(), t_mb.long()
        pred = predict_model(model, X_mb, args.n_samples)
        preds.append(pred)
        labels.append(t_mb)
    labels = np.concatenate(labels, 0)
    predictions = np.concatenate(preds, 1)
    pred_value = np.max(np.mean(predictions, axis=0), axis=1)
    predicted = np.argmax(np.mean(predictions, axis=0), axis=1)
    print(f'accuracy : {1 - (np.count_nonzero((predicted - labels)) / len(predicted))} for eps = {eps[0]}')
    wandb.log({'adv_accuracy': 1 - (np.count_nonzero((predicted - labels)) / len(predicted)), 'step': eps[0]})
    data_pred = np.column_stack(
        (eps[:1000], labels, predicted, pred_value))
    info = pd.DataFrame(data=data_pred,
                        columns=['eps', 'real', 'predicted_adv', 'pred_value'])
    file_name = utilis.get_string(args, eps[0])
    info.to_pickle(Path(args.root_dir,
                        f'''results/r_value_adv_1000_{file_name}_{args.n_samples}_{args.randseed}.h5'''))


if __name__ == "__main__":
    main(args)
