ColGemma4-E4B-IT-Base
See also: ColGemma4-E2B-IT-Base (smaller 2.3B-effective variant)
ColGemma4 is a visual document retrieval model built on Google's Gemma 4 E4B (9B params, 4.5B effective). It generates ColBERT-style multi-vector representations of document images and text queries for late-interaction retrieval.
Single-seed LoRA adapter trained with ColbertLoss, no hard negatives, no model merging.
Built following the ColPali architecture pattern, adapted for Gemma 4's multimodal architecture.
Model Description
| Property | Value |
|---|---|
| Base model | google/gemma-4-E4B-it (9B total, 4.5B effective) |
| Architecture | ColBERT late-interaction over Gemma 4 VLM |
| Embedding dim | 128 |
| Visual tokens | 1120 (max soft tokens) |
| Fine-tuning | LoRA (r=32, alpha=32, dropout=0.1) |
| Trainable params | 73.4M (0.95% of total) |
| Projection | Random-init linear (hidden_size -> 128), not trained |
| Training loss | ColbertLoss (temperature=0.02, in-batch negatives only) |
| Precision | BF16 |
Benchmark Results
All scores are nDCG@5 on the ViDoRe benchmark.
ViDoRe V1
| Task | nDCG@5 | nDCG@10 |
|---|---|---|
| ArxivQA | 84.53 | 85.57 |
| DocVQA | 57.70 | 59.87 |
| InfoVQA | 90.84 | 91.12 |
| ShiftProject | 84.45 | 85.17 |
| SyntheticDocQA - AI | 97.89 | 97.89 |
| SyntheticDocQA - Energy | 94.40 | 95.01 |
| SyntheticDocQA - Government | 96.58 | 96.58 |
| SyntheticDocQA - Healthcare | 97.89 | 97.89 |
| Tabfquad | 92.37 | 92.84 |
| Tatdqa | 76.72 | 78.93 |
| Average | 87.34 | 88.09 |
ViDoRe V2
| Task | nDCG@5 | nDCG@10 |
|---|---|---|
| BioMedical Lectures | 57.06 | 59.48 |
| ESG Reports - HL | 57.96 | 61.07 |
| ESG Reports | 47.51 | 51.59 |
| Economics Reports | 39.76 | 42.46 |
| Average | 50.57 | 53.65 |
ViDoRe V3
| Task | nDCG@5 | nDCG@10 |
|---|---|---|
| Computer Science | 63.55 | 67.76 |
| Energy | 59.64 | 62.55 |
| Finance En | 51.05 | 53.96 |
| Finance Fr | 41.05 | 44.64 |
| HR | 52.60 | 55.17 |
| Pharmaceuticals | 55.98 | 57.42 |
| Physics | 43.37 | 46.11 |
| Industrial | 41.52 | 42.98 |
| Average | 51.10 | 53.82 |
Usage
Installation
pip install colpali-engine transformers torch peft
Loading the Model
import torch
from colgemma4 import ColGemma4, ColGemma4Processor
model = ColGemma4.from_pretrained(
"athrael-soju/ColGemma4-E4B-IT-Base",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
ignore_mismatched_sizes=True,
)
processor = ColGemma4Processor.from_pretrained(
"athrael-soju/ColGemma4-E4B-IT-Base",
max_num_visual_tokens=1120,
)
Encoding Documents (Images)
from PIL import Image
images = [Image.open("page1.png"), Image.open("page2.png")]
batch_doc = processor.process_images(images)
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
with torch.no_grad():
doc_embeddings = model(**batch_doc)
Encoding Queries
queries = ["What is the revenue for Q3 2024?"]
batch_query = processor.process_queries(queries)
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
with torch.no_grad():
query_embeddings = model(**batch_query)
Scoring (MaxSim)
scores = processor.score(query_embeddings, doc_embeddings)
Training Configuration
Base model: google/gemma-4-E4B-it
Loss: ColbertLoss (temperature=0.02)
Hard negatives: none
Batch size per GPU: 64
GPUs: 8
Gradient accumulation: 1
Effective batch size: 512 (64 x 8)
In-batch negatives: 512
LoRA:
r: 32
alpha: 32
dropout: 0.1
target_modules: "language_model.*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj)"
Learning rate: 2e-4 (cosine schedule, 8% warmup)
Weight decay: 0.02
Epochs: 1
Steps: 1,512
Visual tokens: 1120
Attention: Bidirectional (all 42 layers patched)
Gradient checkpointing: enabled
Precision: BF16
Training Data
Trained on ~774K query-document pairs from publicly available datasets:
vidore/colpali_train_setopenbmb/VisRAG-Ret-Train-Synthetic-dataopenbmb/VisRAG-Ret-Train-In-domain-datallamaindex/vdr-multilingual-train(en/de/es/fr/it subsets)vidore/tatdqa_trainvidore/tabfquad_train_set
Troubleshooting & Fine-tuning Guide
If you're building on this model or training your own ColGemma4, here's what we learned along the way.
Gemma 4 Architecture Gotchas
- Position embedding memory blow-up - The vision encoder uses
F.one_hot(positions, num_classes=10240)which allocates ~314 GB at batch=32 with 1120 tokens. Replacing withF.embeddingis mathematically identical and saves ~106 GB/GPU. Required for batch sizes above 8.# In Gemma4VisionEncoder._position_embeddings: # Replace: one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size) # With: return F.embedding(clamped_positions, self.position_embedding_table) - KV-sharing and
use_cache- Gemma 4 E4B has 18 of 42 text layers (24-41) that reuse K/V from 2 donor layers (22 and 23) when caching is enabled. During training, always setuse_cache=Falseto ensure every layer computes its own K/V and all LoRA weights are active. At inference time, setuse_cache=Trueso the KV-sharing architecture works as designed. - Flash Attention is incompatible - Gemma 4 has head_dim 256 (sliding) and 512 (global). FA v2 caps at 256, FA v4 at 128. Use
attn_implementation="sdpa"instead. - Gradient checkpointing is mandatory - At 1120 visual tokens, activations from 42 layers consume well over 200 GB. Even with the position patch, disabling gradient checkpointing will OOM.
Training Tips
- Leave
custom_text_projalone - The 128-dim projection is randomly initialized and works best untrained. Both LoRA-targeting andmodules_to_savecaused regressions in our experiments. The random projection provides a consistent mapping without overfitting. - Keep
grad_accum=1with contrastive losses -all_gatheronly collects the current micro-batch, so accumulation steps halve your in-batch negatives. Training loss looks deceptively good but eval regresses. Use the largest batch that fits withgrad_accum=1. - Avoid
torch.compile- It adds_orig_mod.to weight keys, breaking PEFT adapter loading at eval time. Scores drop to near-zero despite healthy training loss. - Watch for silent weight randomization -
ignore_mismatched_sizes=Truewill initialize mismatched weights to random without any error. Sanity check: loss at step 0 should be nearlog(batch_size)(~6.2 for batch 512), and grad norms should be 5-20. If grad norms are in the thousands, weights didn't load correctly.
HydraGemma4 (Dual-Head Variant)
This repo includes lm_head.pt, the saved language model head from the base model. Combined with the LoRA adapter, this enables the HydraGemma4 dual-head architecture supporting both retrieval (ColBERT embeddings) and generation (text output) from the same base model.
Limitations
- Single-seed, no model merging
- No hard negative mining (relies entirely on in-batch negatives)
- English-centric training data (some multilingual from VDR)
- Visual token budget fixed at 1120
Citation
@misc{colgemma4,
title={ColGemma4: Visual Document Retrieval with Gemma 4},
author={Athrael Soju},
year={2026},
url={https://huggingface.co/athrael-soju/ColGemma4-E4B-IT-Base}
}
Acknowledgements
- ColPali by Faysse et al. for the ColBERT-over-VLM architecture
- Google Gemma 4 for the base model
- colpali-engine for the training framework
- ViDoRe benchmark for evaluation
Model tree for athrael-soju/ColGemma4-E4B-IT-Base
Base model
google/gemma-4-E4B-it