FLAN-T5-Small-FiD: Fusion-in-Decoder for Multi-Document Question Answering

License Base Model Architecture Dataset Training Precision


Model Description

FLAN-T5-Small-FiD is a Fusion-in-Decoder (FiD) model built on google/flan-t5-small for multi-document question answering and retrieval-augmented generation (RAG). It implements the FiD architecture from Izacard & Grave (2021), enabling the model to reason over 100 retrieved passages simultaneously, each up to 512 tokens to produce accurate grounded answers.

The model was fine-tuned on Google's Natural Questions dataset (50,000 training examples streamed from HuggingFace) on an NVIDIA RTX 6000 Pro GPU (96 GB VRAM) using BFloat16 mixed precision with gradient checkpointing, allowing the full 100-passage × 512-token context window to fit in memory during training.

What is Fusion-in-Decoder (FiD)?

FiD is an encoder-decoder architecture designed specifically for open-domain question answering with retrieval. It solves the core scaling problem of transformer self-attention by processing passages independently before fusing them:

  1. Independent Encoding: Each (question, passage) pair is encoded independently by the T5 encoder. For N passages of length L, the encoder processes B × N sequences of length L — no cross-passage self-attention overhead.
  2. Fusion: All encoder hidden states are concatenated along the sequence dimension: [h₁; h₂; …; hₙ] → shape (B, N×L, d_model). This is the key innovation, it creates a unified representation of all passages.
  3. Cross-Attention Decoding: The decoder cross-attends over the full concatenated context from all passages simultaneously. This is where the model learns to synthesize information across multiple documents to generate an answer.

This approach is far more efficient than naively concatenating all passages before encoding (which would require quadratic self-attention over N × L tokens). In FiD, the encoder's self-attention is only over L tokens per passage, and the cross-attention in the decoder handles the fusion.

Why FLAN-T5 as the Base?

FLAN-T5 is an instruction-tuned version of T5, fine-tuned on over 1,800 tasks covering chain-of-thought reasoning, dialogue, summarization, and more. Compared to vanilla T5:

  • Stronger zero-shot instruction following: better generalization to new QA formats
  • Gated-GELU activation: replaces vanilla T5's ReLU FFN with a gated variant, adding a multiplicative gate that improves expressiveness
  • Better out-of-the-box performance: FLAN tuning provides a stronger starting point for downstream fine-tuning

Key Features

  • Massive Context Window: Trained on 100 passages × 512 tokens = 51,200 tokens encoded per question
  • Dynamic Input Shapes: The PyTorch model handles any number of passages and any sequence length at inference with no fixed limits
  • Multi-Document Reasoning: Cross-attention decoder synthesizes information across all retrieved passages simultaneously
  • Instruction-Tuned Base: Built on FLAN-T5 for superior instruction following and generalization
  • Production-Grade Training: BFloat16 mixed precision, gradient checkpointing, TF32 acceleration, fault-tolerant checkpointing with mid-epoch resumption
  • Custom Architecture File: Includes modeling_fid.py, drop it into your project and run.
  • ONNX Export Support: Static-shape ONNX export supported for deployment (model coming soon)

Model Specifications

Attribute Value
Base Model google/flan-t5-small
Architecture Fusion-in-Decoder (T5 Encoder-Decoder)
Parameters ~77M
d_model 512
d_ff 1024
d_kv 64
Encoder Layers 8
Decoder Layers 8
Attention Heads 6
Feed-Forward Activation Gated GELU (gated-gelu)
Vocab Size 32,128 (SentencePiece)
Positional Encoding T5 Relative Position Bias (32 buckets, max distance 128)
Word Embedding Tying Yes (encoder/decoder/LM-head share embeddings)
Dropout 0.1
Max Answer Length 256 tokens
Training Context 100 passages × 512 tokens (51,200 tokens per question)
PyTorch Shapes Fully dynamic (any N passages × any L tokens)
License Apache 2.0

Related Models

This model is part of a planned family of three FiD variants, each built on a different T5 base:

Model Base Status
FLAN-T5-Small-FiD (this) google/flan-t5-small ✅ Available (ONNX coming soon)
LaMini-T5-61M-FiD MBZUAI/LaMini-T5-61M 🔜 Coming Soon
T5-Small-FiD google-t5/t5-small 🔜 Coming Soon

Installation

pip install torch>=2.0.0 transformers>=4.30.0

Clone this repository (includes modeling_fid.py which defines the FiDT5 class):

git lfs install
git clone https://huggingface.co/syedkhalid0/flan-t5-small-fid
cd flan-t5-small-fid

Usage

Basic Inference

import torch
from transformers import AutoTokenizer
from modeling_fid import FiDT5  # from this repo


def prepare_fid_inputs(question, passages, tokenizer, max_length=512):
    """
    Tokenize each (question, passage) pair and stack into FiD input shape.

    Each passage is prepended with the question using the format:
        "question: {question} context: {passage}"

    Returns:
        Tuple of (input_ids, attention_mask), each shaped (1, num_passages, max_length).
    """
    all_input_ids, all_masks = [], []
    for p in passages:
        tok = tokenizer(
            f"question: {question} context: {p}",
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        all_input_ids.append(tok["input_ids"])       # (1, L)
        all_masks.append(tok["attention_mask"])       # (1, L)

    # → (1, N, L)
    return torch.stack(all_input_ids, dim=1), torch.stack(all_masks, dim=1)


# Load model and tokenizer
model = FiDT5.from_pretrained("syedkhalid0/flan-t5-small-fid")
tokenizer = AutoTokenizer.from_pretrained("syedkhalid0/flan-t5-small-fid")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()

# Question + retrieved passages
question = "What is the capital of France?"
passages = [
    "Paris is the capital and most populous city of France.",
    "France is a country primarily located in Western Europe.",
    "The Eiffel Tower is located in Paris, France.",
    "Lyon is the third-largest city in France.",
    "Marseille is the second-largest city in France.",
]

input_ids, attention_mask = prepare_fid_inputs(question, passages, tokenizer)
input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)

print(f"Input shape: {input_ids.shape}")  # (1, 5, 512)

with torch.no_grad():
    output_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=256,
        num_beams=4,
        early_stopping=True,
    )

answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"Answer: {answer}")

Dynamic Shapes: Any Number of Passages

The PyTorch model accepts any number of passages with any sequence length. The dimensions are extracted at runtime from input_ids.size():

# 3 passages × 256 tokens: works
ids, mask = prepare_fid_inputs(question, passages[:3], tokenizer, max_length=256)
print(ids.shape)  # (1, 3, 256)

# 50 passages × 512 tokens: works
ids, mask = prepare_fid_inputs(question, passages * 10, tokenizer, max_length=512)
print(ids.shape)  # (1, 50, 512)

# 100 passages (same as training): works
ids, mask = prepare_fid_inputs(question, passages * 20, tokenizer, max_length=512)
print(ids.shape)  # (1, 100, 512)

The only constraint is available GPU/CPU memory. For reference, 100 passages × 512 tokens was trained on a 96 GB GPU.

Batch Inference

# Multiple questions in a single batch
questions = ["What is the capital of France?", "Who wrote Hamlet?"]
all_passages = [
    ["Paris is the capital of France.", "France is in Europe."],
    ["William Shakespeare wrote Hamlet.", "Hamlet is a tragedy."],
]

batch_ids, batch_masks = [], []
for q, ps in zip(questions, all_passages):
    ids, mask = prepare_fid_inputs(q, ps, tokenizer, max_length=256)
    batch_ids.append(ids.squeeze(0))    # (N, L)
    batch_masks.append(mask.squeeze(0)) # (N, L)

# Pad to same number of passages and stack
# (assuming same N here for simplicity)
input_ids = torch.stack(batch_ids, dim=0).to(device)       # (B, N, L)
attention_mask = torch.stack(batch_masks, dim=0).to(device) # (B, N, L)

with torch.no_grad():
    output_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=256,
        num_beams=4,
        early_stopping=True,
    )

for i, out in enumerate(output_ids):
    print(f"Q: {questions[i]}")
    print(f"A: {tokenizer.decode(out, skip_special_tokens=True)}\n")

Integration with a Retriever (FAISS Example)

from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

# Build a FAISS index from your corpus
retriever = SentenceTransformer("all-MiniLM-L6-v2")
corpus = [
    "Paris is the capital of France.",
    "Berlin is the capital of Germany.",
    "Tokyo is the capital of Japan.",
    # ... your full corpus
]
embeddings = retriever.encode(corpus).astype("float32")
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)


def answer_question(question, top_k=10):
    """End-to-end: retrieve passages → FiD → answer."""
    # Retrieve
    q_emb = retriever.encode([question]).astype("float32")
    _, indices = index.search(q_emb, top_k)
    retrieved = [corpus[i] for i in indices[0]]

    # Prepare FiD inputs
    ids, mask = prepare_fid_inputs(question, retrieved, tokenizer, max_length=512)
    ids, mask = ids.to(device), mask.to(device)

    # Generate
    with torch.no_grad():
        out = model.generate(
            input_ids=ids,
            attention_mask=mask,
            max_length=256,
            num_beams=4,
            early_stopping=True,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)


print(answer_question("What is the capital of France?"))

Architecture Diagram

┌────────────────────────────────────────────────────────────────────┐
│                        T5 ENCODER (Independent)                    │
│                                                                    │
│  "question: Q  context: passage₁"  → Encoder → h₁  (1, L, 512)     │
│  "question: Q  context: passage₂"  → Encoder → h₂  (1, L, 512)     │
│  ...                                                               │
│  "question: Q  context: passageₙ"  → Encoder → hₙ  (1, L, 512)     │
│                                                                    │
│  Internally batched as (B×N, L) for efficiency                     │
├────────────────────────────────────────────────────────────────────┤
│                        FUSION LAYER                                │
│                                                                    │
│  Concatenate: [h₁; h₂; …; hₙ]                                      │
│  Reshape:     (B×N, L, 512) → (B, N×L, 512)                        │
│                                                                    │
│  Example (training): N=100, L=512 → (B, 51200, 512)                │
│  Attention mask is also reshaped: (B, N, L) → (B, N×L)             │
├────────────────────────────────────────────────────────────────────┤
│                     T5 DECODER (Cross-Attention)                   │
│                                                                    │
│  Cross-attends over ALL N×L encoder hidden states simultaneously   │
│  → Synthesizes information from every passage at every step        │
│  → Generates answer autoregressively, token by token               │
│                                                                    │
│  Output: answer tokens (up to 256 tokens)                          │
└────────────────────────────────────────────────────────────────────┘

Training Details

Dataset

The model was trained on Google's Natural Questions (NQ) dataset, streamed directly from HuggingFace (google-research-datasets/natural_questions).

Split Examples
Training 50,000
Evaluation 5,000

Data preprocessing pipeline:

  1. Streaming: Dataset is streamed from HuggingFace, no full download required
  2. Answer extraction: Short answers are extracted from annotator annotations; examples without short answers are skipped
  3. Document cleaning: HTML tags are stripped from document tokens using the is_html flag, producing clean text passages
  4. Passage chunking: Each cleaned document is split into overlapping chunks of 100 words with a stride of 50 words, preserving context across chunk boundaries
  5. Passage padding: If a document yields fewer than 100 passages, existing passages are cycled (not padded with empty strings) to provide redundant relevant context rather than noise
  6. Passage truncation: Passages are capped at 100 per question (n_context=100)

Input format (per passage):

question: {question} context: {passage}

Label masking: Padding tokens in target sequences are replaced with -100, ensuring they are ignored by the cross-entropy loss during training.

Training Configuration

Hyperparameter Value
Optimizer AdamW
Learning Rate 5e-5
LR Schedule Linear warmup (1,000 steps) + linear decay
Batch Size (per step) 6
Gradient Accumulation Steps 2
Effective Batch Size 12
Epochs 3
Weight Decay 0.01
Max Gradient Norm 1.0
Passages per Question 100
Tokens per Passage 512
Max Answer Length 256 tokens
Total Context per Question 100 × 512 = 51,200 tokens
Total Optimizer Steps ~12,501
Evaluation Frequency Every 500 optimizer steps

Training Infrastructure

Component Details
GPU NVIDIA RTX 6000 Pro (96 GB VRAM)
Precision BFloat16 mixed precision (torch.amp.autocast("cuda", dtype=torch.bfloat16))
TF32 Enabled (torch.backends.cuda.matmul.allow_tf32 = True)
Gradient Checkpointing Enabled (trades recomputation for memory, critical for 100-passage training)
DataLoader Workers 24 (persistent, prefetch_factor=4)
Framework PyTorch + HuggingFace Transformers

Optimization Techniques

The training pipeline uses several complementary techniques to maximize throughput and stability on the RTX 6000 Pro:

  1. BFloat16 Mixed Precision: Unlike FP16, BFloat16 has the same exponent range as FP32 (8 bits), virtually eliminating overflow/underflow issues common with FP16 on T5 architectures. This removes the need for GradScaler entirely while providing ~2× throughput improvement.

  2. TF32 (TensorFloat-32): Enabled on Ampere+ GPUs, TF32 uses 19-bit precision for FP32 tensor core operations, providing 8× more FLOPS than IEEE FP32 with negligible accuracy impact.

  3. Gradient Checkpointing: Discards intermediate activations during the forward pass and recomputes them during backpropagation. This is critical for fitting 100 passages × 512 tokens in memory, without it, the activation memory for the encoder alone would exceed available VRAM.

  4. Persistent DataLoader Workers: Workers stay alive between epochs (persistent_workers=True) with aggressive prefetching (prefetch_factor=4), ensuring the GPU is never starved for data. With 24 workers, the CPU-side tokenization and data loading pipeline saturates the GPU's training throughput.

  5. Gradient Accumulation: Accumulates gradients over 2 micro-batches before each optimizer step, achieving an effective batch size of 12 while keeping per-step memory usage bounded.

Fault-Tolerant Checkpointing

The training pipeline implements full mid-epoch resumption:

  • Checkpoint contents: Model weights, tokenizer, optimizer state, LR scheduler state, epoch number, step number, global step counter, and best eval loss — everything needed for bit-exact resumption
  • Checkpoint frequency: Every 500 optimizer steps + end of every epoch
  • Best model tracking: Separately saves the checkpoint with the lowest eval loss
  • Resumption: On restart, the training loop fast-forwards through the DataLoader to the exact batch where training was interrupted, then resumes normally

Use Cases

  • Open-Domain Question Answering: Answer factual questions using retrieved passages from Wikipedia, knowledge bases, or any corpus
  • Retrieval-Augmented Generation (RAG): Ground LLM responses in retrieved evidence for factual accuracy
  • Multi-Document Synthesis: Combine information from multiple sources to produce comprehensive answers
  • Fact Verification: Cross-check claims against multiple evidence documents simultaneously
  • Knowledge-Base QA: Answer customer/support queries from internal documentation
  • Research Paper QA: Query across multiple papers or sections for literature review assistance

Limitations

  • Retriever dependency: Performance depends heavily on the quality of the retrieved passages. Poor retrieval → poor answers.
  • Short-answer focus: Optimized for extractive/factual QA (short answers), not long-form generation or creative writing.
  • Memory scaling: GPU memory usage scales linearly with num_passages × seq_length × d_model. Large passage counts require significant VRAM.
  • Contradictory evidence: May struggle when passages contain conflicting information, the model has no explicit mechanism for conflict resolution.
  • No built-in retriever: This model is the reader component only. You need to pair it with a separate retriever (e.g., FAISS + SentenceTransformers, BM25, Contriever, etc.).
  • English only: Trained exclusively on English Natural Questions data.

ONNX Export (Coming Soon)

🔜 A pre-exported ONNX model will be added to this repository soon.

The FiD architecture supports export to ONNX for deployment in environments where PyTorch is not available (e.g., C++, Rust, or edge devices). The ONNX export uses fixed input shapes for maximum runtime compatibility.

Export Instructions

You can export the model to ONNX yourself using the following wrapper:

import torch
import torch.nn as nn
from modeling_fid import FiDT5
from transformers.modeling_outputs import BaseModelOutput


class FiDONNXWrapper(nn.Module):
    """Wrapper with fixed shapes for ONNX export, including attention mask."""

    def __init__(self, fid_model, n_contexts, seq_len):
        super().__init__()
        self.fid_model = fid_model
        self.n_contexts = n_contexts
        self.seq_len = seq_len

    def forward(self, input_ids, attention_mask, decoder_input_ids):
        B = input_ids.size(0)
        NL = self.n_contexts * self.seq_len

        flat_ids = input_ids.view(B * self.n_contexts, self.seq_len)
        flat_mask = attention_mask.view(B * self.n_contexts, self.seq_len)

        enc_out = self.fid_model.model.encoder(
            input_ids=flat_ids, attention_mask=flat_mask,
        )
        H = enc_out.last_hidden_state.size(-1)

        fused_hidden = enc_out.last_hidden_state.view(B, NL, H)
        fused_mask = attention_mask.view(B, NL)

        dec_out = self.fid_model.model.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=fused_hidden,
            encoder_attention_mask=fused_mask,  # critical: mask padding
        )

        logits = self.fid_model.model.lm_head(dec_out.last_hidden_state)
        return logits


# Export
model = FiDT5.from_pretrained("syedkhalid0/flan-t5-small-fid")
N, L = 10, 256  # fixed shapes for ONNX
wrapper = FiDONNXWrapper(model, N, L)
wrapper.cpu().eval()

dummy_ids = torch.randint(0, 1000, (1, N, L))
dummy_mask = torch.ones((1, N, L), dtype=torch.long)
dummy_dec = torch.zeros((1, 1), dtype=torch.long)

torch.onnx.export(
    wrapper,
    (dummy_ids, dummy_mask, dummy_dec),
    "model_fid.onnx",
    input_names=["input_ids", "attention_mask", "decoder_input_ids"],
    output_names=["logits"],
    dynamic_axes=None,
    opset_version=18,
)

Planned ONNX export shapes:

Input Shape
input_ids (1, 10, 256)
attention_mask (1, 10, 256)
decoder_input_ids (1, 1)
Output Shape
logits (1, 1, 32128)

Note: For flexible input shapes, use the PyTorch model. The ONNX export is intended for fixed-shape deployment scenarios.


Repository Files

syedkhalid0/flan-t5-small-fid/
├── README.md                  ← This file
├── modeling_fid.py            ← FiDT5 class (required for loading)
├── example_usage.py           ← Working inference example
├── config.json                ← T5 model configuration
├── fid_config.json            ← FiD-specific config (n_contexts, maxlengths)
├── generation_config.json     ← Generation settings
├── model.safetensors          ← Model weights (PyTorch, best checkpoint)
├── tokenizer.json             ← Tokenizer vocabulary (SentencePiece, 32K tokens)
├── tokenizer_config.json      ← Tokenizer configuration
├── requirements.txt           ← Python dependencies
└── .gitattributes             ← Git LFS tracking rules

Citation

If you use this model in your research or applications, please cite:

@misc{flan-t5-small-fid-2026,
  author       = {Syed Khalid Hussain},
  title        = {FLAN-T5-Small-FiD: Fusion-in-Decoder for Multi-Document Question Answering},
  year         = {2026},
  publisher    = {Hugging Face},
  howpublished = {\url{https://huggingface.co/syedkhalid0/flan-t5-small-fid}},
  note         = {Fine-tuned on Natural Questions with 100-passage FiD architecture}
}

Original FiD paper:

@article{izacard2021leveraging,
  title   = {Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering},
  author  = {Izacard, Gautier and Grave, Edouard},
  journal = {arXiv preprint arXiv:2007.01282},
  year    = {2021}
}

FLAN-T5 base model:

@article{chung2022scaling,
  title   = {Scaling Instruction-Finetuned Language Models},
  author  = {Chung, Hyung Won and Hou, Le and Longpre, Shayne and others},
  journal = {arXiv preprint arXiv:2210.11416},
  year    = {2022}
}

Contact


License

Apache 2.0, see LICENSE.


Downloads last month
251
Safetensors
Model size
77M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for syedkhalid0/flan-t5-small-fid

Finetuned
(502)
this model

Dataset used to train syedkhalid0/flan-t5-small-fid

Space using syedkhalid0/flan-t5-small-fid 1

Papers for syedkhalid0/flan-t5-small-fid