sentiment-transformer / example.py
Impulse2000's picture
Upload sentiment-transformer model
21035f8 verified
"""
Example usage of the Sentiment Transformer with HuggingFace Transformers.
This file is included in every HF export directory as a quick-start reference.
Usage::
python example.py
python example.py --text "This movie was incredible!"
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
def main() -> None:
parser = argparse.ArgumentParser(
description="Quick-start example for the Sentiment Transformer.",
)
parser.add_argument(
"--text",
type=str,
default=None,
help="Single text to classify. If omitted, runs built-in examples.",
)
parser.add_argument(
"--model-dir",
type=str,
default=str(Path(__file__).resolve().parent),
help="Path to the HF model directory. Defaults to this file's directory.",
)
args = parser.parse_args()
try:
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
pipeline,
)
except ImportError:
print("ERROR: `transformers` is required. Install with:")
print(" pip install transformers torch")
sys.exit(1)
print(f"Loading model from: {args.model_dir}")
model = AutoModelForSequenceClassification.from_pretrained(
args.model_dir, trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
print(f"Model: {type(model).__name__}")
print(f"Labels: {model.config.id2label}")
print()
if args.text:
texts = [args.text]
else:
texts = [
"This movie was absolutely fantastic! I loved every minute of it.",
"Terrible film, completely unwatchable garbage.",
"The movie was okay, nothing special really.",
"An outstanding performance by the entire cast.",
"I fell asleep halfway through. Waste of time.",
]
results = pipe(texts)
for text, result in zip(texts, results):
label = result["label"]
score = result["score"]
print(f" {label:8s} ({score:.4f}) {text}")
# Top-k example
print("\n--- Top-k prediction ---")
sample = texts[0]
top_k = pipe(sample, top_k=None)
print(f" \"{sample[:60]}...\"")
for r in top_k:
bar = "█" * int(r["score"] * 40)
print(f" {r['label']:8s} {r['score']:.4f} {bar}")
if __name__ == "__main__":
main()