Mirrokni, Vahab
PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels
Kacham, Praneeth, Mirrokni, Vahab, Zhong, Peilin
The quadratic time and memory complexity inherent to self-attention mechanisms, with respect to sequence length, presents a critical computational bottleneck in the training and deployment of large-scale Transformer-based language models. Recent theoretical results indicate the intractability of sub-quadratic softmax attention approximation under reasonable complexity assumptions. This paper addresses this challenge by first demonstrating that polynomial attention with high degree can effectively replace softmax without sacrificing model quality. Next, we develop polynomial sketching techniques from numerical linear algebra to achieve linear-time polynomial attention with approximation guarantees. Crucially, our approach achieves this speedup without requiring the sparsification of attention matrices. We also present a block-based algorithm to apply causal masking efficiently. Combining these techniques, we provide \emph{PolySketchFormer}, a practical linear-time Transformer architecture for language modeling that offers provable guarantees. We validate PolySketchFormer empirically by training language models capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. For context lengths of 32k and GPT-2 style models, our model achieves a 2.5-4x speedup in training compared to FlashAttention, with no observed degradation in quality across our experiments.
Learning from Aggregate responses: Instance Level versus Bag Level Loss Functions
Javanmard, Adel, Chen, Lin, Mirrokni, Vahab, Badanidiyuru, Ashwinkumar, Fu, Gang
Due to the rise of privacy concerns, in many practical applications the training data is aggregated before being shared with the learner, in order to protect privacy of users' sensitive responses. In an aggregate learning framework, the dataset is grouped into bags of samples, where each bag is available only with an aggregate response, providing a summary of individuals' responses in that bag. In this paper, we study two natural loss functions for learning from aggregate responses: bag-level loss and the instance-level loss. In the former, the model is learnt by minimizing a loss between aggregate responses and aggregate model predictions, while in the latter the model aims to fit individual predictions to the aggregate responses. In this work, we show that the instance-level loss can be perceived as a regularized form of the bag-level loss. This observation lets us compare the two approaches with respect to bias and variance of the resulting estimators, and introduce a novel interpolating estimator which combines the two approaches. For linear regression tasks, we provide a precise characterization of the risk of the interpolating estimator in an asymptotic regime where the size of the training set grows in proportion to the features dimension. Our analysis allows us to theoretically understand the effect of different factors, such as bag size on the model prediction risk. In addition, we propose a mechanism for differentially private learning from aggregate responses and derive the optimal bag size in terms of prediction risk-privacy trade-off. We also carry out thorough experiments to corroborate our theory and show the efficacy of the interpolating estimator.
Analysis of Dual-Based PID Controllers through Convolutional Mirror Descent
Balseiro, Santiago R., Lu, Haihao, Mirrokni, Vahab, Sivan, Balasubramanian
Dual-based proportional-integral-derivative (PID) controllers are often employed in practice to solve online allocation problems with global constraints, such as budget pacing in online advertising. However, controllers are used in a heuristic fashion and come with no provable guarantees on their performance. This paper provides the first regret bounds on the performance of dual-based PID controllers for online allocation problems. We do so by first establishing a fundamental connection between dual-based PID controllers and a new first-order algorithm for online convex optimization called \emph{Convolutional Mirror Descent} (CMD), which updates iterates based on a weighted moving average of past gradients. CMD recovers, in a special case, online mirror descent with momentum and optimistic mirror descent. We establish sufficient conditions under which CMD attains low regret for general online convex optimization problems with adversarial inputs. We leverage this new result to give the first regret bound for dual-based PID controllers for online allocation problems. As a byproduct of our proofs, we provide the first regret bound for CMD for non-smooth convex optimization, which might be of independent interest.
HyperAttention: Long-context Attention in Near-Linear Time
Han, Insu, Jayaram, Rajesh, Karbasi, Amin, Mirrokni, Vahab, Woodruff, David P., Zandieh, Amir
We present an approximate attention mechanism named HyperAttention to address the computational challenges posed by the growing complexity of long contexts used in Large Language Models (LLMs). Recent work suggests that in the worst-case scenario, quadratic time is necessary unless the entries of the attention matrix are bounded or the matrix has low stable rank. We introduce two parameters which measure: (1) the max column norm in the normalized attention matrix, and (2) the ratio of row norms in the unnormalized attention matrix after detecting and removing large entries. We use these fine-grained parameters to capture the hardness of the problem. Despite previous lower bounds, we are able to achieve a linear time sampling algorithm even when the matrix has unbounded entries or a large stable rank, provided the above parameters are small. HyperAttention features a modular design that easily accommodates integration of other fast low-level implementations, particularly FlashAttention. Empirically, employing Locality Sensitive Hashing (LSH) to identify large entries, HyperAttention outperforms existing methods, giving significant speed improvements compared to state-of-the-art solutions like FlashAttention. We validate the empirical performance of HyperAttention on a variety of different long-context length datasets. For example, HyperAttention makes the inference time of ChatGLM2 50\% faster on 32k context length while perplexity increases from 5.6 to 6.3. On larger context length, e.g., 131k, with causal masking, HyperAttention offers 5-fold speedup on a single attention layer.
Anonymous Learning via Look-Alike Clustering: A Precise Analysis of Model Generalization
Javanmard, Adel, Mirrokni, Vahab
While personalized recommendations systems have become increasingly popular, ensuring user data protection remains a top concern in the development of these learning systems. A common approach to enhancing privacy involves training models using anonymous data rather than individual data. In this paper, we explore a natural technique called \emph{look-alike clustering}, which involves replacing sensitive features of individuals with the cluster's average values. We provide a precise analysis of how training models using anonymous cluster centers affects their generalization capabilities. We focus on an asymptotic regime where the size of the training set grows in proportion to the features dimension. Our analysis is based on the Convex Gaussian Minimax Theorem (CGMT) and allows us to theoretically understand the role of different model components on the generalization error. In addition, we demonstrate that in certain high-dimensional regimes, training over anonymous cluster centers acts as a regularization and improves generalization error of the trained models. Finally, we corroborate our asymptotic theory with finite-sample numerical experiments where we observe a perfect match when the sample size is only of order of a few hundreds.
Replicable Clustering
Esfandiari, Hossein, Karbasi, Amin, Mirrokni, Vahab, Velegkas, Grigoris, Zhou, Felix
We design replicable algorithms in the context of statistical clustering under the recently introduced notion of replicability from Impagliazzo et al. [2022]. According to this definition, a clustering algorithm is replicable if, with high probability, its output induces the exact same partition of the sample space after two executions on different inputs drawn from the same distribution, when its internal randomness is shared across the executions. We propose such algorithms for the statistical $k$-medians, statistical $k$-means, and statistical $k$-centers problems by utilizing approximation routines for their combinatorial counterparts in a black-box manner. In particular, we demonstrate a replicable $O(1)$-approximation algorithm for statistical Euclidean $k$-medians ($k$-means) with $\operatorname{poly}(d)$ sample complexity. We also describe an $O(1)$-approximation algorithm with an additional $O(1)$-additive error for statistical Euclidean $k$-centers, albeit with $\exp(d)$ sample complexity. In addition, we provide experiments on synthetic distributions in 2D using the $k$-means++ implementation from sklearn as a black-box that validate our theoretical results.
Learning Rate Schedules in the Presence of Distribution Shift
Fahrbach, Matthew, Javanmard, Adel, Mirrokni, Vahab, Worah, Pratik
We design learning rate schedules that minimize regret for SGD-based online learning in the presence of a changing data distribution. We fully characterize the optimal learning rate schedule for online linear regression via a novel analysis with stochastic differential equations. For general convex loss functions, we propose new learning rate schedules that are robust to distribution shift and we give upper and lower bounds for the regret that only differ by constants. For non-convex loss functions, we define a notion of regret based on the gradient norm of the estimated models and propose a learning schedule that minimizes an upper bound on the total expected regret. Intuitively, one expects changing loss landscapes to require more exploration, and we confirm that optimal learning rate schedules typically increase in the presence of distribution shift. Finally, we provide experiments for high-dimensional regression models and neural networks to illustrate these learning rate schedules and their cumulative regret.
Causal Inference with Differentially Private (Clustered) Outcomes
Javanmard, Adel, Mirrokni, Vahab, Pouget-Abadie, Jean
Estimating causal effects from randomized experiments is only feasible if participants agree to reveal their potentially sensitive responses. Of the many ways of ensuring privacy, label differential privacy is a widely used measure of an algorithm's privacy guarantee, which might encourage participants to share responses without running the risk of de-anonymization. Many differentially private mechanisms inject noise into the original data-set to achieve this privacy guarantee, which increases the variance of most statistical estimators and makes the precise measurement of causal effects difficult: there exists a fundamental privacy-variance trade-off to performing causal analyses from differentially private data. With the aim of achieving lower variance for stronger privacy guarantees, we suggest a new differential privacy mechanism, "Cluster-DP", which leverages any given cluster structure of the data while still allowing for the estimation of causal effects. We show that, depending on an intuitive measure of cluster quality, we can improve the variance loss while maintaining our privacy guarantees. We compare its performance, theoretically and empirically, to that of its unclustered version and a more extreme uniform-prior version which does not use any of the original response distribution, both of which are special cases of the "Cluster-DP" algorithm.
Measuring Re-identification Risk
Carey, CJ, Dick, Travis, Epasto, Alessandro, Javanmard, Adel, Karlin, Josh, Kumar, Shankar, Medina, Andres Munoz, Mirrokni, Vahab, Nunes, Gabriel Henrique, Vassilvitskii, Sergei, Zhong, Peilin
In this work, we present a new theoretical framework to measure re-identification risk in such user representations. Our framework, based on hypothesis testing, formally bounds the probability that an attacker may be able to obtain the identity of a user from their representation. As an application, we show how our framework is general enough to model important real-world applications such as the Chrome's Topics API for interest-based advertising. We complement our theoretical bounds by showing provably good attack algorithms for re-identification that we use to estimate the re-identification risk in the Topics API. We believe this work provides a rigorous and interpretable notion of re-identification risk and a framework to measure it that can be used to inform real-world applications.
TF-GNN: Graph Neural Networks in TensorFlow
Ferludin, Oleksandr, Eigenwillig, Arno, Blais, Martin, Zelle, Dustin, Pfeifer, Jan, Sanchez-Gonzalez, Alvaro, Li, Wai Lok Sibon, Abu-El-Haija, Sami, Battaglia, Peter, Bulut, Neslihan, Halcrow, Jonathan, de Almeida, Filipe Miguel Gonçalves, Gonnet, Pedro, Jiang, Liangze, Kothari, Parth, Lattanzi, Silvio, Linhares, André, Mayer, Brandon, Mirrokni, Vahab, Palowitch, John, Paradkar, Mihir, She, Jennifer, Tsitsulin, Anton, Villela, Kevin, Wang, Lisa, Wong, David, Perozzi, Bryan
TensorFlow-GNN (TF-GNN) is a scalable library for Graph Neural Networks in TensorFlow. It is designed from the bottom up to support the kinds of rich heterogeneous graph data that occurs in today's information ecosystems. In addition to enabling machine learning researchers and advanced developers, TF-GNN offers low-code solutions to empower the broader developer community in graph learning. Many production models at Google use TF-GNN, and it has been recently released as an open source project. In this paper we describe the TF-GNN data model, its Keras message passing API, and relevant capabilities such as graph sampling and distributed training.