DemoApp / src /model /extabs.py
Reality8081's picture
Update src
1dde759
import torch
import torch.nn as nn
from transformers import BartForConditionalGeneration, BartTokenizer
from huggingface_hub import PyTorchModelHubMixin
class EXTABSModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, model_name="facebook/bart-large"):
super(EXTABSModel, self).__init__()
# Load kiến trúc BART gốc
self.bart = BartForConditionalGeneration.from_pretrained(model_name)
# Extractor Head: Lớp tuyến tính để dự đoán tầm quan trọng của token
hidden_size = self.bart.config.hidden_size
self.ext_head = nn.Linear(hidden_size, 1)
def forward(self, input_ids, attention_mask, labels=None, saliency_mask=None, **kwargs):
# 1. Đi qua BART Encoder
encoder_outputs = self.bart.model.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
hidden_states = encoder_outputs.last_hidden_state
# 2. Dự đoán Saliency Score qua Extractor Head
ext_logits = self.ext_head(hidden_states).squeeze(-1)
# 3. Logic Saliency Masking (Chế độ Inference)
# Nếu không trong quá trình training, tạo mask từ các logit dương (> 0)
mask_to_use = (ext_logits > 0).float()
# Kết hợp với Attention Mask gốc để ép Decoder tập trung vào các token quan trọng
cross_attention_mask = attention_mask * mask_to_use
# 4. Đi qua BART Decoder để sinh văn bản
outputs = self.bart(
encoder_outputs=encoder_outputs,
attention_mask=cross_attention_mask,
labels=labels,
return_dict=True,
**kwargs
)
return outputs.logits, ext_logits
def generate_summary(self, input_ids, attention_mask, **gen_kwargs):
"""Hàm hỗ trợ sinh văn bản tóm tắt trong chế độ Inference"""
# Tái hiện luồng forward để lấy cross_attention_mask
encoder_outputs = self.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = encoder_outputs.last_hidden_state
ext_logits = self.ext_head(hidden_states).squeeze(-1)
mask_to_use = (ext_logits > 0).float()
cross_attention_mask = attention_mask * mask_to_use
# Gọi hàm generate của BART với mask tùy chỉnh
summary_ids = self.bart.generate(
encoder_outputs=encoder_outputs,
attention_mask=cross_attention_mask,
**gen_kwargs
)
return summary_ids