Goto

Collaborating Authors

 sharpness


Continual Learning in Modern Hopfield Networks with an Application to Diffusion Models

arXiv.org Machine Learning

Generative models, including diffusion models, are increasingly used as foundation models and adapted through sequential fine-tuning, making continual learning an essential problem setting. However, continual learning in such generative models remains poorly understood: after a task change, what aspects of the learned distribution are most easily lost, and what replay samples should be prioritized? We address these questions through the modern Hopfield energy. Recent links between modern Hopfield networks (MHNs) and diffusion models allow analyses in MHNs to be transferred to diffusion models. We introduce intrinsic forgetting as an increase in Hopfield energy after the task change. In tractable settings in an MHN, we prove that high-energy, outlier-like samples undergo a larger energy increase than cluster-like samples, implying that samples located in sharp, isolated basins are more forgettable. We further analyze memory replay and show that replay is particularly effective for high-energy samples, enabling an energy-based selection of replay samples. We validate these predictions in experiments on MHNs and two diffusion models under continual-learning settings: Stable Diffusion and a pixel-space DDPM. In these diffusion models, Hopfield energy tracks reconstruction-based forgetting, and replay experiments reveal energy-dependent mitigation of forgetting that is consistent with the MHN analysis.


Does Weight Decay Enhance Training Stability?

arXiv.org Machine Learning

In modern deep learning, weight decay is often credited with "stabilizing" training dynamics, diverging from its classical role as a static regularization penalty. We investigate a fundamental question: *does weight decay stabilize training dynamics, and if so, through which mechanism?* Indeed, training stability is understood through different but related notions in the literature. We consider how weight decay affects the parameter-space dynamics and loss sharpness by analyzing its effects at the \emph{Edge of Stability} (EoS). We show that weight decay robustly slows *progressive sharpening}. Furthermore, we uncover a striking architecture-dependent phase transition. In CNNs, weight decay dampens the oscillations at the EoS, while in MLPs, increasing weight decay causes a phase transition in which the sharpness stabilizes at a threshold significantly below the theoretical $\frac{2}η$ boundary. We develop a mathematical framework that accurately models these phenomena and identify the global alignment of the parameter vector and the sharpness gradient as the mechanistic driver of the phase transition. Importantly, we show that these phenomena translate into stability in terms of search in function-space (NTK). Last, this shows that curvature thresholds obtained from convex/quadratic heuristics may not be reliable stability diagnostics under regularization.


Enhancing Sharpness-Aware Optimization Through Variance Suppression

Neural Information Processing Systems

Sharpness-aware minimization (SAM) has well documented merits in enhancing generalization of deep neural networks, even without sizable data augmentation. Embracing the geometry of the loss function, where neighborhoods of'flat minima' heighten generalization ability, SAM seeks'flat valleys' by minimizing the maximum loss caused by an adversary perturbing parameters within the neighborhood. Although critical to account for sharpness of the loss function, such an'over-friendly adversary' can curtail the outmost level of generalization. The novel approach of this contribution fosters stabilization of adversaries through variance suppression (VaSSO) to avoid such friendliness.


Enhancing Sharpness-Aware Optimization Through Variance Suppression

Neural Information Processing Systems

Sharpness-aware minimization (SAM) has well documented merits in enhancing generalization of deep neural networks, even without sizable data augmentation. Embracing the geometry of the loss function, where neighborhoods of'flat minima' heighten generalization ability, SAM seeks'flat valleys' by minimizing the maximum loss caused by an adversary perturbing parameters within the neighborhood. Although critical to account for sharpness of the loss function, such an'over-friendly adversary' can curtail the outmost level of generalization. The novel approach of this contribution fosters stabilization of adversaries through variance suppression (VaSSO) to avoid such friendliness.


Normalization Layers Are All That Sharpness-Aware Minimization Needs

Neural Information Processing Systems

Sharpness-aware minimization (SAM) was proposed to reduce sharpness of minima and has been shown to enhance generalization performance in various settings. In this work we show that perturbing only the affine normalization parameters (typically comprising 0.1% of the total parameters) in the adversarial step of SAM can outperform perturbing all of the parameters.


Sharpness Minimization Algorithms Do Not Only Minimize Sharpness To Achieve Better Generalization

Neural Information Processing Systems

Despite extensive studies, the underlying reason as to why overparameterized neural networks can generalize remains elusive. Existing theory shows that common stochastic optimizers prefer flatter minimizers of the training loss, and thus a natural potential explanation is that flatness implies generalization. This work critically examines this explanation. Through theoretical and empirical investigation, we identify the following three scenarios for two-layer ReLU networks: (1) flatness provably implies generalization; (2) there exist non-generalizing flattest models and sharpness minimization algorithms fail to generalize poorly, and (3) perhaps most strikingly, there exist non-generalizing flattest models, but sharpness minimization algorithms still generalize. Our results suggest that the relationship between sharpness and generalization subtly depends on the data distributions and the model architectures and sharpness minimization algorithms do not only minimize sharpness to achieve better generalization. This calls for the search for other explanations for the generalization of over-parameterized neural networks.


Calibrating Scientific Foundation Models with Inference-Time Stochastic Attention

arXiv.org Machine Learning

Transformer-based scientific foundation models are increasingly deployed in high-stakes settings, but current architectures give deterministic outputs and provide limited support for calibrated predictive uncertainty. We propose Stochastic Attention, a lightweight inference-time modification that randomizes attention by replacing softmax weights with normalized multinomial samples controlled by a single concentration parameter, and produces predictive ensembles without retraining. To set this parameter, we introduce a calibration objective that matches the stochastic attention output with the target, yielding an efficient univariate post-hoc tuning problem. We evaluate this mechanism on two scientific foundation models for weather and timeseries forecasting along with an additional regression task. Across benchmarks against uncertainty-aware baselines, we find that Stochastic Attention achieves the strongest native calibration and the sharpest prediction intervals at comparable coverage, while requiring only minutes of post-hoc tuning versus days of retraining for competitive baselines.


Generalization at the Edge of Stability

arXiv.org Machine Learning

Training modern neural networks often relies on large learning rates, operating at the edge of stability, where the optimization dynamics exhibit oscillatory and chaotic behavior. Empirically, this regime often yields improved generalization performance, yet the underlying mechanism remains poorly understood. In this work, we represent stochastic optimizers as random dynamical systems, which often converge to a fractal attractor set (rather than a point) with a smaller intrinsic dimension. Building on this connection and inspired by Lyapunov dimension theory, we introduce a novel notion of dimension, coined the `sharpness dimension', and prove a generalization bound based on this dimension. Our results show that generalization in the chaotic regime depends on the complete Hessian spectrum and the structure of its partial determinants, highlighting a complexity that cannot be captured by the trace or spectral norm considered in prior work. Experiments across various MLPs and transformers validate our theory while also providing new insights into the recently observed phenomenon of grokking.


Deep linear networks for regression are implicitly regularized towards flat minima

Neural Information Processing Systems

The largest eigenvalue of the Hessian, or sharpness, of neural networks is a key quantity to understand their optimization dynamics. In this paper, we study the sharpness of deep linear networks for univariate regression. Minimizers can have arbitrarily large sharpness, but not an arbitrarily small one. Indeed, we show a lower bound on the sharpness of minimizers, which grows linearly with depth. We then study the properties of the minimizer found by gradient flow, which is the limit of gradient descent with vanishing learning rate.


Stepping on the Edge: Curvature Aware Learning Rate Tuners

Neural Information Processing Systems

Curvature information -- particularly, the largest eigenvalue of the lossHessian, known as the sharpness -- often forms the basis for learning ratetuners. However, recent work has shown that the curvature information undergoescomplex dynamics during training, going from a phase of increasing sharpness toeventual stabilization. We analyze the closed-loop feedback effect betweenlearning rate tuning and curvature. We find that classical learning rate tunersmay yield greater one-step loss reduction, yet they ultimately underperform inthe long term when compared to constant learning rates in the full batch regime.These models break the stabilization of the sharpness, which we explain using asimplified model of the joint dynamics of the learning rate and the curvature.To further investigate these effects, we introduce a new learning rate tuningmethod, Curvature Dynamics Aware Tuning (CDAT), which prioritizes long termcurvature stabilization over instantaneous progress on the objective. In thefull batch regime, CDAT shows behavior akin to prefixed warm-up schedules on deeplearning objectives, outperforming tuned constant learning rates. In the minibatch regime, we observe that stochasticity introduces confounding effects thatexplain the previous success of some learning rate tuners at appropriate batchsizes. Our findings highlight the critical role of understanding the jointdynamics of the learning rate and curvature, beyond greedy minimization, todiagnose failures and design effective adaptive learning rate tuners.