import torch import torch.nn as nn from transformers import BartForConditionalGeneration, BartTokenizer from huggingface_hub import PyTorchModelHubMixin class EXTABSModel(nn.Module, PyTorchModelHubMixin): def __init__(self, model_name="facebook/bart-large"): super(EXTABSModel, self).__init__() # Load kiến trúc BART gốc self.bart = BartForConditionalGeneration.from_pretrained(model_name) # Extractor Head: Lớp tuyến tính để dự đoán tầm quan trọng của token hidden_size = self.bart.config.hidden_size self.ext_head = nn.Linear(hidden_size, 1) def forward(self, input_ids, attention_mask, labels=None, saliency_mask=None, **kwargs): # 1. Đi qua BART Encoder encoder_outputs = self.bart.model.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) hidden_states = encoder_outputs.last_hidden_state # 2. Dự đoán Saliency Score qua Extractor Head ext_logits = self.ext_head(hidden_states).squeeze(-1) # 3. Logic Saliency Masking (Chế độ Inference) # Nếu không trong quá trình training, tạo mask từ các logit dương (> 0) mask_to_use = (ext_logits > 0).float() # Kết hợp với Attention Mask gốc để ép Decoder tập trung vào các token quan trọng cross_attention_mask = attention_mask * mask_to_use # 4. Đi qua BART Decoder để sinh văn bản outputs = self.bart( encoder_outputs=encoder_outputs, attention_mask=cross_attention_mask, labels=labels, return_dict=True, **kwargs ) return outputs.logits, ext_logits def generate_summary(self, input_ids, attention_mask, **gen_kwargs): """Hàm hỗ trợ sinh văn bản tóm tắt trong chế độ Inference""" # Tái hiện luồng forward để lấy cross_attention_mask encoder_outputs = self.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_states = encoder_outputs.last_hidden_state ext_logits = self.ext_head(hidden_states).squeeze(-1) mask_to_use = (ext_logits > 0).float() cross_attention_mask = attention_mask * mask_to_use # Gọi hàm generate của BART với mask tùy chỉnh summary_ids = self.bart.generate( encoder_outputs=encoder_outputs, attention_mask=cross_attention_mask, **gen_kwargs ) return summary_ids