Characterizing signal propagation to close the performance gap in unnormalized ResNets

Brock, Andrew, De, Soham, Smith, Samuel L.

arXiv.org Machine Learning 

Batch Normalization is a key component in almost all state-of-the-art image classifiers, but it also introduces practical challenges: it breaks the independence between training examples within a batch, can incur compute and memory overhead, and often results in unexpected bugs. Building on recent theoretical analyses of deep ResNets at initialization, we propose a simple set of analysis tools to characterize signal propagation on the forward pass, and leverage these tools to design highly performant ResNets without activation normalization layers. Crucial to our success is an adapted version of the recently proposed Weight Standardization. Our analysis tools show how this technique preserves the signal in networks with ReLU or Swish activation functions by ensuring that the per-channel activation means do not grow with depth. Across a range of FLOP budgets, our networks attain performance competitive with the state-of-the-art EfficientNets on ImageNet. BatchNorm has become a core computational primitive in deep learning (Ioffe & Szegedy, 2015), and it is used in almost all state-of-the-art image classifiers (Tan & Le, 2019; Wei et al., 2020). A number of different benefits of BatchNorm have been identified. It smoothens the loss landscape (Santurkar et al., 2018), which allows training with larger learning rates (Bjorck et al., 2018), and the noise arising from the minibatch estimates of the batch statistics introduces implicit regularization (Luo et al., 2019). Crucially, recent theoretical work (Balduzzi et al., 2017; De & Smith, 2020) has demonstrated that BatchNorm ensures good signal propagation at initialization in deep residual networks with identity skip connections (He et al., 2016b;a), and this benefit has enabled practitioners to train deep ResNets with hundreds or even thousands of layers (Zhang et al., 2019). However, BatchNorm also has many disadvantages. Its behavior is strongly dependent on the batch size, performing poorly when the per device batch size is too small or too large (Hoffer et al., 2017), and it introduces a discrepancy between the behaviour of the model during training and at inference time. BatchNorm also adds memory overhead (Rota Bulò et al., 2018), and is a common source of implementation errors (Pham et al., 2019). In addition, it is often difficult to replicate batch normalized models trained on different hardware. A number of alternative normalization layers have been proposed (Ba et al., 2016; Wu & He, 2018), but typically these alternatives generalize poorly or introduce their own drawbacks, such as added compute costs at inference.

Duplicate Docs Excel Report

Title
None found

Similar Docs  Excel Report  more

TitleSimilaritySource
None found