{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"nested_n_spheres_node.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"6S1576wMxDOx"},"source":["# Setup"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NvJw9wWHxB9Q","executionInfo":{"status":"ok","timestamp":1620879024055,"user_tz":420,"elapsed":3863,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}},"outputId":"a8fa46be-0e91-4a08-b9d4-6c204fb6e39e"},"source":["import time\n","import os\n","import argparse\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","\n","class ArgumentParser:\n","    def add_argument(self, str, type, default):\n","        setattr(self, str[2:], default)\n","\n","    def parse_args(self):\n","        return self\n","\n","parser = ArgumentParser()\n","parser.add_argument('--tol', type=float, default=1e-7)\n","parser.add_argument('--adjoint', type=eval, default=False)\n","parser.add_argument('--visualise', type=eval, default=True)\n","parser.add_argument('--niters', type=int, default=300)\n","parser.add_argument('--lr', type=float, default=0.01)\n","parser.add_argument('--gpu', type=int, default=0)\n","parser.add_argument('--extra_dim', type=int, default=1)\n","parser.add_argument('--data_dimension', type=int, default=2)\n","parser.add_argument('--npoints', type=int, default=50)\n","parser.add_argument('--ntest', type=int, default=10)\n","args = parser.parse_args()\n","\n","! pip install torchdiffeq\n","from torchdiffeq import odeint_adjoint as odeint\n"],"execution_count":1,"outputs":[{"output_type":"stream","text":["Requirement already satisfied: torchdiffeq in /usr/local/lib/python3.7/dist-packages (0.2.1)\n","Requirement already satisfied: torch>=1.3.0 in /usr/local/lib/python3.7/dist-packages (from torchdiffeq) (1.8.1+cu101)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch>=1.3.0->torchdiffeq) (1.19.5)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.3.0->torchdiffeq) (3.7.4.3)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9nLTYRksx4QG","executionInfo":{"status":"ok","timestamp":1620879024410,"user_tz":420,"elapsed":4206,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}},"outputId":"0f2b221d-0a6a-45d6-fffb-c5e30dcf26e6"},"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","! cp -r drive/MyDrive/nested-n-spheres/data ./data."],"execution_count":2,"outputs":[{"output_type":"stream","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"c_MHUQyK2EvJ","executionInfo":{"status":"ok","timestamp":1620879024411,"user_tz":420,"elapsed":4202,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["import csv\n","class Recorder:\n","    def __init__(self):\n","        self.store = []\n","        self.current = dict()\n","\n","    def __setitem__(self, key, value):\n","        for method in ['detach', 'cpu', 'numpy']:\n","            if hasattr(value, method):\n","                value = getattr(value, method)()\n","        if key in self.current:\n","            self.current[key].append(value)\n","        else:\n","            self.current[key] = [value]\n","\n","    def capture(self, verbose=False):\n","        for i in self.current:\n","            self.current[i] = np.mean(self.current[i])\n","        self.store.append(self.current.copy())\n","        self.current = dict()\n","        if verbose:\n","            for i in self.store[-1]:\n","                print('{}: {}'.format(i, self.store[-1][i]))\n","        return self.store[-1]\n","\n","    def tolist(self):\n","        labels = set()\n","        labels = sorted(labels.union(*self.store))\n","        outlist = []\n","        for obs in self.store:\n","            outlist.append([obs.get(i, np.nan) for i in labels])\n","        return labels, outlist\n","\n","    def writecsv(self, writer):\n","        \n","        labels, outlist = self.tolist()\n","        if isinstance(writer, str):\n","            outfile = open(writer, 'w')\n","            csvwriter = csv.writer(outfile)\n","            csvwriter.writerow(labels)\n","            csvwriter.writerows(outlist)\n","            outfile.close()\n","        else:\n","            csvwriter = writer\n","            csvwriter.writerow(labels)\n","            csvwriter.writerows(outlist)"],"execution_count":3,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bfV63AOmxGNX"},"source":["# Model"]},{"cell_type":"code","metadata":{"id":"GhQPSMEqJ0ps","executionInfo":{"status":"ok","timestamp":1620879024411,"user_tz":420,"elapsed":4198,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["totallist = []"],"execution_count":4,"outputs":[]},{"cell_type":"code","metadata":{"id":"M6zb7sWmve6K","executionInfo":{"status":"ok","timestamp":1620879024412,"user_tz":420,"elapsed":4194,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["modelname = 'NODE'\n","seed = 100\n","res = 0.0\n","if modelname == 'GHBNODE':\n","    res = 2.0\n","\n","class ODEfunc(nn.Module):\n","\n","    def __init__(self, dim, nhidden):\n","        super(ODEfunc, self).__init__()\n","        self.elu = nn.ELU(inplace=True)\n","        self.fc1 = nn.Linear(dim, nhidden)\n","        self.fc2 = nn.Linear(nhidden, nhidden)\n","        self.fc3 = nn.Linear(nhidden, dim)\n","        self.nfe = 0\n","\n","    def forward(self, t, x):\n","        self.nfe += 1\n","        out = self.fc1(x)\n","        out = self.elu(out)\n","        out = self.fc2(out)\n","        out = self.elu(out)\n","        out = self.fc3(out)\n","        return out\n","    \n","class ODEBlock(nn.Module):\n","\n","    def __init__(self, odefunc, t0_, tN_):\n","        super(ODEBlock, self).__init__()\n","        self.odefunc = odefunc\n","        self.integration_times = torch.tensor([t0_, tN_]).float()\n","        \n","    def forward(self, x):\n","        out = odeint(self.odefunc, x, self.integration_times, rtol=args.tol, atol=args.tol)\n","        out = out[1]\n","        return out\n","\n","    @property\n","    def nfe(self):\n","        return self.odefunc.nfe\n","\n","    @nfe.setter\n","    def nfe(self, value):\n","        self.odefunc.nfe = value\n","        \n","\n","class Decoder(nn.Module):\n","\n","    def __init__(self, in_dim, out_dim):\n","        super(Decoder, self).__init__()\n","        self.tanh = nn.Hardtanh(min_val=-1.0, max_val=1.0, inplace=False)\n","        self.fc = nn.Linear(in_dim, out_dim)\n","\n","    def forward(self, z):\n","        out = self.fc(z)\n","        out = self.tanh(out)\n","        return out\n","\n","\n","def count_parameters(model):\n","    return sum(p.numel() for p in model.parameters() if p.requires_grad)"],"execution_count":5,"outputs":[]},{"cell_type":"code","metadata":{"id":"9B_1LEW_v4hV","executionInfo":{"status":"ok","timestamp":1620879024805,"user_tz":420,"elapsed":4583,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["def train():\n","    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')\n","    if args.extra_dim == 0:\n","        filename = 'node./'\n","    else:\n","        filename = 'anode('+str(args.extra_dim)+')./'\n","    \n","    try:\n","        os.makedirs('./'+filename)\n","    except FileExistsError:\n","        pass\n","\n","    rec = Recorder()\n","    dim = args.data_dimension + args.extra_dim\n","    outputarr = []\n","\n","    #Download data\n","    name_in = str(args.data_dimension)+'din_'+str(args.npoints)+'_train.npy'\n","    name_out = str(args.data_dimension)+'dout_'+str(args.npoints)+'_train.npy'\n","    folder_name = 'data./'\n","    z0 = torch.tensor(np.load(folder_name+name_in)).float().to(device)\n","    zN = torch.tensor(np.load(folder_name+name_out)).float().to(device)\n","    viz_z0 = torch.tensor(np.load('data./vis_data/2d_vis_data.npy')).float().to(device)\n","    \n","    #Augment z0\n","    zeros = torch.zeros(args.npoints, args.extra_dim).float()\n","    z0 = torch.cat((z0, zeros), dim=1).to(device)\n","\n","    # model\n","    t0, tN = 0, 1\n","    nhidden = 20\n","    feature_layers = [ODEBlock(ODEfunc(dim, nhidden), t0, tN), Decoder(dim, 1)]\n","    model = nn.Sequential(*feature_layers).to(device)\n","    print(model)\n","    print(f'Model Parameters = {count_parameters(model)}')\n","    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n","    loss_func = nn.MSELoss()\n","    \n","    itr_arr = np.empty(args.niters)\n","    loss_arr = np.empty(args.niters)\n","    nfe_arr = np.empty(args.niters)\n","    time_arr = np.empty(args.niters)\n","\n","    # training\n","    start_time = time.time()\n","    for itr in range(1, args.niters + 1):\n","        rec['epoch'] = itr\n","        feature_layers[0].nfe = 0\n","        iter_start_time = time.time()\n","        optimizer.zero_grad()\n","\n","        # forward in time and solve ode\n","        pred_z = model(z0)\n","\n","        # compute loss\n","        loss = loss_func(pred_z, zN)\n","        rec['forward_nfe'] = feature_layers[0].nfe\n","        loss.backward()\n","        optimizer.step()\n","        iter_end_time = time.time()\n","\n","        #make arrays\n","        itr_arr[itr-1] = itr\n","        loss_arr[itr-1] = loss\n","        nfe_arr[itr-1] = feature_layers[0].nfe\n","        time_arr[itr-1] = iter_end_time-iter_start_time\n","        rec['epoch_nfe'] = feature_layers[0].nfe\n","        rec['loss'] = loss\n","        rec['log_loss'] = torch.log10(loss)\n","        rec.capture(verbose=False)\n","\n","        if itr % 100 == 0:\n","            print('Iter: {}, running MSE: {:.4f}'.format(itr, loss))\n","            # if loss > 0.2 and args.extra_dim > 0:\n","            #     raise StopIteration\n","\n","        feature_layers[0].nfe = 0\n","        pred_out = model[0](viz_z0)\n","        outputarr.append(pred_out.detach())\n","            \n","\n","    end_time = time.time()\n","    # print('\\n')\n","    # print('Training complete after {} iters.'.format(itr))\n","    # print('Time = ' + str(end_time-start_time))\n","    loss = loss_func(pred_z, zN).detach().numpy()\n","    # print('Train MSE = ' +str(loss))\n","    # print('NFE = ' +str(feature_layers[0].nfe))\n","    # print('Parameters = '+str(count_parameters(model)))\n","    \n","    np.save(filename+'itr_arr.npy', itr_arr)\n","    np.save(filename+'nfe_arr.npy', nfe_arr)\n","    np.save(filename+'loss_arr.npy', loss_arr)\n","    np.save(filename+'time_arr.npy', time_arr)\n","    torch.save(model, filename+'model.pth')\n","    \n","    \n","    # make test data\n","    name_in = str(args.data_dimension)+'din_'+str(args.ntest)+'_test.npy'\n","    name_out = str(args.data_dimension)+'dout_'+str(args.ntest)+'_test.npy'\n","    folder_name = 'data./'\n","    z0 = torch.tensor(np.load(folder_name+name_in)).float().to(device)\n","    zN = torch.tensor(np.load(folder_name+name_out)).float().to(device)\n","    # augment z0\n","    zeros = torch.zeros(args.ntest, args.extra_dim).float()\n","    z0 = torch.cat((z0, zeros), dim=1).to(device)\n","    \n","    # run test data through network\n","    pred_z = model(z0)\n","\n","    # compute loss\n","    loss = loss_func(pred_z, zN).detach().numpy()\n","    print('Test MSE = ' +str(loss))\n","    \n","    \n","    if args.visualise:\n","        try:\n","            os.makedirs('./figure_data./')\n","        except FileExistsError:\n","            pass\n","        samp_ts = torch.linspace(t0, tN, 30)\n","        if args.data_dimension == 1:\n","            z0 = torch.tensor(np.load('data./vis_data./1d_vis_data.npy')).float()\n","            ntotal = len(z0)\n","            #Augment z0\n","            zeros = torch.zeros(ntotal, args.extra_dim).float()\n","            z0 = torch.cat((z0, zeros), dim=1).to(device)\n","            pred_z = odeint(feature_layers[0].odefunc, z0, samp_ts)\n","            pred_z = pred_z.detach().numpy()\n","            if args.extra_dim == 0:\n","                name = 'figure_data./node_film_1d.npy'\n","            else:\n","                name = 'figure_data./anode_film_(1+'+str(args.extra_dim)+')d.npy'\n","            np.save(name, pred_z)\n","        elif args.data_dimension == 2:\n","            z0 = torch.tensor(np.load('data./vis_data/2d_vis_data.npy')).float()\n","            ntotal = len(z0)\n","            #Augment z0\n","            zeros = torch.zeros(ntotal, args.extra_dim).float()\n","            z0 = torch.cat((z0, zeros), dim=1).to(device)\n","            pred_z = odeint(feature_layers[0].odefunc, z0, samp_ts)\n","            pred_z = pred_z.detach().numpy()\n","            if args.extra_dim == 0:\n","                name = 'figure_data./node_film_2d.npy'\n","            else:\n","                name = 'figure_data./anode_film_(2+'+str(args.extra_dim)+')d.npy'\n","            np.save(name, pred_z)\n","        elif args.data_dimension == 3:\n","            z0 = torch.tensor(np.load('data./vis_data./3d_vis_data.npy')).float()\n","            ntotal = len(z0)\n","            #Augment z0\n","            zeros = torch.zeros(ntotal, args.extra_dim).float()\n","            z0 = torch.cat((z0, zeros), dim=1).to(device)\n","            pred_z = odeint(feature_layers[0].odefunc, z0, samp_ts)\n","            pred_z = pred_z.detach().numpy()\n","            if args.extra_dim == 0:\n","                name = 'figure_data./node_film_3d.npy'\n","            else:\n","                pass\n","            np.save(name, pred_z)\n","        else:\n","            pass\n","    return rec, outputarr, pred_z"],"execution_count":6,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":561},"id":"B2dDpq9O1Uni","executionInfo":{"status":"error","timestamp":1620879024808,"user_tz":420,"elapsed":4581,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}},"outputId":"b2498de1-590d-43c1-da9b-1b166efce646"},"source":["rec_list = [train()[0] for t in range(100)]"],"execution_count":7,"outputs":[{"output_type":"stream","text":["Sequential(\n","  (0): ODEBlock(\n","    (odefunc): ODEfunc(\n","      (elu): ELU(alpha=1.0, inplace=True)\n","      (fc1): Linear(in_features=3, out_features=20, bias=True)\n","      (fc2): Linear(in_features=20, out_features=20, bias=True)\n","      (fc3): Linear(in_features=20, out_features=3, bias=True)\n","    )\n","  )\n","  (1): Decoder(\n","    (tanh): Hardtanh(min_val=-1.0, max_val=1.0)\n","    (fc): Linear(in_features=3, out_features=1, bias=True)\n","  )\n",")\n","Model Parameters = 567\n"],"name":"stdout"},{"output_type":"error","ename":"RuntimeError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)","\u001b[0;32m<ipython-input-7-035afbf79aa8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mrec_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m<ipython-input-7-035afbf79aa8>\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mrec_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m<ipython-input-6-522528899895>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m     76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     77\u001b[0m         \u001b[0mfeature_layers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnfe\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m         \u001b[0mpred_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mviz_z0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     79\u001b[0m         \u001b[0moutputarr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpred_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m<ipython-input-5-9bad8cb92dcf>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0modeint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0modefunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintegration_times\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrtol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0matol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     35\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/adjoint.py\u001b[0m in \u001b[0;36modeint_adjoint\u001b[0;34m(func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options, adjoint_params)\u001b[0m\n\u001b[1;32m    197\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    198\u001b[0m     ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol,\n\u001b[0;32m--> 199\u001b[0;31m                                     adjoint_method, adjoint_options, t.requires_grad, *adjoint_params)\n\u001b[0m\u001b[1;32m    200\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    201\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mevent_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/adjoint.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options, t_requires_grad, *adjoint_params)\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m             \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0modeint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrtol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrtol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0matol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0matol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevent_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mevent_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mevent_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/odeint.py\u001b[0m in \u001b[0;36modeint\u001b[0;34m(func, y0, t, rtol, atol, method, options, event_fn)\u001b[0m\n\u001b[1;32m     75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mevent_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 77\u001b[0;31m         \u001b[0msolution\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintegrate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     78\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     79\u001b[0m         \u001b[0mevent_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msolution\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintegrate_until_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevent_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/solvers.py\u001b[0m in \u001b[0;36mintegrate\u001b[0;34m(self, t)\u001b[0m\n\u001b[1;32m     26\u001b[0m         \u001b[0msolution\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m         \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_before_integrate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     29\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m             \u001b[0msolution\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_advance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/rk_common.py\u001b[0m in \u001b[0;36m_before_integrate\u001b[0;34m(self, t)\u001b[0m\n\u001b[1;32m    159\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_before_integrate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    160\u001b[0m         \u001b[0mt0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m         \u001b[0mf0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    162\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfirst_step\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    163\u001b[0m             first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol,\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/misc.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, t, y, perturb)\u001b[0m\n\u001b[1;32m    189\u001b[0m             \u001b[0;31m# Do nothing.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    190\u001b[0m             \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 191\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/misc.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, t, y, perturb)\u001b[0m\n\u001b[1;32m    189\u001b[0m             \u001b[0;31m# Do nothing.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    190\u001b[0m             \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 191\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m<ipython-input-5-9bad8cb92dcf>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, t, x)\u001b[0m\n\u001b[1;32m     17\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnfe\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     20\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     93\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     96\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m   1751\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhas_torch_function_variadic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1752\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1753\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1754\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1755\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (120x2 and 3x20)"]}]},{"cell_type":"code","metadata":{"id":"OM_oztTKHHvH","executionInfo":{"status":"aborted","timestamp":1620879024806,"user_tz":420,"elapsed":4573,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["print(len(rec_list))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Bq3qbRINHjEF","executionInfo":{"status":"aborted","timestamp":1620879024807,"user_tz":420,"elapsed":4570,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["import scipy.io\n","\n","def calc_mean(rec_list, col=2):\n","  m = []\n","  for rec in rec_list:\n","    m.append(np.array(rec.tolist()[1])[:, col])\n","  m = np.array(m)\n","  scipy.io.savemat(f'{modelname}_fnfe.mat', {'data': m})\n","  np.save(f'{modelname}_fnfe.npy', m)\n","  return np.mean(m, axis=0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"cTY83EN3IGlO","executionInfo":{"status":"aborted","timestamp":1620879024807,"user_tz":420,"elapsed":4565,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["calc_mean(rec_list).shape"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"T_6sbS5lC4Fu","executionInfo":{"status":"aborted","timestamp":1620879025131,"user_tz":420,"elapsed":4884,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["rec, outputarr, pred_z = train()\n","rec.writecsv('train_{}.csv'.format(modelname))\n","np.save('trajectory_{}'.format(modelname), pred_z)\n","totallist.append(rec.tolist()[1])\n","print(np.array(totallist).shape)\n","np.save('train_{}'.format(modelname), np.array(totallist))\n","outputarr = torch.stack(outputarr, dim=0).numpy()\n","np.save(f'drive/MyDrive/PointCloud/{modelname}_outputarr.npy', outputarr, allow_pickle=True)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NCUgX-ljU4OV","executionInfo":{"status":"aborted","timestamp":1620879025132,"user_tz":420,"elapsed":4880,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["import matplotlib.pyplot as plt\n","import matplotlib\n","fsize = 15\n","plt.rc('axes', labelsize=fsize) #fontsize of the x and y labels\n","plt.rc('xtick', labelsize=fsize+5) #fontsize of the x tick labels\n","plt.rc('ytick', labelsize=fsize-5) #fontsize of the y tick labels\n","\n","print(outputarr.shape, type(outputarr))\n","\n","ending_time = 300\n","timespots = np.linspace(0, ending_time, 6, dtype=np.int)\n","timespots[-1] -= 1\n","fig = plt.figure(figsize=(10, 8))\n","gs = fig.add_gridspec(4, 6)\n","axs = [[fig.add_subplot(gs[0, k]) for k in range(6)], [fig.add_subplot(gs[1, :])], [fig.add_subplot(gs[2, :])], [fig.add_subplot(gs[3, :])]]\n","\n","for j in range(6):\n","    for i in range(outputarr.shape[1]):\n","        ts = timespots[j]\n","        col = 'bo' if i < 40 else 'ro'\n","        ax = axs[0][j]\n","        ax.plot(outputarr[ts,i,0], outputarr[ts,i,1], col, alpha=0.4)\n","        ax.set_xticks([])\n","        ax.set_yticks([])\n","\n","f_nfe = calc_mean(rec_list)[:ending_time]\n","f_nfeplot = axs[1][0]\n","f_nfeplot.plot(np.arange(ending_time), f_nfe, linewidth=3)\n","f_nfeplot.set_ylim([10, 42])\n","f_nfeplot.set_yticks([10, 20, 30, 40])\n","f_nfeplot.set_title(f\"{modelname}\".upper(), fontsize=35)\n","f_nfeplot.set_ylabel('Forward', fontsize=25)\n","print(f'Max Forward NFEs: {np.max(f_nfe)}')\n","\n","b_nfe = calc_mean(rec_list, 1)[:ending_time] - f_nfe\n","b_nfeplot = axs[2][0]\n","b_nfeplot.plot(np.arange(ending_time), b_nfe, linewidth=3)\n","b_nfeplot.set_ylim([10, 60])\n","b_nfeplot.set_yticks([10, 20, 30, 40, 50, 60])\n","b_nfeplot.set_ylabel('Backward', fontsize=25)\n","print(f'Max Backward NFEs: {np.max(b_nfe)}')\n","\n","loss = calc_mean(rec_list, col=4)[:ending_time]\n","lossplot = axs[3][0]\n","lossplot.plot(np.arange(ending_time), loss, linewidth=3)\n","lossplot.set_ylabel('Loss', fontsize=25)\n","lossplot.set_xlabel('Epochs', fontsize=30)\n","lossplot.set_yticks([0.0, 0.5, 1.0, 1.5])\n","lossplot.set_ylim([0, 1.5])\n","print(f'Total NFEs: {np.max(f_nfe) + np.max(b_nfe)}')\n","print(f'Min Loss: {np.min(loss)}')\n","\n","plt.tight_layout()\n","plt.savefig(f'{modelname}_fnfe_pc_plot.pdf', format=\"pdf\", bbox_inches='tight')\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"4P0qHp5BCNj4","executionInfo":{"status":"aborted","timestamp":1620879025132,"user_tz":420,"elapsed":4875,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["m = np.array(rec_list)\n","\n","np.save(f'drive/MyDrive/PointCloud/{modelname}_recs_2.npy', m, allow_pickle=True)\n","\n","# from google.colab import files\n","# # files.download(f'{modelname}_fnfe.mat') \n","# files.download(f'{modelname}_fnfe_pc_plot.pdf')\n","# files.download(f'{modelname}_recs.npy')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"TSxuc1wLOtmS"},"source":["# Plotting"]},{"cell_type":"code","metadata":{"id":"gyHDfIwujLDe","executionInfo":{"status":"aborted","timestamp":1620879025133,"user_tz":420,"elapsed":4872,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["#! sudo apt install texlive texlive-latex-extra texlive-fonts-recommended dvipng\n","#! pip install latex\n","\n","\n","#sonode\n","import seaborn as sns\n","import matplotlib.pyplot as plt\n","import matplotlib\n","\n","sns.set_style('dark')\n","\n","\n","ax3 = plt.plot()\n","\n","film_data = np.load('figure_data./node_film_2d.npy')\n","\n","#which frame to stop by\n","a = len(film_data)\n","\n","frames = []\n","\n","for i in range(a):\n","    frames += [film_data[i][:120]]\n","    \n","intermediate = np.asarray(frames)\n","\n","inner = []\n","outer = []\n","\n","for i in range(a):\n","    inner += [intermediate[i][:40]]\n","    outer += [intermediate[i][40:]]\n","    \n","inner = np.asarray(inner)\n","outer = np.asarray(outer)\n","\n","\n","#make film image:\n","\n","inner_pic = np.empty((40, a, 2))\n","outer_pic = np.empty((80, a, 2))\n","\n","for i in range(40):\n","    for j in range(a):\n","        inner_pic[i][j] = inner[j][i]\n","           \n","for i in range(40):\n","    inner_pic_plot = np.transpose(inner_pic[i])\n","    plt.plot(inner_pic_plot[0], inner_pic_plot[1], color='#004488', linewidth=0.3)    \n","inner_start_frame = np.transpose(inner[0])\n","inner_end_frame = np.transpose(inner[len(inner)-1])\n","plt.scatter(inner_start_frame[0], inner_start_frame[1], color='#004488', s=25)\n","plt.scatter(inner_end_frame[0], inner_end_frame[1], color='#004488', s=25)\n","\n","\n","for i in range(80):\n","    for j in range(a):\n","        outer_pic[i][j] = outer[j][i]\n","        \n","for i in range(80):\n","    outer_pic_plot = np.transpose(outer_pic[i])\n","    plt.plot(outer_pic_plot[0], outer_pic_plot[1], color='#BB5566', linewidth=0.3)    \n","outer_start_frame = np.transpose(outer[0])\n","outer_end_frame = np.transpose(outer[len(inner)-1])\n","plt.scatter(outer_start_frame[0], outer_start_frame[1], color='#BB5566', s=25)\n","plt.scatter(outer_end_frame[0], outer_end_frame[1], color='#BB5566', s=25)\n","#plt.xlabel('x', fontsize=14)\n","#plt.ylabel('y',fontsize=14)\n","#plt.xticks([])\n","#plt.yticks([])\n","matplotlib.rc('font', size=20)\n","#rc('text', usetex=True)\n","plt.title(modelname, fontsize=40)\n","\n","\n","plt.tight_layout()\n","#sns.set_style('white')\n","plt.savefig('nested_n_spheres_{}.pdf'.format(modelname), bbox_inches='tight')\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"31rQnb4DKU9T","executionInfo":{"status":"aborted","timestamp":1620879025133,"user_tz":420,"elapsed":4867,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":["#node\n","sns.set_style('dark')\n","# rc('font', family='serif')\n","# rc('text', usetex=True)\n","ax1 = plt.subplot(1, 3, 1)\n","ninner = 40\n","nouter = 80\n","\n","film_data = np.load('figure_data./node_film_2d.npy')\n","\n","a = len(film_data)\n","frames = []\n","\n","for i in range(a):\n","    frames += [film_data[i][:ninner+nouter]]\n","    \n","intermediate = np.asarray(frames)\n","\n","inner = []\n","outer = []\n","\n","for i in range(a):\n","    inner += [intermediate[i][:ninner]]\n","    outer += [intermediate[i][ninner:]]\n","    \n","inner = np.asarray(inner)\n","outer = np.asarray(outer)\n","\n","\n","#make film image:\n","\n","inner_pic = np.empty((ninner, a, 2))\n","outer_pic = np.empty((nouter, a, 2))\n","\n","for i in range(ninner):\n","    for j in range(a):\n","        inner_pic[i][j] = inner[j][i]\n","           \n","for i in range(ninner):\n","    inner_pic_plot = np.transpose(inner_pic[i])\n","    plt.plot(inner_pic_plot[0], inner_pic_plot[1], color='#004488', linewidth=0.3)    \n","inner_start_frame = np.transpose(inner[0])\n","inner_end_frame = np.transpose(inner[len(inner)-1])\n","plt.scatter(inner_start_frame[0], inner_start_frame[1], color='#004488', s=15)\n","plt.scatter(inner_end_frame[0], inner_end_frame[1], color='#004488', s=15)\n","\n","\n","for i in range(nouter):\n","    for j in range(a):\n","        outer_pic[i][j] = outer[j][i]\n","        \n","for i in range(nouter):\n","    outer_pic_plot = np.transpose(outer_pic[i])\n","    plt.plot(outer_pic_plot[0], outer_pic_plot[1], color='#BB5566', linewidth=0.3)    \n","outer_start_frame = np.transpose(outer[0])\n","outer_end_frame = np.transpose(outer[len(inner)-1])\n","plt.scatter(outer_start_frame[0], outer_start_frame[1], color='#BB5566', s=15)\n","plt.scatter(outer_end_frame[0], outer_end_frame[1], color='#BB5566', s=15)\n","#plt.xlabel('x', fontsize=14)\n","#plt.ylabel('y',fontsize=14)\n","# rc('font', family='serif')\n","# rc('text', usetex=True)\n","plt.xticks([])\n","plt.yticks([])\n","plt.title('NODE', fontsize=24)\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","#anode\n","#Selects what frame to stop at, for representation\n","a = 6\n","\n","\n","ninner = 40\n","nouter = 80\n","\n","film_data = np.load('figure_data./anode_film_(2+1)d.npy')\n","\n","frames = []\n","\n","for i in range(len(film_data)):\n","    frames += [film_data[i][:ninner+nouter]]\n","    \n","intermediate = np.asarray(frames)\n","\n","inner = []\n","outer = []\n","\n","for i in range(a):\n","    inner += [intermediate[i][:ninner]]\n","    outer += [intermediate[i][ninner:]]\n","    \n","inner = np.asarray(inner)\n","outer = np.asarray(outer)\n","\n","\n","\n","#make film image:\n","sns.set_style('white')\n","ax2 = fig.add_subplot(132, projection='3d')\n","\n","inner_pic = np.empty((ninner, a, 3))\n","outer_pic = np.empty((nouter, a, 3))\n","\n","for i in range(ninner):\n","    for j in range(a):\n","        inner_pic[i][j] = inner[j][i]\n","           \n","for i in range(ninner):\n","    inner_pic_plot = np.transpose(inner_pic[i])\n","    ax2.plot(inner_pic_plot[0], inner_pic_plot[1], inner_pic_plot[2], color='#004488', linewidth=0.3)    \n","inner_start_frame = np.transpose(inner[0])\n","inner_end_frame = np.transpose(inner[len(inner)-1])\n","ax2.scatter(inner_start_frame[0], inner_start_frame[1], inner_start_frame[2], color='#004488', s=15)\n","ax2.scatter(inner_end_frame[0], inner_end_frame[1], inner_end_frame[2], color='#004488', s=15)\n","\n","\n","for i in range(nouter):\n","    for j in range(a):\n","        outer_pic[i][j] = outer[j][i]\n","        \n","for i in range(nouter):\n","    outer_pic_plot = np.transpose(outer_pic[i])\n","    plt.plot(outer_pic_plot[0], outer_pic_plot[1], outer_pic_plot[2], color='#BB5566', linewidth=0.3)    \n","outer_start_frame = np.transpose(outer[0])\n","outer_end_frame = np.transpose(outer[len(inner)-1])\n","ax2.scatter(outer_start_frame[0], outer_start_frame[1], outer_start_frame[2], color='#BB5566', s=15)\n","ax2.scatter(outer_end_frame[0], outer_end_frame[1], outer_end_frame[2], color='#BB5566', s=15)\n","ax2.grid(False)\n","#ax2.xaxis.pane.fill = False\n","#ax2.yaxis.pane.fill = False\n","#ax2.zaxis.pane.fill = False\n","\n","#ax2.set_xlabel('x', fontsize=14)\n","#ax2.set_ylabel('y', fontsize=14)\n","#ax2.set_zlabel('z', fontsize=14)\n","ax2.set_xticks([])\n","ax2.set_yticks([])\n","ax2.set_zticks([])\n","rc('font', family='serif')\n","rc('text', usetex=True)\n","ax2.set_title('ANODE(1)', fontsize=24, pad=27)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"KbH20vKp7LTI","executionInfo":{"status":"aborted","timestamp":1620879025134,"user_tz":420,"elapsed":4863,"user":{"displayName":"Xia Hedi","photoUrl":"","userId":"10053593801688756203"}}},"source":[""],"execution_count":null,"outputs":[]}]}