Why Do We Need Weight Decay in Modern Deep Learning?

Andriushchenko, Maksym, D'Angelo, Francesco, Varre, Aditya, Flammarion, Nicolas

arXiv.org Artificial Intelligence 

Weight decay is a broadly used technique for training state-of-the-art deep networks, including large language models. Despite its widespread usage, its role remains poorly understood. In this work, we highlight that the role of weight decay in modern deep learning is different from its regularization effect studied in classical learning theory. For overparameterized deep networks, we show how weight decay modifies the optimization dynamics enhancing the ever-present implicit regularization of SGD via the loss stabilization mechanism. In contrast, for underparameterized large language models trained with nearly online SGD, we describe how weight decay balances the bias-variance tradeoff in stochastic optimization leading to lower training loss. Moreover, we show that weight decay also prevents sudden loss divergences for bfloat16 mixed-precision training which is a crucial tool for LLM training. Overall, we present a unifying perspective from ResNets on vision tasks to LLMs: weight decay is never useful as an explicit regularizer but instead changes the training dynamics in a desirable way. Weight decay serves to constrain the network capacity (Goodfellow et al., 2016) and acts as a mechanism for suppressing irrelevant weight components, aligning with the principles of Occam's razor (Krogh & Hertz, 1991). It is central in discussions on generalization bounds (Shalev-Shwartz & Ben-David, 2014), albeit a recent empirical study by Jiang et al. (2020) casts doubt on how well norm-based measures correlate with generalization for deep networks. Weight decay is also known to yield a regularization of the input-output Jacobian (Zhang et al., 2018) and to alter the training dynamics of scale-invariant networks by changing the effective learning rate (Van Laarhoven, 2017). Weight decay is widely used for training most state-of-theart deep networks such as GPT-3 (Brown et al., 2020), CLIP (Radford et al., 2021), or PALM (Chowdhery et al., 2022). We argue that despite its widespread usage, its effect is still poorly understood: in some cases it acts as a regularizer but in some cases as a tool for better optimization. Although the regularization effect of weight decay is thoroughly studied in classical learning theory, deep networks are already equipped with strong implicit regularization coming from the parameter initialization, optimization algorithm, and architecture (Zhang et al., 2016). Moreover, recent years have brought along new architectures and settings such as transformers (Vaswani et al., 2017) and nearly one-epoch language modelling (Brown et al., 2020; Hoffmann et al., 2022).

Duplicate Docs Excel Report

Title
None found

Similar Docs  Excel Report  more

TitleSimilaritySource
None found