{"title": "Learning Hierarchical Information Flow with Recurrent Neural Modules", "book": "Advances in Neural Information Processing Systems", "page_first": 6724, "page_last": 6733, "abstract": "We propose ThalNet, a deep learning model inspired by neocortical communication via the thalamus. Our model consists of recurrent neural modules that send features through a routing center, endowing the modules with the flexibility to share features over multiple time steps. We show that our model learns to route information hierarchically, processing input data by a chain of modules. We observe common architectures, such as feed forward neural networks and skip connections, emerging as special cases of our architecture, while novel connectivity patterns are learned for the text8 compression task. Our model outperforms standard recurrent neural networks on several sequential benchmarks.", "full_text": "Learning Hierarchical Information Flow\n\nwith Recurrent Neural Modules\n\nDanijar Hafner \u2217\nGoogle Brain\n\nmail@danijar.com\n\nJames Davidson\n\nGoogle Brain\n\njcdavidson@google.com\n\nAlex Irpan\nGoogle Brain\n\nalexirpan@google.com\n\nNicolas Heess\n\nGoogle DeepMind\nheess@google.com\n\nAbstract\n\nWe propose ThalNet, a deep learning model inspired by neocortical communication\nvia the thalamus. Our model consists of recurrent neural modules that send features\nthrough a routing center, endowing the modules with the \ufb02exibility to share features\nover multiple time steps. We show that our model learns to route information\nhierarchically, processing input data by a chain of modules. We observe common\narchitectures, such as feed forward neural networks and skip connections, emerging\nas special cases of our architecture, while novel connectivity patterns are learned\nfor the text8 compression task. Our model outperforms standard recurrent neural\nnetworks on several sequential benchmarks.\n\n1\n\nIntroduction\n\nDeep learning models make use of modular building blocks such as fully connected layers, convolu-\ntional layers, and recurrent layers. Researchers often combine them in strictly layered or task-speci\ufb01c\nways. Instead of prescribing this connectivity a priori, our method learns how to route information as\npart of learning to solve the task. We achieve this using recurrent modules that communicate via a\nrouting center that is inspired by the thalamus.\nWarren McCulloch and Walter Pitts invented the perceptron in 1943 as the \ufb01rst mathematical model\nof neural information processing [22], laying the groundwork for modern research on arti\ufb01cial neural\nnetworks. Since then, researchers have continued looking for inspiration from neuroscience to identify\nnew deep learning architectures [11, 13, 16, 31].\nWhile some of these efforts have been directed at learning biologically plausible mechanisms in an\nattempt to explain brain behavior, our interest is to achieve a \ufb02exible learning model. In the neocortex,\ncommunication between areas can be broadly classi\ufb01ed into two pathways: Direct communication\nand communication via the thalamus [28]. In our model, we borrow this latter notion of a centralized\nrouting system to connect specializing neural modules.\nIn our experiments, the presented model learns to form connection patterns that process input\nhierarchically, including skip connections as known from ResNet [12], Highway networks [29], and\nDenseNet [14] and feedback connections, which are known to both play an important role in the\nneocortex and improve deep learning [7, 20]. The learned connectivity structure is adapted to the\ntask, allowing the model to trade-off computational width and depth. In this paper, we study these\nproperties with the goal of building an understanding of the interactions between recurrent neural\nmodules.\n\n\u2217Work done during an internship with Google Brain.\n\n31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA.\n\n\f(a) Module f 1 receives the task input,\nf 2 can be used for side computation, f 3\nis trained on an auxiliary task, and f 4\nproduces the output for the main task.\nFigure 1: Several modules share their learned features via a routing center. Dashed lines are used for\ndynamic reading only. We de\ufb01ne both static and dynamic reading mechanisms in Section 2.2.\n\n(b) Computation of 3 modules unrolled in time. One possible path of\nhierarchical information \ufb02ow is highlighted in green. We show that\nour model learns hierarchical information \ufb02ow, skip connections\nand feedback connections in Section 4.\n\nSection 2 de\ufb01nes our computational model. We point out two critical design axes, which we explore\nexperimentally in the supplementary material. In Section 3 we compare the performance of our model\non three sequential tasks, and show that it consistently outperforms multi-layer recurrent networks.\nIn Section 4, we apply the best performing design to a language modeling task, where we observe\nthat the model automatically learns hierarchical connectivity patterns.\n\n2 Thalamus Gated Recurrent Modules\n\nWe \ufb01nd inspiration for our work in the neurological structure of the neocortex. Areas of the neocortex\ncommunicate via two principal pathways: The cortico-cortico-pathway comprises direct connections\nbetween nuclei, and the cortico-thalamo-cortico comprises connections relayed via the thalamus.\nInspired by this second pathway, we develop a sequential deep learning model in which modules\ncommunicate via a routing center. We name the proposed model ThalNet.\n\n2.1 Model De\ufb01nition\nOur system comprises a tuple of computation modules F = (f 1,\u00b7\u00b7\u00b7 , f I ) that route their respective\nfeatures into a shared center vector \u03a6. An example instance of our ThalNet model is shown in\nFigure 1a. At every time step t, each module f i reads from the center vector via a context input ci\nt\nand an optional task input xi\nt) that each module produces are directed\ninto the center \u03a6.2 Output modules additionally produce task output from their feature vector as a\nfunction oi(\u03c6i) = yi.\nAll modules send their features to the routing center, where they are merged to a single feature vector\nt ). In our experiments, we simply implement m as the concatenation of all \u03c6i.\n\u03a6t = m(\u03c61\nAt the next time step, the center vector \u03a6t is then read selectively by each module using a reading\nt).3 This reading mechanism allows modules\nmechanism to obtain the context input ci\nto read individual features, allowing for complex and selective reuse of information between modules.\nThe initial center vector \u03a60 is the zero vector.\n\nt. The features \u03c6i\n\nt+1 = ri(\u03a6t, \u03c6i\n\nt ,\u00b7\u00b7\u00b7 , \u03c6I\n\nt = f i(ci\n\nt, xi\n\n2In practice, we experiment with both feed forward and recurrent implementations of the modules f i. For\n\nsimplicity, we omit the hidden state used in recurrent modules in our notation.\n\n3The reading mechanism is conditioned on both \u03a6t and \u03c6i\n\nt separately as the merging does not preserve \u03c6i\n\nt in\n\nthe general case.\n\n2\n\nf 2f 3f 4f 1x 1y 4y 3\u03a6\u03a6 2\u03a6 1x 2y 2f 2f 3x 3y 3f 1x 1y 1f 2f 1f 3f 2f 3f 1\fFigure 2: The ThalNet model from the perspective of a single module. In this example, the module\nreceives input xi and produces features to the center \u03a6 and output yi. Its context input ci is determined\nas a linear mapping of the center features from the previous time step. In practice, we apply weight\nnormalization to encourage interpretable weight matrices (analyzed in Section 4).\n\nIn summary, ThalNet is governed by the following equations:\nt, xi\n\u03c6i\nt = f i(ci\nt)\nyi\nt = oi(\u03c6i\nt)\nt ,\u00b7\u00b7\u00b7 , \u03c6I\n\u03a6t = m(\u03c61\nt )\nt+1 = ri(\u03a6t, \u03c6i\nci\nt)\n\nModule features:\nModule output:\nCenter features:\nRead context input:\n\n(1)\n(2)\n(3)\n(4)\n\nThe choice of input and output modules depends on the task at hand. In a simple scenario (e.g., single\ntask), there is exactly one input module receiving task input, some number of side modules, and\nexactly one output module producing predictions. The output modules get trained using appropriate\nloss functions, with their gradients \ufb02owing backwards through the fully differentiable routing center\ninto all modules.\nModules can operate in parallel as reads target the center vector from the previous time step. An\nunrolling of the multi-step process can be seen in Figure 1b. This \ufb01gure illustrates the ability to\narbitrarily route between modules between time steps This suggest a sequential nature of our model,\neven though application to static input is possible by allowing observing the input for multiple time\nsteps.\nWe hypothesize that modules will use the center to route information through a chain of modules\nbefore producing the \ufb01nal output (see Section 4). For tasks that require producing an output at every\ntime step, we repeat input frames to allow the model to process through multiple modules \ufb01rst, before\nproducing an output. This is because communication between modules always spans a time step.4\n\n2.2 Reading Mechanisms\n\nWe now discuss implementations of the reading mechanism ri(\u03a6, \u03c6i) and modules f i(ci, xi), as\nde\ufb01ned in Section 2.1. We draw a distinction between static and dynamic reading mechanisms for\nThalNet. For static reading, ri(\u03a6) is conditioned on independent parameters. For dynamic reading,\nri(\u03a6, \u03c6i) is conditioned on the current corresponding module state, allowing the model to adapt its\nconnectivity within a single sequence. We investigate the following reading mechanisms:\n\n\u2022 Linear Mapping.\n\nIn its simplest form, static reading consists of a fully connected layer\nr(\u03a6,\u00b7) = W \u03a6 with weights W \u2208 R|c|\u00d7|\u03a6| as illustrated in Figure 2. This approach\nperforms reasonably well, but can exhibit unstable learning dynamics and learns noisy\nweight matrices that are hard to interpret. Regularizing weights using L1 or L2 penalties\ndoes not help here since it can cause side modules to not get read from anymore.\n\u2022 Weight Normalization. We found linear mappings with weight normalization [26] pa-\nrameterization to be effective. For this, the context input is computed as r(\u03a6,\u00b7) = \u03b2 W|W| \u03a6\nwith scaling factor \u03b2 \u2208 R, weights W \u2208 R|c|\u00d7|\u03a6|, and the Euclidean matrix norm |W|.\n\n4Please refer to Graves [8] for a study of a similar approach.\n\n3\n\nf 2c\u03a6yx\u03c6 2\u03c6 3\u03c6 1\u03c6 4\u00d7=\fNormalization results in interpretable weights since increasing one weight pushes other, less\nimportant, weights closer to zero, as demonstrated in Section 4.\n\u2022 Fast Softmax. To achieve dynamic routing, we condition the reading weight matrix on\nthe current module features \u03c6i. This can be seen as a form of fast weights, providing\na biologically plausible method for attention [2, 27]. We then apply softmax normal-\nization to the computed weights so that each element of the context is computed as a\nweighted average over center elements, rather than just a weighted sum. Speci\ufb01cally,\n\nk=1 e(W \u03c6+b)(jk)(cid:1)\u03a6 with weights W \u2208 R|\u03c6|\u00d7|\u03a6|\u00d7|c|, and bi-\n\nr(\u03a6, \u03c6)(j) = (cid:0)e(W \u03c6+b)(j)/(cid:80)|\u03a6|\n\nf(cid:0)(1, 2,\u00b7\u00b7\u00b7 ,|\u03a6|)|(W \u03c6 + b)(j), (U \u03c6 + d)(j)\n\nases b \u2208 R|\u03a6|\u00d7|c|. While this allows for a different connectivity pattern at each time step, it\nintroduces |\u03c6i + 1| \u00d7 |\u03a6| \u00d7 |ci| learned parameters per module.\n\u2022 Fast Gaussian. As a compact parameterization for dynamic routing, we consider choos-\ning each context element as a Gaussian weighted average of \u03a6, with only mean and vari-\nance vectors learned conditioned on \u03c6i. The context input is computed as r(\u03a6, \u03c6)(j) =\nb, d \u2208 R|c|, and the Gaussian density function f (x|\u00b5, \u03c32). The density is evaluated for each\nindex in \u03a6 based on its distance from the mean. This reading mechanism only requires\n|\u03c6i + 1| \u00d7 2 \u00d7 |ci| parameters per module and thus makes dynamic reading more practical.\nReading mechanisms could also select between modules on a high level, instead of individual feature\nelements. We do not explore this direction since it seems less biologically plausible. Moreover,\nwe demonstrate that such knowledge about feature boundaries is not necessary, and hierarchical\ninformation \ufb02ow emerges when using \ufb01ne-grained routing (see Figure 4). Theoretically, this also\nallows our model to perform a wider class of computations.\n\n(cid:1)\u03a6 with weights W, U \u2208 R|c|\u00d7|\u03c6|, biases\n\n3 Performance Comparison\n\nWe investigate the properties and performance of our model on several benchmark tasks. First, we\ncompare reading mechanisms and module designs on a simple sequential task, to obtain a good\ncon\ufb01guration for the later experiments. Please refer to the supplementary material for the precise\nexperiment description and results. We \ufb01nd that the weight normalized reading mechanism provides\nbest performance and stability during training. We will use ThalNet models with four modules of\ncon\ufb01guration for all experiments in this section. To explore the performance of ThalNet, we now\nconduct experiments on three sequential tasks of increasing dif\ufb01culty:\n\n\u2022 Sequential Permuted MNIST. We use images from the MNIST [19] data set, the pixels\nof every image by a \ufb01xed random permutation, and show them to the model as a sequence\nof rows. The model outputs its prediction of the handwritten digit at the last time step, so\nthat it must integrate and remember observed information from previous rows. This delayed\nprediction combined with the permutation of pixels makes the task harder than the static\nimage classi\ufb01cation task, with a multi-layer recurrent neural network achieving ~65 % test\nerror. We use the standard split of 60,000 training images and 10,000 testing images.\n\n\u2022 Sequential CIFAR-10.\n\nIn a similar spirit, we use the CIFAR-10 [17] data set and feed\nimages to the model row by row. We \ufb02atten the color channels of every row so that the\nmodel observes a vector of 96 elements at every time step. The classi\ufb01cation is given after\nobserving the last row of the image. This task is more dif\ufb01cult than the MNIST task, as\nthe image show more complex and often ambiguous objects. The data set contains 50,000\ntraining images and 10,000 testing images.\n\u2022 Text8 Language Modeling. This text corpus consisting of the \ufb01rst 108 bytes of the\nEnglish Wikipedia is commonly used as a language modeling benchmark for sequential\nmodels. At every time step, the model observes one byte, usually corresponding to 1\ncharacter, encoded as a one-hot vector of length 256. The task it to predict the distribution\nof the next character in the sequence. Performance is measured in bits per character (BPC)\ncomputed as \u2212 1\ni=1 log2 p(xi). Following Cooijmans et al. [4], we train on the \ufb01rst\n90% and evaluate performance on the following 5% of the corpus.\n\n(cid:80)N\n\nN\n\nFor the two image classi\ufb01cation tasks, we compare variations of our model to a stacked Gated\nRecurrent Unit (GRU) [3] network of 4 layers as baseline. The variations we compare are different\n\n4\n\n\fFigure 3: Performance on the permuted sequential MNIST, sequential CIFAR, and text8 language\nmodeling tasks. The stacked GRU baseline reaches higher training accuracy on CIFAR, but fails to\ngeneralize well. On both tasks, ThalNet clearly outperforms the baseline in testing accuracy. On\nCIFAR, we see how recurrency within the modules speeds up training. The same pattern is shows for\nthe text8 experiment, where ThalNet using 12M parameters matches the performance of the baseline\nwith 14M parameters. The step number 1 or 2 refers to repeated inputs as discussed in Section 2. We\nhad to smooth the graphs using a running average since the models were evaluated on testing batches\non a rolling basis.\n\nchoices of feed-forward layers and GRU layers for implementing the modules f i(ci, xi): We test with\ntwo fully connected layers (FF), a GRU layer (GRU), fully connected followed by GRU (FF-GRU),\nGRU followed by fully connected (GRU-FF), and a GRU sandwiched between fully connected layers\n(FF-GRU-FF).5 For all models, we pick the largest layer sizes such that the number of parameters does\nnot exceed 50,000. Training is performed for 100 epochs on batches of size 50 using RMSProp [30]\nwith a learning rate of 10\u22123.\nFor language modeling, we simulate ThalNet for 2 steps per token, as described in Section 2 to allow\nthe output module to read information about the current input before making its prediction. Note\nthat on this task, our model uses only half of its capacity directly, since its side modules can only\nintegrate longer-term dependencies from previous time steps. We run the baseline once without extra\nsteps and once with 2 steps per token, allowing it to apply its full capacity once and twice on each\ntoken, respectively. This makes the comparison a bit dif\ufb01cult, but only by favouring the baseline.\nThis suggests that architectural modi\ufb01cations, such as explicit skip-connections between modules,\ncould further improve performance.\nThe Text8 task requires larger models. We train ThalNet with 4 modules of a size 400 feed forward\nlayer and a size 600 GRU layer each, totaling in 12 million model parameters. We compare to\na standard baseline in language modeling, a single GRU with 2000 units, totaling in 14 million\nparameters. We train on batches of 100 sequences, each containing 200 bytes, using the Adam\noptimizer [15] with a default learning rate of 10\u22123. We scale down gradients exceeding a norm of 1.\nResults for 50 epochs of training are shown in Figure 3. The training took about 8 days for ThalNet\nwith 2 steps per token, 6 days for the baseline with 2 steps per token, and 3 days for the baseline\nwithout extra steps.\nFigure 3 shows the training and testing and training curves for the three tasks described in this section.\nThalNet outperforms standard GRU networks in all three tasks. Interestingly, ThalNet experiences a\n\n5Note that the modules require some amount of local structure to allow them to specialize. Implementing the\n\nmodules as a single fully connected layer recovers a standard recurrent neural network with one large layer.\n\n5\n\n20406080100Epochs0.50.60.70.80.9Accuracy (%)ThalNet FF-GRU-FFThalNet FF-GRUThalNet FFThalNet GRU-FFThalNet GRUGRU BaselineSequential Permuted MNIST Testing20406080100Epochs0.350.400.450.50Accuracy (%)ThalNet GRU-FFThalNet FF-GRUThalNet GRUGRU BaselineThalNet FF-GRU-FFThalNet FFSequential CIFAR-10 Testing1020304050Epochs1.401.421.441.46Bits per character (BPC)GRU (1 step)GRU (2 steps)ThalNet (2 steps)Text8 Language Modeling Evaluation20406080100Epochs0.50.60.70.80.91.0Accuracy (%)ThalNet FF-GRUThalNet FF-GRU-FFThalNet GRU-FFThalNet FFThalNet GRUGRU BaselineSequential Permuted MNIST Training20406080100Epochs0.400.450.500.550.60Accuracy (%)GRU BaselineThalNet GRUThalNet FF-GRUThalNet GRU-FFThalNet FF-GRU-FFThalNet FFSequential CIFAR-10 Training1020304050Epochs1.201.251.301.351.401.45Bits per character (BPC)ThalNet (2 steps)GRU (1 step)GRU (2 steps)Text8 Language Modeling Training\fmuch smaller gap between training and testing performance than our baseline \u2013 a trend we observed\nacross all experimental results.\nOn the Text8 task, ThalNet scores 1.39 BPC using 12M parameters, while our GRU baseline scores\n1.41 BPC using 14M parameters (lower is better). Our model thus slightly improves on the baseline\nwhile using fewer parameters. This result places ThalNet in between the baseline and regularization\nmethods designed for language modeling, which can also be applied to our model. The baseline\nperformance is consistent with published results of LSTMs with similar number of parameters [18].\nWe hypothesize the information bottleneck at the reading mechanism acting as an implicit regularizer\nthat encourages generalization. Compared to using one large RNN that has a lot of freedom of\nmodeling the input-output mapping, ThalNet imposes local structure to how the input-output mapping\ncan be implemented. In particular, it encourages the model to decompose into several modules that\nhave stronger intra-connectivity than extra-connectivity. Thus, to some extend every module needs to\nlearn a self-contained computation.\n\n4 Hierarchical Connectivity Patterns\n\nUsing its routing center, our model is able to learn its structure as part of learning to solve the\ntask. In this section, we explore the emergent connectivity patterns. We show that our model learns\nto route features in hierarchical ways as hypothesized, including skip connections and feedback\nconnections. For this purpose, we choose the text8 corpus, a medium-scale language modeling\nbenchmark consisting of the \ufb01rst 108 bytes of Wikipedia, preprocessed for the Hutter Prize [21]. The\nmodel observes one one-hot encoded byte per time step, and is trained to predict its future input at\nthe next time step.\nWe use comparably small models to be able to run experiments quickly, comparing ThalNet models\nof 4 FF-GRU-FF modules with layer sizes 50, 100, 50 and 50, 200, 50. Both experiments use\nweight normalized reading. Our focus here is on exploring learned connectivity patterns. We show\ncompetitive results on the task using larger models in Section 3.\nWe simulate two sub time steps to allow for the output module to receive information of the current\ninput frame as discussed in Section 2. Models are trained for 50 epochs on batches of size 10\ncontaining sequences of length 50 using RMSProp with a learning rate of 10\u22123. In general, we\nobserve different random seeds converging to similar connectivity patterns with recurring elements.\n\n4.1 Trained Reading Weights\n\nFigure 4 shows trained reading weights for various reading mechanisms, along with their connectivity\ngraphs that were manually deduced.6 Each image represents a reading weight matrix for the modules\n1 to 4 (top to bottom). Each pixel row shows the weight factors that get multiplied with \u03a6 to produce\na single element of the context vector of that module. The weight matrices thus has dimensions of\n|\u03a6| \u00d7 |ci|. White pixels represent large magnitudes, suggesting focus on features at those positions.\nThe weight matrices of weight normalized reading clearly resemble the boundaries of the four\nconcatenated module features \u03c61,\u00b7\u00b7\u00b7 , \u03c64 in the center vector \u03a6, even though the model has no notion\nof the origin and ordering of elements in the center vector.\nA similar structure emerges with fast softmax reading. These weight matrices are sparser than\nthe weights from weight normalization. Over the course of a sequence, we observe some weights\nstaying constant while others change their magnitudes at each time step. This suggests that optimal\nconnectivity might include both static and dynamic elements. However, this reading mechanism\nleads to less stable training. This problem could potentially alleviated by normalizing the fast weight\nmatrix.\nWith fast Gaussian reading, we see that the distributions occasionally tighten on speci\ufb01c features\nin the \ufb01rst and last modules, the modules that receive input and emit output. The other modules\nlearn large variance parameters, effectively spanning all center features. This could potentially be\naddressed by reading using mixtures of Gaussians for each context element instead. We generally\n\ufb01nd that weight normalized and fast softmax reading select features with in a more targeted way.\n\n6Developing formal measurements for this deduction process seems bene\ufb01cial in the future.\n\n6\n\n\f(a) Weight Normalization\n\n(b) Fast Softmax\n\n(c) Fast Gaussian\n\nFigure 4: Reading weights learned by different reading mechanisms with 4 modules on the text8\nlanguage modeling task, alongside manually deducted connectivity graphs. We plot the weight\nmatrices that produce the context inputs to the four modules, top to bottom. The top images show\nfocus of the input modules, followed by side modules, and output modules at the bottom. Each pixel\nrow gets multiplied with the center vector \u03a6 to produce one scalar element of the context input ci.\nWe visualize the magnitude of weights between the 5 % to the 95 % percentile. We do not include the\nconnectivity graph for Fast Gaussian reading as its reading weights are not clearly structured.\n\n4.2 Commonly Learned Structures\n\nThe top row in Figure 4 shows manually deducted connectivity graphs between modules. Arrows\nrepresent the main direction of information \ufb02ow in the model. For example, the two incoming arrows\nto module 4 in Figure 4a indicate that module 4 mainly attends to features produced by modules 1\nand 3. We infer the connections from the larger weight magnitudes in the \ufb01rst and third quarters of\nthe reading weights for module 4 (bottom row).\nA typical pattern that emerges during the experiments can be seen in the connectivity graphs of both\nweight normalized and fast softmax reading (Figures 4a and 4b). Namely, the output module reads\nfeatures directly from the input module. This direction connection is established early on during\ntraining, likely because this is the most direct gradient path from output to input. Later on, the side\nmodules develop useful features to support the input and output modules.\nIn another pattern, one module reads from all other modules and combines their information. In\nFigure 4b, module 2 takes this role, reading from modules 1, 3, 4, and distributing these features via\nthe input module. In additional experiments with more than four modules, we observed this pattern to\nemerge predominantly. This connection pattern provides a more ef\ufb01cient way of information sharing\nthan cross-connecting all modules.\nBoth connectivity graphs in Figure 4 include hierarchical computation paths through the modules.\nThey include learn skip connections, which are known to improve gradient \ufb02ow from popular models\nsuch as ResNet [12], Highway networks [29], and DenseNet [14]. Furthermore, the connectivity\ngraphs contain backward connections, creating feedback loops over two or more modules. Feedback\nconnections are known to play a critical role in the neocortex, which inspired our work [7].\n\n5 Related Work\n\nWe describe a recurrent mixture of experts model, that learns to dynamically pass information between\nthe modules. Related approaches can be found in various recurrent and multi-task methods as outlined\nin this section.\n\n7\n\n3241xyskipconnectionskipconnectionfeedbackconnection3241xyfeedbackconnectionskipconnection\fModular Neural Networks. ThalNet consists of several recurrent modules that interact and exploit\neach other. Modularity is a common property of existing neural models. [5] learn a matrix of\ntasks and robot bodies to improve both multitask and transfer learning. [1] learn modules modules\nspeci\ufb01c to objects present in the scene, which are selected by an object classi\ufb01er. These approaches\nspecify modules corresponding to a speci\ufb01c task or variable manually. In contrast, our model\nautomatically discovers and exploits the inherent modularity of the task and does not require a\none-to-one correspondence of modules to task variables.\nThe Column Bundle model [23] consists of a central column and several mini-columns around it.\nWhile not applied to temporal data, we observe a structural similarity between our modules and the\nmini-columns, in the case where weights are shared among layers of the mini-columns, which the\nauthors mention as a possibility.\nLearned Computation Paths. We learn the connectivity between modules alongside the task.\nThere are various methods in the multi-task context that also connectivity between modules. Fernando\net al. [6] learn paths through multiple layers of experts using an evolutionary approach. Rusu et al. [25]\nlearn adapter connections to connect to \ufb01xed previously trained experts and exploit their information.\nThese approaches focus on feed-forward architectures. The recurrency in our approach allows for\ncomplex and \ufb02exible computational paths. Moreover, we learn interpretable weight matrices that can\nbe examined directly without performing costly sensitivity analysis.\nThe Neural Programmer Interpreted presented by Reed and De Freitas [24] is related to our dynamic\ngating mechanisms. In their work, a network recursively calls itself in a parameterized way to\nperform tree-shaped computations.\nIn comparison, our model allows for parallel computation\nbetween modules and for unrestricted connectivity patterns between modules.\nMemory Augmented RNNs. The center vector in our model can be interpreted as an external\nmemory, with multiple recurrent controllers operating on it. Preceding work proposes recurrent\nneural networks operating on external memory structures. The Neural Turing Machine proposed by\nGraves et al. [9], and follow-up work [10], investigate differentiable ways to address a memory for\nreading and writing. In the ThalNet model, we use multiple recurrent controllers accessing the center\nvector. Moreover, our center vector is recomputed at each time step, and thus should not be confused\nwith a persistent memory as is typical for model with external memory.\n\n6 Conclusion\n\nWe presented ThalNet, a recurrent modular framework that learns to pass information between neural\nmodules in a hierarchical way. Experiments on sequential and permuted variants of MNIST and\nCIFAR-10 are a promising sign of the viability of this approach. In these experiments, ThalNet\nlearns novel connectivity patterns that include hierarchical paths, skip connections, and feedback\nconnections.\nIn our current implementation, we assume the center features to be a vector. Introducing a matrix\nshape for the center features would open up ways to integrate convolutional modules and similarity-\nbased attention mechanisms for reading from the center. While matrix shaped features are easily\ninterpretable for visual input, it is less clear how this structure will be leveraged for other modalities.\nA further direction of future work is to apply our paradigm to tasks with multiple modalities for\ninputs and outputs. It seems natural to either have a separate input module for each modality, or to\nhave multiple output modules that can all share information through the center. We believe this could\nbe used to hint specialization into speci\ufb01c patterns and create more controllable connectivity patterns\nbetween modules. Similarly, we an interesting direction is to explore the proposed model can be\nleveraged to learn and remember a sequence of tasks.\nWe believe modular computation in neural networks will become more important as researchers\napproach more complex tasks and employ deep learning to rich, multi-modal domains. Our work\nprovides a step in the direction of automatically organizing neural modules that leverage each other\nin order to solve a wide range of tasks in a complex world.\n\n8\n\n\fReferences\n[1] J. Andreas, M. Rohrbach, T. Darrell, and D. Klein. Neural module networks. In IEEE Conference\n\non Computer Vision and Pattern Recognition, pages 39\u201348, 2016.\n\n[2] J. Ba, G. E. Hinton, V. Mnih, J. Z. Leibo, and C. Ionescu. Using fast weights to attend to the\nrecent past. In Advances in Neural Information Processing Systems, pages 4331\u20134339, 2016.\n\n[3] K. Cho, B. van Merri\u00ebnboer, D. Bahdanau, and Y. Bengio. On the properties of neural machine\ntranslation: Encoder\u2013decoder approaches. Syntax, Semantics and Structure in Statistical\nTranslation, page 103, 2014.\n\n[4] T. Cooijmans, N. Ballas, C. Laurent, \u00c7. G\u00fcl\u00e7ehre, and A. Courville. Recurrent batch normaliza-\n\ntion. arXiv preprint arXiv:1603.09025, 2016.\n\n[5] C. Devin, A. Gupta, T. Darrell, P. Abbeel, and S. Levine. Learning modular neural network\n\npolicies for multi-task and multi-robot transfer. arXiv preprint arXiv:1609.07088, 2016.\n\n[6] C. Fernando, D. Banarse, C. Blundell, Y. Zwols, D. Ha, A. A. Rusu, A. Pritzel, and D. Wier-\nstra. Pathnet: Evolution channels gradient descent in super neural networks. arXiv preprint\narXiv:1701.08734, 2017.\n\n[7] C. D. Gilbert and M. Sigman. Brain states: top-down in\ufb02uences in sensory processing. Neuron,\n\n54(5):677\u2013696, 2007.\n\n[8] A. Graves. Adaptive computation time for recurrent neural networks.\n\narXiv:1603.08983, 2016.\n\narXiv preprint\n\n[9] A. Graves, G. Wayne, and I. Danihelka. Neural turing machines. arXiv preprint arXiv:1410.5401,\n\n2014.\n\n[10] A. Graves, G. Wayne, M. Reynolds, T. Harley, I. Danihelka, A. Grabska-Barwi\u00b4nska, S. G.\nColmenarejo, E. Grefenstette, T. Ramalho, J. Agapiou, et al. Hybrid computing using a neural\nnetwork with dynamic external memory. Nature, 538(7626):471\u2013476, 2016.\n\n[11] J. Hawkins and D. George. Hierarchical temporal memory: Concepts, theory and terminology.\n\nTechnical report, Numenta, 2006.\n\n[12] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In IEEE\n\nConference on Computer Vision and Pattern Recognition, pages 770\u2013778, 2016.\n\n[13] G. Hinton, A. Krizhevsky, and S. Wang. Transforming auto-encoders. Arti\ufb01cial Neural Networks\n\nand Machine Learning (ICANN), pages 44\u201351, 2011.\n\n[14] G. Huang, Z. Liu, K. Q. Weinberger, and L. van der Maaten. Densely connected convolutional\n\nnetworks. arXiv preprint arXiv:1608.06993, 2016.\n\n[15] D. Kingma and J. Ba. Adam: A method for stochastic optimization. In International Conference\n\non Learning Representations, 2015.\n\n[16] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan,\nJ. Quan, T. Ramalho, A. Grabska-Barwinska, et al. Overcoming catastrophic forgetting in\nneural networks. Proceedings of the National Academy of Sciences, page 201611835, 2017.\n\n[17] A. Krizhevsky. Learning multiple layers of features from tiny images, 2009.\n\n[18] D. Krueger, T. Maharaj, J. Kram\u00e1r, M. Pezeshki, N. Ballas, N. R. Ke, A. Goyal, Y. Bengio,\nH. Larochelle, A. Courville, et al. Zoneout: Regularizing rnns by randomly preserving hidden\nactivations. arXiv preprint arXiv:1606.01305, 2016.\n\n[19] Y. LeCun and C. Cortes. The MNIST database of handwritten digits, 1998.\n\n[20] T. P. Lillicrap, D. Cownden, D. B. Tweed, and C. J. Akerman. Random synaptic feedback\n\nweights support error backpropagation for deep learning. Nature Communications, 7, 2016.\n\n9\n\n\f[21] M. Mahoney. About the test data. http://mattmahoney.net/dc/textdata, 2011.\n\n[22] W. S. McCulloch and W. Pitts. A logical calculus of the ideas immanent in nervous activity.\n\nThe bulletin of mathematical biophysics, 5(4):115\u2013133, 1943.\n\n[23] T. Pham, T. Tran, and S. Venkatesh. One size \ufb01ts many: Column bundle for multi-x learning.\n\narXiv preprint arXiv:1702.07021, 2017.\n\n[24] S. Reed and N. De Freitas. Neural programmer-interpreters. In International Conference on\n\nLearning Representations, 2015.\n\n[25] A. A. Rusu, N. C. Rabinowitz, G. Desjardins, H. Soyer, J. Kirkpatrick, K. Kavukcuoglu,\nR. Pascanu, and R. Hadsell. Progressive neural networks. arXiv preprint arXiv:1606.04671,\n2016.\n\n[26] T. Salimans and D. P. Kingma. Weight normalization: A simple reparameterization to accelerate\ntraining of deep neural networks. In Advances in Neural Information Processing Systems, pages\n901\u2013901, 2016.\n\n[27] J. Schmidhuber. Learning to control fast-weight memories: An alternative to dynamic recurrent\n\nnetworks. Neural Computation, 4(1):131\u2013139, 1992.\n\n[28] S. M. Sherman. Thalamus plays a central role in ongoing cortical functioning. Nature neuro-\n\nscience, 16(4):533\u2013541, 2016.\n\n[29] R. K. Srivastava, K. Greff, and J. Schmidhuber. Highway networks.\n\narXiv:1505.00387, 2015.\n\narXiv preprint\n\n[30] T. Tieleman and G. Hinton. Lecture 6.5-rmsprop: Divide the gradient by a running average of\n\nits recent magnitude. COURSERA: Neural networks for machine learning, 4(2), 2012.\n\n[31] F. Zenke, B. Poole, and S. Ganguli. Improved multitask learning through synaptic intelligence.\n\narXiv preprint arXiv:1703.04200, 2017.\n\n10\n\n\f", "award": [], "sourceid": 3374, "authors": [{"given_name": "Danijar", "family_name": "Hafner", "institution": "Google Brain"}, {"given_name": "Alexander", "family_name": "Irpan", "institution": "Google"}, {"given_name": "James", "family_name": "Davidson", "institution": "Google Brain"}, {"given_name": "Nicolas", "family_name": "Heess", "institution": "Google DeepMind"}]}