import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from train_vgg19 import vgg19
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
import numpy as np
from utils import Noisy,transform,random_label
noisy = Noisy.apply

import torchvision

from detect import l1_detection, untargeted_detection, targeted_detection

def run_pgd(model, 
            img, 
            dataset):
    model.eval()
    x_var = torch.autograd.Variable(img.clone().cuda(), requires_grad=True)
    true_label = model(transform(x_var.clone(), dataset=dataset)).data.max(1, keepdim=True)[1][0].item()
    optimizer_s = optim.SGD([x_var], lr=0.001)
    counter = 0
    #while model(transform(x_var.clone(), dataset=dataset)).data.max(1, keepdim=True)[1][0].item() == true_label:
    for _ in range(100):
        optimizer_s.zero_grad()
        output = model(transform(x_var, dataset=dataset))
        loss = -F.cross_entropy(output - torch.max(output), torch.LongTensor([true_label]).cuda())
        #print('>   ', loss)
        loss.backward()

        x_var.data = torch.clamp(x_var - 0.001 * torch.sign(x_var.grad.data), min=0, max=1)
        x_var.data = torch.clamp(x_var - img, min=-0.009, max=0.009) + img
        counter += 1
        if counter >= 100:
            break
    return x_var.data.detach().cpu().numpy()


model = models.inception_v3(pretrained=True, transform_input=False)
#model = vgg19()
model = torch.nn.DataParallel(model).cuda()
model.cuda()
#checkpoint = torch.load("vgg19model/checkpoint_99.tar")#save directory for vgg19 model
#checkpoint = torch.load("them_model/model_best.pth.tar")#save directory for vgg19 model
#model.load_state_dict(checkpoint['state_dict'])
model.eval()

clean = []
advex = []
labs = []

testset = torchvision.datasets.ImageFolder(root="imagenetdata",
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),]),
                                               )

for j,i in enumerate(range(0,50000,500)):
    #view_data = torch.load("tmp2/real_dir/inception/%d_img.pt"%i)
    #view_data_label = torch.load("tmp2/real_dir/inception/%d_label.pt"%i)
    #print(view_data.shape)

    view_data, view_data_label = testset[i]
    view_data = view_data[None, :, :, :].cuda()
    print(view_data.shape)

    clean.append(view_data.detach().cpu().numpy())
    labs.append(view_data_label)#.detach().cpu().numpy())

    if j not in [16, 55, 64, 90]:
        advex.append(view_data.detach().cpu().numpy())
        continue
    
    print("True label", view_data_label)
    clean_c1 = l1_detection(model, view_data, 'imagenet', 0.01)
    print('C1A', '%.12f'%clean_c1)

    clean_c2 = targeted_detection(model, view_data, 'imagenet', 0.005, 0.03, cap=100)
    clean_c3 = untargeted_detection(model, view_data, 'imagenet', 3, 0.03, cap=100)
    print('C2A', clean_c2)
    print('C2B', clean_c3)

    print("right?", model(transform(view_data, dataset='imagenet')).argmax(), view_data_label)

    adv = torch.tensor(run_pgd(model, view_data, 'imagenet')).cuda()
    print('adv label', model(transform(adv, dataset='imagenet')).argmax())

    low = 0
    high = 1
    while high - low > .001:
        alpha = (high+low)/2
        print("ALPHA", alpha, high, low)
        middle = adv * alpha + view_data * (1-alpha)
        newlab = model(transform(middle, dataset='imagenet')).argmax()
        print(">> " , alpha, 'label', newlab)
        if newlab == view_data_label:
            print("   FAIL")
            low = alpha
            continue
        c1 = np.max([l1_detection(model, middle, 'imagenet', 0.01) for _ in range(5)])
        c2 = untargeted_detection(model, middle, 'imagenet', .005, 0.03, cap=25)
        c3 = targeted_detection(model, middle, 'imagenet', .005, 0.03, cap=25)

        print('  C1C', c1)
        print('  C2C', c2)
        print('  C3C', c3)

        if c2 >= 20:
            high = alpha
            continue
        if c3 >= 20:
            high = alpha
            continue
        if c1 > 1.83:
            low = alpha
            continue
        
        if c1 < 1.83 and c2 < 20 and c3 < 20:
            print("   ** SUCCESS")
            advex.append(middle.detach().cpu().numpy())
            break
    else:
        print("   ** FAILURE")
        advex.append(view_data.detach().cpu().numpy())

    #print()
np.save("/tmp/4clean.npy", clean)
np.save("/tmp/4labs.npy", labs)
np.save("/tmp/4adv.npy", advex)
