Goto

Collaborating Authors

 flatter minima


Improving Model-Based Reinforcement Learning by Converging to Flatter Minima

Neural Information Processing Systems

Model-based reinforcement learning (MBRL) hinges on a learned dynamics model whose errors can compound along imagined rollouts. We study how encouraging flatness in the model's training loss affects downstream control, and show that steering optimization toward flatter minima yields a better policy. Concretely, we integrate Sharpness-Aware Minimization (SAM) into world-model training as a drop-in objective, leaving the planner and policy components unchanged. On the theory side, we derive PAC-Bayesian bounds that link first-order sharpness to the value-estimation gap and the performance gap between model-optimal and true-optimal policies, implying that flatter minima tighten both. Empirically, SAM reduces measured sharpness and value-prediction error and improves returns across HumanoidBench, Atari-100k, and high-DoF DeepMind Control tasks. Augmenting existing MBRL algorithms with SAM increases mean return, with especially large gains in settings with high dimensional state-action spaces. We further observe positive transfer across algorithms and input modalities, including a transformerbased world-model.


Improving Model-Based Reinforcement Learning by Converging to Flatter Minima

Neural Information Processing Systems

Model-based reinforcement learning (MBRL) hinges on a learned dynamics model whose errors can compound along imagined rollouts. We study how encouraging \emph{flatness} in the model's training loss affects downstream control, and show that steering optimization toward flatter minima yields a better policy. Concretely, we integrate \emph{Sharpness-Aware Minimization} (SAM) into world-model training as a drop-in objective, leaving the planner and policy components unchanged. On the theory side, we derive PAC-Bayesian bounds that link first-order sharpness to the value-estimation gap and the performance gap between model-optimal and true-optimal policies, implying that flatter minima tighten both. Empirically, SAM reduces measured sharpness and value-prediction error and improves returns across HumanoidBench, Atari-100k, and high-DoF DeepMind Control tasks. Augmenting existing MBRL algorithms with SAM increases mean return, with especially large gains in settings with high dimensional state-action space. We further observe positive transfer across algorithms and input modalities, including a transformer-based world-model.



. We would like to point out that

Neural Information Processing Systems

We would like to thank all the valuable and constructive feedback from the reviewers. AdaReg does not explicitly enforce the weight matrices to be positively/negatively correlated. Therefore, our method is orthogonal to but not contradictory with Dropout. Inspired by this result, we explored hyperparameter learning by empirical Bayes. BatchNorm, we do observe that smaller batch size leads to better generalizations.


Gradient Descent Converges Linearly to Flatter Minima than Gradient Flow in Shallow Linear Networks

arXiv.org Machine Learning

We study the gradient descent (GD) dynamics of a depth-2 linear neural network with a single input and output. We show that GD converges at an explicit linear rate to a global minimum of the training loss, even with a large stepsize -- about $2/\textrm{sharpness}$. It still converges for even larger stepsizes, but may do so very slowly. We also characterize the solution to which GD converges, which has lower norm and sharpness than the gradient flow solution. Our analysis reveals a trade off between the speed of convergence and the magnitude of implicit regularization. This sheds light on the benefits of training at the ``Edge of Stability'', which induces additional regularization by delaying convergence and may have implications for training more complex models.


Meta Curvature-Aware Minimization for Domain Generalization

arXiv.org Artificial Intelligence

Domain generalization (DG) aims to enhance the ability of models trained on source domains to generalize effectively to unseen domains. Recently, Sharpness-Aware Minimization (SAM) has shown promise in this area by reducing the sharpness of the loss landscape to obtain more generalized models. However, SAM and its variants sometimes fail to guide the model toward a flat minimum, and their training processes exhibit limitations, hindering further improvements in model generalization. In this paper, we first propose an improved model training process aimed at encouraging the model to converge to a flat minima. To achieve this, we design a curvature metric that has a minimal effect when the model is far from convergence but becomes increasingly influential in indicating the curvature of the minima as the model approaches a local minimum. Then we derive a novel algorithm from this metric, called Meta Curvature-Aware Minimization (MeCAM), to minimize the curvature around the local minima. Specifically, the optimization objective of MeCAM simultaneously minimizes the regular training loss, the surrogate gap of SAM, and the surrogate gap of meta-learning. We provide theoretical analysis on MeCAM's generalization error and convergence rate, and demonstrate its superiority over existing DG methods through extensive experiments on five benchmark DG datasets, including PACS, VLCS, OfficeHome, TerraIncognita, and DomainNet. Code will be available on GitHub.


A Granger-Causal Perspective on Gradient Descent with Application to Pruning

arXiv.org Artificial Intelligence

Stochastic Gradient Descent (SGD) is the main approach to optimizing neural networks. Several generalization properties of deep networks, such as convergence to a flatter minima, are believed to arise from SGD. This article explores the causality aspect of gradient descent. Specifically, we show that the gradient descent procedure has an implicit granger-causal relationship between the reduction in loss and a change in parameters. By suitable modifications, we make this causal relationship explicit. A causal approach to gradient descent has many significant applications which allow greater control. In this article, we illustrate the significance of the causal approach using the application of Pruning. The causal approach to pruning has several interesting properties - (i) We observe a phase shift as the percentage of pruned parameters increase. Such phase shift is indicative of an optimal pruning strategy.


1st-Order Magic: Analysis of Sharpness-Aware Minimization

arXiv.org Artificial Intelligence

Sharpness-Aware Minimization (SAM) is an optimization technique designed to improve generalization by favoring flatter loss minima. To achieve this, SAM optimizes a modified objective that penalizes sharpness, using computationally efficient approximations. Interestingly, we find that more precise approximations of the proposed SAM objective degrade generalization performance, suggesting that the generalization benefits of SAM are rooted in these approximations rather than in the original intended mechanism. This highlights a gap in our understanding of SAM's effectiveness and calls for further investigation into the role of approximations in optimization.


Sharpness-Aware Minimization Efficiently Selects Flatter Minima Late in Training

arXiv.org Machine Learning

Sharpness-Aware Minimization (SAM) has substantially improved the generalization of neural networks under various settings. Despite the success, its effectiveness remains poorly understood. In this work, we discover an intriguing phenomenon in the training dynamics of SAM, shedding lights on understanding its implicit bias towards flatter minima over Stochastic Gradient Descent (SGD). We conjecture that the optimization method chosen in the late phase is more crucial in shaping the final solution's properties. Based on this viewpoint, we extend our findings from SAM to Adversarial Training. We provide source code in supplementary materials and will release checkpoints in future. Recently, it has been observed that the generalization of neural networks is closely tied to the sharpness of the loss landscape (Keskar et al., 2017; Zhang et al., 2017; Neyshabur et al., 2017; Jiang et al., 2020). This has led to the development of many gradient-based optimization algorithms that explicitly/implicitly regularize the sharpness of solutions. In particular, Foret et al. (2021) proposed Sharpness-Aware Minimization (SAM), which has substantially improved the generalization and robustness (Zhang et al., 2024) of neural networks across many tasks, including computer vision (Foret et al., 2021; Chen et al., 2022; Kaddour et al., 2022) and natural language processing (Bahri et al., 2022). Despite the empirical success of SAM, its effectiveness is not yet fully understood. Andriushchenko & Flammarion (2022) has shown that existing theoretical justifications based on PAC-Bayes generalization bounds (Foret et al., 2021; Wu et al., 2020a) are incomplete in explaining the superior performance of SAM.


QT-DoG: Quantization-aware Training for Domain Generalization

arXiv.org Artificial Intelligence

Domain Generalization (DG) aims to train models that perform well not only on the training (source) domains but also on novel, unseen target data distributions. A key challenge in DG is preventing overfitting to source domains, which can be mitigated by finding flatter minima in the loss landscape. In this work, we propose Quantization-aware Training for Domain Generalization (QT-DoG) and demonstrate that weight quantization effectively leads to flatter minima in the loss landscape, thereby enhancing domain generalization. Unlike traditional quantization methods focused on model compression, QT-DoG exploits quantization as an implicit regularizer by inducing noise in model weights, guiding the optimization process toward flatter minima that are less sensitive to perturbations and overfitting. We provide both theoretical insights and empirical evidence demonstrating that quantization inherently encourages flatter minima, leading to better generalization across domains. Moreover, with the benefit of reducing the model size through quantization, we demonstrate that an ensemble of multiple quantized models further yields superior accuracy than the state-of-the-art DG approaches with no computational or memory overheads. Our extensive experiments demonstrate that QT-DoG generalizes across various datasets, architectures, and quantization algorithms, and can be combined with other DG methods, establishing its versatility and robustness.