Neural Attention Search

Neural Information Processing Systems 

We present Neural Attention Search (NAtS), an end-to-end learnable sparse transformer that automatically evaluates the importance of each token within a sequence and determines if the corresponding token can be dropped after several steps. To this end, we design a search space that contains three token types: (i) Global Tokens will be preserved and queried by all the following tokens; (ii) Local Tokens survive until the next global token appears; and (iii) Sliding Window Tokens have an impact on the inference of a fixed size of the next following tokens. Similar to the One-Shot Neural Architecture Search approach, this token-type information can be learned jointly with the architecture weights via a learnable attention mask. Experiments on both training a new transformer from scratch and fine-tuning existing large language models show that NAtS can efficiently reduce the KV cache size and the inference costs for the models while maintaining the models' performance.