EmbeddingGemma-300M: Depth + Embedding Width Compressed
This is a compressed version of google/embeddinggemma-300m with 10x depth reduction (24 β 2 layers) and 10x embedding width reduction (768 β 76 β 768).
Model Details
- Base Model:
google/embeddinggemma-300m - Compression:
- Depth reduction: 10x (24 layers β 2 layers)
- Embedding width reduction: 10x (768 β 76 β 768)
- Parameters: ~33M (down from ~300M)
- Output Dimension: 768
- Compression Ratio: ~9x
Installation
pip install torch sentence-transformers transformers
Usage
Simple Usage with SentenceTransformers (Recommended)
The easiest way to use this model is with the sentence-transformers library:
from sentence_transformers import SentenceTransformer
import torch
# Load the model
model = SentenceTransformer(
"Pieces/embeddinggemma-300m-distilled-depth10pct-10-768dim-best",
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Encode texts
texts = ["Hello world", "This is a test"]
embeddings = model.encode(texts, convert_to_tensor=True)
print(f"Embeddings shape: {embeddings.shape}")
# Output: torch.Size([2, 768])
Advanced Usage (Full Model Access)
If you need access to the full model structure:
import torch
from pathlib import Path
import sys
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from playground.validate_from_checkpoint import load_trained_model
from tags_model.training.train_distillation import _get_transformer_layers
# Load the model
model, config = load_trained_model(
checkpoint_path="Pieces/embeddinggemma-300m-distilled-depth10pct-10-768dim-best",
device="cuda" if torch.cuda.is_available() else "cpu",
compile_model=False,
)
# Verify structure
transformer = model.backbone.transformer
layers = _get_transformer_layers(transformer)
layer_count = len(layers) if layers else 0
total_params = sum(p.numel() for p in model.parameters())
print(f"Model structure:")
print(f" - Layers: {layer_count}")
print(f" - Parameters: {total_params:,}")
# Encode texts
model.eval()
with torch.no_grad():
texts = ["Hello world", "This is a test"]
# Handle Identity projection (when output_dim == embedding_dim)
if hasattr(model.backbone, 'projection'):
proj = model.backbone.projection
is_identity = (
isinstance(proj, torch.nn.Identity) or
'Identity' in proj.__class__.__name__
)
if is_identity:
# Identity projection - encode directly from transformer
embeddings = model.backbone.transformer.encode_texts(
texts, max_length=512, return_dict=False
)
else:
# Use standard encode_texts which handles projection
embeddings = model.backbone.encode_texts(
texts, max_length=512, return_dict=False
)
else:
embeddings = model.backbone.encode_texts(
texts, max_length=512, return_dict=False
)
print(f"Embeddings shape: {embeddings.shape}")
Query-Tag Retrieval Example
from sentence_transformers import SentenceTransformer
import torch
from typing import List
# Load model
model = SentenceTransformer(
"Pieces/embeddinggemma-300m-distilled-depth10pct-10-768dim-best"
)
def compute_similarities(query_embeddings, tag_embeddings):
"""Compute cosine similarities between queries and tags."""
query_norm = query_embeddings / (query_embeddings.norm(dim=1, keepdim=True) + 1e-8)
tag_norm = tag_embeddings / (tag_embeddings.norm(dim=1, keepdim=True) + 1e-8)
return torch.mm(query_norm, tag_norm.t())
# Example queries and tags
queries = [
"How to implement authentication in a web application?",
"What are the best practices for database optimization?",
]
tags = [
"authentication", "security", "web-development",
"database", "sql", "performance", "optimization",
"machine-learning", "deployment", "production",
]
# Encode queries and tags
query_embeddings = model.encode(queries, convert_to_tensor=True)
tag_embeddings = model.encode(tags, convert_to_tensor=True)
# Compute similarities
similarities = compute_similarities(query_embeddings, tag_embeddings)
# Get top tags for each query
for query_idx, query in enumerate(queries):
top_k = 3
top_similarities, top_indices = torch.topk(
similarities[query_idx], k=top_k, dim=0
)
print(f"\nQuery: {query}")
for rank, (tag_idx, sim) in enumerate(
zip(top_indices.cpu().tolist(), top_similarities.cpu().tolist()), start=1
):
print(f" {rank}. {tags[tag_idx]} (similarity: {sim:.4f})")
See usage_example.py in this repository for a complete standalone example.
Model Architecture
- Transformer Layers: 2 (pruned from 24, keeping last 2 layers)
- Hidden Size: 768
- Embedding Width Reduction: Enabled (768 β 76 β 768)
- Output Dimension: 768
Performance
This compressed model maintains reasonable performance while being significantly smaller:
- Parameters: 33M (vs 300M original)
- Memory: ~132MB (vs ~1.2GB original)
- Speed: Faster inference due to fewer layers
Citation
If you use this model, please cite:
@misc{embeddinggemma-compressed-depth-emb,
title={EmbeddingGemma-300M: Depth + Embedding Width Compressed},
author={Pieces},
year={2024},
url={https://huggingface.co/Pieces/embeddinggemma-300m-distilled-depth10pct-10-768dim-best}
}
- Downloads last month
- 3