{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from train_by_reconnect.LaPerm import LaPermTrainLoop\n",
    "from train_by_reconnect.weight_utils import random_prune\n",
    "from train_by_reconnect.viz_utils import Profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.datasets import cifar10\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten\n",
    "from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization\n",
    "from tensorflow.keras import optimizers\n",
    "from tensorflow.keras import regularizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(x_train, x_test):\n",
    "    # normalize inputs for zero mean and unit variance\n",
    "    mean, std = np.mean(x_train), np.std(x_train)\n",
    "    X_train = (x_train-mean)/(std+1e-8)\n",
    "    X_test = (x_test-mean)/(std+1e-8)\n",
    "    return x_train, x_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training parameters\n",
    "batch_size = 50\n",
    "epochs = 125\n",
    "\n",
    "learning_rate = 0.001 # initial learning rate\n",
    "lr_drop = 10\n",
    "\n",
    "tsize = 30000 # size of data for getting the train accuracy\n",
    "vali_freq = 250 # validate per vali_freq batches\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "x_train, x_test = normalize(x_train, x_test)\n",
    "\n",
    "def lr_scheduler(epoch):\n",
    "    learning_rate = 0.001\n",
    "    return learning_rate * (0.6 ** (epoch // lr_drop))\n",
    "\n",
    "def k_scheduler(epoch):\n",
    "    return 1000\n",
    "\n",
    "# data augmentation\n",
    "datagen = ImageDataGenerator(\n",
    "    rotation_range=15,\n",
    "    width_shift_range=0.1,\n",
    "    height_shift_range=0.1,\n",
    "    horizontal_flip=True)\n",
    "datagen.fit(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_shape = (32, 32, 3)\n",
    "initializer = 'he_uniform'\n",
    "regularizer = regularizers.l2(1e-4)\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Conv2D(64, (3, 3), padding='same',\n",
    "                 input_shape=x_shape, \n",
    "                 kernel_regularizer=regularizer,\n",
    "                 kernel_initializer=initializer, \n",
    "                 bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.3))\n",
    "\n",
    "model.add(Conv2D(64, (3, 3), padding='same', \n",
    "                 kernel_regularizer=regularizer,\n",
    "                 kernel_initializer=initializer, \n",
    "                 bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "model.add(Flatten())\n",
    "model.add(Dense(256, \n",
    "                kernel_regularizer=regularizer,\n",
    "                kernel_initializer=initializer, \n",
    "                bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Dropout(0.5))\n",
    "\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dense(256, \n",
    "                kernel_regularizer=regularizer,\n",
    "                kernel_initializer=initializer, \n",
    "                bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "model.add(Dropout(0.5))\n",
    "model.add(Dense(10,\n",
    "                kernel_initializer=initializer, \n",
    "                bias_initializer=initializer))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "# random_prune(model, prune_rate=0.7) # uncomment for random pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loop = LaPermTrainLoop(model=model,\n",
    "                       loss='sparse_categorical_crossentropy',\n",
    "                       inner_optimizer=tf.keras.optimizers.Adam(),\n",
    "                       k_schedule=k_scheduler,\n",
    "                       lr_schedule=lr_scheduler)\n",
    "loop.fit(x_train, y_train,\n",
    "         batch_size, epochs=epochs,\n",
    "         datagen=datagen, \n",
    "         validation_data=(x_test, y_test), \n",
    "         validation_freq=vali_freq, \n",
    "         tsize=tsize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize train and validation accuracies\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(loop._history['val accuracy'], label='Validation Accuracy')\n",
    "plt.plot(loop._history['accuracy'], label='Train Accuracy')\n",
    "plt.grid(linestyle='--')\n",
    "plt.xlabel('Epochs', size=15)\n",
    "plt.ylabel('Accuracy', size=15)\n",
    "plt.legend(prop={'size':15})\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
