Goto

Collaborating Authors

 minima


Improving Model-Based Reinforcement Learning by Converging to Flatter Minima

Neural Information Processing Systems

Model-based reinforcement learning (MBRL) hinges on a learned dynamics model whose errors can compound along imagined rollouts. We study how encouraging flatness in the model's training loss affects downstream control, and show that steering optimization toward flatter minima yields a better policy. Concretely, we integrate Sharpness-Aware Minimization (SAM) into world-model training as a drop-in objective, leaving the planner and policy components unchanged. On the theory side, we derive PAC-Bayesian bounds that link first-order sharpness to the value-estimation gap and the performance gap between model-optimal and true-optimal policies, implying that flatter minima tighten both. Empirically, SAM reduces measured sharpness and value-prediction error and improves returns across HumanoidBench, Atari-100k, and high-DoF DeepMind Control tasks. Augmenting existing MBRL algorithms with SAM increases mean return, with especially large gains in settings with high dimensional state-action spaces. We further observe positive transfer across algorithms and input modalities, including a transformerbased world-model.


Variational Learning Finds Flatter Solutions at the Edge of Stability

Neural Information Processing Systems

Variational Learning (VL) has recently gained popularity for training deep neural networks. Part of its empirical success can be explained by theories such as PACBayes bounds, minimum description length and marginal likelihood, but little has been done to unravel the implicit regularization in play. Here, we analyze the implicit regularization of VL through the Edge of Stability (EoS) framework. EoS has previously been used to show that gradient descent can find flat solutions and we extend this result to show that VL can find even flatter solutions. This result is obtained by controlling the shape of the variational posterior as well as the number of posterior samples used during training. The derivation follows in a similar fashion as in the standard EoS literature for deep learning, by first deriving a result for a quadratic problem and then extending it to deep neural networks. We empirically validate these findings on a wide variety of large networks, such as ResNet and ViT, to find that the theoretical results closely match the empirical ones. Ours is the first work to analyze the EoS dynamics of VL.


Flat Channels to Infinity in Neural Loss Landscapes

Neural Information Processing Systems

The loss landscapes of neural networks contain minima and saddle points that may be connected in flat regions or appear in isolation. We identify and characterize a special structure in the loss landscape: channels along which the loss decreases extremely slowly, while the output weights of at least two neurons, ai and aj, diverge to infinity, and their input weight vectors, wi and wj, become equal to each other. At convergence, the two neurons implement a gated linear unit: aiฯƒ(wi x) + ajฯƒ(wj x) cฯƒ(w x) + (v x)ฯƒ (w x). Geometrically, these channels to infinity are asymptotically parallel to symmetry-induced lines of critical points. Gradient flow solvers, and related optimization methods like SGD or ADAM, reach the channels with high probability in diverse regression settings, but without careful inspection they look like flat local minima with finite parameter values. Our characterization provides a comprehensive picture of these quasi-flat regions in terms of gradient dynamics, geometry, and functional interpretation. The emergence of gated linear units at the end of the channels highlights a surprising aspect of the computational capabilities of fully connected layers.


ATale of Two Symmetries: Exploring the Loss Landscape of Equivariant Models

Neural Information Processing Systems

Equivariant neural networks have proven to be effective for tasks with known underlying symmetries. However, optimizing equivariant networks can be tricky and best training practices are less established than for standard networks. In particular, recent works have found small training benefits from relaxing equivariance constraints. This raises the question: do equivariance constraints introduce fundamental obstacles to optimization? Or do they simply require different hyperparameter tuning?


AUnified Stability Analysis of SAM vs SGD: Role of Data Coherence and Emergence of Simplicity Bias

Neural Information Processing Systems

Understanding the dynamics of optimization in deep learning is increasingly important as models scale. While stochastic gradient descent (SGD) and its variants reliably find solutions that generalize well, the mechanisms driving this generalization remain unclear. Notably, these algorithms often prefer flatter or simpler minima--particularly in overparameterized settings. Prior work has linked flatness to generalization, and methods like Sharpness-Aware Minimization (SAM) explicitly encourage flatness, but a unified theory connecting data structure, optimization dynamics, and the nature of learned solutions is still lacking. In this work, we develop a linear stability framework that analyzes the behavior of SGD, random perturbations, and SAM--particularly in two-layer ReLU networks. Central to our analysis is a coherence measure that quantifies how gradient curvature aligns across data points, revealing why certain minima are stable and favored during training.


Asymptotics of SGD in Sequence-Single Index Models and Single-Layer Attention Networks

Neural Information Processing Systems

We study the dynamics of stochastic gradient descent (SGD) for a class of sequence models termed Sequence Single-Index (SSI) models, where the target depends on a single direction in input space applied to a sequence of tokens. This setting generalizes classical single-index models to the sequential domain, encompassing simplified one-layer attention architectures. We derive a closed-form expression for the population loss in terms of a pair of sufficient statistics capturing semantic and positional alignment, and characterize the induced high-dimensional SGD dynamics for these coordinates. Our analysis reveals two distinct training phases: escape from uninformative initialization and alignment with the target subspace, and demonstrates how the sequence length and positional encoding influence convergence speed and learning trajectories. These results provide a rigorous and interpretable foundation for understanding how sequential structure in data can be beneficial for learning with attention-based models. Stochastic Gradient Descent (SGD) is the core optimization tool driving modern machine learning. Recent years have seen substantial progress in understanding its dynamics, particularly in two-layer networks [Saad and Solla, 1995, Mei et al., 2018, Chizat and Bach, 2018, Rotskoff and VandenEijnden, 2022, Sirignano and Spiliopoulos, 2020, Arnaboldi et al., 2023a]. While global convergence is qualitatively well-understood when the network is wide enough, quantitative results are scarcer. A particularly fruitful body of recent theoretical work addressing this gap has focused on deriving precise convergence rates for particular model classes on synthetic data, such as high-dimensional Gaussian single and multi-index models [Ben Arous et al., 2021, Abbe et al., 2022, 2023].


A Tale of Two Symmetries: Exploring the Loss Landscape of Equivariant Models

Neural Information Processing Systems

Equivariant neural networks have proven to be effective for tasks with known underlying symmetries. However, optimizing equivariant networks can be tricky and best training practices are less established than for standard networks. In particular, recent works have found small training benefits from relaxing equivariance constraints. This raises the question: do equivariance constraints introduce fundamental obstacles to optimization? Or do they simply require different hyperparameter tuning?


Stable Minima of ReLU Neural Networks Suffer from the Curse of Dimensionality: The Neural Shattering Phenomenon

Neural Information Processing Systems

We study the implicit bias of flatness / low (loss) curvature and its effects on generalization in two-layer overparameterized ReLU networks with multivariate inputs---a problem well motivated by the minima stability and edge-of-stability phenomena in gradient-descent training. Existing work either requires interpolation or focuses only on univariate inputs. This paper presents new and somewhat surprising theoretical results for multivariate inputs. On two natural settings (1) generalization gap for flat solutions, and (2) mean-squared error (MSE) in nonparametric function estimation by stable minima, we prove upper and lower bounds, which establish that while flatness does imply generalization, the resulting rates of convergence necessarily deteriorate exponentially as the input dimension grows. This gives an exponential separation between the flat solutions compared to low-norm solutions (i.e., weight decay), which are known not to suffer from the curse of dimensionality. In particular, our minimax lower bound construction, based on a novel packing argument with boundary-localized ReLU neurons, reveals how flat solutions can exploit a kind of neural shattering where neurons rarely activate, but with high weight magnitudes. This leads to poor performance in high dimensions. We corroborate these theoretical findings with extensive numerical simulations. To the best of our knowledge, our analysis provides the first systematic explanation for why flat minima may fail to generalize in high dimensions.


A Unified Stability Analysis of SAM vs SGD: Role of Data Coherence and Emergence of Simplicity Bias

Neural Information Processing Systems

Understanding the dynamics of optimization algorithms in deep learning has become increasingly critical, especially as models grow in scale and complexity. Despite the empirical success of stochastic gradient descent (SGD) and its variants in finding solutions that generalize well, the precise mechanisms underlying this generalization remain poorly understood. A particularly intriguing aspect of this phenomenon is the bias of optimization algorithms towards certain types of minima--often flatter or simpler--especially in overparameterized regimes. While prior works have associated flatness of the loss landscape with better generalization, tools to mechanistically connect data, optimization algorithms, and the nature of the resulting minima are still limited. For instance, methods like Sharpness-Aware Minimization (SAM) have shown practical gains by explicitly promoting flatness, but lack a unified theoretical framework explaining their influence across different data structures and model architectures. In this work, we introduce a comprehensive linear stability analysis framework to dissect the behavior of optimization algorithms--SGD, random perturbations, and SAM--in neural networks, focusing particularly on two-layer ReLU models. Our approach is built upon a novel coherence measure that captures the interaction between data geometry and gradient similarity, providing new insights into why and how certain solutions are favored.