Deora, Puneesh
Implicit Bias and Fast Convergence Rates for Self-attention
Vasudeva, Bhavya, Deora, Puneesh, Thrampoulidis, Christos
Self-attention serves as the fundamental building block of transformers, distinguishing them from traditional neural networks (Vaswani et al., 2017) and driving their outstanding performance across various applications, including natural language processing and generation (Devlin et al., 2019; Brown et al., 2020; Raffel et al., 2020), as well as computer vision (Dosovitskiy et al., 2021; Radford et al., 2021; Touvron et al., 2021). With transformers establishing themselves as the de-facto deep-learning architecture, driving advancements in applications seamlessly integrated into society's daily life at an unprecedented pace (OpenAI, 2022), there has been a surge of recent interest in the mathematical study of the fundamental optimization and statistical principles of the self-attention mechanism; see Section 6 on related work for an overview. In pursuit of this objective, Tarzanagh et al. (2023b,a) have initiated an investigation into the implicit bias of gradient descent (GD) in training a self-attention layer with fixed linear decoder in a binary classification task. Concretely, the study paradigm of implicit bias seeks to characterize structural properties of the weights learned by GD when the training objective has multiple solutions. The prototypical instance of this paradigm is GD training of linear logistic regression on separable data: among infinitely many possible solutions to logistic-loss minimization (each linear separator defines one such solution), GD learns weights that converge in direction to the (unique) max-margin class separator (Soudry et al., 2018; Ji and Telgarsky, 2018). Notably, convergence is global, holding irrespective of the initial weights' direction, and comes with explicit rates that characterize its speed with respect to the number of iterations. Drawing an analogy to this prototypical instance, when training self-attention with linear decoder in a binary classification task, Tarzanagh et al. (2023a) defines a hard-margin SVM problem (W-SVM) that separates, with maximal margin, optimal input tokens from non-optimal ones based on their respective softmax logits.
On the Optimization and Generalization of Multi-head Attention
Deora, Puneesh, Ghaderi, Rouzbeh, Taheri, Hossein, Thrampoulidis, Christos
The training and generalization dynamics of the Transformer's core mechanism, namely the Attention mechanism, remain under-explored. Besides, existing analyses primarily focus on single-head attention. Inspired by the demonstrated benefits of overparameterization when training fully-connected networks, we investigate the potential optimization and generalization advantages of using multiple attention heads. Towards this goal, we derive convergence and generalization guarantees for gradient-descent training of a single-layer multi-head self-attention model, under a suitable realizability condition on the data. We then establish primitive conditions on the initialization that ensure realizability holds. Finally, we demonstrate that these conditions are satisfied for a simple tokenized-mixture model. We expect the analysis can be extended to various data-model and architecture variations.