Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from transformers import BartModel, BartTokenizer | |
| from huggingface_hub import PyTorchModelHubMixin | |
| class BartExtractiveSummarizer(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, model_name="facebook/bart-large"): | |
| super(BartExtractiveSummarizer, self).__init__() | |
| self.encoder = BartModel.from_pretrained(model_name).encoder | |
| hidden_size = self.encoder.config.hidden_size | |
| self.classifier = nn.Linear(hidden_size, 1) | |
| # Force float32 from the beginning | |
| self.to(torch.float32) | |
| def forward(self, input_ids, attention_mask, saliency_mask=None, **kwargs): | |
| device = next(self.parameters()).device | |
| input_ids = input_ids.to(torch.long).to(device) | |
| attention_mask = attention_mask.to(torch.long).to(device) | |
| if saliency_mask is not None: | |
| saliency_mask = saliency_mask.to(torch.float32).to(device) | |
| # Extra safety: ensure encoder stays in float32 | |
| if self.encoder.parameters().__next__().dtype != torch.float32: | |
| self.encoder = self.encoder.to(torch.float32) | |
| encoder_outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| hidden_states = encoder_outputs.last_hidden_state.float() | |
| logits = self.classifier(hidden_states).squeeze(-1) | |
| loss = None | |
| if saliency_mask is not None: | |
| active_loss = attention_mask.view(-1) == 1 | |
| active_logits = logits.view(-1)[active_loss] | |
| active_labels = saliency_mask.view(-1)[active_loss].float() | |
| num_pos = active_labels.sum() | |
| num_neg = active_labels.size(0) - num_pos | |
| weight = torch.tensor(num_neg / num_pos if num_pos > 0 else 1.0, | |
| dtype=torch.float32, device=logits.device) | |
| loss_fct = nn.BCEWithLogitsLoss(pos_weight=weight) | |
| loss = loss_fct(active_logits, active_labels) | |
| return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} | |
| def get_trigrams(text: str): | |
| """Tạo tập hợp các cụm 3 từ liên tiếp từ một đoạn văn bản (Trigram Blocking)""" | |
| words = text.lower().split() | |
| return set(tuple(words[i:i+3]) for i in range(len(words)-2)) | |