gemma-frozen-512-step64000

This is a fine-tuned contrastive learning model for tag classification.

Model Description

Fine-tuned Gemma model (frozen backbone) at step 64000. Trained for persona-conditioned tag classification with high/low connectivity specialization.

Training Details

  • Checkpoint Path: output/contrastive_gemma_frozen_512/checkpoint_step_64000
  • Base Model: google/embeddinggemma-300m

Performance

Validation Metrics

High Connectivity

Base Branch:

  • Mean Rank (1st positive): 1.37
  • Hit@1: 82.00%
  • Hit@10: 100.00%
  • MRR: 0.8860

Personalized Branch:

  • Mean Rank (1st positive): 1.34
  • Hit@1: 83.00%
  • Hit@10: 100.00%
  • MRR: 0.8945

Low Connectivity

Base Branch:

  • Mean Rank (1st positive): 2.01
  • Hit@1: 65.00%
  • Hit@10: 99.00%
  • MRR: 0.7707

Personalized Branch:

  • Mean Rank (1st positive): 2.07
  • Hit@1: 63.00%
  • Hit@10: 99.00%
  • MRR: 0.7564

Usage

See the README.md for detailed usage examples.

Usage Example

"""
Example: Using Fine-tuned Gemma Model for Tag Classification

This example shows how to use the fine-tuned Gemma model for
persona-conditioned tag classification using our module abstractions.

Installation:
    pip install git+https://github.com/Pieces/TAG-module.git@main
    # Or: pip install -e .

To use this model:
    1. Download model.pt and config.json from the Hub
    2. Place them in a directory (e.g., ./checkpoint/)
    3. Update checkpoint_path below
"""

import torch
from pathlib import Path
from playground.validate_from_checkpoint import (
    load_trained_model,
    encode_query,
    encode_general_tags,
    compute_ranked_tags,
)

# Load the fine-tuned model
print("Loading fine-tuned Gemma model...")

# Option 1: Load from local checkpoint (if you have it)
checkpoint_path = "output/contrastive_gemma_frozen_512/checkpoint_step_64000"

# Option 2: Load from downloaded Hub files
# 1. Download model.pt and config.json from:
#    https://huggingface.co/Pieces/gemma-frozen-512-step64000
# 2. Place in a directory and update path:
# checkpoint_path = "./downloaded_checkpoint"

if not Path(checkpoint_path).exists():
    print(f"âš  Checkpoint not found at: {checkpoint_path}")
    print("Please download model.pt and config.json from the Hub and update checkpoint_path")
    exit(1)

model, config = load_trained_model(checkpoint_path, device="cpu")
model.eval()
print("✓ Model loaded!")

# Example query with persona
query_text = "How to implement OAuth2 authentication in a Python Flask API?"
persona_text = "I'm a backend developer working on web APIs and microservices"

# Candidate tags to rank
candidate_tags = [
    "python", "flask", "oauth2", "authentication", "api",
    "security", "web-development", "jwt", "rest-api", "backend",
    "microservices", "fastapi", "django"
]

print(f"\nQuery: {query_text}")
print(f"Persona: {persona_text}")
print(f"Candidate tags: {candidate_tags}\n")

# Encode query using base branch (no persona conditioning)
print("Encoding query (base branch)...")
with torch.inference_mode():
    query_emb_base = encode_query(
        model=model,
        query_text=query_text,
        persona_text=persona_text,
        connectivity="high",
        branch="base",
        max_length=512,
        use_pretrained_backbone=False,
        extraction_mode="full",
    )

# Encode query using personalized branch (with persona conditioning)
print("Encoding query (personalized branch)...")
with torch.inference_mode():
    query_emb_personalized = encode_query(
        model=model,
        query_text=query_text,
        persona_text=persona_text,
        connectivity="high",
        branch="personalized",  # Use personalized branch
        max_length=512,
        use_pretrained_backbone=False,
        extraction_mode="full",
    )

# Encode tags
print("Encoding tags...")
with torch.inference_mode():
    tag_embs_base = encode_general_tags(
        model=model,
        general_tags=candidate_tags,
        connectivity="high",
        branch="base",
        persona_text=persona_text,
        max_length=512,
        use_pretrained_backbone=False,
        extraction_mode="full",
    )
    
    tag_embs_personalized = encode_general_tags(
        model=model,
        general_tags=candidate_tags,
        connectivity="high",
        branch="personalized",  # Use personalized branch
        persona_text=persona_text,
        max_length=512,
        use_pretrained_backbone=False,
        extraction_mode="full",
    )

print(f"Query embedding shape: {query_emb_base.shape}")
print(f"Tag embeddings shape: {tag_embs_base.shape}")

# Rank tags using base branch
print("\n" + "="*60)
print("Rankings (Base Branch - No Persona Conditioning):")
print("="*60)
ranked_tags_base = compute_ranked_tags(
    query_emb=query_emb_base,
    pos_embs=torch.empty(0, model.config.backbone_embedding_dim),
    neg_embs=torch.empty(0, model.config.backbone_embedding_dim),
    general_embs=tag_embs_base,
    positive_tags=[],
    negative_tags=[],
    general_tags=candidate_tags,
)

for tag, rank, label, score in ranked_tags_base[:5]:
    print(f"{rank:2d}. {tag:20s} (score: {score:.4f})")

# Rank tags using personalized branch
print("\n" + "="*60)
print("Rankings (Personalized Branch - With Persona Conditioning):")
print("="*60)
ranked_tags_personalized = compute_ranked_tags(
    query_emb=query_emb_personalized,
    pos_embs=torch.empty(0, model.config.backbone_embedding_dim),
    neg_embs=torch.empty(0, model.config.backbone_embedding_dim),
    general_embs=tag_embs_personalized,
    positive_tags=[],
    negative_tags=[],
    general_tags=candidate_tags,
)

for tag, rank, label, score in ranked_tags_personalized[:5]:
    print(f"{rank:2d}. {tag:20s} (score: {score:.4f})")

print("\n" + "="*60)
print("Example complete!")
print("\nNote: Personalized branch adapts tag rankings based on the persona context.")

Running the Example

# Install the repository first
pip install git+https://github.com/Pieces/TAG-module.git@main
# Or for local development:
pip install -e .

# Run the example
python gemma_finetuned_example.py

Citation

If you use this model, please cite:

@software{{tag_module,
  title = {{TAG Module: Persona-Conditioned Contrastive Learning for Tag Classification}},
  author = {{Your Name}},
  year = {{2025}},
  url = {{https://github.com/yourusername/tag-module}}
}}

License

Please refer to the original model license for the backbone model.

Downloads last month
8
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support