import torch import torch.nn.functional as F from tqdm import tqdm import json import os import random from src.utils.text_utils import get_target_answer, normalize_answer def _is_closed_question(question: str, answer: str) -> bool: q = normalize_answer(question) a = normalize_answer(answer) return ( a in {"có", "không"} or q.endswith(" không") or " bình thường " in f" {q} " or " có " in f" {q} " ) def _flip_closed_answer(answer: str) -> str: a = normalize_answer(answer) if a == "có": return "không" if a == "không": return "có" return a def _answer_category(question: str, answer: str) -> str: q = normalize_answer(question) a = normalize_answer(answer) if _is_closed_question(question, answer): return "closed" if any(term in q for term in ["ở đâu", "vi tri", "where"]): return "location" if any(term in a for term in ["trái", "phải", "trên", "dưới", "giữa", "bên"]): return "location" if any(term in a for term in ["mặt phẳng", "ngang", "vành", "dọc"]): return "plane" if any(term in a for term in ["gan", "phổi", "tim", "não", "thận", "lách", "bàng quang", "khí quản", "trung thất"]): return "organ" return "finding" def _build_answer_pools(data: list[dict], max_words: int) -> tuple[dict[str, list[str]], dict[str, list[str]]]: question_to_answers = {} category_to_answers = {} for item in data: question = item.get("question_vi", item.get("question", "")) answer = get_target_answer(item, max_words=max_words) if not question or not answer: continue q_norm = normalize_answer(question) a_norm = normalize_answer(answer) category = _answer_category(question, answer) question_to_answers.setdefault(q_norm, []) if a_norm not in question_to_answers[q_norm]: question_to_answers[q_norm].append(a_norm) category_to_answers.setdefault(category, []) if a_norm not in category_to_answers[category]: category_to_answers[category].append(a_norm) return question_to_answers, category_to_answers def _build_rejected_candidates( data: list[dict], idx: int, chosen: str, question_to_answers: dict[str, list[str]], category_to_answers: dict[str, list[str]], ) -> list[str]: item = data[idx] question = item.get("question_vi", item.get("question", "")) question_norm = normalize_answer(question) chosen_norm = normalize_answer(chosen) category = _answer_category(question, chosen) candidates = [] if _is_closed_question(question, chosen): flipped = _flip_closed_answer(chosen) if flipped and flipped != chosen_norm: candidates.append(flipped) else: for answer in question_to_answers.get(question_norm, []): if answer != chosen_norm: candidates.append(answer) for answer in category_to_answers.get(category, []): if answer != chosen_norm: candidates.append(answer) deduped = [] seen = set() for candidate in candidates: candidate_norm = normalize_answer(candidate) if not candidate_norm or candidate_norm == chosen_norm or candidate_norm in seen: continue seen.add(candidate_norm) deduped.append(candidate_norm) return deduped def _build_pair_record(item: dict, source_idx: int, chosen: str, rejected: str) -> dict: return { "image": item.get("image_name") or item.get("image"), "source_idx": source_idx, "question": item["question_vi"], "chosen": chosen, "rejected": rejected, "answer_type": _answer_category(item["question_vi"], chosen), } def _round_robin_merge(grouped_pairs: dict[str, list[dict]], target_count: int) -> list[dict]: ordered_groups = sorted(grouped_pairs.keys()) merged = [] while len(merged) < target_count: progressed = False for group in ordered_groups: if grouped_pairs[group]: merged.append(grouped_pairs[group].pop()) progressed = True if len(merged) >= target_count: break if not progressed: break return merged def create_preference_data( vqa_json_path, output_path, num_pairs=400, closed_ratio=0.6, max_answer_words=6, seed=42, ): """ Tạo dữ liệu Preference (Chosen vs Rejected) cho DPO. Trong Medical VQA, 'Rejected' thường là các câu trả lời bị hallucination hoặc sai thuật ngữ y khoa. """ with open(vqa_json_path, 'r', encoding='utf-8') as f: data = json.load(f) question_to_answers, category_to_answers = _build_answer_pools(data, max_words=max_answer_words) rng = random.Random(seed) closed_pairs = [] open_pairs_by_group = {"location": [], "plane": [], "organ": [], "finding": []} for i in range(len(data)): item = data[i] chosen = get_target_answer(item, max_words=max_answer_words) chosen_norm = normalize_answer(chosen) if not chosen_norm or len(chosen_norm.split()) > max_answer_words: continue rejected_candidates = _build_rejected_candidates( data, i, chosen_norm, question_to_answers=question_to_answers, category_to_answers=category_to_answers, ) category = _answer_category(item["question_vi"], chosen_norm) for rejected in rejected_candidates: if len(rejected.split()) > max_answer_words: continue pair = _build_pair_record(item, i, chosen_norm, rejected) if category == "closed": closed_pairs.append(pair) elif category in open_pairs_by_group: open_pairs_by_group[category].append(pair) rng.shuffle(closed_pairs) for pairs in open_pairs_by_group.values(): rng.shuffle(pairs) target_closed = min(len(closed_pairs), int(round(num_pairs * closed_ratio))) target_open = max(0, num_pairs - target_closed) sampled_closed = closed_pairs[:target_closed] sampled_open = _round_robin_merge(open_pairs_by_group, target_open) pref_data = sampled_closed + sampled_open rng.shuffle(pref_data) with open(output_path, 'w', encoding='utf-8') as f: json.dump(pref_data, f, ensure_ascii=False, indent=2) print( f"[SUCCESS] Đã tạo {len(pref_data)} cặp preference dữ liệu tại {output_path} " f"(closed={len(sampled_closed)}, open={len(sampled_open)})" ) preview_count = min(30, len(pref_data)) if preview_count: print(f"[INFO] Preview {preview_count} cặp preference đầu tiên để kiểm tra nhanh:") for idx, pair in enumerate(pref_data[:preview_count], start=1): print( f" [{idx:02d}] type={pair.get('answer_type')} | " f"Q={pair.get('question')} | chosen={pair.get('chosen')} | rejected={pair.get('rejected')}" ) return pref_data class MedicalDPOTrainer: """ Trainer cho Direct Preference Optimization (DPO) trên LLaVA-Med. Giúp tối ưu hóa mô hình dựa trên các cặp preference dữ liệu y tế. """ def __init__(self, model, reference_model, train_loader, optimizer, device, config): self.model = model self.reference_model = reference_model self.train_loader = train_loader self.optimizer = optimizer self.device = device self.config = config self.beta = config.get('dpo_beta', 0.1) def get_log_probs(self, logits, labels): """ Tính log probabilities cho các sequence. logits: [batch, seq_len, vocab] labels: [batch, seq_len] """ # Shift logits và labels để khớp (next token prediction) log_probs = F.log_softmax(logits, dim=-1) # Lấy log prob của các token đúng per_token_logps = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(2) # Chỉ lấy các token không phải padding (giả định mask > 0) return (per_token_logps * (labels != 0)).sum(-1) def compute_loss(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps): """ Tính DPO loss theo công thức: -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected))) """ pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps logits = pi_logratios - ref_logratios loss = -F.logsigmoid(self.beta * logits).mean() # Thêm các chỉ số để theo dõi (rewards) chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() return loss, chosen_rewards, rejected_rewards def train(self, epochs=3): print(f"[INFO] Bắt đầu huấn luyện DPO (beta={self.beta})...") self.model.train() self.reference_model.eval() # Freeze reference model để tiết kiệm VRAM (Quan trọng cho T4) for param in self.reference_model.parameters(): param.requires_grad_(False) print(f"[INFO] DPO Trainer Ready ({self.device})") for epoch in range(epochs): self.model.train() total_loss = 0.0 # Đã thêm dòng khởi tạo total_loss tại đây pbar = tqdm(self.train_loader, desc=f"DPO Epoch {epoch+1}") for batch in pbar: images = batch['image'].to(self.device) chosen_ids = batch['chosen_ids'].to(self.device) rejected_ids = batch['rejected_ids'].to(self.device) # Tính Logits cho Chosen và Rejected (Sử dụng Duck Typing/Safe Forward) try: # Case: LLaVA-style multimodal model outputs_w = self.model(input_ids=chosen_ids, pixel_values=images, labels=chosen_ids) outputs_l = self.model(input_ids=rejected_ids, pixel_values=images, labels=rejected_ids) logits_w = outputs_w.logits logits_l = outputs_l.logits except Exception: # Fallback: Modular model (A1/A2 style) _, logits_w = self.model(images, chosen_ids) _, logits_l = self.model(images, rejected_ids) # 2. Forward Reference Model (No Grad) with torch.no_grad(): try: # Multimodal case outputs_ref_w = self.reference_model(input_ids=chosen_ids, pixel_values=images, labels=chosen_ids) outputs_ref_l = self.reference_model(input_ids=rejected_ids, pixel_values=images, labels=rejected_ids) ref_logits_w = outputs_ref_w.logits ref_logits_l = outputs_ref_l.logits except Exception: # Modular case _, ref_logits_w = self.reference_model(images, chosen_ids) _, ref_logits_l = self.reference_model(images, rejected_ids) # 3. Tính log probs logps_w = self.get_log_probs(logits_w, chosen_ids) logps_l = self.get_log_probs(logits_l, rejected_ids) ref_logps_w = self.get_log_probs(ref_logits_w, chosen_ids) ref_logps_l = self.get_log_probs(ref_logits_l, rejected_ids) # 4. Tính Loss loss, _, _ = self.compute_loss(logps_w, logps_l, ref_logps_w, ref_logps_l) # 5. Backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() total_loss += loss.item() pbar.set_postfix({"loss": loss.item()}) print(f"Epoch {epoch+1} | DPO Loss: {total_loss/len(self.train_loader):.4f}")