Sahai, Anant
Can Custom Models Learn In-Context? An Exploration of Hybrid Architecture Performance on In-Context Learning Tasks
Campbell, Ryan, Lojo, Nelson, Viswanadha, Kesava, Tryggestad, Christoffer Grondal, Sun, Derrick Han, Vijapurapu, Sriteja, Rolfsen, August, Sahai, Anant
In-Context Learning (ICL) is a phenomenon where task learning occurs through a prompt sequence without the necessity of parameter updates. ICL in Multi-Headed Attention (MHA) with absolute positional embedding has been the focus of more study than other sequence model varieties. We examine implications of architectural differences between GPT-2 and LLaMa as well as LlaMa and Mamba. We extend work done by Garg et al. (2022) and Park et al. (2024) to GPT-2/LLaMa hybrid and LLaMa/Mamba hybrid models - examining the interplay between sequence transformation blocks and regressive performance in-context. We note that certain architectural changes cause degraded training efficiency/ICL accuracy by converging to suboptimal predictors or converging slower. We also find certain hybrids showing optimistic performance improvements, informing potential future ICL-focused architecture modifications. Additionally, we propose the "ICL regression score", a scalar metric describing a model's whole performance on a specific task. Compute limitations impose restrictions on our architecture-space, training duration, number of training runs, function class complexity, and benchmark complexity. To foster reproducible and extensible research, we provide a typed, modular, and extensible Python package on which we run all experiments.
Provable Weak-to-Strong Generalization via Benign Overfitting
Wu, David X., Sahai, Anant
The classic teacher-student model in machine learning posits that a strong teacher supervises a weak student to improve the student's capabilities. We instead consider the inverted situation, where a weak teacher supervises a strong student with imperfect pseudolabels. This paradigm was recently brought forth by Burns et al.'23 and termed \emph{weak-to-strong generalization}. We theoretically investigate weak-to-strong generalization for binary and multilabel classification in a stylized overparameterized spiked covariance model with Gaussian covariates where the weak teacher's pseudolabels are asymptotically like random guessing. Under these assumptions, we provably identify two asymptotic phases of the strong student's generalization after weak supervision: (1) successful generalization and (2) random guessing. Our techniques should eventually extend to weak-to-strong multiclass classification. Towards doing so, we prove a tight lower tail inequality for the maximum of correlated Gaussians, which may be of independent interest. Understanding the multilabel setting reinforces the value of using logits for weak supervision when they are available.
Precise Asymptotic Generalization for Multiclass Classification with Overparameterized Linear Models
Wu, David X., Sahai, Anant
We study the asymptotic generalization of an overparameterized linear model for multiclass classification under the Gaussian covariates bi-level model introduced in Subramanian et al.~'22, where the number of data points, features, and classes all grow together. We fully resolve the conjecture posed in Subramanian et al.~'22, matching the predicted regimes for generalization. Furthermore, our new lower bounds are akin to an information-theoretic strong converse: they establish that the misclassification rate goes to 0 or 1 asymptotically. One surprising consequence of our tight results is that the min-norm interpolating classifier can be asymptotically suboptimal relative to noninterpolating classifiers in the regime where the min-norm interpolating regressor is known to be optimal. The key to our tight analysis is a new variant of the Hanson-Wright inequality which is broadly useful for multiclass problems with sparse labels. As an application, we show that the same type of analysis can be used to analyze the related multilabel classification problem under the same bi-level ensemble.
Generalization for multiclass classification with overparameterized linear models
Subramanian, Vignesh, Arya, Rahul, Sahai, Anant
Via an overparameterized linear model with Gaussian features, we provide conditions for good generalization for multiclass classification of minimum-norm interpolating solutions in an asymptotic setting where both the number of underlying features and the number of classes scale with the number of training points. The survival/contamination analysis framework for understanding the behavior of overparameterized learning problems is adapted to this setting, revealing that multiclass classification qualitatively behaves like binary classification in that, as long as there are not too many classes (made precise in the paper), it is possible to generalize well even in some settings where the corresponding regression tasks would not generalize. Besides various technical challenges, it turns out that the key difference from the binary classification setting is that there are relatively fewer positive training examples of each class in the multiclass setting as the number of classes increases, making the multiclass problem "harder" than the binary one.
Classification and Adversarial examples in an Overparameterized Linear Model: A Signal Processing Perspective
Narang, Adhyyan, Muthukumar, Vidya, Sahai, Anant
State-of-the-art deep learning classifiers are heavily overparameterized with respect to the amount of training examples and observed to generalize well on "clean" data, but be highly susceptible to infinitesmal adversarial perturbations. In this paper, we identify an overparameterized linear ensemble, that uses the "lifted" Fourier feature map, that demonstrates both of these behaviors. The input is one-dimensional, and the adversary is only allowed to perturb these inputs and not the non-linear features directly. We find that the learned model is susceptible to adversaries in an intermediate regime where classification generalizes but regression does not. Notably, the susceptibility arises despite the absence of model mis-specification or label noise, which are commonly cited reasons for adversarial-susceptibility. These results are extended theoretically to a random-Fourier-sum setup that exhibits double-descent behavior. In both feature-setups, the adversarial vulnerability arises because of a phenomenon we term spatial localization: the predictions of the learned model are markedly more sensitive in the vicinity of training points than elsewhere. This sensitivity is a consequence of feature lifting and is reminiscent of Gibb's and Runge's phenomena from signal processing and functional analysis. Despite the adversarial susceptibility, we find that classification with these features can be easier than the more commonly studied "independent feature" models.
On the Impossibility of Convergence of Mixed Strategies with No Regret Learning
Muthukumar, Vidya, Phade, Soham, Sahai, Anant
We study convergence properties of the mixed strategies that result from a general class of optimal no regret learning strategies in a repeated game setting where the stage game is any 2 by 2 competitive game (i.e. game for which all the Nash equilibria (NE) of the game are completely mixed). We consider the class of strategies whose information set at each step is the empirical average of the opponent's realized play (and the step number), that we call mean based strategies. We first show that there does not exist any optimal no regret, mean based strategy for player 1 that would result in the convergence of her mixed strategies (in probability) against an opponent that plays his Nash equilibrium mixed strategy at each step. Next, we show that this last iterate divergence necessarily occurs if player 2 uses any adaptive strategy with a minimal randomness property. This property is satisfied, for example, by any fixed sequence of mixed strategies for player 2 that converges to NE. We conjecture that this property holds when both players use optimal no regret learning strategies against each other, leading to the divergence of the mixed strategies with a positive probability. Finally, we show that variants of mean based strategies using recency bias, which have yielded last iterate convergence in deterministic min max optimization, continue to lead to this last iterate divergence. This demonstrates a crucial difference in outcomes between using the opponent's mixtures and realizations to make strategy updates.
Classification vs regression in overparameterized regimes: Does the loss function matter?
Muthukumar, Vidya, Narang, Adhyyan, Subramanian, Vignesh, Belkin, Mikhail, Hsu, Daniel, Sahai, Anant
Paradigmatic problems in supervised machine learning (ML) involve predicting an output response from an input, based on patterns extracted from a (training) dataset. In classification, the output response is (finitely) discrete and we need to classify input data into one of these discrete categories. In regression, the output is continuous, typically a real number or a vector. Owing to this important distinction in output response, the two tasks are typically treated differently. The differences in treatment manifest in two phases of modern ML: optimization (training), which consists of an algorithmic procedure to extract a predictor from the training data, typically by minimizing the training loss (also called empirical risk); and generalization (testing), which consists of an evaluation of the obtained predictor on a separate test, or validation, dataset. Traditionally, the choice of loss functions for both phases is starkly different across classification and regression tasks. The squared-loss function is typically used both for the training and the testing phases in regression. In contrast, the hinge or logistic (cross-entropy for multi-class problems) loss functions are typically used in the training phase of classification, while the very different 0-1 loss function is used for testing.
Harmless interpolation of noisy data in regression
Muthukumar, Vidya, Vodrahalli, Kailas, Sahai, Anant
In classification problems (i.e. when the labels Y are discrete), the scaling of the test error with respect to n is determined by characterizations of the VC-dimension [2]/Rademacher complexity [3] of the function class, which in the worst case increases with its number of parameters. In regression (i.e. when the labels Y are continuous), the mean-squared error of the ordinary least-squares estimate is characterized by the condition number of the regression matrix, which is reasonable for appropriate ratios of d/n but tends to increase astronomically as d approaches n. The qualitative fear is the same: if the function class is too complex, it starts to overfit noise and can generalize poorly to unseen test data. But there is a gap between "can" and "will" -- and indeed this conventional wisdom has been challenged by the recent advent of deeper and deeper neural networks. In particular, a thought-provoking paper [4] noted that several deep neural networks generalize well despite achieving zero or close to zero training error, and being so expressive that they even have the ability to fit pure noise. As they put it, "understanding deep learning requires rethinking generalization". How can we reconcile the fact that good interpolative solutions exist with the classical bias-variance tradeoff? These phenomena are being actively investigated in a statistical sense [5,6] and a computational sense [7-9] in classification problems and/or noiseless models.