import numpy as np
import os
import random
import tensorflow as tf
import cv2
import keras.datasets.mnist as mnist
import multiprocessing as mp

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def show(img):
    img = img.real
    remap = " .*#" + "#" * 100
    img = (img.flatten()) * 3
    print("START")
    for i in range(28):
        print("".join([remap[int(round(x))] for x in img[i * 28:i * 28 + 28]]))

import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Activation
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from keras.preprocessing.image import ImageDataGenerator

import multiprocessing as mp


def make_model(filters=64, s1=5, s2=5, s3=3,
               mp1=True, mp2=True, d1=0, d2=0, fc=256,
               opt=0, lr=1e-3, decay=1e-3):
    model = Sequential()
    model.add(Conv2D(filters, kernel_size=(s1, s1),
                     activation='relu',
                     input_shape=(28, 28, 1)))
    if mp1:
        model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(filters*2, (s2, s2), activation='relu'))
    model.add(BatchNormalization())
    if s3 > 0:
        model.add(Conv2D(filters*2, (s3, s3), activation='relu'))
        model.add(BatchNormalization())
    if mp2:
        model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(d1))
    model.add(Flatten())
    model.add(Dense(fc, activation='relu'))
    model.add(Dropout(d2))
    model.add(Dense(10))

    model.summary()
    
    if opt == 0:
        opt = keras.optimizers.Adam(lr, decay=decay)
    elif opt == 1:
        opt = keras.optimizers.RMSprop(lr, decay=decay)
    elif opt == 2:
        opt = keras.optimizers.SGD(lr, momentum=.99,
                                   decay=decay)
    elif opt == 3:
        opt = keras.optimizers.SGD(lr, momentum=.95,
                                   decay=decay)
    elif opt == 4:
        opt = keras.optimizers.SGD(lr, momentum=.9,
                                   decay=decay)

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=opt,
                  metrics=['accuracy'])

    final = Sequential()
    final.add(model)
    final.add(Activation('softmax'))
    final.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=opt,
                  metrics=['accuracy'])
        
    
    return model, final

class StopEarly(keras.callbacks.Callback):
    def __init__(self):
        super(keras.callbacks.Callback, self).__init__()
        
    def on_epoch_end(self, epoch, logs={}):
        if logs.get('loss') < 1e-3:
            self.model.stop_training = True
        if logs.get('acc') == 1.0:
            self.model.stop_training = True

    
def train_model(model, x_train, y_train, batch_size=256,
                epochs=20, data_augmentation=False):
    if data_augmentation == False:
        model.fit(x_train, y_train,
                  batch_size=batch_size,
                  epochs=epochs,
                  shuffle=True,
                  verbose=2,
                  callbacks=[
                      StopEarly()
                  ],
        )

    return model

def IHT(y):
    k, t, T = None, None, 10
    ok = np.arange(28).reshape((1,28))+np.arange(28).reshape((28,1))

    x = np.zeros(y.shape)
    e = np.zeros(y.shape)
    for i in range(T):
        x = np.fft.fft2(y - e)
        x[ok>6] = 0

        e = (y-np.fft.ifft2(x))
        e[ok>8] = 0

    return np.fft.ifft2(x)

model, final = make_model()
(x_train, y_train), (x_test, y_test) = mnist.load_data()

img_rows = img_cols = 28
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

p = mp.Pool(16)
"""
x_train_iht = np.array(p.map(IHT, [x for x in x_train]))
np.save("x_train_iht.npy", x_train_iht)
x_test_iht = np.array(p.map(IHT, [x for x in x_test]))
np.save("x_test_iht.npy", x_test_iht)
train_model(final, x_train_iht.reshape((-1,28,28,1)),
            keras.utils.to_categorical(y_train, 10))
model.save("mnist.model")
#"""
x_train_iht = np.load("x_train_iht.npy")
x_test_iht = np.load("x_test_iht.npy")
model.load_weights("mnist.model")
#print(final.evaluate(x_test_iht.reshape((-1,28,28,1)),
#                 keras.utils.to_categorical(y_test, 10)))

N = 1000
for j in range(N):
    for i in range(50):
        o = (j,random.randint(0,27),random.randint(0,27))
        x_test[o] = 1-x_test[o]

adv = np.array(p.map(IHT, [x for x in x_test[:N]]))

for i in range(10):
    show(x_test[i])
    show(adv[i])

print(final.evaluate(adv.reshape((-1,28,28,1)),
                     keras.utils.to_categorical(y_test[:N], 10)))
    
# what is t, T
