Rishi2455 commited on
Commit
6b575fc
·
verified ·
1 Parent(s): b2d7984

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +15 -104
model_loader.py CHANGED
@@ -225,75 +225,7 @@ class SemanticFeatureExtractor:
225
  # ============================================================
226
 
227
  if TORCH_AVAILABLE:
228
- class EOUClassifier(nn.Module):
229
- """AutoModel with auxiliary features for End-of-Utterance detection"""
230
-
231
- def __init__(self, model_path: str, use_aux: bool = True, num_aux_features: int = 15, dropout: float = 0.1):
232
- super().__init__()
233
- self.use_aux = use_aux
234
-
235
- # Load the configuration purely from local files
236
- from transformers import AutoConfig, AutoModel
237
- config = AutoConfig.from_pretrained(model_path, local_files_only=True)
238
-
239
- # Initialize model architecture WITHOUT downloading base weights
240
- self.base_model = AutoModel.from_config(config)
241
-
242
- # DistilBert uses 'dim', others use 'hidden_size'
243
- hidden_size = getattr(config, 'hidden_size', getattr(config, 'dim', 768))
244
-
245
- self.pooler_dropout = nn.Dropout(dropout)
246
-
247
- if self.use_aux:
248
- self.aux_projection = nn.Sequential(
249
- nn.Linear(num_aux_features, 32),
250
- nn.GELU(),
251
- nn.Dropout(dropout),
252
- )
253
- classifier_input_size = hidden_size + 32
254
- else:
255
- classifier_input_size = hidden_size
256
-
257
- self.classifier = nn.Sequential(
258
- nn.Linear(classifier_input_size, 256),
259
- nn.GELU(),
260
- nn.LayerNorm(256),
261
- nn.Dropout(dropout),
262
- nn.Linear(256, 64),
263
- nn.GELU(),
264
- nn.Dropout(dropout),
265
- nn.Linear(64, 2),
266
- )
267
-
268
- def forward(self, input_ids, attention_mask, token_type_ids=None,
269
- aux_features=None, labels=None):
270
-
271
- # DistilBert doesn't accept token_type_ids
272
- model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
273
- if token_type_ids is not None and "token_type_ids" in self.base_model.forward.__code__.co_varnames:
274
- model_inputs["token_type_ids"] = token_type_ids
275
-
276
- outputs = self.base_model(**model_inputs)
277
-
278
- # Get the CLS token representation (first token)
279
- cls_output = outputs.last_hidden_state[:, 0, :]
280
- cls_output = self.pooler_dropout(cls_output)
281
-
282
- if self.use_aux and aux_features is not None:
283
- aux_projected = self.aux_projection(aux_features)
284
- combined = torch.cat([cls_output, aux_projected], dim=-1)
285
- else:
286
- combined = cls_output
287
-
288
- logits = self.classifier(combined)
289
-
290
- loss = None
291
- if labels is not None:
292
- # Default smoothing if none configured
293
- loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
294
- loss = loss_fn(logits, labels)
295
-
296
- return {'loss': loss, 'logits': logits}
297
 
298
 
299
 
@@ -376,32 +308,13 @@ class EOUModelEngine:
376
 
377
  model_config = Config()
378
  model_config.model_name = self.eou_config.get(
379
- 'model_name', 'microsoft/deberta-v3-base'
380
- )
381
- model_config.use_aux_features = self.eou_config.get(
382
- 'use_aux_features', True
383
- )
384
- num_aux = self.eou_config.get('num_aux_features', 15)
385
-
386
  def _load_pytorch():
387
- model = EOUClassifier(
388
- model_path=model_dir,
389
- use_aux=self.eou_config.get('use_aux_features', True),
390
- num_aux_features=num_aux
 
391
  )
392
-
393
- # Try to find weights
394
- for alt in ['model.safetensors', 'pytorch_model.bin', 'pytorch_model_full.pt']:
395
- alt_path = os.path.join(model_dir, alt)
396
- if os.path.exists(alt_path):
397
- if alt.endswith('.safetensors'):
398
- state_dict = load_safetensors(alt_path, device=str(self.device))
399
- else:
400
- state_dict = torch.load(alt_path, map_location=self.device, weights_only=True)
401
- break
402
- else:
403
- raise FileNotFoundError(f"No model weights found in {model_dir}")
404
- model.load_state_dict(state_dict, strict=False)
405
  model.to(self.device)
406
  model.eval()
407
  return model
@@ -564,19 +477,17 @@ class EOUModelEngine:
564
  if token_type_ids is not None:
565
  token_type_ids = token_type_ids.to(self.device)
566
 
567
- aux_features = torch.tensor(
568
- [self.feature_extractor.extract(clean_text)], dtype=torch.float
569
- ).to(self.device)
570
-
571
  with torch.no_grad():
572
- outputs = self.torch_model(
573
- input_ids=input_ids,
574
- attention_mask=attention_mask,
575
- token_type_ids=token_type_ids,
576
- aux_features=aux_features,
577
- )
 
 
578
 
579
- probs = torch.softmax(outputs['logits'], dim=-1)[0].cpu().numpy()
580
  complete_prob = float(probs[1])
581
  incomplete_prob = float(probs[0])
582
  is_complete = complete_prob >= self.threshold
 
225
  # ============================================================
226
 
227
  if TORCH_AVAILABLE:
228
+ from transformers import AutoModelForSequenceClassification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
 
 
308
 
309
  model_config = Config()
310
  model_config.model_name = self.eou_config.get(
 
 
 
 
 
 
 
311
  def _load_pytorch():
312
+ # The model is a standard HF classification model (like DistilBertForSequenceClassification)
313
+ # This natively handles config.json AND strictly loads your model.safetensors weights!
314
+ model = AutoModelForSequenceClassification.from_pretrained(
315
+ model_dir,
316
+ local_files_only=True
317
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  model.to(self.device)
319
  model.eval()
320
  return model
 
477
  if token_type_ids is not None:
478
  token_type_ids = token_type_ids.to(self.device)
479
 
 
 
 
 
480
  with torch.no_grad():
481
+ model_inputs = {
482
+ "input_ids": input_ids,
483
+ "attention_mask": attention_mask
484
+ }
485
+ if token_type_ids is not None and "token_type_ids" in self.torch_model.forward.__code__.co_varnames:
486
+ model_inputs["token_type_ids"] = token_type_ids
487
+
488
+ outputs = self.torch_model(**model_inputs)
489
 
490
+ probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
491
  complete_prob = float(probs[1])
492
  incomplete_prob = float(probs[0])
493
  is_complete = complete_prob >= self.threshold