Implicit Bias and Fast Convergence Rates for Self-attention
Vasudeva, Bhavya, Deora, Puneesh, Thrampoulidis, Christos
–arXiv.org Artificial Intelligence
Self-attention serves as the fundamental building block of transformers, distinguishing them from traditional neural networks (Vaswani et al., 2017) and driving their outstanding performance across various applications, including natural language processing and generation (Devlin et al., 2019; Brown et al., 2020; Raffel et al., 2020), as well as computer vision (Dosovitskiy et al., 2021; Radford et al., 2021; Touvron et al., 2021). With transformers establishing themselves as the de-facto deep-learning architecture, driving advancements in applications seamlessly integrated into society's daily life at an unprecedented pace (OpenAI, 2022), there has been a surge of recent interest in the mathematical study of the fundamental optimization and statistical principles of the self-attention mechanism; see Section 6 on related work for an overview. In pursuit of this objective, Tarzanagh et al. (2023b,a) have initiated an investigation into the implicit bias of gradient descent (GD) in training a self-attention layer with fixed linear decoder in a binary classification task. Concretely, the study paradigm of implicit bias seeks to characterize structural properties of the weights learned by GD when the training objective has multiple solutions. The prototypical instance of this paradigm is GD training of linear logistic regression on separable data: among infinitely many possible solutions to logistic-loss minimization (each linear separator defines one such solution), GD learns weights that converge in direction to the (unique) max-margin class separator (Soudry et al., 2018; Ji and Telgarsky, 2018). Notably, convergence is global, holding irrespective of the initial weights' direction, and comes with explicit rates that characterize its speed with respect to the number of iterations. Drawing an analogy to this prototypical instance, when training self-attention with linear decoder in a binary classification task, Tarzanagh et al. (2023a) defines a hard-margin SVM problem (W-SVM) that separates, with maximal margin, optimal input tokens from non-optimal ones based on their respective softmax logits.
arXiv.org Artificial Intelligence
Feb-8-2024
- Country:
- North America > United States > Minnesota > Hennepin County > Minneapolis (0.14)
- Genre:
- Research Report > New Finding (1.00)
- Technology: