gemma-frozen-512-step64000
This is a fine-tuned contrastive learning model for tag classification.
Model Description
Fine-tuned Gemma model (frozen backbone) at step 64000. Trained for persona-conditioned tag classification with high/low connectivity specialization.
Training Details
- Checkpoint Path:
output/contrastive_gemma_frozen_512/checkpoint_step_64000 - Base Model:
google/embeddinggemma-300m
Performance
Validation Metrics
High Connectivity
Base Branch:
- Mean Rank (1st positive): 1.37
- Hit@1: 82.00%
- Hit@10: 100.00%
- MRR: 0.8860
Personalized Branch:
- Mean Rank (1st positive): 1.34
- Hit@1: 83.00%
- Hit@10: 100.00%
- MRR: 0.8945
Low Connectivity
Base Branch:
- Mean Rank (1st positive): 2.01
- Hit@1: 65.00%
- Hit@10: 99.00%
- MRR: 0.7707
Personalized Branch:
- Mean Rank (1st positive): 2.07
- Hit@1: 63.00%
- Hit@10: 99.00%
- MRR: 0.7564
Usage
See the README.md for detailed usage examples.
Usage Example
"""
Example: Using Fine-tuned Gemma Model for Tag Classification
This example shows how to use the fine-tuned Gemma model for
persona-conditioned tag classification using our module abstractions.
Installation:
pip install git+https://github.com/Pieces/TAG-module.git@main
# Or: pip install -e .
To use this model:
1. Download model.pt and config.json from the Hub
2. Place them in a directory (e.g., ./checkpoint/)
3. Update checkpoint_path below
"""
import torch
from pathlib import Path
from playground.validate_from_checkpoint import (
load_trained_model,
encode_query,
encode_general_tags,
compute_ranked_tags,
)
# Load the fine-tuned model
print("Loading fine-tuned Gemma model...")
# Option 1: Load from local checkpoint (if you have it)
checkpoint_path = "output/contrastive_gemma_frozen_512/checkpoint_step_64000"
# Option 2: Load from downloaded Hub files
# 1. Download model.pt and config.json from:
# https://huggingface.co/Pieces/gemma-frozen-512-step64000
# 2. Place in a directory and update path:
# checkpoint_path = "./downloaded_checkpoint"
if not Path(checkpoint_path).exists():
print(f"âš Checkpoint not found at: {checkpoint_path}")
print("Please download model.pt and config.json from the Hub and update checkpoint_path")
exit(1)
model, config = load_trained_model(checkpoint_path, device="cpu")
model.eval()
print("✓ Model loaded!")
# Example query with persona
query_text = "How to implement OAuth2 authentication in a Python Flask API?"
persona_text = "I'm a backend developer working on web APIs and microservices"
# Candidate tags to rank
candidate_tags = [
"python", "flask", "oauth2", "authentication", "api",
"security", "web-development", "jwt", "rest-api", "backend",
"microservices", "fastapi", "django"
]
print(f"\nQuery: {query_text}")
print(f"Persona: {persona_text}")
print(f"Candidate tags: {candidate_tags}\n")
# Encode query using base branch (no persona conditioning)
print("Encoding query (base branch)...")
with torch.inference_mode():
query_emb_base = encode_query(
model=model,
query_text=query_text,
persona_text=persona_text,
connectivity="high",
branch="base",
max_length=512,
use_pretrained_backbone=False,
extraction_mode="full",
)
# Encode query using personalized branch (with persona conditioning)
print("Encoding query (personalized branch)...")
with torch.inference_mode():
query_emb_personalized = encode_query(
model=model,
query_text=query_text,
persona_text=persona_text,
connectivity="high",
branch="personalized", # Use personalized branch
max_length=512,
use_pretrained_backbone=False,
extraction_mode="full",
)
# Encode tags
print("Encoding tags...")
with torch.inference_mode():
tag_embs_base = encode_general_tags(
model=model,
general_tags=candidate_tags,
connectivity="high",
branch="base",
persona_text=persona_text,
max_length=512,
use_pretrained_backbone=False,
extraction_mode="full",
)
tag_embs_personalized = encode_general_tags(
model=model,
general_tags=candidate_tags,
connectivity="high",
branch="personalized", # Use personalized branch
persona_text=persona_text,
max_length=512,
use_pretrained_backbone=False,
extraction_mode="full",
)
print(f"Query embedding shape: {query_emb_base.shape}")
print(f"Tag embeddings shape: {tag_embs_base.shape}")
# Rank tags using base branch
print("\n" + "="*60)
print("Rankings (Base Branch - No Persona Conditioning):")
print("="*60)
ranked_tags_base = compute_ranked_tags(
query_emb=query_emb_base,
pos_embs=torch.empty(0, model.config.backbone_embedding_dim),
neg_embs=torch.empty(0, model.config.backbone_embedding_dim),
general_embs=tag_embs_base,
positive_tags=[],
negative_tags=[],
general_tags=candidate_tags,
)
for tag, rank, label, score in ranked_tags_base[:5]:
print(f"{rank:2d}. {tag:20s} (score: {score:.4f})")
# Rank tags using personalized branch
print("\n" + "="*60)
print("Rankings (Personalized Branch - With Persona Conditioning):")
print("="*60)
ranked_tags_personalized = compute_ranked_tags(
query_emb=query_emb_personalized,
pos_embs=torch.empty(0, model.config.backbone_embedding_dim),
neg_embs=torch.empty(0, model.config.backbone_embedding_dim),
general_embs=tag_embs_personalized,
positive_tags=[],
negative_tags=[],
general_tags=candidate_tags,
)
for tag, rank, label, score in ranked_tags_personalized[:5]:
print(f"{rank:2d}. {tag:20s} (score: {score:.4f})")
print("\n" + "="*60)
print("Example complete!")
print("\nNote: Personalized branch adapts tag rankings based on the persona context.")
Running the Example
# Install the repository first
pip install git+https://github.com/Pieces/TAG-module.git@main
# Or for local development:
pip install -e .
# Run the example
python gemma_finetuned_example.py
Citation
If you use this model, please cite:
@software{{tag_module,
title = {{TAG Module: Persona-Conditioned Contrastive Learning for Tag Classification}},
author = {{Your Name}},
year = {{2025}},
url = {{https://github.com/yourusername/tag-module}}
}}
License
Please refer to the original model license for the backbone model.
- Downloads last month
- 8
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support