Medical-VQA / src /models /transformer_decoder.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
raw
history blame
10 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class MedicalVQADecoder(nn.Module):
def __init__(
self,
decoder_type: str = "transformer",
vocab_size: int = 30000,
hidden_size: int = 768,
pretrained_embeddings=None,
num_layers: int = 3,
nhead: int = 8,
dropout: float = 0.1,
):
super().__init__()
self.decoder_type = decoder_type.lower()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
# ── NhΓ‘nh 1: Classifier cho Yes/No ──────────────────────────────────
# [FIX] ThΓͺm Dropout + GELU theo best-practice hiện Δ‘αΊ‘i
self.classifier_head = nn.Sequential(
nn.Linear(hidden_size, 512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, 2),
)
# ── NhΓ‘nh 2: Generator ───────────────────────────────────────────────
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
if pretrained_embeddings is not None:
self.embedding.weight.data.copy_(pretrained_embeddings)
if self.decoder_type == "lstm":
self.generator = nn.LSTM(
hidden_size, hidden_size, num_layers=1, batch_first=True
)
else:
# [FIX A2] Pre-LayerNorm (norm_first=True): hα»™i tα»₯ α»•n Δ‘α»‹nh hΖ‘n, giαΊ£m gap A1-A2
# dim_feedforward=4*hidden (768*4=3072) theo chuαΊ©n Transformer gα»‘c
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_size,
nhead=nhead,
dim_feedforward=hidden_size * 4,
dropout=dropout,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.generator = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False)
# [OPTIMIZATION] Weight Tying: chia sαΊ» trọng sα»‘ Embedding ↔ Output Projection
# GiαΊ£m ~vocab_size * hidden_size params, cαΊ£i thiện generalization (Press & Wolf 2017)
self.output_layer.weight = self.embedding.weight
# [OPTIMIZATION] Cache causal mask để trΓ‘nh re-allocate mα»—i forward pass
self._causal_mask_cache: dict[tuple, torch.Tensor] = {}
# ── Mask helper ─────────────────────────────────────────────────────────
def _get_causal_mask(self, sz: int, device: torch.device) -> torch.Tensor:
key = (sz, str(device))
if key not in self._causal_mask_cache:
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
self._causal_mask_cache[key] = mask
return self._causal_mask_cache[key]
# ── Public generate API ──────────────────────────────────────────────────
def generate(self, fused_features, beam_width: int = 1, max_len: int = 10):
"""Sinh cÒu trả lời. Trả về token IDs [B, max_len]."""
if beam_width <= 1:
return self._greedy_search(fused_features, max_len)
return self._beam_search(fused_features, beam_width, max_len)
# ── Greedy Search ────────────────────────────────────────────────────────
def _greedy_search(self, fused_features, max_len: int):
"""
Greedy decoding (beam_width=1).
LSTM: chỉ feed token cuα»‘i, h_state giα»― ngα»― cαΊ£nh β†’ trΓ‘nh O(nΒ²) recompute.
Trả về token IDs [B, max_len].
"""
batch_size = fused_features.size(0)
device = fused_features.device
generated = torch.zeros((batch_size, 1), dtype=torch.long, device=device) # BOS=0
h_state = None
for _ in range(max_len):
if self.decoder_type == "lstm":
curr_emb = self.embedding(generated[:, -1:]) # [B,1,H]
if h_state is None:
h0 = fused_features.transpose(0, 1).contiguous()
h_state = (h0, torch.zeros_like(h0))
outputs, h_state = self.generator(curr_emb, h_state)
else:
curr_emb = self.embedding(generated)
tgt_mask = self._get_causal_mask(generated.size(1), device)
outputs = self.generator(curr_emb, fused_features, tgt_mask=tgt_mask)
next_token = self.output_layer(outputs[:, -1:, :]).argmax(dim=-1)
generated = torch.cat([generated, next_token], dim=1)
return generated[:, 1:] # Bỏ BOS
# ── Beam Search ──────────────────────────────────────────────────────────
def _beam_search(
self,
fused_features,
beam_width: int,
max_len: int,
repetition_penalty: float = 1.2,
alpha: float = 0.7,
):
"""
Beam Search vα»›i Length Normalization + Vectorised Repetition Penalty.
[FIX] Thay vΓ²ng for Python sang tensor ops để tΔƒng tα»‘c ~3-5Γ— trΓͺn GPU.
Trả về token IDs [B, max_len].
"""
batch_size = fused_features.size(0)
device = fused_features.device
all_results = []
for b in range(batch_size):
feat = fused_features[b:b+1] # [1, 1, H]
beams = [(torch.zeros((1, 1), dtype=torch.long, device=device), 0.0, None)]
for _ in range(max_len):
new_beams = []
for seq, score, h_state in beams:
if seq[0, -1].item() == 2: # EOS
new_beams.append((seq, score, h_state))
continue
if self.decoder_type == "lstm":
curr_emb = self.embedding(seq[:, -1:])
if h_state is None:
h0 = feat.transpose(0, 1).contiguous()
h_state = (h0, torch.zeros_like(h0))
outputs, next_h = self.generator(curr_emb, h_state)
else:
curr_emb = self.embedding(seq)
tgt_mask = self._get_causal_mask(seq.size(1), device)
outputs = self.generator(curr_emb, feat, tgt_mask=tgt_mask)
next_h = None
logits = self.output_layer(outputs[:, -1, :]).squeeze(0) # [V]
# [OPTIMIZED] Vectorised Repetition Penalty (thay vΓ²ng for Python)
unique_ids = seq[0].unique()
valid_ids = unique_ids[(unique_ids != 0) & (unique_ids != 2)]
if valid_ids.numel() > 0:
neg_mask = logits[valid_ids] < 0
factors = torch.where(
neg_mask,
torch.full_like(logits[valid_ids], repetition_penalty),
torch.full_like(logits[valid_ids], 1.0 / repetition_penalty),
)
logits = logits.clone()
logits[valid_ids] = logits[valid_ids] * factors
log_probs = F.log_softmax(logits, dim=-1)
topk_log_probs, topk_ids = torch.topk(log_probs, beam_width)
for i in range(beam_width):
new_seq = torch.cat([seq, topk_ids[i].view(1, 1)], dim=1)
new_beams.append((new_seq, score + topk_log_probs[i].item(), next_h))
def _norm_score(beam):
seq_len = max(beam[0].size(1) - 1, 1)
return beam[1] / (seq_len ** alpha)
new_beams.sort(key=_norm_score, reverse=True)
beams = new_beams[:beam_width]
if all(bm[0][0, -1].item() == 2 for bm in beams):
break
beams.sort(key=_norm_score, reverse=True)
best_seq = beams[0][0][:, 1:] # Bỏ BOS
if best_seq.size(1) < max_len:
pad = torch.zeros((1, max_len - best_seq.size(1)), dtype=torch.long, device=device)
best_seq = torch.cat([best_seq, pad], dim=1)
else:
best_seq = best_seq[:, :max_len]
all_results.append(best_seq)
return torch.cat(all_results, dim=0) # [B, max_len]
# ── Training Forward ─────────────────────────────────────────────────────
def forward(self, fused_features, target_ids=None, beam_width: int = 1):
"""
fused_features: [B, 1, H]
target_ids: [B, SeqLen] β€” Teacher Forcing; None β†’ inference
"""
logits_closed = self.classifier_head(fused_features.squeeze(1))
if target_ids is not None:
target_emb = self.embedding(target_ids)
if self.decoder_type == "lstm":
h0 = fused_features.transpose(0, 1).contiguous()
outputs, _ = self.generator(target_emb, (h0, torch.zeros_like(h0)))
else:
tgt_mask = self._get_causal_mask(target_ids.size(1), target_ids.device)
outputs = self.generator(target_emb, fused_features, tgt_mask=tgt_mask)
logits_open = self.output_layer(outputs)
else:
logits_open = self.generate(fused_features, beam_width=beam_width)
return logits_closed, logits_open