import torch
import numpy as np
import torch.nn.functional as F
import random
from cleverhans.torch.utils import clip_eta
import math

def get_string(args, eps):
    if args.attack != 'PGD':
        args.attack_iteration_number = 1
    if args.net == 'SNN':
        file_name = f'''{args.dataset}_{args.net}_{args.attack}_{args.attack_samples}_{args.max_eps}_{eps}_{args.droprate}_{args.attack_iteration_number}_{args.stoch_varianz}'''
    else:
        file_name = f'''{args.dataset}_{args.net}_{args.attack}_{args.attack_samples}_{args.max_eps}_{eps}_{args.droprate}_{args.attack_iteration_number}'''
    return file_name


def get_loss(model, x, y, n_attacks ,attack):
    if attack =='FGM' or attack == 'FGSM':
        loss_fn = torch.nn.CrossEntropyLoss()
        pred = torch.zeros(x.shape[0], 10, device=x.device)
        for _ in range(n_attacks):
            logits = model(x)
            pred += (1 / n_attacks) * F.softmax(logits, dim=1)
            loss = loss_fn(pred, y)
    elif attack == 'margin':
        pred = torch.zeros(x.shape[0], 10, device=x.device)
        for _ in range(n_attacks):
            logits = model(x)
            pred += (1 / n_attacks) * F.softmax(logits, dim=1)
        idx_all = torch.arange(10).view(1, 10).expand_as(pred).cuda()
        mask = y.unsqueeze(1).expand_as(idx_all) != idx_all
        rest = pred[mask].view(x.shape[0], 9)
        second_prob, _ = torch.max(rest, 1)
        loss = -torch.sum((pred[:, y] - second_prob))
    elif attack =='strong':
        # this correspons to the margin loss on the logits
        pred = torch.zeros(x.shape[0], 10, device=x.device)
        for _ in range(n_attacks):
            logits = model(x)
            pred += (1 / n_attacks) * logits
        idx_all = torch.arange(10).view(1, 10).expand_as(pred).cuda()
        mask = y.unsqueeze(1).expand_as(idx_all) != idx_all
        rest = pred[mask].view(x.shape[0], 9)
        second_prob, _ = torch.max(rest, 1)
        loss = -torch.sum((pred[:, y] - second_prob))
    elif attack =='CW':
        pred = torch.zeros(x.shape[0], 10, device=x.device)
        for i in range(n_attacks):
            pred += (1 / n_attacks) * model(x)
        loss = cw(pred, y)
    return loss


def fast_gradient_method( model,  x, y, eps, n_attacks, attack, clip_min, clip_max ):
    """
    PyTorch implementation of the Fast Gradient Method from cleverhans.
    MIT License

    Copyright (c) 2019 Google Inc., OpenAI and Pennsylvania State University

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE.

    Partly modified.
    """
	
    if eps == 0:
        return x

    # x needs to be a leaf variable, of floating point type and have requires_grad being True for
    # its grad to be computed and stored properly in a backward call
    x = x.clone().detach().to(torch.float).requires_grad_(True)

    # Compute loss
    loss = get_loss(model, x, y, n_attacks, attack)

    # Define gradient of loss wrt input
    loss.backward()
    if attack =='FGSM':
        optimal_perturbation = optimize_linear(x.grad, eps, np.inf)
    else:
        optimal_perturbation = optimize_linear(x.grad, eps, 2)

    # Add perturbation to original example to obtain adversarial example
    adv_x = x + optimal_perturbation
    adv_x = torch.clamp(adv_x, clip_min, clip_max)

    return adv_x


def fast_gradient_method_smooth( model, x, y, eps, args):
    """
    Modified PyTorch implementation of the Fast Gradient Method from cleverhans.
    MIT License

    Copyright (c) 2019 Google Inc., OpenAI and Pennsylvania State University

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE.

    Partly modified.
    """

    if eps == 0:
        return x
    # x needs to be a leaf variable, of floating point type and have requires_grad being True for
    # its grad to be computed and stored properly in a backward call
    x = x.clone().detach().to(torch.float).requires_grad_(True)

    # Compute loss
    loss_fn = torch.nn.CrossEntropyLoss()
    pred = torch.zeros(x.shape[0], 10, device=x.device)
    for ind_x in range(args.noisy_pred_samples):
        X_mb = torch.normal(mean=torch.zeros_like(x),
                            std=torch.ones_like(x) * math.sqrt(args.smooth_level)).cuda() + x
        X_mb = torch.clip(X_mb, 0, 1)
        for _ in range(args.attack_samples):
            logits = model(X_mb)
            pred += (1 / (args.attack_samples * args.noisy_pred_samples)) * F.softmax(logits, dim=1)
    loss = loss_fn(pred, y)

    # Define gradient of loss wrt input
    loss.backward()
    optimal_perturbation = optimize_linear(x.grad, eps, 2)

    # Add perturbation to original example to obtain adversarial example
    adv_x = x + optimal_perturbation
    adv_x = torch.clamp(adv_x, 0, 1)
    return adv_x



def projected_gradient_descent(model, x, eps, attack_iter, attack_samples, y):
    """
    Modified PyTorch implementation from cleverhans.
    MIT License

    Copyright (c) 2019 Google Inc., OpenAI and Pennsylvania State University

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE.

    Partly modified.
    """
    if eps == 0:
        return x
    eps_iter=eps/(attack_iter/2)
    if eps_iter == 0:
        return x

    adv_x = x
    adv_x = torch.clamp(adv_x, 0, 1)

    i = 0
    while i < attack_iter:
        adv_x = fast_gradient_method( model,  adv_x,  y, eps_iter, attack_samples, 'FGM', 0, 1)
        # Clipping perturbation eta to norm norm ball
        eta = adv_x - x
        eta = clip_eta(eta, 2, eps)
        adv_x = x + eta

        # Redo the clipping.
        adv_x = torch.clamp(adv_x, 0, 1)
        i += 1

    return adv_x



def cw(pred, label):
    y_onehot = torch.nn.functional.one_hot(label, 10).to(torch.float)
    real = torch.sum(y_onehot * pred, 1)
    other, _ = torch.max((1 - y_onehot) * pred - y_onehot * 1e4, 1)
    loss = - torch.max(
        (real - other) + 0.1,
        torch.tensor(0.).to(real.device)
    )
    return torch.mean(loss)


def optimize_linear(grad, eps, norm):
    """
    Modified PyTorch implementation from cleverhans.
    MIT License

    Copyright (c) 2019 Google Inc., OpenAI and Pennsylvania State University

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE.

    Partly modified.
    """
    red_ind = list(range(1, len(grad.size())))
    avoid_zero_div = torch.tensor(1e-12, dtype=grad.dtype, device=grad.device)
    if norm == np.inf:
        # Take sign of gradient
        optimal_perturbation = torch.sign(grad)
    elif norm == 2:
        square = torch.max(avoid_zero_div, torch.sum(grad ** 2, red_ind, keepdim=True))
        optimal_perturbation = grad / torch.sqrt(square)
        opt_pert_norm = (
            optimal_perturbation.pow(2).sum(dim=red_ind, keepdim=True).sqrt()
        )
        one_mask = (square <= avoid_zero_div).to(torch.float) * opt_pert_norm + (
            square > avoid_zero_div
        ).to(torch.float)
        assert torch.allclose(opt_pert_norm, one_mask, rtol=1e-05, atol=1e-08)
    else:
        raise NotImplementedError(
            "Only L-inf, L1 and L2 norms are " "currently implemented."
        )
    scaled_perturbation = eps * optimal_perturbation
    return scaled_perturbation


def predict_model(model, adv_data, true_data, label, pred_number, delta, args):
    '''
    INPUT: model and data to predict
    OUTPUT: a numpy array of size : pred_number, number of samples, number of output nodes
    '''
    # calculate adversarial example prediction
    args.randseed +=1
    preds= []
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    with torch.no_grad():
        for _ in range(pred_number):
            pred = model.forward(adv_data)
            if args.net != 'SNN':
                pred = F.softmax(pred, dim=1)
            preds.append(pred.squeeze(0).detach().cpu())
    preds= torch.stack(preds, dim=0)
    predicted = torch.mean(preds, dim=0)
    pred_label_adv= torch.argmax(predicted)

    ## calculate the gradients:
    preds = []
    true_data.requires_grad_(True)
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    for _ in range(pred_number):
        pred = model.forward(true_data)
        if args.net != 'SNN':
            pred = F.softmax(pred, dim=1)
        preds.append(pred.squeeze(0))
    # now build mean
    preds= torch.stack(preds, dim=0)
    predicted = torch.mean(preds, dim=0)
    pred_label = torch.argmax(predicted)
    true_class = predicted[label]
    gradients = []
    zaehler = []
    norm = []
    alpha = []
    for i in range(10):
        if i != label:
            loss = true_class - predicted[i]
            loss.backward(retain_graph=True)
            zaehler.append(loss.data.item())
            grad_value = true_data.grad.detach().cpu().squeeze(0)
            gradients.append(grad_value)
            true_data.grad = torch.zeros(true_data.shape, device=true_data.device)
            norm_val =np.linalg.norm(grad_value)
            norm.append(norm_val)
            angle = angle_between(delta, -grad_value)
            alpha.append(angle)
        else:
            zaehler.append(0)
            norm.append(1)
            alpha.append(1)
    del true_data, loss, pred
    torch.cuda.empty_cache()

    denominator = (np.cos(np.array(alpha)) * np.array(norm))+1e-20
    r_values = -np.array(zaehler) / denominator
    zero_cos = np.where(np.cos(np.array(alpha)) <= 0)
    if len(zero_cos[0]) > 1:
        r_values[zero_cos] = np.inf
    r_values[label]= np.inf
    if pred_label!= label:
        r_values[:]=0

    min_r_idx = np.argmin(r_values)
    min_rvalue = r_values[min_r_idx]
    return pred_label_adv.numpy(), min_rvalue, min_r_idx, zaehler, norm, alpha


def predict_model_smooth(model, adv_data, true_data, label, pred_number, delta, args, eps):
    '''
    INPUT: model and data to predict
    OUTPUT: a numpy array of size : pred_number, number of samples, number of output nodes
    '''
    # calculate adversarial example prediction
    args.randseed +=1
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    pred = torch.zeros(adv_data.shape[0], 10)
    for ind_x in range(args.noisy_pred_samples):
        X_mb = torch.normal(mean=torch.zeros_like(adv_data),
                            std=torch.ones_like(adv_data) * math.sqrt(args.smooth_level)).cuda() + adv_data
        X_mb = torch.clip(X_mb, 0, 1)
        with torch.no_grad():
            for _ in range(pred_number):
                logits = model(X_mb)
                pred += (1 / (pred_number * args.noisy_pred_samples)) * F.softmax(logits.cpu(), dim=1)

    pred_label_adv= torch.argmax(pred)

    ## calculate the gradients:
    preds = []
    true_data.requires_grad_(True)
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    pred = torch.zeros(true_data.shape[0], 10, device=true_data.device)
    for ind in range(args.noisy_pred_samples):
        X_mb = torch.normal(mean=torch.zeros_like(true_data),
                            std=torch.ones_like(true_data) * math.sqrt(args.smooth_level)).cuda() + true_data
        X_mb = torch.clip(X_mb, 0, 1)
        for _ in range(pred_number):
            logits = model(X_mb)
            pred += (1 / (pred_number * args.noisy_pred_samples)) * F.softmax(logits, dim=1)

    pred= pred.squeeze(0)
    pred_label = torch.argmax(pred)
    true_class = pred[label]
    gradients = []
    zaehler = []
    norm = []
    alpha = []
    for i in range(10):
        if i != label:
            loss = true_class - pred[i]
            loss.backward(retain_graph=True)
            zaehler.append(loss.data.item())
            grad_value = true_data.grad.detach().cpu().squeeze(0)
            gradients.append(grad_value)
            true_data.grad = torch.zeros(true_data.shape, device=true_data.device)
            norm_val =np.linalg.norm(grad_value)
            norm.append(norm_val)
            angle = angle_between(delta, -grad_value)
            alpha.append(angle)
        else:
            zaehler.append(0)
            norm.append(1)
            alpha.append(1)
    del true_data, loss, pred
    torch.cuda.empty_cache()
    # in the linear case:
    denominator = (np.cos(np.array(alpha)) * np.array(norm))
    r_values_linear = np.array(zaehler) / denominator
    zero_cos = np.where(np.cos(np.array(alpha)) <= 0)
    if len(zero_cos[0]) > 1:
        r_values_linear[zero_cos] = np.inf
    r_values_linear[label]= np.inf
    if pred_label!= label:
        r_values_linear[:]=0

    min_r_idx_linear = np.argmin(r_values_linear)
    min_rvalue_linear = r_values_linear[min_r_idx_linear]

    # in the smooth case
    denominator = (np.cos(np.array(alpha)) * np.array(norm)) + ((1/args.smooth_level)*eps)
    r_values_smooth = np.array(zaehler) / denominator
    zero_denominator = np.where(denominator <= 0)
    if len(zero_denominator[0]) > 1:
        r_values_smooth[zero_denominator] = np.inf
    r_values_smooth[label] = np.inf
    if pred_label != label:
        r_values_smooth[:] = 0

    min_r_idx_smooth = np.argmin(r_values_smooth)
    min_rvalue_smooth = r_values_smooth[min_r_idx_smooth]

    return pred_label_adv.numpy(), min_rvalue_linear, min_rvalue_smooth, min_r_idx_linear, min_r_idx_smooth ,zaehler, norm, alpha


def angle_between(v1, v2):
    """ Returns the angle between two vectors.  """
    v1_u = unit_vector(v1.flatten())
    v2_u = unit_vector(v2.flatten())
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))


def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector /(np.linalg.norm(vector)+1e-20)
