Spaces:
Running
Running
Update model_loader.py
Browse files- model_loader.py +15 -104
model_loader.py
CHANGED
|
@@ -225,75 +225,7 @@ class SemanticFeatureExtractor:
|
|
| 225 |
# ============================================================
|
| 226 |
|
| 227 |
if TORCH_AVAILABLE:
|
| 228 |
-
|
| 229 |
-
"""AutoModel with auxiliary features for End-of-Utterance detection"""
|
| 230 |
-
|
| 231 |
-
def __init__(self, model_path: str, use_aux: bool = True, num_aux_features: int = 15, dropout: float = 0.1):
|
| 232 |
-
super().__init__()
|
| 233 |
-
self.use_aux = use_aux
|
| 234 |
-
|
| 235 |
-
# Load the configuration purely from local files
|
| 236 |
-
from transformers import AutoConfig, AutoModel
|
| 237 |
-
config = AutoConfig.from_pretrained(model_path, local_files_only=True)
|
| 238 |
-
|
| 239 |
-
# Initialize model architecture WITHOUT downloading base weights
|
| 240 |
-
self.base_model = AutoModel.from_config(config)
|
| 241 |
-
|
| 242 |
-
# DistilBert uses 'dim', others use 'hidden_size'
|
| 243 |
-
hidden_size = getattr(config, 'hidden_size', getattr(config, 'dim', 768))
|
| 244 |
-
|
| 245 |
-
self.pooler_dropout = nn.Dropout(dropout)
|
| 246 |
-
|
| 247 |
-
if self.use_aux:
|
| 248 |
-
self.aux_projection = nn.Sequential(
|
| 249 |
-
nn.Linear(num_aux_features, 32),
|
| 250 |
-
nn.GELU(),
|
| 251 |
-
nn.Dropout(dropout),
|
| 252 |
-
)
|
| 253 |
-
classifier_input_size = hidden_size + 32
|
| 254 |
-
else:
|
| 255 |
-
classifier_input_size = hidden_size
|
| 256 |
-
|
| 257 |
-
self.classifier = nn.Sequential(
|
| 258 |
-
nn.Linear(classifier_input_size, 256),
|
| 259 |
-
nn.GELU(),
|
| 260 |
-
nn.LayerNorm(256),
|
| 261 |
-
nn.Dropout(dropout),
|
| 262 |
-
nn.Linear(256, 64),
|
| 263 |
-
nn.GELU(),
|
| 264 |
-
nn.Dropout(dropout),
|
| 265 |
-
nn.Linear(64, 2),
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
def forward(self, input_ids, attention_mask, token_type_ids=None,
|
| 269 |
-
aux_features=None, labels=None):
|
| 270 |
-
|
| 271 |
-
# DistilBert doesn't accept token_type_ids
|
| 272 |
-
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 273 |
-
if token_type_ids is not None and "token_type_ids" in self.base_model.forward.__code__.co_varnames:
|
| 274 |
-
model_inputs["token_type_ids"] = token_type_ids
|
| 275 |
-
|
| 276 |
-
outputs = self.base_model(**model_inputs)
|
| 277 |
-
|
| 278 |
-
# Get the CLS token representation (first token)
|
| 279 |
-
cls_output = outputs.last_hidden_state[:, 0, :]
|
| 280 |
-
cls_output = self.pooler_dropout(cls_output)
|
| 281 |
-
|
| 282 |
-
if self.use_aux and aux_features is not None:
|
| 283 |
-
aux_projected = self.aux_projection(aux_features)
|
| 284 |
-
combined = torch.cat([cls_output, aux_projected], dim=-1)
|
| 285 |
-
else:
|
| 286 |
-
combined = cls_output
|
| 287 |
-
|
| 288 |
-
logits = self.classifier(combined)
|
| 289 |
-
|
| 290 |
-
loss = None
|
| 291 |
-
if labels is not None:
|
| 292 |
-
# Default smoothing if none configured
|
| 293 |
-
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
|
| 294 |
-
loss = loss_fn(logits, labels)
|
| 295 |
-
|
| 296 |
-
return {'loss': loss, 'logits': logits}
|
| 297 |
|
| 298 |
|
| 299 |
|
|
@@ -376,32 +308,13 @@ class EOUModelEngine:
|
|
| 376 |
|
| 377 |
model_config = Config()
|
| 378 |
model_config.model_name = self.eou_config.get(
|
| 379 |
-
'model_name', 'microsoft/deberta-v3-base'
|
| 380 |
-
)
|
| 381 |
-
model_config.use_aux_features = self.eou_config.get(
|
| 382 |
-
'use_aux_features', True
|
| 383 |
-
)
|
| 384 |
-
num_aux = self.eou_config.get('num_aux_features', 15)
|
| 385 |
-
|
| 386 |
def _load_pytorch():
|
| 387 |
-
model
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
|
|
|
| 391 |
)
|
| 392 |
-
|
| 393 |
-
# Try to find weights
|
| 394 |
-
for alt in ['model.safetensors', 'pytorch_model.bin', 'pytorch_model_full.pt']:
|
| 395 |
-
alt_path = os.path.join(model_dir, alt)
|
| 396 |
-
if os.path.exists(alt_path):
|
| 397 |
-
if alt.endswith('.safetensors'):
|
| 398 |
-
state_dict = load_safetensors(alt_path, device=str(self.device))
|
| 399 |
-
else:
|
| 400 |
-
state_dict = torch.load(alt_path, map_location=self.device, weights_only=True)
|
| 401 |
-
break
|
| 402 |
-
else:
|
| 403 |
-
raise FileNotFoundError(f"No model weights found in {model_dir}")
|
| 404 |
-
model.load_state_dict(state_dict, strict=False)
|
| 405 |
model.to(self.device)
|
| 406 |
model.eval()
|
| 407 |
return model
|
|
@@ -564,19 +477,17 @@ class EOUModelEngine:
|
|
| 564 |
if token_type_ids is not None:
|
| 565 |
token_type_ids = token_type_ids.to(self.device)
|
| 566 |
|
| 567 |
-
aux_features = torch.tensor(
|
| 568 |
-
[self.feature_extractor.extract(clean_text)], dtype=torch.float
|
| 569 |
-
).to(self.device)
|
| 570 |
-
|
| 571 |
with torch.no_grad():
|
| 572 |
-
|
| 573 |
-
input_ids
|
| 574 |
-
attention_mask
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
|
|
|
|
|
|
| 578 |
|
| 579 |
-
probs = torch.softmax(outputs
|
| 580 |
complete_prob = float(probs[1])
|
| 581 |
incomplete_prob = float(probs[0])
|
| 582 |
is_complete = complete_prob >= self.threshold
|
|
|
|
| 225 |
# ============================================================
|
| 226 |
|
| 227 |
if TORCH_AVAILABLE:
|
| 228 |
+
from transformers import AutoModelForSequenceClassification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
|
|
|
|
| 308 |
|
| 309 |
model_config = Config()
|
| 310 |
model_config.model_name = self.eou_config.get(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
def _load_pytorch():
|
| 312 |
+
# The model is a standard HF classification model (like DistilBertForSequenceClassification)
|
| 313 |
+
# This natively handles config.json AND strictly loads your model.safetensors weights!
|
| 314 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 315 |
+
model_dir,
|
| 316 |
+
local_files_only=True
|
| 317 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
model.to(self.device)
|
| 319 |
model.eval()
|
| 320 |
return model
|
|
|
|
| 477 |
if token_type_ids is not None:
|
| 478 |
token_type_ids = token_type_ids.to(self.device)
|
| 479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
with torch.no_grad():
|
| 481 |
+
model_inputs = {
|
| 482 |
+
"input_ids": input_ids,
|
| 483 |
+
"attention_mask": attention_mask
|
| 484 |
+
}
|
| 485 |
+
if token_type_ids is not None and "token_type_ids" in self.torch_model.forward.__code__.co_varnames:
|
| 486 |
+
model_inputs["token_type_ids"] = token_type_ids
|
| 487 |
+
|
| 488 |
+
outputs = self.torch_model(**model_inputs)
|
| 489 |
|
| 490 |
+
probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
|
| 491 |
complete_prob = float(probs[1])
|
| 492 |
incomplete_prob = float(probs[0])
|
| 493 |
is_complete = complete_prob >= self.threshold
|