import torch import torch.nn as nn from .encoder import MedicalImageEncoder from .phobert_encoder import PhoBERTEncoder from .transformer_decoder import MedicalVQADecoder class CoAttentionFusion(nn.Module): """ Cơ chế Co-Attention giúp mô hình tập trung vào các vùng ảnh và từ ngữ liên quan lẫn nhau. """ def __init__(self, hidden_size=768, nhead=8): super(CoAttentionFusion, self).__init__() # Cross-modal attention: Ảnh hỏi Chữ và Chữ hỏi Ảnh self.v2t_attn = nn.MultiheadAttention(hidden_size, nhead, batch_first=True) self.t2v_attn = nn.MultiheadAttention(hidden_size, nhead, batch_first=True) self.fusion_layer = nn.Sequential( nn.Linear(hidden_size * 2, hidden_size), nn.LayerNorm(hidden_size), nn.ReLU(), nn.Dropout(0.1) ) def forward(self, v_feats, t_feats): # v_feats: [B, 49, 768] — KHÔNG cần unsqueeze nữa t_seq = t_feats.unsqueeze(1) # [B, 1, 768] — text vẫn giữ # Parallel Co-Attention v_fused, _ = self.v2t_attn(v_feats, t_seq, t_seq) t_fused, _ = self.t2v_attn(t_seq, v_feats, v_feats) # v_fused: [B, 49, 768] → pool về [B, 1, 768] trước khi concat v_fused = v_fused.mean(dim=1, keepdim=True) # Kết hợp thông tin từ cả hai hướng combined = torch.cat([v_fused, t_fused], dim=-1) # [B, 1, 1536] return self.fusion_layer(combined) # [B, 1, 768] class MedicalVQAModelA(nn.Module): """ Kiến trúc rời (Hướng A) cho Medical VQA Tiếng Việt. Sử dụng DenseNet-121 (XRV) + PhoBERT + Co-Attention + Dual-Head Decoder. """ def __init__(self, decoder_type="transformer", vocab_size=30000, hidden_size=768, phobert_model=None, **kwargs): super(MedicalVQAModelA, self).__init__() # 1. Image Encoder (DenseNet-121 XRV) self.image_encoder = MedicalImageEncoder(pretrained=True) # 2. Text Encoder (PhoBERT) self.text_encoder = PhoBERTEncoder(model_name=phobert_model) if phobert_model else PhoBERTEncoder() # 3. Fusion Layer (Co-Attention Fusion) self.fusion = CoAttentionFusion(hidden_size=hidden_size, nhead=8) # 4. Trích xuất pretrained embeddings từ PhoBERT cho Decoder phobert_embeddings = self.text_encoder.bert.embeddings.word_embeddings.weight actual_vocab_size = phobert_embeddings.size(0) # 5. Decoder (LSTM / Transformer) self.decoder = MedicalVQADecoder( decoder_type=decoder_type, vocab_size=actual_vocab_size, hidden_size=hidden_size, pretrained_embeddings=phobert_embeddings ) def forward(self, images, input_ids, attention_mask, labels_open=None, labels_closed=None): v_feats = self.image_encoder(images) t_feats = self.text_encoder(input_ids, attention_mask) fused = self.fusion(v_feats, t_feats) logits_closed, logits_open = self.decoder(fused, labels_open) return logits_closed, logits_open def generate(self, images, input_ids, attention_mask, beam_width=1, max_len=10): """ Giao diện chuyên biệt cho quá trình Inference (chỉ trả token IDs cho open-ended). """ v_feats = self.image_encoder(images) t_feats = self.text_encoder(input_ids, attention_mask) fused = self.fusion(v_feats, t_feats) return self.decoder.generate(fused, beam_width=beam_width, max_len=max_len) def inference(self, images, input_ids, attention_mask, beam_width=1, max_len=10): """ [NEW] Trả về CẢ HAI dual-head outputs: - logits_closed: [B, 2] — dùng cho câu Yes/No (classifier head) - generated_ids: [B, max_len] — dùng cho câu mở (generative head) """ v_feats = self.image_encoder(images) t_feats = self.text_encoder(input_ids, attention_mask) fused = self.fusion(v_feats, t_feats) logits_closed = self.decoder.classifier_head(fused.squeeze(1)) # [B, 2] generated_ids = self.decoder.generate(fused, beam_width=beam_width, max_len=max_len) # [B, max_len] return logits_closed, generated_ids