FLAN-T5-Small-FiD: Fusion-in-Decoder for Multi-Document Question Answering
Model Description
FLAN-T5-Small-FiD is a Fusion-in-Decoder (FiD) model built on google/flan-t5-small for multi-document question answering and retrieval-augmented generation (RAG). It implements the FiD architecture from Izacard & Grave (2021), enabling the model to reason over 100 retrieved passages simultaneously, each up to 512 tokens to produce accurate grounded answers.
The model was fine-tuned on Google's Natural Questions dataset (50,000 training examples streamed from HuggingFace) on an NVIDIA RTX 6000 Pro GPU (96 GB VRAM) using BFloat16 mixed precision with gradient checkpointing, allowing the full 100-passage × 512-token context window to fit in memory during training.
What is Fusion-in-Decoder (FiD)?
FiD is an encoder-decoder architecture designed specifically for open-domain question answering with retrieval. It solves the core scaling problem of transformer self-attention by processing passages independently before fusing them:
- Independent Encoding: Each
(question, passage)pair is encoded independently by the T5 encoder. ForNpassages of lengthL, the encoder processesB × Nsequences of lengthL— no cross-passage self-attention overhead. - Fusion: All encoder hidden states are concatenated along the sequence dimension:
[h₁; h₂; …; hₙ]→ shape(B, N×L, d_model). This is the key innovation, it creates a unified representation of all passages. - Cross-Attention Decoding: The decoder cross-attends over the full concatenated context from all passages simultaneously. This is where the model learns to synthesize information across multiple documents to generate an answer.
This approach is far more efficient than naively concatenating all passages before encoding (which would require quadratic self-attention over N × L tokens). In FiD, the encoder's self-attention is only over L tokens per passage, and the cross-attention in the decoder handles the fusion.
Why FLAN-T5 as the Base?
FLAN-T5 is an instruction-tuned version of T5, fine-tuned on over 1,800 tasks covering chain-of-thought reasoning, dialogue, summarization, and more. Compared to vanilla T5:
- Stronger zero-shot instruction following: better generalization to new QA formats
- Gated-GELU activation: replaces vanilla T5's ReLU FFN with a gated variant, adding a multiplicative gate that improves expressiveness
- Better out-of-the-box performance: FLAN tuning provides a stronger starting point for downstream fine-tuning
Key Features
- ✅ Massive Context Window: Trained on 100 passages × 512 tokens = 51,200 tokens encoded per question
- ✅ Dynamic Input Shapes: The PyTorch model handles any number of passages and any sequence length at inference with no fixed limits
- ✅ Multi-Document Reasoning: Cross-attention decoder synthesizes information across all retrieved passages simultaneously
- ✅ Instruction-Tuned Base: Built on FLAN-T5 for superior instruction following and generalization
- ✅ Production-Grade Training: BFloat16 mixed precision, gradient checkpointing, TF32 acceleration, fault-tolerant checkpointing with mid-epoch resumption
- ✅ Custom Architecture File: Includes
modeling_fid.py, drop it into your project and run. - ✅ ONNX Export Support: Static-shape ONNX export supported for deployment (model coming soon)
Model Specifications
| Attribute | Value |
|---|---|
| Base Model | google/flan-t5-small |
| Architecture | Fusion-in-Decoder (T5 Encoder-Decoder) |
| Parameters | ~77M |
| d_model | 512 |
| d_ff | 1024 |
| d_kv | 64 |
| Encoder Layers | 8 |
| Decoder Layers | 8 |
| Attention Heads | 6 |
| Feed-Forward Activation | Gated GELU (gated-gelu) |
| Vocab Size | 32,128 (SentencePiece) |
| Positional Encoding | T5 Relative Position Bias (32 buckets, max distance 128) |
| Word Embedding Tying | Yes (encoder/decoder/LM-head share embeddings) |
| Dropout | 0.1 |
| Max Answer Length | 256 tokens |
| Training Context | 100 passages × 512 tokens (51,200 tokens per question) |
| PyTorch Shapes | Fully dynamic (any N passages × any L tokens) |
| License | Apache 2.0 |
Related Models
This model is part of a planned family of three FiD variants, each built on a different T5 base:
| Model | Base | Status |
|---|---|---|
| FLAN-T5-Small-FiD (this) | google/flan-t5-small | ✅ Available (ONNX coming soon) |
| LaMini-T5-61M-FiD | MBZUAI/LaMini-T5-61M | 🔜 Coming Soon |
| T5-Small-FiD | google-t5/t5-small | 🔜 Coming Soon |
Installation
pip install torch>=2.0.0 transformers>=4.30.0
Clone this repository (includes modeling_fid.py which defines the FiDT5 class):
git lfs install
git clone https://huggingface.co/syedkhalid0/flan-t5-small-fid
cd flan-t5-small-fid
Usage
Basic Inference
import torch
from transformers import AutoTokenizer
from modeling_fid import FiDT5 # from this repo
def prepare_fid_inputs(question, passages, tokenizer, max_length=512):
"""
Tokenize each (question, passage) pair and stack into FiD input shape.
Each passage is prepended with the question using the format:
"question: {question} context: {passage}"
Returns:
Tuple of (input_ids, attention_mask), each shaped (1, num_passages, max_length).
"""
all_input_ids, all_masks = [], []
for p in passages:
tok = tokenizer(
f"question: {question} context: {p}",
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
all_input_ids.append(tok["input_ids"]) # (1, L)
all_masks.append(tok["attention_mask"]) # (1, L)
# → (1, N, L)
return torch.stack(all_input_ids, dim=1), torch.stack(all_masks, dim=1)
# Load model and tokenizer
model = FiDT5.from_pretrained("syedkhalid0/flan-t5-small-fid")
tokenizer = AutoTokenizer.from_pretrained("syedkhalid0/flan-t5-small-fid")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
# Question + retrieved passages
question = "What is the capital of France?"
passages = [
"Paris is the capital and most populous city of France.",
"France is a country primarily located in Western Europe.",
"The Eiffel Tower is located in Paris, France.",
"Lyon is the third-largest city in France.",
"Marseille is the second-largest city in France.",
]
input_ids, attention_mask = prepare_fid_inputs(question, passages, tokenizer)
input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
print(f"Input shape: {input_ids.shape}") # (1, 5, 512)
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=256,
num_beams=4,
early_stopping=True,
)
answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"Answer: {answer}")
Dynamic Shapes: Any Number of Passages
The PyTorch model accepts any number of passages with any sequence length. The dimensions are extracted at runtime from input_ids.size():
# 3 passages × 256 tokens: works
ids, mask = prepare_fid_inputs(question, passages[:3], tokenizer, max_length=256)
print(ids.shape) # (1, 3, 256)
# 50 passages × 512 tokens: works
ids, mask = prepare_fid_inputs(question, passages * 10, tokenizer, max_length=512)
print(ids.shape) # (1, 50, 512)
# 100 passages (same as training): works
ids, mask = prepare_fid_inputs(question, passages * 20, tokenizer, max_length=512)
print(ids.shape) # (1, 100, 512)
The only constraint is available GPU/CPU memory. For reference, 100 passages × 512 tokens was trained on a 96 GB GPU.
Batch Inference
# Multiple questions in a single batch
questions = ["What is the capital of France?", "Who wrote Hamlet?"]
all_passages = [
["Paris is the capital of France.", "France is in Europe."],
["William Shakespeare wrote Hamlet.", "Hamlet is a tragedy."],
]
batch_ids, batch_masks = [], []
for q, ps in zip(questions, all_passages):
ids, mask = prepare_fid_inputs(q, ps, tokenizer, max_length=256)
batch_ids.append(ids.squeeze(0)) # (N, L)
batch_masks.append(mask.squeeze(0)) # (N, L)
# Pad to same number of passages and stack
# (assuming same N here for simplicity)
input_ids = torch.stack(batch_ids, dim=0).to(device) # (B, N, L)
attention_mask = torch.stack(batch_masks, dim=0).to(device) # (B, N, L)
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=256,
num_beams=4,
early_stopping=True,
)
for i, out in enumerate(output_ids):
print(f"Q: {questions[i]}")
print(f"A: {tokenizer.decode(out, skip_special_tokens=True)}\n")
Integration with a Retriever (FAISS Example)
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
# Build a FAISS index from your corpus
retriever = SentenceTransformer("all-MiniLM-L6-v2")
corpus = [
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
"Tokyo is the capital of Japan.",
# ... your full corpus
]
embeddings = retriever.encode(corpus).astype("float32")
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
def answer_question(question, top_k=10):
"""End-to-end: retrieve passages → FiD → answer."""
# Retrieve
q_emb = retriever.encode([question]).astype("float32")
_, indices = index.search(q_emb, top_k)
retrieved = [corpus[i] for i in indices[0]]
# Prepare FiD inputs
ids, mask = prepare_fid_inputs(question, retrieved, tokenizer, max_length=512)
ids, mask = ids.to(device), mask.to(device)
# Generate
with torch.no_grad():
out = model.generate(
input_ids=ids,
attention_mask=mask,
max_length=256,
num_beams=4,
early_stopping=True,
)
return tokenizer.decode(out[0], skip_special_tokens=True)
print(answer_question("What is the capital of France?"))
Architecture Diagram
┌────────────────────────────────────────────────────────────────────┐
│ T5 ENCODER (Independent) │
│ │
│ "question: Q context: passage₁" → Encoder → h₁ (1, L, 512) │
│ "question: Q context: passage₂" → Encoder → h₂ (1, L, 512) │
│ ... │
│ "question: Q context: passageₙ" → Encoder → hₙ (1, L, 512) │
│ │
│ Internally batched as (B×N, L) for efficiency │
├────────────────────────────────────────────────────────────────────┤
│ FUSION LAYER │
│ │
│ Concatenate: [h₁; h₂; …; hₙ] │
│ Reshape: (B×N, L, 512) → (B, N×L, 512) │
│ │
│ Example (training): N=100, L=512 → (B, 51200, 512) │
│ Attention mask is also reshaped: (B, N, L) → (B, N×L) │
├────────────────────────────────────────────────────────────────────┤
│ T5 DECODER (Cross-Attention) │
│ │
│ Cross-attends over ALL N×L encoder hidden states simultaneously │
│ → Synthesizes information from every passage at every step │
│ → Generates answer autoregressively, token by token │
│ │
│ Output: answer tokens (up to 256 tokens) │
└────────────────────────────────────────────────────────────────────┘
Training Details
Dataset
The model was trained on Google's Natural Questions (NQ) dataset, streamed directly from HuggingFace (google-research-datasets/natural_questions).
| Split | Examples |
|---|---|
| Training | 50,000 |
| Evaluation | 5,000 |
Data preprocessing pipeline:
- Streaming: Dataset is streamed from HuggingFace, no full download required
- Answer extraction: Short answers are extracted from annotator annotations; examples without short answers are skipped
- Document cleaning: HTML tags are stripped from document tokens using the
is_htmlflag, producing clean text passages - Passage chunking: Each cleaned document is split into overlapping chunks of 100 words with a stride of 50 words, preserving context across chunk boundaries
- Passage padding: If a document yields fewer than 100 passages, existing passages are cycled (not padded with empty strings) to provide redundant relevant context rather than noise
- Passage truncation: Passages are capped at 100 per question (
n_context=100)
Input format (per passage):
question: {question} context: {passage}
Label masking: Padding tokens in target sequences are replaced with -100, ensuring they are ignored by the cross-entropy loss during training.
Training Configuration
| Hyperparameter | Value |
|---|---|
| Optimizer | AdamW |
| Learning Rate | 5e-5 |
| LR Schedule | Linear warmup (1,000 steps) + linear decay |
| Batch Size (per step) | 6 |
| Gradient Accumulation Steps | 2 |
| Effective Batch Size | 12 |
| Epochs | 3 |
| Weight Decay | 0.01 |
| Max Gradient Norm | 1.0 |
| Passages per Question | 100 |
| Tokens per Passage | 512 |
| Max Answer Length | 256 tokens |
| Total Context per Question | 100 × 512 = 51,200 tokens |
| Total Optimizer Steps | ~12,501 |
| Evaluation Frequency | Every 500 optimizer steps |
Training Infrastructure
| Component | Details |
|---|---|
| GPU | NVIDIA RTX 6000 Pro (96 GB VRAM) |
| Precision | BFloat16 mixed precision (torch.amp.autocast("cuda", dtype=torch.bfloat16)) |
| TF32 | Enabled (torch.backends.cuda.matmul.allow_tf32 = True) |
| Gradient Checkpointing | Enabled (trades recomputation for memory, critical for 100-passage training) |
| DataLoader Workers | 24 (persistent, prefetch_factor=4) |
| Framework | PyTorch + HuggingFace Transformers |
Optimization Techniques
The training pipeline uses several complementary techniques to maximize throughput and stability on the RTX 6000 Pro:
BFloat16 Mixed Precision: Unlike FP16, BFloat16 has the same exponent range as FP32 (8 bits), virtually eliminating overflow/underflow issues common with FP16 on T5 architectures. This removes the need for
GradScalerentirely while providing ~2× throughput improvement.TF32 (TensorFloat-32): Enabled on Ampere+ GPUs, TF32 uses 19-bit precision for FP32 tensor core operations, providing 8× more FLOPS than IEEE FP32 with negligible accuracy impact.
Gradient Checkpointing: Discards intermediate activations during the forward pass and recomputes them during backpropagation. This is critical for fitting 100 passages × 512 tokens in memory, without it, the activation memory for the encoder alone would exceed available VRAM.
Persistent DataLoader Workers: Workers stay alive between epochs (
persistent_workers=True) with aggressive prefetching (prefetch_factor=4), ensuring the GPU is never starved for data. With 24 workers, the CPU-side tokenization and data loading pipeline saturates the GPU's training throughput.Gradient Accumulation: Accumulates gradients over 2 micro-batches before each optimizer step, achieving an effective batch size of 12 while keeping per-step memory usage bounded.
Fault-Tolerant Checkpointing
The training pipeline implements full mid-epoch resumption:
- Checkpoint contents: Model weights, tokenizer, optimizer state, LR scheduler state, epoch number, step number, global step counter, and best eval loss — everything needed for bit-exact resumption
- Checkpoint frequency: Every 500 optimizer steps + end of every epoch
- Best model tracking: Separately saves the checkpoint with the lowest eval loss
- Resumption: On restart, the training loop fast-forwards through the DataLoader to the exact batch where training was interrupted, then resumes normally
Use Cases
- Open-Domain Question Answering: Answer factual questions using retrieved passages from Wikipedia, knowledge bases, or any corpus
- Retrieval-Augmented Generation (RAG): Ground LLM responses in retrieved evidence for factual accuracy
- Multi-Document Synthesis: Combine information from multiple sources to produce comprehensive answers
- Fact Verification: Cross-check claims against multiple evidence documents simultaneously
- Knowledge-Base QA: Answer customer/support queries from internal documentation
- Research Paper QA: Query across multiple papers or sections for literature review assistance
Limitations
- Retriever dependency: Performance depends heavily on the quality of the retrieved passages. Poor retrieval → poor answers.
- Short-answer focus: Optimized for extractive/factual QA (short answers), not long-form generation or creative writing.
- Memory scaling: GPU memory usage scales linearly with
num_passages × seq_length × d_model. Large passage counts require significant VRAM. - Contradictory evidence: May struggle when passages contain conflicting information, the model has no explicit mechanism for conflict resolution.
- No built-in retriever: This model is the reader component only. You need to pair it with a separate retriever (e.g., FAISS + SentenceTransformers, BM25, Contriever, etc.).
- English only: Trained exclusively on English Natural Questions data.
ONNX Export (Coming Soon)
🔜 A pre-exported ONNX model will be added to this repository soon.
The FiD architecture supports export to ONNX for deployment in environments where PyTorch is not available (e.g., C++, Rust, or edge devices). The ONNX export uses fixed input shapes for maximum runtime compatibility.
Export Instructions
You can export the model to ONNX yourself using the following wrapper:
import torch
import torch.nn as nn
from modeling_fid import FiDT5
from transformers.modeling_outputs import BaseModelOutput
class FiDONNXWrapper(nn.Module):
"""Wrapper with fixed shapes for ONNX export, including attention mask."""
def __init__(self, fid_model, n_contexts, seq_len):
super().__init__()
self.fid_model = fid_model
self.n_contexts = n_contexts
self.seq_len = seq_len
def forward(self, input_ids, attention_mask, decoder_input_ids):
B = input_ids.size(0)
NL = self.n_contexts * self.seq_len
flat_ids = input_ids.view(B * self.n_contexts, self.seq_len)
flat_mask = attention_mask.view(B * self.n_contexts, self.seq_len)
enc_out = self.fid_model.model.encoder(
input_ids=flat_ids, attention_mask=flat_mask,
)
H = enc_out.last_hidden_state.size(-1)
fused_hidden = enc_out.last_hidden_state.view(B, NL, H)
fused_mask = attention_mask.view(B, NL)
dec_out = self.fid_model.model.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=fused_hidden,
encoder_attention_mask=fused_mask, # critical: mask padding
)
logits = self.fid_model.model.lm_head(dec_out.last_hidden_state)
return logits
# Export
model = FiDT5.from_pretrained("syedkhalid0/flan-t5-small-fid")
N, L = 10, 256 # fixed shapes for ONNX
wrapper = FiDONNXWrapper(model, N, L)
wrapper.cpu().eval()
dummy_ids = torch.randint(0, 1000, (1, N, L))
dummy_mask = torch.ones((1, N, L), dtype=torch.long)
dummy_dec = torch.zeros((1, 1), dtype=torch.long)
torch.onnx.export(
wrapper,
(dummy_ids, dummy_mask, dummy_dec),
"model_fid.onnx",
input_names=["input_ids", "attention_mask", "decoder_input_ids"],
output_names=["logits"],
dynamic_axes=None,
opset_version=18,
)
Planned ONNX export shapes:
| Input | Shape |
|---|---|
input_ids |
(1, 10, 256) |
attention_mask |
(1, 10, 256) |
decoder_input_ids |
(1, 1) |
| Output | Shape |
|---|---|
logits |
(1, 1, 32128) |
Note: For flexible input shapes, use the PyTorch model. The ONNX export is intended for fixed-shape deployment scenarios.
Repository Files
syedkhalid0/flan-t5-small-fid/
├── README.md ← This file
├── modeling_fid.py ← FiDT5 class (required for loading)
├── example_usage.py ← Working inference example
├── config.json ← T5 model configuration
├── fid_config.json ← FiD-specific config (n_contexts, maxlengths)
├── generation_config.json ← Generation settings
├── model.safetensors ← Model weights (PyTorch, best checkpoint)
├── tokenizer.json ← Tokenizer vocabulary (SentencePiece, 32K tokens)
├── tokenizer_config.json ← Tokenizer configuration
├── requirements.txt ← Python dependencies
└── .gitattributes ← Git LFS tracking rules
Citation
If you use this model in your research or applications, please cite:
@misc{flan-t5-small-fid-2026,
author = {Syed Khalid Hussain},
title = {FLAN-T5-Small-FiD: Fusion-in-Decoder for Multi-Document Question Answering},
year = {2026},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/syedkhalid0/flan-t5-small-fid}},
note = {Fine-tuned on Natural Questions with 100-passage FiD architecture}
}
Original FiD paper:
@article{izacard2021leveraging,
title = {Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering},
author = {Izacard, Gautier and Grave, Edouard},
journal = {arXiv preprint arXiv:2007.01282},
year = {2021}
}
FLAN-T5 base model:
@article{chung2022scaling,
title = {Scaling Instruction-Finetuned Language Models},
author = {Chung, Hyung Won and Hou, Le and Longpre, Shayne and others},
journal = {arXiv preprint arXiv:2210.11416},
year = {2022}
}
Contact
- Author: Syed Khalid Hussain
- Repository: syedkhalid0/flan-t5-small-fid
- Issues: Use the Community tab on this repo
- Email: hello@syedkhalid.tech
- Support: Buy Me a Coffee ☕
License
Apache 2.0, see LICENSE.
- Downloads last month
- 251
Model tree for syedkhalid0/flan-t5-small-fid
Base model
google/flan-t5-small