{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/users/u6537967/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.layers import Dense, Activation, Dropout, Flatten, Reshape, MaxPooling1D, GlobalAveragePooling1D\n",
    "from keras.models import Model, Sequential\n",
    "from keras.regularizers import l2\n",
    "from keras.optimizers import Adam\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "import keras.backend as K\n",
    "from sklearn.model_selection import train_test_split\n",
    "from keras import optimizers\n",
    "from keras import regularizers\n",
    "from keras import initializers\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "from utils import *\n",
    "from keras_dgl.layers import GraphCNN\n",
    "from keras.layers.normalization import BatchNormalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_check_point_loc = '/home/users/u6537967/Desktop/projects/NodeClassification/src/gcn_test/node_model.h5'\n",
    "model_checkpoint = ModelCheckpoint(model_check_point_loc, monitor='val_acc', verbose=0, save_weights_only=True, mode='max')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading cora dataset...\n",
      "Dataset has 2708 nodes, 5429 edges, 1433 features.\n"
     ]
    }
   ],
   "source": [
    "X, A, Y = load_data(dataset='cora')\n",
    "A = np.array(A.todense())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.matrixlib.defmatrix.matrix'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "print(type(X))\n",
    "print(type(A))\n",
    "print(type(Y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2708, 1433)\n",
      "(2708, 2708)\n",
      "(2708, 7)\n"
     ]
    }
   ],
   "source": [
    "print(X.shape)\n",
    "print(A.shape)\n",
    "print(Y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.5, random_state=200)\n",
    "_, Y_val, _, train_idx, val_idx, test_idx, train_mask = get_splits(Y)\n",
    "train_idx = np.array(train_idx)\n",
    "val_idx = np.array(val_idx)\n",
    "test_idx = np.array(test_idx)\n",
    "labels = np.argmax(Y, axis=1) + 1\n",
    "\n",
    "# Normalize X\n",
    "X /= X.sum(1).reshape(-1, 1)\n",
    "X = np.array(X)\n",
    "\n",
    "Y_train = np.zeros(Y.shape)\n",
    "labels_train = np.zeros(labels.shape)\n",
    "Y_train[train_idx] = Y[train_idx]\n",
    "labels_train[train_idx] = labels[train_idx]\n",
    "\n",
    "Y_test = np.zeros(Y.shape)\n",
    "labels_test = np.zeros(labels.shape)\n",
    "Y_test[test_idx] = Y[test_idx]\n",
    "labels_test[test_idx] = labels[test_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2708, 7)\n",
      "(2708,)\n",
      "(2708,)\n",
      "(2708, 7)\n"
     ]
    }
   ],
   "source": [
    "print(Y_train.shape)\n",
    "print(labels_train.shape)\n",
    "print(labels_test.shape)\n",
    "print(Y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "print(type(Y_train))\n",
    "print(type(labels_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2708, 2708)\n",
      "<class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "A_norm = preprocess_adj_numpy(A, True)\n",
    "print(A_norm.shape)\n",
    "print(type(A_norm))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2708, 2708)\n",
      "<class 'numpy.matrixlib.defmatrix.matrix'>\n"
     ]
    }
   ],
   "source": [
    "# Identity matrix for self loop\n",
    "I = np.matrix(np.eye(A.shape[0]))\n",
    "A_hat = A + I\n",
    "\n",
    "# Degree matrix\n",
    "D_hat = np.array(np.sum(A_hat, axis=0))[0]\n",
    "D_hat = np.matrix(np.diag(D_hat))\n",
    "\n",
    "A_norm = D_hat**-1 * A_hat\n",
    "\n",
    "print(A_norm.shape)\n",
    "print(type(A_norm))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "(2708, 2708)\n"
     ]
    }
   ],
   "source": [
    "num_filters = 1\n",
    "\n",
    "graph_conv_filters = np.asarray(A_norm) #np.concatenate((A_norm, A_norm), axis=0)#np.asarray(A_norm) \n",
    "#np.concatenate([A_norm, np.matmul(A_norm, A_norm)], axis=0)\n",
    "\n",
    "#graph_conv_filters = graph_conv_filters.reshape(graph_conv_filters.shape[0], graph_conv_filters.shape[1], 1)\n",
    "\n",
    "print(type(graph_conv_filters))\n",
    "print(graph_conv_filters.shape)\n",
    "\n",
    "graph_conv_filters = K.constant(graph_conv_filters)\n",
    "\n",
    "#print(type(graph_conv_filters))\n",
    "#print(graph_conv_filters.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''model = Sequential()\n",
    "model.add(GraphCNN(16, num_filters, graph_conv_filters, input_shape=(X.shape[1],), activation='elu', kernel_regularizer=l2(5e-4)))\n",
    "model.add(Dropout(0.2))\n",
    "model.add(GraphCNN(Y.shape[1], num_filters, graph_conv_filters, activation='elu', kernel_regularizer=l2(5e-4)))\n",
    "model.add(Activation('softmax'))\n",
    "model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.01), metrics=['acc'])'''\n",
    "\n",
    "model = Sequential()\n",
    "model.add(GraphCNN(16, num_filters, graph_conv_filters, input_shape=(X.shape[1],), kernel_regularizer=l2(5e-4))) #(2708, 16)\n",
    "#model.add(BatchNormalization())\n",
    "model.add(Activation('relu'))\n",
    "#model.add(Reshape((64,1))) #(2708, 16, 1) #64\n",
    "#model.add(MaxPooling1D(pool_size=(2))) #(2708, 8, 1)\n",
    "model.add(Dropout(0.2))\n",
    "\n",
    "#print(model.layers[0].output.get_shape())\n",
    "#print(model.layers[4].output.get_shape())\n",
    "\n",
    "model.add(GraphCNN(32, num_filters, graph_conv_filters, kernel_regularizer=l2(5e-4)))\n",
    "#model.add(BatchNormalization())\n",
    "model.add(Activation('elu'))\n",
    "model.add(Dropout(0.2))\n",
    "#print(model.layers[4].output.get_shape())\n",
    "\n",
    "'''model.add(GraphCNN(64, num_filters, graph_conv_filters, kernel_regularizer=l2(5e-4)))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Activation('elu'))\n",
    "model.add(Dropout(0.2))'''\n",
    "#print(model.layers[6].output.get_shape())\n",
    "\n",
    "#model.add(Reshape((64,1)))\n",
    "\n",
    "#model.add(GlobalAveragePooling1D())\n",
    "\n",
    "# Dense layer connected to 'Softmax' output\n",
    "#model.add(Dense(128, activation='elu', name='fc_'+str(1)))\n",
    "\n",
    "# Dense layer connected to 'Softmax' output\n",
    "#model.add(Dense(64, activation='elu', name='fc_'+str(2)))\n",
    "\n",
    "# Dense layer connected to 'Softmax' output\n",
    "#model.add(Dense(32, activation='elu', name='fc_'+str(3)))\n",
    "\n",
    "model.add(Dense(7, kernel_initializer=initializers.he_normal(seed=None), \n",
    "                kernel_regularizer=regularizers.l2(0), activity_regularizer=regularizers.l2(0), \n",
    "                use_bias=False, activation='softmax', name='fc_'+str(4)))\n",
    "#print(model.layers[10].output.get_shape())\n",
    "#model.add(Activation('softmax'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "graph_cnn_1 (GraphCNN)       (None, 16)                22944     \n",
      "_________________________________________________________________\n",
      "activation_1 (Activation)    (None, 16)                0         \n",
      "_________________________________________________________________\n",
      "dropout_1 (Dropout)          (None, 16)                0         \n",
      "_________________________________________________________________\n",
      "graph_cnn_2 (GraphCNN)       (None, 32)                544       \n",
      "_________________________________________________________________\n",
      "activation_2 (Activation)    (None, 32)                0         \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 32)                0         \n",
      "_________________________________________________________________\n",
      "fc_4 (Dense)                 (None, 7)                 224       \n",
      "=================================================================\n",
      "Total params: 23,712\n",
      "Trainable params: 23,712\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "(1433,)\n"
     ]
    }
   ],
   "source": [
    "model.summary()\n",
    "print((X.shape[1],))\n",
    "print(Y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "#sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)\n",
    "#model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['acc'])\n",
    "\n",
    "model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.01), metrics=['acc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0000 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0001 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0002 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0003 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0004 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0005 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0006 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0007 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0008 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0009 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0010 train_acc= 0.2929 test_acc= 0.3090\n",
      "Epoch: 0011 train_acc= 0.3429 test_acc= 0.3140\n",
      "Epoch: 0012 train_acc= 0.4429 test_acc= 0.3370\n",
      "Epoch: 0013 train_acc= 0.4786 test_acc= 0.3580\n",
      "Epoch: 0014 train_acc= 0.5214 test_acc= 0.3780\n",
      "Epoch: 0015 train_acc= 0.5357 test_acc= 0.3810\n",
      "Epoch: 0016 train_acc= 0.5429 test_acc= 0.3780\n",
      "Epoch: 0017 train_acc= 0.5429 test_acc= 0.3750\n",
      "Epoch: 0018 train_acc= 0.5571 test_acc= 0.3810\n",
      "Epoch: 0019 train_acc= 0.5929 test_acc= 0.3970\n",
      "Epoch: 0020 train_acc= 0.6500 test_acc= 0.4340\n",
      "Epoch: 0021 train_acc= 0.6929 test_acc= 0.4690\n",
      "Epoch: 0022 train_acc= 0.7071 test_acc= 0.5110\n",
      "Epoch: 0023 train_acc= 0.7286 test_acc= 0.5330\n",
      "Epoch: 0024 train_acc= 0.7286 test_acc= 0.5470\n",
      "Epoch: 0025 train_acc= 0.7286 test_acc= 0.5530\n",
      "Epoch: 0026 train_acc= 0.7286 test_acc= 0.5560\n",
      "Epoch: 0027 train_acc= 0.7214 test_acc= 0.5460\n",
      "Epoch: 0028 train_acc= 0.7214 test_acc= 0.5500\n",
      "Epoch: 0029 train_acc= 0.7571 test_acc= 0.5620\n",
      "Epoch: 0030 train_acc= 0.7786 test_acc= 0.5750\n",
      "Epoch: 0031 train_acc= 0.7929 test_acc= 0.5980\n",
      "Epoch: 0032 train_acc= 0.8429 test_acc= 0.6240\n",
      "Epoch: 0033 train_acc= 0.8643 test_acc= 0.6570\n",
      "Epoch: 0034 train_acc= 0.8786 test_acc= 0.6900\n",
      "Epoch: 0035 train_acc= 0.8786 test_acc= 0.7050\n",
      "Epoch: 0036 train_acc= 0.8857 test_acc= 0.7250\n",
      "Epoch: 0037 train_acc= 0.8929 test_acc= 0.7340\n",
      "Epoch: 0038 train_acc= 0.8929 test_acc= 0.7430\n",
      "Epoch: 0039 train_acc= 0.9000 test_acc= 0.7450\n",
      "Epoch: 0040 train_acc= 0.9000 test_acc= 0.7480\n",
      "Epoch: 0041 train_acc= 0.9143 test_acc= 0.7470\n",
      "Epoch: 0042 train_acc= 0.9000 test_acc= 0.7500\n",
      "Epoch: 0043 train_acc= 0.9071 test_acc= 0.7490\n",
      "Epoch: 0044 train_acc= 0.9071 test_acc= 0.7470\n",
      "Epoch: 0045 train_acc= 0.9286 test_acc= 0.7490\n",
      "Epoch: 0046 train_acc= 0.9286 test_acc= 0.7560\n",
      "Epoch: 0047 train_acc= 0.9286 test_acc= 0.7590\n",
      "Epoch: 0048 train_acc= 0.9357 test_acc= 0.7640\n",
      "Epoch: 0049 train_acc= 0.9357 test_acc= 0.7660\n",
      "Epoch: 0050 train_acc= 0.9357 test_acc= 0.7640\n",
      "Epoch: 0051 train_acc= 0.9429 test_acc= 0.7700\n",
      "Epoch: 0052 train_acc= 0.9571 test_acc= 0.7730\n",
      "Epoch: 0053 train_acc= 0.9857 test_acc= 0.7790\n",
      "Epoch: 0054 train_acc= 0.9857 test_acc= 0.7810\n",
      "Epoch: 0055 train_acc= 0.9857 test_acc= 0.7900\n",
      "Epoch: 0056 train_acc= 0.9857 test_acc= 0.7970\n",
      "Epoch: 0057 train_acc= 0.9857 test_acc= 0.8040\n",
      "Epoch: 0058 train_acc= 0.9857 test_acc= 0.8050\n",
      "Epoch: 0059 train_acc= 0.9857 test_acc= 0.8050\n",
      "Epoch: 0060 train_acc= 0.9857 test_acc= 0.8050\n",
      "Epoch: 0061 train_acc= 0.9857 test_acc= 0.7970\n",
      "Epoch: 0062 train_acc= 0.9857 test_acc= 0.7980\n",
      "Epoch: 0063 train_acc= 0.9857 test_acc= 0.7940\n",
      "Epoch: 0064 train_acc= 0.9857 test_acc= 0.7980\n",
      "Epoch: 0065 train_acc= 0.9857 test_acc= 0.8040\n",
      "Epoch: 0066 train_acc= 0.9857 test_acc= 0.8110\n",
      "Epoch: 0067 train_acc= 0.9857 test_acc= 0.8140\n",
      "Epoch: 0068 train_acc= 0.9857 test_acc= 0.8150\n",
      "Epoch: 0069 train_acc= 0.9929 test_acc= 0.8100\n",
      "Epoch: 0070 train_acc= 0.9929 test_acc= 0.8050\n",
      "Epoch: 0071 train_acc= 0.9857 test_acc= 0.8010\n",
      "Epoch: 0072 train_acc= 0.9857 test_acc= 0.8090\n",
      "Epoch: 0073 train_acc= 0.9857 test_acc= 0.7990\n",
      "Epoch: 0074 train_acc= 0.9857 test_acc= 0.8020\n",
      "Epoch: 0075 train_acc= 0.9929 test_acc= 0.8100\n",
      "Epoch: 0076 train_acc= 0.9929 test_acc= 0.8090\n",
      "Epoch: 0077 train_acc= 0.9929 test_acc= 0.8090\n",
      "Epoch: 0078 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0079 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0080 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0081 train_acc= 0.9929 test_acc= 0.8120\n",
      "Epoch: 0082 train_acc= 0.9929 test_acc= 0.8120\n",
      "Epoch: 0083 train_acc= 0.9929 test_acc= 0.8110\n",
      "Epoch: 0084 train_acc= 1.0000 test_acc= 0.8200\n",
      "Epoch: 0085 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0086 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0087 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0088 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0089 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0090 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0091 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0092 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0093 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0094 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0095 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0096 train_acc= 1.0000 test_acc= 0.7990\n",
      "Epoch: 0097 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0098 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0099 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0100 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0101 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0102 train_acc= 1.0000 test_acc= 0.8220\n",
      "Epoch: 0103 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0104 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0105 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0106 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0107 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0108 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0109 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0110 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0111 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0112 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0113 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0114 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0115 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0116 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0117 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0118 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0119 train_acc= 1.0000 test_acc= 0.8190\n",
      "Epoch: 0120 train_acc= 1.0000 test_acc= 0.8220\n",
      "Epoch: 0121 train_acc= 1.0000 test_acc= 0.8230\n",
      "Epoch: 0122 train_acc= 1.0000 test_acc= 0.8250\n",
      "Epoch: 0123 train_acc= 1.0000 test_acc= 0.8180\n",
      "Epoch: 0124 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0125 train_acc= 1.0000 test_acc= 0.8170\n",
      "Epoch: 0126 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0127 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0128 train_acc= 1.0000 test_acc= 0.8170\n",
      "Epoch: 0129 train_acc= 1.0000 test_acc= 0.8220\n",
      "Epoch: 0130 train_acc= 1.0000 test_acc= 0.7940\n",
      "Epoch: 0131 train_acc= 1.0000 test_acc= 0.7950\n",
      "Epoch: 0132 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0133 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0134 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0135 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0136 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0137 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0138 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0139 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0140 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0141 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0142 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0143 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0144 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0145 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0146 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0147 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0148 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0149 train_acc= 1.0000 test_acc= 0.8180\n",
      "Epoch: 0150 train_acc= 1.0000 test_acc= 0.8200\n",
      "Epoch: 0151 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0152 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0153 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0154 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0155 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0156 train_acc= 1.0000 test_acc= 0.8180\n",
      "Epoch: 0157 train_acc= 1.0000 test_acc= 0.8170\n",
      "Epoch: 0158 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0159 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0160 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0161 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0162 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0163 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0164 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0165 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0166 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0167 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0168 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0169 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0170 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0171 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0172 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0173 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0174 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0175 train_acc= 1.0000 test_acc= 0.8090\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0176 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0177 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0178 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0179 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0180 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0181 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0182 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0183 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0184 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0185 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0186 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0187 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0188 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0189 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0190 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0191 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0192 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0193 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0194 train_acc= 1.0000 test_acc= 0.8010\n",
      "Epoch: 0195 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0196 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0197 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0198 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0199 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0200 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0201 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0202 train_acc= 1.0000 test_acc= 0.7970\n",
      "Epoch: 0203 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0204 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0205 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0206 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0207 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0208 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0209 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0210 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0211 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0212 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0213 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0214 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0215 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0216 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0217 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0218 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0219 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0220 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0221 train_acc= 1.0000 test_acc= 0.7960\n",
      "Epoch: 0222 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0223 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0224 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0225 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0226 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0227 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0228 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0229 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0230 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0231 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0232 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0233 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0234 train_acc= 1.0000 test_acc= 0.8170\n",
      "Epoch: 0235 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0236 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0237 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0238 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0239 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0240 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0241 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0242 train_acc= 1.0000 test_acc= 0.8190\n",
      "Epoch: 0243 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0244 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0245 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0246 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0247 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0248 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0249 train_acc= 1.0000 test_acc= 0.8180\n",
      "Epoch: 0250 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0251 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0252 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0253 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0254 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0255 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0256 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0257 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0258 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0259 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0260 train_acc= 1.0000 test_acc= 0.7980\n",
      "Epoch: 0261 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0262 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0263 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0264 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0265 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0266 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0267 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0268 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0269 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0270 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0271 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0272 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0273 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0274 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0275 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0276 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0277 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0278 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0279 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0280 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0281 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0282 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0283 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0284 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0285 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0286 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0287 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0288 train_acc= 1.0000 test_acc= 0.8010\n",
      "Epoch: 0289 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0290 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0291 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0292 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0293 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0294 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0295 train_acc= 1.0000 test_acc= 0.7930\n",
      "Epoch: 0296 train_acc= 1.0000 test_acc= 0.8010\n",
      "Epoch: 0297 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0298 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0299 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0300 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0301 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0302 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0303 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0304 train_acc= 1.0000 test_acc= 0.7980\n",
      "Epoch: 0305 train_acc= 1.0000 test_acc= 0.7990\n",
      "Epoch: 0306 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0307 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0308 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0309 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0310 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0311 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0312 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0313 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0314 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0315 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0316 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0317 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0318 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0319 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0320 train_acc= 1.0000 test_acc= 0.7970\n",
      "Epoch: 0321 train_acc= 1.0000 test_acc= 0.7980\n",
      "Epoch: 0322 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0323 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0324 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0325 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0326 train_acc= 1.0000 test_acc= 0.8010\n",
      "Epoch: 0327 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0328 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0329 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0330 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0331 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0332 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0333 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0334 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0335 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0336 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0337 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0338 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0339 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0340 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0341 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0342 train_acc= 1.0000 test_acc= 0.8170\n",
      "Epoch: 0343 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0344 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0345 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0346 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0347 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0348 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0349 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0350 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0351 train_acc= 1.0000 test_acc= 0.8100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0352 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0353 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0354 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0355 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0356 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0357 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0358 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0359 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0360 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0361 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0362 train_acc= 1.0000 test_acc= 0.8010\n",
      "Epoch: 0363 train_acc= 1.0000 test_acc= 0.7960\n",
      "Epoch: 0364 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0365 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0366 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0367 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0368 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0369 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0370 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0371 train_acc= 1.0000 test_acc= 0.7990\n",
      "Epoch: 0372 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0373 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0374 train_acc= 1.0000 test_acc= 0.8190\n",
      "Epoch: 0375 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0376 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0377 train_acc= 1.0000 test_acc= 0.8180\n",
      "Epoch: 0378 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0379 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0380 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0381 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0382 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0383 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0384 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0385 train_acc= 1.0000 test_acc= 0.8170\n",
      "Epoch: 0386 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0387 train_acc= 1.0000 test_acc= 0.7970\n",
      "Epoch: 0388 train_acc= 1.0000 test_acc= 0.8010\n",
      "Epoch: 0389 train_acc= 1.0000 test_acc= 0.8180\n",
      "Epoch: 0390 train_acc= 1.0000 test_acc= 0.8170\n",
      "Epoch: 0391 train_acc= 1.0000 test_acc= 0.8190\n",
      "Epoch: 0392 train_acc= 1.0000 test_acc= 0.8180\n",
      "Epoch: 0393 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0394 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0395 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0396 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0397 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0398 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0399 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0400 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0401 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0402 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0403 train_acc= 1.0000 test_acc= 0.8150\n",
      "Epoch: 0404 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0405 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0406 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0407 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0408 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0409 train_acc= 1.0000 test_acc= 0.8210\n",
      "Epoch: 0410 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0411 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0412 train_acc= 1.0000 test_acc= 0.8070\n",
      "Epoch: 0413 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0414 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0415 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0416 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0417 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0418 train_acc= 1.0000 test_acc= 0.7970\n",
      "Epoch: 0419 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0420 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0421 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0422 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0423 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0424 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0425 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0426 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0427 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0428 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0429 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0430 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0431 train_acc= 1.0000 test_acc= 0.8190\n",
      "Epoch: 0432 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0433 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0434 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0435 train_acc= 1.0000 test_acc= 0.8010\n",
      "Epoch: 0436 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0437 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0438 train_acc= 1.0000 test_acc= 0.8030\n",
      "Epoch: 0439 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0440 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0441 train_acc= 1.0000 test_acc= 0.8020\n",
      "Epoch: 0442 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0443 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0444 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0445 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0446 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0447 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0448 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0449 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0450 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0451 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0452 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0453 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0454 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0455 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0456 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0457 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0458 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0459 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0460 train_acc= 1.0000 test_acc= 0.7980\n",
      "Epoch: 0461 train_acc= 1.0000 test_acc= 0.7950\n",
      "Epoch: 0462 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0463 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0464 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0465 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0466 train_acc= 1.0000 test_acc= 0.8160\n",
      "Epoch: 0467 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0468 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0469 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0470 train_acc= 1.0000 test_acc= 0.8090\n",
      "Epoch: 0471 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0472 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0473 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0474 train_acc= 1.0000 test_acc= 0.8170\n",
      "Epoch: 0475 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0476 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0477 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0478 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0479 train_acc= 1.0000 test_acc= 0.8100\n",
      "Epoch: 0480 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0481 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0482 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0483 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0484 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0485 train_acc= 1.0000 test_acc= 0.8130\n",
      "Epoch: 0486 train_acc= 1.0000 test_acc= 0.8120\n",
      "Epoch: 0487 train_acc= 1.0000 test_acc= 0.8000\n",
      "Epoch: 0488 train_acc= 1.0000 test_acc= 0.7940\n",
      "Epoch: 0489 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0490 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0491 train_acc= 1.0000 test_acc= 0.8080\n",
      "Epoch: 0492 train_acc= 1.0000 test_acc= 0.8040\n",
      "Epoch: 0493 train_acc= 1.0000 test_acc= 0.7990\n",
      "Epoch: 0494 train_acc= 1.0000 test_acc= 0.7890\n",
      "Epoch: 0495 train_acc= 1.0000 test_acc= 0.7920\n",
      "Epoch: 0496 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0497 train_acc= 1.0000 test_acc= 0.8050\n",
      "Epoch: 0498 train_acc= 1.0000 test_acc= 0.8060\n",
      "Epoch: 0499 train_acc= 1.0000 test_acc= 0.8030\n"
     ]
    }
   ],
   "source": [
    "'''X_in = X[np.newaxis, ...]\n",
    "\n",
    "print(np.newaxis)\n",
    "print(X_in.shape)\n",
    "print(Y_in.shape)'''\n",
    "\n",
    "#model.fit(X, Y, batch_size=X.shape[0], epochs=500, shuffle=False, #validation_data=(Xtest, Ytest),\n",
    "          #callbacks=[model_checkpoint], verbose=1)\n",
    "    \n",
    "nb_epochs = 500\n",
    "\n",
    "for epoch in range(nb_epochs):\n",
    "    model.fit(X, Y_train, sample_weight=train_mask, batch_size=A.shape[0], epochs=1, shuffle=False, verbose=0)\n",
    "    Y_pred = model.predict(X, batch_size=A.shape[0])\n",
    "    _, train_acc = evaluate_preds(Y_pred, [Y_train], [train_idx])\n",
    "    _, test_acc = evaluate_preds(Y_pred, [Y_test], [test_idx])\n",
    "    print(\"Epoch: {:04d}\".format(epoch), \"train_acc= {:.4f}\".format(train_acc[0]), \"test_acc= {:.4f}\".format(test_acc[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
