Chen, Liangfu
Inference Optimization of Foundation Models on AI Accelerators
Park, Youngsuk, Budhathoki, Kailash, Chen, Liangfu, Kübler, Jonas, Huang, Jiaji, Kleindessner, Matthäus, Huan, Jun, Cevher, Volkan, Wang, Yida, Karypis, George
Powerful foundation models, including large language models (LLMs), with Transformer architectures have ushered in a new era of Generative AI across various industries. Industry and research community have witnessed a large number of new applications, based on those foundation models. Such applications include question and answer, customer services, image and video generation, and code completions, among others. However, as the number of model parameters reaches to hundreds of billions, their deployment incurs prohibitive inference costs and high latency in real-world scenarios. As a result, the demand for cost-effective and fast inference using AI accelerators is ever more higher. To this end, our tutorial offers a comprehensive discussion on complementary inference optimization techniques using AI accelerators. Beginning with an overview of basic Transformer architectures and deep learning system frameworks, we deep dive into system optimization techniques for fast and memory-efficient attention computations and discuss how they can be implemented efficiently on AI accelerators. Next, we describe architectural elements that are key for fast transformer inference. Finally, we examine various model compression and fast decoding strategies in the same context.
Bifurcated Attention: Accelerating Massively Parallel Decoding with Shared Prefixes in LLMs
Athiwaratkun, Ben, Gonugondla, Sujan Kumar, Gouda, Sanjay Krishna, Qian, Haifeng, Ding, Hantian, Sun, Qing, Wang, Jun, Guo, Jiacheng, Chen, Liangfu, Bhatia, Parminder, Nallapati, Ramesh, Sengupta, Sudipta, Xiang, Bing
This study introduces bifurcated attention, a method designed to enhance language model inference in shared-context batch decoding scenarios. Our approach addresses the challenge of redundant memory IO costs, a critical factor contributing to latency in high batch sizes and extended context lengths. Bifurcated attention achieves this by strategically dividing the attention mechanism during incremental decoding into two separate GEMM operations: one focusing on the KV cache from prefill, and another on the decoding process itself. While maintaining the computational load (FLOPs) of standard attention mechanisms, bifurcated attention ensures precise computation with significantly reduced memory IO. Our empirical results show over 2.1$\times$ speedup when sampling 16 output sequences and more than 6.2$\times$ speedup when sampling 32 sequences at context lengths exceeding 8k tokens on a 7B model that uses multi-head attention. The efficiency gains from bifurcated attention translate into lower latency, making it particularly suitable for real-time applications. For instance, it enables massively parallel answer generation without substantially increasing latency, thus enhancing performance when integrated with post-processing techniques such as re-ranking.