Spaces:
Sleeping
Sleeping
| # app/models/text_model.py | |
| # RoBERTa-based text toxicity model with ONNX optimization | |
| from pathlib import Path | |
| import numpy as np | |
| from app.config import get_settings | |
| from app.observability.logging import get_logger | |
| logger = get_logger(__name__) | |
| class TextToxicityModel: | |
| """ | |
| Text toxicity classifier using a RoBERTa-based model. | |
| Supports both ONNX (fast) and PyTorch (fallback) inference. | |
| Model: unitary/toxic-bert (multi-label toxicity detection). | |
| Labels: toxic, severe_toxic, obscene, threat, insult, identity_hate | |
| """ | |
| LABELS = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] | |
| def __init__(self): | |
| self.settings = get_settings() | |
| self.tokenizer = None | |
| self.onnx_session = None | |
| self.pt_model = None | |
| self.device = None | |
| self._loaded = False | |
| def load(self) -> None: | |
| """Load the tokenizer and model (ONNX preferred, PyTorch fallback).""" | |
| from transformers import AutoTokenizer | |
| model_name = self.settings.text_model_name | |
| cache_dir = self.settings.model_cache_path / "roberta" | |
| onnx_path = cache_dir / "text_toxicity.onnx" | |
| logger.info("loading_text_model", model=model_name) | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) | |
| if self.settings.onnx_enabled and onnx_path.exists(): | |
| # Use existing ONNX model | |
| from app.models.onnx_utils import load_onnx_session | |
| self.onnx_session = load_onnx_session(onnx_path) | |
| logger.info("text_model_loaded", backend="onnx") | |
| elif self.settings.onnx_enabled: | |
| # Load PyTorch, export to ONNX, then use ONNX | |
| self._load_pytorch(model_name, cache_dir) | |
| self._export_onnx(onnx_path) | |
| # Switch to ONNX session | |
| from app.models.onnx_utils import load_onnx_session | |
| self.onnx_session = load_onnx_session(onnx_path) | |
| self.pt_model = None # Free PyTorch memory | |
| logger.info("text_model_loaded", backend="onnx", note="exported_from_pytorch") | |
| else: | |
| # PyTorch only | |
| self._load_pytorch(model_name, cache_dir) | |
| logger.info("text_model_loaded", backend="pytorch") | |
| self._loaded = True | |
| def _load_pytorch(self, model_name: str, cache_dir: Path) -> None: | |
| """Load the PyTorch model.""" | |
| import torch | |
| from transformers import AutoModelForSequenceClassification | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.pt_model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, cache_dir=cache_dir | |
| ) | |
| self.pt_model.to(self.device) | |
| self.pt_model.eval() | |
| def _export_onnx(self, onnx_path: Path) -> None: | |
| """Export current PyTorch model to ONNX.""" | |
| import torch | |
| from app.models.onnx_utils import export_to_onnx | |
| sample = self.tokenizer( | |
| "test input for export", | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=128, | |
| ) | |
| sample = {k: v.to(self.device) for k, v in sample.items()} | |
| export_to_onnx( | |
| model=self.pt_model, | |
| sample_input=sample, | |
| output_path=onnx_path, | |
| input_names=["input_ids", "attention_mask"], | |
| output_names=["logits"], | |
| ) | |
| def predict(self, text: str) -> dict: | |
| """ | |
| Predict toxicity scores for input text. | |
| Args: | |
| text: Input text to classify. | |
| Returns: | |
| Dict with: | |
| - labels: list of label names | |
| - scores: list of per-label probabilities | |
| - is_toxic: bool (any label > 0.5) | |
| - max_score: float (highest toxicity probability) | |
| - max_label: str (label with highest probability) | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Text model not loaded. Call load() first.") | |
| # Tokenize | |
| encoding = self.tokenizer( | |
| text, | |
| return_tensors="np" if self.onnx_session else "pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=128, | |
| ) | |
| if self.onnx_session: | |
| return self._predict_onnx(encoding) | |
| else: | |
| return self._predict_pytorch(encoding) | |
| def _predict_onnx(self, encoding: dict) -> dict: | |
| """Run ONNX inference.""" | |
| from app.models.onnx_utils import onnx_inference | |
| inputs = { | |
| "input_ids": encoding["input_ids"].astype(np.int64), | |
| "attention_mask": encoding["attention_mask"].astype(np.int64), | |
| } | |
| outputs = onnx_inference(self.onnx_session, inputs) | |
| logits = outputs[0][0] # (num_labels,) | |
| return self._format_output(logits) | |
| def _predict_pytorch(self, encoding: dict) -> dict: | |
| """Run PyTorch inference.""" | |
| import torch | |
| inputs = {k: v.to(self.device) for k, v in encoding.items()} | |
| with torch.no_grad(): | |
| outputs = self.pt_model(**inputs) | |
| logits = outputs.logits[0].cpu().numpy() | |
| return self._format_output(logits) | |
| def _format_output(self, logits: np.ndarray) -> dict: | |
| """Convert raw logits to formatted prediction dict.""" | |
| # Sigmoid for multi-label classification | |
| scores = 1 / (1 + np.exp(-logits)) | |
| scores = scores.tolist() | |
| # Handle case where model has fewer outputs than expected labels | |
| labels = self.LABELS[: len(scores)] | |
| label_scores = dict(zip(labels, scores)) | |
| max_idx = int(np.argmax(scores)) | |
| is_toxic = any(s > 0.5 for s in scores) | |
| return { | |
| "labels": labels, | |
| "scores": scores, | |
| "label_scores": label_scores, | |
| "is_toxic": is_toxic, | |
| "max_score": scores[max_idx], | |
| "max_label": labels[max_idx], | |
| } | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |