infy / utils.py
shourya
Fix NER model ID and add resilient fallback loading
d153152
"""
Utility functions for HuggingFace Enabling Sessions Spaces app
"""
import torch
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM
)
import numpy as np
from functools import lru_cache
import config
# Lazy loading for heavy dependencies
_sbert_model = None
_qa_model = None
_qa_tokenizer = None
_summarization_model = None
_summarization_tokenizer = None
def get_sbert_model():
"""Lazy load Sentence-BERT model only when needed."""
global _sbert_model
if _sbert_model is None:
from sentence_transformers import SentenceTransformer
_sbert_model = SentenceTransformer(config.EMBEDDINGS_MODEL, device="cpu")
return _sbert_model
@lru_cache(maxsize=10)
def load_pipeline(task_type: str):
"""Load and cache a pipeline for the given task."""
try:
device = -1 # Use CPU (safer for Spaces)
if task_type == "sentiment":
return pipeline("sentiment-analysis", model=config.SENTIMENT_MODEL, device=device)
elif task_type == "ner":
try:
return pipeline("ner", model=config.NER_MODEL, device=device, aggregation_strategy="simple")
except Exception:
# Fallback to another public NER model if primary ID fails.
fallback_ner_model = "dbmdz/bert-large-cased-finetuned-conll03-english"
return pipeline("ner", model=fallback_ner_model, device=device, aggregation_strategy="simple")
elif task_type == "summarization":
# `summarization` alias is not present in some transformers builds.
return pipeline("text2text-generation", model=config.SUMMARIZATION_MODEL, device=device)
else:
raise ValueError(f"Unknown task type: {task_type}")
except Exception as e:
raise Exception(f"Error loading {task_type} pipeline: {str(e)}")
def get_qa_model():
"""Lazy load QA model and tokenizer."""
global _qa_model, _qa_tokenizer
if _qa_model is None:
_qa_tokenizer = AutoTokenizer.from_pretrained(config.QA_MODEL)
_qa_model = AutoModelForQuestionAnswering.from_pretrained(config.QA_MODEL)
_qa_model.eval()
return _qa_model, _qa_tokenizer
def get_summarization_model():
"""Lazy load Summarization model and tokenizer."""
global _summarization_model, _summarization_tokenizer
if _summarization_model is None:
_summarization_tokenizer = AutoTokenizer.from_pretrained(config.SUMMARIZATION_MODEL)
_summarization_model = AutoModelForSeq2SeqLM.from_pretrained(config.SUMMARIZATION_MODEL)
_summarization_model.eval()
return _summarization_model, _summarization_tokenizer
def run_sentiment_analysis(text: str):
"""Run sentiment analysis on text."""
pipe = load_pipeline("sentiment")
result = pipe(text[:512]) # Truncate to avoid token limit
return result[0] if result else {"label": "Unknown", "score": 0}
def run_ner(text: str):
"""Run Named Entity Recognition on text."""
try:
pipe = load_pipeline("ner")
result = pipe(text[:512])
return result
except Exception as e:
return [{"word": "", "entity_group": "ERROR", "score": 0.0, "error": str(e)}]
def run_qa(context: str, question: str):
"""Run question answering on context using direct model inference."""
try:
model, tokenizer = get_qa_model()
inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
answer_start_idx = outputs.start_logits.argmax(dim=1).item()
answer_end_idx = outputs.end_logits.argmax(dim=1).item() + 1
answer = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start_idx:answer_end_idx])
)
score = (outputs.start_logits.max().item() + outputs.end_logits.max().item()) / 2
return {
"answer": answer.strip(),
"score": float(score),
"start": int(answer_start_idx),
"end": int(answer_end_idx)
}
except Exception as e:
return {"error": str(e), "answer": "Unable to answer", "score": 0}
def run_summarization(text: str):
"""Generate summary of text using direct model inference."""
try:
model, tokenizer = get_summarization_model()
inputs = tokenizer(text[:1024], return_tensors="pt", max_length=1024, truncation=True)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
max_length=150,
min_length=30,
num_beams=4,
length_penalty=2.0,
early_stopping=True,
forced_bos_token_id=0,
)
summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0]
return summary.strip()
except Exception as e:
return f"Error: {str(e)}"
def compute_similarity(text1: str, text2: str):
"""Compute semantic similarity between two texts."""
try:
from sentence_transformers import util
model = get_sbert_model()
embeddings = model.encode([text1, text2], convert_to_tensor=True)
similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1])
return float(similarity.item())
except Exception as e:
return f"Error: {str(e)}"
def tokenize_text(text: str, model_name: str = config.SENTIMENT_MODEL):
"""Tokenize text and show tokens with IDs."""
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoding = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])
token_ids = encoding["input_ids"][0].tolist()
attention_mask = encoding["attention_mask"][0].tolist()
result = {
"tokens": tokens,
"token_ids": token_ids,
"attention_mask": attention_mask,
"num_tokens": len(tokens),
}
return result
except Exception as e:
return {"error": str(e)}
def format_tokenizer_output(tokenization_result):
"""Format tokenization result for display."""
if "error" in tokenization_result:
return f"Error: {tokenization_result['error']}"
tokens = tokenization_result["tokens"]
token_ids = tokenization_result["token_ids"]
output = f"**Total Tokens:** {tokenization_result['num_tokens']}\n\n"
output += "| Token | Token ID | Attention Mask |\n"
output += "|-------|----------|----------------|\n"
for token, tid, attn in zip(
tokens,
token_ids,
tokenization_result["attention_mask"]
):
output += f"| {token} | {tid} | {attn} |\n"
return output
def format_ner_output(ner_results):
"""Format NER results for display."""
if not ner_results:
return "No entities found"
output = "| Entity | Type | Score |\n"
output += "|--------|------|-------|\n"
for result in ner_results:
word = result.get("word", "")
entity_type = result.get("entity_group", result.get("entity", ""))
score = result.get("score", 0)
output += f"| {word} | {entity_type} | {score:.4f} |\n"
return output