Medical-VQA / src /engine /dpo_trainer.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
import torch
import torch.nn.functional as F
from tqdm import tqdm
import json
import os
import random
from src.utils.text_utils import get_target_answer, normalize_answer
def _is_closed_question(question: str, answer: str) -> bool:
q = normalize_answer(question)
a = normalize_answer(answer)
return (
a in {"có", "không"}
or q.endswith(" không")
or " bình thường " in f" {q} "
or " có " in f" {q} "
)
def _flip_closed_answer(answer: str) -> str:
a = normalize_answer(answer)
if a == "có":
return "không"
if a == "không":
return "có"
return a
def _answer_category(question: str, answer: str) -> str:
q = normalize_answer(question)
a = normalize_answer(answer)
if _is_closed_question(question, answer):
return "closed"
if any(term in q for term in ["ở đâu", "vi tri", "where"]):
return "location"
if any(term in a for term in ["trái", "phải", "trên", "dưới", "giữa", "bên"]):
return "location"
if any(term in a for term in ["mặt phẳng", "ngang", "vành", "dọc"]):
return "plane"
if any(term in a for term in ["gan", "phổi", "tim", "não", "thận", "lách", "bàng quang", "khí quản", "trung thất"]):
return "organ"
return "finding"
def _build_answer_pools(data: list[dict], max_words: int) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
question_to_answers = {}
category_to_answers = {}
for item in data:
question = item.get("question_vi", item.get("question", ""))
answer = get_target_answer(item, max_words=max_words)
if not question or not answer:
continue
q_norm = normalize_answer(question)
a_norm = normalize_answer(answer)
category = _answer_category(question, answer)
question_to_answers.setdefault(q_norm, [])
if a_norm not in question_to_answers[q_norm]:
question_to_answers[q_norm].append(a_norm)
category_to_answers.setdefault(category, [])
if a_norm not in category_to_answers[category]:
category_to_answers[category].append(a_norm)
return question_to_answers, category_to_answers
def _build_rejected_candidates(
data: list[dict],
idx: int,
chosen: str,
question_to_answers: dict[str, list[str]],
category_to_answers: dict[str, list[str]],
) -> list[str]:
item = data[idx]
question = item.get("question_vi", item.get("question", ""))
question_norm = normalize_answer(question)
chosen_norm = normalize_answer(chosen)
category = _answer_category(question, chosen)
candidates = []
if _is_closed_question(question, chosen):
flipped = _flip_closed_answer(chosen)
if flipped and flipped != chosen_norm:
candidates.append(flipped)
else:
for answer in question_to_answers.get(question_norm, []):
if answer != chosen_norm:
candidates.append(answer)
for answer in category_to_answers.get(category, []):
if answer != chosen_norm:
candidates.append(answer)
deduped = []
seen = set()
for candidate in candidates:
candidate_norm = normalize_answer(candidate)
if not candidate_norm or candidate_norm == chosen_norm or candidate_norm in seen:
continue
seen.add(candidate_norm)
deduped.append(candidate_norm)
return deduped
def _build_pair_record(item: dict, source_idx: int, chosen: str, rejected: str) -> dict:
return {
"image": item.get("image_name") or item.get("image"),
"source_idx": source_idx,
"question": item["question_vi"],
"chosen": chosen,
"rejected": rejected,
"answer_type": _answer_category(item["question_vi"], chosen),
}
def _round_robin_merge(grouped_pairs: dict[str, list[dict]], target_count: int) -> list[dict]:
ordered_groups = sorted(grouped_pairs.keys())
merged = []
while len(merged) < target_count:
progressed = False
for group in ordered_groups:
if grouped_pairs[group]:
merged.append(grouped_pairs[group].pop())
progressed = True
if len(merged) >= target_count:
break
if not progressed:
break
return merged
def create_preference_data(
vqa_json_path,
output_path,
num_pairs=400,
closed_ratio=0.6,
max_answer_words=6,
seed=42,
):
"""
Tạo dữ liệu Preference (Chosen vs Rejected) cho DPO.
Trong Medical VQA, 'Rejected' thường là các câu trả lời bị hallucination hoặc sai thuật ngữ y khoa.
"""
with open(vqa_json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
question_to_answers, category_to_answers = _build_answer_pools(data, max_words=max_answer_words)
rng = random.Random(seed)
closed_pairs = []
open_pairs_by_group = {"location": [], "plane": [], "organ": [], "finding": []}
for i in range(len(data)):
item = data[i]
chosen = get_target_answer(item, max_words=max_answer_words)
chosen_norm = normalize_answer(chosen)
if not chosen_norm or len(chosen_norm.split()) > max_answer_words:
continue
rejected_candidates = _build_rejected_candidates(
data,
i,
chosen_norm,
question_to_answers=question_to_answers,
category_to_answers=category_to_answers,
)
category = _answer_category(item["question_vi"], chosen_norm)
for rejected in rejected_candidates:
if len(rejected.split()) > max_answer_words:
continue
pair = _build_pair_record(item, i, chosen_norm, rejected)
if category == "closed":
closed_pairs.append(pair)
elif category in open_pairs_by_group:
open_pairs_by_group[category].append(pair)
rng.shuffle(closed_pairs)
for pairs in open_pairs_by_group.values():
rng.shuffle(pairs)
target_closed = min(len(closed_pairs), int(round(num_pairs * closed_ratio)))
target_open = max(0, num_pairs - target_closed)
sampled_closed = closed_pairs[:target_closed]
sampled_open = _round_robin_merge(open_pairs_by_group, target_open)
pref_data = sampled_closed + sampled_open
rng.shuffle(pref_data)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(pref_data, f, ensure_ascii=False, indent=2)
print(
f"[SUCCESS] Đã tạo {len(pref_data)} cặp preference dữ liệu tại {output_path} "
f"(closed={len(sampled_closed)}, open={len(sampled_open)})"
)
preview_count = min(30, len(pref_data))
if preview_count:
print(f"[INFO] Preview {preview_count} cặp preference đầu tiên để kiểm tra nhanh:")
for idx, pair in enumerate(pref_data[:preview_count], start=1):
print(
f" [{idx:02d}] type={pair.get('answer_type')} | "
f"Q={pair.get('question')} | chosen={pair.get('chosen')} | rejected={pair.get('rejected')}"
)
return pref_data
class MedicalDPOTrainer:
"""
Trainer cho Direct Preference Optimization (DPO) trên LLaVA-Med.
Giúp tối ưu hóa mô hình dựa trên các cặp preference dữ liệu y tế.
"""
def __init__(self, model, reference_model, train_loader, optimizer, device, config):
self.model = model
self.reference_model = reference_model
self.train_loader = train_loader
self.optimizer = optimizer
self.device = device
self.config = config
self.beta = config.get('dpo_beta', 0.1)
def get_log_probs(self, logits, labels):
"""
Tính log probabilities cho các sequence.
logits: [batch, seq_len, vocab]
labels: [batch, seq_len]
"""
# Shift logits và labels để khớp (next token prediction)
log_probs = F.log_softmax(logits, dim=-1)
# Lấy log prob của các token đúng
per_token_logps = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(2)
# Chỉ lấy các token không phải padding (giả định mask > 0)
return (per_token_logps * (labels != 0)).sum(-1)
def compute_loss(self, policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps, reference_rejected_logps):
"""
Tính DPO loss theo công thức: -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(self.beta * logits).mean()
# Thêm các chỉ số để theo dõi (rewards)
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return loss, chosen_rewards, rejected_rewards
def train(self, epochs=3):
print(f"[INFO] Bắt đầu huấn luyện DPO (beta={self.beta})...")
self.model.train()
self.reference_model.eval()
# Freeze reference model để tiết kiệm VRAM (Quan trọng cho T4)
for param in self.reference_model.parameters():
param.requires_grad_(False)
print(f"[INFO] DPO Trainer Ready ({self.device})")
for epoch in range(epochs):
self.model.train()
total_loss = 0.0 # Đã thêm dòng khởi tạo total_loss tại đây
pbar = tqdm(self.train_loader, desc=f"DPO Epoch {epoch+1}")
for batch in pbar:
images = batch['image'].to(self.device)
chosen_ids = batch['chosen_ids'].to(self.device)
rejected_ids = batch['rejected_ids'].to(self.device)
# Tính Logits cho Chosen và Rejected (Sử dụng Duck Typing/Safe Forward)
try:
# Case: LLaVA-style multimodal model
outputs_w = self.model(input_ids=chosen_ids, pixel_values=images, labels=chosen_ids)
outputs_l = self.model(input_ids=rejected_ids, pixel_values=images, labels=rejected_ids)
logits_w = outputs_w.logits
logits_l = outputs_l.logits
except Exception:
# Fallback: Modular model (A1/A2 style)
_, logits_w = self.model(images, chosen_ids)
_, logits_l = self.model(images, rejected_ids)
# 2. Forward Reference Model (No Grad)
with torch.no_grad():
try:
# Multimodal case
outputs_ref_w = self.reference_model(input_ids=chosen_ids, pixel_values=images, labels=chosen_ids)
outputs_ref_l = self.reference_model(input_ids=rejected_ids, pixel_values=images, labels=rejected_ids)
ref_logits_w = outputs_ref_w.logits
ref_logits_l = outputs_ref_l.logits
except Exception:
# Modular case
_, ref_logits_w = self.reference_model(images, chosen_ids)
_, ref_logits_l = self.reference_model(images, rejected_ids)
# 3. Tính log probs
logps_w = self.get_log_probs(logits_w, chosen_ids)
logps_l = self.get_log_probs(logits_l, rejected_ids)
ref_logps_w = self.get_log_probs(ref_logits_w, chosen_ids)
ref_logps_l = self.get_log_probs(ref_logits_l, rejected_ids)
# 4. Tính Loss
loss, _, _ = self.compute_loss(logps_w, logps_l, ref_logps_w, ref_logps_l)
# 5. Backward
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
pbar.set_postfix({"loss": loss.item()})
print(f"Epoch {epoch+1} | DPO Loss: {total_loss/len(self.train_loader):.4f}")