import argparse
import numpy as np
import os
import torch
#import tensorflow as tf
import tensorflow.compat.v1 as tf
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from setup_cifar import CIFAR, CIFARModel
import warnings
import Utils_CIFAR as util
import random

warnings.filterwarnings(action='ignore')

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
SEED = 121

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--iter', default=1000, type=int)
    parser.add_argument('--d', default=3072, type=int)
    parser.add_argument('--k', default=60, type=int)
    parser.add_argument('--eta', default=0.01, type=float)
    parser.add_argument('--s2', default=100, type=int)
    parser.add_argument('--q', default=100, type=int)
    parser.add_argument('--miu', default=1e-3, type=float)
    parser.add_argument('--lname', default='./CIFAR/result.txt', type=str)
    return parser.parse_args()

if __name__ == "__main__":
    with tf.Session() as sess:
        random.seed(SEED)
        np.random.seed(SEED)
        tf.set_random_seed(SEED)
        args = get_args()
        use_log = True
        image_number = 100
        #select 300 numbers from 1000 test_data
        image_id_set = np.random.choice(range(1000), image_number * 3, replace=False)

        data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)
        print(image_id_set)
        succ_count, ii, iii = 0, 0, 0

        I = 1000
        l2_distortion_collect = np.zeros(image_number)
        attack_succ_count = np.zeros(image_number)
        cc = 0
        cc2 = 0
        while iii < image_number:
            attack_flag = False
            image_id = image_id_set[ii]
            ii = ii + 1
            orig_prob, orig_class, orig_prob_str = util.model_prediction(model,
                                                                         np.expand_dims(data.test_data[image_id],axis=0))  ## orig_class: predicted label;
            # untargeted attack
            target_label = orig_class
            orig_img, target = util.generate_data(data,image_id,target_label)
            true_label_list = np.argmax(data.test_labels, axis=1)
            true_label = true_label_list[image_id]

            with open(args.lname,'a+') as f:
                f.write("\n Image ID:{}, infer label:{}, true label:{} \n".format(image_id, orig_class, true_label))
            print("Image ID:{}, infer label:{}, true label:{}".format(image_id, orig_class, true_label))
            if true_label != orig_class:
                with open(args.lname,'a+') as f:
                    f.write("True Label is different from the original prediction, pass!\n")
                print("True Label is different from the original prediction, pass!")
                continue
            else:
                iii = iii + 1

            with open(args.lname, 'a+') as f:
                f.write('\n'+str(iii)+'/'+ str(image_number)+'\n')
            print('\n', iii, '/', image_number)


            adv_image = orig_img
            count = 0
            for i in range(args.iter):
                gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q)
                adv_image = adv_image - args.eta * gradient
                # ||delta||_0<k
                delta_tmp = adv_image-orig_img
                delta_tmp = np.reshape(delta_tmp, (args.d))
                top_k_idx = np.argsort(-np.abs(delta_tmp))[0:args.k]
                delta = np.zeros_like(delta_tmp)
                delta[top_k_idx] = delta_tmp[top_k_idx]
                l2_dist = np.linalg.norm(delta, ord=2, keepdims=False)
                l0_num = 0
                for dim in range(args.d):
                    if delta[dim] != 0:
                        l0_num = l0_num + 1
                l0_dist = l0_num / args.d

                delta = np.reshape(delta, (1, 32,32,3))
                adv_image = np.clip(orig_img + delta,-0.5,0.5)
                attack_prob, attack_predict_class,_ = util.model_prediction(model, adv_image)
                # Judge whether the attack succeeds, if so, break
                if (i + 1) % 1 == 0:
                    if true_label != attack_predict_class:
                        with open(args.lname, 'a+') as f:
                            f.write("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                                    i + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                        print("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                            i + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                        attack_flag = True
                        count = count + 1
                        if count == 1:
                            attack_succ = i + 1
                            l2_distortion_collect[cc] = l2_dist
                            cc = cc + 1
                        break
                    else:
                        with open(args.lname, 'a+') as f:
                            f.write("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                                    i + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                        print("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                            i + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
            if (attack_flag):
                succ_count = succ_count + 1
                attack_succ_count[cc2] = attack_succ
                cc2 = cc2 + 1
                with open(args.lname, 'a+') as f:
                    f.write("It takes {} iterations to find the first attack \n".format(attack_succ))
                print("It takes {} iterations to find the first attack".format(attack_succ))
            else:
                with open(args.lname, 'a+') as f:
                    f.write("Attack Fails\n")
                print("Attack Fails")

        l2_dist_avg = np.sum(l2_distortion_collect)/cc
        attack_succ_count_avg = np.sum(attack_succ_count)/cc2
        print("succ rate: %3.5f, l2_dist_avg: %3.5f, attack_succ_avg: %3.5f  \n" % (
        succ_count / image_number, l2_dist_avg, attack_succ_count_avg))
        print(l2_distortion_collect)
        print(attack_succ_count)
        with open(args.lname, 'a+') as f:
            f.write("succ rate: %3.5f, l2_dist_avg: %3.5f, attack_succ_avg: %3.5f  \n" %(succ_count / image_number, l2_dist_avg, attack_succ_count_avg))


