Dichotomy of Early and Late Phase Implicit Biases Can Provably Induce Grokking

Lyu, Kaifeng, Jin, Jikai, Li, Zhiyuan, Du, Simon S., Lee, Jason D., Hu, Wei

arXiv.org Artificial Intelligence 

Recent work by Power et al. (2022) highlighted a surprising "grokking" phenomenon in learning arithmetic tasks: a neural net first "memorizes" the training set, resulting in perfect training accuracy but near-random test accuracy, and after training for sufficiently longer, it suddenly transitions to perfect test accuracy. This paper studies the grokking phenomenon in theoretical setups and shows that it can be induced by a dichotomy of early and late phase implicit biases. Specifically, when training homogeneous neural nets with large initialization and small weight decay on both classification and regression tasks, we prove that the training process gets trapped at a solution corresponding to a kernel predictor for a long time, and then a very sharp transition to min-norm/max-margin predictors occurs, leading to a dramatic change in test accuracy. The generalization behavior of modern over-parameterized neural nets has been puzzling: these nets have the capacity to overfit the training set, and yet they frequently exhibit a small gap between training and test performance when trained by popular gradient-based optimizers. A common view now is that the network architectures and training pipelines can automatically induce regularization effects to avoid or mitigate overfitting throughout the training trajectory. Recently, Power et al. (2022) discovered an even more perplexing generalization phenomenon called grokking: when training a neural net to learn modular arithmetic operations, it first "memorizes" the training set with zero training error and near-random test error, and then training for much longer leads to a sharp transition from no generalization to perfect generalization. See Section 2 for our reproduction of this phenomenon. Beyond modular arithmetic, grokking has been reported in learning group operations (Chughtai et al., 2023), learning sparse parity (Barak et al., 2022; Bhattamishra et al., 2022), learning greatest common divisor (Charton, 2023), and image classification (Liu et al., 2023; Radhakrishnan et al., 2022).