Medical-VQA / src /models /medical_vqa_model.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
import torch
import torch.nn as nn
from .encoder import MedicalImageEncoder
from .phobert_encoder import PhoBERTEncoder
from .transformer_decoder import MedicalVQADecoder
class CoAttentionFusion(nn.Module):
"""
Cơ chế Co-Attention giúp mô hình tập trung vào các vùng ảnh và từ ngữ liên quan lẫn nhau.
"""
def __init__(self, hidden_size=768, nhead=8):
super(CoAttentionFusion, self).__init__()
# Cross-modal attention: Ảnh hỏi Chữ và Chữ hỏi Ảnh
self.v2t_attn = nn.MultiheadAttention(hidden_size, nhead, batch_first=True)
self.t2v_attn = nn.MultiheadAttention(hidden_size, nhead, batch_first=True)
self.fusion_layer = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(),
nn.Dropout(0.1)
)
def forward(self, v_feats, t_feats):
# v_feats: [B, 49, 768] — KHÔNG cần unsqueeze nữa
t_seq = t_feats.unsqueeze(1) # [B, 1, 768] — text vẫn giữ
# Parallel Co-Attention
v_fused, _ = self.v2t_attn(v_feats, t_seq, t_seq)
t_fused, _ = self.t2v_attn(t_seq, v_feats, v_feats)
# v_fused: [B, 49, 768] → pool về [B, 1, 768] trước khi concat
v_fused = v_fused.mean(dim=1, keepdim=True)
# Kết hợp thông tin từ cả hai hướng
combined = torch.cat([v_fused, t_fused], dim=-1) # [B, 1, 1536]
return self.fusion_layer(combined) # [B, 1, 768]
class MedicalVQAModelA(nn.Module):
"""
Kiến trúc rời (Hướng A) cho Medical VQA Tiếng Việt.
Sử dụng DenseNet-121 (XRV) + PhoBERT + Co-Attention + Dual-Head Decoder.
"""
def __init__(self, decoder_type="transformer", vocab_size=30000, hidden_size=768, phobert_model=None, **kwargs):
super(MedicalVQAModelA, self).__init__()
# 1. Image Encoder (DenseNet-121 XRV)
self.image_encoder = MedicalImageEncoder(pretrained=True)
# 2. Text Encoder (PhoBERT)
self.text_encoder = PhoBERTEncoder(model_name=phobert_model) if phobert_model else PhoBERTEncoder()
# 3. Fusion Layer (Co-Attention Fusion)
self.fusion = CoAttentionFusion(hidden_size=hidden_size, nhead=8)
# 4. Trích xuất pretrained embeddings từ PhoBERT cho Decoder
phobert_embeddings = self.text_encoder.bert.embeddings.word_embeddings.weight
actual_vocab_size = phobert_embeddings.size(0)
# 5. Decoder (LSTM / Transformer)
self.decoder = MedicalVQADecoder(
decoder_type=decoder_type,
vocab_size=actual_vocab_size,
hidden_size=hidden_size,
pretrained_embeddings=phobert_embeddings
)
def forward(self, images, input_ids, attention_mask, labels_open=None, labels_closed=None):
v_feats = self.image_encoder(images)
t_feats = self.text_encoder(input_ids, attention_mask)
fused = self.fusion(v_feats, t_feats)
logits_closed, logits_open = self.decoder(fused, labels_open)
return logits_closed, logits_open
def generate(self, images, input_ids, attention_mask, beam_width=1, max_len=10):
"""
Giao diện chuyên biệt cho quá trình Inference (chỉ trả token IDs cho open-ended).
"""
v_feats = self.image_encoder(images)
t_feats = self.text_encoder(input_ids, attention_mask)
fused = self.fusion(v_feats, t_feats)
return self.decoder.generate(fused, beam_width=beam_width, max_len=max_len)
def inference(self, images, input_ids, attention_mask, beam_width=1, max_len=10):
"""
[NEW] Trả về CẢ HAI dual-head outputs:
- logits_closed: [B, 2] — dùng cho câu Yes/No (classifier head)
- generated_ids: [B, max_len] — dùng cho câu mở (generative head)
"""
v_feats = self.image_encoder(images)
t_feats = self.text_encoder(input_ids, attention_mask)
fused = self.fusion(v_feats, t_feats)
logits_closed = self.decoder.classifier_head(fused.squeeze(1)) # [B, 2]
generated_ids = self.decoder.generate(fused, beam_width=beam_width, max_len=max_len) # [B, max_len]
return logits_closed, generated_ids