Goto

Collaborating Authors

 Shenoy, Pradeep


Universal Model Routing for Efficient LLM Inference

arXiv.org Artificial Intelligence

Large language models' significant advances in capabilities are accompanied by significant increases in inference costs. Model routing is a simple technique for reducing inference cost, wherein one maintains a pool of candidate LLMs, and learns to route each prompt to the smallest feasible LLM. Existing works focus on learning a router for a fixed pool of LLMs. In this paper, we consider the problem of dynamic routing, where new, previously unobserved LLMs are available at test time. We propose a new approach to this problem that relies on representing each LLM as a feature vector, derived based on predictions on a set of representative prompts. Based on this, we detail two effective strategies, relying on cluster-based routing and a learned cluster map respectively. We prove that these strategies are estimates of a theoretically optimal routing rule, and provide an excess risk bound to quantify their errors. Experiments on a range of public benchmarks show the effectiveness of the proposed strategies in routing amongst more than 30 unseen LLMs.


Masked Generative Nested Transformers with Decode Time Scaling

arXiv.org Artificial Intelligence

Recent advances in visual generation have made significant strides in producing content of exceptional quality. However, most methods suffer from a fundamental problem - a bottleneck of inference computational efficiency. Most of these algorithms involve multiple passes over a transformer model to generate tokens or denoise inputs. However, the model size is kept consistent throughout all iterations, which makes it computationally expensive. In this work, we aim to address this issue primarily through two key ideas - (a) not all parts of the generation process need equal compute, and we design a decode time model scaling schedule to utilize compute effectively, and (b) we can cache and reuse some of the computation. Combining these two ideas leads to using smaller models to process more tokens while large models process fewer tokens. These different-sized models do not increase the parameter size, as they share parameters. We rigorously experiment with ImageNet256$\times$256 , UCF101, and Kinetics600 to showcase the efficacy of the proposed method for image/video generation and frame prediction. Our experiments show that with almost $3\times$ less compute than baseline, our model obtains competitive performance.


Improving Generalization via Meta-Learning on Hard Samples

arXiv.org Artificial Intelligence

Learned reweighting (LRW) approaches to supervised learning use an optimization criterion to assign weights for training instances, in order to maximize performance on a representative validation dataset. We pose and formalize the problem of optimized selection of the validation set used in LRW training, to improve classifier generalization. In particular, we show that using hard-to-classify instances in the validation set has both a theoretical connection to, and strong empirical evidence of generalization. We provide an efficient algorithm for training this meta-optimized model, as well as a simple train-twice heuristic for careful comparative study. We demonstrate that LRW with easy validation data performs consistently worse than LRW with hard validation data, establishing the validity of our meta-optimization problem. Our proposed algorithm outperforms a wide range of baselines on a range of datasets and domain shift challenges (Imagenet-1K, CIFAR-100, Clothing-1M, CAMELYON, WILDS, etc.), with ~1% gains using VIT-B on Imagenet. We also show that using naturally hard examples for validation (Imagenet-R / Imagenet-A) in LRW training for Imagenet improves performance on both clean and naturally hard test instances by 1-2%. Secondary analyses show that using hard validation data in an LRW framework improves margins on test data, hinting at the mechanism underlying our empirical gains. We believe this work opens up new research directions for the meta-optimization of meta-learning in a supervised learning context.


Selective classification using a robust meta-learning approach

arXiv.org Artificial Intelligence

Predictive uncertainty-a model's self awareness regarding its accuracy on an input-is key for both building robust models via training interventions and for test-time applications such as selective classification. We propose a novel instance-conditioned reweighting approach that captures predictive uncertainty using an auxiliary network and unifies these train- and test-time applications. The auxiliary network is trained using a meta-objective in a bilevel optimization framework. A key contribution of our proposal is the meta-objective of minimizing the dropout variance, an approximation of Bayesian Predictive uncertainty. We show in controlled experiments that we effectively capture the diverse specific notions of uncertainty through this meta-objective, while previous approaches only capture certain aspects. These results translate to significant gains in real-world settings-selective classification, label noise, domain adaptation, calibration-and across datasets-Imagenet, Cifar100, diabetic retinopathy, Camelyon, WILDs, Imagenet-C,-A,-R, Clothing1M, etc. For Diabetic Retinopathy, we see upto 3.4%/3.3% accuracy and AUC gains over SOTA in selective classification. We also improve upon large-scale pretrained models such as PLEX.


Instance-Conditional Timescales of Decay for Non-Stationary Learning

arXiv.org Artificial Intelligence

Slow concept drift is a ubiquitous, yet under-studied problem in practical machine learning systems. In such settings, although recent data is more indicative of future data, naively prioritizing recent instances runs the risk of losing valuable information from the past. We propose an optimization-driven approach towards balancing instance importance over large training windows. First, we model instance relevance using a mixture of multiple timescales of decay, allowing us to capture rich temporal trends. Second, we learn an auxiliary scorer model that recovers the appropriate mixture of timescales as a function of the instance itself. Finally, we propose a nested optimization objective for learning the scorer, by which it maximizes forward transfer for the learned model. Experiments on a large real-world dataset of 39M photos over a 9 year period show upto 15% relative gains in accuracy compared to other robust learning baselines. We replicate our gains on two collections of real-world datasets for non-stationary learning, and extend our work to continual learning settings where, too, we beat SOTA methods by large margins.


Rescuing referral failures during automated diagnosis of domain-shifted medical images

arXiv.org Artificial Intelligence

The success of deep learning models deployed in the real world depends critically on their ability to generalize well across diverse data domains. Here, we address a fundamental challenge with selective classification during automated diagnosis with domain-shifted medical images. In this scenario, models must learn to avoid making predictions when label confidence is low, especially when tested with samples far removed from the training set (covariate shift). Such uncertain cases are typically referred to the clinician for further analysis and evaluation. Yet, we show that even state-of-the-art domain generalization approaches fail severely during referral when tested on medical images acquired from a different demographic or using a different technology. We examine two benchmark diagnostic medical imaging datasets exhibiting strong covariate shifts: i) diabetic retinopathy prediction with retinal fundus images and ii) multilabel disease prediction with chest X-ray images. We show that predictive uncertainty estimates do not generalize well under covariate shifts leading to non-monotonic referral curves, and severe drops in performance (up to 50%) at high referral rates (>70%). We evaluate novel combinations of robust generalization and post hoc referral approaches, that rescue these failures and achieve significant performance improvements, typically >10%, over baseline methods. Our study identifies a critical challenge with referral in domain-shifted medical images and finds key applications in reliable, automated disease diagnosis.


Using Early Readouts to Mediate Featural Bias in Distillation

arXiv.org Artificial Intelligence

Deep networks tend to learn spurious feature-label correlations in real-world supervised learning tasks. This vulnerability is aggravated in distillation, where a student model may have lesser representational capacity than the corresponding teacher model. Often, knowledge of specific spurious correlations is used to reweight instances & rebalance the learning process. We propose a novel early readout mechanism whereby we attempt to predict the label using representations from earlier network layers. We show that these early readouts automatically identify problem instances or groups in the form of confident, incorrect predictions. Leveraging these signals to modulate the distillation loss on an instance level allows us to substantially improve not only group fairness measures across benchmark datasets, but also overall accuracy of the student model. We also provide secondary analyses that bring insight into the role of feature learning in supervision and distillation.


Sample-Efficient Personalization: Modeling User Parameters as Low Rank Plus Sparse Components

arXiv.org Machine Learning

Personalization of machine learning (ML) predictions for individual users/domains/enterprises is critical for practical recommendation systems. Standard personalization approaches involve learning a user/domain specific embedding that is fed into a fixed global model which can be limiting. On the other hand, personalizing/fine-tuning model itself for each user/domain -- a.k.a meta-learning -- has high storage/infrastructure cost. Moreover, rigorous theoretical studies of scalable personalization approaches have been very limited. To address the above issues, we propose a novel meta-learning style approach that models network weights as a sum of low-rank and sparse components. This captures common information from multiple individuals/users together in the low-rank part while sparse part captures user-specific idiosyncrasies. We then study the framework in the linear setting, where the problem reduces to that of estimating the sum of a rank-$r$ and a $k$-column sparse matrix using a small number of linear measurements. We propose a computationally efficient alternating minimization method with iterative hard thresholding -- AMHT-LRS -- to learn the low-rank and sparse part. Theoretically, for the realizable Gaussian data setting, we show that AMHT-LRS solves the problem efficiently with nearly optimal sample complexity. Finally, a significant challenge in personalization is ensuring privacy of each user's sensitive data. We alleviate this problem by proposing a differentially private variant of our method that also is equipped with strong generalization guarantees.


Overcoming Simplicity Bias in Deep Networks using a Feature Sieve

arXiv.org Artificial Intelligence

Simplicity bias is the concerning tendency of deep networks to over-depend on simple, weakly predictive features, to the exclusion of stronger, more complex features. This is exacerbated in real-world applications by limited training data and spurious feature-label correlations, leading to biased, incorrect predictions. We propose a direct, interventional method for addressing simplicity bias in DNNs, which we call the feature sieve. We aim to automatically identify and suppress easily-computable spurious features in lower layers of the network, thereby allowing the higher network levels to extract and utilize richer, more meaningful representations. We provide concrete evidence of this differential suppression & enhancement of relevant features on both controlled datasets and real-world images, and report substantial gains on many real-world debiasing benchmarks (11.4% relative gain on Imagenet-A; 3.2% on BAR, etc). Crucially, we do not depend on prior knowledge of spurious attributes or features, and in fact outperform many baselines that explicitly incorporate such information. We believe that our feature sieve work opens up exciting new research directions in automated adversarial feature extraction and representation learning for deep networks.


Shaken, and Stirred: Long-Range Dependencies Enable Robust Outlier Detection with PixelCNN++

arXiv.org Artificial Intelligence

Reliable outlier detection is critical for real-world deployment of deep learning models. Although extensively studied, likelihoods produced by deep generative models have been largely dismissed as being impractical for outlier detection. First, deep generative model likelihoods are readily biased by low-level input statistics. Second, many recent solutions for correcting these biases are computationally expensive, or do not generalize well to complex, natural datasets. Here, we explore outlier detection with a state-of-the-art deep autoregressive model: PixelCNN++. We show that biases in PixelCNN++ likelihoods arise primarily from predictions based on local dependencies. We propose two families of bijective transformations -- ``stirring'' and ``shaking'' -- which ameliorate low-level biases and isolate the contribution of long-range dependencies to PixelCNN++ likelihoods. These transformations are inexpensive and readily computed at evaluation time. We test our approaches extensively with five grayscale and six natural image datasets and show that they achieve or exceed state-of-the-art outlier detection, particularly on datasets with complex, natural images. We also show that our solutions work well with other types of generative models (generative flows and variational autoencoders) and that their efficacy is governed by each model's reliance on local dependencies. In sum, lightweight remedies suffice to achieve robust outlier detection on image data with deep generative models.