Not enough data to create a plot.
Try a different view from the menu above.
Agarwala, Atish
Avoiding spurious sharpness minimization broadens applicability of SAM
Singh, Sidak Pal, Mobahi, Hossein, Agarwala, Atish, Dauphin, Yann
Curvature regularization techniques like Sharpness Aware Minimization (SAM) have shown great promise in improving generalization on vision tasks. However, we find that SAM performs poorly in domains like natural language processing (NLP), often degrading performance -- even with twice the compute budget. We investigate the discrepancy across domains and find that in the NLP setting, SAM is dominated by regularization of the logit statistics -- instead of improving the geometry of the function itself. We use this observation to develop an alternative algorithm we call Functional-SAM, which regularizes curvature only through modification of the statistics of the overall function implemented by the neural network, and avoids spurious minimization through logit manipulation. Furthermore, we argue that preconditioning the SAM perturbation also prevents spurious minimization, and when combined with Functional-SAM, it gives further improvements. Our proposed algorithms show improved performance over AdamW and SAM baselines when trained for an equal number of steps, in both fixed-length and Chinchilla-style training settings, at various model scales (including billion-parameter scale). On the whole, our work highlights the importance of more precise characterizations of sharpness in broadening the applicability of curvature regularization to large language models (LLMs).
Exact Risk Curves of signSGD in High-Dimensions: Quantifying Preconditioning and Noise-Compression Effects
Xiao, Ke Liang, Marshall, Noah, Agarwala, Atish, Paquette, Elliot
The success of deep learning has been driven by the effectiveness of relatively simple stochastic optimization algorithms. Stochastic gradient descent ( SGD) with momentum can be used to train models like ResNet50 with minimal hyperparameter tuning. The workhorse of modern machine learning is Adam, which was designed to give an approximation of preconditioning with a diagonal, online approximation of the Fisher information matrix (Kingma, 2014). Additional hypotheses for the success of Adam include its ability to maintain balanced updates to parameters across layers and its potential noise-mitigating effects (Zhang et al., 2020; 2024). Getting a quantitative, theoretical understanding of Adam and its variants is hindered by their complexity. While the multiple exponential moving averages are easy to implement, they complicate analysis. The practical desire for simpler, more efficient learning algorithms as well as the theoretical desire for simpler models to analyze have led to a resurgence in the study of signSGD .
Stepping on the Edge: Curvature Aware Learning Rate Tuners
Roulet, Vincent, Agarwala, Atish, Grill, Jean-Bastien, Swirszcz, Grzegorz, Blondel, Mathieu, Pedregosa, Fabian
Curvature information -- particularly, the largest eigenvalue of the loss Hessian, known as the sharpness -- often forms the basis for learning rate tuners. However, recent work has shown that the curvature information undergoes complex dynamics during training, going from a phase of increasing sharpness to eventual stabilization. We analyze the closed-loop feedback effect between learning rate tuning and curvature. We find that classical learning rate tuners may yield greater one-step loss reduction, yet they ultimately underperform in the long term when compared to constant learning rates in the full batch regime. These models break the stabilization of the sharpness, which we explain using a simplified model of the joint dynamics of the learning rate and the curvature. To further investigate these effects, we introduce a new learning rate tuning method, Curvature Dynamics Aware Tuning (CDAT), which prioritizes long term curvature stabilization over instantaneous progress on the objective. In the full batch regime, CDAT shows behavior akin to prefixed warm-up schedules on deep learning objectives, outperforming tuned constant learning rates. In the mini batch regime, we observe that stochasticity introduces confounding effects that explain the previous success of some learning rate tuners at appropriate batch sizes. Our findings highlight the critical role of understanding the joint dynamics of the learning rate and curvature, beyond greedy minimization, to diagnose failures and design effective adaptive learning rate tuners.
A Clipped Trip: the Dynamics of SGD with Gradient Clipping in High-Dimensions
Marshall, Noah, Xiao, Ke Liang, Agarwala, Atish, Paquette, Elliot
The success of modern machine learning is due in part to the adaptive optimization methods that have been developed to deal with the difficulties of training large models over complex datasets. One such method is gradient clipping: a practical procedure with limited theoretical underpinnings. In this work, we study clipping in a least squares problem under streaming SGD. We develop a theoretical analysis of the learning dynamics in the limit of large intrinsic dimension--a model and dataset dependent notion of dimensionality. In this limit we find a deterministic equation that describes the evolution of the loss. We show that with Gaussian noise clipping cannot improve SGD performance. Yet, in other noisy settings, clipping can provide benefits with tuning of the clipping threshold. In these cases, clipping biases updates in a way beneficial to training which cannot be recovered by SGD under any schedule. We conclude with a discussion about the links between high-dimensional clipping and neural network training.
High dimensional analysis reveals conservative sharpening and a stochastic edge of stability
Agarwala, Atish, Pennington, Jeffrey
Recent empirical and theoretical work has shown that the dynamics of the large eigenvalues of the training loss Hessian have some remarkably robust features across models and datasets in the full batch regime. There is often an early period of progressive sharpening where the large eigenvalues increase, followed by stabilization at a predictable value known as the edge of stability. Previous work showed that in the stochastic setting, the eigenvalues increase more slowly - a phenomenon we call conservative sharpening. We provide a theoretical analysis of a simple high-dimensional model which shows the origin of this slowdown. We also show that there is an alternative stochastic edge of stability which arises at small batch size that is sensitive to the trace of the Neural Tangent Kernel rather than the large Hessian eigenvalues. We conduct an experimental study which highlights the qualitative differences from the full batch phenomenology, and suggests that controlling the stochastic edge of stability can help optimization.
Gradient descent induces alignment between weights and the empirical NTK for deep non-linear networks
Beaglehole, Daniel, Mitliagkas, Ioannis, Agarwala, Atish
Understanding the mechanisms through which neural networks extract statistics from input-label pairs is one of the most important unsolved problems in supervised learning. Prior works have identified that the gram matrices of the weights in trained neural networks of general architectures are proportional to the average gradient outer product of the model, in a statement known as the Neural Feature Ansatz (NFA). However, the reason these quantities become correlated during training is poorly understood. In this work, we explain the emergence of this correlation. We identify that the NFA is equivalent to alignment between the left singular structure of the weight matrices and a significant component of the empirical neural tangent kernels associated with those weights. We establish that the NFA introduced in prior works is driven by a centered NFA that isolates this alignment. We show that the speed of NFA development can be predicted analytically at early training times in terms of simple statistics of the inputs and labels. Finally, we introduce a simple intervention to increase NFA correlation at any given layer, which dramatically improves the quality of features learned.
Neglected Hessian component explains mysteries in Sharpness regularization
Dauphin, Yann N., Agarwala, Atish, Mobahi, Hossein
Recent work has shown that methods like SAM which either explicitly or implicitly penalize second order information can improve generalization in deep learning. Seemingly similar methods like weight noise and gradient penalties often fail to provide such benefits. We show that these differences can be explained by the structure of the Hessian of the loss. First, we show that a common decomposition of the Hessian can be quantitatively interpreted as separating the feature exploitation from feature exploration. The feature exploration, which can be described by the Nonlinear Modeling Error matrix (NME), is commonly neglected in the literature since it vanishes at interpolation. Our work shows that the NME is in fact important as it can explain why gradient penalties are sensitive to the choice of activation function. Using this insight we design interventions to improve performance. We also provide evidence that challenges the long held equivalence of weight noise and gradient penalties. This equivalence relies on the assumption that the NME can be ignored, which we find does not hold for modern networks since they involve significant feature learning. We find that regularizing feature exploitation but not feature exploration yields performance similar to gradient penalties.
On the Interplay Between Stepsize Tuning and Progressive Sharpening
Roulet, Vincent, Agarwala, Atish, Pedregosa, Fabian
Recent empirical work has revealed an intriguing property of deep learning models by which the sharpness (largest eigenvalue of the Hessian) increases throughout optimization until it stabilizes around a critical value at which the optimizer operates at the edge of stability, given a fixed stepsize (Cohen et al, 2022). We investigate empirically how the sharpness evolves when using stepsize-tuners, the Armijo linesearch and Polyak stepsizes, that adapt the stepsize along the iterations to local quantities such as, implicitly, the sharpness itself. We find that the surprisingly poor performance of a classical Armijo linesearch in the deterministic setting may be well explained by its tendency to ever-increase the sharpness of the objective. On the other hand, we observe that Polyak stepsizes operate generally at the edge of stability or even slightly beyond, outperforming its Armijo and constant stepsizes counterparts in the deterministic setting. We conclude with an analysis that suggests unlocking stepsize tuners requires an understanding of the joint dynamics of the step size and the sharpness.
SAM operates far from home: eigenvalue regularization as a dynamical phenomenon
Agarwala, Atish, Dauphin, Yann N.
The Sharpness Aware Minimization (SAM) optimization algorithm has been shown to control large eigenvalues of the loss Hessian and provide generalization benefits in a variety of settings. The original motivation for SAM was a modified loss function which penalized sharp minima; subsequent analyses have also focused on the behavior near minima. However, our work reveals that SAM provides a strong regularization of the eigenvalues throughout the learning trajectory. We show that in a simplified setting, SAM dynamically induces a stabilization related to the edge of stability (EOS) phenomenon observed in large learning rate gradient descent. Our theory predicts the largest eigenvalue as a function of the learning rate and SAM radius parameters. Finally, we show that practical models can also exhibit this EOS stabilization, and that understanding SAM must account for these dynamics far away from any minima.
One Network Fits All? Modular versus Monolithic Task Formulations in Neural Networks
Agarwala, Atish, Das, Abhimanyu, Juba, Brendan, Panigrahy, Rina, Sharan, Vatsal, Wang, Xin, Zhang, Qiuyi
Can deep learning solve multiple tasks simultaneously, even when they are unrelated and very different? We investigate how the representations of the underlying tasks affect the ability of a single neural network to learn them jointly. We present theoretical and empirical findings that a single neural network is capable of simultaneously learning multiple tasks from a combined data set, for a variety of methods for representing tasks--for example, when the distinct tasks are encoded by well-separated clusters or decision trees over certain task-code attributes. More concretely, we present a novel analysis that shows that families of simple programming-like constructs for the codes encoding the tasks are learnable by two-layer neural networks with standard training. We study more generally how the complexity of learning such combined tasks grows with the complexity of the task codes; we find that combining many tasks may incur a sample complexity penalty, even though the individual tasks are easy to learn. We provide empirical support for the usefulness of the learning bounds by training networks on clusters, decision trees, and SQL-style aggregation. Standard practice in machine learning has long been to only address carefully circumscribed, often very related tasks. For example, we might train a single classifier to label an image as containing objects from a certain predefined set, or to label the words of a sentence with their semantic roles. Indeed, when working with relatively simple classes of functions like linear classifiers, it would be unreasonable to expect to train a classifier that handles more than such a carefully scoped task (or related tasks in standard multitask learning). As techniques for learning with relatively rich classes such as neural networks have been developed, it is natural to ask whether or not such scoping of tasks is inherently necessary. Indeed, many recent works (see Section 1.2) have proposed eschewing this careful scoping of tasks, and instead training a single, "monolithic" function spanning many tasks. Large, deep neural networks can, in principle, represent multiple classifiers in such a monolithic learned function (Hornik, 1991), giving rise to the field of multitask learning. This combined function might be learned by combining all of the training data for all of the tasks into one large batch-see Section 1.2 for some examples. Taken to an extreme, we could consider seeking to learn a universal circuit--that is, a circuit that interprets arbitrary programs in a programming language which can encode various tasks. But, the ability to represent such a monolithic combined function does not necessarily entail that such a function can be efficiently learned by existing methods.