{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/users/u6537967/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.layers import Dense, Activation, Dropout, Reshape, concatenate, ReLU, Input\n",
    "from keras.models import Model, Sequential\n",
    "from keras.regularizers import l2, l1_l2\n",
    "from keras.optimizers import Adam, SGD, Adamax, Nadam\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras.constraints import unit_norm\n",
    "from keras import optimizers\n",
    "from keras import regularizers\n",
    "from keras import initializers\n",
    "import keras.backend as K\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.utils import class_weight\n",
    "from scipy.linalg import fractional_matrix_power\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "from utils import *\n",
    "from arma_optimizer import *\n",
    "from arma_graph_filters_cnn import ArmaGraphFiltersCNN\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_check_point_loc = '/home/asiri/projects/cora_ARMA/arma_cnn_cora_model.h5'\n",
    "model_checkpoint = ModelCheckpoint(model_check_point_loc, monitor='val_acc', verbose=0, save_weights_only=True, mode='max')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading cora dataset...\n",
      "Dataset has 2708 nodes, 5429 edges, 1433 features.\n"
     ]
    }
   ],
   "source": [
    "X, A, Y = load_data(dataset='cora')\n",
    "A = np.array(A.todense())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, Y_val, _, train_idx, val_idx, test_idx, train_mask = get_splits(Y)\n",
    "train_idx = np.array(train_idx)\n",
    "val_idx = np.array(val_idx)\n",
    "test_idx = np.array(test_idx)\n",
    "labels = np.argmax(Y, axis=1) + 1\n",
    "\n",
    "# Normalize X\n",
    "#X /= X.sum(1).reshape(-1, 1)\n",
    "X = np.array(X)\n",
    "\n",
    "Y_train = np.zeros(Y.shape)\n",
    "labels_train = np.zeros(labels.shape)\n",
    "Y_train[train_idx] = Y[train_idx]\n",
    "labels_train[train_idx] = labels[train_idx]\n",
    "\n",
    "Y_test = np.zeros(Y.shape)\n",
    "labels_test = np.zeros(labels.shape)\n",
    "Y_test[test_idx] = Y[test_idx]\n",
    "labels_test[test_idx] = labels[test_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "ECOS 2.0.7 - (C) embotech GmbH, Zurich Switzerland, 2012-15. Web: www.embotech.com/ECOS\n",
      "\n",
      "It     pcost       dcost      gap   pres   dres    k/t    mu     step   sigma     IR    |   BT\n",
      " 0  +0.000e+00  -9.000e-01  +3e+02  9e-01  5e-01  1e+00  2e+00    ---    ---    1  1  - |  -  - \n",
      " 1  -1.558e+00  -1.896e+00  +1e+02  3e-01  2e-01  5e-01  7e-01  0.6114  7e-02   2  1  1 |  0  0\n",
      " 2  -2.220e+01  -2.218e+01  +1e+02  3e-01  1e-01  5e-01  5e-01  0.9890  7e-01   2  2  2 |  0  0\n",
      " 3  -2.340e-02  +3.835e-02  +1e+01  2e-02  5e-03  1e-01  6e-02  0.9642  9e-02   2  2  2 |  0  0\n",
      " 4  -2.341e-01  -2.032e-01  +5e+00  2e-03  3e-03  5e-02  2e-02  0.7761  2e-01   2  3  3 |  0  0\n",
      " 5  +4.778e-01  +4.830e-01  +9e-01  4e-04  5e-04  9e-03  4e-03  0.8274  9e-03   3  3  3 |  0  0\n",
      " 6  +5.248e-01  +5.273e-01  +7e-01  3e-04  3e-04  5e-03  3e-03  0.4628  4e-01   3  3  3 |  0  0\n",
      " 7  +6.410e-01  +6.412e-01  +7e-02  3e-05  3e-05  5e-04  3e-04  0.9109  2e-02   3  3  3 |  0  0\n",
      " 8  +6.464e-01  +6.464e-01  +4e-02  2e-05  2e-05  2e-04  2e-04  0.6807  4e-01   3  3  3 |  0  0\n",
      " 9  +6.539e-01  +6.539e-01  +5e-03  2e-06  2e-06  2e-05  3e-05  0.8800  5e-03   3  2  2 |  0  0\n",
      "10  +6.544e-01  +6.544e-01  +1e-03  5e-07  6e-07  4e-06  6e-06  0.8733  1e-01   3  2  3 |  0  0\n",
      "11  +6.546e-01  +6.546e-01  +4e-04  2e-07  2e-07  1e-06  2e-06  0.7433  1e-01   3  3  3 |  0  0\n",
      "12  +6.546e-01  +6.546e-01  +4e-04  1e-07  2e-07  9e-07  2e-06  0.4087  5e-01   3  2  2 |  0  0\n",
      "13  +6.547e-01  +6.547e-01  +1e-05  5e-09  6e-09  3e-08  6e-08  0.9828  2e-02   3  2  2 |  0  0\n",
      "14  +6.547e-01  +6.547e-01  +2e-07  6e-11  7e-11  4e-10  7e-10  0.9879  1e-04   3  2  2 |  0  0\n",
      "15  +6.547e-01  +6.547e-01  +4e-09  2e-12  2e-12  1e-11  2e-11  0.9741  1e-04   3  2  2 |  0  0\n",
      "\n",
      "OPTIMAL (within feastol=2.2e-12, reltol=6.3e-09, abstol=4.1e-09).\n",
      "Runtime: 0.002199 seconds.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Identity matrix for self loop\n",
    "I = np.matrix(np.eye(A.shape[0]))\n",
    "A_hat = A + I\n",
    "\n",
    "# Degree matrix\n",
    "D_hat = np.array(np.sum(A_hat, axis=0))[0]\n",
    "D_hat = np.matrix(np.diag(D_hat))\n",
    "\n",
    "#Laplacian matrix\n",
    "L = I - (fractional_matrix_power(D_hat, -0.5) * A_hat * fractional_matrix_power(D_hat, -0.5))\n",
    "L = L - ((lmax(L)/2) * I)\n",
    "\n",
    "lambda_cut = 0.5\n",
    "\n",
    "def step(x, a):\n",
    "    for index in range(len(x)):\n",
    "        if(x[index] >= a):\n",
    "            x[index] = float(1)\n",
    "        else:\n",
    "            x[index] = float(0)\n",
    "    return x\n",
    "    \n",
    "response = lambda x: step(x, lmax(L)/2 - lambda_cut)\n",
    "\n",
    "# Since the eigenvalues might change, sample eigenvalue domain uniformly\n",
    "mu = np.linspace(0, lmax(L), 70) #100\n",
    "\n",
    "#AR filter order (decrease radius for larger values)\n",
    "Ka = 5\n",
    "\n",
    "#MA filter order\n",
    "Kb = 3\n",
    "\n",
    "#for speed make small, for accuracy increase. Should be below 1 if the distributed implementation is used. \n",
    "#With the (faster) conj. gradient implementation, any radius is allowed.\n",
    "radius = 0.90\n",
    "\n",
    "b, a, rARMA, error = agsp_design_ARMA(mu, response, Kb, Ka, radius)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "h_zero = np.zeros(L.shape[0])\n",
    "\n",
    "def L_mult_numerator(coef):\n",
    "    y = coef.item(0) * np.linalg.matrix_power(L, 0)\n",
    "    for i in range(1, len(coef)):\n",
    "        x = np.linalg.matrix_power(L, i)\n",
    "        y = y + coef.item(i) * x\n",
    "\n",
    "    return y\n",
    "\n",
    "def L_mult_denominator(coef):\n",
    "    y_d = h_zero\n",
    "    for i in range(0, len(coef)):\n",
    "        x_d = np.linalg.matrix_power(L, i+1)\n",
    "        y_d = y_d + coef.item(i) * x_d\n",
    "    \n",
    "    return y_d\n",
    "\n",
    "poly_num = L_mult_numerator(b)\n",
    "poly_denom = L_mult_denominator(a)\n",
    "\n",
    "arma_conv_AR = K.constant(poly_denom)\n",
    "arma_conv_MA = K.constant(poly_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dense_factor(inputs, input_signal, num_nodes, droput):\n",
    "    \n",
    "    h_1 = BatchNormalization()(inputs)\n",
    "    h_1 = ArmaGraphFiltersCNN(num_nodes, \n",
    "                              arma_conv_AR, \n",
    "                              arma_conv_MA, \n",
    "                              input_signal, \n",
    "                              kernel_initializer=initializers.glorot_normal(seed=1), \n",
    "                              kernel_regularizer=l2(9e-2), \n",
    "                              kernel_constraint=unit_norm(),\n",
    "                              use_bias=True,\n",
    "                              bias_initializer=initializers.glorot_normal(seed=1), \n",
    "                              bias_constraint=unit_norm())(h_1)\n",
    "    h_1 = ReLU()(h_1)\n",
    "    output = Dropout(droput)(h_1)\n",
    "    return output\n",
    "\n",
    "def dense_block(inputs):\n",
    "\n",
    "    concatenated_inputs = inputs\n",
    "    \n",
    "    num_nodes = [8, 16, 32, 64, 128]\n",
    "    droput = [0.9, 0.9, 0.9, 0.9, 0.9]\n",
    "\n",
    "    for i in range(5):\n",
    "        x = dense_factor(concatenated_inputs, inputs, num_nodes[i], droput[i])\n",
    "        concatenated_inputs = concatenate([concatenated_inputs, x], axis=1)\n",
    "\n",
    "    return concatenated_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            (None, 1433)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_1 (BatchNor (None, 1433)         5732        input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "arma_graph_filters_cnn_1 (ArmaG (None, 8)            22936       batch_normalization_1[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "re_lu_1 (ReLU)                  (None, 8)            0           arma_graph_filters_cnn_1[0][0]   \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 8)            0           re_lu_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 1441)         0           input_1[0][0]                    \n",
      "                                                                 dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_2 (BatchNor (None, 1441)         5764        concatenate_1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "arma_graph_filters_cnn_2 (ArmaG (None, 16)           46000       batch_normalization_2[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "re_lu_2 (ReLU)                  (None, 16)           0           arma_graph_filters_cnn_2[0][0]   \n",
      "__________________________________________________________________________________________________\n",
      "dropout_2 (Dropout)             (None, 16)           0           re_lu_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_2 (Concatenate)     (None, 1457)         0           concatenate_1[0][0]              \n",
      "                                                                 dropout_2[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_3 (BatchNor (None, 1457)         5828        concatenate_2[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "arma_graph_filters_cnn_3 (ArmaG (None, 32)           92512       batch_normalization_3[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "re_lu_3 (ReLU)                  (None, 32)           0           arma_graph_filters_cnn_3[0][0]   \n",
      "__________________________________________________________________________________________________\n",
      "dropout_3 (Dropout)             (None, 32)           0           re_lu_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_3 (Concatenate)     (None, 1489)         0           concatenate_2[0][0]              \n",
      "                                                                 dropout_3[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_4 (BatchNor (None, 1489)         5956        concatenate_3[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "arma_graph_filters_cnn_4 (ArmaG (None, 64)           187072      batch_normalization_4[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "re_lu_4 (ReLU)                  (None, 64)           0           arma_graph_filters_cnn_4[0][0]   \n",
      "__________________________________________________________________________________________________\n",
      "dropout_4 (Dropout)             (None, 64)           0           re_lu_4[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_4 (Concatenate)     (None, 1553)         0           concatenate_3[0][0]              \n",
      "                                                                 dropout_4[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "batch_normalization_5 (BatchNor (None, 1553)         6212        concatenate_4[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "arma_graph_filters_cnn_5 (ArmaG (None, 128)          382336      batch_normalization_5[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "re_lu_5 (ReLU)                  (None, 128)          0           arma_graph_filters_cnn_5[0][0]   \n",
      "__________________________________________________________________________________________________\n",
      "dropout_5 (Dropout)             (None, 128)          0           re_lu_5[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_5 (Concatenate)     (None, 1681)         0           concatenate_4[0][0]              \n",
      "                                                                 dropout_5[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "fc_1 (Dense)                    (None, 7)            11774       concatenate_5[0][0]              \n",
      "==================================================================================================\n",
      "Total params: 772,122\n",
      "Trainable params: 757,376\n",
      "Non-trainable params: 14,746\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "def dense_block_model(x_train):\n",
    "    \n",
    "    inputs = Input((x_train.shape[1],))\n",
    "    \n",
    "    x = dense_block(inputs)\n",
    "\n",
    "    predictions = Dense(7, kernel_initializer=initializers.glorot_normal(seed=1), \n",
    "                        kernel_regularizer=regularizers.l2(1e-10), \n",
    "                        kernel_constraint=unit_norm(), \n",
    "                        activity_regularizer=regularizers.l2(1e-10), \n",
    "                        use_bias=True, \n",
    "                        bias_initializer=initializers.glorot_normal(seed=1), \n",
    "                        bias_constraint=unit_norm(), \n",
    "                        activation='softmax', name='fc_'+str(1))(x)\n",
    "    \n",
    "    model = Model(input=inputs, output=predictions)\n",
    "    \n",
    "    model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.002), metrics=['acc'])\n",
    "    \n",
    "    return model\n",
    "\n",
    "model_dense_block = dense_block_model(X)\n",
    "model_dense_block.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0000 train_acc= 0.2600 test_acc= 0.2500\n",
      "Epoch: 0001 train_acc= 0.5800 test_acc= 0.4640\n",
      "Epoch: 0002 train_acc= 0.7350 test_acc= 0.6220\n",
      "Epoch: 0003 train_acc= 0.8050 test_acc= 0.6830\n",
      "Epoch: 0004 train_acc= 0.8350 test_acc= 0.7130\n",
      "Epoch: 0005 train_acc= 0.8650 test_acc= 0.7370\n",
      "Epoch: 0006 train_acc= 0.8700 test_acc= 0.7580\n",
      "Epoch: 0007 train_acc= 0.8900 test_acc= 0.7760\n",
      "Epoch: 0008 train_acc= 0.8950 test_acc= 0.7880\n",
      "Epoch: 0009 train_acc= 0.9150 test_acc= 0.7960\n",
      "Epoch: 0010 train_acc= 0.9300 test_acc= 0.8110\n",
      "Epoch: 0011 train_acc= 0.9300 test_acc= 0.8100\n",
      "Epoch: 0012 train_acc= 0.9350 test_acc= 0.8170\n",
      "Epoch: 0013 train_acc= 0.9350 test_acc= 0.8180\n",
      "Epoch: 0014 train_acc= 0.9400 test_acc= 0.8190\n",
      "Epoch: 0015 train_acc= 0.9400 test_acc= 0.8190\n",
      "Epoch: 0016 train_acc= 0.9400 test_acc= 0.8190\n",
      "Epoch: 0017 train_acc= 0.9450 test_acc= 0.8220\n",
      "Epoch: 0018 train_acc= 0.9450 test_acc= 0.8220\n",
      "Epoch: 0019 train_acc= 0.9450 test_acc= 0.8240\n",
      "Epoch: 0020 train_acc= 0.9500 test_acc= 0.8240\n",
      "Epoch: 0021 train_acc= 0.9550 test_acc= 0.8240\n",
      "Epoch: 0022 train_acc= 0.9550 test_acc= 0.8240\n",
      "Epoch: 0023 train_acc= 0.9600 test_acc= 0.8250\n",
      "Epoch: 0024 train_acc= 0.9600 test_acc= 0.8280\n",
      "Epoch: 0025 train_acc= 0.9600 test_acc= 0.8320\n",
      "Epoch: 0026 train_acc= 0.9650 test_acc= 0.8300\n",
      "Epoch: 0027 train_acc= 0.9700 test_acc= 0.8280\n",
      "Epoch: 0028 train_acc= 0.9700 test_acc= 0.8290\n",
      "Epoch: 0029 train_acc= 0.9700 test_acc= 0.8310\n",
      "Epoch: 0030 train_acc= 0.9700 test_acc= 0.8330\n",
      "Epoch: 0031 train_acc= 0.9700 test_acc= 0.8310\n",
      "Epoch: 0032 train_acc= 0.9750 test_acc= 0.8270\n",
      "Epoch: 0033 train_acc= 0.9750 test_acc= 0.8290\n",
      "Epoch: 0034 train_acc= 0.9750 test_acc= 0.8270\n",
      "Epoch: 0035 train_acc= 0.9750 test_acc= 0.8260\n",
      "Epoch: 0036 train_acc= 0.9800 test_acc= 0.8260\n",
      "Epoch: 0037 train_acc= 0.9800 test_acc= 0.8230\n",
      "Epoch: 0038 train_acc= 0.9800 test_acc= 0.8230\n",
      "Epoch: 0039 train_acc= 0.9850 test_acc= 0.8230\n",
      "Epoch: 0040 train_acc= 0.9850 test_acc= 0.8230\n",
      "Epoch: 0041 train_acc= 0.9850 test_acc= 0.8240\n",
      "Epoch: 0042 train_acc= 0.9850 test_acc= 0.8240\n",
      "Epoch: 0043 train_acc= 0.9900 test_acc= 0.8240\n",
      "Epoch: 0044 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0045 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0046 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0047 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0048 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0049 train_acc= 0.9900 test_acc= 0.8270\n",
      "Epoch: 0050 train_acc= 0.9900 test_acc= 0.8280\n",
      "Epoch: 0051 train_acc= 0.9900 test_acc= 0.8300\n",
      "Epoch: 0052 train_acc= 0.9900 test_acc= 0.8290\n",
      "Epoch: 0053 train_acc= 0.9900 test_acc= 0.8280\n",
      "Epoch: 0054 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0055 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0056 train_acc= 0.9900 test_acc= 0.8270\n",
      "Epoch: 0057 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0058 train_acc= 0.9900 test_acc= 0.8250\n",
      "Epoch: 0059 train_acc= 0.9900 test_acc= 0.8260\n",
      "Epoch: 0060 train_acc= 0.9900 test_acc= 0.8270\n",
      "Epoch: 0061 train_acc= 0.9900 test_acc= 0.8300\n",
      "Epoch: 0062 train_acc= 0.9900 test_acc= 0.8310\n",
      "Epoch: 0063 train_acc= 0.9900 test_acc= 0.8300\n",
      "Epoch: 0064 train_acc= 0.9900 test_acc= 0.8320\n",
      "Epoch: 0065 train_acc= 0.9900 test_acc= 0.8320\n",
      "Epoch: 0066 train_acc= 0.9900 test_acc= 0.8340\n",
      "Epoch: 0067 train_acc= 0.9900 test_acc= 0.8350\n",
      "Epoch: 0068 train_acc= 0.9900 test_acc= 0.8360\n",
      "Epoch: 0069 train_acc= 0.9900 test_acc= 0.8360\n",
      "Epoch: 0070 train_acc= 0.9900 test_acc= 0.8350\n",
      "Epoch: 0071 train_acc= 0.9900 test_acc= 0.8390\n",
      "Epoch: 0072 train_acc= 0.9900 test_acc= 0.8370\n",
      "Epoch: 0073 train_acc= 0.9900 test_acc= 0.8370\n",
      "Epoch: 0074 train_acc= 0.9900 test_acc= 0.8380\n",
      "Epoch: 0075 train_acc= 0.9900 test_acc= 0.8390\n",
      "Epoch: 0076 train_acc= 0.9900 test_acc= 0.8420\n",
      "Epoch: 0077 train_acc= 0.9900 test_acc= 0.8410\n",
      "Epoch: 0078 train_acc= 0.9900 test_acc= 0.8390\n",
      "Epoch: 0079 train_acc= 0.9900 test_acc= 0.8370\n",
      "Epoch: 0080 train_acc= 0.9900 test_acc= 0.8380\n",
      "Epoch: 0081 train_acc= 0.9900 test_acc= 0.8390\n",
      "Epoch: 0082 train_acc= 0.9900 test_acc= 0.8390\n",
      "Epoch: 0083 train_acc= 0.9900 test_acc= 0.8400\n",
      "Epoch: 0084 train_acc= 0.9900 test_acc= 0.8400\n",
      "Epoch: 0085 train_acc= 0.9900 test_acc= 0.8400\n",
      "Epoch: 0086 train_acc= 0.9900 test_acc= 0.8390\n",
      "Epoch: 0087 train_acc= 0.9900 test_acc= 0.8410\n",
      "Epoch: 0088 train_acc= 0.9900 test_acc= 0.8420\n",
      "Epoch: 0089 train_acc= 0.9900 test_acc= 0.8410\n",
      "Epoch: 0090 train_acc= 0.9900 test_acc= 0.8400\n",
      "Epoch: 0091 train_acc= 0.9900 test_acc= 0.8410\n",
      "Epoch: 0092 train_acc= 1.0000 test_acc= 0.8490\n",
      "Epoch: 0093 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0094 train_acc= 0.9950 test_acc= 0.8440\n",
      "Epoch: 0095 train_acc= 0.9950 test_acc= 0.8440\n",
      "Epoch: 0096 train_acc= 0.9950 test_acc= 0.8450\n",
      "Epoch: 0097 train_acc= 0.9950 test_acc= 0.8460\n",
      "Epoch: 0098 train_acc= 0.9950 test_acc= 0.8470\n",
      "Epoch: 0099 train_acc= 0.9950 test_acc= 0.8470\n",
      "Epoch: 0100 train_acc= 0.9950 test_acc= 0.8460\n",
      "Epoch: 0101 train_acc= 0.9950 test_acc= 0.8470\n",
      "Epoch: 0102 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0103 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0104 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0105 train_acc= 0.9950 test_acc= 0.8490\n",
      "Epoch: 0106 train_acc= 0.9950 test_acc= 0.8490\n",
      "Epoch: 0107 train_acc= 0.9950 test_acc= 0.8500\n",
      "Epoch: 0108 train_acc= 0.9950 test_acc= 0.8490\n",
      "Epoch: 0109 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0110 train_acc= 0.9900 test_acc= 0.8510\n",
      "Epoch: 0111 train_acc= 0.9900 test_acc= 0.8520\n",
      "Epoch: 0112 train_acc= 0.9900 test_acc= 0.8550\n",
      "Epoch: 0113 train_acc= 0.9950 test_acc= 0.8510\n",
      "Epoch: 0114 train_acc= 0.9950 test_acc= 0.8520\n",
      "Epoch: 0115 train_acc= 1.0000 test_acc= 0.8510\n",
      "Epoch: 0116 train_acc= 1.0000 test_acc= 0.8480\n",
      "Epoch: 0117 train_acc= 1.0000 test_acc= 0.8470\n",
      "Epoch: 0118 train_acc= 0.9950 test_acc= 0.8420\n",
      "Epoch: 0119 train_acc= 0.9950 test_acc= 0.8450\n",
      "Epoch: 0120 train_acc= 0.9950 test_acc= 0.8450\n",
      "Epoch: 0121 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0122 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0123 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0124 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0125 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0126 train_acc= 0.9950 test_acc= 0.8460\n",
      "Epoch: 0127 train_acc= 1.0000 test_acc= 0.8500\n",
      "Epoch: 0128 train_acc= 1.0000 test_acc= 0.8500\n",
      "Epoch: 0129 train_acc= 1.0000 test_acc= 0.8490\n",
      "Epoch: 0130 train_acc= 1.0000 test_acc= 0.8490\n",
      "Epoch: 0131 train_acc= 0.9950 test_acc= 0.8490\n",
      "Epoch: 0132 train_acc= 0.9950 test_acc= 0.8460\n",
      "Epoch: 0133 train_acc= 0.9950 test_acc= 0.8450\n",
      "Epoch: 0134 train_acc= 0.9950 test_acc= 0.8440\n",
      "Epoch: 0135 train_acc= 0.9950 test_acc= 0.8450\n",
      "Epoch: 0136 train_acc= 0.9950 test_acc= 0.8440\n",
      "Epoch: 0137 train_acc= 0.9950 test_acc= 0.8470\n",
      "Epoch: 0138 train_acc= 1.0000 test_acc= 0.8490\n",
      "Epoch: 0139 train_acc= 1.0000 test_acc= 0.8450\n",
      "Epoch: 0140 train_acc= 1.0000 test_acc= 0.8450\n",
      "Epoch: 0141 train_acc= 1.0000 test_acc= 0.8460\n",
      "Epoch: 0142 train_acc= 1.0000 test_acc= 0.8440\n",
      "Epoch: 0143 train_acc= 0.9950 test_acc= 0.8430\n",
      "Epoch: 0144 train_acc= 0.9950 test_acc= 0.8450\n",
      "Epoch: 0145 train_acc= 0.9950 test_acc= 0.8440\n",
      "Epoch: 0146 train_acc= 1.0000 test_acc= 0.8450\n",
      "Epoch: 0147 train_acc= 0.9950 test_acc= 0.8430\n",
      "Epoch: 0148 train_acc= 1.0000 test_acc= 0.8430\n",
      "Epoch: 0149 train_acc= 1.0000 test_acc= 0.8440\n",
      "Epoch: 0150 train_acc= 1.0000 test_acc= 0.8440\n",
      "Epoch: 0151 train_acc= 1.0000 test_acc= 0.8440\n",
      "Epoch: 0152 train_acc= 1.0000 test_acc= 0.8450\n",
      "Epoch: 0153 train_acc= 1.0000 test_acc= 0.8440\n",
      "Epoch: 0154 train_acc= 1.0000 test_acc= 0.8480\n",
      "Epoch: 0155 train_acc= 1.0000 test_acc= 0.8500\n",
      "Epoch: 0156 train_acc= 1.0000 test_acc= 0.8510\n",
      "Epoch: 0157 train_acc= 1.0000 test_acc= 0.8530\n",
      "Epoch: 0158 train_acc= 1.0000 test_acc= 0.8510\n",
      "Epoch: 0159 train_acc= 0.9950 test_acc= 0.8460\n",
      "Epoch: 0160 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0161 train_acc= 0.9950 test_acc= 0.8480\n",
      "Epoch: 0162 train_acc= 0.9950 test_acc= 0.8520\n",
      "Epoch: 0163 train_acc= 0.9950 test_acc= 0.8530\n",
      "Epoch: 0164 train_acc= 0.9900 test_acc= 0.8550\n",
      "Epoch: 0165 train_acc= 0.9900 test_acc= 0.8540\n",
      "Epoch: 0166 train_acc= 0.9900 test_acc= 0.8500\n",
      "Epoch: 0167 train_acc= 0.9900 test_acc= 0.8510\n",
      "Epoch: 0168 train_acc= 0.9900 test_acc= 0.8530\n",
      "Epoch: 0169 train_acc= 0.9900 test_acc= 0.8520\n",
      "Epoch: 0170 train_acc= 0.9900 test_acc= 0.8490\n",
      "Epoch: 0171 train_acc= 0.9900 test_acc= 0.8510\n",
      "Epoch: 0172 train_acc= 0.9900 test_acc= 0.8540\n",
      "Epoch: 0173 train_acc= 0.9950 test_acc= 0.8550\n",
      "Epoch: 0174 train_acc= 0.9950 test_acc= 0.8550\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0175 train_acc= 0.9950 test_acc= 0.8550\n",
      "Epoch: 0176 train_acc= 0.9950 test_acc= 0.8510\n",
      "Epoch: 0177 train_acc= 0.9950 test_acc= 0.8410\n",
      "Epoch: 0178 train_acc= 0.9950 test_acc= 0.8400\n",
      "Epoch: 0179 train_acc= 0.9950 test_acc= 0.8380\n",
      "Epoch: 0180 train_acc= 0.9950 test_acc= 0.8380\n",
      "Epoch: 0181 train_acc= 0.9950 test_acc= 0.8370\n",
      "Epoch: 0182 train_acc= 0.9950 test_acc= 0.8390\n",
      "Epoch: 0183 train_acc= 0.9950 test_acc= 0.8420\n",
      "Epoch: 0184 train_acc= 0.9950 test_acc= 0.8430\n",
      "Epoch: 0185 train_acc= 0.9950 test_acc= 0.8440\n",
      "Epoch: 0186 train_acc= 0.9950 test_acc= 0.8450\n",
      "Epoch: 0187 train_acc= 0.9950 test_acc= 0.8490\n",
      "Epoch: 0188 train_acc= 1.0000 test_acc= 0.8520\n",
      "Epoch: 0189 train_acc= 1.0000 test_acc= 0.8530\n",
      "Epoch: 0190 train_acc= 1.0000 test_acc= 0.8550\n",
      "Epoch: 0191 train_acc= 0.9950 test_acc= 0.8530\n",
      "Epoch: 0192 train_acc= 0.9950 test_acc= 0.8530\n",
      "Epoch: 0193 train_acc= 0.9950 test_acc= 0.8530\n",
      "Epoch: 0194 train_acc= 0.9950 test_acc= 0.8470\n",
      "Epoch: 0195 train_acc= 0.9950 test_acc= 0.8430\n",
      "Epoch: 0196 train_acc= 0.9950 test_acc= 0.8360\n",
      "Epoch: 0197 train_acc= 0.9950 test_acc= 0.8300\n",
      "Epoch: 0198 train_acc= 0.9950 test_acc= 0.8300\n",
      "Epoch: 0199 train_acc= 0.9950 test_acc= 0.8300\n"
     ]
    }
   ],
   "source": [
    "nb_epochs = 200\n",
    "\n",
    "class_weight = class_weight.compute_class_weight('balanced', np.unique(labels_train), labels_train)\n",
    "class_weight_dic = dict(enumerate(class_weight))\n",
    "\n",
    "for epoch in range(nb_epochs):\n",
    "    model_dense_block.fit(X, Y_train, sample_weight=train_mask, batch_size=A.shape[0], epochs=1, shuffle=False, \n",
    "                          class_weight=class_weight_dic, verbose=0)\n",
    "    Y_pred = model_dense_block.predict(X, batch_size=A.shape[0])\n",
    "    _, train_acc = evaluate_preds(Y_pred, [Y_train], [train_idx])\n",
    "    _, test_acc = evaluate_preds(Y_pred, [Y_test], [test_idx])\n",
    "    print(\"Epoch: {:04d}\".format(epoch), \"train_acc= {:.4f}\".format(train_acc[0]), \"test_acc= {:.4f}\".format(test_acc[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
