Grokking at the Edge of Numerical Stability
Prieto, Lucas, Barsbey, Melih, Mediano, Pedro A. M., Birdal, Tolga
Grokking, or sudden generalization that occurs after prolonged overfitting, is a surprising phenomenon that has challenged our understanding of deep learning. While a lot of progress has been made in understanding grokking, it is still not clear why generalization is delayed and why grokking often does not happen without regularization. In this work we argue that without regularization, grokking tasks push models to the edge of numerical stability, introducing floating point errors in the Softmax that we refer to as Softmax Collapse (SC). We show that SC prevents grokking and that mitigating SC leads to grokking without regularization. Investigating the root cause of SC, we find that beyond the point of overfitting, the gradients strongly align with what we call the naïve loss minimization (NLM) direction. This component of the gradient does not change the predictions of the model but decreases the loss by scaling the logits, usually through the scaling of the weights along their current direction. We show that this scaling of the logits explains the delay in generalization characteristic of grokking, and eventually leads to SC, stopping learning altogether. To validate these hypotheses, we introduce two key contributions that mitigate the issues faced in grokking tasks: (i) StableMax, a new activation function that prevents SC and enables grokking without regularization, and (ii) Grad, a training algorithm that leads to quick generalization in grokking tasks by preventing NLM altogether. These contributions provide new insights into grokking, shedding light on its delayed generalization, reliance on regularization, and the effectiveness of known grokking-inducing methods. Code for this paper can be found at: https://github.com/LucasPrietoAl/ Deep learning has been transformative for a variety of fields such as natural language processing (Devlin et al., 2019), computer vision (Krizhevsky et al., 2012), geometry processing (Qi et al., 2017), and 3D vision (Deng et al., 2018). This rapid proliferation has brought with it surprising phenomena that defy the predictions of classical statistical learning theory. In this paper we explore one such recently observed phenomenon known as grokking, first described by Power et al. (2022) as a sudden and unexpected generalization occurring after prolonged overfitting.
Jan-8-2025