from __future__ import print_function
import copy
import math
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable


margin_epsilon = 0.05
num_class = 10
MIN_VALUE = -3
MAX_VALUE = 3


def train(model, train_loader, optimizer, args):
    criterion = nn.CrossEntropyLoss()
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    return


def test(model, data_loader, args):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    correct = 0
    for data, target in data_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss += criterion(output, target).data[0]
        pred = output.data.max(1)[1]
        correct += pred.eq(target.data).cpu().sum()
    loss /= len(data_loader.dataset)
    accuracy = 100. * correct / len(data_loader.dataset)
    return accuracy, loss


def to_one_hot(y):
    y_tensor = y.data if isinstance(y, Variable) else y
    y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
    y_one_hot = torch.zeros(y_tensor.size()[0], num_class).scatter_(1, y_tensor, 1)
    return y_one_hot


def test_with_margin(model, data_loader, args):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    correct = 0
    margin_list = torch.Tensor([])
    if args.cuda:
        margin_list = margin_list.cuda()

    for data, target in data_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss += criterion(output, target).data[0]
        pred = output.data.max(1)[1]
        correct += pred.eq(target.data).cpu().sum()

        target = target.data
        output_m = output.clone()
        for i in range(target.size(0)):
            output_m[i, target[i]] = output_m[i, :].min()
        margin_ = (output[:, target].diag() - output_m[:, output_m.max(1)[1]].diag()).data
        margin_list = torch.cat((margin_list, margin_), 0)
    if args.cuda:
        margin_list = margin_list.cpu()
    margin = np.percentile(margin_list.numpy(), 100*margin_epsilon)

    N = len(data_loader.dataset)
    loss /= N
    accuracy = 100. * correct / N
    print('margin\t', margin)

    return accuracy, loss, margin


def n_param(module, init_module):
    bparam = 0 if module.bias is None else module.bias.size(0)
    return bparam + module.weight.size(0) * module.weight.view(module.weight.size(0),-1).size(1)


def norm(module, init_module, p=2, q=2):
    return module.weight.view(module.weight.size(0), -1).norm(p=p, dim=1).norm(q).data[0]


def op_norm(module, init_module, p=float('Inf')):
    _, S, _ = module.weight.view(module.weight.size(0), -1).svd()
    return S.norm(p).data[0]


def reparam(model, prev_layer=None):
    for module_name, child in model.named_children():
        prev_layer = reparam(child, prev_layer)
        if module_name in ['Linear', 'Conv1d', 'Conv2d', 'Conv3d']:
            prev_layer = child
        elif module_name in ['BatchNorm2d', 'BatchNorm1d']:
            with torch.no_grad():
                scale = child.weight / ((child.running_var + child.eps).sqrt())
                prev_layer.bias.copy_( child.bias  + ( scale * (prev_layer.bias - child.running_mean) ) )
                perm = list(reversed(range(prev_layer.weight.dim())))
                prev_layer.weight.copy_((prev_layer.weight.permute(perm) * scale ).permute(perm))
                child.bias.fill_(0)
                child.weight.fill_(1)
                child.running_mean.fill_(0)
                child.running_var.fill_(1)
    return prev_layer


def calc_measure(model, init_model, measure_func, operator, kwargs={}, p=1):
    if operator == 'product':
        measure_val = math.exp(calc_measure(model, init_model, measure_func, 'log_product', kwargs, p))
    elif operator == 'norm':
        measure_val = (calc_measure(model, init_model, measure_func, 'sum', kwargs, p=p)) ** (1. / p)
    else:
        measure_val = 0
        for (module_name, child), (_, init_child) in zip(model.named_children(), init_model.named_children()):
            if module_name in ['Linear', 'Conv1d', 'Conv2d', 'Conv3d']:
                if operator == 'log_product':
                    measure_val += math.log(measure_func(child, init_child, **kwargs))
                elif operator == 'sum':
                    measure_val += (measure_func(child, init_child, **kwargs)) ** p
                elif operator == 'max':
                    measure_val = max(measure_val, measure_func(child, init_child, **kwargs))
            else:
                measure_val += calc_measure(child, init_child, measure_func, operator, kwargs, p=p)
    return measure_val


def lp_path_norm(model, args, p, input_size):
    tmp_model = copy.deepcopy(model)
    tmp_model.eval()
    tmp_model.double()
    for name, param in tmp_model.named_parameters():
        if param.requires_grad:
            param.data = param.data.abs().pow(p)
    data_ones = Variable(torch.ones(input_size)).double()
    if args.cuda:
        data_ones = data_ones.cuda()
    tmp_out = tmp_model(data_ones)
    return tmp_out.data.sum() ** (1. / p )


def calculate_complexity(model, init_model, margin, args):
    nchannels, img_dim  = 3, 32

    model = copy.deepcopy(model)
    reparam(model)
    reparam(init_model)

    Frobenious_norm = calc_measure(model, init_model, norm, 'product', {'p':2, 'q':2}) / margin**2
    spectral_norm = calc_measure(model, init_model, op_norm, 'product', {'p':float('Inf')}) / margin**2
    l1_path_norm = lp_path_norm(model, args, p=1, input_size=[1, nchannels, img_dim, img_dim]) / margin**2
    l2_path_norm = lp_path_norm(model, args, p=2, input_size=[1, nchannels, img_dim, img_dim]) / margin**2
    print('Frobenious norm: {}\tspectral norm: {}'.format(Frobenious_norm, spectral_norm))
    print('l1 path norm: {}\tl2 path norm: {}'.format(l1_path_norm, l2_path_norm))
    return Frobenious_norm, l1_path_norm, l2_path_norm, spectral_norm


def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [min, max] range
    perturbed_image = torch.clamp(perturbed_image, MIN_VALUE, MAX_VALUE)
    # Return the perturbed image
    return perturbed_image


def get_fgsm_acc_list(model, data_loadder):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    epsilon_list = [
        0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2,
        0.225, 0.25, 0.275, 0.3, 0.325, 0.35, 0.375, 0.4,
        0.425, 0.45, 0.475, 0.5
    ]
    acc_list, perturbed_acc_list = [], []

    for epsilon in epsilon_list:
        correct, perturbed_correct = 0, 0
        for data, target in data_loadder:
            data, target = Variable(data.cuda()), Variable(target.cuda())
            data.requires_grad = True
            output = model(data)
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).cpu().sum()

            loss = criterion(output, target)
            model.zero_grad()
            loss.backward()

            data_grad = data.grad
            perturbed_data = fgsm_attack(data, epsilon, data_grad)
            perturbed_output = model(perturbed_data)
            perturbed_pred = perturbed_output.data.max(1)[1]
            perturbed_correct += perturbed_pred.eq(target.data).cpu().sum()

        N = len(data_loadder.dataset)
        correct = 100. * correct / N
        perturbed_correct = 100. * perturbed_correct / N
        print('epsilon: {}\t\tactual acc: {}\tperturbed acc: {}'.format(epsilon, correct, perturbed_correct))
        acc_list.append(correct)
        perturbed_acc_list.append(perturbed_correct)

    return epsilon_list, acc_list, perturbed_acc_list