How Text Diffusion Works

4 hours ago 2

When Google released their latest Gemini family at I/O, they also announced Gemini Diffusion to go along with them. Aptly named1, it's based on a new diffusion architecture and claims competitive metrics on reasoning, math, and code with their other state-of-the-art models.

These metrics certainly surprised me. Diffusion models have long been thought of as somewhat of a novelty in text generation since real-world performance was always lagging behind other architectures. It's not as theoretically clear how you can produce a reasonable output by sampling tokens at the same time in random places in a paragraph.

Normal generative transformer architectures (GPT, Claude) are autoregressive in their training and in their inference. They proceed word by word, left to right, adding the most likely word in the sequence every time. By definition the next token needs to be conditioned on what came before. You write a paragraph word by word - one at a time - and so should a language model.

But what if it didn't have to be one at a time?

Autoregressive Generation

In autoregressive generation, each new token is predicted based on all the tokens that came before it. Think of it like writing a sentence where you can only see the words you've already written - you pick the next word based on the context so far, then move to the next position.

This creates a dependency chain: token 3 depends on tokens 1 and 2, token 4 depends on tokens 1, 2, and 3, and so on. It's sequential by nature, which means you can't parallelize the generation process. As models grow in parameter size, more flops are required to calculate each token. So the overall time to generate a full response gets proportionally higher.

Autoregressive training is similarly causal - the model can only attend to previous tokens, never future ones. The model is pretrained to output the next token in the sequence without any masks coming into play.

Click the buttons on this widget to see how the autoregressive prediction looks at inference time.

Text Diffusion

You've probably seen image diffusion at this point via Midjourney and some of the other image generators. Image diffusion works by starting with pure noise2 and then gradually removing that noise over multiple steps to render a coherent image. At each step, the model predicts what the "less noisy" version should look like, bit by bit converging on the final output. The key insight is that you can work backwards from noise to signal because the pretraining process simulates this denoising.

Text diffusion follows a similar principle, but instead of denoising pixels, we're "denoising" tokens. We start with a sentence full of [MASK] tokens (the text equivalent of noise) and progressively fill them in over multiple steps.

Google is still keeping their specific training procedure and architecture under wraps. But we can look to some other contemporary papers for a pretty good idea of what they're probably doing. LLaDA was one of the first papers that showed this was both technically feasible and wouldn't hurt model accuracy.

They speculate that the underlying power of language models comes not from their autoregressive nature, but by being very good modelers of the underlying language distribution of the next word3. If the modeling is robust you should be able to predict the word at some arbitrary portion of the paragraph and then fill in the rest of the blanks.

Diffusion Pretraining

Since we no longer want to only get one word out at a time, we train a transformer that takes the entire sequence at one time (with masks) and performs a seq-to-seq mapping task. The output layer predicts tokens for all time slots, but the loss is only calculated on the masked token positions. This is almost identical to how BERT was pretrained. Even though it was a transformer, it wasn't a generative architecture - which is important because it only needed to create an internal representation of a sentence without needing to generate an arbitrary length of text as output.

Supervised finetuning4 is handled in the same way, except the masking procedure only will mask out response tokens. We assume we will have access to the full prompt at inference time and it's only the response that we need to predict from scratch.

Here's what the actual pretraining process looks like with one datapoint. It walks you through the masking and loss calculation.

Diffusion Inference

So far, this just looks like a normal sequence to sequence task. And other research shows pretty comprehensively that one-shot prediction of an entire response body doesn't work that well. That's where diffusion comes in.

Diffusion runs in a loop. You set the amount of timesteps in your inference budget up-front. I imagine you can also make this dynamic if the diffusion process isn't converging after some initial time.

You start with everything masked, then run a forward pass with the model. At this point you'll have the full predictions of every output token - but without any cross-attention between them. We need to assume that some of these tokens are correct, temporarily at least, so we can feed them forward to the next pass. We therefore remask tokens that will be predicted in the next loop according to some heuristic: random, low confidence, etc. This will let the model witness the new assumed input and fill in the blanks.

You can remask at random but should do so more aggressively towards the beginning: the model is less sure of the output, so remasking ~90% of low confidence tokens gives the model more of an ability to correct its output. At the end you'll want to hone the final text by choosing maybe ~10%5. This mirrors classic continuous diffusion processes (like in image generation) since you perform more aggressive denoising in the beginning before tuning the final pixel values.

In sum:

  1. Start fully masked: Begin with every token as [MASK]
  2. Predict all positions: Run the model to get predictions for every masked position
  3. Selective unmasking: Only unmask a fraction of the predictions (the "best" ones according to some strategy)
  4. Remask the rest: mask the other values so the model is forced to predict them from scratch, only biasing towards the now unmasked content
  5. Repeat: Continue this process over multiple steps

Assuming we select more than 1 masked token at a time, and our overall diffusion timesteps are less than the autoregressive steps of generating the output every token - we should save on the overall prediction time for the model when compared to autoregressive models.

Conclusion

In their release post, Google mentions that they believe the majority of frontier models will be diffusion based in the next few years. I think that's certainly possible. They're definitely faster - and as LLMs are increasing used for looped workflows like agents and code generation, I can see any fundamental speed gains becoming sorely needed.

Companies can't deploy GPUs fast enough, and my sense is at most times the foundational model labs are hosting their chips at 90+% utilization. Faster inference doesn't just mean faster speeds to end users but it means lower time-on-chips, higher throughput, and therefore fewer chip requirements.

What I don't yet have an intuitive feeling for is whether diffusion can yield the same results as autoregressive predictions in steady state research. Language modeling becomes a less obvious task when you don't have the linearity of the whole sequence beforehand. For chain-of-thought reasoning it seems especially unintuitive since you should attend to earlier logic before later ones.

We'll see how this one diffuses out.

Read Entire Article