| import torch |
| import logging |
| import re |
| from typing import Dict, List, Any |
| from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the RECCON emotional trigger extraction model using native transformers. |
| Args: |
| path: Path to model directory (provided by HuggingFace Inference Endpoints) |
| """ |
| logger.info("Initializing RECCON Trigger Extraction endpoint...") |
|
|
| |
| cuda_available = torch.cuda.is_available() |
| if not cuda_available: |
| logger.warning("GPU not detected. Running on CPU. Inference will be slower.") |
| |
| |
| self.device_id = 0 if cuda_available else -1 |
|
|
| |
| model_path = path if path and path != "." else "." |
| logger.info(f"Loading model from {model_path}...") |
|
|
| try: |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model, loading_info = AutoModelForQuestionAnswering.from_pretrained( |
| model_path, |
| output_loading_info=True |
| ) |
|
|
| logger.warning("RECCON load info - missing_keys: %s", loading_info.get("missing_keys")) |
| logger.warning("RECCON load info - unexpected_keys: %s", loading_info.get("unexpected_keys")) |
| logger.warning("RECCON load info - error_msgs: %s", loading_info.get("error_msgs")) |
| logger.warning("Loaded model class: %s", model.__class__.__name__) |
| logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None)) |
|
|
| |
| |
| self.pipe = pipeline( |
| "question-answering", |
| model=model, |
| tokenizer=tokenizer, |
| device=self.device_id, |
| top_k=20, |
| handle_impossible_answer=False |
| ) |
| logger.info("Model loaded successfully.") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise |
|
|
| |
| self.question_template = ( |
| "Extract the exact short phrase (<= 8 words) from the target " |
| "utterance that most strongly signals the emotion {emotion}. " |
| "Return only a substring of the target utterance." |
| ) |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Process inference request. |
| """ |
| |
| inputs = data.pop("inputs", data) |
|
|
| |
| if isinstance(inputs, dict): |
| inputs = [inputs] |
|
|
| if not inputs: |
| return [{"error": "No inputs provided", "triggers": []}] |
|
|
| |
| pipeline_inputs = [] |
| valid_indices = [] |
|
|
| for i, item in enumerate(inputs): |
| utterance = item.get("utterance", "").strip() |
| emotion = item.get("emotion", "") |
|
|
| if not utterance: |
| logger.warning(f"Empty utterance at index {i}") |
| continue |
|
|
| |
| question = self.question_template.format(emotion=emotion) |
| |
| |
| pipeline_inputs.append({ |
| 'question': question, |
| 'context': utterance |
| }) |
| valid_indices.append(i) |
|
|
| |
| results = [] |
|
|
| if not pipeline_inputs: |
| |
| for item in inputs: |
| results.append({ |
| "utterance": item.get("utterance", ""), |
| "emotion": item.get("emotion", ""), |
| "error": "Missing or empty utterance", |
| "triggers": [] |
| }) |
| return results |
|
|
| try: |
| |
| predictions = self.pipe(pipeline_inputs, batch_size=8) |
| |
| |
| |
| if isinstance(predictions, dict): |
| predictions = [predictions] |
| elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict): |
| |
| |
| |
| if len(pipeline_inputs) == 1: |
| predictions = [predictions] |
| |
| |
| |
| logger.debug(f"Raw predictions: {predictions}") |
|
|
| |
| pred_idx = 0 |
| for i, item in enumerate(inputs): |
| utterance = item.get("utterance", "").strip() |
| emotion = item.get("emotion", "") |
|
|
| if i not in valid_indices: |
| results.append({ |
| "utterance": utterance, |
| "emotion": emotion, |
| "error": "Missing or empty utterance", |
| "triggers": [] |
| }) |
| else: |
| |
| |
| current_preds = predictions[pred_idx] |
|
|
| |
| |
| if isinstance(current_preds, dict): |
| current_preds = [current_preds] |
|
|
| logger.info( |
| "RECCON raw spans (answer, score): %s", |
| [(p.get("answer"), p.get("score", 0.0), 3) for p in current_preds[:5]] |
| ) |
| |
| def is_good_span(ans: str) -> bool: |
| if not ans: |
| return False |
| a = ans.strip() |
| if len(a) < 3: |
| return False |
| |
| if all(ch in ".,!?;:-—'\"()[]{}" for ch in a): |
| return False |
| |
| if not any(ch.isalpha() for ch in a): |
| return False |
| return True |
| |
| raw_answers = [p.get("answer", "") for p in current_preds] |
| raw_answers = [a for a in raw_answers if is_good_span(a)] |
| triggers = self._clean_spans(raw_answers, utterance) |
|
|
| results.append({ |
| "utterance": utterance, |
| "emotion": emotion, |
| "triggers": triggers |
| }) |
| pred_idx += 1 |
|
|
| logger.debug(f"Cleaned results: {results}") |
| return results |
|
|
| except Exception as e: |
| logger.error(f"Model prediction failed: {e}") |
| return [{ |
| "utterance": item.get("utterance", ""), |
| "emotion": item.get("emotion", ""), |
| "error": str(e), |
| "triggers": [] |
| } for item in inputs] |
|
|
| def _clean_spans(self, spans: List[str], target_text: str) -> List[str]: |
| """ |
| Clean and filter extracted trigger spans. |
| (Logic preserved exactly as provided) |
| """ |
| target_text = target_text or "" |
| target_lower = target_text.lower() |
|
|
| def _norm(s: str) -> str: |
| s = (s or "").strip().lower() |
| s = re.sub(r"\s+", " ", s) |
| s = re.sub(r"^[^\w]+|[^\w]+$", "", s) |
| return s |
|
|
| def _extract_from_target(target: str, phrase_lower: str) -> str: |
| idx = target.lower().find(phrase_lower) |
| if idx >= 0: |
| return target[idx:idx+len(phrase_lower)] |
| return phrase_lower |
|
|
| STOP = { |
| "a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at", |
| "with", "for", "from", "is", "am", "are", "was", "were", "be", "been", |
| "being", "i", "you", "he", "she", "it", "we", "they", "my", "your", "his", |
| "her", "their", "our", "me", "him", "her", "them", "this", "that", "these", |
| "those" |
| } |
|
|
| candidates = [] |
| for s in spans: |
| s = (s or "").strip() |
| if not s: |
| continue |
| s_norm = _norm(s) |
| if not s_norm: |
| continue |
| if target_text and s_norm not in target_lower: |
| continue |
| tokens = s_norm.split() |
| if len(tokens) > 8 or len(s_norm) > 80: |
| continue |
| if len(tokens) == 1 and (tokens[0] in STOP or len(tokens[0]) <= 2): |
| continue |
| candidates.append({ |
| "norm": s_norm, |
| "tokens": tokens, |
| "tok_len": len(tokens), |
| "char_len": len(s_norm) |
| }) |
|
|
| |
| short_candidates = [c for c in candidates if 1 <= c["tok_len"] <= 3] |
| if short_candidates: |
| candidates = short_candidates |
| |
| |
| candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=False) |
| kept_norms = [] |
| for c in list(candidates): |
| n = c["norm"] |
| if any(n in kn or kn in n for kn in kept_norms): |
| continue |
| kept_norms.append(n) |
|
|
| cleaned = [_extract_from_target(target_text, n) for n in kept_norms] |
|
|
| if not cleaned and spans: |
| tt_tokens = target_lower.split() |
| best = None |
| for s in spans: |
| words = [w for w in (s or '').lower().strip().split() if w] |
| for L in range(min(8, len(words)), 0, -1): |
| for i in range(len(words) - L + 1): |
| phrase = words[i:i+L] |
| for j in range(len(tt_tokens) - L + 1): |
| if tt_tokens[j:j+L] == phrase: |
| cand = " ".join(phrase) |
| best = cand |
| break |
| if best: |
| break |
| if best: |
| break |
| if best: |
| return [_extract_from_target(target_text, best)] |
|
|
| return cleaned[:3] |