--- language: - en license: apache-2.0 tags: - multimodal - embedding - matryoshka - trimodal - image-text-audio - retrieval - cross-modal - edge - rag library_name: safetensors pipeline_tag: feature-extraction base_model: - google/embeddinggemma-300m datasets: - custom --- # TEG-421M — Trimodal Embeddings Gemma **TEG** (Trimodal Embeddings Gemma) maps image and audio into the same embedding space as text, enabling cross-modal retrieval with a single vector index. All three modalities share a unified 768-dim space via [Google's embeddinggemma-300M](https://huggingface.co/google/embeddinggemma-300m), with full Matryoshka truncation support down to 128 dims. > Also available in [GGUF format](https://huggingface.co/augmem/teg-421m-gguf) for quantized edge deployment. ## Architecture TEG combines lightweight edge encoders with deep projection heads that distill into Gemma's embedding space: ``` Text ──→ embeddinggemma-300M ──────────────────────→ 768-dim (L2-normalized) Image ──→ MobileNetV4-Medium (1280-d) ──→ DeepProjectionHead ──→ 768-dim Audio ──→ EfficientAT mn20_as (1920-d) ──→ DeepProjectionHead ──→ 768-dim ``` | Component | Architecture | Params | Size | |---|---|---|---| | Text encoder | embeddinggemma-300M (bfloat16) | 307.6M | 586.7 MB | | Image encoder | MobileNetV4-Medium (timm) | 8.5M | 32.4 MB | | Audio encoder | EfficientAT mn20_as | 18.0M | 68.5 MB | | Image projection | DeepProjectionHead (1280 → 768) | 42.0M | 160.1 MB | | Audio projection | DeepProjectionHead (1920 → 768) | 44.6M | 170.1 MB | | **Total** | | **420.6M** | **1017.8 MB** | ### Projection head detail Each `DeepProjectionHead` is a residual MLP: ``` Linear(encoder_dim, 4096) → GELU → LayerNorm → Dropout(0.3) → Linear(4096, 4096) → GELU → LayerNorm → Dropout(0.3) + residual → Linear(4096, 768) ``` ### Matryoshka dimensions Embeddings can be truncated to `[768, 512, 256, 128]` dimensions while preserving retrieval quality — trained with Matryoshka Representation Learning (MRL) using weights `[1.0, 1.0, 2.0, 4.0]`. ## Benchmarks All benchmarks run on a single NVIDIA L4 GPU with 5K samples where applicable. ### Cross-modal retrieval — SALT (5K trimodal samples) | Direction | TEG-421M (421M) | LCO-3B (4.7B) | Nemotron-3B (4.7B) | ImageBind (1.2B) | EBind | |---|---|---|---|---|---| | Text → Image R@1 | 0.672 | 0.660 | 0.529 | 0.712 | **0.779** | | Image → Text R@1 | 0.620 | 0.564 | 0.299 | 0.736 | **0.783** | | Text → Audio R@1 | **0.113** | 0.042 | 0.018 | 0.038 | 0.047 | | Audio → Text R@1 | **0.115** | 0.032 | 0.010 | 0.039 | 0.035 | | Audio → Image R@1 | **0.081** | 0.027 | 0.016 | 0.023 | 0.027 | | Image → Audio R@1 | **0.083** | 0.034 | 0.018 | 0.025 | 0.032 | TEG leads all audio cross-modal directions by 2-10x over models that are 3-11x larger. Image↔Audio improved ~40% over v1 via joint cross-modal training. Vision-text trails EBind/ImageBind but uses encoders small enough for edge deployment. ### Audio retrieval — AudioCaps & Clotho | Benchmark | Direction | TEG-421M | LCO-3B | Nemotron-3B | CLAP-Small | CLAP-Large | ImageBind | EBind | |---|---|---|---|---|---|---|---|---| | AudioCaps | A→T R@1 | 0.159 | 0.250 | 0.050 | **0.425** | 0.420 | 0.116 | 0.225 | | AudioCaps | T→A R@1 | 0.149 | 0.215 | 0.075 | **0.315** | 0.280 | 0.080 | 0.219 | | Clotho | A→T R@1 | 0.168 | 0.178 | 0.038 | 0.166 | **0.195** | 0.061 | 0.088 | | Clotho | T→A R@1 | 0.123 | **0.187** | 0.070 | 0.159 | 0.167 | 0.074 | 0.118 | CLAP models lead on audio-only benchmarks (audio specialists with no image support). Among trimodal models, TEG is competitive with LCO while being 11x smaller. ### Image-text retrieval — MSCOCO & Flickr30k | Benchmark | Direction | TEG-421M (421M) | EBind (1.78B*) | ImageBind (1.2B) | LCO-3B (4.7B) | Nemotron-3B (4.7B) | |---|---|---|---|---|---|---| | MSCOCO 5K | I→T R@1 | 0.248 | **0.743** | 0.658 | 0.533 | 0.225 | | MSCOCO 5K | T→I R@1 | 0.180 | **0.559** | 0.490 | 0.469 | 0.334 | | MSCOCO 5K | I→T R@10 | 0.622 | **0.948** | 0.918 | 0.784 | 0.630 | | Flickr30k | I→T R@1 | 0.498 | — | — | **0.840** | 0.419 | | Flickr30k | T→I R@1 | 0.358 | — | — | **0.765** | 0.563 | TEG's image-text retrieval trades accuracy for edge deployability — MobileNetV4-Medium is ~100x smaller than the ViT-H/ViT-L encoders used by competitors. On MSCOCO, TEG outperforms Nemotron-3B on I→T despite being 11x smaller. ### Zero-shot classification — ESC-50 | Model | Params | Accuracy | |---|---|---| | CLAP-Large | 67.8M | **0.905** | | LCO-3B | 4.7B | 0.853 | | TEG-421M | 421M | 0.820 | | EBind | 1.78B* | 0.770 | | CLAP-Small | 27.5M | 0.751 | | Nemotron-3B | 4.7B | 0.727 | | ImageBind | 1.2B | 0.664 | ## Usage ### Loading components ```python from safetensors.torch import load_file # Load entire model tensors = load_file("teg-421m.safetensors") # Extract components by prefix gemma_sd = {k.removeprefix("gemma."): v for k, v in tensors.items() if k.startswith("gemma.")} image_enc_sd = {k.removeprefix("image_encoder."): v for k, v in tensors.items() if k.startswith("image_encoder.")} audio_enc_sd = {k.removeprefix("audio_encoder."): v for k, v in tensors.items() if k.startswith("audio_encoder.")} image_proj_sd = {k.removeprefix("image_projection."): v for k, v in tensors.items() if k.startswith("image_projection.")} audio_proj_sd = {k.removeprefix("audio_projection."): v for k, v in tensors.items() if k.startswith("audio_projection.")} ``` ### Reading metadata ```python from safetensors import safe_open with safe_open("teg-421m.safetensors", framework="pt") as f: metadata = f.metadata() print(metadata) # Keys: format, version, text_model, embed_dim, image_encoder_name, # image_encoder_dim, audio_encoder_name, audio_encoder_dim, # audio_sample_rate, matryoshka_dims, total_parameters, ... ``` ### Matryoshka truncation ```python import torch.nn.functional as F # Full 768-dim embedding embedding = model(input) # (N, 768) # Truncate to 256-dim and re-normalize embedding_256 = F.normalize(embedding[:, :256], dim=-1) ``` ## File layout ``` teg-421m.safetensors # All components in one file (~1 GB) ``` ### Tensor key prefixes | Prefix | Component | Tensors | |---|---|---| | `gemma.*` | embeddinggemma-300M (bfloat16) | 316 | | `image_encoder.*` | MobileNetV4-Medium | 462 | | `audio_encoder.*` | EfficientAT mn20_as | 312 | | `image_projection.*` | Deep projection head | 14 | | `audio_projection.*` | Deep projection head | 14 | ## Training - **Loss**: InfoNCE (contrastive) with Matryoshka Representation Learning - **Data**: ~4.8M synthetically generated trimodal triplets (text, image, audio) - **Hardware**: 2x NVIDIA L4 GPUs - **Optimizer**: AdamW, lr=1e-3 (projections), weight decay=1e-4 - **Epochs**: ~22 (early stopping on validation recall) - **Projection heads only** — encoders and Gemma are frozen during training ### Design decisions - **Frozen source encoders**: MobileNetV4 and EfficientAT are kept frozen; only projection heads are trained via distillation into Gemma's space - **Deep projection heads**: Residual MLPs with dropout outperformed shallow 2-layer heads, especially for audio - **Matryoshka weighting**: Higher weight on smaller dimensions (4x at 128-dim) ensures quality at aggressive truncation levels - **Edge-first**: Source encoders chosen for edge deployment — MobileNetV4-Medium and EfficientAT mn20 can run on devices like Raspberry Pi 5 *\*EBind's [HuggingFace checkpoint](https://huggingface.co/encord-team/ebind-full) is 8.93M parameters (bridge heads only), but inference requires frozen backbones (SigLIP ViT-L, CLAP HTSAT, text encoder) totaling 1.78B loaded parameters as measured by our benchmark harness.* ## Limitations - Audio retrieval lags behind specialist models like CLAP on audio-only benchmarks - Image-text retrieval trades some accuracy vs larger vision encoders (SigLIP, CLIP ViT-L) for edge deployability - Trained primarily on synthetic data — real-world distribution shifts may affect performance - Text modality requires the full embeddinggemma-300M model (307M params, bfloat16) ## Links - **Website**: [augmem.ai](https://augmem.ai) - **GitHub**: [github.com/augmem](https://github.com/augmem) - **GGUF variant**: [augmem/teg-421m-gguf](https://huggingface.co/augmem/teg-421m-gguf) ## License Apache 2.0