FlashFormer: Whole-Model Kernels for Efficient Low-Batch Inference

Nrusimha, Aniruddha, Brandon, William, Mishra, Mayank, Shen, Yikang, Panda, Rameswar, Ragan-Kelley, Jonathan, Kim, Yoon

arXiv.org Artificial Intelligence 

The size and compute characteristics of modern large language models have led to an increased interest in developing specialized kernels tailored for particular training and inference workloads. Existing kernels primarily optimize for compute utilization, targeting the large-batch training and inference settings. However, low-batch inference, where memory bandwidth and kernel launch overheads are significant factors, remains important for many applications of interest such as in edge deployment and latency-sensitive applications. This paper describes FlashFormer, which fuses the entire transformer forward pass into a single kernel for accelerating low-batch inference of large language models. Across various model sizes and quantizations settings, FlashFormer achieves nontrivial speedups compared to existing inference kernels.