Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import BartForConditionalGeneration, BartTokenizer | |
| from huggingface_hub import PyTorchModelHubMixin | |
| class EXTABSModel(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, model_name="facebook/bart-large"): | |
| super(EXTABSModel, self).__init__() | |
| # Load kiến trúc BART gốc | |
| self.bart = BartForConditionalGeneration.from_pretrained(model_name) | |
| # Extractor Head: Lớp tuyến tính để dự đoán tầm quan trọng của token | |
| hidden_size = self.bart.config.hidden_size | |
| self.ext_head = nn.Linear(hidden_size, 1) | |
| def forward(self, input_ids, attention_mask, labels=None, saliency_mask=None, **kwargs): | |
| # 1. Đi qua BART Encoder | |
| encoder_outputs = self.bart.model.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| hidden_states = encoder_outputs.last_hidden_state | |
| # 2. Dự đoán Saliency Score qua Extractor Head | |
| ext_logits = self.ext_head(hidden_states).squeeze(-1) | |
| # 3. Logic Saliency Masking (Chế độ Inference) | |
| # Nếu không trong quá trình training, tạo mask từ các logit dương (> 0) | |
| mask_to_use = (ext_logits > 0).float() | |
| # Kết hợp với Attention Mask gốc để ép Decoder tập trung vào các token quan trọng | |
| cross_attention_mask = attention_mask * mask_to_use | |
| # 4. Đi qua BART Decoder để sinh văn bản | |
| outputs = self.bart( | |
| encoder_outputs=encoder_outputs, | |
| attention_mask=cross_attention_mask, | |
| labels=labels, | |
| return_dict=True, | |
| **kwargs | |
| ) | |
| return outputs.logits, ext_logits | |
| def generate_summary(self, input_ids, attention_mask, **gen_kwargs): | |
| """Hàm hỗ trợ sinh văn bản tóm tắt trong chế độ Inference""" | |
| # Tái hiện luồng forward để lấy cross_attention_mask | |
| encoder_outputs = self.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| hidden_states = encoder_outputs.last_hidden_state | |
| ext_logits = self.ext_head(hidden_states).squeeze(-1) | |
| mask_to_use = (ext_logits > 0).float() | |
| cross_attention_mask = attention_mask * mask_to_use | |
| # Gọi hàm generate của BART với mask tùy chỉnh | |
| summary_ids = self.bart.generate( | |
| encoder_outputs=encoder_outputs, | |
| attention_mask=cross_attention_mask, | |
| **gen_kwargs | |
| ) | |
| return summary_ids |