import torch import torch.nn as nn import numpy as np from transformers import BartModel, BartTokenizer from huggingface_hub import PyTorchModelHubMixin class BartExtractiveSummarizer(nn.Module, PyTorchModelHubMixin): def __init__(self, model_name="facebook/bart-large"): super(BartExtractiveSummarizer, self).__init__() self.encoder = BartModel.from_pretrained(model_name).encoder hidden_size = self.encoder.config.hidden_size self.classifier = nn.Linear(hidden_size, 1) # Force float32 from the beginning self.to(torch.float32) def forward(self, input_ids, attention_mask, saliency_mask=None, **kwargs): device = next(self.parameters()).device input_ids = input_ids.to(torch.long).to(device) attention_mask = attention_mask.to(torch.long).to(device) if saliency_mask is not None: saliency_mask = saliency_mask.to(torch.float32).to(device) # Extra safety: ensure encoder stays in float32 if self.encoder.parameters().__next__().dtype != torch.float32: self.encoder = self.encoder.to(torch.float32) encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask ) hidden_states = encoder_outputs.last_hidden_state.float() logits = self.classifier(hidden_states).squeeze(-1) loss = None if saliency_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1)[active_loss] active_labels = saliency_mask.view(-1)[active_loss].float() num_pos = active_labels.sum() num_neg = active_labels.size(0) - num_pos weight = torch.tensor(num_neg / num_pos if num_pos > 0 else 1.0, dtype=torch.float32, device=logits.device) loss_fct = nn.BCEWithLogitsLoss(pos_weight=weight) loss = loss_fct(active_logits, active_labels) return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} def get_trigrams(text: str): """Tạo tập hợp các cụm 3 từ liên tiếp từ một đoạn văn bản (Trigram Blocking)""" words = text.lower().split() return set(tuple(words[i:i+3]) for i in range(len(words)-2))