r-spmm kernel
SPLAT: A framework for optimised GPU code-generation for SParse reguLar ATtention
Gupta, Ahan, Yuan, Yueming, Jain, Devansh, Ge, Yuhao, Aponte, David, Zhou, Yanqi, Mendis, Charith
Multi-head-self-attention (MHSA) mechanisms achieve state-of-the-art (SOTA) performance across natural language processing and vision tasks. However, their quadratic dependence on sequence lengths has bottlenecked inference speeds. To circumvent this bottleneck, researchers have proposed various sparse-MHSA models, where a subset of full attention is computed. Despite their promise, current sparse libraries and compilers do not support high-performance implementations for diverse sparse-MHSA patterns due to the underlying sparse formats they operate on. These formats, which are typically designed for high-performance & scientific computing applications, are either curated for extreme amounts of random sparsity (<1% non-zero values), or specific sparsity patterns. However, the sparsity patterns in sparse-MHSA are moderately sparse (10-50% non-zero values) and varied, resulting in existing sparse-formats trading off generality for performance. We bridge this gap, achieving both generality and performance, by proposing a novel sparse format: affine-compressed-sparse-row (ACSR) and supporting code-generation scheme, SPLAT, that generates high-performance implementations for diverse sparse-MHSA patterns on GPUs. Core to our proposed format and code generation algorithm is the observation that common sparse-MHSA patterns have uniquely regular geometric properties. These properties, which can be analyzed just-in-time, expose novel optimizations and tiling strategies that SPLAT exploits to generate high-performance implementations for diverse patterns. To demonstrate SPLAT's efficacy, we use it to generate code for various sparse-MHSA models, achieving geomean speedups of 2.05x and 4.05x over hand-written kernels written in triton and TVM respectively on A100 GPUs. Moreover, its interfaces are intuitive and easy to use with existing implementations of MHSA in JAX.
- North America > United States > California > San Francisco County > San Francisco (0.14)
- North America > United States > Illinois > Champaign County > Urbana (0.04)
- North America > United States > New York > New York County > New York City (0.04)
- (11 more...)
- Information Technology > Artificial Intelligence > Machine Learning > Neural Networks > Deep Learning (1.00)
- Information Technology > Artificial Intelligence > Representation & Reasoning > Automatic Programming (0.84)
- Information Technology > Artificial Intelligence > Natural Language > Large Language Model (0.67)