{"title": "BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning", "book": "Advances in Neural Information Processing Systems", "page_first": 7026, "page_last": 7037, "abstract": "We develop BatchBALD, a tractable approximation to the mutual information between a batch of points and model parameters, which we use as an acquisition function to select multiple informative points jointly for the task of deep Bayesian active learning. BatchBALD is a greedy linear-time $1 - \\nicefrac{1}{e}$-approximate algorithm amenable to dynamic programming and efficient caching. We compare BatchBALD to the commonly used approach for batch data acquisition and find that the current approach acquires similar and redundant points, sometimes performing worse than randomly acquiring data. We finish by showing that, using BatchBALD to consider dependencies within an acquisition batch, we achieve new state of the art performance on standard benchmarks, providing substantial data efficiency improvements in batch acquisition.", "full_text": "BatchBALD: E\ufb03cient and Diverse Batch Acquisition\n\nfor Deep Bayesian Active Learning\n\nAndreas Kirsch\u2217\n\nJoost van Amersfoort\u2217\n\nYarin Gal\n\nOATML\n\nDepartment of Computer Science\n\nUniversity of Oxford\n\n{andreas.kirsch, joost.van.amersfoort, yarin}@cs.ox.ac.uk\n\nAbstract\n\nWe develop BatchBALD, a tractable approximation to the mutual information\nbetween a batch of points and model parameters, which we use as an acquisition\nfunction to select multiple informative points jointly for the task of deep Bayesian\nactive learning. BatchBALD is a greedy linear-time 1 \u2212 1\ne -approximate algorithm\namenable to dynamic programming and e\ufb03cient caching. We compare BatchBALD\nto the commonly used approach for batch data acquisition and \ufb01nd that the current\napproach acquires similar and redundant points, sometimes performing worse\nthan randomly acquiring data. We \ufb01nish by showing that, using BatchBALD to\nconsider dependencies within an acquisition batch, we achieve new state of the\nart performance on standard benchmarks, providing substantial data e\ufb03ciency\nimprovements in batch acquisition.\n\n1\n\nIntroduction\n\nA key problem in deep learning is data e\ufb03ciency. While excellent performance can be obtained\nwith modern tools, these are often data-hungry, rendering the deployment of deep learning in the\nreal-world challenging for many tasks. Active learning (AL) [7] is a powerful technique for attaining\ndata e\ufb03ciency. Instead of a-priori collecting and labelling a large dataset, which often comes at a\nsigni\ufb01cant expense, in AL we iteratively acquire labels from an expert only for the most informative\ndata points from a pool of available unlabelled data. After each acquisition step, the newly labelled\npoints are added to the training set, and the model is retrained. This process is repeated until a suitable\nlevel of accuracy is achieved. The goal of AL is to minimise the amount of data that needs to be\nlabelled. AL has already made real-world impact in manufacturing [34], robotics [5], recommender\nsystems [1], medical imaging [18], and NLP [31], motivating the need for pushing AL even further.\nIn AL, the informativeness of new points is assessed by an acquisition function. There are a number\nof intuitive choices, such as model uncertainty and mutual information, and, in this paper, we focus\non BALD [19], which has proven itself in the context of deep learning [13, 30, 20]. BALD is based\non mutual information and scores points based on how well their label would inform us about the true\nmodel parameter distribution. In deep learning models [16, 32], we generally treat the parameters\nas point estimates instead of distributions. However, Bayesian neural networks have become a\npowerful alternative to traditional neural networks and do provide a distribution over their parameters.\nImprovements in approximate inference [4, 12] have enabled their usage for high dimensional data\nsuch as images and in conjunction with BALD for Bayesian AL of images [13].\nIn practical AL applications, instead of single data points, batches of data points are acquired\nduring each acquisition step to reduce the number of times the model is retrained and expert-time is\n\n\u2217joint \ufb01rst authors\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fFigure 1: Idealised acquisitions of BALD and\nBatchBALD. If a dataset were to contain many\n(near) replicas for each data point, then BALD\nwould select all replicas of a single informative\ndata point at the expense of other informative data\npoints, wasting data e\ufb03ciency.\n\nFigure 2: Performance on Repeated MNIST with\nacquisition size 10. See section 4.1 for further\ndetails. BatchBALD outperforms BALD while\nBALD performs worse than random acquisition\ndue to the replications in the dataset.\n\nrequested. Model retraining becomes a computational bottleneck for larger models while expert time\nis expensive: consider, for example, the e\ufb00ort that goes into commissioning a medical specialist to\nlabel a single MRI scan, then waiting until the model is retrained, and then commissioning a new\nmedical specialist to label the next MRI scan, and the extra amount of time this takes.\nIn Gal et al. [13], batch acquisition, i.e. the acquisition of multiple points, takes the top b points\nwith the highest BALD acquisition score. This naive approach leads to acquiring points that are\nindividually very informative, but not necessarily so jointly. See \ufb01gure 1 for such a batch acquisition\nof BALD in which it performs poorly whereas scoring points jointly (\"BatchBALD\") can \ufb01nd batches\nof informative data points. Figure 2 shows how a dataset consisting of repeated MNIST digits (with\nadded Gaussian noise) leads BALD to perform worse than random acquisition while BatchBALD\nsustains good performance.\nNaively \ufb01nding the best batch to acquire requires enumerating all possible subsets within the available\ndata, which is intractable as the number of potential subsets grows exponentially with the acquisition\nsize b and the size of available points to choose from. Instead, we develop a greedy algorithm that\nselects a batch in linear time, and show that it is at worst a 1\u2212 1/e approximation to the optimal choice\nfor our acquisition function. We provide an open-source implementation2.\nThe main contributions of this work are:\n\n1. BatchBALD, a data-e\ufb03cient active learning method that acquires sets of high-dimensional\n\nimage data, leading to improved data e\ufb03ciency and reduced total run time, section 3.1;\n\n2. a greedy algorithm to select a batch of points e\ufb03ciently, section 3.2; and\n3. an estimator for the acquisition function that scales to larger acquisition sizes and to datasets\n\nwith many classes, section 3.3.\n\n2 Background\n\n2.1 Problem Setting\nThe Bayesian active learning setup consists of an unlabelled dataset Dpool, the current training set\nDtrain, a Bayesian model M with model parameters \u03c9\u03c9\u03c9 \u223c p(\u03c9\u03c9\u03c9 | Dtrain), and output predictions p(y |\nx, \u03c9\u03c9\u03c9,Dtrain) for data point x and prediction y \u2208 {1, ..., c} in the classi\ufb01cation case. The conditioning of\n\u03c9\u03c9\u03c9 on Dtrain expresses that the model has been trained with Dtrain. Furthermore, an oracle can provide\nus with the correct label \u02dcy for a data point in the unlabelled pool x \u2208 Dpool. The goal is to obtain a\ncertain level of prediction accuracy with the least amount of oracle queries. At each acquisition step,\n\n2https://github.com/BlackHC/BatchBALD\n\n2\n\nBatchBALDBALD\f(cid:110)\n\n(cid:111)\n\n(cid:111)\n\na batch of data points\nis selected using an acquisition function a which scores a candidate\nbatch of unlabelled data points {x1, ..., xb} \u2286 Dpool using the current model parameters p(\u03c9\u03c9\u03c9 | Dtrain):\n(1)\n\na(cid:0){x1, ..., xb} , p(\u03c9\u03c9\u03c9 | Dtrain)(cid:1).\n\n= arg max\n\nb\n\n(cid:110)\n\nx\u2217\n1, ..., x\u2217\nx\u2217\n1, ..., x\u2217\n\nb\n\n{x1,...,xb}\u2286Dpool\n\n2.2 BALD\n\n(cid:2)H(y | x, \u03c9\u03c9\u03c9,Dtrain)(cid:3) .\n\nBALD (Bayesian Active Learning by Disagreement) [19] uses an acquisition function that estimates\nthe mutual information between the model predictions and the model parameters. Intuitively, it\ncaptures how strongly the model predictions for a given data point and the model parameters are\ncoupled, implying that \ufb01nding out about the true label of data points with high mutual information\nwould also inform us about the true model parameters. Originally introduced outside the context of\ndeep learning, the only requirement on the model is that it is Bayesian. BALD is de\ufb01ned as:\n\nI(y ; \u03c9\u03c9\u03c9 | x,Dtrain) = H(y | x,Dtrain) \u2212 Ep(\u03c9\u03c9\u03c9|Dtrain)\n\n(2)\nLooking at the two terms in equation (2), for the mutual information to be high, the left term has to\nbe high and the right term low. The left term is the entropy of the model prediction, which is high\nwhen the model\u2019s prediction is uncertain. The right term is an expectation of the entropy of the model\nprediction over the posterior of the model parameters and is low when the model is overall certain\nfor each draw of model parameters from the posterior. Both can only happen when the model has\nmany possible ways to explain the data, which means that the posterior draws are disagreeing among\nthemselves.\nBALD was originally intended for acquiring individual data points and immediately retraining the\nmodel. This becomes a bottleneck in deep learning, where retraining takes a substantial amount of\ntime. Applications of BALD [12, 20] usually acquire the top b. This can be expressed as summing\nover individual scores:\n\n(cid:0){x1, ..., xb} , p(\u03c9\u03c9\u03c9 | Dtrain)(cid:1) =\n\nb(cid:88)\n\naBALD\n\nI(yi ; \u03c9\u03c9\u03c9 | xi,Dtrain),\n\n(3)\n\nand \ufb01nding the optimal batch for this acquisition function using a greedy algorithm, which reduces to\npicking the top b highest-scoring data points.\n\ni=1\n\n2.3 Bayesian Neural Networks (BNN)\n\nIn this paper we focus on BNNs as our Bayesian model because they scale well to high dimensional\ninputs, such as images. Compared to regular neural networks, BNNs maintain a distribution over\ntheir weights instead of point estimates. Performing exact inference in BNNs is intractable for any\nreasonably sized model, so we resort to using a variational approximation. Similar to Gal et al. [13],\nwe use MC dropout [12], which is easy to implement, scales well to large models and datasets, and is\nstraightforward to optimise.\n\n3 Methods\n\n3.1 BatchBALD\n\naBatchBALD\n\nWe propose BatchBALD as an extension of BALD whereby we jointly score points by estimating the\nmutual information between a joint of multiple data points and the model parameters:3\n\n(cid:0){x1, ..., xb} , p(\u03c9\u03c9\u03c9 | Dtrain)(cid:1) = I(y1, ..., yb ; \u03c9\u03c9\u03c9 | x1, ..., xb,Dtrain).\n\n(4)\nThis builds on the insight that independent selection of a batch of data points leads to data ine\ufb03ciency\nas correlations between data points in an acquisition batch are not taken into account.\nTo understand how to compute the mutual information between a set of points and the model\nparameters, we express x1, ..., xb, and y1, ..., yb through joint random variables x1:b and y1:b in a\nproduct probability space and use the de\ufb01nition of the mutual information for two random variables:\n(5)\n3We use the notation I(x, y ; z | c) to denote the mutual information between the joint of the random variables\n\nI(y1:b ; \u03c9\u03c9\u03c9 | x1:b,Dtrain) = H(y1:b | x1:b,Dtrain) \u2212 Ep(\u03c9\u03c9\u03c9|Dtrain) H(y1:b | x1:b, \u03c9\u03c9\u03c9,Dtrain).\n\nx, y and the random variable z conditioned on c.\n\n3\n\n\f(cid:88)\n\ni\n\nI(yi ; \u03c9\u03c9\u03c9 | xi,Dtrain) =\n\n(cid:88)\n\ni\n\n\u00b5*(yi \u2229 \u03c9\u03c9\u03c9)\n\nI(y1, ..., yb ; \u03c9\u03c9\u03c9| x1, ..., xb,Dtrain) = \u00b5*\n\n\uf8eb\uf8ec\uf8ec\uf8ec\uf8ec\uf8ec\uf8ed(cid:91)\n\ni\n\n\uf8f6\uf8f7\uf8f7\uf8f7\uf8f7\uf8f7\uf8f8\n\nyi \u2229 \u03c9\u03c9\u03c9\n\n(a) BALD\n\n(b) BatchBALD\n\nFigure 3: Intuition behind BALD and BatchBALD using I-diagrams [36]. BALD overestimates the\njoint mutual information. BatchBALD, however, takes the overlap between variables into account and\nwill strive to acquire a better cover of \u03c9\u03c9\u03c9. Areas contributing to the respective score are shown in grey,\nand areas that are double-counted in dark grey.\n\nAlgorithm 1: Greedy BatchBALD 1 \u2212 1/e-approximate algorithm\nInput: acquisition size b, unlabelled dataset Dpool, model parameters p(\u03c9\u03c9\u03c9 | Dtrain)\n1 A0 \u2190 \u2205\n2 for n \u2190 1 to b do\nforeach x \u2208 Dpool \\ An\u22121 do sx \u2190 aBatchBALD\nxn \u2190 arg max\nsx\nx\u2208Dpool\\An\u22121\nAn \u2190 An\u22121 \u222a {xn}\n\n(cid:0)An\u22121 \u222a {x} , p(\u03c9\u03c9\u03c9 | Dtrain)(cid:1)\n\n3\n4\n\n5\n6 end\nOutput: acquisition batch An = {x1, ..., xb}\n\ndiverse ones under maximisation.\n\nIntuitively, the mutual information between two random variables can be seen as the intersection\nof their information content. In fact, Yeung [36] shows that a signed measure \u00b5\u2217 can be de\ufb01ned for\ndiscrete random variables x, y, such that I(x ; y) = \u00b5*(x \u2229 y), H(x, y) = \u00b5*(x \u222a y), Ep(y) H(x | y) =\n\u00b5*(x \\ y), and so on, where we identify random variables with their counterparts in information space,\nand conveniently drop conditioning on Dtrain and xi.\ni \u00b5*(yi \u2229 \u03c9\u03c9\u03c9), which double\ncounts overlaps between the yi. Naively extending BALD to the mutual information between y1, ..., yb\n\nUsing this, BALD can be viewed as the sum of individual intersections(cid:80)\nand \u03c9\u03c9\u03c9, which is equivalent to \u00b5*(cid:0)(cid:84)\nBatchBALD, on the other hand, takes overlaps into account by computing \u00b5*(cid:0)(cid:83)\n\uf8eb\uf8ec\uf8ec\uf8ec\uf8ec\uf8ec\uf8ed(cid:91)\n\nI(y1, ..., yb ; \u03c9\u03c9\u03c9 | x1, ..., xb,Dtrain) = H(y1:b | x1:b,Dtrain) \u2212 Ep(\u03c9\u03c9\u03c9|Dtrain) H(y1:b | x1:b, \u03c9\u03c9\u03c9,Dtrain)\n= \u00b5*\n(7)\nThis is depicted in \ufb01gure 3 and also motivates that aBatchBALD \u2264 aBALD, which we prove in appendix\nB.1. For acquisition size 1, BatchBALD and BALD are equivalent.\n\ni yi \u2229 \u03c9\u03c9\u03c9(cid:1), would lead to selecting similar data points instead of\ni yi \u2229 \u03c9\u03c9\u03c9(cid:1) and is more\n\nlikely to acquire a more diverse cover under maximisation:\n\n\uf8f6\uf8f7\uf8f7\uf8f7\uf8f7\uf8f7\uf8f8 = \u00b5*\n\n\uf8eb\uf8ec\uf8ec\uf8ec\uf8ec\uf8ec\uf8ed(cid:91)\n\n\uf8f6\uf8f7\uf8f7\uf8f7\uf8f7\uf8f7\uf8f8 \u2212 \u00b5*\n\n\uf8eb\uf8ec\uf8ec\uf8ec\uf8ec\uf8ec\uf8ed(cid:91)\n\ni\n\nyi \\ \u03c9\u03c9\u03c9\n\nyi\n\ni\n\n\uf8f6\uf8f7\uf8f7\uf8f7\uf8f7\uf8f7\uf8f8\n\nyi \u2229 \u03c9\u03c9\u03c9\n\ni\n\n(6)\n\n3.2 Greedy approximation algorithm for BatchBALD\n\nTo avoid the combinatorial explosion that arises from jointly scoring subsets of points, we introduce a\ngreedy approximation for computing BatchBALD, depicted in algorithm 1. In appendix A, we prove\nthat aBatchBALD is submodular, which means the greedy algorithm is 1 \u2212 1/e-approximate [8, 24, 25].\n\n4\n\n\fn(cid:88)\n\ni=1\n\nk\n\nn(cid:88)\n\nk(cid:88)\n\ni=1\n\nj=1\n\n(9)\n\nEp(\u03c9\u03c9\u03c9)\n\nH(yi | \u02c6\u03c9\u03c9\u03c9 j).\n\n(cid:2)H(yi | \u03c9\u03c9\u03c9)(cid:3) \u2248 1\n(cid:2)p(y | \u03c9\u03c9\u03c9)(cid:3), and, using sampled \u02c6\u03c9\u03c9\u03c9 j, we compute the\n(cid:2)\u2212 log p(y1, ..., yn)(cid:3)\n(cid:2)p(y1, ..., yn | \u03c9\u03c9\u03c9)(cid:3)(cid:105)\n(cid:104)\u2212 log Ep(\u03c9\u03c9\u03c9)\n\uf8eb\uf8ec\uf8ec\uf8ec\uf8ec\uf8ec\uf8ec\uf8ed1\n\uf8f6\uf8f7\uf8f7\uf8f7\uf8f7\uf8f7\uf8f7\uf8f8 log\n\uf8f6\uf8f7\uf8f7\uf8f7\uf8f7\uf8f7\uf8f7\uf8f8 .\nk(cid:88)\nk(cid:88)\n\np(\u02c6y1:n | \u02c6\u03c9\u03c9\u03c9 j)\n\np(\u02c6y1:n | \u02c6\u03c9\u03c9\u03c9 j)\n\n(10)\n(11)\n\n(12)\n\nk\n\nH(y1, ..., yn) = Ep(y1,...,yn)\n\n= Ep(\u03c9\u03c9\u03c9) Ep(y1,...,yn|\u03c9\u03c9\u03c9)\n\n\u2248 \u2212(cid:88)\n\n\uf8eb\uf8ec\uf8ec\uf8ec\uf8ec\uf8ec\uf8ec\uf8ed1\n\nk\n\n\u02c6y1:n\n\nj=1\n\nIn appendix B.2, we show that, under idealised conditions, when using BatchBALD and a \ufb01xed \ufb01nal\n|Dtrain|, the active learning loop itself can be seen as a greedy 1 \u2212 1/e-approximation algorithm, and\nthat an active learning loop with BatchBALD and acquisition size larger than 1 is bounded by an an\nactive learning loop with individual acquisitions, that is BALD/BatchBALD with acquisition size 1,\nwhich is the ideal case.\n\n3.3 Computing aBatchBALD\nFor brevity, we leave out conditioning on x1, ..., xn, and Dtrain, and p(\u03c9\u03c9\u03c9) denotes p(\u03c9\u03c9\u03c9 | Dtrain) in this\nsection. aBatchBALD is then written as:\n\n(cid:0){x1, ..., xn} , p(\u03c9\u03c9\u03c9)(cid:1) = H(y1, ..., yn) \u2212 Ep(\u03c9\u03c9\u03c9)\n\n(8)\nBecause the yi are independent when conditioned on \u03c9\u03c9\u03c9, computing the right term of equation (8)\nis simpli\ufb01ed as the conditional joint entropy decomposes into a sum. We can approximate the\nexpectation using a Monte-Carlo estimator with k samples from our model parameter distribution\n\u02c6\u03c9\u03c9\u03c9 j \u223c p(\u03c9\u03c9\u03c9):\n\n(cid:2)H(y1, ..., yn | \u03c9\u03c9\u03c9)(cid:3) .\n\naBatchBALD\n\n(cid:2)H(y1, ..., yn | \u03c9\u03c9\u03c9)(cid:3) =\n\nEp(\u03c9\u03c9\u03c9)\n\nComputing the left term of equation (8) is di\ufb03cult because the unconditioned joint probability does\nnot factorise. Applying the equality p(y) = Ep(\u03c9\u03c9\u03c9)\nentropy by summing over all possible con\ufb01gurations \u02c6y1:n of y1:n:\n\nj=1\n\n(cid:32)1\n\nk\n\n(cid:33)\n\nk(cid:88)\n\nj=1\n\n1\nk\n\nk(cid:88)\n\nj=1\n\n3.4 E\ufb03cient implementation\nIn each iteration of the algorithm, x1, ..., xn\u22121 stay \ufb01xed while xn varies over Dpool \\ An\u22121. We can\nreduce the required computations by factorizing p(y1:n | \u03c9\u03c9\u03c9) into p(y1:n\u22121 | \u03c9\u03c9\u03c9) p(yn | \u03c9\u03c9\u03c9). We store\np(\u02c6y1:n\u22121 | \u02c6\u03c9\u03c9\u03c9 j) in a matrix \u02c6P1:n\u22121 of shape cn\u22121 \u00d7 k and p(yn | \u02c6\u03c9\u03c9\u03c9 j) in a matrix \u02c6Pn of shape c \u00d7 k. The\n\nj=1 p(\u02c6y1:n | \u02c6\u03c9\u03c9\u03c9 j) in (12) can be then be turned into a matrix product:\n\nsum(cid:80)k\n\np(\u02c6y1:n | \u02c6\u03c9\u03c9\u03c9 j) =\n\n1\nk\n\np(\u02c6y1:n\u22121 | \u02c6\u03c9\u03c9\u03c9 j) p(\u02c6yn | \u02c6\u03c9\u03c9\u03c9 j) =\n\n\u02c6P1:n\u22121 \u02c6PT\nn\n\n.\n\n\u02c6y1:n\u22121,\u02c6yn\n\n(13)\n\nThis can be further sped up by using batch matrix multiplication to compute the joint entropy for\ndi\ufb00erent xn. \u02c6P1:n\u22121 only has to be computed once, and we can recursively compute \u02c6P1:n using \u02c6P1:n\u22121\nand \u02c6Pn, which allows us to sample p(y | \u02c6\u03c9\u03c9\u03c9 j) for each x \u2208 Dpool only once at the beginning of the\nalgorithm.\nFor larger acquisition sizes, we use m MC samples of y1:n\u22121 as enumerating all possible con\ufb01gurations\nbecomes infeasible. See appendix C for details.\nMonte-Carlo sampling bounds the time complexity of the full BatchBALD algorithm to O(bc \u00b7\nmin{cb, m} \u00b7 |Dpool| \u00b7 k) compared to O(cb \u00b7 |Dpool|b \u00b7 k) for naively \ufb01nding the exact optimal batch and\nO((b + k) \u00b7 |Dpool|) for BALD4.\n\n4 Experiments\n\nIn our experiments, we start by showing how a naive application of the BALD algorithm to an image\ndataset can lead to poor results in a dataset with many (near) duplicate data points, and show that\nBatchBALD solves this problem in a grounded way while obtaining favourable results (\ufb01gure 2).\n\n4b is the acquisition size, c is the number of classes, k is the number of MC dropout samples, and m is the\n\nnumber of sampled con\ufb01gurations of y1:n\u22121.\n\n5\n\n\f(a) BALD\n\n(b) BatchBALD\n\nFigure 4: Performance on MNIST for increasing acquisition sizes. BALD\u2019s performance drops\ndrastically as the acquisition size increases. BatchBALD maintains strong performance even with\nincreasing acquisition size.\n\nWe then illustrate BatchBALD\u2019s e\ufb00ectiveness on standard AL datasets: MNIST and EMNIST.\nEMNIST [6] is an extension of MNIST that also includes letters, for a total of 47 classes, and has a\ntwice as large training set. See appendix F for examples of the dataset. We show that BatchBALD\nprovides a substantial performance improvement in these scenarios, too, and has more diverse\nacquisitions. Finally, we look at BatchBALD in the setting of transfer learning, where we \ufb01netune a\nlarge pretrained model on a more di\ufb03cult dataset called CINIC-10 [9], which is a combination of\nCIFAR-10 and downscaled ImageNet.\nIn our experiments, we repeatedly go through active learning loops. One active learning loop consists\nof training the model on the available labelled data and subsequently acquiring new data points using\na chosen acquisition function. As the labelled dataset is small in the beginning, it is important to\navoid over\ufb01tting. We do this by using early stopping after 3 epochs of declining accuracy on the\nvalidation set. We pick the model with the highest validation accuracy. Throughout our experiments,\nwe use the Adam [22] optimiser with learning rate 0.001 and betas 0.9/0.999. All our results report\nthe median of 6 trials, with lower and upper quartiles. We use these quartiles to draw the \ufb01lled error\nbars on our \ufb01gures.\nWe reinitialize the model after each acquisition, similar to Gal et al. [13]. We found this helps\nthe model improve even when very small batches are acquired. It also decorrelates subsequent\nacquisitions as \ufb01nal model performance is dependent on a particular initialization [10].\nWhen computing p(y| x, \u03c9\u03c9\u03c9,Dtrain), it is important to keep the dropout masks in MC dropout consistent\nwhile sampling from the model. This is necessary to capture dependencies between the inputs for\nBatchBALD, and it makes the scores for di\ufb00erent points more comparable by removing this source\nof noise. We do not keep the masks \ufb01xed when computing BALD scores because its performance\nusually bene\ufb01ts from the added noise. We also do not need to keep these masks \ufb01xed for training and\nevaluating the model.\nIn all our experiments, we either compute joint entropies exactly by enumerating all con\ufb01gurations,\nor we estimate them using 10,000 MC samples, picking whichever method is faster. In practice, we\ncompute joint entropies exactly for roughly the \ufb01rst 4 data points in an acquisition batch and use MC\nsampling thereafter.\n\n4.1 Repeated MNIST\n\nAs demonstrated in the introduction, naively applying BALD to a dataset that contains many (near)\nreplicated data points leads to poor performance. We show how this manifests in practice by taking\nthe MNIST dataset and replicating each data point in the training set two times (obtaining a training\nset that is three times larger than the original). After normalising the dataset, we add isotropic\nGaussian noise with a standard deviation of 0.1 to simulate slight di\ufb00erences between the duplicated\ndata points in the training set. All results are obtained using an acquisition size of 10 and 10 MC\n\n6\n\n\fFigure 5: Performance on MNIST. BatchBALD\noutperforms BALD with acquisition size 10 and\nperforms close to the optimum of acquisition size\n1.\n\nFigure 6: Relative total time on MNIST. Normal-\nized to training BatchBALD with acquisition size\n10 to 95% accuracy. The stars mark when 95%\naccuracy is reached for each method.\n\nTable 1: Number of required data points on MNIST until 90% and 95% accuracy are reached. 25%-,\n50%- and 75%-quartiles for the number of required data points when available.\n\n90% accuracy\nBatchBALD 70 / 90 / 110\nBALD 6\nBALD [13]\n\n120 / 120 / 170\n145\n\n95% accuracy\n190 / 200 / 230\n250 / 250 / >300\n335\n\ndropout samples. The initial dataset was constructed by taking a balanced set of 20 data points5, two\nof each class (similar to [13]).\nOur model consists of two blocks of [convolution, dropout, max-pooling, relu], with 32 and 64 5x5\nconvolution \ufb01lters. These blocks are followed by a two-layer MLP that includes dropout between the\nlayers and has 128 and 10 hidden units. The dropout probability is 0.5 in all three locations. This\narchitecture achieves 99% accuracy with 10 MC dropout samples during test time on the full MNIST\ndataset.\nThe results can be seen in \ufb01gure 2. In this illustrative scenario, BALD performs poorly, and even\nrandomly acquiring points performs better. However, BatchBALD is able to cope with the replication\nperfectly. In appendix D, we look at varying the repetition number and show that as we increase\nthe number of repetitions BALD gradually performs worse. In appendix E, we also compare with\nVariation Ratios [11], and Mean STD [21] which perform on par with random acquisition.\n\n4.2 MNIST\n\nFor the second experiment, we follow the setup of Gal et al. [13] and perform AL on the MNIST\ndataset using 100 MC dropout samples. We use the same model architecture and initial dataset as\ndescribed in section 4.1. Due to di\ufb00erences in model architecture, hyper parameters and model\nretraining, we signi\ufb01cantly outperform the original results in Gal et al. [13] as shown in table 1.\nWe \ufb01rst look at BALD for increasing acquisition size in \ufb01gure 4a. As we increase the acquisition size\nfrom the ideal of acquiring points individually and fully retraining after each points (acquisition size\n1) to 40, there is a substantial performance drop.\nBatchBALD, in \ufb01gure 4b, is able to maintain performance when doubling the acquisition size from 5\nto 10. Performance drops only slightly at 40, possibly due to estimator noise.\nThe results for acquisition size 10 for both BALD and BatchBALD are compared in \ufb01gure 5.\nBatchBALD outperforms BALD. Indeed, BatchBALD with acquisition size 10 performs close to the\nideal with acquisition size 1. The total run time of training these three models until 95% accuracy is\n\n5These initial data points were chosen by running BALD 6 times with the initial dataset picked randomly and\n\nchoosing the set of the median model. They were subsequently held \ufb01xed.\n\n6reimplementation using reported experimental setup\n\n7\n\n\fFigure 7: Performance on EMNIST. BatchBALD\nconsistently outperforms both random acquisition\nand BALD while BALD is unable to beat random\nacquisition.\n\nFigure 8: Entropy of acquired class labels over ac-\nquisition steps on EMNIST. BatchBALD steadily\nacquires a more diverse set of data points.\n\nvisualized in \ufb01gure 6, where we see that BatchBALD with acquisition size 10 is much faster than\nBALD with acquisition size 1, and only marginally slower than BALD with acquisition size 10.\n\n4.3 EMNIST\n\nIn this experiment, we show that BatchBALD also provides a signi\ufb01cant improvement when we\nconsider the more di\ufb03cult EMNIST dataset [6] in the Balanced setup, which consists of 47 classes,\ncomprising letters and digits. The training set consists of 112,800 28x28 images balanced by class, of\nwhich the last 18,800 images constitute the validation set. We do not use an initial dataset and instead\nperform the initial acquisition step with the randomly initialized model. We use 10 MC dropout\nsamples.\nWe use a similar model architecture as before, but with added capacity. Three blocks of [convolution,\ndropout, max-pooling, relu], with 32, 64 and 128 3x3 convolution \ufb01lters, and 2x2 max pooling. These\nblocks are followed by a two-layer MLP with 512 and 47 hidden units, with again a dropout layer in\nbetween. We use dropout probability 0.5 throughout the model.\nThe results for acquisition size 5 can be seen in \ufb01gure 7. BatchBALD outperforms both random\nacquisition and BALD while BALD is unable to beat random acquisition. Figure 8 gives some insight\ninto why BatchBALD performs better than BALD. The entropy of the categorical distribution of\nacquired class labels is consistently higher, meaning that BatchBALD acquires a more diverse set\nof data points. In \ufb01gure 15, the classes on the x-axis are sorted by number of data points that were\nacquired of that class. We see that BALD undersamples classes while BatchBALD is more consistent.\n\n4.4 CINIC-10\n\nCINIC-10 is an interesting dataset because it is large\n(270k data points) and its data comes from two di\ufb00er-\nent sources: CIFAR-10 and ImageNet. To get strong\nperformance on the test set it is important to obtain\ndata from both sets. Instead of training a very deep\nmodel from scratch on a small dataset, we opt to run\nthis experiment in a transfer learning setting, where\nwe use a pretrained model and acquire data only to\n\ufb01netune the original model. This is common prac-\ntice and suitable in cases where data is abound for\nan auxiliary domain, but is expensive to label for the\ndomain of interest.\nFor the CINIC-10 experiment, we use 160k training\nsamples for the unlabelled pool, 20k validation sam-\nples, and the other 90k as test samples. We use an\n\n8\n\nFigure 9: Performance on CINIC-10. Batch-\nBALD outperforms BALD from 500 acquired\nsamples onwards.\n\n\fImageNet pretrained VGG-16, provided by PyTorch [26], with a dropout layer before a 512 hidden\nunit (instead of 4096) fully connected layer. We use 50 MC dropout samples, acquisition size 10 and\nrepeat the experiment for 6 trials. The results are in \ufb01gure 9, with the 59% mark reached at 1170 for\nBatchBALD and 1330 for BALD (median).\n\n5 Related work\n\nAL is closely related to Bayesian Optimisation (BO), which is concerned with \ufb01nding the global\noptimum of a function [33], with the fewest number of function evaluations. This is generally done\nusing a Gaussian Process. A common problem in BO is the lack of parallelism, with usually a single\nworker being responsible for function evaluations. In real-world settings, there are usually many\nsuch workers available and making optimal use of them is an open problem [14, 2] with some work\nexploring mutual information for optimising a multi-objective problem [17].\nMaintaining diversity when acquiring a batch of data has also been attempted using constrained\noptimisation [15] and in Gaussian Mixture Models [3]. In AL of molecular data, the lack of diversity\nin batches of data points acquired using the BALD objective has been noted by Janz et al. [20], who\npropose to resolve it by limiting the number of MC dropout samples and relying on noisy estimates.\nA related approach to AL is semi-supervised learning (also sometimes referred to as weakly-\nsupervised), in which the labelled data is commonly assumed to be \ufb01xed and the unlabelled data is\nused for unsupervised learning [23, 27]. Wang et al. [35], Sener and Savarese [29], Samarth Sinha\n[28] explore combining it with AL.\n\n6 Scope and limitations\n\nUnbalanced datasets BALD and BatchBALD do not work well when the test set is unbalanced as\nthey aim to learn well about all classes and do not follow the density of the dataset. However, if the\ntest set is balanced, but the training set is not, we expect BatchBALD to perform well.\nUnlabelled data BatchBALD does not take into account any information from the unlabelled\ndataset. However, BatchBALD uses the underlying Bayesian model for estimating uncertainty for\nunlabelled data points, and semi-supervised learning could improve these estimates by providing\nmore information about the underlying structure of the feature space. We leave a semi-supervised\nextension of BatchBALD to future work.\nNoisy estimator A signi\ufb01cant amount of noise is introduced by MC-dropout\u2019s variational approxi-\nmation to training BNNs. Sampling of the joint entropies introduces additional noise. The quality of\nlarger acquisition batches would be improved by reducing this noise.\n\n7 Conclusion\n\nWe have introduced a new batch acquisition function, BatchBALD, for Deep Bayesian Active\nLearning, and a greedy algorithm that selects good candidate batches compared to the intractable\noptimal solution. Acquisitions show increased diversity of data points and improved performance\nover BALD and other methods.\nWhile our method comes with additional computational cost during acquisition, BatchBALD is able\nto signi\ufb01cantly reduce the number of data points that need to be labelled and the number of times\nthe model has to be retrained, potentially saving considerable costs and \ufb01lling an important gap in\npractical Deep Bayesian Active Learning.\n\n9\n\n\fAcknowledgements\n\nThe authors want to thank Binxin (Robin) Ru for helpful references to submodularity and the\nappropriate proofs. We would also like to thank the rest of OATML for their feedback at several\nstages of the project. AK is supported by the UK EPSRC CDT in Autonomous Intelligent Machines\nand Systems (grant reference EP/L015897/1). JvA is grateful for funding by the EPSRC (grant\nreference EP/N509711/1) and Google-DeepMind. Funding for computational resources was provided\nby the Allan Turing Institute and Google.\n\nAuthor contributions\n\nAK derived the original estimator, proved submodularity and bounds, implemented BatchBALD\ne\ufb03ciently, and ran the experiments. JvA developed the narrative and experimental design, advised on\ndebugging, structured the paper into its current form, and pushed it forward at di\ufb03cult times. JvA\nand AK wrote the paper jointly.\n\n10\n\n\fReferences\n[1] Gediminas Adomavicius and Alexander Tuzhilin. Toward the next generation of recommender\nIEEE Transactions on\n\nsystems: A survey of the state-of-the-art and possible extensions.\nKnowledge & Data Engineering, 2005.\n\n[2] Ahsan S Alvi, Binxin Ru, Jan Calliess, Stephen J Roberts, and Michael A Osborne. Asyn-\nchronous batch Bayesian optimisation with improved local penalisation. arXiv preprint\narXiv:1901.10452, 2019.\n\n[3] Javad Azimi, Alan Fern, Xiaoli Zhang-Fern, Glencora Borradaile, and Brent Heeringa. Batch\n\nactive learning via coordinated matching. arXiv preprint arXiv:1206.6458, 2012.\n\n[4] Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncertainty\nin neural network. In Proceedings of the 32nd International Conference on Machine Learning,\nProceedings of Machine Learning Research, pages 1613\u20131622, 2015.\n\n[5] Sylvain Calinon, Florent Guenter, and Aude Billard. On learning, representing, and generalizing\na task in a humanoid robot. IEEE Transactions on Systems, Man, and Cybernetics, Part B\n(Cybernetics), 37(2):286\u2013298, 2007.\n\n[6] Gregory Cohen, Saeed Afshar, Jonathan Tapson, and Andr\u00e9 van Schaik. Emnist: Extending\nIn 2017 International Joint Conference on Neural Networks\n\nmnist to handwritten letters.\n(IJCNN), pages 2921\u20132926. IEEE, 2017.\n\n[7] David A Cohn, Zoubin Ghahramani, and Michael I Jordan. Active learning with statistical\n\nmodels. Journal of arti\ufb01cial intelligence research, 4:129\u2013145, 1996.\n\n[8] Nguyen Viet Cuong, Wee Sun Lee, Nan Ye, Kian Ming A Chai, and Hai Leong Chieu. Active\nlearning for probabilistic hypotheses using the maximum gibbs error criterion. In Advances in\nNeural Information Processing Systems, pages 1457\u20131465, 2013.\n\n[9] Luke N Darlow, Elliot J Crowley, Antreas Antoniou, and Amos J Storkey. Cinic-10 is not\n\nimagenet or cifar-10. arXiv preprint arXiv:1810.03505, 2018.\n\n[10] Jonathan Frankle and Michael Carbin. The lottery ticket hypothesis: Finding sparse, trainable\n\nneural networks. In International Conference on Learning Representations, 2019.\n\n[11] Linton C Freeman. Elementary applied statistics: for students in behavioral science. John\n\nWiley & Sons, 1965.\n\n[12] Yarin Gal and Zoubin Ghahramani. Dropout as a Bayesian approximation: Representing\nmodel uncertainty in deep learning. In international conference on machine learning, pages\n1050\u20131059, 2016.\n\n[13] Yarin Gal, Riashat Islam, and Zoubin Ghahramani. Deep Bayesian active learning with image\ndata. In Proceedings of the 34th International Conference on Machine Learning-Volume 70,\npages 1183\u20131192. JMLR. org, 2017.\n\n[14] Javier Gonz\u00e1lez, Zhenwen Dai, Philipp Hennig, and Neil Lawrence. Batch Bayesian optimiza-\n\ntion via local penalization. In Arti\ufb01cial Intelligence and Statistics, pages 648\u2013657, 2016.\n\n[15] Yuhong Guo and Dale Schuurmans. Discriminative batch mode active learning. In Advances in\n\nneural information processing systems, pages 593\u2013600, 2008.\n\n[16] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image\nrecognition. In Proceedings of the IEEE conference on computer vision and pattern recognition,\npages 770\u2013778, 2016.\n\n[17] Daniel Hern\u00e1ndez-Lobato, Jose Hernandez-Lobato, Amar Shah, and Ryan Adams. Predictive\nentropy search for multi-objective Bayesian optimization. In International Conference on\nMachine Learning, pages 1492\u20131501, 2016.\n\n11\n\n\f[18] Steven CH Hoi, Rong Jin, Jianke Zhu, and Michael R Lyu. Batch mode active learning and its\napplication to medical image classi\ufb01cation. In Proceedings of the 23rd international conference\non Machine learning, pages 417\u2013424. ACM, 2006.\n\n[19] Neil Houlsby, Ferenc Husz\u00e1r, Zoubin Ghahramani, and M\u00e1t\u00e9 Lengyel. Bayesian active learning\n\nfor classi\ufb01cation and preference learning. arXiv preprint arXiv:1112.5745, 2011.\n\n[20] David Janz, Jos van der Westhuizen, and Jos\u00e9 Miguel Hern\u00e1ndez-Lobato. Actively learning\n\nwhat makes a discrete sequence valid. arXiv preprint arXiv:1708.04465, 2017.\n\n[21] Alex Kendall, Vijay Badrinarayanan, and Roberto Cipolla. Bayesian segnet: Model uncertainty\nin deep convolutional encoder-decoder architectures for scene understanding. arXiv preprint\narXiv:1511.02680, 2015.\n\n[22] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint\n\narXiv:1412.6980, 2014.\n\n[23] Durk P Kingma, Shakir Mohamed, Danilo Jimenez Rezende, and Max Welling. Semi-supervised\nlearning with deep generative models. In Advances in neural information processing systems,\npages 3581\u20133589, 2014.\n\n[24] Andreas Krause, Ajit Singh, and Carlos Guestrin. Near-optimal sensor placements in Gaussian\nprocesses: Theory, e\ufb03cient algorithms and empirical studies. Journal of Machine Learning\nResearch, 9(Feb):235\u2013284, 2008.\n\n[25] George L Nemhauser, Laurence A Wolsey, and Marshall L Fisher. An analysis of approximations\nfor maximizing submodular set functions\u2014i. Mathematical programming, 14(1):265\u2013294,\n1978.\n\n[26] Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito,\nZeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic di\ufb00erentiation in\nPyTorch. In NIPS Autodi\ufb00 Workshop, 2017.\n\n[27] Antti Rasmus, Mathias Berglund, Mikko Honkala, Harri Valpola, and Tapani Raiko. Semi-\nsupervised learning with ladder networks. In Advances in neural information processing systems,\npages 3546\u20133554, 2015.\n\n[28] Trevor Darrell Samarth Sinha, Sayna Ebrahimi. Variational adversarial active learning. arXiv\n\npreprint arXiv:1904.00370, 2019.\n\n[29] Ozan Sener and Silvio Savarese. Active learning for convolutional neural networks: A core-set\n\napproach. In International Conference on Learning Representations, 2018.\n\n[30] Yanyao Shen, Hyokun Yun, Zachary C. Lipton, Yakov Kronrod, and Animashree Anandkumar.\nDeep active learning for named entity recognition. In International Conference on Learning\nRepresentations, 2018.\n\n[31] Aditya Siddhant and Zachary C Lipton. Deep Bayesian active learning for natural language\nprocessing: Results of a large-scale empirical study. arXiv preprint arXiv:1808.05697, 2018.\n[32] K. Simonyan and A. Zisserman. Very deep convolutional networks for large-scale image\n\nrecognition. In International Conference on Learning Representations, 2015.\n\n[33] Jasper Snoek, Hugo Larochelle, and Ryan P Adams. Practical Bayesian optimization of machine\nlearning algorithms. In Advances in neural information processing systems, pages 2951\u20132959,\n2012.\n\n[34] Simon Tong. Active learning: theory and applications, volume 1. Stanford University USA,\n\n2001.\n\n[35] Keze Wang, Dongyu Zhang, Ya Li, Ruimao Zhang, and Liang Lin. Cost-e\ufb00ective active learning\nfor deep image classi\ufb01cation. IEEE Transactions on Circuits and Systems for Video Technology,\n27(12):2591\u20132600, 2017.\n\n[36] Raymond W Yeung. A new outlook on shannon\u2019s information measures. IEEE transactions on\n\ninformation theory, 37(3):466\u2013474, 1991.\n\n12\n\n\f", "award": [], "sourceid": 3801, "authors": [{"given_name": "Andreas", "family_name": "Kirsch", "institution": "University of Oxford"}, {"given_name": "Joost", "family_name": "van Amersfoort", "institution": "University of Oxford"}, {"given_name": "Yarin", "family_name": "Gal", "institution": "University of Oxford"}]}