TEG-421M β€” Trimodal Embeddings Gemma

TEG (Trimodal Embeddings Gemma) maps image and audio into the same embedding space as text, enabling cross-modal retrieval with a single vector index. All three modalities share a unified 768-dim space via Google's embeddinggemma-300M, with full Matryoshka truncation support down to 128 dims.

Also available in GGUF format for quantized edge deployment.

Architecture

TEG combines lightweight edge encoders with deep projection heads that distill into Gemma's embedding space:

Text  ──→ embeddinggemma-300M ──────────────────────→ 768-dim (L2-normalized)
Image ──→ MobileNetV4-Medium (1280-d) ──→ DeepProjectionHead ──→ 768-dim
Audio ──→ EfficientAT mn20_as (1920-d) ──→ DeepProjectionHead ──→ 768-dim
Component Architecture Params Size
Text encoder embeddinggemma-300M (bfloat16) 307.6M 586.7 MB
Image encoder MobileNetV4-Medium (timm) 8.5M 32.4 MB
Audio encoder EfficientAT mn20_as 18.0M 68.5 MB
Image projection DeepProjectionHead (1280 β†’ 768) 42.0M 160.1 MB
Audio projection DeepProjectionHead (1920 β†’ 768) 44.6M 170.1 MB
Total 420.6M 1017.8 MB

Projection head detail

Each DeepProjectionHead is a residual MLP:

Linear(encoder_dim, 4096) β†’ GELU β†’ LayerNorm β†’ Dropout(0.3)
  β†’ Linear(4096, 4096) β†’ GELU β†’ LayerNorm β†’ Dropout(0.3) + residual
  β†’ Linear(4096, 768)

Matryoshka dimensions

Embeddings can be truncated to [768, 512, 256, 128] dimensions while preserving retrieval quality β€” trained with Matryoshka Representation Learning (MRL) using weights [1.0, 1.0, 2.0, 4.0].

Benchmarks

All benchmarks run on a single NVIDIA L4 GPU with 5K samples where applicable.

Cross-modal retrieval β€” SALT (5K trimodal samples)

Direction TEG-421M (421M) LCO-3B (4.7B) Nemotron-3B (4.7B) ImageBind (1.2B) EBind
Text β†’ Image R@1 0.672 0.660 0.529 0.712 0.779
Image β†’ Text R@1 0.620 0.564 0.299 0.736 0.783
Text β†’ Audio R@1 0.113 0.042 0.018 0.038 0.047
Audio β†’ Text R@1 0.115 0.032 0.010 0.039 0.035
Audio β†’ Image R@1 0.081 0.027 0.016 0.023 0.027
Image β†’ Audio R@1 0.083 0.034 0.018 0.025 0.032

TEG leads all audio cross-modal directions by 2-10x over models that are 3-11x larger. Image↔Audio improved ~40% over v1 via joint cross-modal training. Vision-text trails EBind/ImageBind but uses encoders small enough for edge deployment.

Audio retrieval β€” AudioCaps & Clotho

Benchmark Direction TEG-421M LCO-3B Nemotron-3B CLAP-Small CLAP-Large ImageBind EBind
AudioCaps A→T R@1 0.159 0.250 0.050 0.425 0.420 0.116 0.225
AudioCaps T→A R@1 0.149 0.215 0.075 0.315 0.280 0.080 0.219
Clotho A→T R@1 0.168 0.178 0.038 0.166 0.195 0.061 0.088
Clotho T→A R@1 0.123 0.187 0.070 0.159 0.167 0.074 0.118

CLAP models lead on audio-only benchmarks (audio specialists with no image support). Among trimodal models, TEG is competitive with LCO while being 11x smaller.

Image-text retrieval β€” MSCOCO & Flickr30k

Benchmark Direction TEG-421M (421M) EBind (1.78B*) ImageBind (1.2B) LCO-3B (4.7B) Nemotron-3B (4.7B)
MSCOCO 5K I→T R@1 0.248 0.743 0.658 0.533 0.225
MSCOCO 5K T→I R@1 0.180 0.559 0.490 0.469 0.334
MSCOCO 5K I→T R@10 0.622 0.948 0.918 0.784 0.630
Flickr30k I→T R@1 0.498 — — 0.840 0.419
Flickr30k T→I R@1 0.358 — — 0.765 0.563

TEG's image-text retrieval trades accuracy for edge deployability — MobileNetV4-Medium is ~100x smaller than the ViT-H/ViT-L encoders used by competitors. On MSCOCO, TEG outperforms Nemotron-3B on I→T despite being 11x smaller.

Zero-shot classification β€” ESC-50

Model Params Accuracy
CLAP-Large 67.8M 0.905
LCO-3B 4.7B 0.853
TEG-421M 421M 0.820
EBind 1.78B* 0.770
CLAP-Small 27.5M 0.751
Nemotron-3B 4.7B 0.727
ImageBind 1.2B 0.664

Usage

Loading components

from safetensors.torch import load_file

# Load entire model
tensors = load_file("teg-421m.safetensors")

# Extract components by prefix
gemma_sd = {k.removeprefix("gemma."): v for k, v in tensors.items() if k.startswith("gemma.")}
image_enc_sd = {k.removeprefix("image_encoder."): v for k, v in tensors.items() if k.startswith("image_encoder.")}
audio_enc_sd = {k.removeprefix("audio_encoder."): v for k, v in tensors.items() if k.startswith("audio_encoder.")}
image_proj_sd = {k.removeprefix("image_projection."): v for k, v in tensors.items() if k.startswith("image_projection.")}
audio_proj_sd = {k.removeprefix("audio_projection."): v for k, v in tensors.items() if k.startswith("audio_projection.")}

Reading metadata

from safetensors import safe_open

with safe_open("teg-421m.safetensors", framework="pt") as f:
    metadata = f.metadata()
    print(metadata)
    # Keys: format, version, text_model, embed_dim, image_encoder_name,
    #        image_encoder_dim, audio_encoder_name, audio_encoder_dim,
    #        audio_sample_rate, matryoshka_dims, total_parameters, ...

Matryoshka truncation

import torch.nn.functional as F

# Full 768-dim embedding
embedding = model(input)  # (N, 768)

# Truncate to 256-dim and re-normalize
embedding_256 = F.normalize(embedding[:, :256], dim=-1)

File layout

teg-421m.safetensors     # All components in one file (~1 GB)

Tensor key prefixes

Prefix Component Tensors
gemma.* embeddinggemma-300M (bfloat16) 316
image_encoder.* MobileNetV4-Medium 462
audio_encoder.* EfficientAT mn20_as 312
image_projection.* Deep projection head 14
audio_projection.* Deep projection head 14

Training

  • Loss: InfoNCE (contrastive) with Matryoshka Representation Learning
  • Data: ~4.8M synthetically generated trimodal triplets (text, image, audio)
  • Hardware: 2x NVIDIA L4 GPUs
  • Optimizer: AdamW, lr=1e-3 (projections), weight decay=1e-4
  • Epochs: ~22 (early stopping on validation recall)
  • Projection heads only β€” encoders and Gemma are frozen during training

Design decisions

  • Frozen source encoders: MobileNetV4 and EfficientAT are kept frozen; only projection heads are trained via distillation into Gemma's space
  • Deep projection heads: Residual MLPs with dropout outperformed shallow 2-layer heads, especially for audio
  • Matryoshka weighting: Higher weight on smaller dimensions (4x at 128-dim) ensures quality at aggressive truncation levels
  • Edge-first: Source encoders chosen for edge deployment β€” MobileNetV4-Medium and EfficientAT mn20 can run on devices like Raspberry Pi 5

*EBind's HuggingFace checkpoint is 8.93M parameters (bridge heads only), but inference requires frozen backbones (SigLIP ViT-L, CLAP HTSAT, text encoder) totaling 1.78B loaded parameters as measured by our benchmark harness.

Limitations

  • Audio retrieval lags behind specialist models like CLAP on audio-only benchmarks
  • Image-text retrieval trades some accuracy vs larger vision encoders (SigLIP, CLIP ViT-L) for edge deployability
  • Trained primarily on synthetic data β€” real-world distribution shifts may affect performance
  • Text modality requires the full embeddinggemma-300M model (307M params, bfloat16)

Links

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for augmem/teg-421m

Finetuned
(221)
this model
Quantizations
1 model