Kakade, Sham
Pareto Frontiers in Neural Feature Learning: Data, Compute, Width, and Luck
Edelman, Benjamin L., Goel, Surbhi, Kakade, Sham, Malach, Eran, Zhang, Cyril
In modern deep learning, algorithmic choices (such as width, depth, and learning rate) are known to modulate nuanced resource tradeoffs. This work investigates how these complexities necessarily arise for feature learning in the presence of computational-statistical gaps. We begin by considering offline sparse parity learning, a supervised classification problem which admits a statistical query lower bound for gradient-based training of a multilayer perceptron. This lower bound can be interpreted as a multi-resource tradeoff frontier: successful learning can only occur if one is sufficiently rich (large model), knowledgeable (large dataset), patient (many training iterations), or lucky (many random guesses). We show, theoretically and experimentally, that sparse initialization and increasing network width yield significant improvements in sample efficiency in this setting. Here, width plays the role of parallel search: it amplifies the probability of finding "lottery ticket" neurons, which learn sparse features more sample-efficiently. Finally, we show that the synthetic sparse parity task can be useful as a proxy for real problems requiring axis-aligned feature learning. We demonstrate improved sample efficiency on tabular classification benchmarks by using wide, sparsely-initialized MLP models; these networks sometimes outperform tuned random forests.
Learning an Inventory Control Policy with General Inventory Arrival Dynamics
Andaz, Sohrab, Eisenach, Carson, Madeka, Dhruv, Torkkola, Kari, Jia, Randy, Foster, Dean, Kakade, Sham
In this paper we address the problem of learning and backtesting inventory control policies in the presence of general arrival dynamics -- which we term as a quantity-over-time arrivals model (QOT). We also allow for order quantities to be modified as a post-processing step to meet vendor constraints such as order minimum and batch size constraints -- a common practice in real supply chains. To the best of our knowledge this is the first work to handle either arbitrary arrival dynamics or an arbitrary downstream post-processing of order quantities. Building upon recent work (Madeka et al., 2022) we similarly formulate the periodic review inventory control problem as an exogenous decision process, where most of the state is outside the control of the agent. Madeka et al. (2022) show how to construct a simulator that replays historic data to solve this class of problem. In our case, we incorporate a deep generative model for the arrivals process as part of the history replay. By formulating the problem as an exogenous decision process, we can apply results from Madeka et al. (2022) to obtain a reduction to supervised learning. Finally, we show via simulation studies that this approach yields statistically significant improvements in profitability over production baselines. Using data from an ongoing real-world A/B test, we show that Gen-QOT generalizes well to off-policy data.
AdANNS: A Framework for Adaptive Semantic Search
Rege, Aniket, Kusupati, Aditya, S, Sharan Ranjit, Fan, Alan, Cao, Qingqing, Kakade, Sham, Jain, Prateek, Farhadi, Ali
Web-scale search systems learn an encoder to embed a given query which is then hooked into an approximate nearest neighbor search (ANNS) pipeline to retrieve similar data points. To accurately capture tail queries and data points, learned representations typically are rigid, high-dimensional vectors that are generally used as-is in the entire ANNS pipeline and can lead to computationally expensive retrieval. In this paper, we argue that instead of rigid representations, different stages of ANNS can leverage adaptive representations of varying capacities to achieve significantly better accuracy-compute trade-offs, i.e., stages of ANNS that can get away with more approximate computation should use a lower-capacity representation of the same data point. To this end, we introduce AdANNS, a novel ANNS design framework that explicitly leverages the flexibility of Matryoshka Representations. We demonstrate state-of-the-art accuracy-compute trade-offs using novel AdANNS-based key ANNS building blocks like search data structures (AdANNS-IVF) and quantization (AdANNS-OPQ). For example on ImageNet retrieval, AdANNS-IVF is up to 1.5% more accurate than the rigid representations-based IVF at the same compute budget; and matches accuracy while being up to 90x faster in wall-clock time. For Natural Questions, 32-byte AdANNS-OPQ matches the accuracy of the 64-byte OPQ baseline constructed using rigid representations -- same accuracy at half the cost! We further show that the gains from AdANNS translate to modern-day composite ANNS indices that combine search structures and quantization. Finally, we demonstrate that AdANNS can enable inference-time adaptivity for compute-aware search on ANNS indices built non-adaptively on matryoshka representations. Code is open-sourced at https://github.com/RAIVNLab/AdANNS.
MatFormer: Nested Transformer for Elastic Inference
Devvrit, null, Kudugunta, Sneha, Kusupati, Aditya, Dettmers, Tim, Chen, Kaifeng, Dhillon, Inderjit, Tsvetkov, Yulia, Hajishirzi, Hannaneh, Kakade, Sham, Farhadi, Ali, Jain, Prateek
Transformer models are deployed in a wide range of settings, from multi-accelerator clusters to standalone mobile phones. The diverse inference constraints in these scenarios necessitate practitioners to train foundation models such as PaLM 2, Llama, & ViTs as a series of models of varying sizes. Due to significant training costs, only a select few model sizes are trained and supported, limiting more fine-grained control over relevant tradeoffs, including latency, cost, and accuracy. This work introduces MatFormer, a nested Transformer architecture designed to offer elasticity in a variety of deployment constraints. Each Feed Forward Network (FFN) block of a MatFormer model is jointly optimized with a few nested smaller FFN blocks. This training procedure allows for the Mix'n'Match of model granularities across layers -- i.e., a trained universal MatFormer model enables extraction of hundreds of accurate smaller models, which were never explicitly optimized. We empirically demonstrate MatFormer's effectiveness across different model classes (decoders & encoders), modalities (language & vision), and scales (up to 2.6B parameters). We find that a 2.6B decoder-only MatFormer language model (MatLM) allows us to extract smaller models spanning from 1.5B to 2.6B, each exhibiting comparable validation loss and one-shot downstream evaluations to their independently trained counterparts. Furthermore, we observe that smaller encoders extracted from a universal MatFormer-based ViT (MatViT) encoder preserve the metric-space structure for adaptive large-scale retrieval. Finally, we showcase that speculative decoding with the accurate and consistent submodels extracted from MatFormer can further reduce inference latency.
On Provable Copyright Protection for Generative Models
Vyas, Nikhil, Kakade, Sham, Barak, Boaz
There is a growing concern that learned conditional generative models may output samples that are substantially similar to some copyrighted data $C$ that was in their training set. We give a formal definition of $\textit{near access-freeness (NAF)}$ and prove bounds on the probability that a model satisfying this definition outputs a sample similar to $C$, even if $C$ is included in its training set. Roughly speaking, a generative model $p$ is $\textit{$k$-NAF}$ if for every potentially copyrighted data $C$, the output of $p$ diverges by at most $k$-bits from the output of a model $q$ that $\textit{did not access $C$ at all}$. We also give generative model learning algorithms, which efficiently modify the original generative model learning algorithm in a black box manner, that output generative models with strong bounds on the probability of sampling protected content. Furthermore, we provide promising experiments for both language (transformers) and image (diffusion) generative models, showing minimal degradation in output quality while ensuring strong protections against sampling protected content.
Scaling Laws for Imitation Learning in NetHack
Tuyls, Jens, Madeka, Dhruv, Torkkola, Kari, Foster, Dean, Narasimhan, Karthik, Kakade, Sham
Imitation Learning (IL) is one of the most widely used methods in machine learning. Yet, while powerful, many works find it is often not able to fully recover the underlying expert behavior [1-3]. However, none of these works deeply investigate the role of scaling up the model and data size. Inspired by recent work in Natural Language Processing (NLP) [4, 5] where "scaling up" has resulted in increasingly more capable LLMs, we investigate whether carefully scaling up model and data size can bring similar improvements in the imitation learning setting. To demonstrate our findings, we focus on the game of NetHack, a challenging environment featuring procedural generation, stochasticity, long-term dependencies, and partial observability. We find IL loss and mean return scale smoothly with the compute budget and are strongly correlated, resulting in power laws for training compute-optimal IL agents with respect to model size and number of samples. We forecast and train several NetHack agents with IL and find they outperform prior state-of-the-art by at least 2x in all settings. Our work both demonstrates the scaling behavior of imitation learning in a challenging domain, as well as the viability of scaling up current approaches for increasingly capable agents in NetHack, a game that remains elusively hard for current AI systems.
Beyond Implicit Bias: The Insignificance of SGD Noise in Online Learning
Vyas, Nikhil, Morwani, Depen, Zhao, Rosie, Kaplun, Gal, Kakade, Sham, Barak, Boaz
The success of SGD in deep learning has been ascribed by prior works to the implicit bias induced by high learning rate or small batch size ("SGD noise"). While prior works that focused on offline learning (i.e., multiple-epoch training), we study the impact of SGD noise on online (i.e., single epoch) learning. Through an extensive empirical analysis of image and language data, we demonstrate that large learning rate and small batch size do not confer any implicit bias advantages in online learning. In contrast to offline learning, the benefits of SGD noise in online learning are strictly computational, facilitating larger or more cost-effective gradient steps. Our work suggests that SGD in the online regime can be construed as taking noisy steps along the "golden path" of the noiseless gradient flow algorithm. We provide evidence to support this hypothesis by conducting experiments that reduce SGD noise during training and by measuring the pointwise functional distance between models trained with varying SGD noise levels, but at equivalent loss values. Our findings challenge the prevailing understanding of SGD and offer novel insights into its role in online learning.
Modified Gauss-Newton Algorithms under Noise
Pillutla, Krishna, Roulet, Vincent, Kakade, Sham, Harchaoui, Zaid
The Gauss-Newton method and its variants such as the Levenberg-Marquardt method [15, 16] have been applied successfully in phase retrieval [5, 11, 20], nonlinear control [22, 24], and non-negative matrix factorization [12]. Modern machine learning problems such as deep learning possess a similar compositional structure, which makes Gauss-Newton-like algorithms potential good candidates [8, 26, 30]. However, in such problems, we are often interested in the generalization performance on unseen data. It is unclear whether the additional cost of solving the subproblems can be amortized by the superior efficiency of Gauss-Newton-like algorithms. In this paper, we investigate whether modified Gauss-Newton methods or prox-linear algorithms with incremental gradient inner loops are superior to direct stochastic subgradient algorithms for nonsmooth problems with a compositional objective and a finite-sum structure in terms of generalization error.
Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit
Barak, Boaz, Edelman, Benjamin L., Goel, Surbhi, Kakade, Sham, Malach, Eran, Zhang, Cyril
There is mounting evidence of emergent phenomena in the capabilities of deep learning methods as we scale up datasets, model sizes, and training times. While there are some accounts of how these resources modulate statistical capacity, far less is known about their effect on the computational problem of model training. This work conducts such an exploration through the lens of learning a $k$-sparse parity of $n$ bits, a canonical discrete search problem which is statistically easy but computationally hard. Empirically, we find that a variety of neural networks successfully learn sparse parities, with discontinuous phase transitions in the training curves. On small instances, learning abruptly occurs at approximately $n^{O(k)}$ iterations; this nearly matches SQ lower bounds, despite the apparent lack of a sparse prior. Our theoretical analysis shows that these observations are not explained by a Langevin-like mechanism, whereby SGD "stumbles in the dark" until it finds the hidden set of features (a natural algorithm which also runs in $n^{O(k)}$ time). Instead, we show that SGD gradually amplifies the sparse solution via a Fourier gap in the population gradient, making continual progress that is invisible to loss and error metrics.
Recurrent Convolutional Neural Networks Learn Succinct Learning Algorithms
Goel, Surbhi, Kakade, Sham, Kalai, Adam Tauman, Zhang, Cyril
Neural networks (NNs) struggle to efficiently solve certain problems, such as learning parities, even when there are simple learning algorithms for those problems. Can NNs discover learning algorithms on their own? We exhibit a NN architecture that, in polynomial time, learns as well as any efficient learning algorithm describable by a constant-sized program. For example, on parity problems, the NN learns as well as Gaussian elimination, an efficient algorithm that can be succinctly described. Our architecture combines both recurrent weight sharing between layers and convolutional weight sharing to reduce the number of parameters down to a constant, even though the network itself may have trillions of nodes. While in practice the constants in our analysis are too large to be directly meaningful, our work suggests that the synergy of Recurrent and Convolutional NNs (RCNNs) may be more natural and powerful than either alone, particularly for concisely parameterizing discrete algorithms.