Choromanski, Krzysztof Marcin
Mnemosyne: Learning to Train Transformers with Transformers
Jain, Deepali, Choromanski, Krzysztof Marcin, Dubey, Avinava, Singh, Sumeet, Sindhwani, Vikas, Zhang, Tingnan, Tan, Jie
In this work, we propose a new class of learnable optimizers, called \textit{Mnemosyne}. It is based on the novel spatio-temporal low-rank implicit attention Transformers that can learn to train entire neural network architectures, including other Transformers, without any task-specific optimizer tuning. We show that Mnemosyne: (a) outperforms popular LSTM optimizers (also with new feature engineering to mitigate catastrophic forgetting of LSTMs), (b) can successfully train Transformers while using simple meta-training strategies that require minimal computational resources, (c) matches accuracy-wise SOTA hand-designed optimizers with carefully tuned hyper-parameters (often producing top performing models). Furthermore, Mnemosyne provides space complexity comparable to that of its hand-designed first-order counterparts, which allows it to scale to training larger sets of parameters. We conduct an extensive empirical evaluation of Mnemosyne on: (a) fine-tuning a wide range of Vision Transformers (ViTs) from medium-size architectures to massive ViT-Hs (36 layers, 16 heads), (b) pre-training BERT models and (c) soft prompt-tuning large 11B+ T5XXL models. We complement our results with a comprehensive theoretical analysis of the compact associative memory used by Mnemosyne which we believe was never done before.
Learning a Fourier Transform for Linear Relative Positional Encodings in Transformers
Choromanski, Krzysztof Marcin, Li, Shanda, Likhosherstov, Valerii, Dubey, Kumar Avinava, Luo, Shengjie, He, Di, Yang, Yiming, Sarlos, Tamas, Weingarten, Thomas, Weller, Adrian
We propose a new class of linear Transformers called FourierLearner-Transformers (FLTs), which incorporate a wide range of relative positional encoding mechanisms (RPEs). These include regular RPE techniques applied for nongeometric data, as well as novel RPEs operating on the sequences of tokens embedded in higher-dimensional Euclidean spaces (e.g. point clouds). FLTs construct the optimal RPE mechanism implicitly by learning its spectral representation. As opposed to other architectures combining efficient low-rank linear attention with RPEs, FLTs remain practical in terms of their memory usage and do not require additional assumptions about the structure of the RPE-mask. FLTs allow also for applying certain structural inductive bias techniques to specify masking strategies, e.g. they provide a way to learn the so-called local RPEs introduced in this paper and providing accuracy gains as compared with several other linear Transformers for language modeling. We also thoroughly tested FLTs on other data modalities and tasks, such as: image classification and 3D molecular modeling. For 3D-data FLTs are, to the best of our knowledge, the first Transformers architectures providing RPE-enhanced linear attention.