{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "occupancy.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "4mavG5wNgb09",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import jax\n",
        "from jax import random, grad, jit, vmap\n",
        "from jax.config import config\n",
        "from jax.lib import xla_bridge\n",
        "import jax.numpy as np\n",
        "\n",
        "from jax.experimental import stax\n",
        "from jax.experimental import optimizers\n",
        "\n",
        "from livelossplot import PlotLosses\n",
        "import matplotlib.pyplot as plt\n",
        "from tqdm.notebook import tqdm as tqdm\n",
        "\n",
        "import time\n",
        "import imageio\n",
        "import json\n",
        "import os\n",
        "\n",
        "import numpy as onp\n",
        "\n",
        "from IPython.display import clear_output\n",
        "\n",
        "## Random seed\n",
        "rand_key = random.PRNGKey(0)\n",
        "\n",
        "prop_cycle = plt.rcParams['axes.prop_cycle']\n",
        "colors = prop_cycle.by_key()['color']\n",
        "\n",
        "basedir = '' # base output dir\n",
        "\n",
        "import trimesh\n",
        "import pyembree"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A8PxkO55hMHO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def as_mesh(scene_or_mesh):\n",
        "    \"\"\"\n",
        "    Convert a possible scene to a mesh.\n",
        "\n",
        "    If conversion occurs, the returned mesh has only vertex and face data.\n",
        "    \"\"\"\n",
        "    if isinstance(scene_or_mesh, trimesh.Scene):\n",
        "        if len(scene_or_mesh.geometry) == 0:\n",
        "            mesh = None  # empty scene\n",
        "        else:\n",
        "            # we lose texture information here\n",
        "            mesh = trimesh.util.concatenate(\n",
        "                tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)\n",
        "                    for g in scene_or_mesh.geometry.values()))\n",
        "    else:\n",
        "        assert(isinstance(scene_or_mesh, trimesh.Trimesh))\n",
        "        mesh = scene_or_mesh\n",
        "    return mesh\n",
        "\n",
        "\n",
        "def recenter_mesh(mesh):\n",
        "  mesh.vertices -= mesh.vertices.mean(0)\n",
        "  mesh.vertices /= np.max(np.abs(mesh.vertices))\n",
        "  mesh.vertices = .5 * (mesh.vertices + 1.)\n",
        "\n",
        "\n",
        "def load_mesh(mesh_name, verbose=True):\n",
        "\n",
        "  mesh = trimesh.load(mesh_files[mesh_name])\n",
        "  mesh = as_mesh(mesh)\n",
        "  if verbose: \n",
        "    print(mesh.vertices.shape)\n",
        "  recenter_mesh(mesh)\n",
        "\n",
        "  c0, c1 = mesh.vertices.min(0) - 1e-3, mesh.vertices.max(0) + 1e-3\n",
        "  corners = [c0, c1]\n",
        "  if verbose:\n",
        "    print(c0, c1)\n",
        "    print(c1-c0)\n",
        "    print(np.prod(c1-c0))\n",
        "    print(.5 * (c0+c1) * 2 - 1)\n",
        "\n",
        "    \n",
        "  test_pt_file = os.path.join(logdir, mesh_name + '_test_pts.npy')\n",
        "  if not os.path.exists(test_pt_file):\n",
        "    if verbose: print('regen pts')\n",
        "    test_pts = np.array([make_test_pts(mesh, corners), make_test_pts(mesh, corners)])\n",
        "    np.save(test_pt_file, test_pts)\n",
        "  else:\n",
        "    if verbose: print('load pts')\n",
        "    test_pts = np.load(test_pt_file)\n",
        "\n",
        "  if verbose: print(test_pts.shape)\n",
        "\n",
        "  return mesh, corners, test_pts\n",
        "\n",
        "\n",
        "###################\n",
        "\n",
        "\n",
        "\n",
        "def make_network(num_layers, num_channels):\n",
        "  layers = []\n",
        "  for i in range(num_layers-1):\n",
        "      layers.append(stax.Dense(num_channels))\n",
        "      layers.append(stax.Relu)\n",
        "  layers.append(stax.Dense(1))\n",
        "  return stax.serial(*layers)\n",
        "\n",
        "\n",
        "input_encoder = jit(lambda x, a, b: (np.concatenate([a * np.sin((2.*np.pi*x) @ b.T), \n",
        "                                                     a * np.cos((2.*np.pi*x) @ b.T)], axis=-1) / np.linalg.norm(a)) if a is not None else (x * 2. - 1.))\n",
        "\n",
        "\n",
        "\n",
        "trans_t = lambda t : np.array([\n",
        "    [1,0,0,0],\n",
        "    [0,1,0,0],\n",
        "    [0,0,1,t],\n",
        "    [0,0,0,1],\n",
        "], dtype=np.float32)\n",
        "\n",
        "rot_phi = lambda phi : np.array([\n",
        "    [1,0,0,0],\n",
        "    [0,np.cos(phi),-np.sin(phi),0],\n",
        "    [0,np.sin(phi), np.cos(phi),0],\n",
        "    [0,0,0,1],\n",
        "], dtype=np.float32)\n",
        "\n",
        "rot_theta = lambda th : np.array([\n",
        "    [np.cos(th),0,-np.sin(th),0],\n",
        "    [0,1,0,0],\n",
        "    [np.sin(th),0, np.cos(th),0],\n",
        "    [0,0,0,1],\n",
        "], dtype=np.float32)\n",
        "\n",
        "\n",
        "def pose_spherical(theta, phi, radius):\n",
        "    c2w = trans_t(radius)\n",
        "    c2w = rot_phi(phi/180.*np.pi) @ c2w\n",
        "    c2w = rot_theta(theta/180.*np.pi) @ c2w\n",
        "    # c2w = np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]) @ c2w\n",
        "    return c2w\n",
        "\n",
        "\n",
        "\n",
        "def get_rays(H, W, focal, c2w):\n",
        "    i, j = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')\n",
        "    dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)\n",
        "    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)\n",
        "    rays_o = np.broadcast_to(c2w[:3,-1], rays_d.shape)\n",
        "    return np.stack([rays_o, rays_d], 0)\n",
        "\n",
        "get_rays = jit(get_rays, static_argnums=(0, 1, 2,))\n",
        "\n",
        "#########\n",
        "\n",
        "\n",
        "def render_rays_native_hier(params, ab, rays, corners, near, far, N_samples, N_samples_2, clip): #, rand=False):\n",
        "    rays_o, rays_d = rays[0], rays[1]\n",
        "    c0, c1 = corners\n",
        "\n",
        "    th = .5\n",
        "    \n",
        "    # Compute 3D query points\n",
        "    z_vals = np.linspace(near, far, N_samples) \n",
        "    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]\n",
        "    \n",
        "    # Run network\n",
        "    alpha = jax.nn.sigmoid(np.squeeze(apply_fn(params, input_encoder(.5 * (pts + 1), *ab))))\n",
        "    if clip:\n",
        "      mask = np.logical_or(np.any(.5 * (pts + 1) < c0, -1), np.any(.5 * (pts + 1) > c1, -1))\n",
        "      alpha = np.where(mask, 0., alpha)\n",
        "\n",
        "    alpha = np.where(alpha > th, 1., 0)\n",
        "\n",
        "    trans = 1.-alpha + 1e-10\n",
        "    trans = np.concatenate([np.ones_like(trans[...,:1]), trans[...,:-1]], -1)  \n",
        "    weights = alpha * np.cumprod(trans, -1)\n",
        "    \n",
        "    depth_map = np.sum(weights * z_vals, -1) \n",
        "    acc_map = np.sum(weights, -1)\n",
        "\n",
        "    # Second pass to refine isosurface\n",
        "\n",
        "    z_vals = np.linspace(-1., 1., N_samples_2) * .01 + depth_map[...,None]\n",
        "    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]\n",
        "\n",
        "    # Run network\n",
        "    alpha = jax.nn.sigmoid(np.squeeze(apply_fn(params, input_encoder(.5 * (pts + 1), *ab))))\n",
        "    if clip:\n",
        "      # alpha = np.where(np.any(np.abs(pts) > 1, -1), 0., alpha)\n",
        "      mask = np.logical_or(np.any(.5 * (pts + 1) < c0, -1), np.any(.5 * (pts + 1) > c1, -1))\n",
        "      alpha = np.where(mask, 0., alpha)\n",
        "\n",
        "    alpha = np.where(alpha > th, 1., 0)\n",
        "\n",
        "    trans = 1.-alpha + 1e-10\n",
        "    trans = np.concatenate([np.ones_like(trans[...,:1]), trans[...,:-1]], -1)  \n",
        "    weights = alpha * np.cumprod(trans, -1)\n",
        "    \n",
        "    depth_map = np.sum(weights * z_vals, -1) \n",
        "    acc_map = np.sum(weights, -1)\n",
        "\n",
        "    return depth_map, acc_map\n",
        "\n",
        "render_rays = jit(render_rays_native_hier, static_argnums=(3,4,5,6,7,8))\n",
        "\n",
        "\n",
        "\n",
        "@jit\n",
        "def make_normals(rays, depth_map):\n",
        "  rays_o, rays_d = rays\n",
        "  pts = rays_o + rays_d * depth_map[...,None]\n",
        "  dx = pts - np.roll(pts, -1, axis=0)\n",
        "  dy = pts - np.roll(pts, -1, axis=1)\n",
        "  normal_map = np.cross(dx, dy)\n",
        "  normal_map = normal_map / np.maximum(np.linalg.norm(normal_map, axis=-1, keepdims=True), 1e-5)\n",
        "  return normal_map\n",
        "\n",
        "\n",
        "def render_mesh_normals(mesh, rays):\n",
        "  origins, dirs = rays.reshape([2,-1,3])\n",
        "  origins = origins * .5 + .5\n",
        "  dirs = dirs * .5\n",
        "  z = mesh.ray.intersects_first(origins, dirs)\n",
        "  pic = onp.zeros([origins.shape[0],3]) \n",
        "  pic[z!=-1] = mesh.face_normals[z[z!=-1]]\n",
        "  pic = np.reshape(pic, rays.shape[1:])\n",
        "  return pic\n",
        "\n",
        "def uniform_bary(u):\n",
        "  su0 = np.sqrt(u[..., 0])\n",
        "  b0 = 1. - su0\n",
        "  b1 = u[..., 1] * su0\n",
        "  return np.stack([b0, b1, 1. - b0 - b1], -1)\n",
        "\n",
        "\n",
        "def get_normal_batch(mesh, bsize):\n",
        "\n",
        "  batch_face_inds = np.array(onp.random.randint(0, mesh.faces.shape[0], [bsize]))\n",
        "  batch_barys = np.array(uniform_bary(onp.random.uniform(size=[bsize, 2])))\n",
        "  batch_faces = mesh.faces[batch_face_inds]\n",
        "  batch_normals = mesh.face_normals[batch_face_inds]\n",
        "  batch_pts = np.sum(mesh.vertices[batch_faces] * batch_barys[...,None], 1)\n",
        "\n",
        "  return batch_pts, batch_normals\n",
        "\n",
        "\n",
        "def make_test_pts(mesh, corners, test_size=2**18):\n",
        "  c0, c1 = corners\n",
        "  test_easy = onp.random.uniform(size=[test_size, 3]) * (c1-c0) + c0\n",
        "  batch_pts, batch_normals = get_normal_batch(mesh, test_size)\n",
        "  test_hard = batch_pts + onp.random.normal(size=[test_size,3]) * .01\n",
        "  return test_easy, test_hard\n",
        "\n",
        "gt_fn = lambda queries, mesh : mesh.ray.contains_points(queries.reshape([-1,3])).reshape(queries.shape[:-1])\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z0qHpC4ahM2_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "embedding_size = 256\n",
        "embedding_method = 'gaussian'\n",
        "embedding_param = 12.\n",
        "embed_params = [embedding_method, embedding_size, embedding_param]\n",
        "init_fn, apply_fn = make_network(8, 256)\n",
        "\n",
        "N_iters = 10000\n",
        "batch_size = 64*64*2 * 4\n",
        "lr = 5e-4\n",
        "step = optimizers.exponential_decay(lr, 5000, .1)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3sw-wlOlEtCV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "R = 2.\n",
        "c2w = pose_spherical(90. + 10 + 45, -30., R)\n",
        "\n",
        "N_samples = 64\n",
        "N_samples_2 = 64\n",
        "H = 180\n",
        "W = H\n",
        "focal = H * .9\n",
        "rays = get_rays(H, W, focal, c2w[:3,:4])\n",
        "\n",
        "render_args_lr = [get_rays(H, W, focal, c2w[:3,:4]), None, R-1, R+1, N_samples, N_samples_2, True]\n",
        "  \n",
        "N_samples = 256\n",
        "N_samples_2 = 256\n",
        "H = 512\n",
        "W = H\n",
        "focal = H * .9\n",
        "rays = get_rays(H, W, focal, c2w[:3,:4])\n",
        "\n",
        "render_args_hr = [get_rays(H, W, focal, c2w[:3,:4]), None, R-1, R+1, N_samples, N_samples_2, True]\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BZH0Ymj4FJLN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "def run_training(embed_params, mesh, corners, test_pts, render_args_lr, name=''):\n",
        "\n",
        "  validation_pts, testing_pts = test_pts\n",
        "\n",
        "  N = 256\n",
        "  x_test = np.linspace(0.,1.,N, endpoint=False) * 1.\n",
        "  x_test = np.stack(np.meshgrid(*([x_test]*2), indexing='ij'), -1)\n",
        "  queries_plot = np.concatenate([x_test, .5 + np.zeros_like(x_test[...,0:1])], -1)\n",
        "\n",
        "  embedding_method, embedding_size, embedding_scale = embed_params\n",
        "  c0, c1 = corners\n",
        "\n",
        "  if embedding_method == 'gauss':\n",
        "    print('gauss bvals')\n",
        "    bvals = onp.random.normal(size=[embedding_size,3]) * embedding_scale\n",
        "\n",
        "  if embedding_method == 'posenc':\n",
        "    print('posenc bvals')\n",
        "    bvals = 2.**np.linspace(0,embedding_scale,embedding_size//3) - 1\n",
        "    bvals = np.reshape(np.eye(3)*bvals[:,None,None], [len(bvals)*3, 3])\n",
        "\n",
        "  if embedding_method == 'basic':\n",
        "    print('basic bvals')\n",
        "    bvals = np.eye(3)\n",
        "\n",
        "\n",
        "  if embedding_method == 'none':\n",
        "    print('NO abvals')\n",
        "    avals = None\n",
        "    bvals = None\n",
        "  else:\n",
        "    avals = np.ones_like(bvals[:,0])\n",
        "\n",
        "  ab = (avals, bvals)\n",
        "  x_enc = input_encoder(np.ones([1,3]), avals, bvals)\n",
        "  print(x_enc.shape)\n",
        "\n",
        "  _, net_params = init_fn(rand_key, (-1, x_enc.shape[-1]))\n",
        "\n",
        "  opt_init, opt_update, get_params = optimizers.adam(step)\n",
        "  opt_state = opt_init(net_params)\n",
        "\n",
        "  @jit\n",
        "  def network_pred(params, inputs):\n",
        "    return jax.nn.sigmoid(np.squeeze(apply_fn(params, input_encoder(inputs, *ab))))\n",
        "\n",
        "  @jit\n",
        "  def loss_fn(params, inputs, z):\n",
        "    x = (np.squeeze(apply_fn(params, input_encoder(inputs, *ab))[...,0]))\n",
        "    loss_main = np.mean(np.maximum(x, 0) - x * z + np.log(1 + np.exp(-np.abs(x))))\n",
        "    return loss_main\n",
        "\n",
        "  @jit\n",
        "  def step_fn(i, opt_state, inputs, outputs):\n",
        "    params = get_params(opt_state)\n",
        "    g = grad(loss_fn)(params, inputs, outputs)\n",
        "    return opt_update(i, g, opt_state)\n",
        "    \n",
        "  psnrs = []\n",
        "  losses = []\n",
        "  tests = [[],[]]\n",
        "  xs = []\n",
        "\n",
        "  gt_val = [gt_fn(test, mesh) for test in validation_pts]\n",
        "\n",
        "  for i in tqdm(range(N_iters+1)):\n",
        "    \n",
        "    inputs = onp.random.uniform(size=[batch_size, 3]) * (c1-c0) + c0\n",
        "    opt_state = step_fn(i, opt_state, inputs, gt_fn(inputs, mesh))  \n",
        "\n",
        "    if i%100==0:\n",
        "      clear_output(wait=True)\n",
        "\n",
        "      inputs = queries_plot\n",
        "      outputs = gt_fn(inputs, mesh)\n",
        "\n",
        "      losses.append(loss_fn(get_params(opt_state), inputs, outputs))\n",
        "\n",
        "      pred = network_pred(get_params(opt_state), inputs)\n",
        "      psnrs.append(-10.*np.log10(np.mean(np.square(pred-outputs))))\n",
        "      xs.append(i)\n",
        "      slices = [outputs, pred, np.abs(pred - outputs)]\n",
        "\n",
        "      renderings = list(render_rays(get_params(opt_state), ab, *render_args_lr))\n",
        "      renderings.append(make_normals(render_args_lr[0], renderings[0]) * .5 + .5)\n",
        "\n",
        "      for to_show in [slices, renderings]:\n",
        "        L = len(to_show)\n",
        "        plt.figure(figsize=(6*L,6))\n",
        "        for i, z in enumerate(to_show):\n",
        "          plt.subplot(1,L,i+1)\n",
        "          plt.imshow(z)\n",
        "          plt.colorbar()\n",
        "        plt.show()\n",
        "\n",
        "\n",
        "      plt.figure(figsize=(25,4))\n",
        "\n",
        "      plt.subplot(151)\n",
        "      plt.plot(xs, psnrs)\n",
        "      plt.subplot(152)\n",
        "      plt.plot(xs, np.log10(np.array(losses)))\n",
        "\n",
        "      for j, test in enumerate(validation_pts):\n",
        "        full_pred = network_pred(get_params(opt_state), test)\n",
        "        # outputs = gt_fn(test, mesh)\n",
        "        outputs = gt_val[j]\n",
        "        val_iou = np.logical_and(full_pred > .5, outputs > .5).sum() / np.logical_or(full_pred > .5, outputs > .5).sum()\n",
        "        tests[j].append(val_iou)\n",
        "\n",
        "      plt.subplot(153)\n",
        "      for t in tests:\n",
        "        plt.plot(np.log10(1-np.array(t)))\n",
        "      plt.subplot(154)\n",
        "      for t in tests[:1]:\n",
        "        plt.plot(np.log10(1-np.array(t)))\n",
        "      for k in tests_all:\n",
        "        plt.plot(np.log10(1-tests_all[k][0]), label=k + ' easy')\n",
        "      plt.legend()\n",
        "      plt.subplot(155)\n",
        "      for t in tests[1:]:\n",
        "        plt.plot(np.log10(1-np.array(t)))\n",
        "      for k in tests_all:\n",
        "        plt.plot(np.log10(1-tests_all[k][1]), label=k + ' hard')\n",
        "      plt.legend()\n",
        "      plt.show()\n",
        "      print(name, i, tests[0][-1], tests[1][-1])\n",
        "\n",
        "  scores = []\n",
        "  for i, test in enumerate(testing_pts):\n",
        "    full_pred = network_pred(get_params(opt_state), test)\n",
        "    outputs = gt_fn(test, mesh)\n",
        "    val_iou = np.logical_and(full_pred > .5, outputs > .5).sum() / np.logical_or(full_pred > .5, outputs > .5).sum()\n",
        "    scores.append(val_iou)\n",
        "\n",
        "  meta_run = [\n",
        "              (get_params(opt_state), ab),\n",
        "              np.array(tests),\n",
        "              scores,\n",
        "              renderings,\n",
        "  ]\n",
        "\n",
        "  return meta_run"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zs3JKYzSGa8X",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Put your mesh files here\n",
        "mesh_files = {\n",
        "    'dragon'    : 'dragon_obj.obj',\n",
        "    'armadillo' : 'Armadillo.ply',\n",
        "    'buddha'    : 'buddha_obj.obj',\n",
        "    'lucy'      : 'Alucy.obj',\n",
        "}\n",
        "\n",
        "logdir = os.path.join(basedir, 'occupancy_logs')\n",
        "os.makedirs(logdir, exist_ok=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yxJ_Q8ypFNRs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "N_iters = 10000\n",
        "tests_all = {}\n",
        "out_all = {}\n",
        "scores = {}\n",
        "\n",
        "mesh_names = ['dragon', 'buddha', 'armadillo', 'lucy']\n",
        "\n",
        "embed_tasks = [\n",
        "               ['gauss', 256, 12.],\n",
        "               ['posenc', 256, 6.],\n",
        "               ['basic', None, None],\n",
        "               ['none', None, None],\n",
        "]\n",
        "\n",
        "expdir = os.path.join(logdir, 'full_runs')\n",
        "os.makedirs(expdir, exist_ok = True)\n",
        "print(expdir)\n",
        "\n",
        "for mesh_name in mesh_names:\n",
        "\n",
        "  mesh, corners, test_pts = load_mesh(mesh_name)\n",
        "\n",
        "  render_args_lr[1] = corners\n",
        "  render_args_hr[1] = corners\n",
        "\n",
        "  mesh_normal_map = render_mesh_normals(mesh, render_args_hr[0])\n",
        "  plt.imshow(mesh_normal_map * .5 + .5)\n",
        "  plt.show()\n",
        "\n",
        "  for embed_params in embed_tasks:\n",
        "    embedding_method, embedding_size, embedding_param = embed_params\n",
        "\n",
        "    expname = f'{mesh_name}_{embedding_method}_{embedding_param}'\n",
        "    print(expname)\n",
        "\n",
        "    out = run_training(embed_params, mesh, corners, test_pts, render_args_lr, expname)\n",
        "    tests_all[expname] = out[1]\n",
        "    out_all[expname] = out\n",
        "\n",
        "    rays = render_args_hr[0]\n",
        "    rets = []\n",
        "    hbatch = 16\n",
        "    for i in tqdm(range(0, H, hbatch)):\n",
        "      rets.append(render_rays(*out[0], rays[:,i:i+hbatch], *render_args_hr[1:]))\n",
        "    depth_map, acc_map = [np.concatenate([r[i] for r in rets], 0) for i in range(2)]\n",
        "\n",
        "    normal_map = make_normals(rays, depth_map)\n",
        "    normal_map = (255 * (.5 * normal_map + .5)).astype(np.uint8)\n",
        "    imageio.imsave(os.path.join(expdir, expname + '.png'), normal_map)\n",
        "\n",
        "    np.save(os.path.join(expdir, expname + '_netparams.npy'), out[0])\n",
        "\n",
        "    scores[expname] = out[2]\n",
        "    with open(os.path.join(expdir, 'scores.txt'), 'w') as f:\n",
        "      f.write(str(scores))\n",
        "    with open(os.path.join(expdir, 'scores_json.txt'), 'w') as f:\n",
        "      json.dump({k : onp.array(scores[k]).tolist() for k in scores}, f, indent=4)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NbBGibwMHu8r",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}