Can Stability be Detrimental? Better Generalization through Gradient Descent Instabilities

Wang, Lawrence, Roberts, Stephen J.

arXiv.org Machine Learning 

Traditional analyses of gradient descent optimization show that, when the largest eigenvalue of the loss Hessian - often referred to as the sharpness - is below a critical learning-rate threshold, then training is'stable' and training loss decreases monotonically. Recent studies, however, have suggested that the majority of modern deep neural networks achieve good performance despite operating outside this stable regime. In this work, we demonstrate that such instabilities, induced by large learning rates, move model parameters toward flatter regions of the loss landscape. Our crucial insight lies in noting that, during these instabilities, the orientation of the Hessian eigenvectors rotate. This, we conjecture, allows the model to explore regions of the loss landscape that display more desirable geometrical properties for generalization, such as flatness. These rotations are a consequence of network depth, and we prove that for any network with depth > 1, unstable growth in parameters causes rotations in the principal components of the Hessian, which promotes exploration of the parameter space away from unstable directions. Our empirical studies reveal an implicit regularization effect in gradient descent with large learning rates operating beyond the stability threshold. We find these lead to improved generalization performance on modern benchmark datasets. Deep neural networks are widely successful across a number of tasks, but their generalization performance is dependent on careful choices of hyperparameters which govern the learning process. Gradient descent (including stochastic gradient descent and ADAM (Kingma & Ba, 2017)) is arguably the most widely-used learning algorithm due to its simplicity and versatility. For such methods, the descent lemma upper-bounds the choice of learning rate by the local curvature (or sharpness) to guarantee stable optimization trajectories and provable decreases for convex training losses. Recently, the'unstable' learning-rate regime has been a focal point for research.