import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from train_vgg19 import vgg19
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
import numpy as np
from utils import Noisy,transform,random_label
noisy = Noisy.apply

from detect import l1_detection, untargeted_detection, targeted_detection

#def l1_detection(model, 
#                 img, 
#                 dataset, 
#                 n_radius):
#    return torch.norm(F.softmax(model(transform(img, dataset=dataset))) - F.softmax(
#        model(transform(noisy(img, n_radius), dataset=dataset))), 1).item()
#
#def untargeted_detection(model, 
#                         img, 
#                         dataset, 
#                         lr, 
#                         u_radius, 
#                         cap=100,
#                         margin=20,
#                         use_margin=False):
#    model.eval()
#    x_var = torch.autograd.Variable(img.clone().cuda(), requires_grad=True)
#    true_label = model(transform(x_var.clone(), dataset=dataset)).data.max(1, keepdim=True)[1][0].item()
#    optimizer_s = optim.SGD([x_var], lr=lr)
#    counter = 0
#    while model(transform(x_var.clone(), dataset=dataset)).data.max(1, keepdim=True)[1][0].item() == true_label:
#        optimizer_s.zero_grad()
#        output = model(transform(x_var, dataset=dataset))
#        if use_margin:
#            _, top2_1 = output.data.cpu().topk(2)
#            argmax11 = top2_1[0][0]
#            if argmax11 == true_label:
#                argmax11 = top2_1[0][1]
#            loss = (output[0][true_label] - output[0][argmax11] + margin).clamp(min=0)
#        else:
#            loss = -F.cross_entropy(output, torch.LongTensor([true_label]).cuda())
#        loss.backward()
#
#        x_var.data = torch.clamp(x_var - lr * x_var.grad.data, min=0, max=1)
#        x_var.data = torch.clamp(x_var - img, min=-u_radius, max=u_radius) + img
#        counter += 1
#        if counter >= cap:
#            break
#    return counter

def run_pgd(model, 
            img, 
            dataset):
    model.eval()
    x_var = torch.autograd.Variable(img.clone().cuda(), requires_grad=True)
    true_label = model(transform(x_var.clone(), dataset=dataset)).data.max(1, keepdim=True)[1][0].item()
    optimizer_s = optim.SGD([x_var], lr=0.001)
    counter = 0
    #while model(transform(x_var.clone(), dataset=dataset)).data.max(1, keepdim=True)[1][0].item() == true_label:
    for _ in range(100):
        optimizer_s.zero_grad()
        output = model(transform(x_var, dataset=dataset))
        loss = -F.cross_entropy(output - torch.max(output), torch.LongTensor([true_label]).cuda())
        #print('>   ', loss)
        loss.backward()

        x_var.data = torch.clamp(x_var - 0.001 * torch.sign(x_var.grad.data), min=0, max=1)
        x_var.data = torch.clamp(x_var - img, min=-0.031, max=0.031) + img
        counter += 1
        if counter >= 100:
            break
    return x_var.data.detach().cpu().numpy()


model = vgg19()
model.features = torch.nn.DataParallel(model.features)
model.cuda()
#checkpoint = torch.load("vgg19model/checkpoint_99.tar")#save directory for vgg19 model
checkpoint = torch.load("them_model/model_best.pth.tar")#save directory for vgg19 model
model.load_state_dict(checkpoint['state_dict'])
model.eval()

clean = []
advex = []
labs = []

for i in range(50):
    view_data = torch.load("tmp/real_dir/vgg/%d_img.pt"%i)
    view_data_label = torch.load("tmp/real_dir/vgg/%d_label.pt"%i)
    print("True label", view_data_label)
    clean_c1 = l1_detection(model, view_data, 'cifar', 0.01)
    print('C1A', '%.12f'%clean_c1)

    print('C2A', untargeted_detection(model, view_data, 'cifar', 1, 0.5, cap=100))
    clean_c2 = targeted_detection(model, view_data, 'cifar', 0.0005, 0.5, cap=100)
    #print('C2B', clean_c2)
    adv = torch.tensor(run_pgd(model, view_data, 'cifar')).cuda()
    print('adv label', model(transform(adv, dataset='cifar')).argmax())
    #print('C1B', l1_detection(model, adv, 'cifar', 0.01))
    #print('C2B', targeted_detection(model, adv, 'cifar', .0005, 0.5, cap=100))
    
    clean.append(view_data.detach().cpu().numpy())
    labs.append(view_data_label.detach().cpu().numpy())

    for alpha in [.1, .2, .3, .4, .5, .6, .7, .8, .9]:
        middle = adv * alpha + view_data * (1-alpha)
        newlab = model(transform(middle, dataset='cifar')).argmax()
        print(">> " , alpha, 'label', newlab)
        if newlab == view_data_label:
            print("   FAIL")
            continue
        c1 = l1_detection(model, middle, 'cifar', 0.01)
        c2 = targeted_detection(model, middle, 'cifar', .0005, 0.5, cap=100)
        if c1 < 0.000087 and c2 < 40:
            print("   ** SUCCESS")
            advex.append(middle.detach().cpu().numpy())
            break
        print('  C1C', c1)
        print('  C2C', c2)
    else:
        advex.append(view_data.detach().cpu().numpy())

    #print()
    np.save("/tmp/clean.npy", clean)
    np.save("/tmp/labs.npy", labs)
    np.save("/tmp/adv.npy", advex)
