DemoApp / src /model /baseline_extractive_model.py
Reality8081's picture
Update src
1dde759
import torch
import torch.nn as nn
import numpy as np
from transformers import BartModel, BartTokenizer
from huggingface_hub import PyTorchModelHubMixin
class BartExtractiveSummarizer(nn.Module, PyTorchModelHubMixin):
def __init__(self, model_name="facebook/bart-large"):
super(BartExtractiveSummarizer, self).__init__()
self.encoder = BartModel.from_pretrained(model_name).encoder
hidden_size = self.encoder.config.hidden_size
self.classifier = nn.Linear(hidden_size, 1)
# Force float32 from the beginning
self.to(torch.float32)
def forward(self, input_ids, attention_mask, saliency_mask=None, **kwargs):
device = next(self.parameters()).device
input_ids = input_ids.to(torch.long).to(device)
attention_mask = attention_mask.to(torch.long).to(device)
if saliency_mask is not None:
saliency_mask = saliency_mask.to(torch.float32).to(device)
# Extra safety: ensure encoder stays in float32
if self.encoder.parameters().__next__().dtype != torch.float32:
self.encoder = self.encoder.to(torch.float32)
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
hidden_states = encoder_outputs.last_hidden_state.float()
logits = self.classifier(hidden_states).squeeze(-1)
loss = None
if saliency_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1)[active_loss]
active_labels = saliency_mask.view(-1)[active_loss].float()
num_pos = active_labels.sum()
num_neg = active_labels.size(0) - num_pos
weight = torch.tensor(num_neg / num_pos if num_pos > 0 else 1.0,
dtype=torch.float32, device=logits.device)
loss_fct = nn.BCEWithLogitsLoss(pos_weight=weight)
loss = loss_fct(active_logits, active_labels)
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
def get_trigrams(text: str):
"""Tạo tập hợp các cụm 3 từ liên tiếp từ một đoạn văn bản (Trigram Blocking)"""
words = text.lower().split()
return set(tuple(words[i:i+3]) for i in range(len(words)-2))