Chen, Jiasi
Selective Attention: Enhancing Transformer through Principled Context Control
Zhang, Xuechen, Chang, Xiangyu, Li, Mingchen, Roy-Chowdhury, Amit, Chen, Jiasi, Oymak, Samet
The attention mechanism within the transformer architecture enables the model to weigh and combine tokens based on their relevance to the query. While self-attention has enjoyed major success, it notably treats all queries $q$ in the same way by applying the mapping $V^\top\text{softmax}(Kq)$, where $V,K$ are the value and key embeddings respectively. In this work, we argue that this uniform treatment hinders the ability to control contextual sparsity and relevance. As a solution, we introduce the $\textit{Selective Self-Attention}$ (SSA) layer that augments the softmax nonlinearity with a principled temperature scaling strategy. By controlling temperature, SSA adapts the contextual sparsity of the attention map to the query embedding and its position in the context window. Through theory and experiments, we demonstrate that this alleviates attention dilution, aids the optimization process, and enhances the model's ability to control softmax spikiness of individual queries. We also incorporate temperature scaling for value embeddings and show that it boosts the model's ability to suppress irrelevant/noisy tokens. Notably, SSA is a lightweight method which introduces less than 0.5% new parameters through a weight-sharing strategy and can be fine-tuned on existing LLMs. Extensive empirical evaluations demonstrate that SSA-equipped models achieve a noticeable and consistent accuracy improvement on language modeling benchmarks.
TREACLE: Thrifty Reasoning via Context-Aware LLM and Prompt Selection
Zhang, Xuechen, Huang, Zijian, Taga, Ege Onur, Joe-Wong, Carlee, Oymak, Samet, Chen, Jiasi
Recent successes in natural language processing have led to the proliferation of large language models (LLMs) by multiple providers. Each LLM offering has different inference accuracy, monetary cost, and latency, and their accuracy further depends on the exact wording of the question (i.e., the specific prompt). At the same time, users often have a limit on monetary budget and latency to answer all their questions, and they do not know which LLMs to choose for each question to meet their accuracy and long-term budget requirements. To navigate this rich design space, we propose TREACLE (Thrifty Reasoning via Context-Aware LLM and Prompt Selection), a reinforcement learning policy that jointly selects the model and prompting scheme while respecting the user's monetary cost and latency constraints. TREACLE uses the problem context, including question text embeddings (reflecting the type or difficulty of a query) and the response history (reflecting the consistency of previous responses) to make smart decisions. Our evaluations on standard reasoning datasets (GSM8K, CSQA, and LLC ) with various LLMs and prompts show that TREACLE enables cost savings of up to 85% compared to baselines while maintaining high accuracy. Importantly, it provides the user with the ability to gracefully trade off accuracy for cost.
Class-attribute Priors: Adapting Optimization to Heterogeneity and Fairness Objective
Zhang, Xuechen, Li, Mingchen, Chen, Jiasi, Thrampoulidis, Christos, Oymak, Samet
Modern classification problems exhibit heterogeneities across individual classes: Each class may have unique attributes, such as sample size, label quality, or predictability (easy vs difficult), and variable importance at test-time. Without care, these heterogeneities impede the learning process, most notably, when optimizing fairness objectives. Confirming this, under a gaussian mixture setting, we show that the optimal SVM classifier for balanced accuracy needs to be adaptive to the class attributes. This motivates us to propose CAP: An effective and general method that generates a class-specific learning strategy (e.g. hyperparameter) based on the attributes of that class. This way, optimization process better adapts to heterogeneities. CAP leads to substantial improvements over the naive approach of assigning separate hyperparameters to each class. We instantiate CAP for loss function design and post-hoc logit adjustment, with emphasis on label-imbalanced problems. We show that CAP is competitive with prior art and its flexibility unlocks clear benefits for fairness objectives beyond balanced accuracy. Finally, we evaluate CAP on problems with label noise as well as weighted test objectives to showcase how CAP can jointly adapt to different heterogeneities.
FedYolo: Augmenting Federated Learning with Pretrained Transformers
Zhang, Xuechen, Li, Mingchen, Chang, Xiangyu, Chen, Jiasi, Roy-Chowdhury, Amit K., Suresh, Ananda Theertha, Oymak, Samet
The growth and diversity of machine learning applications motivate a rethinking of learning with mobile and edge devices. How can we address diverse client goals and learn with scarce heterogeneous data? While federated learning aims to address these issues, it has challenges hindering a unified solution. Large transformer models have been shown to work across a variety of tasks achieving remarkable few-shot adaptation. This raises the question: Can clients use a single general-purpose model, rather than custom models for each task, while obeying device and network constraints? In this work, we investigate pretrained transformers (PTF) to achieve these on-device learning goals and thoroughly explore the roles of model size and modularity, where the latter refers to adaptation through modules such as prompts or adapters. Focusing on federated learning, we demonstrate that: (1) Larger scale shrinks the accuracy gaps between alternative approaches and improves heterogeneity robustness. Scale allows clients to run more local SGD epochs which can significantly reduce the number of communication rounds. At the extreme, clients can achieve respectable accuracy locally highlighting the potential of fully-local learning. (2) Modularity, by design, enables $>$100$\times$ less communication in bits. Surprisingly, it also boosts the generalization capability of local adaptation methods and the robustness of smaller PTFs. Finally, it enables clients to solve multiple unrelated tasks simultaneously using a single PTF, whereas full updates are prone to catastrophic forgetting. These insights on scale and modularity motivate a new federated learning approach we call "You Only Load Once" (FedYolo): The clients load a full PTF model once and all future updates are accomplished through communication-efficient modules with limited catastrophic-forgetting, where each task is assigned to its own module.
Learning Feature Nonlinearities with Non-Convex Regularized Binned Regression
Oymak, Samet, Mahdavi, Mehrdad, Chen, Jiasi
For various applications, the relations between the dependent and independent variables are highly nonlinear. Consequently, for large scale complex problems, neural networks and regression trees are commonly preferred over linear models such as Lasso. This work proposes learning the feature nonlinearities by binning feature values and finding the best fit in each quantile using non-convex regularized linear regression. The algorithm first captures the dependence between neighboring quantiles by enforcing smoothness via piecewise-constant/linear approximation and then selects a sparse subset of good features. We prove that the proposed algorithm is statistically and computationally efficient. In particular, it achieves linear rate of convergence while requiring near-minimal number of samples. Evaluations on synthetic and real datasets demonstrate that algorithm is competitive with current state-of-the-art and accurately learns feature nonlinearities. Finally, we explore an interesting connection between the binning stage of our algorithm and sparse Johnson-Lindenstrauss matrices.