{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from train_by_reconnect.LaPerm import LaPermTrainLoop\n",
    "from train_by_reconnect.weight_utils import random_prune\n",
    "from train_by_reconnect.viz_utils import Profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.datasets import cifar10\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten\n",
    "from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization\n",
    "from tensorflow.keras import optimizers\n",
    "from tensorflow.keras import regularizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(x_train, x_test):\n",
    "    # normalize inputs for zero mean and unit variance\n",
    "    mean, std = np.mean(x_train), np.std(x_train)\n",
    "    X_train = (x_train-mean)/(std+1e-8)\n",
    "    X_test = (x_test-mean)/(std+1e-8)\n",
    "    return x_train, x_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training parameters\n",
    "batch_size = 50\n",
    "epochs = 125\n",
    "\n",
    "learning_rate = 0.001 # initial learning rate\n",
    "lr_drop = 10\n",
    "\n",
    "tsize = 30000 # size of data for getting the train accuracy\n",
    "vali_freq = 250 # validate per vali_freq batches\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "x_train, x_test = normalize(x_train, x_test)\n",
    "\n",
    "def lr_scheduler(epoch):\n",
    "    learning_rate = 0.001\n",
    "    return learning_rate * (0.6 ** (epoch // lr_drop))\n",
    "\n",
    "def k_scheduler(epoch):\n",
    "    return 1000\n",
    "\n",
    "# data augmentation\n",
    "datagen = ImageDataGenerator(\n",
    "    rotation_range=15,\n",
    "    width_shift_range=0.1,\n",
    "    height_shift_range=0.1,\n",
    "    horizontal_flip=True)\n",
    "datagen.fit(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "\n",
    "x_shape=(32,32,3)\n",
    "\n",
    "initializer = 'he_uniform'\n",
    "regularizer = regularizers.l2(1e-4)\n",
    "\n",
    "model.add(Conv2D(64, (3, 3), padding='same',\n",
    "                    input_shape=x_shape, kernel_regularizer=regularizer,\n",
    "                    kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.3))\n",
    "\n",
    "model.add(Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizer,\n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizer,\n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizer,\n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizer,\n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizer,\n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizer,\n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizer, \n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizer, \n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizer, \n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "\n",
    "model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizer, \n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizer, \n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizer, \n",
    "kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "model.add(Dropout(0.5))\n",
    "\n",
    "model.add(Flatten())\n",
    "model.add(Dense(512, kernel_regularizer=regularizer, \n",
    "                kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('relu'))\n",
    "model.add(BatchNormalization())\n",
    "\n",
    "model.add(Dropout(0.5))\n",
    "model.add(Dense(10, kernel_initializer=initializer,bias_initializer=initializer))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "# random_prune(model, prune_rate=0.7) # uncomment for random pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loop = LaPermTrainLoop(model=model,\n",
    "                       loss='sparse_categorical_crossentropy',\n",
    "                       inner_optimizer=tf.keras.optimizers.Adam(),\n",
    "                       k_schedule=k_scheduler,\n",
    "                       lr_schedule=lr_scheduler)\n",
    "loop.fit(x_train, y_train,\n",
    "         batch_size, epochs=epochs,\n",
    "         datagen=datagen, \n",
    "         validation_data=(x_test, y_test), \n",
    "         validation_freq=vali_freq, \n",
    "         tsize=tsize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Profiler(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize train and validation accuracies\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(loop._history['val accuracy'], label='Validation Accuracy')\n",
    "plt.plot(loop._history['accuracy'], label='Train Accuracy')\n",
    "plt.grid(linestyle='--')\n",
    "plt.xlabel('Epochs', size=15)\n",
    "plt.ylabel('Accuracy', size=15)\n",
    "plt.legend(prop={'size':15})\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
