import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pickle
import random
import tensorflow as tf
from predict_spectrum import CK
from tensorflow import keras
from tensorflow.keras import initializers

font = {'size': 15}
matplotlib.rc('font', **font)
np.random.seed(321)
tf.random.set_seed(456)

# Define the nonlinear activation function
c = np.sqrt(0.21747 / (2 * np.sqrt(2 * np.pi)))
def activ_function(x):
  return (keras.backend.sigmoid(x) - 0.5) / c
b = 2 * 0.258961 / np.sqrt(2 * np.pi) / 0.208276
a = 2 * 0.0561939 / (np.sqrt(2 * np.pi) * (c ** 2))
def dactiv_function(x):
  return np.exp(-x) / (((1 + np.exp(-x)) ** 2) * c)

# Setting parameters
sample_dim = 10000
test_sample_dim = 2000
input_dim = 3072
weight_dim = 1000
learning_rate = 0.01
batch = 128
epoch = 60
t_n = 1.0 / np.sqrt(weight_dim)
L = 4
remove_PCs = 5

# Generate the training data/labels from Cifar-10
Xcifar = []
Ycifar = []
for j in range(1, 6):
  with open('cifar-10-batches-py/data_batch_%d' % j, 'rb') as fo:
    dat = pickle.load(fo, encoding='bytes')
    Xcifar.append(np.transpose(dat[b'data']))
    Ycifar.append(np.transpose(dat[b'labels']))
Xcifar = np.hstack(Xcifar)
Ycifar = np.hstack(Ycifar)
inds_0 = np.argwhere(Ycifar == 0).reshape(5000, ).tolist()
inds_1 = np.argwhere(Ycifar == 1).reshape(5000, ).tolist()
inds = inds_1 + inds_0
random.shuffle(inds)
X0 = Xcifar[:, inds]
Y0 = [Ycifar[i] for i in inds]
d0 = X0.shape[0]
X0 = X0 - X0.mean(axis=0)
X0 = X0 / X0.std(axis=0) / np.sqrt(d0)
if remove_PCs > 0:
  u, s, vh = np.linalg.svd(X0)
  for i in range(remove_PCs):
    X0 -= s[i] * np.outer(u[:, i], vh[i, :])
X0 = X0 - X0.mean(axis=0)
X0 = X0 / X0.std(axis=0) / np.sqrt(d0)
train_data = np.transpose(X0)
train_labels = np.array(Y0)

# Generate the test data/labels from Cifar-10
Xcifar_test = []
Ycifar_test = []
with open('cifar-10-batches-py/test_batch', 'rb') as fo:
  dat = pickle.load(fo, encoding='bytes')
  Xcifar_test.append(np.transpose(dat[b'data']))
  Ycifar_test.append(np.transpose(dat[b'labels']))
Xcifar_test = np.hstack(Xcifar_test)
Ycifar_test = np.hstack(Ycifar_test)
inds_0 = np.argwhere(Ycifar_test == 0).reshape(1000, ).tolist()
inds_1 = np.argwhere(Ycifar_test == 1).reshape(1000, ).tolist()
inds = inds_1 + inds_0
random.shuffle(inds)
X0_test = Xcifar_test[:, inds]
Y0_test = [Ycifar_test[i] for i in inds]
d0 = X0_test.shape[0]
X0_test = X0_test - X0_test.mean(axis=0)
X0_test = X0_test / X0_test.std(axis=0) / np.sqrt(d0)
if remove_PCs > 0:
  u, s, vh = np.linalg.svd(X0_test)
  for i in range(remove_PCs):
    X0_test -= s[i] * np.outer(u[:, i], vh[i, :])
X0_test = X0_test - X0_test.mean(axis=0)
X0_test = X0_test / X0_test.std(axis=0) / np.sqrt(d0)
test_data = np.transpose(X0_test)
test_labels = np.array(Y0_test)

# Create the neural network model
inputs = tf.keras.Input(shape=(input_dim,))
x_10 = tf.keras.layers.Dense(weight_dim, activation=activ_function, use_bias=True, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=1.0), bias_initializer='zeros')(inputs)
x_1 = keras.layers.Lambda(lambda x: x * t_n)(x_10)
layer_1 = keras.models.Model(inputs, x_1, name="first_layer")
x_20 = tf.keras.layers.Dense(weight_dim, activation=activ_function, use_bias=True, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=1.0), bias_initializer='zeros')(x_1)
x_2 = keras.layers.Lambda(lambda x: x * t_n)(x_20)
layer_2 = keras.models.Model(inputs, x_2, name="second_layer")
x_30 = tf.keras.layers.Dense(weight_dim, activation=activ_function, use_bias=True, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=1.0), bias_initializer='zeros')(x_2)
x_3 = keras.layers.Lambda(lambda x: x * t_n)(x_30)
layer_3 = keras.models.Model(inputs, x_3, name="third_layer")
x_40 = tf.keras.layers.Dense(weight_dim, activation=activ_function, use_bias=True, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=1.0), bias_initializer='zeros')(x_3)
x_4 = keras.layers.Lambda(lambda x: x * t_n)(x_40)
layer_4 = keras.models.Model(inputs, x_4, name="forth_layer")
outputs = tf.keras.layers.Dense(1, activation='sigmoid', use_bias=True, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=1.0), bias_initializer='zeros')(x_4)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()
optimizer = tf.keras.optimizers.Adam(learning_rate)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

# Record the weights throughout the training process
X_0 = train_data
Y_0 = train_labels
X_1 = []
X_2 = []
X_3 = []
X_4 = []
X_1.append(layer_1.predict(train_data))
X_2.append(layer_2.predict(train_data))
X_3.append(layer_3.predict(train_data))
X_4.append(layer_4.predict(train_data))
Ws = []

class MyCallback(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    X_1.append(layer_1.predict(train_data))
    X_2.append(layer_2.predict(train_data))
    X_3.append(layer_3.predict(train_data))
    X_4.append(layer_4.predict(train_data))
    Ws.append(model.get_weights())

callback = MyCallback()
history = model.fit(train_data, train_labels, epochs=epoch, batch_size=batch, validation_data=(test_data, test_labels), verbose=False, callbacks=[callback])
score = model.evaluate(test_data, test_labels, verbose=0)
print('The test accuracy: %f' % score[1])
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='lower right')
plt.title('Test Accuracy=%f L=%d n=%d d=%d' % (score[1], L, sample_dim, weight_dim))
plt.savefig('Accuracy L=%d n=%d d=%d' % (L, sample_dim, weight_dim))
plt.show()

def cross(X):
  return np.dot(X, np.transpose(X))

# Define the Histogram of eigenvalues
def histeig(eigs, ax, xgrid=None, dgrid=None, bins=100, xlim=None, ylim=None, title=None, scale='linear'):
  if xlim is not None:
    eigs = eigs[np.nonzero(eigs <= xlim[1])[0]]
    h = ax.hist(eigs, bins=np.linspace(xlim[0], xlim[1], num=bins))
  else:
    h = ax.hist(eigs, bins=bins)
  if xgrid is not None:
    space = h[1][1] - h[1][0]
    ax.plot(xgrid, dgrid * len(eigs) * space, 'r', label='Theoretical spectrum', linewidth=1.5)
    ax.legend()
  ax.set_title(title, fontsize=15)
  ax.set_xscale(scale)
  ax.set_yscale(scale)
  ax.set_xlim(xlim)
  if ylim is None:
    ax.set_ylim([0, max(h[0]) * 1.5])
  else:
    ax.set_ylim(ylim)
  return ax

# Save matrices
X = [X_0, X_1, X_2, X_3, X_4]
with open('Initial_Layers_L=%d_n=%d_d=%d.pickle' % (L, sample_dim, weight_dim), 'wb') as f:
  pickle.dump(X[L][0], f)
with open('Trained_Layers_L=%d_n=%d_d=%d.pickle' % (L, sample_dim, weight_dim), 'wb') as f:
  pickle.dump(X[L][epoch], f)
with open('Train_labels_L=%d_n=%d_d=%d.pickle' % (L, sample_dim, weight_dim), 'wb') as f:
  pickle.dump(Y_0, f)
with open('Weights_L=%d_n=%d_d=%d.pickle' % (L, sample_dim, weight_dim), 'wb') as f:
  pickle.dump(Ws[epoch - 1], f)
with open('Train_data_L=%d_n=%d_d=%d.pickle' % (L, sample_dim, weight_dim), 'wb') as f:
  pickle.dump(X_0, f)
with open('Test_data_L=%d_n=%d_d=%d.pickle' % (L, sample_dim, weight_dim), 'wb') as f:
  pickle.dump(test_data, f)
with open('Test_labels_L=%d_n=%d_d=%d.pickle' % (L, sample_dim, weight_dim), 'wb') as f:
  pickle.dump(test_labels, f)

# PCA analysis for the last layer
pcs = 2
area = np.pi * 0.8

x0 = np.linspace(-1, 1, 100)
x1 = np.linspace(-7, 7, 200)

v0, w0 = np.linalg.eig(cross(np.transpose(X[L][0])))
v1, w1 = np.linalg.eig(cross(np.transpose(X[L][epoch])))
V0 = w0[:, :pcs]
V1 = w1[:, :pcs]
U0 = np.matmul(X[L][0], V0)
U1 = np.matmul(X[L][epoch], V1)
S0 = np.diag(np.sqrt(v0[:pcs]))
S1 = np.diag(np.sqrt(v1[:pcs]))
X0 = np.matmul(U0, np.transpose(V0))
X1 = np.matmul(U1, np.transpose(V1))

train_labels = train_labels.reshape((sample_dim, 1))
inputs = tf.keras.Input(shape=(weight_dim,))
outputs = keras.layers.Dense(1, activation='sigmoid', use_bias=True, kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=1.0), bias_initializer='zeros')(inputs)
model_pcs = tf.keras.Model(inputs=inputs, outputs=outputs)
model_pcs.summary()
optimizer = tf.keras.optimizers.Adam(0.05)
model_pcs.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

history0 = model_pcs.fit(X0, train_labels, epochs=20, batch_size=200, verbose=False)
w0 = model_pcs.get_weights()
score0 = model_pcs.evaluate(X0, train_labels, verbose=0)
y0 = np.matmul(np.transpose(w0[0]), np.transpose(X0)) + w0[1]
print('Initial accuracy: %f' % score0[1])

history1 = model_pcs.fit(X1, train_labels, epochs=20, batch_size=200, verbose=False)
w1 = model_pcs.get_weights()
score1 = model_pcs.evaluate(X1, train_labels, verbose=0)
y1 = np.matmul(np.transpose(w1[0]), np.transpose(X1)) + w1[1]
print('Trained accuracy: %f' % score1[1])

f, ax = plt.subplots(1, 1)
ax.plot(history0.history['accuracy'])
ax.plot(history1.history['accuracy'])
plt.show()

train_labels = train_labels.reshape((1, sample_dim))

f, ax = plt.subplots(1, 1)
ax.scatter(y0, train_labels, s=area, alpha=1)
ax.plot(x0, 1 / (1 + np.exp(-x0)), '--', linewidth=0.9, label='Sigmoid')
ax.axhline(y=0.5, color='k', linewidth=0.9)
ax.axvline(x=0, color='k', linewidth=0.9)
ax.set_title('Top %d Pcs of initial CK, layer 4' % pcs)
ax.set_xlabel('Best linear prediction with top %d Pcs' % pcs, fontsize=12)
ax.set_ylabel('Training label y', fontsize=12)
ax.legend(loc='center right')
f.savefig('Initial Pcs=%d L=%d n=%d d=%d' % (pcs, L, sample_dim, weight_dim))

f, ax = plt.subplots(1, 1)
ax.scatter(y1, train_labels, s=area, alpha=1)
ax.plot(x1, 1 / (1 + np.exp(-x1)), '--', linewidth=0.9, label='Sigmoid')
ax.axhline(y=0.5, color='k', linewidth=0.9)
ax.axvline(x=0, color='k', linewidth=0.9)
ax.set_title('Top %d Pcs of trained CK, layer 4' % pcs)
ax.set_xlabel('Best linear prediction with top %d Pcs' % pcs, fontsize=12)
ax.set_ylabel('Training label y', fontsize=12)
ax.legend(loc='center right')
f.savefig('Trained Pcs=%d L=%d n=%d d=%d at epoch=%d' % (pcs, L, sample_dim, weight_dim, epoch))
plt.show()

# Save CK matrix & Plot CK spectrum
np.random.seed(421)
tf.random.set_seed(456)
bins = 80
ylim = [0, 20]

spec = np.linalg.eigvalsh(cross(np.transpose(X[0])))
spec = np.pad(spec, (0, int(sample_dim - input_dim)), 'constant')
gamma = np.array([sample_dim / input_dim, sample_dim / weight_dim, sample_dim / weight_dim, sample_dim / weight_dim, sample_dim / weight_dim])
eigs_initial = np.linalg.eigvalsh(cross(np.transpose(X[L][0])))
xgrid = np.linspace(max(eigs_initial) * (-0.05), max(eigs_initial) * 1.2, num=1000)
dgrid = CK(L, gamma[:L], b, xgrid, spec=spec)

k = 0
eigs = np.linalg.eigvalsh(cross(np.transpose(X[L][k])))
eigs = np.pad(eigs, (0, int(sample_dim - weight_dim)), 'constant')
f, ax = plt.subplots(1, 1)
histeig(eigs, ax, xgrid=xgrid, dgrid=dgrid, bins=bins, ylim=ylim, title='Initial CK spectrum, layer 4')
f.savefig('Initial CK L=%d n=%d d=%d.png' % (L, sample_dim, weight_dim))

k = epoch
eigs = np.linalg.eigvalsh(cross(np.transpose(X[L][k])))
eigs = np.pad(eigs, (0, int(sample_dim - weight_dim)), 'constant')
f, ax = plt.subplots(1, 1)
histeig(eigs, ax, xgrid=None, dgrid=None, bins=bins, ylim=ylim, title='Trained CK spectrum, layer 4')
f.savefig('Epoch=%d_CK L=%d n=%d d=%d.png' % (k, L, sample_dim, weight_dim))
plt.show()

