Medical-VQA / src /engine /trainer.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
raw
history blame
23.2 kB
import wandb
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import csv
import json
from src.utils.early_stopping import DynamicClassWeights
class MedicalVQATrainer:
def __init__(self, model, train_loader, val_loader, optimizer, device, config, scheduler=None, pad_token_id=0, beam_width=1):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.optimizer = optimizer
self.scheduler = scheduler
self.device = device
self.config = config
self.beam_width = beam_width
# [FIX] Dynamic class weights computed from actual data distribution
# Replaces hard-coded [1.0, 2.5] which may not match real imbalance ratio
dynamic_weights = DynamicClassWeights.compute_weights(train_loader, device=device)
self.criterion_closed = nn.CrossEntropyLoss(weight=dynamic_weights)
# [NOTE] Label smoothing only on open-ended head: closed-head needs sharp 0/1
self.criterion_open = nn.CrossEntropyLoss(
ignore_index=pad_token_id,
label_smoothing=config['train'].get('label_smoothing', 0.0)
)
self.criterion_closed_hard = nn.CrossEntropyLoss(weight=dynamic_weights) # no smoothing
# AMP (Automatic Mixed Precision)
self.use_amp = config['train'].get('use_amp', False) and device.type == 'cuda'
self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
self.history = []
@staticmethod
def _flatten_dict(data, parent_key="", sep="."):
items = {}
for key, value in data.items():
new_key = f"{parent_key}{sep}{key}" if parent_key else str(key)
if isinstance(value, dict):
items.update(MedicalVQATrainer._flatten_dict(value, new_key, sep=sep))
elif isinstance(value, (list, tuple)):
continue
else:
items[new_key] = value
return items
def save_history(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
json_path = os.path.join(output_dir, "history.json")
csv_path = os.path.join(output_dir, "history.csv")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(self.history, f, ensure_ascii=False, indent=2)
flat_rows = [self._flatten_dict(row) for row in self.history]
if flat_rows:
fieldnames = sorted({key for row in flat_rows for key in row.keys()})
with open(csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(flat_rows)
@staticmethod
def _compute_closed_weights(train_loader):
"""Đếm phΓ’n phα»‘i Yes/No vΓ  tΓ­nh inverse frequency weights."""
counts = {0: 0, 1: 0} # 0=khΓ΄ng, 1=cΓ³
for batch in train_loader:
labels = batch['label_closed']
for lbl in labels:
v = lbl.item()
if v in counts:
counts[v] += 1
total = counts[0] + counts[1]
if total == 0:
return torch.ones(2)
# Inverse frequency: class Γ­t mαΊ«u β†’ weight cao hΖ‘n
w0 = total / (2 * max(counts[0], 1))
w1 = total / (2 * max(counts[1], 1))
weights = torch.tensor([w0, w1], dtype=torch.float32)
print(f"[INFO] Closed question distribution: khΓ΄ng={counts[0]}, cΓ³={counts[1]}")
return weights
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
# [OPTIMIZATION] Gradient accumulation for larger effective batch size
accumulation_steps = self.config['train'].get('gradient_accumulation_steps', 2)
for batch_idx, batch in enumerate(pbar):
images = batch['image'].to(self.device)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
label_closed = batch['label_closed'].to(self.device)
target_ids = batch['target_ids'].to(self.device)
# Zero gradients only at the beginning or after optimizer step
if batch_idx % accumulation_steps == 0:
self.optimizer.zero_grad()
# Sα»­ dα»₯ng AMP Autocast
with torch.cuda.amp.autocast(enabled=self.use_amp):
# Teacher Forcing: Input lΓ  <s> A B, Target lΓ  A B </s>
decoder_input = target_ids[:, :-1]
decoder_target = target_ids[:, 1:]
logits_closed, logits_open = self.model(images, input_ids, attention_mask, decoder_input)
# Loss calculation
loss = 0
mask_closed = (label_closed != -1)
if mask_closed.any():
loss += self.criterion_closed(logits_closed[mask_closed], label_closed[mask_closed])
# PhΓ’n tΓ‘ch Loss Generator để chα»‘ng Mode Collapse (Lười biαΊΏng)
vocab_size = logits_open.size(-1)
mask_open = (label_closed == -1)
# 1. CΓ’u hỏi Yes/No: GiαΊ£m trọng sα»‘ xuα»‘ng cα»±c thαΊ₯p (0.1) để model khΓ΄ng bα»‹ thiΓͺn vα»‹
if mask_closed.any():
loss_gen_closed = self.criterion_open(logits_open[mask_closed].reshape(-1, vocab_size), decoder_target[mask_closed].reshape(-1))
loss += loss_gen_closed * 0.1
# 2. CΓ’u hỏi Mở: TΔƒng trọng sα»‘ + Length Penalty + Coverage Penalty
if mask_open.any():
open_logits = logits_open[mask_open]
open_targets = decoder_target[mask_open]
loss_gen_open = self.criterion_open(open_logits.reshape(-1, vocab_size), open_targets.reshape(-1))
# Length penalty: phαΊ‘t nαΊΏu model sinh quΓ‘ Γ­t token cΓ³ nghΔ©a
pred_lengths = (open_targets != self.criterion_open.ignore_index).float().sum(dim=-1).mean()
length_penalty = torch.clamp(1.0 - pred_lengths / 15.0, min=0.0)
# Thay coverage loss bαΊ±ng entropy penalty (Δ‘ΓΊng hΖ‘n)
# PhαΊ‘t khi model quΓ‘ confident vΓ o 1 token
probs = torch.softmax(open_logits, dim=-1) # [N, seq, vocab]
entropy = -(probs * torch.log(probs + 1e-9)).sum(dim=-1).mean()
coverage_loss = torch.clamp(2.0 - entropy, min=0.0) # phαΊ‘t nαΊΏu entropy < 2.0
# [TUNED] Reduce weight 3.0β†’2.0: open head was dominating,
# causing closed-head accuracy to plateau (observed in A1/A2 runs)
open_loss_weight = self.config.get('open_loss_weight', 2.0)
loss += (loss_gen_open + 0.3 * length_penalty + 0.1 * coverage_loss) * open_loss_weight
# [OPTIMIZATION] Normalize loss by accumulation steps for proper gradient scaling
loss = loss / accumulation_steps
# Backward vα»›i GradScaler
self.scaler.scale(loss).backward()
# [OPTIMIZATION] Update weights only after accumulating gradients
is_last_batch = (batch_idx + 1) == len(self.train_loader)
if (batch_idx + 1) % accumulation_steps == 0 or is_last_batch:
# Gradient Clipping
if self.config['train'].get('grad_clip'):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.config['train']['grad_clip'])
self.scaler.step(self.optimizer)
self.scaler.update()
# [CRITICAL FIX] Step scheduler sau mα»—i batch thay vΓ¬ epoch để warmup mượt hΖ‘n
if self.scheduler:
self.scheduler.step()
total_loss += loss.item() * accumulation_steps
# [FIX] Log LR cho tα»«ng param group β€” hiển thα»‹ decoder LR (group cuα»‘i) trΓͺn progress bar
decoder_lr = self.optimizer.param_groups[-1]['lr']
vision_lr = self.optimizer.param_groups[0]['lr']
if wandb.run:
wandb.log({
"batch_loss": loss.item(),
"lr_vision": vision_lr,
"lr_decoder": decoder_lr,
})
pbar.set_postfix({"loss": f"{loss.item():.3f}", "dec_lr": f"{decoder_lr:.1e}", "vis_lr": f"{vision_lr:.1e}"})
epoch_train_loss = total_loss / len(self.train_loader)
if wandb.run:
wandb.log({"train_loss_epoch": epoch_train_loss})
return epoch_train_loss
def val_epoch(self, tokenizer, epoch=0):
"""
Thα»±c hiện Δ‘Γ‘nh giΓ‘ trΓͺn tαΊ­p Validation sau mα»—i Epoch.
"""
from src.engine.medical_eval import evaluate_vqa
max_ans_len = self.config.get('data', {}).get('max_answer_len', 32)
max_words = self.config.get('data', {}).get('answer_max_words', 10)
print(f"\nπŸ” Đang chαΊ‘y Validation cho Epoch {epoch} (max_ans_len={max_ans_len})...")
metrics = evaluate_vqa(
self.model,
self.val_loader,
self.device,
tokenizer,
beam_width=self.beam_width,
max_len=max_ans_len,
max_words=max_words
)
# In cÑc metrics quan trọng
print(
f"[METRICS] Accuracy: {metrics.get('accuracy_normalized', metrics['accuracy']):.4f} | "
f"F1: {metrics.get('f1_normalized', metrics['f1']):.4f} | "
f"BLEU-4: {metrics.get('bleu4_normalized', metrics['bleu4']):.4f}"
)
if wandb.run:
wandb.log({
"epoch": epoch,
"val_accuracy": metrics["accuracy"],
"val_accuracy_normalized": metrics.get("accuracy_normalized", metrics["accuracy"]),
"val_f1": metrics["f1"],
"val_f1_normalized": metrics.get("f1_normalized", metrics["f1"]),
"val_bleu4": metrics["bleu4"],
"val_bleu4_normalized": metrics.get("bleu4_normalized", metrics["bleu4"]),
"val_bert_score": metrics.get("bert_score", 0),
"val_bert_score_raw": metrics.get("bert_score_raw", metrics.get("bert_score", 0)),
"val_semantic_raw": metrics.get("semantic_raw", metrics.get("semantic", 0)),
})
return metrics
def train(self, epochs, tokenizer=None):
best_val_acc = 0.0
patience = self.config['train'].get('patience', 10)
counter = 0
ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
history_dir = self.config.get("history_dir")
print(f"[INFO] BαΊ―t Δ‘αΊ§u huαΊ₯n luyện trong {epochs} epochs...")
# Log to WandB if available
if wandb.run is not None:
wandb.config.update({
'total_epochs': epochs,
'patience': patience,
'variant': self.config.get('variant', 'Unknown'),
'device': str(self.device),
'use_amp': self.use_amp,
})
for epoch in range(1, epochs + 1):
train_loss = self.train_epoch(epoch)
metrics = self.val_epoch(tokenizer, epoch=epoch)
val_acc = metrics.get('accuracy_normalized', metrics.get('accuracy', 0))
closed_eval = metrics.get("closed_eval", {})
open_eval = metrics.get("open_eval", {})
is_best = val_acc > best_val_acc
epoch_record = {
"epoch": epoch,
"train_loss": float(train_loss),
"val_accuracy": float(metrics.get("accuracy", 0.0)),
"val_accuracy_normalized": float(metrics.get("accuracy_normalized", metrics.get("accuracy", 0.0))),
"val_f1": float(metrics.get("f1", 0.0)),
"val_f1_normalized": float(metrics.get("f1_normalized", metrics.get("f1", 0.0))),
"val_bleu4": float(metrics.get("bleu4", 0.0)),
"val_bleu4_normalized": float(metrics.get("bleu4_normalized", metrics.get("bleu4", 0.0))),
"val_bert_score": float(metrics.get("bert_score", 0.0)),
"val_bert_score_raw": float(metrics.get("bert_score_raw", metrics.get("bert_score", 0.0))),
"val_semantic_raw": float(metrics.get("semantic_raw", metrics.get("semantic", 0.0))),
"val_closed_accuracy": float(closed_eval.get("accuracy", metrics.get("closed", {}).get("accuracy", -1))),
"val_closed_em": float(closed_eval.get("em", metrics.get("closed", {}).get("em", -1))),
"val_closed_f1": float(closed_eval.get("f1", metrics.get("closed", {}).get("f1", -1))),
"val_open_accuracy": float(metrics.get("open", {}).get("accuracy", -1)),
"val_open_semantic": float(open_eval.get("semantic", metrics.get("open", {}).get("semantic", -1))),
"val_open_bertscore": float(open_eval.get("bert_score", metrics.get("open", {}).get("bert_score", -1))),
"val_open_f1": float(open_eval.get("f1", metrics.get("open", {}).get("f1", -1))),
"val_open_rouge_l": float(open_eval.get("rouge_l", metrics.get("open", {}).get("rouge_l", -1))),
"best_so_far": bool(is_best),
"metrics": metrics,
}
self.history.append(epoch_record)
# Kiểm tra vΓ  LΖ°u Best Checkpoint
if is_best:
best_val_acc = val_acc
counter = 0
variant = self.config.get('variant', 'A')
save_path = os.path.join(ckpt_dir, f"medical_vqa_{variant}_best.pth")
torch.save(self.model.state_dict(), save_path)
resume_path = os.path.join(ckpt_dir, f"medical_vqa_{variant}_resume.pth")
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'best_val_acc': best_val_acc,
'train_loss': float(train_loss),
}
torch.save(checkpoint, resume_path)
print(f"🌟 Best model saved with Accuracy: {val_acc:.4f}")
else:
counter += 1
if history_dir:
self.save_history(history_dir)
if counter >= patience:
print(f"πŸ›‘ Early stopping tαΊ‘i epoch {epoch}!")
break
print("[INFO] HuαΊ₯n luyện hoΓ n tαΊ₯t.")
if history_dir:
self.save_history(history_dir)
# ── Auto-plot sau khi training kαΊΏt thΓΊc ──────────────────────────────
if history_dir and len(self.history) >= 1:
chart_paths = self.plot_training_results(history_dir)
print(f"[INFO] πŸ“Š Đã lΖ°u {len(chart_paths)} biểu Δ‘α»“ tαΊ‘i: {history_dir}")
return self.history
# ── Visualization ────────────────────────────────────────────────────────
def plot_training_results(self, output_dir: str) -> list:
"""
Tα»± Δ‘α»™ng vαΊ½ vΓ  lΖ°u 4 biểu Δ‘α»“ sau khi training kαΊΏt thΓΊc:
1. Train Loss theo epoch
2. Val Accuracy + F1 + BLEU-4 (multi-metric)
3. Closed vs Open Accuracy (bar per epoch)
4. BERTScore + Semantic Score
TrαΊ£ về list cΓ‘c đường dαΊ«n file αΊ£nh Δ‘Γ£ lΖ°u.
"""
try:
import matplotlib
matplotlib.use("Agg") # Non-interactive backend (an toΓ n cho server)
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
except ImportError:
print("[WARNING] matplotlib chΖ°a cΓ i β€” bỏ qua vαΊ½ biểu Δ‘α»“.")
return []
os.makedirs(output_dir, exist_ok=True)
variant = self.config.get('variant', 'Model')
epochs = [r["epoch"] for r in self.history]
saved = []
# Palette
COLORS = {
"loss": "#e74c3c",
"accuracy": "#2ecc71",
"f1": "#3498db",
"bleu4": "#9b59b6",
"bert": "#e67e22",
"semantic": "#1abc9c",
"closed": "#2980b9",
"open": "#e74c3c",
}
def _finish(fig, fname):
fig.tight_layout()
path = os.path.join(output_dir, fname)
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
# Upload to WandB if available
if wandb.run:
wandb.log({fname.replace(".png", ""): wandb.Image(path)})
saved.append(path)
# ── Chart 1: Train Loss ──────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(epochs, [r["train_loss"] for r in self.history],
color=COLORS["loss"], linewidth=2.5, marker="o", markersize=5,
label="Train Loss")
ax.set_title(f"[{variant}] Train Loss per Epoch", fontsize=14, fontweight="bold")
ax.set_xlabel("Epoch"); ax.set_ylabel("Loss")
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
ax.legend(); ax.grid(True, alpha=0.3)
_finish(fig, f"{variant}_01_train_loss.png")
# ── Chart 2: Validation Metrics (Acc / F1 / BLEU-4) ─────────────────
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(epochs, [r["val_accuracy_normalized"] for r in self.history],
color=COLORS["accuracy"], linewidth=2.5, marker="o", label="Accuracy")
ax.plot(epochs, [r["val_f1_normalized"] for r in self.history],
color=COLORS["f1"], linewidth=2.5, marker="s", label="F1")
ax.plot(epochs, [r["val_bleu4_normalized"] for r in self.history],
color=COLORS["bleu4"], linewidth=2.5, marker="^", label="BLEU-4")
# Mark best epoch
best_epoch = max(self.history, key=lambda r: r["val_accuracy_normalized"])
ax.axvline(x=best_epoch["epoch"], color="gray", linestyle="--", alpha=0.6,
label=f"Best epoch {best_epoch['epoch']} ({best_epoch['val_accuracy_normalized']:.2%})")
ax.set_title(f"[{variant}] Validation Metrics per Epoch", fontsize=14, fontweight="bold")
ax.set_xlabel("Epoch"); ax.set_ylabel("Score")
ax.set_ylim(0, 1.05)
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
ax.legend(loc="lower right"); ax.grid(True, alpha=0.3)
_finish(fig, f"{variant}_02_val_metrics.png")
# ── Chart 3: Closed vs Open Accuracy ────────────────────────────────
closed_vals = [r["val_closed_accuracy"] for r in self.history]
open_vals = [r["val_open_accuracy"] for r in self.history]
has_closed = any(v >= 0 for v in closed_vals)
has_open = any(v >= 0 for v in open_vals)
if has_closed or has_open:
fig, ax = plt.subplots(figsize=(10, 5))
w = 0.35
x = range(len(epochs))
if has_closed:
c_vals = [v if v >= 0 else 0 for v in closed_vals]
ax.bar([i - w/2 for i in x], c_vals, w, label="Closed (Yes/No)",
color=COLORS["closed"], alpha=0.85)
if has_open:
o_vals = [v if v >= 0 else 0 for v in open_vals]
ax.bar([i + w/2 for i in x], o_vals, w, label="Open-ended",
color=COLORS["open"], alpha=0.85)
ax.set_xticks(list(x)); ax.set_xticklabels([f"E{e}" for e in epochs])
ax.set_title(f"[{variant}] Closed vs Open Accuracy per Epoch",
fontsize=14, fontweight="bold")
ax.set_ylabel("Accuracy")
ax.set_ylim(0, 1.05)
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
ax.legend(); ax.grid(True, alpha=0.3, axis="y")
_finish(fig, f"{variant}_03_closed_vs_open.png")
# ── Chart 4: BERTScore + Semantic Score ──────────────────────────────
bert_vals = [r["val_bert_score_raw"] for r in self.history]
semantic_vals = [r["val_semantic_raw"] for r in self.history]
if any(v > 0 for v in bert_vals + semantic_vals):
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(epochs, bert_vals, color=COLORS["bert"], linewidth=2.5,
marker="o", label="BERTScore")
ax.plot(epochs, semantic_vals, color=COLORS["semantic"], linewidth=2.5,
marker="s", label="Semantic Score")
ax.set_title(f"[{variant}] BERTScore & Semantic Score per Epoch",
fontsize=14, fontweight="bold")
ax.set_xlabel("Epoch"); ax.set_ylabel("Score")
ax.set_ylim(0, 1.05)
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
ax.legend(); ax.grid(True, alpha=0.3)
_finish(fig, f"{variant}_04_bert_semantic.png")
# ── Print final summary table ─────────────────────────────────────────
print("\n" + "═" * 72)
print(f" πŸ“Š TRAINING SUMMARY β€” {variant}")
print("═" * 72)
print(f" {'Epoch':>5} {'TrainLoss':>10} {'Accuracy':>9} {'F1':>7} {'BLEU-4':>7} {'Best':>5}")
print("─" * 72)
for r in self.history:
star = " β˜…" if r.get("best_so_far") else ""
print(
f" {r['epoch']:>5} {r['train_loss']:>10.4f} "
f"{r['val_accuracy_normalized']:>9.2%} "
f"{r['val_f1_normalized']:>7.2%} "
f"{r['val_bleu4_normalized']:>7.2%}{star}"
)
print("═" * 72 + "\n")
return saved