Tensor Attention Training: Provably Efficient Learning of Higher-order Transformers
Gu, Jiuxiang, Liang, Yingyu, Shi, Zhenmei, Song, Zhao, Zhou, Yufa
–arXiv.org Artificial Intelligence
Tensor Attention, a multi-view attention that is able to capture high-order correlations among multiple modalities, can overcome the representational limitations of classical matrix attention. However, the $\Omega(n^3)$ time complexity of tensor attention poses a significant obstacle to its practical implementation in transformers, where $n$ is the input sequence length. In this work, we prove that the backward gradient of tensor attention training can be computed in almost linear $n^{1+o(1)}$ time, the same complexity as its forward computation under a bounded entries assumption. We provide a closed-form solution for the gradient and propose a fast computation method utilizing polynomial approximation methods and tensor algebraic tricks. Furthermore, we prove the necessity and tightness of our assumption through hardness analysis, showing that slightly weakening it renders the gradient problem unsolvable in truly subcubic time. Our theoretical results establish the feasibility of efficient higher-order transformer training and may facilitate practical applications of tensor attention architectures.
arXiv.org Artificial Intelligence
May-25-2024
- Country:
- South America
- Chile > Santiago Metropolitan Region
- Santiago Province > Santiago (0.04)
- Brazil > Rio de Janeiro
- Rio de Janeiro (0.04)
- Chile > Santiago Metropolitan Region
- North America > United States
- Virginia (0.04)
- Pennsylvania (0.04)
- Wisconsin > Dane County
- Madison (0.04)
- Europe
- Sweden > Östergötland County
- Linköping (0.04)
- Spain > Andalusia
- Granada Province > Granada (0.04)
- Sweden > Östergötland County
- Asia
- Middle East > Jordan (0.04)
- China > Hong Kong (0.04)
- Myanmar > Tanintharyi Region
- Dawei (0.04)
- Afghanistan > Parwan Province
- Charikar (0.04)
- Africa > Senegal
- Kolda Region > Kolda (0.04)
- South America
- Genre:
- Research Report (0.82)
- Technology: