Enabling Approximate Joint Sampling in Diffusion LMs

Bansal, Parikshit, Sanghavi, Sujay

arXiv.org Artificial Intelligence 

In autoregressive language models, each token is sampled by conditioning on all the past tokens; the overall string has thus been sampled from the correct underlying joint distribution represented by the model. In contrast, masked diffusion language models generate text by unmasking tokens out of order and potentially in parallel. Generating an overall string sampled from the correct underlying joint distribution would (again) require exactly one token unmasking in every full-model forward pass. The more tokens unmasked in parallel, the further away the string is from the true joint; this can be seen in the resulting drop in accuracy (but, increase in speed). In this paper we devise a way to approximately sample multiple tokens from the joint distribution in a single full-model forward pass; we do so by developing a new lightweight single-layer "sampler" on top of an existing large diffusion LM. One forward pass of the full model can now be followed by multiple forward passes of only this sampler layer, to yield multiple unmasked tokens. Our sampler is trained to mimic exact joint sampling from the (frozen) full model. We show the effectiveness of our approximate joint sampling for both pretrained-only (Dream-7B-Base) and instruction-tuned (Dream-7B-Instruct) models on language modeling and math & coding tasks. When four tokens are unmasked for each full-model denoising step, our sampling algorithm achieves a MAUVE score of 0.87 (vs marginal baseline of 0.31) with respect to the true joint distribution. Masked diffusion language models Sahoo et al. (2024); Austin et al. (2021); Lou et al. (2023) involve generating text strings by starting from an all-masked sequence of tokens, and then iteratively replacing the masked tokens with tokens from the vocabulary, with each "denoising" forward pass unmasking one or a few tokens. As opposed to auto-regressive models which generate tokens left to right and one token in each forward pass, in masked diffusion models tokens can be potentially unmasked in any order and also potentially multiple tokens can be unmasked in parallel. The higher the number of tokens unmasked in parallel after a single denoising forward pass, the faster and cheaper the overall generation Sahoo et al. (2024).