{"title": "Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks", "book": "Advances in Neural Information Processing Systems", "page_first": 4900, "page_last": 4909, "abstract": "Reducing the numerical precision of data and computation is extremely effective in accelerating deep learning training workloads. Towards this end, 8-bit floating point representations (FP8) were recently proposed for DNN training. However, its applicability was demonstrated on a few selected models only and significant degradation is observed when popular networks such as MobileNet and Transformer are trained using FP8. This degradation is due to the inherent precision requirement difference in the forward and backward passes of DNN training. Using theoretical insights, we propose a hybrid FP8 (HFP8) format and DNN end-to-end distributed training procedure. We demonstrate, using HFP8, the successful training of deep learning models across a whole spectrum of applications including Image Classification, Object Detection, Language and Speech without accuracy degradation. Finally, we demonstrate that, by using the new 8 bit format, we can directly quantize a pre-trained model down to 8-bits without losing accuracy by simply fine-tuning batch normalization statistics. These novel techniques enable a new generations of 8-bit hardware that are robust for building and deploying neural network models.", "full_text": "Hybrid 8-bit Floating Point (HFP8) Training and\n\nInference for Deep Neural Networks\n\nXiao Sun\n\nJungwook Choi\u2217\n\nChia-Yu Chen\n\nNaigang Wang\n\nSwagath Venkataramani\n\nVijayalakshmi Srinivasan\n\nXiaodong Cui Wei Zhang\n\nKailash Gopalakrishnan\n\nIBM T. J. Watson Research Center\nYorktown Heights, NY 10598, USA\n\n{xsun, * ,cchen,nwang,swagath.venkataramani,viji,cuix,weiz,kailash}@us.ibm.com\n\nAbstract\n\nReducing the numerical precision of data and computation is extremely effective\nin accelerating deep learning training workloads. Towards this end, 8-bit \ufb02oating\npoint representations (FP8) were recently proposed for DNN training. However,\nits applicability was only demonstrated on a few selected models and signi\ufb01cant\ndegradation is observed when popular networks such as MobileNet and Trans-\nformer are trained using FP8. This degradation is due to the inherent precision\nrequirement difference in the forward and backward passes of DNN training. Using\ntheoretical insights, we propose a hybrid FP8 (HFP8) format and DNN end-to-end\ndistributed training procedure. We demonstrate, using HFP8, the successful train-\ning of deep learning models across a whole spectrum of applications including\nImage Classi\ufb01cation, Object Detection, Language and Speech without accuracy\ndegradation. Finally, we demonstrate that, by using the new 8 bit format, we can\ndirectly quantize a pre-trained model down to 8-bits without losing accuracy by\nsimply \ufb01ne-tuning batch normalization statistics. These novel techniques enable a\nnew generations of 8-bit hardware that are robust for building and deploying neural\nnetwork models.\n\n1\n\nIntroduction\n\nAs Deep Neural Networks (DNNs) evolve rapidly and as models get more complex, training times\nhave increased signi\ufb01cantly. To mitigate this challenge, ef\ufb01cient training through reduced precision\nexploitation has become increasingly important. Using reduced precision for data representations\nand general matrix multiplications (GEMM) can accelerate DNN training dramatically and save\nsigni\ufb01cant computing time and power. Indeed, GPUs can already perform mixed-precision training\nwith 16-bit IEEE Half-Precision \ufb02oating point formats for deep learning tasks [1]. Recently, a new (1-\n5-2) (sign-exponent-mantissa) \ufb02oating-point 8-bit format (FP8), was used to successfully train popular\nImageNet models [2] without much accuracy loss. In addition, 8-bit Fixed point formats (INT8) have\nalso been explored to train ResNet50 successfully although 1 of the 3 GEMM computations was\nperformed in higher precision [3]. In addition to DNN training, ef\ufb01cient low-precision deployment\nis critical in a wide range of edge inference use cases where cost and energy constraints can limit\n\u2217contributed to this work while at IBM, is currently with the Electrical Engineering Department, Hanyang\n\nUniversity, South Korea, email: choij@hanyang.ac.kr\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fperformance [4]. Towards that end, Trans-Precision inference, where models are trained in higher\nprecision and deployed in lower precision formats, have become extremely important [5, 6, 7].\nWhile 8-bit training techniques have progressed rapidly, recent work [2, 3, 8, 9] have only demon-\nstrated its applicability on a small subset of deep learning models\u2014focused around convolution\nnetworks such as ResNet [10]. Indeed, plethora of challenges exist to extend FP8 training to a\nbroader spectrum of applications such as image classi\ufb01cation, object detection, speech and natural\nlanguage processing while preserving model accuracy. Furthermore, in large-scale distributed training\nsystems, FP8 acceleration of GEMM and Convolution operations within each learner makes the\ncommunication between learners at the weight update step a critical bottleneck. Alleviating this\nbottleneck using 8-bit communication schemes could substantially improve the end-to-end training\nperformance for distributed DNN training. In addition, for low-precision inference, \ufb01xed point\ntechniques involving costly retraining of networks for ultra-short bit-widths [11, 12, 13] as well as\npost-training quantization for simpler deployment of the INT8/INT4 inference models [7, 14, 15, 16]\nhave been extensively explored, but the state-of-the-art techniques still lose signi\ufb01cant model accuracy\nwhen they are applied to compact models like MobileNet [17] on large datasets (e.g., ImageNet). In\ncomparison to the \ufb01xed-point representation, FP8 based schemes have a wider dynamic range and\ndo not need to \ufb01nd the right quantization range for each layer and channel\u2014serving post-training\nquantization more naturally.\n\nFigure 1: Challenges in the previous FP8 (1-5-2) format. (a) FP8 training on MobileNetV2 with different\nwidth multipliers\u2014showcasing signi\ufb01cant accuracy degradation from FP8 training for capacity constrained\nmodels. (b) FP8 training on Transformer-based machine translation \u2014 large loss in BLEU scores. (c) Table for\nTrans-Precision inference from FP32 to FP8 \u2014 large accuracy loss observed in various domains.\n\n1.1 Challenges and Related Works\n\nTraining and inferencing smaller capacity models in 8-bit: A key challenge in (1-5-2) FP8\ntraining is the ability to train smaller capacity models without losing accuracy. While networks such\nas VGG [18] and ResNet [10] can be trained effectively, models like MobileNetV2 that have 1/7th the\ncapacity of ResNet50 suffer signi\ufb01cant degradation (\u223c 1%) if trained in FP8, as shown in Fig.1(a).\nThis training problem is further exacerbated when we reduce the layer width of MobileNetV2 by\n0.5 and 0.35\u2014resulting in 2 \u223c 3% degradation. Furthermore, as discussed in the previous section,\ninferencing on the post-training quantized MobileNetV2 models [7] using INT8 and INT4 formats\nresults in signi\ufb01cant accuracy degradation (>2%). Recent work has identi\ufb01ed the small variance\nof the depthwise convolution layers as the cause of such degradation [14], which has been partly\naddressed in [15] by retraining the networks with the adaptable dynamic range. Techniques that can\nresolve this low-precision training challenge for smaller capacity models and simultaneously avoid\nthe issues in post-training quantization can be extremely important for Edge deployment use cases.\nApplicability of 8-bit Training to Other Domains: In Natural Language Processing(NLP) and\nSpeech domains, popular networks built on LSTMs and Transformer blocks perform simple matrix\nmultiplications using fully-connected (FC) layers rather than convolution operations. Training these\nnetworks using FP8 has proven to be a signi\ufb01cant challenge. As an example, we\u2019ve observed slow\nconvergence and lower BLEU scores on the Transformer model on the WMT14 dataset trained\nusing (1-5-2) FP8 precision, as shown in Fig.1(b). Furthermore, in many state of the art Speech\nand language models, the last FC layer has a very large dimension\u2014corresponding to vocabulary\nsize(typically 10-100 times larger than ImageNet) [19, 20]. As a result, the last FC layer consumes\na signi\ufb01cant fraction (> 20-30%) of the total computation time. Currently, 8-bit training solutions\n\n2\n\n\fcustomized for convolution nets relax the last layer precision by keeping that layer in 16-bit (FP16)\nsince last layers computed in FP8 have shown to increase classi\ufb01cation error rates [2]. However, this\nsolution is expensive for NLP and Speech tasks which have large last layers. In addition, Object\nDetection and semantic segmentation networks such as MaskRCNN [21] and SSD [22] that load a\npre-trained ImageNet backbone model and \ufb01ne-tune it with an Object Detection dataset have not\nbeen investigated within the framework of 8-bit training. Finally, Trans-Precision Inference in (1-5-2)\nFP8 (directly converted from FP32 trained models) results in signi\ufb01cant accuracy degradation in\nmany of these models as shown in the table of Fig.1(c). The goal of this paper is to enable an 8-bit\ntraining methodology that addresses all of the above challenges.\n8-bit weight updates: The weight update phase of low precision training requires a master copy\nof weights in higher precision to accumulate gradients across minibatches without losing critical\ninformation due to \"swamping\" [23]. In INT 8 training, FP32 is used for this master copy, resulting\nin increased latency due to the bandwidth needed for communication and AXPY (Y = AX + Y )\ncomputation in 32-bits. In FP8 training, even with stochastic rounding techniques, 16-bit (FP16)\nweights are still needed for the master copy to preserve convergence [2]. As a solution to this\nproblem, weight averaging has been proposed to facilitate exchange of 8-bit weights (while keeping\nhigher precision weights locally). This scheme, however, results in >4% accuracy degradation on\nResNet18/ImageNet [24]. An ideal weight update scheme should compress gradients and only\ncompute and communicate 8-bit weights during the training process.\n\n1.2 Contributions\n\nIn this paper, we introduce a new hybrid FP8 format and technique that is applicable to both compu-\ntations (training and inference) and communication to address all of these challenges. In comparison\nto the state-of-the-art FP8 training and INT8 inference solutions, our primary contributions include:\n\n1. A novel hybrid FP8 format that uses 4 exponent bits and 3 mantissa bits (1-4-3 with an expo-\nnent bias) for forward propagation and 5 exponent bits and 2 mantissa bits (1-5-2) for backward\npropagation\u2014achieving negligible accuracy degradation on previously problematic models in-\ncluding MobileNetV2 and Transformer.\n\n2. Demonstrated the robustness of the HFP8 format on a wide spectrum of DNN tasks including\n\nImage Classi\ufb01cation, Object Detection, NLP and Speech\u2014while fully preserving accuracy.\n\n3. Through theoretical analysis, we\u2019ve identi\ufb01ed BN statistics as the primary reason for accuracy\nloss in low-precision Trans-Precision inference and show that BN statistics could be \ufb01ne tuned to\nfully recover model accuracy while using our 1-4-3 FP8 precision.\n\n4. Introduced a deterministic FP8 weight update scheme that can converge to baseline accuracies with-\nout using stochastic rounding along with a compatible all-reduce technique that takes advantage\nof low bit-width weights to speed up distributed learning.\n\n2 New Hybrid FP8 (HFP8) Formats and Computations\n\n2.1\n\nImpact of FP8 formats on Trans-Precision Inference (Post-Training Quantization)\n\nIn this section, we explore how different FP8 precision formats for activations and weights impact\nTrans-Precision Inference accuracy. Towards that end, we adapt the theoretical framework of Sakr et\nal. [25] to quantify the mismatch probability between a reduced precision neural network and its\nfull-precision counterpart. Consider a neural network for a classi\ufb01cation task such as MobileNetV2,\nwith quantized weights (W + qw) and activations (A + qA), where each numerical output (Zi) after\nthe feedforward pass may be corrupted by a quantization noise (qzi). Using Taylor\u2019s expansion\nand ignoring the cross-products of quantization noise terms, the total quantization noise qzi can be\nexpressed as [25]:\n\n(cid:88)\n\nah\u2208A\n\nqzi =\n\n(cid:88)\n\nwh\u2208W\n\nqah\n\n\u2202zi\n\u2202ah\n\n+\n\nqwh\n\n\u2202zi\n\u2202wh\n\n,\n\n(1)\n\nwhere A and W are index sets. By evaluating the probability of any pair of outputs (zi < zj) that\n\ufb02ipped due to quantization errors Pr(zi + qzi > zj + qzj ), the mismatch probability pm between\nthe reduced precision network and its full precision baseline yields an upper bound\u2014de\ufb01ned by the\n\n3\n\n\fquantization error of each activation and weight multiplied by the corresponding gradients called\n\u201cgain\u201d. As the gains are network speci\ufb01c, we can evaluate them empirically using Eqn.1.\nFig.2 shows the computed mismatch probability due to activation and weight quantizations for each\nlayer of the MobileNetV2 (CIFAR-10) model. The results clearly show that by moving just one bit\nfrom the exponent (1-5-2) to the mantissa (1-4-3), the mismatch probability corresponding to both\nactivations and weights decrease dramatically. This improvement comes from the fact that weights\nand activations are represented with higher \ufb01delity using the extra mantissa bit.\nHowever, since the total bit-width is limited to 8, reduction in the exponent bit-width can result\nin clamping of large weights and activations and/or truncation of small values to the minimum\nrepresentable value in (1-4-3). Given the typical numerical distribution of these tensors during\ntraining, we found that under\ufb02ow represents a more serious concern. To mitigate this effect,\nwe introduce a \ufb01xed exponent bias that shifts the coverage range of the (1-4-3) FP8 format to\n[2\u22122ebit\u22121\u2212bias+1, 2mbit+1\u22121\nBy choosing an exponent bias of 4, we intend to better align the (1-4-3) format with the distributions\nof activations and weights seen in a wide range of DNN layers and models. As veri\ufb01ed in Fig.2,\nintroducing an extra bias of 4 on the exponent further reduces the impact of quantization\u2014speci\ufb01cally\non the weights in the lower layers which appear to have much smaller magnitudes. In the (1-4-3) FP8\nwith bias=4 format, we reduce the maximum clamping value from 480 down to 30, large enough to\ncover the wide variety of networks that we have investigated. In exchange, we are able to represent\nsmaller activations and weights down to 2\u221211 (much lower than the 2\u22127 truncation threshold in\n1-4-3). For simplicity of notation, all the following (1-4-3) FP8 experiments have a default exponent\nbias of 4. These experiments indicate that the 5-bit exponent range 2\u221215 \u2212 216 is an overkill for\nDNN inference, and 4 bit exponents with a bias of 4 have suf\ufb01cient range and \ufb01delity to represent\nactivations and weights for both training and Trans-Precision inference performance. Finally, we\u2019ve\nveri\ufb01ed that these conclusions extend to a large number of neural network models and topologies.\n\n2mbit \u00d7 22ebit\u22121\u2212bias].\n\nFigure 2: The layer-wise decomposition of mismatch probability (Eqn.1) for (a) activation and (b) weight\nquantizations on a MobileNetV2 (CIFAR-10) model (excluding the \ufb01rst and last layers which are in full\nprecision). (1-5-2) results in higher activation errors compared to (1-4-3) with or without bias=4. (1-4-3) with\nbias=4 shows the lowest mismatch thanks to the extra \ufb01delity needed for representing small weight values near\nthe network output.\n\n2.2\n\nImpact of FP8 Formats on Model Training Accuracy\n\nIn addition to increasing mismatch probability, we note that quantization noise also degrades the\nLipschitz property of loss surfaces, that is, the loss changes in a faster rate, and the magnitudes of\nthe gradients are larger too. In Fig.3, we plot (a) the loss surfaces of a FP32 trained model and (b) a\n(1-5-2) FP8 trained model along two random directions with their coef\ufb01cients scanned along the x and\ny axis [26]. The loss surface of the (1-5-2) trained model shows multiple saddle points and appears\nrougher\u2014making gradient descent based training unstable as evidenced by the kinks in Fig.3(b). The\nmitigation of such kinks has also explained the effectiveness of Batch Normalization [27]. In contrast,\nby increasing the number of mantissa bits from 2 to 3 for the forward pass only (while keeping\ngradients and errors in 1-5-2), the loss-surface appears to be signi\ufb01cantly improved in Fig.3(c),\nimplying easier optimization. On the other hand, comparing the loss surfaces for training and test, we\ncan see that FP8 quantizations do not impact generalization.\nGuided by these insights, we propose our Hybrid FP8 (HFP8) formats utilizing two different FP8\nformats to customize the precision separately for the forward and backward passes of DNN training\u2014\n\n4\n\n\fFigure 3: The loss surfaces of models trained in different precisions: (a) FP32, (b) FP8 (all GEMMs in 1-5-2,\n[2]), (c) HFP8 (1-4-3 only for forward pass). The top row for the loss surfaces from training data while the\nbottom row from test data. The loss surfaces with HFP8 maintain good Lipschitz properties compared to FP32\nwhile the loss surfaces with FP8 exhibit multiple saddle points which hinder training convergence.\n\nimproving the performance on training and Trans-Precision inference. The underlying reason for\nthis choice is that forward and backward passes have different optimal balances between range and\nprecision. While tensors in the forward pass prefer higher precision (and lower representational\nerror), gradients in the backward pass prefer a higher dynamic range. We describe our HFP8 training\nmethodology where weights and activations adopt the (1-4-3) format (bias=4) while tensors used\nin backpropagation continue to be represented using the (1-5-2) format (in combination with loss\nscaling techniques pioneered by [28])(see Fig.1 in Appendix A). Our experiment shows that this\nsimple change can signi\ufb01cantly improve the performance of both MobileNet and Transformer models\n(as shown in Fig.4(a) and (b)), in comparison to forward (1-5-2) based results that showed signi\ufb01cant\ndegradation in Fig.1. In the following sections, we will show that this improvement is universally\napplicable to both training and Trans-Precision inference (training results for Speech and Object\nDetection models are shown in Fig.4(c) and (d)).\nFor errors and gradients in HFP8 back-propagation, we employ the (1-5-2) FP8 format, which\nhas proven to be optimal across various deep learning tasks. However, unlike activations and\nweights, even 5-bit exponents are insuf\ufb01cient to represent the wide dynamic range seen in activation\ngradients. Therefore, loss scaling has been adopted to enable gradients to become large enough to be\nrepresentable using the (1-5-2) format [2]. Nonetheless, it\u2019s infeasible to seek a unique scaling factor\nthat \ufb01ts a wide range of different models and datasets. Towards that end, we adapted auto-adjusted\nscale factors for gradients and errors during HFP8 training using Apex [28] (details in Appendix B).\nFinally, through hardware design experiments, we\u2019ve con\ufb01rmed that \ufb02oating-point units (FPUs) that\ncan support both formats are only 5% larger than the original FPUs that only support 1-5-2.\n\nFigure 4: Training curves using HFP8 on (a) MobileNetV2 with different width-multipliers (sizes) (b)\nTransformer-base machine translation (c) LSTM-based Speech Model for the SWB300 dataset and (d) Mask\nR-CNN model. No signi\ufb01cant loss in accuracy is observed across DNN layer types, models and datasets. Final\ntraining results on a diverse set of models are summarized in Table4.\n\n5\n\n\f2.3 Last Layer Precision and the SoftMax Function\n\nFor networks with large output dimensions (typically seen in Speech and NLP), the last FC and\nSoftMax layers contribute to a signi\ufb01cant fraction of the total computation time due to large matrix-\nmultiplications and expensive exponential functions (especially if these need to be computed in FP16).\nIn these layers, it therefore becomes critical to be able to use 8-bit computations.\nFirst, we note that when (1-4-3) FP8 is used along with (1-6-9) FP16 output precision no degradation\non LSTM-based SWB300 and Transformer-based translation tasks is observed. In contrast, when the\noutput precision of the FC layer is set to (1-4-3) as well, large loss in accuracy is observed (network\ndiverges in SWB300 and \u223c 1 BLEU degradation in WMT En-De). This occurs because the largest\noutput of the last FC layer may be quantized into the same values (bins) during conversion from 16 to\n8-bit and therefore become indistinguishable to the ensuing SoftMax layer. Fig.5(a) and (b) shows the\ndistribution of output activations before and after FP8 (1-4-3) quantization in the transformer model\nfor WMT En-De translation(dout = 42720), showing that the largest numbers are poorly represented\nby 8-bit in Fig.5(b). Interestingly, we discovered that if the quantization step is performed after the\nmax subtraction step (i.e. x\u2212 xmax) in SoftMax, this degradation in accuracy can be fully eliminated.\nIn Fig.5(c), the x \u2212 xmax sub-step of SoftMax moves the largest values closest to 0, where data\nrepresentation is strongest due to non-uniform nature of \ufb02oating point representation. Furthermore,\nthis technique also allows SoftMax to be performed using just 8-bits. Detailed discussions on the\nreduced precision SoftMax will be a focus of future work. Overall, the (1-4-3) HFP8 format in the\nlast FC layer when combined with an output precision of 1-6-9 and the max-subtracted SoftMax\nfunction allows for ef\ufb01cient end-to-end HFP8 computations.\n\nFigure 5: Output activation distributions of the last FC layer and its quantization (to 1-4-3 FP8) in Transformer\nmodel before and after subtracting the max output. Quantization after subtracting with the max output allows the\nlargest inputs of SoftMax to be represented with high \ufb01delity in 8-bits and fully preserves model accuracy.\n\n3 Trans-Precision Inference in FP8\n\nGuided by the theoretical framework in Section 2.1, we investigate how inference accuracies are\nimpacted when FP32 trained models are used directly with different FP8 formats for inference (i.e.\nwithout any retraining). Using MobileNetV2 trained on ImageNet as an example, we immediately\nobserve that the (1-4-3) FP8 format is signi\ufb01cantly more accurate than the (1-5-2) format\u2014as shown\nin the \ufb01rst 2 rows of the Table in Fig.6(a). This is consistent with mismatch probability based\npredictions described earlier. However, even with the right FP8 format, we observe that we lose >5%\nin model accuracy in comparison to FP32. To reduce this gap, we provide 2 additional insights. The\n\ufb01rst key insight comes from a theoretical understanding of how quantization errors in weights and\nactivations directly impact the accuracy of outputs of the succeeding Batch Normalization (BN) layer.\nRetuning the statistics (mean and variance) of the BN layer for the precision of interest (i.e. inference\nprecision) has the potential to signi\ufb01cantly reduce this error. As shown in Eqn.2, the quantization\nerror at the output of a BN layer (Z - ZQ) can be expressed in terms of the variance of quantization\nerror in BN input \u03c32\nY \u2014assuming Q and Y are not correlated\n(please see Appendix C for a detailed derivation):\n\nQ and the variance of precise input \u03c32\n\n\uf8f1\uf8f4\uf8f2\uf8f4\uf8f3\u223c= \u03b32 \u03c32\n\nQ\n\u03c32\nY\n\nE[(cid:107)Z \u2212 ZQ(cid:107)2]\n\n,\n\n\u223c= 2\u03b32(1 \u2212\n\noriginal BN statistics.\nretuning BN statistics.\n\n(2)\n\n1(cid:115)\n\n1+\n\n),\n\n\u03c32\nQ\n\u03c32\nY\n\n6\n\n\fTable 1: FP8 Trans-Precision inference for FP32 trained models after BN re-tuning (if applicable)\n\nModel (Dataset)\nResNet18 (ImageNet)\nResNet50 (ImageNet)\nDenseNet121 (ImageNet)\nAlexNet (ImageNet)\nMobileNetV2 (ImageNet)\n4-bidirectional-LSTM Speech (SWB300)a\nTransformer-base (WMT14 En-De)b\nSSD-Lite (VOC)c\nMaskRCNN (COCO)d\n\n1-5-2 Inference\n68.93\n75.89\n74.40\n56.87\n70.31\n10.10\n27.06\n67.40\n32.83/28.65\naWord Error RatebBLEU scorec mean average precision(mAP)d Box/Mask average precision\n\nBaseline (FP32)\n69.32\n76.44\n74.76\n57.10\n71.81\n9.90\n27.53\n68.79\n33.58/29.27\n\n1-4-3 Inference\n68.99\n76.46\n74.78\n57.07\n71.37\n9.90\n27.47\n68.22\n33.43/29.10\n\nPlotting this equation in Fig.6(b), we observe that E[(cid:107)Z \u2212 Zq(cid:107)2] increases linearly with \u03c32\nQ when\npreserving original BN parameters, but increases only sub-linearly when BN statistics are re-tuned.\nThis reveals the fact that the impact of quantization at BN output would be suppressed once BN\nstatistics are properly tuned. As shown in Rows 5 and 6 of Fig.6(a), re-tuning BN statistics using just\n2% of a single epoch of the training dataset reduces this accuracy gap signi\ufb01cantly.\nY have vastly\nThe second key insight obtained using Eqn.2 indicates that layers that have very low \u03c32\nmagni\ufb01ed output errors. Plotting \u03c32\nY as a function of layer number for MobileNetV2 (Fig.6(c)) leads\nus to note that depthwise (DW) convolution layers produce activations that have orders of magnitude\nsmaller variance (\u03c32) in comparison to traditional convolution layers [14]. We therefore expect\nthe precision setting in these layers to strongly impact Trans-Precision inference accuracies. Since\nDW layers contribute to <3% of the overall compute in the MobileNet family of networks [17], we\nrecommend setting the precision in these layers uniformly to FP16. As can be seen from Rows 3,4\n(without BN re-tuning) and Rows 7,8 (with BN re-tuning), this precision setting in the DW layers\nsubstantially improves inference accuracies to within \u223c 1% of the baseline.\nA combination of these techniques \u2013 (a) picking the right precision format (1-4-3) for weights and\nactivations of convolution and FC layers (b) setting the precision for DW layers to FP16 and (c)\nupdating BN \u00b5 and \u03c32 with minimal training data \u2013 allows MobileNetV2 to hit accuracies within\n\u223c 0.5% of the full precision baseline. Furthermore, we show that these techniques extend very well\nto other models; as shown in Table 1, FP8 1-4-3 with BN re-tuning can fully recover the baseline\ninference accuracies for the entire spectrum of networks studied. Note that for BN re-tuning, data\ndoes not need to be labeled and thus can be done at the edge devices.\n\nFigure 6: (a) Trans-Precision inference accuracies using the FP32 MobileNetV2 Model in two different FP8\nformats. Rows 5-8 further re-tune BN \u00b5 and \u03c32 using 2% of the training data while using FP8 precision. (b)\nVisualization of the quantization error (Eqn.2) at BN output vs. the quantization error variance at BN input for\ntwo different variances of precise BN input (c) The variances of precise BN input for 52 convolution layers in\nMobileNetV2 showing DW layers having orders of magnitude smaller \u03c32.\n\n4 Hybrid FP8 Distributed Training Scheme and Results\n\nAs DNN compute functions are accelerated in each learner using HFP8, the communication cost\namong learners and memory bandwidth dominated tasks like weight updates become bottlenecks.\n\n7\n\n\fTable 2: Ring based communication schemes among N learners using Hybrid FP8\n\nCommon ring-based weight update\n(1) reduce-scatter gradients\n(2) all-gather gradients\n(3) locally update full size of weights\n\nProposed ring-based weight update\n(1) reduce-scatter gradients\n(2) locally update 1/N gradients to weights\n(3) all-gather weights\n\n&\n\ndata format sent\nFP16(1-6-9)\nlocal\nFP8 (1-4-3)\n\nTable 3: Round-off residual based Hybrid FP8 weight update(per worker)\n\n(Initialize round off residual)\n\nFor each 1/N weight: Rt=0 \u2190 0\nFor timestep t for each 1/N weight:\n\n\u02c6Wt \u2190 Wt\u22121 \u2212 \u03b1t\u22121g(Wt\u22121) \u2212 Rt\u22121:\nWt \u2190 QW( \u02c6Wt)\n\u02c6Rt \u2190 Wt \u2212 \u02c6Wt\nRt \u2190 QR( \u02c6Rt)\n\n(Apply gradients and carried-on residuals)\n\n(Quantize new weights)\n(Overwrite residuals)\n\n(Quantize residuals, higher precision than QW)\n\nTable 4: Baseline vs. Hybrid FP8 training on Image, Language, Speech and Object-Detection Models\n\nModel(Dataset) Accuracy or [other metrics]\nAlexNet (ImageNet)\nResNet18 (ImageNet)\nResNet50 (ImageNet)\nMobileNetV2 (ImageNet)\nDenseNet121 (ImageNet)\n2-LSTM (PennTreeBank)[Test ppl.]\nTransformer-base (WMT14 En-De)[BLEU]\n4-bidirectional-LSTM Speech (SWB300)[WER]\nMaskRCNN(ResNet50) (COCO)[Box/Mask AP]\nSSD-Lite(MobileNetV2) (VOC)[mAP]\n\nBaseline(FP32) HFP8 + Round-off update\n57.28\n69.38\n76.44\n71.81\n74.76\n83.66\n27.50\n9.90\n33.58/29.27\n68.79\n\n57.21\n69.39\n76.22\n71.61\n74.65\n83.86\n27.27\n10.00\n33.06/28.86\n68.72\n\nHardware performance estimations indicate that this communication could take up to \u223c 41 \u2212 62% of\nthe end-to-end training time for ResNet50 with HFP8 GEMM (Appendix D for details). As illustrated\nin Table 2, the conventional communication pattern used in deep learning algorithms exchanges\ngradients through ring-based all-reduce [29, 30] and then each learner updates the whole model\nlocally. To take advantage of 8-bit weights in off-chip communication as well as to minimize local\nmemory bandwidth, we modify the existing distributed learning scheme slightly\u2014so that each of N\nlearners updates only 1/N th of the model after the reduce-scatter phase minimizing local memory\ntransactions. When updating the model globally, the \ufb01nal 8-bit weights produced in each learner are\ndistributed in the all-gather phase, thereby improving off-chip communication by 2\u00d7 compared to\nconventional 16-bit gradient communication.\nTo improve the robustness of low-precision weight updates and to prevent \"swamping\" [2], we\npropose a deterministic round-off residual update scheme that stores the weight in 8-bit while saving\nthe quantization errors locally as \"round-off\" residuals in FP16 as illustrated in Table 3. We study\nthis round-off residual scheme on a wide range of DNN applications and show that it does not impact\nconvergence (consistent with the rich body of theoretical work in this space [31, 32]). With 8-bit\nweight updates and a modi\ufb01ed ring-distribution scheme, our technique improves end-to-end training\ntime by 32 \u2212 38% on ResNet50 (for details, see Appendix D).\nFinally, to demonstrate the wide applicability and the robustness of the HFP8 formats, 8-bit computa-\ntions and round-off residual scheme, we tested it on a wide spectrum of deep learning models and\ndatasets without changes to network architectures, data pre-processing, or hyper-parameters(details in\nAppendix E). As shown in Table 4, every single network tested achieved accuracy very close to\nthe full precision baseline, including tasks that were problematic for previous FP8 endeavors (such\nas MobileNet and Transformers). More complex and challenging tasks, such as Object Detection,\nSpeech and Machine Translation in HFP8 are demonstrated and for the \ufb01rst time show performance\nwithin 0.5% of the full precision baseline on large networks and datasets. Given the limited computa-\ntional complexity in the \ufb01rst and last layers we set the precision in these layers to FP16 except for\nSpeech and Transformer networks, where we use the same HFP8 settings on the large \ufb01nal FC layer\nand \ufb01nd no degradation.\n\n8\n\n\f5 Conclusions\n\nWe have demonstrated DNN training with a new Hybrid FP8 format that adopts two different\nFP8 formats for forward and backward propagation. In addition, we introduced a novel round-off\nresidual scheme which can signi\ufb01cantly improve robustness of low-precision AXPY and reduce\ncommunication bandwidth requirements. We\u2019ve con\ufb01rmed the superior accuracy of this approach over\nprevious 8-bit training proposals on a wide range of state of the art DNN models. In addition, we\u2019ve\npresented new insights in Batch Normalization and depthwise Convolutional layers that demonstrate\nhow the same FP8 format can be used for highly accurate Trans-Precision inference (starting from\nhigher precision FP32 models). These novel techniques enable a new generation of 8-bit hardware\nsystems that are robust for the entire spectrum of DNN training and inference applications.\n\nAcknowledgments\n\nThe authors would like to thank Lam Nguyen and Charbel Sakr for helpful theoretical discussions,\nAnthony Giordano, I-Hsin Chung, Ming-Hung Chen, Paul Crumley, Kaoutar El maghraoui, and\nJeffrey Burns for the computing infrastructure, and Leland Chang, Sunil Shukla, Silvia Mueller,\nAnkur Agrawal, Jinwook Oh, Marcel Schaal, Mauricio Serrano, Wei Wang and the team for the\nchip platform targeted in this work. This research is realized by generous collaborations across IBM\nResearch.\n\nReferences\n[1] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris\n\nGinsburg et al. \"Mixed precision training.\" arXiv preprint arXiv:1710.03740, 2017.\n\n[2] Naigang Wang, Jungwook Choi, Daniel Brand, Chia-Yu Chen, and Kailash Gopalakrishnan. \"Training\ndeep neural networks with 8-bit \ufb02oating point numbers.\" In Advances in neural information processing\nsystems, pp. 7675-7684. 2018.\n\n[3] Ron Banner, Itay Hubara, Elad Hoffer, and Daniel Soudry. \"Scalable methods for 8-bit training of neural\n\nnetworks.\" In Advances in Neural Information Processing Systems, pp. 5145-5153. 2018.\n\n[4] Yann LeCun. \"1.1 Deep Learning Hardware: Past, Present, and Future.\" In 2019 IEEE International\n\nSolid-State Circuits Conference-(ISSCC), pp. 12-19. IEEE, 2019.\n\n[5] Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig\nAdam, and Dmitry Kalenichenko. \"Quantization and training of neural networks for ef\ufb01cient integer-\narithmetic-only inference.\" In Proceedings of the IEEE Conference on Computer Vision and Pattern\nRecognition, pp. 2704-2713. 2018.\n\n[6] Norman P. Jouppi, Cliff Young, Nishant Patil, David Patterson, Gaurav Agrawal, Raminder Bajwa, Sarah\nBates et al. \"In-datacenter performance analysis of a tensor processing unit.\" In 2017 ACM/IEEE 44th\nAnnual International Symposium on Computer Architecture (ISCA), pp. 1-12. IEEE, 2017.\n\n[7] Raghuraman Krishnamoorthi. \"Quantizing deep convolutional networks for ef\ufb01cient inference: A whitepa-\n\nper.\" arXiv preprint arXiv:1806.08342, 2018.\n\n[8] Shuang Wu, Guoqi Li, Feng Chen, and Luping Shi. \"Training and inference with integers in deep neural\n\nnetworks.\" International Conference on Learning Representations (ICLR), 2018.\n\n[9] Po-Chen Lin, Mu-Kai Sun, Chu King Kung, and Tzi-Dar Chiueh. \"FloatSD: A New Weight Representation\nand Associated Update Method for Ef\ufb01cient Convolutional Neural Network Training.\" IEEE Journal on\nEmerging and Selected Topics in Circuits and Systems, 2019.\n\n[10] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. \"Deep Residual Learning for Image Recogni-\n\ntion.\" arXiv preprint arXiv:1512.03385, 2015\n\n[11] Jungwook Choi, Swagath Venkataramani, Vijayalakshmi Srinivasan, Kailash Gopalakrishnan, Zhuo Wang,\nand Pierce Chuang. \"ACCURATE AND EFFICIENT 2-BIT QUANTIZED NEURAL NETWORKS.\" In\nSysML 2019.\n\n[12] Asit Mishra, Eriko Nurvitadhi, Jeffrey J. Cook, and Debbie Marr. \"WRPN: wide reduced-precision\n\nnetworks.\" arXiv preprint arXiv:1709.01134, 2017.\n\n9\n\n\f[13] Dongqing Zhang, Jiaolong Yang, Dongqiangzi Ye, and Gang Hua. \"Lq-nets: Learned quantization for\nhighly accurate and compact deep neural networks.\" In Proceedings of the European Conference on\nComputer Vision (ECCV), pp. 365-382. 2018.\n\n[14] Tao Sheng, Chen Feng, Shaojie Zhuo, Xiaopeng Zhang, Liang Shen, and Mickey Aleksic. \"A quantization-\nfriendly separable convolution for mobilenets.\" In 2018 1st Workshop on Energy Ef\ufb01cient Machine\nLearning and Cognitive Computing for Embedded Applications (EMC2), pp. 14-18. IEEE, 2018.\n\n[15] Sambhav R. Jain, Albert Gural, Michael Wu, and Chris Dick. \"Trained Uniform Quantization for Accurate\nand Ef\ufb01cient Neural Network Inference on Fixed-Point Hardware.\" arXiv preprint arXiv:1903.08066,\n2019.\n\n[16] Ritchie Zhao, Yuwei Hu, Jordan Dotzel, Christopher De Sa, and Zhiru Zhang. \"Improving Neural Network\n\nQuantization using Outlier Channel Splitting.\" arXiv preprint arXiv:1901.09504, 2019.\n\n[17] Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, and Liang-Chieh Chen. \"Mo-\nbileNetV2: Inverted residuals and linear bottlenecks.\" In Proceedings of the IEEE Conference on Computer\nVision and Pattern Recognition, pp. 4510-4520. 2018.\n\n[18] Karen Simonyan, Andrew Zisserman. \"Very Deep Convolutional Networks for Large-Scale Image Recog-\n\nnition.\" arXiv:1409.1556, 2014.\n\n[19] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, \u0141ukasz\nKaiser, and Illia Polosukhin. \"Attention is all you need.\" In Advances in neural information processing\nsystems, pp. 5998-6008. 2017.\n\n[20] Xiaodong Cui, Vaibhava Goel, and George Saon. \"Embedding-based speaker adaptive training of deep\n\nneural networks.\" arXiv preprint arXiv:1710.06937, 2017.\n\n[21] Kaiming He, Georgia Gkioxari, Piotr Doll\u00e1r, and Ross Girshick. \"Mask R-CNN.\" In Proceedings of the\n\nIEEE international conference on computer vision, pp. 2961-2969. 2017.\n\n[22] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, and\nAlexander C. Berg. \"SSD: Single shot multibox detector.\" In European conference on computer vision, pp.\n21-37. Springer, Cham, 2016.\n\n[23] Nicholas J. Higham. \"The accuracy of \ufb02oating point summation.\" SIAM Journal on Scienti\ufb01c Computing\n\n14, no. 4 (1993): 783-799.\n\n[24] Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, and Christopher De\nSa. \"SWALP: Stochastic Weight Averaging in Low-Precision Training.\" arXiv preprint arXiv:1904.11943,\n2019.\n\n[25] Charbel Sakr, Yongjune Kim, and Naresh Shanbhag. \"Analytical guarantees on numerical precision of\ndeep neural networks.\" In Proceedings of the 34th International Conference on Machine Learning-Volume\n70, pp. 3007-3016. JMLR. org, 2017.\n\n[26] Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer, and Tom Goldstein. \"Visualizing the loss landscape of\n\nneural nets.\" In Advances in Neural Information Processing Systems, pp. 6389-6399. 2018.\n\n[27] Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, and Aleksander Madry. \"How does batch normalization\n\nhelp optimization?.\" In Advances in Neural Information Processing Systems, pp. 2483-2493. 2018.\n\n[28] https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training\n\n[29] Baidu. https://github.com/baidu-research/baidu-allreduce. 2017.\n\n[30] Dario Amodei, Sundaram Ananthanarayanan, Rishita Anubhai, Jingliang Bai, Eric Battenberg, Carl\nCase, Jared Casper et al. \"Deep speech 2: End-to-end speech recognition in english and mandarin.\" In\nInternational conference on machine learning, pp. 173-182. 2016.\n\n[31] Dan Alistarh, Torsten Hoe\ufb02er, Mikael Johansson, Nikola Konstantinov, Sarit Khirirat, and C\u00e9dric Renggli.\n\"The convergence of sparsi\ufb01ed gradient methods.\" In Advances in Neural Information Processing Systems,\npp. 5973-5983. 2018.\n\n[32] Li Hao, Soham De, Zheng Xu, Christoph Studer, Hanan Samet, and Tom Goldstein. \"Training quantized\nnets: A deeper understanding.\" In Advances in Neural Information Processing Systems, pp. 5811-5821.\n2017.\n\n10\n\n\f", "award": [], "sourceid": 2711, "authors": [{"given_name": "Xiao", "family_name": "Sun", "institution": "IBM Thomas J. Watson Research Center"}, {"given_name": "Jungwook", "family_name": "Choi", "institution": "Hanyang University"}, {"given_name": "Chia-Yu", "family_name": "Chen", "institution": "IBM research"}, {"given_name": "Naigang", "family_name": "Wang", "institution": "IBM T. J. Watson Research Center"}, {"given_name": "Swagath", "family_name": "Venkataramani", "institution": "IBM Research"}, {"given_name": "Vijayalakshmi (Viji)", "family_name": "Srinivasan", "institution": "IBM TJ Watson"}, {"given_name": "Xiaodong", "family_name": "Cui", "institution": "IBM T. J. Watson Research Center"}, {"given_name": "Wei", "family_name": "Zhang", "institution": "IBM T.J.Watson Research Center"}, {"given_name": "Kailash", "family_name": "Gopalakrishnan", "institution": "IBM Research"}]}