Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MedicalVQADecoder(nn.Module): | |
| def __init__( | |
| self, | |
| decoder_type: str = "transformer", | |
| vocab_size: int = 30000, | |
| hidden_size: int = 768, | |
| pretrained_embeddings=None, | |
| num_layers: int = 3, | |
| nhead: int = 8, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.decoder_type = decoder_type.lower() | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| # ββ NhΓ‘nh 1: Classifier cho Yes/No ββββββββββββββββββββββββββββββββββ | |
| # [FIX] ThΓͺm Dropout + GELU theo best-practice hiα»n ΔαΊ‘i | |
| self.classifier_head = nn.Sequential( | |
| nn.Linear(hidden_size, 512), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(512, 2), | |
| ) | |
| # ββ NhΓ‘nh 2: Generator βββββββββββββββββββββββββββββββββββββββββββββββ | |
| self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0) | |
| if pretrained_embeddings is not None: | |
| self.embedding.weight.data.copy_(pretrained_embeddings) | |
| if self.decoder_type == "lstm": | |
| self.generator = nn.LSTM( | |
| hidden_size, hidden_size, num_layers=1, batch_first=True | |
| ) | |
| else: | |
| # [FIX A2] Pre-LayerNorm (norm_first=True): hα»i tα»₯ α»n Δα»nh hΖ‘n, giαΊ£m gap A1-A2 | |
| # dim_feedforward=4*hidden (768*4=3072) theo chuαΊ©n Transformer gα»c | |
| decoder_layer = nn.TransformerDecoderLayer( | |
| d_model=hidden_size, | |
| nhead=nhead, | |
| dim_feedforward=hidden_size * 4, | |
| dropout=dropout, | |
| activation="gelu", | |
| batch_first=True, | |
| norm_first=True, | |
| ) | |
| self.generator = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) | |
| self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False) | |
| # [OPTIMIZATION] Weight Tying: chia sαΊ» trα»ng sα» Embedding β Output Projection | |
| # GiαΊ£m ~vocab_size * hidden_size params, cαΊ£i thiα»n generalization (Press & Wolf 2017) | |
| self.output_layer.weight = self.embedding.weight | |
| # [OPTIMIZATION] Cache causal mask Δα» trΓ‘nh re-allocate mα»i forward pass | |
| self._causal_mask_cache: dict[tuple, torch.Tensor] = {} | |
| # ββ Mask helper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_causal_mask(self, sz: int, device: torch.device) -> torch.Tensor: | |
| key = (sz, str(device)) | |
| if key not in self._causal_mask_cache: | |
| mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool() | |
| self._causal_mask_cache[key] = mask | |
| return self._causal_mask_cache[key] | |
| # ββ Public generate API ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate(self, fused_features, beam_width: int = 1, max_len: int = 10): | |
| """Sinh cΓ’u trαΊ£ lα»i. TrαΊ£ vα» token IDs [B, max_len].""" | |
| if beam_width <= 1: | |
| return self._greedy_search(fused_features, max_len) | |
| return self._beam_search(fused_features, beam_width, max_len) | |
| # ββ Greedy Search ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _greedy_search(self, fused_features, max_len: int): | |
| """ | |
| Greedy decoding (beam_width=1). | |
| LSTM: chα» feed token cuα»i, h_state giα»― ngα»― cαΊ£nh β trΓ‘nh O(nΒ²) recompute. | |
| TrαΊ£ vα» token IDs [B, max_len]. | |
| """ | |
| batch_size = fused_features.size(0) | |
| device = fused_features.device | |
| generated = torch.zeros((batch_size, 1), dtype=torch.long, device=device) # BOS=0 | |
| h_state = None | |
| for _ in range(max_len): | |
| if self.decoder_type == "lstm": | |
| curr_emb = self.embedding(generated[:, -1:]) # [B,1,H] | |
| if h_state is None: | |
| h0 = fused_features.transpose(0, 1).contiguous() | |
| h_state = (h0, torch.zeros_like(h0)) | |
| outputs, h_state = self.generator(curr_emb, h_state) | |
| else: | |
| curr_emb = self.embedding(generated) | |
| tgt_mask = self._get_causal_mask(generated.size(1), device) | |
| outputs = self.generator(curr_emb, fused_features, tgt_mask=tgt_mask) | |
| next_token = self.output_layer(outputs[:, -1:, :]).argmax(dim=-1) | |
| generated = torch.cat([generated, next_token], dim=1) | |
| return generated[:, 1:] # Bα» BOS | |
| # ββ Beam Search ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _beam_search( | |
| self, | |
| fused_features, | |
| beam_width: int, | |
| max_len: int, | |
| repetition_penalty: float = 1.2, | |
| alpha: float = 0.7, | |
| ): | |
| """ | |
| Beam Search vα»i Length Normalization + Vectorised Repetition Penalty. | |
| [FIX] Thay vΓ²ng for Python sang tensor ops Δα» tΔng tα»c ~3-5Γ trΓͺn GPU. | |
| TrαΊ£ vα» token IDs [B, max_len]. | |
| """ | |
| batch_size = fused_features.size(0) | |
| device = fused_features.device | |
| all_results = [] | |
| for b in range(batch_size): | |
| feat = fused_features[b:b+1] # [1, 1, H] | |
| beams = [(torch.zeros((1, 1), dtype=torch.long, device=device), 0.0, None)] | |
| for _ in range(max_len): | |
| new_beams = [] | |
| for seq, score, h_state in beams: | |
| if seq[0, -1].item() == 2: # EOS | |
| new_beams.append((seq, score, h_state)) | |
| continue | |
| if self.decoder_type == "lstm": | |
| curr_emb = self.embedding(seq[:, -1:]) | |
| if h_state is None: | |
| h0 = feat.transpose(0, 1).contiguous() | |
| h_state = (h0, torch.zeros_like(h0)) | |
| outputs, next_h = self.generator(curr_emb, h_state) | |
| else: | |
| curr_emb = self.embedding(seq) | |
| tgt_mask = self._get_causal_mask(seq.size(1), device) | |
| outputs = self.generator(curr_emb, feat, tgt_mask=tgt_mask) | |
| next_h = None | |
| logits = self.output_layer(outputs[:, -1, :]).squeeze(0) # [V] | |
| # [OPTIMIZED] Vectorised Repetition Penalty (thay vΓ²ng for Python) | |
| unique_ids = seq[0].unique() | |
| valid_ids = unique_ids[(unique_ids != 0) & (unique_ids != 2)] | |
| if valid_ids.numel() > 0: | |
| neg_mask = logits[valid_ids] < 0 | |
| factors = torch.where( | |
| neg_mask, | |
| torch.full_like(logits[valid_ids], repetition_penalty), | |
| torch.full_like(logits[valid_ids], 1.0 / repetition_penalty), | |
| ) | |
| logits = logits.clone() | |
| logits[valid_ids] = logits[valid_ids] * factors | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| topk_log_probs, topk_ids = torch.topk(log_probs, beam_width) | |
| for i in range(beam_width): | |
| new_seq = torch.cat([seq, topk_ids[i].view(1, 1)], dim=1) | |
| new_beams.append((new_seq, score + topk_log_probs[i].item(), next_h)) | |
| def _norm_score(beam): | |
| seq_len = max(beam[0].size(1) - 1, 1) | |
| return beam[1] / (seq_len ** alpha) | |
| new_beams.sort(key=_norm_score, reverse=True) | |
| beams = new_beams[:beam_width] | |
| if all(bm[0][0, -1].item() == 2 for bm in beams): | |
| break | |
| beams.sort(key=_norm_score, reverse=True) | |
| best_seq = beams[0][0][:, 1:] # Bα» BOS | |
| if best_seq.size(1) < max_len: | |
| pad = torch.zeros((1, max_len - best_seq.size(1)), dtype=torch.long, device=device) | |
| best_seq = torch.cat([best_seq, pad], dim=1) | |
| else: | |
| best_seq = best_seq[:, :max_len] | |
| all_results.append(best_seq) | |
| return torch.cat(all_results, dim=0) # [B, max_len] | |
| # ββ Training Forward βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def forward(self, fused_features, target_ids=None, beam_width: int = 1): | |
| """ | |
| fused_features: [B, 1, H] | |
| target_ids: [B, SeqLen] β Teacher Forcing; None β inference | |
| """ | |
| logits_closed = self.classifier_head(fused_features.squeeze(1)) | |
| if target_ids is not None: | |
| target_emb = self.embedding(target_ids) | |
| if self.decoder_type == "lstm": | |
| h0 = fused_features.transpose(0, 1).contiguous() | |
| outputs, _ = self.generator(target_emb, (h0, torch.zeros_like(h0))) | |
| else: | |
| tgt_mask = self._get_causal_mask(target_ids.size(1), target_ids.device) | |
| outputs = self.generator(target_emb, fused_features, tgt_mask=tgt_mask) | |
| logits_open = self.output_layer(outputs) | |
| else: | |
| logits_open = self.generate(fused_features, beam_width=beam_width) | |
| return logits_closed, logits_open | |