Goto

Collaborating Authors

 mathsf


Exact Expressive Power of Transformers with Padding

Neural Information Processing Systems

Chain of thought is a natural inference-time method for increasing the computational power of transformer-based large language models (LLMs), but comes at the cost of sequential decoding. Are there more efficient alternatives to expand a transformer's expressive power without adding parameters? We consider transformers with *padding* tokens as a form of parallelizable test-time compute. We show that averaging-hard-attention, masked-pre-norm transformers with polynomial padding recognize precisely the class $\mathsf{FO}$-uniform $\mathsf{TC}^0$ of extremely parallelizable problems. While the $\mathsf{TC}^0$ upper bound was known, proving a matching lower bound had been elusive. Further, our novel analysis reveals the precise expanded power of padded transformers when coupled with another form of inference-time compute, namely dynamically increasing depth via *looping*.


The Complexity of Finding Local Optima in Contrastive Learning

Neural Information Processing Systems

Contrastive learning is a powerful technique for discovering meaningful data representations by optimizing objectives based on $\textit{contrastive information}$, often given as a set of weighted triplets $\{(x_i, y_i^+, z_{i}^-)\}_{i = 1}^m$ indicating that an anchor $x_i$ is more similar to a positive example $y_i$ than to a negative example $z_i$. The goal is to find representations (e.g., embeddings in $\mathbb{R}^d$ or a tree metric) where anchors are placed closer to positive than to negative examples. While finding $\textit{global}$ optima of contrastive objectives is $\mathsf{NP}$-hard, the complexity of finding $\text{\textit{local}}$ optima---representations that do not improve by local search algorithms such as gradient-based methods---remains open. Our work settles the complexity of finding local optima in various contrastive learning problems by proving $\mathsf{PLS}$-hardness in discrete settings (e.g., maximize satisfied triplets) and $\mathsf{CLS}$-hardness in continuous settings (e.g., minimize Triplet Loss), where $\mathsf{PLS}$ (Polynomial Local Search) and $\mathsf{CLS}$ (Continuous Local Search) are well-studied complexity classes capturing local search dynamics in discrete and continuous optimization, respectively.


Computational Aspects of Bayesian Persuasion under Approximate Best Response

Neural Information Processing Systems

We study Bayesian persuasion under approximate best response, where the receiver may choose any action that is not too much suboptimal, given their posterior belief upon receiving the signal. We focus on the computational aspects of the problem, aiming to design algorithms that efficiently compute (almost) optimal strategies for the sender. Despite the absence of the revelation principle --- which has been one of the most powerful tools in Bayesian persuasion --- we design polynomial-time exact algorithms for the problem when either the state space or the action space is small, as well as a quasi-polynomial-time approximation scheme (QPTAS) for the general problem. On the negative side, we show there is no polynomial-time exact algorithm for the general problem unless $\mathsf{P} = \mathsf{NP}$. Our results build on several new algorithmic ideas, which might be useful in other principal-agent problems where robustness is desired.


On the Complexity of Learning Sparse Functions with Statistical and Gradient Queries

Neural Information Processing Systems

The goal of this paper is to investigate the complexity of gradient algorithms when learning sparse functions (juntas). We introduce a type of Statistical Queries ($\mathsf{SQ}$), which we call Differentiable Learning Queries ($\mathsf{DLQ}$), to model gradient queries on a specified loss with respect to an arbitrary model. We provide a tight characterization of the query complexity of $\mathsf{DLQ}$ for learning the support of a sparse function over generic product distributions. This complexity crucially depends on the loss function. For the squared loss, $\mathsf{DLQ}$ matches the complexity of Correlation Statistical Queries $(\mathsf{CSQ})$--potentially much worse than $\mathsf{SQ}$. But for other simple loss functions, including the $\ell_1$ loss, $\mathsf{DLQ}$ always achieves the same complexity as $\mathsf{SQ}$. We also provide evidence that $\mathsf{DLQ}$ can indeed capture learning with (stochastic) gradient descent by showing it correctly describes the complexity of learning with a two-layer neural network in the mean field regime and linear scaling.


On the Complexity of Identification in Linear Structural Causal Models

Neural Information Processing Systems

Learning the unknown causal parameters of a linear structural causal model is a fundamental task in causal analysis. The task, known as the problem of identification, asks to estimate the parameters of the model from acombination of assumptions on the graphical structure of the model and observational data, represented as a non-causal covariance matrix.In this paper, we give a new sound and complete algorithm for generic identification which runs in polynomial space. By a standard simulation result, namely $\mathsf{PSPACE} \subseteq \mathsf{EXP}$,this algorithm has exponential running time which vastly improves the state-of-the-art double exponential time method using a Gröbner basis approach. The paper also presents evidence that parameter identification is computationally hard in general. In particular, we prove, that the taskasking whether, for a given feasible correlation matrix, there are exactly one or two or more parameter sets explaining the observed matrix, is hard for $\forall \mathbb{R}$, the co-class of the existential theory of the reals. In particular, this problem is $\mathsf{coNP}$-hard.To our best knowledge, this is the first hardness result for some notion of identifiability.


Oja's Algorithm for Streaming Sparse PCA

Neural Information Processing Systems

Oja's algorithm for Streaming Principal Component Analysis (PCA) for $n$ data-points in a $d$ dimensional space achieves the same sin-squared error $O(r_{\mathsf{eff}}/n)$ as the offline algorithm in $O(d)$ space and $O(nd)$ time and a single pass through the datapoints. Here $r_{\mathsf{eff}}$ is the effective rank (ratio of the trace and the principal eigenvalue of the population covariance matrix $\Sigma$). Under this computational budget, we consider the problem of sparse PCA, where the principal eigenvector of $\Sigma$ is $s$-sparse, and $r_{\mathsf{eff}}$ can be large. In this setting, to our knowledge, *there are no known single-pass algorithms* that achieve the minimax error bound in $O(d)$ space and $O(nd)$ time without either requiring strong initialization conditions or assuming further structure (e.g., spiked) of the covariance matrix.We show that a simple single-pass procedure that thresholds the output of Oja's algorithm (the Oja vector) can achieve the minimax error bound under some regularity conditions in $O(d)$ space and $O(nd)$ time. We present a nontrivial and novel analysis of the entries of the unnormalized Oja vector, which involves the projection of a product of independent random matrices on a random initial vector. This is completely different from previous analyses of Oja's algorithm and matrix products, which have been done when the $r_{\mathsf{eff}}$ is bounded.


Span-Based Optimal Sample Complexity for Weakly Communicating and General Average Reward MDPs

Neural Information Processing Systems

We study the sample complexity of learning an $\varepsilon$-optimal policy in an average-reward Markov decision process (MDP) under a generative model. For weakly communicating MDPs, we establish the complexity bound $\widetilde{O}\left(SA\frac{\mathsf{H}}{\varepsilon^2} \right)$, where $\mathsf{H}$ is the span of the bias function of the optimal policy and $SA$ is the cardinality of the state-action space. Our result is the first that is minimax optimal (up to log factors) in all parameters $S,A,\mathsf{H}$, and $\varepsilon$, improving on existing work that either assumes uniformly bounded mixing times for all policies or has suboptimal dependence on the parameters. We also initiate the study of sample complexity in general (multichain) average-reward MDPs.


Nimbus: Secure and Efficient Two-Party Inference for Transformers

Neural Information Processing Systems

Transformer models have gained significant attention due to their power in machine learning tasks. Their extensive deployment has raised concerns about the potential leakage of sensitive information during inference. However, when being applied to Transformers, existing approaches based on secure two-party computation (2PC) bring about efficiency limitations in two folds: (1) resource-intensive matrix multiplications in linear layers, and (2) complex non-linear activation functions like $\mathsf{GELU}$ and $\mathsf{Softmax}$. This work presents a new two-party inference framework $\mathsf{Nimbus}$ for Transformer models. Specifically, we propose a new 2PC paradigm to securely compute matrix multiplications based on an outer-product insight, which achieves $2.9\times \sim 12.5\times$ performance improvements compared to the state-of-the-art (SOTA) protocol. Furthermore, through a new observation of utilizing the input distribution, we propose an approach of low-degree polynomial approximation for $\mathsf{GELU}$ and $\mathsf{Softmax}$, which improves the performance of the SOTA polynomial approximation by $2.9\times \sim 4.0\times$, where the average accuracy loss of our approach is 0.08\% compared to the non-2PC inference without privacy. Compared with the SOTA two-party inference, $\mathsf{Nimbus}$ improves the end-to-end performance of $BERT_{base}$ inference by $2.7\times \sim 4.7\times$ across different network settings.


In-Context Learning of a Linear Transformer Block: Benefits of the MLP Component and One-Step GD Initialization

Neural Information Processing Systems

We study the \emph{in-context learning} (ICL) ability of a \emph{Linear Transformer Block} (LTB) that combines a linear attention component and a linear multi-layer perceptron (MLP) component. For ICL of linear regression with a Gaussian prior and a \emph{non-zero mean}, we show that LTB can achieve nearly Bayes optimal ICL risk. In contrast, using only linear attention must incur an irreducible additive approximation error. Furthermore, we establish a correspondence between LTB and one-step gradient descent estimators with learnable initialization ($\mathsf{GD}-\beta$), in the sense that every $\mathsf{GD}-\beta$ estimator can be implemented by an LTB estimator and every optimal LTB estimator that minimizes the in-class ICL risk is effectively a $\mathsf{GD}-\beta$ estimator.Finally, we show that $\mathsf{GD}-\beta$ estimators can be efficiently optimized with gradient flow, despite a non-convex training objective.Our results reveal that LTB achieves ICL by implementing $\mathsf{GD}-\beta$, and they highlight the role of MLP layers in reducing approximation error.


From Stochastic Mixability to Fast Rates

Neural Information Processing Systems

Empirical risk minimization (ERM) is a fundamental learning rule for statistical learning problems where the data is generated according to some unknown distribution $\mathsf{P}$ and returns a hypothesis $f$ chosen from a fixed class $\mathcal{F}$ with small loss $\ell$. In the parametric setting, depending upon $(\ell, \mathcal{F},\mathsf{P})$ ERM can have slow $(1/\sqrt{n})$ or fast $(1/n)$ rates of convergence of the excess risk as a function of the sample size $n$. There exist several results that give sufficient conditions for fast rates in terms of joint properties of $\ell$, $\mathcal{F}$, and $\mathsf{P}$, such as the margin condition and the Bernstein condition. In the non-statistical prediction with expert advice setting, there is an analogous slow and fast rate phenomenon, and it is entirely characterized in terms of the mixability of the loss $\ell$ (there being no role there for $\mathcal{F}$ or $\mathsf{P}$). The notion of stochastic mixability builds a bridge between these two models of learning, reducing to classical mixability in a special case. The present paper presents a direct proof of fast rates for ERM in terms of stochastic mixability of $(\ell,\mathcal{F}, \mathsf{P})$, and in so doing provides new insight into the fast-rates phenomenon.