{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from train_by_reconnect.LaPerm import LaPermTrainLoop\n",
    "from train_by_reconnect.weight_utils import random_prune\n",
    "from train_by_reconnect.viz_utils import Profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation\n",
    "from tensorflow.keras.layers import AveragePooling2D, Input, Flatten\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.datasets import cifar10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training hyperparameters\n",
    "batch_size = 50  # orig paper trained all networks with batch_size=128\n",
    "epochs = 200\n",
    "\n",
    "tsize = 30000 # size of data for getting the train accuracy\n",
    "vali_freq = 800 # validate per vali_freq batches\n",
    "\n",
    "learning_rate = 0.001\n",
    "\n",
    "# Load the CIFAR10 data.\n",
    "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
    "\n",
    "# Input image dimensions.\n",
    "input_shape = x_train.shape[1:]\n",
    "\n",
    "# Normalize data.\n",
    "x_train = x_train.astype('float32') / 255\n",
    "x_test = x_test.astype('float32') / 255\n",
    "\n",
    "# If subtract pixel mean is enabled\n",
    "x_train_mean = np.mean(x_train, axis=0)\n",
    "x_train -= x_train_mean\n",
    "x_test -= x_train_mean\n",
    "\n",
    "print('x_train shape:', x_train.shape)\n",
    "print(x_train.shape[0], 'train samples')\n",
    "print(x_test.shape[0], 'test samples')\n",
    "print('y_train shape:', y_train.shape)\n",
    "\n",
    "# Convert class vectors to binary class matrices.\n",
    "y_train = tf.keras.utils.to_categorical(y_train, 10)\n",
    "y_test = tf.keras.utils.to_categorical(y_test, 10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 8\n",
    "depth = n * 6 + 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Modified based on https://keras.io/examples/cifar10_resnet/\n",
    "\n",
    "initializer = 'he_uniform'\n",
    "regularizer = l2(1e-4)\n",
    "\n",
    "\n",
    "def resnet_layer(inputs,\n",
    "                 num_filters=16,\n",
    "                 kernel_size=3,\n",
    "                 strides=1,\n",
    "                 activation='relu',\n",
    "                 batch_normalization=True):\n",
    "\n",
    "    conv = Conv2D(num_filters,\n",
    "                  kernel_size=kernel_size,\n",
    "                  strides=strides,\n",
    "                  padding='same',\n",
    "                  kernel_initializer=initializer,\n",
    "                  kernel_regularizer=regularizer)\n",
    "    x = inputs\n",
    "    x = conv(x)\n",
    "    if batch_normalization:\n",
    "        x = BatchNormalization()(x)\n",
    "    if activation is not None:\n",
    "        x = Activation(activation)(x)\n",
    "    return x\n",
    "\n",
    "\n",
    "def resnet(input_shape, depth, num_classes=10):\n",
    "    assert (depth - 2) % 6 == 0, \"incorrect depth.\"\n",
    "\n",
    "    num_filters = 16\n",
    "    num_res_blocks = int((depth - 2) / 6)\n",
    "\n",
    "    inputs = Input(shape=input_shape)\n",
    "    x = resnet_layer(inputs=inputs)\n",
    "\n",
    "    for stack in range(3):\n",
    "        for res_block in range(num_res_blocks):\n",
    "            strides = 1\n",
    "            if stack > 0 and res_block == 0:\n",
    "                strides = 2\n",
    "            y = resnet_layer(inputs=x,\n",
    "                             num_filters=num_filters,\n",
    "                             strides=strides)\n",
    "            y = resnet_layer(inputs=y,\n",
    "                             num_filters=num_filters,\n",
    "                             activation=None)\n",
    "            if stack > 0 and res_block == 0:\n",
    "                x = resnet_layer(inputs=x,\n",
    "                                 num_filters=num_filters,\n",
    "                                 kernel_size=1,\n",
    "                                 strides=strides,\n",
    "                                 activation=None,\n",
    "                                 batch_normalization=False)\n",
    "            x = tf.keras.layers.add([x, y])\n",
    "            x = Activation('relu')(x)\n",
    "        num_filters *= 2\n",
    "\n",
    "    x = AveragePooling2D(pool_size=8)(x)\n",
    "    y = Flatten()(x)\n",
    "    outputs = Dense(num_classes,\n",
    "                    activation='softmax',\n",
    "                    kernel_initializer=initializer)(y)\n",
    "\n",
    "    return Model(inputs=inputs, outputs=outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datagen = ImageDataGenerator(\n",
    "    width_shift_range=0.1,\n",
    "    height_shift_range=0.1,\n",
    "    horizontal_flip=True)\n",
    "\n",
    "# Compute quantities required for featurewise normalization\n",
    "# (std, mean, and principal components if ZCA whitening is applied).\n",
    "datagen.fit(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lr_scheduler(epoch):\n",
    "    lr = learning_rate\n",
    "    if epoch > 180:\n",
    "        lr *= 0.5e-3\n",
    "    elif epoch > 160:\n",
    "        lr *= 1e-3\n",
    "    elif epoch > 120:\n",
    "        lr *= 1e-2\n",
    "    elif epoch > 80:\n",
    "        lr *= 1e-1\n",
    "    return lr\n",
    "\n",
    "def k_scheduler(epoch):\n",
    "    return 800"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = resnet(input_shape=input_shape, depth=depth)\n",
    "# random_prune(model, prune_rate=0.7) # uncomment for random pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_test.shape, y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 200\n",
    "loop = LaPermTrainLoop(model=model,\n",
    "                       loss='categorical_crossentropy',\n",
    "                       inner_optimizer=Adam(),\n",
    "                       k_schedule=k_scheduler, \n",
    "                       lr_schedule=lr_scheduler)\n",
    "\n",
    "loop.fit(x_train, y_train, batch_size, epochs=epochs,\n",
    "         datagen=datagen, validation_data=(x_test, y_test),\n",
    "         validation_freq=vali_freq, \n",
    "         tsize=tsize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Profiler(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize train and validation accuracies\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(loop._history['val accuracy'], label='Validation Accuracy')\n",
    "plt.plot(loop._history['accuracy'], label='Train Accuracy')\n",
    "plt.grid(linestyle='--')\n",
    "plt.xlabel('Epochs', size=15)\n",
    "plt.ylabel('Accuracy', size=15)\n",
    "plt.legend(prop={'size':15})\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
