Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond

Oncescu, Costin-Andrei, Purandare, Sanket, Idreos, Stratos, Kakade, Sham

arXiv.org Artificial Intelligence 

A lot of recent progress in deep learning, particularly in the form of large language models (LLMs) has been driven by the transformer architecture [Vaswani et al., 2017]. While these models have great quality, it comes at a computation cost which scales quadratically in sequence length - both during training and inference. This can become prohibitive for very long contexts and as such a number of alternative architectures with better computational scaling in context length have been proposed [Gu and Dao, 2023, Poli et al., 2023, Fu et al., 2024]. While most of these works have improved computational efficiency for training, some still scale quadratically in sequence length when it comes to inference, thus not improving asymptotically over transformers. In this work, we propose a framework for optimizing inference efficiency for a general class of such models. As a case study, which inspired the method, we focus on long convolution sequence models (LCSMs) [Poli et al., 2023, Fu et al., 2022, Romero et al., 2021, Li et al., 2022, Karami and Ghodsi, 2024, Fu et al., 2023a]. However, our approach is not limited to LCSMs alone and we identify the properties that allow for such inference speedups in hope to guide the design of future architectures. In the particular case of LCSMs (including Hyena), the building block of the architecture is that of convolving the input sequence with a sequence-length long, (potentially underparameterized) filter. If we let L be the sequence length (e.g.