Gromov, Andrey
PARQ: Piecewise-Affine Regularized Quantization
Jin, Lisa, Ma, Jianhao, Liu, Zechun, Gromov, Andrey, Defazio, Aaron, Xiao, Lin
Modern deep learning models exhibit exceptional vision and language processing capabilities, but come with excessive sizes and demands on memory and computing. Quantization is an effective approach for model compression, which can significantly reduce their memory footprint, computing cost, as well as latency for inference (e.g., Han et al., 2016; Sze et al., 2017). There are two main classes of quantization methods: post-training quantization (PTQ) and quantization-aware training (QAT). Both are widely adopted and receive extensive research--see the recent survey papers (Gholami et al., 2022; Fournarakis et al., 2022) and references therein. PTQ converts the weights of a pre-trained model directly into lower precision without repeating the training pipeline; it thus has less overhead and is relatively easy to apply Nagel et al. (2020); Cai et al. (2020); Chee et al. (2024). However, it is mainly limited to 4 or more bit regimes and can suffer steep performance drops with fewer bits Yao et al. (2022); Dettmers & Zettlemoyer (2023). This is especially the case for transformer-based models, which prove harder to quantize Bai et al. (2021); Qin et al. (2022) compared to convolutional architectures Martinez et al. (2019); Qin et al. (2020). On the other hand, QAT integrates quantization into pre-training and/or fine-tuning processes and can produce low-bit (especially binary) models with mild performance degradation (e.g.
Spectral Journey: How Transformers Predict the Shortest Path
Cohen, Andrew, Gromov, Andrey, Yang, Kaiyu, Tian, Yuandong
Decoder-only transformers lead to a step-change in capability of large language models. However, opinions are mixed as to whether they are really planning or reasoning. A path to making progress in this direction is to study the model's behavior in a setting with carefully controlled data. Then interpret the learned representations and reverse-engineer the computation performed internally. We study decoder-only transformer language models trained from scratch to predict shortest paths on simple, connected and undirected graphs. In this setting, the representations and the dynamics learned by the model are interpretable. We present three major results: (1) Two-layer decoder-only language models can learn to predict shortest paths on simple, connected graphs containing up to 10 nodes. (2) Models learn a graph embedding that is correlated with the spectral decomposition of the line graph. (3) Following the insights, we discover a novel approximate path-finding algorithm Spectral Line Navigator (SLN) that finds shortest path by greedily selecting nodes in the space of spectral embedding of the line graph.
Towards an Improved Understanding and Utilization of Maximum Manifold Capacity Representations
Schaeffer, Rylan, Lecomte, Victor, Pai, Dhruv Bhandarkar, Carranza, Andres, Isik, Berivan, Unell, Alyssa, Khona, Mikail, Yerxa, Thomas, LeCun, Yann, Chung, SueYeon, Gromov, Andrey, Shwartz-Ziv, Ravid, Koyejo, Sanmi
Maximum Manifold Capacity Representations (MMCR) is a recent multi-view self-supervised learning (MVSSL) method that matches or surpasses other leading MVSSL methods. MMCR is intriguing because it does not fit neatly into any of the commonplace MVSSL lineages, instead originating from a statistical mechanical perspective on the linear separability of data manifolds. In this paper, we seek to improve our understanding and our utilization of MMCR. To better understand MMCR, we leverage tools from high dimensional probability to demonstrate that MMCR incentivizes alignment and uniformity of learned embeddings. We then leverage tools from information theory to show that such embeddings maximize a well-known lower bound on mutual information between views, thereby connecting the geometric perspective of MMCR to the information-theoretic perspective commonly discussed in MVSSL. To better utilize MMCR, we mathematically predict and experimentally confirm non-monotonic changes in the pretraining loss akin to double descent but with respect to atypical hyperparameters. We also discover compute scaling laws that enable predicting the pretraining loss as a function of gradients steps, batch size, embedding dimension and number of views. We then show that MMCR, originally applied to image data, is performant on multimodal image-text data. By more deeply understanding the theoretical and empirical behavior of MMCR, our work reveals insights on improving MVSSL methods.
Grokking Modular Polynomials
Doshi, Darshil, He, Tianyu, Das, Aritra, Gromov, Andrey
Neural networks readily learn a subset of the modular arithmetic tasks, while failing to generalize on the rest. This limitation remains unmoved by the choice of architecture and training strategies. On the other hand, an analytical solution for the weights of Multi-layer Perceptron (MLP) networks that generalize on the modular addition task is known in the literature. In this work, we (i) extend the class of analytical solutions to include modular multiplication as well as modular addition with many terms. Additionally, we show that real networks trained on these datasets learn similar solutions upon generalization (grokking).
Learning to grok: Emergence of in-context learning and skill composition in modular arithmetic tasks
He, Tianyu, Doshi, Darshil, Das, Aritra, Gromov, Andrey
Large language models can solve tasks that were not present in the training set. This capability is believed to be due to in-context learning and skill composition. In this work, we study the emergence of in-context learning and skill composition in a collection of modular arithmetic tasks. Specifically, we consider a finite collection of linear modular functions $z = a \, x + b \, y \;\mathrm{mod}\; p$ labeled by the vector $(a, b) \in \mathbb{Z}_p^2$. We use some of these tasks for pre-training and the rest for out-of-distribution testing. We empirically show that a GPT-style transformer exhibits a transition from in-distribution to out-of-distribution generalization as the number of pre-training tasks increases. We find that the smallest model capable of out-of-distribution generalization requires two transformer blocks, while for deeper models, the out-of-distribution generalization phase is \emph{transient}, necessitating early stopping. Finally, we perform an interpretability study of the pre-trained models, revealing the highly structured representations in both phases; and discuss the learnt algorithm.
Is Model Collapse Inevitable? Breaking the Curse of Recursion by Accumulating Real and Synthetic Data
Gerstgrasser, Matthias, Schaeffer, Rylan, Dey, Apratim, Rafailov, Rafael, Sleight, Henry, Hughes, John, Korbak, Tomasz, Agrawal, Rajashree, Pai, Dhruv, Gromov, Andrey, Roberts, Daniel A., Yang, Diyi, Donoho, David L., Koyejo, Sanmi
The proliferation of generative models, combined with pretraining on web-scale data, raises a timely question: what happens when these models are trained on their own generated outputs? Recent investigations into model-data feedback loops discovered that such loops can lead to model collapse, a phenomenon where performance progressively degrades with each model-fitting iteration until the latest model becomes useless. However, several recent papers studying model collapse assumed that new data replace old data over time rather than assuming data accumulate over time. In this paper, we compare these two settings and show that accumulating data prevents model collapse. We begin by studying an analytically tractable setup in which a sequence of linear models are fit to the previous models' predictions. Previous work showed if data are replaced, the test error increases linearly with the number of model-fitting iterations; we extend this result by proving that if data instead accumulate, the test error has a finite upper bound independent of the number of iterations. We next empirically test whether accumulating data similarly prevents model collapse by pretraining sequences of language models on text corpora. We confirm that replacing data does indeed cause model collapse, then demonstrate that accumulating data prevents model collapse; these results hold across a range of model sizes, architectures and hyperparameters. We further show that similar results hold for other deep generative models on real data: diffusion models for molecule generation and variational autoencoders for image generation. Our work provides consistent theoretical and empirical evidence that data accumulation mitigates model collapse.
The Unreasonable Ineffectiveness of the Deeper Layers
Gromov, Andrey, Tirumala, Kushal, Shapourian, Hassan, Glorioso, Paolo, Roberts, Daniel A.
We empirically study a simple layer-pruning strategy for popular families of open-weight pretrained LLMs, finding minimal degradation of performance on different question-answering benchmarks until after a large fraction (up to half) of the layers are removed. To prune these models, we identify the optimal block of layers to prune by considering similarity across layers; then, to "heal" the damage, we perform a small amount of finetuning. In particular, we use parameter-efficient finetuning (PEFT) methods, specifically quantization and Low Rank Adapters (QLoRA), such that each of our experiments can be performed on a single A100 GPU. From a practical perspective, these results suggest that layer pruning methods can complement other PEFT strategies to further reduce computational resources of finetuning on the one hand, and can improve the memory and latency of inference on the other hand. From a scientific perspective, the robustness of these LLMs to the deletion of layers implies either that current pretraining methods are not properly leveraging the parameters in the deeper layers of the network or that the shallow layers play a critical role in storing knowledge.
Bridging Associative Memory and Probabilistic Modeling
Schaeffer, Rylan, Zahedi, Nika, Khona, Mikail, Pai, Dhruv, Truong, Sang, Du, Yilun, Ostrow, Mitchell, Chandra, Sarthak, Carranza, Andres, Fiete, Ila Rani, Gromov, Andrey, Koyejo, Sanmi
Associative memory and probabilistic modeling are two fundamental topics in artificial intelligence. The first studies recurrent neural networks designed to denoise, complete and retrieve data, whereas the second studies learning and sampling from probability distributions. Based on the observation that associative memory's energy functions can be seen as probabilistic modeling's negative log likelihoods, we build a bridge between the two that enables useful flow of ideas in both directions. We showcase four examples: First, we propose new energy-based models that flexibly adapt their energy functions to new in-context datasets, an approach we term \textit{in-context learning of energy functions}. Second, we propose two new associative memory models: one that dynamically creates new memories as necessitated by the training data using Bayesian nonparametrics, and another that explicitly computes proportional memory assignments using the evidence lower bound. Third, using tools from associative memory, we analytically and numerically characterize the memory capacity of Gaussian kernel density estimators, a widespread tool in probababilistic modeling. Fourth, we study a widespread implementation choice in transformers -- normalization followed by self attention -- to show it performs clustering on the hypersphere. Altogether, this work urges further exchange of useful ideas between these two continents of artificial intelligence.
To grok or not to grok: Disentangling generalization and memorization on corrupted algorithmic datasets
Doshi, Darshil, Das, Aritra, He, Tianyu, Gromov, Andrey
Robust generalization is a major challenge in deep learning, particularly when the number of trainable parameters is very large. In general, it is very difficult to know if the network has memorized a particular set of examples or understood the underlying rule (or both). Motivated by this challenge, we study an interpretable model where generalizing representations are understood analytically, and are easily distinguishable from the memorizing ones. Namely, we consider two-layer neural networks trained on modular arithmetic tasks where (ξ 100%) of labels are corrupted (i.e. We show that (i) it is possible for the network to memorize the corrupted labels and achieve 100% generalization at the same time; (ii) the memorizing neurons can be identified and pruned, lowering the accuracy on corrupted data and improving the accuracy on uncorrupted data; (iii) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve 100% accuracy on the uncorrupted dataset; and (iv) the effect of these regularization methods is ("mechanistically") interpretable: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones. Finally, we show that in the presence of regularization, the training dynamics involves two consecutive stages: first, the network undergoes the grokking dynamics reaching high train and test accuracy; second, it unlearns the memorizing representations, where train accuracy suddenly jumps from 100% to 100(1 ξ)%. The astounding progress of deep learning in the last decade has been facilitated by massive, highquality datasets. Annotated real-world datasets inevitably contain noisy labels, due to biases of annotation schemes (Paolacci et al., 2010; Cothey, 2004) or inherent ambiguity (Beyer et al., 2020). A key challenge in training large models is to prevent overfitting the noisy data and attain robust generalization performance. On the other hand, in large models, it is possible for memorization and generalization to coexist (Zhang et al., 2017; 2021). By and large, the tussle between memorization and generalization, especially in the presence of label corruption, remains poorly understood. In generative language models the problem of memorization is even more nuanced. On the one hand, some factual knowledge is critical for the language models to produce accurate information.
Grokking modular arithmetic
Gromov, Andrey
We present a simple neural network that can learn modular arithmetic tasks and exhibits a sudden jump in generalization known as ``grokking''. Concretely, we present (i) fully-connected two-layer networks that exhibit grokking on various modular arithmetic tasks under vanilla gradient descent with the MSE loss function in the absence of any regularization; (ii) evidence that grokking modular arithmetic corresponds to learning specific feature maps whose structure is determined by the task; (iii) analytic expressions for the weights -- and thus for the feature maps -- that solve a large class of modular arithmetic tasks; and (iv) evidence that these feature maps are also found by vanilla gradient descent as well as AdamW, thereby establishing complete interpretability of the representations learnt by the network.