{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Functional Encryption - Classification and information leakage\n",
    "\n",
    "This code is made to convert the private layers in a format compatible with  project which implements Quadratic Functional Encryption for real.\n",
    "\n",
    "- make sure the Pytorch model is loaded is the appropriate path\n",
    "- Load it back, transform and pickle the private layers to be readable for the QFE project."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Allow to load packages from parent\n",
    "import sys, os\n",
    "sys.path.insert(1, os.path.realpath(os.path.pardir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from math import log2, ceil\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as utils\n",
    "\n",
    "import learn\n",
    "from learn import collateral\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "PRIVATE_OUTPUT_SIZE = 4\n",
    "N_CHARS = 10\n",
    "N_FONTS = 2\n",
    "prec=(3, 7, 5)\n",
    "\n",
    "HIDDEN_WIDTH = 40"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load the Pytorch model you need to re-specify its structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CollateralNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CollateralNet, self).__init__()\n",
    "        self.proj1 = nn.Linear(784, HIDDEN_WIDTH)\n",
    "        self.diag1 = nn.Linear(HIDDEN_WIDTH, PRIVATE_OUTPUT_SIZE, bias=False)\n",
    "\n",
    "        # --- FFN for characters\n",
    "        self.lin1 = nn.Linear(PRIVATE_OUTPUT_SIZE, 32)\n",
    "        self.lin2 = nn.Linear(32, N_CHARS)\n",
    "\n",
    "        # --- Junction\n",
    "        self.jct = nn.Linear(PRIVATE_OUTPUT_SIZE, 784)\n",
    "\n",
    "        # --- CNN for families\n",
    "        self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
    "        self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "        self.fc1 = nn.Linear(4 * 4 * 50, 500)\n",
    "        self.fc2 = nn.Linear(500, N_FONTS)\n",
    "\n",
    "    def quad(self, x):\n",
    "        # --- Quadratic\n",
    "        x = x.view(-1, 784)\n",
    "        x = self.proj1(x)\n",
    "        x = x * x\n",
    "        x = self.diag1(x)\n",
    "        return x\n",
    "\n",
    "    def char_net(self, x):\n",
    "        # --- FFN\n",
    "        x = F.relu(x)\n",
    "        x = F.relu(self.lin1(x))\n",
    "        x = self.lin2(x)\n",
    "        return x\n",
    "\n",
    "    def font_net(self, x):\n",
    "        # --- Junction\n",
    "        x = self.jct(x)\n",
    "        x = x.view(-1, 1, 28, 28)\n",
    "\n",
    "        # --- CNN\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = x.view(-1, 4 * 4 * 50)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "    def forward_char(self, x):\n",
    "        x = self.quad(x)\n",
    "        x = self.char_net(x)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "\n",
    "    def forward_font(self, x):\n",
    "        x = self.quad(x)\n",
    "        x = self.font_net(x)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "    \n",
    "    # We add the ability to freeze some layers to ensure that the collateral task does\n",
    "    # not modify the quadratic net\n",
    "    \n",
    "    def get_params(self, net):\n",
    "        \"\"\"Select the params for a given part of the net\"\"\"\n",
    "        if net == 'quad':\n",
    "            layers = [self.proj1, self.diag1]\n",
    "        elif net == 'char':\n",
    "            layers = [self.lin1, self.lin2]\n",
    "        elif net == 'font':\n",
    "            layers = [self.jct, self.fc1, self.fc2, self.conv1, self.conv2]\n",
    "        else:\n",
    "            raise AttributeError(f'{net} type not recognized')\n",
    "        params = [p for layer in layers for p in layer.parameters()]\n",
    "        return params\n",
    "\n",
    "    def freeze(self, net):\n",
    "        \"\"\"Freeze a part of the net\"\"\"\n",
    "        net_params = self.get_params(net)\n",
    "        for param in net_params:\n",
    "            param.requires_grad = False\n",
    "\n",
    "    def unfreeze(self):\n",
    "        \"\"\"Unfreeze the net\"\"\"\n",
    "        for param in self.parameters():\n",
    "            param.requires_grad = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This path needs to be adapter depending of your project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CollateralNet(\n",
       "  (proj1): Linear(in_features=784, out_features=40, bias=True)\n",
       "  (diag1): Linear(in_features=40, out_features=4, bias=False)\n",
       "  (lin1): Linear(in_features=4, out_features=32, bias=True)\n",
       "  (lin2): Linear(in_features=32, out_features=10, bias=True)\n",
       "  (jct): Linear(in_features=4, out_features=784, bias=True)\n",
       "  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n",
       "  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n",
       "  (fc1): Linear(in_features=800, out_features=500, bias=True)\n",
       "  (fc2): Linear(in_features=500, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "CODE_PATH = \"/Users/.../code/\"\n",
    "path = CODE_PATH + 'QFEproject/mnist/objects/ml_models/quad_conv.pt'\n",
    "model = CollateralNet()\n",
    "results = {}\n",
    "\n",
    "model.load_state_dict(torch.load(path))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-2., -2.,  2.,  ...,  0.,  2.,  3.],\n",
       "        [ 2.,  0., -4.,  ...,  1.,  0., -2.],\n",
       "        [ 3.,  0.,  0.,  ...,  0.,  0., -3.],\n",
       "        ...,\n",
       "        [-1.,  3., -2.,  ..., -2., -2.,  0.],\n",
       "        [ 4., -3., -3.,  ...,  4., -3.,  3.],\n",
       "        [ 0.,  0., -3.,  ..., -1., -2., -3.]], requires_grad=True)"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.proj1.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([784, 40])"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.proj1.weight.t().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_prec, proj_prec, diag_prec = prec \n",
    "proj = torch.cat((model.proj1.bias.reshape(1, HIDDEN_WIDTH) / 2**data_prec, model.proj1.weight.t()), 0)\n",
    "proj = proj.long().tolist()\n",
    "diag = model.diag1.weight.t().long().tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pickle the private layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(proj) == (784 + 1)\n",
    "assert len(proj[0]) == len(diag)\n",
    "assert len(diag[0]) == PRIVATE_OUTPUT_SIZE\n",
    "model = (proj,diag)\n",
    "\n",
    "with open('/Users/.../code/QFEproject/mnist/objects/ml_models/torch_cl_large.mlm', 'wb') as f:\n",
    "    pickle.dump(model, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
