EmbeddingGemma-300M: Depth + Embedding Width Compressed

This is a compressed version of google/embeddinggemma-300m with 10x depth reduction (24 β†’ 2 layers) and 10x embedding width reduction (768 β†’ 76 β†’ 768).

Model Details

  • Base Model: google/embeddinggemma-300m
  • Compression:
    • Depth reduction: 10x (24 layers β†’ 2 layers)
    • Embedding width reduction: 10x (768 β†’ 76 β†’ 768)
  • Parameters: ~33M (down from ~300M)
  • Output Dimension: 768
  • Compression Ratio: ~9x

Installation

pip install torch sentence-transformers transformers

Usage

Simple Usage with SentenceTransformers (Recommended)

The easiest way to use this model is with the sentence-transformers library:

from sentence_transformers import SentenceTransformer
import torch

# Load the model
model = SentenceTransformer(
    "Pieces/embeddinggemma-300m-distilled-depth10pct-10-768dim-best",
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Encode texts
texts = ["Hello world", "This is a test"]
embeddings = model.encode(texts, convert_to_tensor=True)

print(f"Embeddings shape: {embeddings.shape}")
# Output: torch.Size([2, 768])

Advanced Usage (Full Model Access)

If you need access to the full model structure:

import torch
from pathlib import Path
import sys

# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

from playground.validate_from_checkpoint import load_trained_model
from tags_model.training.train_distillation import _get_transformer_layers

# Load the model
model, config = load_trained_model(
    checkpoint_path="Pieces/embeddinggemma-300m-distilled-depth10pct-10-768dim-best",
    device="cuda" if torch.cuda.is_available() else "cpu",
    compile_model=False,
)

# Verify structure
transformer = model.backbone.transformer
layers = _get_transformer_layers(transformer)
layer_count = len(layers) if layers else 0
total_params = sum(p.numel() for p in model.parameters())

print(f"Model structure:")
print(f"  - Layers: {layer_count}")
print(f"  - Parameters: {total_params:,}")

# Encode texts
model.eval()
with torch.no_grad():
    texts = ["Hello world", "This is a test"]
    
    # Handle Identity projection (when output_dim == embedding_dim)
    if hasattr(model.backbone, 'projection'):
        proj = model.backbone.projection
        is_identity = (
            isinstance(proj, torch.nn.Identity) or
            'Identity' in proj.__class__.__name__
        )
        
        if is_identity:
            # Identity projection - encode directly from transformer
            embeddings = model.backbone.transformer.encode_texts(
                texts, max_length=512, return_dict=False
            )
        else:
            # Use standard encode_texts which handles projection
            embeddings = model.backbone.encode_texts(
                texts, max_length=512, return_dict=False
            )
    else:
        embeddings = model.backbone.encode_texts(
            texts, max_length=512, return_dict=False
        )

print(f"Embeddings shape: {embeddings.shape}")

Query-Tag Retrieval Example

from sentence_transformers import SentenceTransformer
import torch
from typing import List

# Load model
model = SentenceTransformer(
    "Pieces/embeddinggemma-300m-distilled-depth10pct-10-768dim-best"
)

def compute_similarities(query_embeddings, tag_embeddings):
    """Compute cosine similarities between queries and tags."""
    query_norm = query_embeddings / (query_embeddings.norm(dim=1, keepdim=True) + 1e-8)
    tag_norm = tag_embeddings / (tag_embeddings.norm(dim=1, keepdim=True) + 1e-8)
    return torch.mm(query_norm, tag_norm.t())

# Example queries and tags
queries = [
    "How to implement authentication in a web application?",
    "What are the best practices for database optimization?",
]

tags = [
    "authentication", "security", "web-development",
    "database", "sql", "performance", "optimization",
    "machine-learning", "deployment", "production",
]

# Encode queries and tags
query_embeddings = model.encode(queries, convert_to_tensor=True)
tag_embeddings = model.encode(tags, convert_to_tensor=True)

# Compute similarities
similarities = compute_similarities(query_embeddings, tag_embeddings)

# Get top tags for each query
for query_idx, query in enumerate(queries):
    top_k = 3
    top_similarities, top_indices = torch.topk(
        similarities[query_idx], k=top_k, dim=0
    )
    
    print(f"\nQuery: {query}")
    for rank, (tag_idx, sim) in enumerate(
        zip(top_indices.cpu().tolist(), top_similarities.cpu().tolist()), start=1
    ):
        print(f"  {rank}. {tags[tag_idx]} (similarity: {sim:.4f})")

See usage_example.py in this repository for a complete standalone example.

Model Architecture

  • Transformer Layers: 2 (pruned from 24, keeping last 2 layers)
  • Hidden Size: 768
  • Embedding Width Reduction: Enabled (768 β†’ 76 β†’ 768)
  • Output Dimension: 768

Performance

This compressed model maintains reasonable performance while being significantly smaller:

  • Parameters: 33M (vs 300M original)
  • Memory: ~132MB (vs ~1.2GB original)
  • Speed: Faster inference due to fewer layers

Citation

If you use this model, please cite:

@misc{embeddinggemma-compressed-depth-emb,
  title={EmbeddingGemma-300M: Depth + Embedding Width Compressed},
  author={Pieces},
  year={2024},
  url={https://huggingface.co/Pieces/embeddinggemma-300m-distilled-depth10pct-10-768dim-best}
}
Downloads last month
3
Safetensors
Model size
28.4M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support