Do Transformers Actually perform Bayesian Inference? New Research Says Yes and Shows Exactly How

Community Article Published April 21, 2026

A deep dive into "The Bayesian Geometry of Transformer Attention", a paper that finally gives us a rigorous answer to one of the most debated questions in LLM interpretability.

The Bayesian Geometry of Transformer Attention


It's become common to describe large language models as "Bayesian learners" models that somehow update beliefs as they process context, narrowing down possibilities with each new token. But is that actually true, or is it just a flattering metaphor?

A new paper, "The Bayesian Geometry of Transformer Attention" by Agarwal, Dalal, and Misra (Columbia University / Google DeepMind) , takes a rigorous crack at this question. And the answer is striking: small transformers, trained with nothing but standard cross-entropy loss, can implement exact Bayesian inference down to 10⁻³–10⁻⁴ bit accuracy.

More importantly, the paper doesn't stop at "yes, they do it." It opens up the hood and shows exactly how and why using geometric diagnostics that are concrete enough to be tested in frontier LLMs.


The Core Problem: You Can't Verify Bayes on Natural Language

Before this work, evidence for Bayesian behavior in transformers was largely behavioral. Models looked Bayesian, but you could never rule out that they were pattern-matching or exploiting memorized templates.

The problem is simple: for natural language, you don't know the ground-truth posterior. There's no oracle to compare against.

The authors' solution is elegant: don't use natural language at all.


Introducing: Bayesian Wind Tunnels

The key methodological contribution is the concept of a Bayesian wind tunnel a controlled prediction task satisfying three properties:

  1. The analytic posterior is known exactly at every step
  2. The hypothesis space is so large that memorization is computationally infeasible
  3. Correct prediction requires genuine probabilistic inference

This converts the qualitative question ("does it do Bayes?") into a quantitative test: does the model's predictive entropy match the analytic posterior entropy, position by position?

The authors build four wind tunnels:

  • Bijection learning the model observes input-output pairs from a random bijection (permutation) and must predict unseen mappings. With a vocabulary of 20, there are 20! ≈ 2.4 × 10¹⁸ possible bijections. The ground-truth posterior entropy at step k is simply log₂(20 − k + 1) a clean descending staircase.
  • HMM state tracking each sequence is generated from a freshly sampled Hidden Markov Model. The model must track the hidden state posterior using the forward algorithm.
  • Bayesian regression continuous inference over Gaussian linear weights with a closed-form predictive posterior.
  • Associative recall content-based retrieval of cue-target pairs, testing a specific capability the authors call random-access binding.

The Main Result: Transformers Track the Posterior with Machine-Precision Accuracy

On the bijection task, a 2.67M-parameter transformer achieves:

MAE = 3 × 10⁻³ bits smaller than single-precision floating-point noise in the analytic posterior itself.

On HMM filtering, a 2.68M transformer achieves MAE = 7.5 × 10⁻⁵ bits at training length, and generalizes gracefully to 2.5× longer sequences with no discontinuity at the training boundary strong evidence it learned a recursive algorithm, not a finite-horizon memorized computation.

A capacity-matched MLP, given the same data, same parameters, and same training? It fails catastrophically 5,467× worse on the HMM task.

The key metric throughout is entropy MAE:

MAE=1LkHmodel(k)HBayes(k)\text{MAE} = \frac{1}{L} \sum_k |H_{\text{model}}(k) - H_{\text{Bayes}}(k)|

This is a direct, interpretable, bit-level measure of Bayesian correctness independent of accuracy or perplexity.


A Taxonomy of Inference Primitives

The most intellectually interesting contribution is a decomposition of Bayesian inference into three primitives, each tested by a different wind tunnel:

Primitive Description Task
Belief accumulation Integrating evidence into a running posterior Bijection (hypothesis elimination)
Belief transport Propagating beliefs through stochastic dynamics HMM filtering
Random-access binding Retrieving stored hypotheses by content Associative recall

Different architectures realize different subsets of these primitives. The authors test four: Transformers, Mamba (SSM), LSTMs, and MLPs.

Here's the full comparison table:

Architecture Bijection (accumulation) HMM (transport) Assoc. Recall (binding) Primitives realized
Transformer 0.007 bits 0.049 bits 100% All 3 ✅
Mamba 0.010 bits 0.024 bits 97.8% (slow) 2 of 3 🟡
LSTM 0.009 bits 0.411 bits 0.5% (chance) 1 of 3 ❌
MLP 1.85 bits 0.40 bits 0 of 3 ❌

The taxonomy is clean and explanatory:

  • Transformers realize all three because attention externalizes belief as a geometric, addressable representation rather than compressing it into fixed-size state.
  • Mamba excels at HMM filtering (even beating the transformer 0.024 vs 0.049 bits) because its selective state-space mechanism is well-suited for belief transport. But it struggles with binding, requiring 2.5× more training and still only reaching 97.8% on associative recall.
  • LSTMs succeed on bijection because the sufficient statistic is static (just track which outputs have been seen). They fail on HMM because the belief vector must be transported through a transition matrix something fixed gating can't do under standard training.
  • MLPs fail uniformly. Without sequence structure, there's no hope.

Opening the Hood: How Transformers Actually Do It

The behavioral results are impressive, but the mechanistic analysis is where the paper really shines.

Using ablations, QK geometry, probe dynamics, and training trajectories, the authors identify a three-stage computational mechanism:

Stage 1 Foundational Binding (Layer 0)

Layer 0 attention constructs an orthogonal hypothesis frame. Key vectors for distinct input tokens form a near-orthogonal basis a coordinate system over which posterior mass can be represented. Off-diagonal cosine similarities drop to ~0.052 vs. 0.082 for random vectors (37% reduction, p < 0.001).

A single "hypothesis-frame head" in Layer 0 is uniquely indispensable. Ablating just this head severely disrupts calibration. Ablating any other head? Much smaller effect. This is the structural bottleneck: you can't do Bayesian inference without first building the frame.

Stage 2 Progressive Elimination (Middle Layers)

As depth increases, queries align more strongly with the feasible subset of keys the hypotheses that haven't been ruled out yet. Early layers attend broadly; deeper layers concentrate attention almost exclusively on remaining candidates.

This geometric sharpening directly mirrors Bayesian conditioning inconsistent hypotheses receive vanishing attention weight. Layer-wise ablations confirm the computation is non-redundant: removing any single layer increases error by more than an order of magnitude.

Stage 3 Precision Refinement (Late Layers)

After routing stabilizes, the final layers refine the precision of the posterior. Value representations transition from scattered clusters to a smooth one-dimensional manifold parameterized by posterior entropy. This "value manifold unfurling" is what enables fine-grained encoding of residual uncertainty and it's visible in PCA projections of the attention output across training checkpoints.

There's a clean frame–precision dissociation: attention maps stabilize early and change little, while value representations keep improving. Attention decides where information flows; FFN layers handle the heavy Bayesian computation.


Mamba's Mechanism: A Different Road to the Same Geometry

Mamba doesn't use attention but it still achieves near-Bayesian performance on HMM filtering. How?

Its final-layer representations self-organize into five discrete clusters, one per HMM hidden state with within-cluster variation encoding posterior entropy. This is the corner geometry of the belief simplex, the same geometry transformers discover, but achieved via input-dependent state selection (Δ, B, C matrices) rather than query-key matching.

Where Mamba breaks down is predictable from the primitives framework: binding requires retrieving an arbitrary past position on demand. Attention does this in O(1) via query-key matching. Mamba must simulate retrieval through its recurrent state which is slower and less precise. This exactly explains the 2.5× training gap and 97.8% ceiling on associative recall.


Why This Matters Beyond the Wind Tunnels

The paper is careful to frame its contribution precisely: this is about what the learned function computes (the filtering posterior over hidden states), not about Bayesian uncertainty over network weights.

But the implications reach beyond these controlled settings. The geometric diagnostics uncovered here key orthogonality, progressive QK sharpening, value-manifold structure are testable predictions for frontier LLMs. If similar Bayesian manifolds arise in large pretrained models, it would suggest that probabilistic reasoning in LLMs isn't just emergent from scale, but is architecturally grounded in the same geometric structures visible in these tiny, verifiable systems.

The paper also gives us a principled lower bound for reasoning capability: if a model can't implement Bayes in a setting with a known posterior and impossible memorization, there's little basis for claiming genuine inference in natural language.


What's Coming Next: The Trilogy

This is Paper I of three. The planned arc:

  • Paper II (already on arXiv 2512.22473): Shows that the Bayesian geometry arises generically from gradient dynamics under cross-entropy training explaining why transformers learn this structure.
  • Paper III: Shows how these primitives compose in partially observed settings closer to natural language.

Together, the trilogy aims to characterize when, why, and how neural sequence models implement probabilistic reasoning.


TL;DR

  • Small transformers implement exact Bayesian posteriors in controlled settings, with sub-bit accuracy, even beyond training length
  • Capacity-matched MLPs fail by orders of magnitude this is a genuine architectural capability gap
  • Bayesian inference decomposes into three primitives: accumulation, transport, binding and different architectures realize different subsets
  • Mamba beats transformers on belief transport (HMM), but can't fully implement random-access binding
  • LSTMs handle accumulation of static statistics but fail when belief states must evolve
  • Internally, transformers implement a three-stage mechanism: orthogonal hypothesis framing → progressive QK sharpening → value manifold refinement
  • These geometric signatures are concrete, testable predictions for analyzing large pretrained models

Paper: arXiv:2512.22471 Authors: Naman Agarwal (Google DeepMind), Siddhartha R. Dalal (Columbia), Vishal Misra (Columbia)

Community

Article author

If you like this article, please support me by Upvoting the same⬆️😉

Sign up or log in to comment