Nolte, Niklas
MagicPIG: LSH Sampling for Efficient LLM Generation
Chen, Zhuoming, Sadhukhan, Ranajoy, Ye, Zihao, Zhou, Yang, Zhang, Jianyu, Nolte, Niklas, Tian, Yuandong, Douze, Matthijs, Bottou, Leon, Jia, Zhihao, Chen, Beidi
Large language models (LLMs) with long context windows have gained significant attention. However, the KV cache, stored to avoid re-computation, becomes a bottleneck. Various dynamic sparse or TopK-based attention approximation methods have been proposed to leverage the common insight that attention is sparse. In this paper, we first show that TopK attention itself suffers from quality degradation in certain downstream tasks because attention is not always as sparse as expected. Rather than selecting the keys and values with the highest attention scores, sampling with theoretical guarantees can provide a better estimation for attention output. To make the sampling-based approximation practical in LLM generation, we propose MagicPIG, a heterogeneous system based on Locality Sensitive Hashing (LSH). MagicPIG significantly reduces the workload of attention computation while preserving high accuracy for diverse tasks. MagicPIG stores the LSH hash tables and runs the attention computation on the CPU, which allows it to serve longer contexts and larger batch sizes with high approximation accuracy. MagicPIG can improve decoding throughput by up to $5\times$ across various GPU hardware and achieve 54ms decoding latency on a single RTX 4090 for Llama-3.1-8B-Instruct model with a context of 96k tokens. The code is available at https://github.com/Infini-AI-Lab/MagicPIG.
Transformers Can Navigate Mazes With Multi-Step Prediction
Nolte, Niklas, Kitouni, Ouail, Williams, Adina, Rabbat, Mike, Ibrahim, Mark
Despite their remarkable success in language modeling, transformers trained to predict the next token in a sequence struggle with long-term planning. This limitation is particularly evident in tasks requiring foresight to plan multiple steps ahead such as maze navigation. The standard next single token prediction objective, however, offers no explicit mechanism to predict multiple steps ahead - or revisit the path taken so far. Consequently, in this work we study whether explicitly predicting multiple steps ahead (and backwards) can improve transformers' maze navigation. We train parameter-matched transformers from scratch, under identical settings, to navigate mazes of varying types and sizes with standard next token prediction and MLM-U, an objective explicitly predicting multiple steps ahead and backwards. We find that MLM-U considerably improves transformers' ability to navigate mazes compared to standard next token prediction across maze types and complexities. We also find MLM-U training is 4x more sample efficient and converges 2x faster in terms of GPU training hours relative to next token training. Finally, for more complex mazes we find MLM-U benefits from scaling to larger transformers. Remarkably, we find transformers trained with MLM-U outperform larger transformers trained with next token prediction using additional supervision from A* search traces. We hope these findings underscore the promise of learning objectives to advance transformers' capacity for long-term planning. The code can be found at https://github.com/facebookresearch/maze_navigation_MLMU
The Factorization Curse: Which Tokens You Predict Underlie the Reversal Curse and More
Kitouni, Ouail, Nolte, Niklas, Bouchacourt, Diane, Williams, Adina, Rabbat, Mike, Ibrahim, Mark
Today's best language models still struggle with hallucinations: factually incorrect generations, which impede their ability to reliably retrieve information seen during training. The reversal curse, where models cannot recall information when probed in a different order than was encountered during training, exemplifies this in information retrieval. We reframe the reversal curse as a factorization curse - a failure of models to learn the same joint distribution under different factorizations. Through a series of controlled experiments with increasing levels of realism including WikiReversal, a setting we introduce to closely simulate a knowledge intensive finetuning task, we find that the factorization curse is an inherent failure of the next-token prediction objective used in popular large language models. Moreover, we demonstrate reliable information retrieval cannot be solved with scale, reversed tokens, or even naive bidirectional-attention training. Consequently, various approaches to finetuning on specialized data would necessarily provide mixed results on downstream tasks, unless the model has already seen the right sequence of tokens. Across five tasks of varying levels of complexity, our results uncover a promising path forward: factorization-agnostic objectives can significantly mitigate the reversal curse and hint at improved knowledge storage and planning capabilities.
From Neurons to Neutrons: A Case Study in Interpretability
Kitouni, Ouail, Nolte, Niklas, Pรฉrez-Dรญaz, Vรญctor Samuel, Trifinopoulos, Sokratis, Williams, Mike
Mechanistic Interpretability (MI) promises a path toward fully understanding how neural networks make their predictions. Prior work demonstrates that even when trained to perform simple arithmetic, models can implement a variety of algorithms (sometimes concurrently) depending on initialization and hyperparameters. Does this mean neuron-level interpretability techniques have limited applicability? We argue that high-dimensional neural networks can learn low-dimensional representations of their training data that are useful beyond simply making good predictions. Such representations can be understood through the mechanistic interpretability lens and provide insights that are surprisingly faithful to human-derived domain knowledge. This indicates that such approaches to interpretability can be useful for deriving a new understanding of a problem from models trained to solve it. As a case study, we extract nuclear physics concepts by studying models trained to reproduce nuclear data.
Memory Mosaics
Zhang, Jianyu, Nolte, Niklas, Sadhukhan, Ranajoy, Chen, Beidi, Bottou, Lรฉon
This paper presents a learning system architecture, Memory Mosaics, in which multiple associative memories work in concert to carry out a prediction task of interest. Such systems are closely related to memory networks [Weston et al., 2014, Sukhbaatar et al., 2015] and resemble transformers [Vaswani et al., 2017] despite significant differences. Like transformers, Memory Mosaics possesses some of the disentanglement and compositional capabilities that have long eluded machine learning systems [Lake and Baroni, 2018]. Unlike transformers whose internal mechanism are hard to decipher [Olsson et al., 2022, Bietti et al., 2024], Memory Mosaics achieve these capabilities in comparatively transparent ways. The three main contributions of this work are (a) recognizing and exploiting a similarity between smoothing associative memories and self-attention, (b) identifying and illustrating the predictive disentanglement principle which explains how training decomposes the overall task in interesting ways, and (c) showing that this comparatively transparent architecture matches the performance of decoding transformers on a language modeling task. Section 2 describes the basic architecture and outlines its consequences. Section 3 illustrates the predictive disentanglement principle. Section 4 extends these ideas to fully formed memory mosaics.
Transforming the Bootstrap: Using Transformers to Compute Scattering Amplitudes in Planar N = 4 Super Yang-Mills Theory
Cai, Tianji, Merz, Garrett W., Charton, Franรงois, Nolte, Niklas, Wilhelm, Matthias, Cranmer, Kyle, Dixon, Lance J.
We pursue the use of deep learning methods to improve state-of-the-art computations in theoretical high-energy physics. Planar N = 4 Super Yang-Mills theory is a close cousin to the theory that describes Higgs boson production at the Large Hadron Collider; its scattering amplitudes are large mathematical expressions containing integer coefficients. In this paper, we apply Transformers to predict these coefficients. The problem can be formulated in a language-like representation amenable to standard cross-entropy training objectives. We design two related experiments and show that the model achieves high accuracy (> 98%) on both tasks. Our work shows that Transformers can be applied successfully to problems in theoretical physics that require exact solutions.
DiSK: A Diffusion Model for Structured Knowledge
Kitouni, Ouail, Nolte, Niklas, Hensman, James, Mitra, Bhaskar
Structured (dictionary-like) data presents challenges for left-to-right language models, as they can struggle with structured entities for a wide variety of reasons such as formatting and sensitivity to the order in which attributes are presented. Tabular generative models suffer from a different set of limitations such as their lack of flexibility. We introduce Diffusion Models of Structured Knowledge (DiSK) - a new architecture and training approach specialized for structured data. DiSK handles text, categorical, and continuous numerical data using a Gaussian mixture model approach, which allows for improved precision when dealing with numbers. It employs diffusion training to model relationships between properties. Experiments demonstrate DiSK's state-of-the-art performance on tabular data modeling, synthesis, and imputation on over 15 datasets across diverse domains. DiSK provides an effective inductive bias for generative modeling and manipulation of structured data. The techniques we propose could open the door to improved knowledge manipulation in future language models.
Salsa Fresca: Angular Embeddings and Pre-Training for ML Attacks on Learning With Errors
Stevens, Samuel, Wenger, Emily, Li, Cathy, Nolte, Niklas, Saxena, Eshika, Charton, Franรงois, Lauter, Kristin
Learning with Errors (LWE) is a hard math problem underlying recently standardized post-quantum cryptography (PQC) systems for key exchange and digital signatures. Prior work proposed new machine learning (ML)-based attacks on LWE problems with small, sparse secrets, but these attacks require millions of LWE samples to train on and take days to recover secrets. We propose three key methods -- better preprocessing, angular embeddings and model pre-training -- to improve these attacks, speeding up preprocessing by $25\times$ and improving model sample efficiency by $10\times$. We demonstrate for the first time that pre-training improves and reduces the cost of ML attacks on LWE. Our architecture improvements enable scaling to larger-dimension LWE problems: this work is the first instance of ML attacks recovering sparse binary secrets in dimension $n=1024$, the smallest dimension used in practice for homomorphic encryption applications of LWE where sparse binary secrets are proposed.
NuCLR: Nuclear Co-Learned Representations
Kitouni, Ouail, Nolte, Niklas, Trifinopoulos, Sokratis, Kantamneni, Subhash, Williams, Mike
We introduce Nuclear Co-Learned Representations (NuCLR), a deep learning model that predicts various nuclear observables, including binding and decay energies, and nuclear charge radii. The model is trained using a multi-task approach with shared representations and obtains state-of-the-art performance, achieving levels of precision that are crucial for understanding fundamental phenomena in nuclear (astro)physics. We also report an intriguing finding that the learned representation of NuCLR exhibits the prominent emergence of crucial aspects of the nuclear shell model, namely the shell structure, including the well-known magic numbers, and the Pauli Exclusion Principle. This suggests that the model is capable of capturing the underlying physical principles and that our approach has the potential to offer valuable insights into nuclear theory.
Expressive Monotonic Neural Networks
Kitouni, Ouail, Nolte, Niklas, Williams, Michael
The monotonic dependence of the outputs of a neural network on some of its inputs is a crucial inductive bias in many scenarios where domain knowledge dictates such behavior. This is especially important for interpretability and fairness considerations. In a broader context, scenarios in which monotonicity is important can be found in finance, medicine, physics, and other disciplines. It is thus desirable to build neural network architectures that implement this inductive bias provably. In this work, we propose a weight-constrained architecture with a single residual connection to achieve exact monotonic dependence in any subset of the inputs. The weight constraint scheme directly controls the Lipschitz constant of the neural network and thus provides the additional benefit of robustness. Compared to currently existing techniques used for monotonicity, our method is simpler in implementation and in theory foundations, has negligible computational overhead, is guaranteed to produce monotonic dependence, and is highly expressive. We show how the algorithm is used to train powerful, robust, and interpretable discriminators that achieve competitive performance compared to current state-of-the-art methods across various benchmarks, from social applications to the classification of the decays of subatomic particles produced at the CERN Large Hadron Collider.