Rishi2455 commited on
Commit
a281be3
·
verified ·
1 Parent(s): eeb9ddc

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +83 -0
  2. model_loader.py +643 -0
  3. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import asyncio
4
+ import logging
5
+ from model_loader import engine
6
+ from deepgram import DeepgramClient, PrerecordedOptions
7
+ from huggingface_hub import snapshot_download
8
+
9
+ # Setup Logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger("smart_turn_gradio")
12
+
13
+ # --- HF Hub Configuration ---
14
+ # This downloads your model from the hub to a local cache
15
+ MODEL_REPO_ID = "Rishi2455/smart-turn-model"
16
+ local_model_path = snapshot_download(repo_id=MODEL_REPO_ID)
17
+
18
+ # Initialize Deepgram
19
+ DEEPGRAM_API_KEY = os.getenv("DEEPGRAM_API_KEY")
20
+ dg_client = DeepgramClient(DEEPGRAM_API_KEY) if DEEPGRAM_API_KEY else None
21
+
22
+ async def load_model():
23
+ """Ensure model is loaded from the HF Hub path."""
24
+ if not engine.is_loaded:
25
+ # Pass the local path where HF Hub files were downloaded
26
+ await engine.load_model(local_model_path)
27
+
28
+ async def predict_text(text):
29
+ if not text:
30
+ return "Please enter some text."
31
+
32
+ await load_model()
33
+ result = await engine.predict(text)
34
+
35
+ status = "✅ COMPLETE (User finished)" if result["is_complete"] else "⏳ INCOMPLETE (User still talking)"
36
+ return f"**Result:** {status}\n**Confidence:** {result['confidence']:.2%}\n**Prob(Complete):** {result['complete_probability']:.2%}"
37
+
38
+ async def predict_audio(audio_path):
39
+ if not audio_path:
40
+ return "Please record or upload audio."
41
+ if not dg_client:
42
+ return "Deepgram API Key NOT found. Please set DEEPGRAM_API_KEY in Settings."
43
+
44
+ await load_model()
45
+
46
+ with open(audio_path, 'rb') as audio:
47
+ source = {'buffer': audio.read()}
48
+ options = PrerecordedOptions(model="nova-2", smart_format=True)
49
+ response = dg_client.listen.rest.v("1").transcribe_file(source, options)
50
+
51
+ transcript = response.results.channels[0].alternatives[0].transcript
52
+ if not transcript:
53
+ return "No speech detected."
54
+
55
+ result = await engine.predict(transcript)
56
+
57
+ status = "✅ COMPLETE" if result["is_complete"] else "⏳ INCOMPLETE"
58
+ return f"**Transcript:** \"{transcript}\"\n\n**Result:** {status}\n**Confidence:** {result['confidence']:.2%}"
59
+
60
+ # Gradio Interface
61
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
62
+ gr.Markdown("# 🤖 Smart Turn - EOU Detection")
63
+ gr.Markdown(f"Running model from: [{MODEL_REPO_ID}](https://huggingface.co/{MODEL_REPO_ID})")
64
+
65
+ with gr.Tab("📝 Text Prediction"):
66
+ text_input = gr.Textbox(placeholder="Type a sentence...", label="Input Text")
67
+ text_output = gr.Markdown(label="Analysis")
68
+ text_btn = gr.Button("Analyze Text")
69
+ text_btn.click(predict_text, inputs=text_input, outputs=text_output)
70
+
71
+ with gr.Tab("🎙️ Audio Prediction"):
72
+ audio_input = gr.Audio(type="filepath", label="Record/Upload Audio")
73
+ audio_output = gr.Markdown(label="Analysis")
74
+ audio_btn = gr.Button("Transcribe & Analyze")
75
+ audio_btn.click(predict_audio, inputs=audio_input, outputs=audio_output)
76
+
77
+ gr.Examples(
78
+ examples=[["i want to"], ["i want to book a flight"], ["can you help me with"]],
79
+ inputs=text_input
80
+ )
81
+
82
+ if __name__ == "__main__":
83
+ demo.launch()
model_loader.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Async Model Loader & Inference Engine for EOU Detection
3
+ Supports ONNX Runtime (fast) with PyTorch fallback.
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import json
9
+ import asyncio
10
+ import time
11
+ import logging
12
+ from typing import List, Dict, Optional, Any
13
+ from dataclasses import dataclass
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from safetensors.torch import load_file as load_safetensors
16
+
17
+ import numpy as np
18
+
19
+ logger = logging.getLogger("eou_model")
20
+
21
+ # Try importing ONNX Runtime first, then PyTorch as fallback
22
+ try:
23
+ import onnxruntime as ort
24
+ ONNX_AVAILABLE = True
25
+ logger.info("ONNX Runtime available — will use fast inference path")
26
+ except ImportError:
27
+ ONNX_AVAILABLE = False
28
+ logger.warning("onnxruntime not installed — falling back to PyTorch")
29
+
30
+ try:
31
+ import torch
32
+ import torch.nn as nn
33
+ from transformers import DebertaV2Model
34
+ TORCH_AVAILABLE = True
35
+ except ImportError:
36
+ TORCH_AVAILABLE = False
37
+
38
+ from transformers import AutoTokenizer
39
+
40
+
41
+ # ============================================================
42
+ # Config & Feature Extraction
43
+ # ============================================================
44
+
45
+ @dataclass
46
+ class Config:
47
+ model_name: str = "microsoft/deberta-v3-base"
48
+ max_length: int = 128 # Reduced from 256 — EOU utterances are short
49
+ use_aux_features: bool = True
50
+ dropout: float = 0.1
51
+ label_smoothing: float = 0.05
52
+
53
+
54
+ class TextCleaner:
55
+ """Clean text for ASR-trained model (no punctuation expected)"""
56
+
57
+ # Compile regex once for performance
58
+ _PUNCT_RE = re.compile(r'[^\w\s]', re.UNICODE)
59
+ _MULTI_SPACE_RE = re.compile(r'\s+')
60
+
61
+ @classmethod
62
+ def clean(cls, text: str) -> str:
63
+ """Strip punctuation, lowercase, and normalize whitespace."""
64
+ text = text.strip()
65
+ if not text:
66
+ return text
67
+ text = cls._PUNCT_RE.sub('', text) # Remove all punctuation
68
+ text = cls._MULTI_SPACE_RE.sub(' ', text) # Collapse multiple spaces
69
+ text = text.strip().lower() # Lowercase for ASR input
70
+ return text
71
+
72
+
73
+ class SemanticFeatureExtractor:
74
+ """Extract 15 semantic features for EOU detection (punctuation-free).
75
+
76
+ Matches the feature_type='semantic_no_punctuation' training config.
77
+ """
78
+
79
+ CONJUNCTIONS = {'and', 'but', 'or', 'so', 'because', 'since', 'although',
80
+ 'while', 'if', 'when', 'that', 'which', 'who', 'where',
81
+ 'unless', 'until', 'whether', 'though', 'whereas'}
82
+
83
+ PREPOSITIONS = {'to', 'for', 'with', 'at', 'in', 'on', 'of', 'from',
84
+ 'by', 'about', 'into', 'through', 'during', 'before',
85
+ 'after', 'above', 'below', 'between', 'under', 'over'}
86
+
87
+ ARTICLES = {'a', 'an', 'the'}
88
+
89
+ SUBJECT_PRONOUNS = {'i', 'we', 'they', 'he', 'she', 'it', 'you'}
90
+
91
+ AUXILIARIES = {'is', 'am', 'are', 'was', 'were', 'be', 'been', 'being',
92
+ 'have', 'has', 'had', 'do', 'does', 'did',
93
+ 'will', 'would', 'shall', 'should',
94
+ 'can', 'could', 'may', 'might', 'must'}
95
+
96
+ COMMON_TRANSITIVE = {'get', 'got', 'take', 'took', 'make', 'made',
97
+ 'give', 'gave', 'tell', 'told', 'find', 'found',
98
+ 'know', 'knew', 'want', 'need', 'see', 'saw',
99
+ 'put', 'keep', 'kept', 'let', 'say', 'said',
100
+ 'think', 'thought', 'ask', 'asked', 'use', 'used',
101
+ 'show', 'showed', 'try', 'tried', 'buy', 'bought'}
102
+
103
+ # Common verbs for has_verb detection
104
+ COMMON_VERBS = AUXILIARIES | COMMON_TRANSITIVE | {
105
+ 'go', 'went', 'come', 'came', 'run', 'ran', 'look', 'looked',
106
+ 'like', 'liked', 'play', 'played', 'work', 'worked', 'call',
107
+ 'called', 'move', 'moved', 'live', 'lived', 'believe', 'happen',
108
+ 'happened', 'include', 'included', 'turn', 'turned', 'follow',
109
+ 'followed', 'begin', 'began', 'seem', 'seemed', 'help', 'helped',
110
+ 'talk', 'talked', 'start', 'started', 'write', 'wrote', 'read',
111
+ 'feel', 'felt', 'provide', 'hold', 'held', 'stand', 'stood',
112
+ 'set', 'learn', 'learned', 'change', 'changed', 'lead', 'led',
113
+ 'understand', 'understood', 'watch', 'watched', 'pay', 'paid',
114
+ 'bring', 'brought', 'meet', 'met', 'send', 'sent', 'build',
115
+ 'built', 'stay', 'stayed', 'open', 'opened', 'create', 'created'
116
+ }
117
+
118
+ COMMON_NOUNS_SIMPLE = {
119
+ 'time', 'year', 'people', 'way', 'day', 'man', 'woman', 'child',
120
+ 'world', 'life', 'hand', 'part', 'place', 'case', 'week', 'company',
121
+ 'system', 'program', 'question', 'work', 'government', 'number',
122
+ 'night', 'point', 'home', 'water', 'room', 'mother', 'area',
123
+ 'money', 'story', 'fact', 'month', 'lot', 'right', 'study',
124
+ 'book', 'eye', 'job', 'word', 'business', 'issue', 'side', 'kind',
125
+ 'head', 'house', 'service', 'friend', 'father', 'power', 'hour',
126
+ 'game', 'line', 'end', 'members', 'city', 'community',
127
+ 'name', 'president', 'team', 'minute', 'idea', 'body', 'information',
128
+ 'back', 'parent', 'face', 'others', 'level', 'office', 'door',
129
+ 'health', 'person', 'art', 'car', 'food', 'phone', 'thing',
130
+ 'things', 'problem', 'answer', 'account', 'card', 'payment'
131
+ }
132
+
133
+ DISCOURSE_MARKERS = {'well', 'so', 'like', 'okay', 'ok', 'yeah',
134
+ 'yes', 'no', 'right', 'sure', 'actually',
135
+ 'basically', 'honestly', 'anyway', 'alright',
136
+ 'exactly', 'absolutely', 'definitely', 'totally'}
137
+
138
+ ADVERBS = {'very', 'really', 'also', 'just', 'now', 'then', 'still',
139
+ 'already', 'always', 'never', 'often', 'sometimes',
140
+ 'usually', 'quickly', 'slowly', 'well', 'too', 'quite',
141
+ 'almost', 'enough', 'only', 'even', 'probably', 'maybe',
142
+ 'certainly', 'finally', 'recently', 'actually', 'simply',
143
+ 'clearly', 'completely', 'especially', 'generally'}
144
+
145
+ FUNCTION_WORDS = (
146
+ CONJUNCTIONS | PREPOSITIONS | ARTICLES
147
+ | SUBJECT_PRONOUNS | AUXILIARIES
148
+ | {'the', 'a', 'an', 'this', 'that', 'these', 'those',
149
+ 'my', 'your', 'his', 'her', 'its', 'our', 'their',
150
+ 'not', 'no', 'very', 'just', 'also', 'too'}
151
+ )
152
+
153
+ @classmethod
154
+ def extract(cls, text: str) -> List[float]:
155
+ """Extract 15 semantic features (no punctuation features)."""
156
+ text = text.strip()
157
+ words = text.lower().split()
158
+ num_words = len(words)
159
+ last_word = words[-1] if words else ''
160
+
161
+ # Check if text has a verb anywhere
162
+ has_verb = float(any(w in cls.COMMON_VERBS for w in words))
163
+
164
+ # Check if there's a subject followed by a verb (simple heuristic)
165
+ has_subj_verb = 0.0
166
+ for i in range(len(words) - 1):
167
+ if words[i] in cls.SUBJECT_PRONOUNS and words[i + 1] in cls.COMMON_VERBS:
168
+ has_subj_verb = 1.0
169
+ break
170
+
171
+ # Check if a verb appeared earlier and last word is a noun
172
+ verb_seen = any(w in cls.COMMON_VERBS for w in words[:-1]) if num_words > 1 else False
173
+ ends_noun_after_verb = float(
174
+ verb_seen and last_word in cls.COMMON_NOUNS_SIMPLE
175
+ )
176
+
177
+ # Check if last word looks like a complete content word
178
+ # (not a function word, and at least 3 chars)
179
+ ends_complete_word = float(
180
+ last_word not in cls.FUNCTION_WORDS
181
+ and len(last_word) >= 3
182
+ ) if last_word else 0.0
183
+
184
+ # Adverb after verb check
185
+ ends_adverb_after_verb = float(
186
+ verb_seen and last_word in cls.ADVERBS
187
+ )
188
+
189
+ # Content word ratio
190
+ content_words = [w for w in words if w not in cls.FUNCTION_WORDS]
191
+ content_ratio = len(content_words) / max(num_words, 1)
192
+
193
+ features = [
194
+ float(last_word in cls.CONJUNCTIONS), # ends_conjunction
195
+ float(last_word in cls.PREPOSITIONS), # ends_preposition
196
+ float(last_word in cls.ARTICLES), # ends_article
197
+ float(last_word in cls.SUBJECT_PRONOUNS), # ends_subject_pronoun
198
+ float(last_word in cls.AUXILIARIES), # ends_auxiliary
199
+ float(last_word in cls.COMMON_TRANSITIVE), # ends_transitive
200
+ ends_complete_word, # ends_complete_word
201
+ has_verb, # has_verb
202
+ ends_noun_after_verb, # ends_noun_after_verb
203
+ float(last_word in cls.DISCOURSE_MARKERS), # ends_discourse_marker
204
+ min(num_words / 30.0, 1.0), # norm_word_count
205
+ has_subj_verb, # has_subj_verb
206
+ ends_adverb_after_verb, # ends_adverb_after_verb
207
+ float(num_words <= 2), # is_very_short
208
+ round(content_ratio, 4), # content_ratio
209
+ ]
210
+ return features
211
+
212
+ @classmethod
213
+ def feature_names(cls) -> List[str]:
214
+ return [
215
+ 'ends_conjunction', 'ends_preposition', 'ends_article',
216
+ 'ends_subject_pronoun', 'ends_auxiliary', 'ends_transitive',
217
+ 'ends_complete_word', 'has_verb', 'ends_noun_after_verb',
218
+ 'ends_discourse_marker', 'norm_word_count', 'has_subj_verb',
219
+ 'ends_adverb_after_verb', 'is_very_short', 'content_ratio'
220
+ ]
221
+
222
+
223
+ # ============================================================
224
+ # PyTorch Model (fallback only — kept for compatibility)
225
+ # ============================================================
226
+
227
+ if TORCH_AVAILABLE:
228
+ class DeBERTaEOUClassifier(nn.Module):
229
+ """DeBERTa with auxiliary features for End-of-Utterance detection"""
230
+
231
+ def __init__(self, config: Config, num_aux_features: int = 15): # 15 semantic features
232
+ super().__init__()
233
+ self.config = config
234
+ self.use_aux = config.use_aux_features
235
+
236
+ self.deberta = DebertaV2Model.from_pretrained(config.model_name)
237
+ hidden_size = self.deberta.config.hidden_size
238
+
239
+ self.pooler_dropout = nn.Dropout(config.dropout)
240
+
241
+ if self.use_aux:
242
+ self.aux_projection = nn.Sequential(
243
+ nn.Linear(num_aux_features, 32),
244
+ nn.GELU(),
245
+ nn.Dropout(config.dropout),
246
+ )
247
+ classifier_input_size = hidden_size + 32
248
+ else:
249
+ classifier_input_size = hidden_size
250
+
251
+ self.classifier = nn.Sequential(
252
+ nn.Linear(classifier_input_size, 256),
253
+ nn.GELU(),
254
+ nn.LayerNorm(256),
255
+ nn.Dropout(config.dropout),
256
+ nn.Linear(256, 64),
257
+ nn.GELU(),
258
+ nn.Dropout(config.dropout),
259
+ nn.Linear(64, 2),
260
+ )
261
+
262
+ def forward(self, input_ids, attention_mask, token_type_ids=None,
263
+ aux_features=None, labels=None):
264
+
265
+ outputs = self.deberta(
266
+ input_ids=input_ids,
267
+ attention_mask=attention_mask,
268
+ token_type_ids=token_type_ids,
269
+ )
270
+
271
+ cls_output = outputs.last_hidden_state[:, 0, :]
272
+ cls_output = self.pooler_dropout(cls_output)
273
+
274
+ if self.use_aux and aux_features is not None:
275
+ aux_projected = self.aux_projection(aux_features)
276
+ combined = torch.cat([cls_output, aux_projected], dim=-1)
277
+ else:
278
+ combined = cls_output
279
+
280
+ logits = self.classifier(combined)
281
+
282
+ loss = None
283
+ if labels is not None:
284
+ loss_fn = nn.CrossEntropyLoss(
285
+ label_smoothing=self.config.label_smoothing
286
+ )
287
+ loss = loss_fn(logits, labels)
288
+
289
+ return {'loss': loss, 'logits': logits}
290
+
291
+
292
+ # ============================================================
293
+ # Async Inference Engine (ONNX primary, PyTorch fallback)
294
+ # ============================================================
295
+
296
+ class EOUModelEngine:
297
+ """Async model engine — uses ONNX Runtime for fast inference"""
298
+
299
+ def __init__(self):
300
+ self.onnx_session = None # ONNX Runtime session
301
+ self.torch_model = None # PyTorch model (fallback)
302
+ self.tokenizer: Optional[Any] = None
303
+ self.feature_extractor = SemanticFeatureExtractor()
304
+ self.device = None
305
+ self.threshold: float = 0.5
306
+ self.eou_config: Dict = {}
307
+ self.is_loaded: bool = False
308
+ self.model_dir: str = ""
309
+ self.backend: str = "" # "onnx" or "pytorch"
310
+ self.max_length: int = 128 # Reduced default
311
+
312
+ # Thread pool for blocking operations
313
+ self._executor = ThreadPoolExecutor(max_workers=2)
314
+ self._lock = asyncio.Lock()
315
+
316
+ async def load_model(self, model_dir: str) -> Dict:
317
+ """Load model — prefers ONNX, falls back to PyTorch"""
318
+ async with self._lock:
319
+ logger.info(f"Loading model from {model_dir}...")
320
+ start_time = time.time()
321
+
322
+ try:
323
+ # Load config
324
+ config_path = os.path.join(model_dir, 'eou_config.json')
325
+ if os.path.exists(config_path):
326
+ with open(config_path, 'r') as f:
327
+ self.eou_config = json.load(f)
328
+ self.threshold = self.eou_config.get('best_threshold', 0.5)
329
+ else:
330
+ self.eou_config = {}
331
+ self.threshold = 0.5
332
+
333
+ # Use reduced max_length (128) unless config says otherwise
334
+ self.max_length = min(
335
+ self.eou_config.get('max_length', 128), 128
336
+ )
337
+
338
+ # Load tokenizer (in thread to not block event loop)
339
+ loop = asyncio.get_event_loop()
340
+ self.tokenizer = await loop.run_in_executor(
341
+ self._executor,
342
+ lambda: AutoTokenizer.from_pretrained(model_dir)
343
+ )
344
+ # Try ONNX first (prefer quantized)
345
+ onnx_quantized_path = os.path.join(model_dir, 'eou_model_quantized.onnx')
346
+ onnx_original_path = os.path.join(model_dir, 'eou_model.onnx')
347
+
348
+ onnx_path = onnx_quantized_path if os.path.exists(onnx_quantized_path) else onnx_original_path
349
+
350
+ if ONNX_AVAILABLE and os.path.exists(onnx_path):
351
+ self.backend = "onnx"
352
+ if onnx_path == onnx_quantized_path:
353
+ logger.info("✅ Loading INT8 Quantized ONNX model (ultra fast)")
354
+ else:
355
+ logger.info("✅ Loading Original ONNX model (fast path)")
356
+
357
+ self.onnx_session = await loop.run_in_executor(
358
+ self._executor,
359
+ lambda: self._create_onnx_session(onnx_path)
360
+ )
361
+
362
+ elif TORCH_AVAILABLE:
363
+ self.backend = "pytorch"
364
+ logger.info("⚠️ ONNX model not found, using PyTorch fallback")
365
+ self.device = torch.device(
366
+ "cuda" if torch.cuda.is_available() else "cpu"
367
+ )
368
+
369
+ model_config = Config()
370
+ model_config.model_name = self.eou_config.get(
371
+ 'model_name', 'microsoft/deberta-v3-base'
372
+ )
373
+ model_config.use_aux_features = self.eou_config.get(
374
+ 'use_aux_features', True
375
+ )
376
+ num_aux = self.eou_config.get('num_aux_features', 15)
377
+
378
+ def _load_pytorch():
379
+ model = DeBERTaEOUClassifier(
380
+ model_config, num_aux_features=num_aux
381
+ )
382
+
383
+ # Try to find weights
384
+ for alt in ['model.safetensors', 'pytorch_model.bin', 'pytorch_model_full.pt']:
385
+ alt_path = os.path.join(model_dir, alt)
386
+ if os.path.exists(alt_path):
387
+ if alt.endswith('.safetensors'):
388
+ state_dict = load_safetensors(alt_path, device=str(self.device))
389
+ else:
390
+ state_dict = torch.load(alt_path, map_location=self.device, weights_only=True)
391
+ break
392
+ else:
393
+ raise FileNotFoundError(f"No model weights found in {model_dir}")
394
+ model.load_state_dict(state_dict, strict=False)
395
+ model.to(self.device)
396
+ model.eval()
397
+ return model
398
+
399
+ self.torch_model = await loop.run_in_executor(
400
+ self._executor, _load_pytorch
401
+ )
402
+ else:
403
+ raise RuntimeError(
404
+ "Neither onnxruntime nor torch is available!"
405
+ )
406
+
407
+ self.model_dir = model_dir
408
+ self.is_loaded = True
409
+ load_time = time.time() - start_time
410
+
411
+ info = {
412
+ "status": "loaded",
413
+ "backend": self.backend,
414
+ "model_dir": model_dir,
415
+ "device": str(self.device) if self.device else "cpu",
416
+ "threshold": self.threshold,
417
+ "max_length": self.max_length,
418
+ "load_time_seconds": round(load_time, 2),
419
+ "model_name": self.eou_config.get(
420
+ 'model_name', 'microsoft/deberta-v3-base'
421
+ ),
422
+ "use_aux_features": self.eou_config.get(
423
+ 'use_aux_features', True
424
+ ),
425
+ }
426
+ logger.info(
427
+ f"Model loaded in {load_time:.2f}s "
428
+ f"[backend={self.backend}]"
429
+ )
430
+ return info
431
+
432
+ except Exception as e:
433
+ logger.error(f"Model loading failed: {e}")
434
+ self.is_loaded = False
435
+ raise
436
+
437
+ @staticmethod
438
+ def _create_onnx_session(onnx_path: str):
439
+ """Create an optimized ONNX Runtime session"""
440
+ opts = ort.SessionOptions()
441
+ opts.graph_optimization_level = (
442
+ ort.GraphOptimizationLevel.ORT_ENABLE_ALL
443
+ )
444
+ opts.intra_op_num_threads = os.cpu_count() or 4
445
+ opts.inter_op_num_threads = 2
446
+ opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
447
+
448
+ # Use CPUExecutionProvider (add CUDAExecutionProvider if GPU)
449
+ providers = ['CPUExecutionProvider']
450
+ return ort.InferenceSession(
451
+ onnx_path, sess_options=opts, providers=providers
452
+ )
453
+
454
+ # ----------------------------------------------------------
455
+ # Prediction — ONNX path (fast)
456
+ # ----------------------------------------------------------
457
+
458
+ def _predict_onnx(self, text: str) -> Dict:
459
+ """ONNX Runtime prediction — significantly faster on CPU"""
460
+ start_time = time.time()
461
+
462
+ # Clean text for ASR-trained model (strip punctuation)
463
+ clean_text = TextCleaner.clean(text)
464
+
465
+ # Tokenize with DYNAMIC padding (key optimization!)
466
+ encoding = self.tokenizer(
467
+ clean_text,
468
+ truncation=True,
469
+ max_length=self.max_length,
470
+ padding=True, # Dynamic padding
471
+ return_tensors='np',
472
+ )
473
+
474
+ # Build ONNX input feed
475
+ feed = {
476
+ 'input_ids': encoding['input_ids'].astype(np.int64),
477
+ 'attention_mask': encoding['attention_mask'].astype(np.int64),
478
+ }
479
+
480
+ # Add token_type_ids if the model expects it
481
+ onnx_input_names = [inp.name for inp in self.onnx_session.get_inputs()]
482
+ if 'token_type_ids' in onnx_input_names:
483
+ if 'token_type_ids' in encoding:
484
+ feed['token_type_ids'] = (
485
+ encoding['token_type_ids'].astype(np.int64)
486
+ )
487
+ else:
488
+ feed['token_type_ids'] = np.zeros_like(
489
+ encoding['input_ids'], dtype=np.int64
490
+ )
491
+
492
+ # Add auxiliary features if the model expects them
493
+ if 'aux_features' in onnx_input_names:
494
+ aux = np.array(
495
+ [self.feature_extractor.extract(clean_text)], dtype=np.float32
496
+ )
497
+ feed['aux_features'] = aux
498
+
499
+ # Run inference
500
+ outputs = self.onnx_session.run(None, feed)
501
+ logits = outputs[0] # shape: [1, 2]
502
+
503
+ # Softmax
504
+ exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
505
+ probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
506
+ probs = probs[0]
507
+
508
+ complete_prob = float(probs[1])
509
+ incomplete_prob = float(probs[0])
510
+ is_complete = complete_prob >= self.threshold
511
+
512
+ inference_time = time.time() - start_time
513
+
514
+ # Feature analysis
515
+ features = self.feature_extractor.extract(clean_text)
516
+ feature_names = self.feature_extractor.feature_names()
517
+ feature_analysis = {
518
+ name: round(val, 3) for name, val in zip(feature_names, features)
519
+ }
520
+
521
+ return {
522
+ "text": text,
523
+ "is_complete": is_complete,
524
+ "confidence": round(float(max(probs)), 4),
525
+ "complete_probability": round(complete_prob, 4),
526
+ "incomplete_probability": round(incomplete_prob, 4),
527
+ "threshold": self.threshold,
528
+ "inference_time_ms": round(inference_time * 1000, 2),
529
+ "features": feature_analysis,
530
+ }
531
+
532
+ # ----------------------------------------------------------
533
+ # Prediction — PyTorch path (fallback)
534
+ # ----------------------------------------------------------
535
+
536
+ def _predict_pytorch(self, text: str) -> Dict:
537
+ """PyTorch prediction (fallback if ONNX not available)"""
538
+ start_time = time.time()
539
+
540
+ # Clean text for ASR-trained model (strip punctuation)
541
+ clean_text = TextCleaner.clean(text)
542
+
543
+ encoding = self.tokenizer(
544
+ clean_text,
545
+ truncation=True,
546
+ max_length=self.max_length,
547
+ padding=True, # Dynamic padding fix
548
+ return_tensors='pt',
549
+ )
550
+
551
+ input_ids = encoding['input_ids'].to(self.device)
552
+ attention_mask = encoding['attention_mask'].to(self.device)
553
+ token_type_ids = encoding.get('token_type_ids')
554
+ if token_type_ids is not None:
555
+ token_type_ids = token_type_ids.to(self.device)
556
+
557
+ aux_features = torch.tensor(
558
+ [self.feature_extractor.extract(clean_text)], dtype=torch.float
559
+ ).to(self.device)
560
+
561
+ with torch.no_grad():
562
+ outputs = self.torch_model(
563
+ input_ids=input_ids,
564
+ attention_mask=attention_mask,
565
+ token_type_ids=token_type_ids,
566
+ aux_features=aux_features,
567
+ )
568
+
569
+ probs = torch.softmax(outputs['logits'], dim=-1)[0].cpu().numpy()
570
+ complete_prob = float(probs[1])
571
+ incomplete_prob = float(probs[0])
572
+ is_complete = complete_prob >= self.threshold
573
+
574
+ inference_time = time.time() - start_time
575
+
576
+ features = self.feature_extractor.extract(clean_text)
577
+ feature_names = self.feature_extractor.feature_names()
578
+ feature_analysis = {
579
+ name: round(val, 3) for name, val in zip(feature_names, features)
580
+ }
581
+
582
+ return {
583
+ "text": text,
584
+ "is_complete": is_complete,
585
+ "confidence": round(float(max(probs)), 4),
586
+ "complete_probability": round(complete_prob, 4),
587
+ "incomplete_probability": round(incomplete_prob, 4),
588
+ "threshold": self.threshold,
589
+ "inference_time_ms": round(inference_time * 1000, 2),
590
+ "features": feature_analysis,
591
+ }
592
+
593
+ # ----------------------------------------------------------
594
+ # Public async API
595
+ # ----------------------------------------------------------
596
+
597
+ async def predict(self, text: str) -> Dict:
598
+ """Async prediction — dispatches to ONNX or PyTorch"""
599
+ if not self.is_loaded:
600
+ raise RuntimeError("Model not loaded")
601
+
602
+ loop = asyncio.get_event_loop()
603
+ predict_fn = (
604
+ self._predict_onnx if self.backend == "onnx"
605
+ else self._predict_pytorch
606
+ )
607
+ return await loop.run_in_executor(
608
+ self._executor, predict_fn, text
609
+ )
610
+
611
+ async def predict_batch(
612
+ self, texts: List[str]
613
+ ) -> List[Dict]:
614
+ """Async batch prediction"""
615
+ tasks = [
616
+ self.predict(text) for text in texts
617
+ ]
618
+ return await asyncio.gather(*tasks)
619
+
620
+ async def update_threshold(self, new_threshold: float) -> Dict:
621
+ """Update classification threshold"""
622
+ old_threshold = self.threshold
623
+ self.threshold = max(0.0, min(1.0, new_threshold))
624
+ return {
625
+ "old_threshold": old_threshold,
626
+ "new_threshold": self.threshold,
627
+ }
628
+
629
+ def get_status(self) -> Dict:
630
+ """Get model status"""
631
+ return {
632
+ "is_loaded": self.is_loaded,
633
+ "backend": self.backend,
634
+ "model_dir": self.model_dir,
635
+ "device": str(self.device) if self.device else "cpu",
636
+ "threshold": self.threshold,
637
+ "max_length": self.max_length,
638
+ "config": self.eou_config,
639
+ }
640
+
641
+
642
+ # Singleton instance
643
+ engine = EOUModelEngine()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ huggingface_hub
3
+ torch
4
+ transformers
5
+ safetensors
6
+ sentencepiece
7
+ protobuf
8
+ numpy
9
+ deepgram-sdk<4.0.0
10
+ python-dotenv