{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import initializers, activations, constraints, regularizers\n",
    "import keras.backend as K\n",
    "from keras.engine.topology import Layer\n",
    "import tensorflow as tf\n",
    "from graph_arma_conv_op import *\n",
    "\n",
    "class ArmaGraphFiltersCNN(Layer):\n",
    "    \n",
    "    def __init__(self,\n",
    "                output_dim,\n",
    "                num_filters,\n",
    "                arma_conv_filters,\n",
    "                numerator_poly_coeff,\n",
    "                denominator_poly_coeff,\n",
    "                max_iterations,\n",
    "                tolerance=1e-10,\n",
    "                activation=None,\n",
    "                use_bias=True,\n",
    "                kernel_initializer='glorot_uniform',\n",
    "                bias_initializer='zeros',\n",
    "                kernel_regularizer=None,\n",
    "                bias_regularizer=None,\n",
    "                activity_regularizer=None,\n",
    "                kernel_constraint=None,\n",
    "                bias_constraint=None,\n",
    "                **kwargs):\n",
    "        super(ArmaGraphFiltersCNN, self).__init__(**kwargs)\n",
    "        \n",
    "        self.output_dim = output_dim\n",
    "        self.num_filters = num_filters\n",
    "        self.arma_conv_filters = arma_conv_filters\n",
    "        self.numerator_poly_coeff = numerator_poly_coeff\n",
    "        self.denominator_poly_coeff = denominator_poly_coeff\n",
    "        self.max_iterations = max_iterations\n",
    "        self.tolerance = tolerance\n",
    "        \n",
    "        self.activation = activations.get(activation)\n",
    "        self.use_bias = use_bias\n",
    "        self.kernel_initializer = initializers.get(kernel_initializer)\n",
    "        self.kernel_initializer.__name__ = kernel_initializer\n",
    "        self.bias_initializer = initializers.get(bias_initializer)\n",
    "        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n",
    "        self.bias_regularizer = regularizers.get(bias_regularizer)\n",
    "        self.activity_regularizer = regularizers.get(activity_regularizer)\n",
    "        self.kernel_constraint = constraints.get(kernel_constraint)\n",
    "        self.bias_constraint = constraints.get(bias_constraint)\n",
    "        \n",
    "    def build(self, input_shape):\n",
    "        \n",
    "        self.input_dim = input_shape[-1]\n",
    "        kernel_shape = (self.num_filters * self.input_dim, self.output_dim)\n",
    "        \n",
    "        self.kernel = self.add_weight(shape=kernel_shape,\n",
    "                                      initializer=self.kernel_initializer,\n",
    "                                      name='kernel',\n",
    "                                      regularizer=self.kernel_regularizer,\n",
    "                                      constraint=self.kernel_constraint)\n",
    "        if(self.use_bias):\n",
    "            self.bias = self.add_weight(shape=(self.output_dim,),\n",
    "                                        initializer=self.kernel_initializer,\n",
    "                                        name='bias',\n",
    "                                        regularizer=self.bias_regularizer,\n",
    "                                        constraint=self.bias_constraint)\n",
    "        else:\n",
    "            self.bias = None\n",
    "            \n",
    "        self.built = True\n",
    "        \n",
    "    def call(self, input):\n",
    "        \n",
    "        output = arma_graph_conv(input,\n",
    "                                 self.num_filters,\n",
    "                                 self.arma_conv_filters,\n",
    "                                 self.numerator_poly_coeff,\n",
    "                                 self.denominator_poly_coeff,\n",
    "                                 self.max_iterations,\n",
    "                                 self.tolerance,\n",
    "                                 self.kernel)\n",
    "        if(self.use_bias):\n",
    "            output = K.bias_add(output, self.bias)\n",
    "        if(self.activation is not None):\n",
    "            output = self.activation(output)\n",
    "        \n",
    "        return output\n",
    "    \n",
    "    def compute_output_shape(self, input_shape):\n",
    "        output_shape = (input_shape[0], self.output_dim)\n",
    "        return output_shape\n",
    "    \n",
    "    def get_config(self):\n",
    "        \n",
    "        config = {\n",
    "            'output_dim': self.output_dim,\n",
    "            'num_filters': self.num_filters,\n",
    "            'arma_conv_filters': self.arma_conv_filters,\n",
    "            'numerator_poly_coeff': self.numerator_poly_coeff,\n",
    "            'denominator_poly_coeff': self.denominator_poly_coeff,\n",
    "            'max_iterations': self.max_iterations,\n",
    "            'tolerance': self.tolerance,\n",
    "            'activation': activations.serialize(self.activation),\n",
    "            'use_bias': self.use_bias,\n",
    "            'kernel_initializer': initializers.serialize(self.kernel_initializer),\n",
    "            'bias_initializer': initializers.serialize(self.bias_initializer),\n",
    "            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),\n",
    "            'bias_regularizer': regularizers.serialize(self.bias_regularizer),\n",
    "            'activity_regularizer': regularizers.serialize(self.activity_regularizer),\n",
    "            'kernel_constraint': constraints.serialize(self.kernel_constraint),\n",
    "            'bias_constraint': constraints.serialize(self.bias_constraint)\n",
    "        }\n",
    "        \n",
    "        base_config = super(ArmaGraphFiltersCNN, self).get_config()\n",
    "        \n",
    "        return dict(list(base_config.items()) + list(config.items()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
