Learning How Hard to Think: Input-Adaptive Allocation of LM Computation

Damani, Mehul, Shenfeld, Idan, Peng, Andi, Bobu, Andreea, Andreas, Jacob

arXiv.org Artificial Intelligence 

Computationally intensive decoding procedures--including search, reranking, and self-critique--can improve the quality of language model (LM) outputs in problems spanning code generation, numerical reasoning, and dialog. Existing work typically applies the same decoding procedure for every input to an LM. But not all inputs require the same amount of computation to process. Can we allocate decoding computation adaptively, using more resources to answer questions whose answers will be harder to compute? We present an approach that predicts the distribution of rewards given an input and computation budget, then allocates additional computation to inputs for which it is predicted to be most useful. We apply this approach in two decoding procedures: first, an adaptive best-of-k procedure that dynamically selects the number of samples to generate as input to a reranker; second, a routing procedure that dynamically responds to a query using a decoding procedure that is expensive but accurate, or one that is cheaper but less capable. Across a suite of programming, mathematics, and dialog tasks, we show that accurate computation-allocation procedures can be learned, and reduce computation by up to 50% at no cost to response quality, or improve quality by up to 10% at a fixed computational budget. Given a set Importantly, computationally intensive problem of input queries to a language model, we train domains may exhibit considerable variation a lightweight model to estimate the difficulty of in the difficulty of individual problem instances: these queries (more precisely, a model that estimates not all problems are equally hard to how much each query would benefit from solve. Balancing a binary tree might require queries for which it would be most beneficial. Maximally efficient use of computational resources thus requires identifying, a priori, the inputs for which additional computation will improve outputs.