Medical-VQA / src /engine /trainer.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
import wandb
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import csv
import json
from src.utils.early_stopping import DynamicClassWeights
class MedicalVQATrainer:
def __init__(self, model, train_loader, val_loader, optimizer, device, config, scheduler=None, pad_token_id=0, beam_width=1):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.optimizer = optimizer
self.scheduler = scheduler
self.device = device
self.config = config
self.beam_width = beam_width
# [FIX] Dynamic class weights computed from actual data distribution
# Replaces hard-coded [1.0, 2.5] which may not match real imbalance ratio
dynamic_weights = DynamicClassWeights.compute_weights(train_loader, device=device)
self.criterion_closed = nn.CrossEntropyLoss(weight=dynamic_weights)
# [NOTE] Label smoothing only on open-ended head: closed-head needs sharp 0/1
self.criterion_open = nn.CrossEntropyLoss(
ignore_index=pad_token_id,
label_smoothing=config['train'].get('label_smoothing', 0.0)
)
self.criterion_closed_hard = nn.CrossEntropyLoss(weight=dynamic_weights) # no smoothing
# AMP (Automatic Mixed Precision)
self.use_amp = config['train'].get('use_amp', False) and device.type == 'cuda'
self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
self.history = []
@staticmethod
def _flatten_dict(data, parent_key="", sep="."):
items = {}
for key, value in data.items():
new_key = f"{parent_key}{sep}{key}" if parent_key else str(key)
if isinstance(value, dict):
items.update(MedicalVQATrainer._flatten_dict(value, new_key, sep=sep))
elif isinstance(value, (list, tuple)):
continue
else:
items[new_key] = value
return items
def save_history(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
json_path = os.path.join(output_dir, "history.json")
csv_path = os.path.join(output_dir, "history.csv")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(self.history, f, ensure_ascii=False, indent=2)
flat_rows = [self._flatten_dict(row) for row in self.history]
if flat_rows:
fieldnames = sorted({key for row in flat_rows for key in row.keys()})
with open(csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(flat_rows)
@staticmethod
def _compute_closed_weights(train_loader):
"""Đếm phΓ’n phα»‘i Yes/No vΓ  tΓ­nh inverse frequency weights."""
counts = {0: 0, 1: 0} # 0=khΓ΄ng, 1=cΓ³
for batch in train_loader:
labels = batch['label_closed']
for lbl in labels:
v = lbl.item()
if v in counts:
counts[v] += 1
total = counts[0] + counts[1]
if total == 0:
return torch.ones(2)
# Inverse frequency: class Γ­t mαΊ«u β†’ weight cao hΖ‘n
w0 = total / (2 * max(counts[0], 1))
w1 = total / (2 * max(counts[1], 1))
weights = torch.tensor([w0, w1], dtype=torch.float32)
print(f"[INFO] Closed question distribution: khΓ΄ng={counts[0]}, cΓ³={counts[1]}")
return weights
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
# [OPTIMIZATION] Gradient accumulation for larger effective batch size
accumulation_steps = self.config['train'].get('gradient_accumulation_steps', 2)
for batch_idx, batch in enumerate(pbar):
images = batch['image'].to(self.device)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
label_closed = batch['label_closed'].to(self.device)
target_ids = batch['target_ids'].to(self.device)
# Zero gradients only at the beginning or after optimizer step
if batch_idx % accumulation_steps == 0:
self.optimizer.zero_grad()
# Sα»­ dα»₯ng AMP Autocast
with torch.cuda.amp.autocast(enabled=self.use_amp):
# Teacher Forcing: Input lΓ  <s> A B, Target lΓ  A B </s>
decoder_input = target_ids[:, :-1]
decoder_target = target_ids[:, 1:]
logits_closed, logits_open = self.model(images, input_ids, attention_mask, decoder_input)
# Loss calculation
loss = 0
mask_closed = (label_closed != -1)
if mask_closed.any():
loss += self.criterion_closed(logits_closed[mask_closed], label_closed[mask_closed])
# PhΓ’n tΓ‘ch Loss Generator để chα»‘ng Mode Collapse (Lười biαΊΏng)
vocab_size = logits_open.size(-1)
mask_open = (label_closed == -1)
# 1. CΓ’u hỏi Yes/No: GiαΊ£m trọng sα»‘ xuα»‘ng cα»±c thαΊ₯p (0.1) để model khΓ΄ng bα»‹ thiΓͺn vα»‹
if mask_closed.any():
loss_gen_closed = self.criterion_open(logits_open[mask_closed].reshape(-1, vocab_size), decoder_target[mask_closed].reshape(-1))
loss += loss_gen_closed * 0.1
# 2. CΓ’u hỏi Mở: TΔƒng trọng sα»‘ + Length Penalty + Coverage Penalty
if mask_open.any():
open_logits = logits_open[mask_open]
open_targets = decoder_target[mask_open]
loss_gen_open = self.criterion_open(open_logits.reshape(-1, vocab_size), open_targets.reshape(-1))
# Length penalty: phαΊ‘t nαΊΏu model sinh quΓ‘ Γ­t token cΓ³ nghΔ©a
pred_lengths = (open_targets != self.criterion_open.ignore_index).float().sum(dim=-1).mean()
length_penalty = torch.clamp(1.0 - pred_lengths / 15.0, min=0.0)
# Thay coverage loss bαΊ±ng entropy penalty (Δ‘ΓΊng hΖ‘n)
# PhαΊ‘t khi model quΓ‘ confident vΓ o 1 token
probs = torch.softmax(open_logits, dim=-1) # [N, seq, vocab]
entropy = -(probs * torch.log(probs + 1e-9)).sum(dim=-1).mean()
coverage_loss = torch.clamp(2.0 - entropy, min=0.0) # phαΊ‘t nαΊΏu entropy < 2.0
# [TUNED] Reduce weight 3.0β†’2.0: open head was dominating,
# causing closed-head accuracy to plateau (observed in A1/A2 runs)
open_loss_weight = self.config.get('open_loss_weight', 2.0)
loss += (loss_gen_open + 0.3 * length_penalty + 0.1 * coverage_loss) * open_loss_weight
# [OPTIMIZATION] Normalize loss by accumulation steps for proper gradient scaling
loss = loss / accumulation_steps
# Backward vα»›i GradScaler
self.scaler.scale(loss).backward()
# [OPTIMIZATION] Update weights only after accumulating gradients
is_last_batch = (batch_idx + 1) == len(self.train_loader)
if (batch_idx + 1) % accumulation_steps == 0 or is_last_batch:
# Gradient Clipping
if self.config['train'].get('grad_clip'):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.config['train']['grad_clip'])
self.scaler.step(self.optimizer)
self.scaler.update()
# [CRITICAL FIX] Step scheduler sau mα»—i batch thay vΓ¬ epoch để warmup mượt hΖ‘n
if self.scheduler:
self.scheduler.step()
total_loss += loss.item() * accumulation_steps
# [FIX] Log LR cho tα»«ng param group β€” hiển thα»‹ decoder LR (group cuα»‘i) trΓͺn progress bar
decoder_lr = self.optimizer.param_groups[-1]['lr']
vision_lr = self.optimizer.param_groups[0]['lr']
if wandb.run:
wandb.log({
"batch_loss": loss.item(),
"lr_vision": vision_lr,
"lr_decoder": decoder_lr,
})
pbar.set_postfix({"loss": f"{loss.item():.3f}", "dec_lr": f"{decoder_lr:.1e}", "vis_lr": f"{vision_lr:.1e}"})
epoch_train_loss = total_loss / len(self.train_loader)
if wandb.run:
wandb.log({"train_loss_epoch": epoch_train_loss})
return epoch_train_loss
def val_epoch(self, tokenizer, epoch=0):
"""
Thα»±c hiện Δ‘Γ‘nh giΓ‘ trΓͺn tαΊ­p Validation sau mα»—i Epoch.
"""
from src.engine.medical_eval import evaluate_vqa
max_ans_len = self.config.get('data', {}).get('max_answer_len', 32)
max_words = self.config.get('data', {}).get('answer_max_words', 10)
print(f"\nπŸ” Đang chαΊ‘y Validation cho Epoch {epoch} (max_ans_len={max_ans_len})...")
metrics = evaluate_vqa(
self.model,
self.val_loader,
self.device,
tokenizer,
beam_width=self.beam_width,
max_len=max_ans_len,
max_words=max_words
)
# In cÑc metrics quan trọng
print(
f"[METRICS] Accuracy: {metrics.get('accuracy_normalized', metrics['accuracy']):.4f} | "
f"F1: {metrics.get('f1_normalized', metrics['f1']):.4f} | "
f"BLEU-4: {metrics.get('bleu4_normalized', metrics['bleu4']):.4f}"
)
if wandb.run:
wandb.log({
"epoch": epoch,
"val_accuracy": metrics["accuracy"],
"val_accuracy_normalized": metrics.get("accuracy_normalized", metrics["accuracy"]),
"val_f1": metrics["f1"],
"val_f1_normalized": metrics.get("f1_normalized", metrics["f1"]),
"val_bleu4": metrics["bleu4"],
"val_bleu4_normalized": metrics.get("bleu4_normalized", metrics["bleu4"]),
"val_bert_score": metrics.get("bert_score", 0),
"val_bert_score_raw": metrics.get("bert_score_raw", metrics.get("bert_score", 0)),
"val_semantic_raw": metrics.get("semantic_raw", metrics.get("semantic", 0)),
})
return metrics
def train(self, epochs, tokenizer=None):
best_val_acc = 0.0
patience = self.config['train'].get('patience', 10)
counter = 0
ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
history_dir = self.config.get("history_dir")
print(f"[INFO] BαΊ―t Δ‘αΊ§u huαΊ₯n luyện trong {epochs} epochs...")
# Log to WandB if available
if wandb.run is not None:
wandb.config.update({
'total_epochs': epochs,
'patience': patience,
'variant': self.config.get('variant', 'Unknown'),
'device': str(self.device),
'use_amp': self.use_amp,
})
for epoch in range(1, epochs + 1):
train_loss = self.train_epoch(epoch)
metrics = self.val_epoch(tokenizer, epoch=epoch)
val_acc = metrics.get('accuracy_normalized', metrics.get('accuracy', 0))
closed_eval = metrics.get("closed_eval", {})
open_eval = metrics.get("open_eval", {})
is_best = val_acc > best_val_acc
epoch_record = {
"epoch": epoch,
"train_loss": float(train_loss),
"val_accuracy": float(metrics.get("accuracy", 0.0)),
"val_accuracy_normalized": float(metrics.get("accuracy_normalized", metrics.get("accuracy", 0.0))),
"val_f1": float(metrics.get("f1", 0.0)),
"val_f1_normalized": float(metrics.get("f1_normalized", metrics.get("f1", 0.0))),
"val_bleu4": float(metrics.get("bleu4", 0.0)),
"val_bleu4_normalized": float(metrics.get("bleu4_normalized", metrics.get("bleu4", 0.0))),
"val_bert_score": float(metrics.get("bert_score", 0.0)),
"val_bert_score_raw": float(metrics.get("bert_score_raw", metrics.get("bert_score", 0.0))),
"val_semantic_raw": float(metrics.get("semantic_raw", metrics.get("semantic", 0.0))),
"val_closed_accuracy": float(closed_eval.get("accuracy", metrics.get("closed", {}).get("accuracy", -1))),
"val_closed_em": float(closed_eval.get("em", metrics.get("closed", {}).get("em", -1))),
"val_closed_f1": float(closed_eval.get("f1", metrics.get("closed", {}).get("f1", -1))),
"val_open_accuracy": float(metrics.get("open", {}).get("accuracy", -1)),
"val_open_semantic": float(open_eval.get("semantic", metrics.get("open", {}).get("semantic", -1))),
"val_open_bertscore": float(open_eval.get("bert_score", metrics.get("open", {}).get("bert_score", -1))),
"val_open_f1": float(open_eval.get("f1", metrics.get("open", {}).get("f1", -1))),
"val_open_rouge_l": float(open_eval.get("rouge_l", metrics.get("open", {}).get("rouge_l", -1))),
"best_so_far": bool(is_best),
"metrics": metrics,
}
self.history.append(epoch_record)
# Kiểm tra vΓ  LΖ°u Best Checkpoint
if is_best:
best_val_acc = val_acc
counter = 0
variant = self.config.get('variant', 'A')
save_path = os.path.join(ckpt_dir, f"medical_vqa_{variant}_best.pth")
torch.save(self.model.state_dict(), save_path)
resume_path = os.path.join(ckpt_dir, f"medical_vqa_{variant}_resume.pth")
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'best_val_acc': best_val_acc,
'train_loss': float(train_loss),
}
torch.save(checkpoint, resume_path)
print(f"🌟 Best model saved with Accuracy: {val_acc:.4f}")
else:
counter += 1
if history_dir:
self.save_history(history_dir)
if counter >= patience:
print(f"πŸ›‘ Early stopping tαΊ‘i epoch {epoch}!")
break
print("[INFO] HuαΊ₯n luyện hoΓ n tαΊ₯t.")
if history_dir:
self.save_history(history_dir)
# ── Auto-plot sau khi training kαΊΏt thΓΊc ──────────────────────────────
if history_dir and len(self.history) >= 1:
chart_paths = self.plot_training_results(history_dir)
print(f"[INFO] πŸ“Š Đã lΖ°u {len(chart_paths)} biểu Δ‘α»“ tαΊ‘i: {history_dir}")
return self.history
# ── Visualization ────────────────────────────────────────────────────────
def plot_training_results(self, output_dir: str) -> list:
"""
Tα»± Δ‘α»™ng vαΊ½ vΓ  lΖ°u 4 biểu Δ‘α»“ sau khi training kαΊΏt thΓΊc:
1. Train Loss theo epoch
2. Val Accuracy + F1 + BLEU-4 (multi-metric)
3. Closed vs Open Accuracy (bar per epoch)
4. BERTScore + Semantic Score
TrαΊ£ về list cΓ‘c đường dαΊ«n file αΊ£nh Δ‘Γ£ lΖ°u.
"""
try:
import matplotlib
matplotlib.use("Agg") # Non-interactive backend (an toΓ n cho server)
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
except ImportError:
print("[WARNING] matplotlib chΖ°a cΓ i β€” bỏ qua vαΊ½ biểu Δ‘α»“.")
return []
os.makedirs(output_dir, exist_ok=True)
variant = self.config.get('variant', 'Model')
epochs = [r["epoch"] for r in self.history]
saved = []
# Palette
COLORS = {
"loss": "#e74c3c",
"accuracy": "#2ecc71",
"f1": "#3498db",
"bleu4": "#9b59b6",
"bert": "#e67e22",
"semantic": "#1abc9c",
"closed": "#2980b9",
"open": "#e74c3c",
}
def _finish(fig, fname):
fig.tight_layout()
path = os.path.join(output_dir, fname)
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
# Upload to WandB if available
if wandb.run:
wandb.log({fname.replace(".png", ""): wandb.Image(path)})
saved.append(path)
# ── Chart 1: Train Loss ──────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(epochs, [r["train_loss"] for r in self.history],
color=COLORS["loss"], linewidth=2.5, marker="o", markersize=5,
label="Train Loss")
ax.set_title(f"[{variant}] Train Loss per Epoch", fontsize=14, fontweight="bold")
ax.set_xlabel("Epoch"); ax.set_ylabel("Loss")
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
ax.legend(); ax.grid(True, alpha=0.3)
_finish(fig, f"{variant}_01_train_loss.png")
# ── Chart 2: Validation Metrics (Acc / F1 / BLEU-4) ─────────────────
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(epochs, [r["val_accuracy_normalized"] for r in self.history],
color=COLORS["accuracy"], linewidth=2.5, marker="o", label="Accuracy")
ax.plot(epochs, [r["val_f1_normalized"] for r in self.history],
color=COLORS["f1"], linewidth=2.5, marker="s", label="F1")
ax.plot(epochs, [r["val_bleu4_normalized"] for r in self.history],
color=COLORS["bleu4"], linewidth=2.5, marker="^", label="BLEU-4")
# Mark best epoch
best_epoch = max(self.history, key=lambda r: r["val_accuracy_normalized"])
ax.axvline(x=best_epoch["epoch"], color="gray", linestyle="--", alpha=0.6,
label=f"Best epoch {best_epoch['epoch']} ({best_epoch['val_accuracy_normalized']:.2%})")
ax.set_title(f"[{variant}] Validation Metrics per Epoch", fontsize=14, fontweight="bold")
ax.set_xlabel("Epoch"); ax.set_ylabel("Score")
ax.set_ylim(0, 1.05)
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
ax.legend(loc="lower right"); ax.grid(True, alpha=0.3)
_finish(fig, f"{variant}_02_val_metrics.png")
# ── Chart 3: Closed vs Open Accuracy ────────────────────────────────
closed_vals = [r["val_closed_accuracy"] for r in self.history]
open_vals = [r["val_open_accuracy"] for r in self.history]
has_closed = any(v >= 0 for v in closed_vals)
has_open = any(v >= 0 for v in open_vals)
if has_closed or has_open:
fig, ax = plt.subplots(figsize=(10, 5))
w = 0.35
x = range(len(epochs))
if has_closed:
c_vals = [v if v >= 0 else 0 for v in closed_vals]
ax.bar([i - w/2 for i in x], c_vals, w, label="Closed (Yes/No)",
color=COLORS["closed"], alpha=0.85)
if has_open:
o_vals = [v if v >= 0 else 0 for v in open_vals]
ax.bar([i + w/2 for i in x], o_vals, w, label="Open-ended",
color=COLORS["open"], alpha=0.85)
ax.set_xticks(list(x)); ax.set_xticklabels([f"E{e}" for e in epochs])
ax.set_title(f"[{variant}] Closed vs Open Accuracy per Epoch",
fontsize=14, fontweight="bold")
ax.set_ylabel("Accuracy")
ax.set_ylim(0, 1.05)
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
ax.legend(); ax.grid(True, alpha=0.3, axis="y")
_finish(fig, f"{variant}_03_closed_vs_open.png")
# ── Chart 4: BERTScore + Semantic Score ──────────────────────────────
bert_vals = [r["val_bert_score_raw"] for r in self.history]
semantic_vals = [r["val_semantic_raw"] for r in self.history]
if any(v > 0 for v in bert_vals + semantic_vals):
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(epochs, bert_vals, color=COLORS["bert"], linewidth=2.5,
marker="o", label="BERTScore")
ax.plot(epochs, semantic_vals, color=COLORS["semantic"], linewidth=2.5,
marker="s", label="Semantic Score")
ax.set_title(f"[{variant}] BERTScore & Semantic Score per Epoch",
fontsize=14, fontweight="bold")
ax.set_xlabel("Epoch"); ax.set_ylabel("Score")
ax.set_ylim(0, 1.05)
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
ax.legend(); ax.grid(True, alpha=0.3)
_finish(fig, f"{variant}_04_bert_semantic.png")
# ── Print final summary table ─────────────────────────────────────────
print("\n" + "═" * 72)
print(f" πŸ“Š TRAINING SUMMARY β€” {variant}")
print("═" * 72)
print(f" {'Epoch':>5} {'TrainLoss':>10} {'Accuracy':>9} {'F1':>7} {'BLEU-4':>7} {'Best':>5}")
print("─" * 72)
for r in self.history:
star = " β˜…" if r.get("best_so_far") else ""
print(
f" {r['epoch']:>5} {r['train_loss']:>10.4f} "
f"{r['val_accuracy_normalized']:>9.2%} "
f"{r['val_f1_normalized']:>7.2%} "
f"{r['val_bleu4_normalized']:>7.2%}{star}"
)
print("═" * 72 + "\n")
return saved