import wandb import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, random_split from transformers import AutoTokenizer import yaml import argparse import os import random os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") # [Bypass CVE-2025-32434] Bỏ qua yêu cầu nâng cấp PyTorch 2.6 của transformers import transformers.utils.import_utils transformers.utils.import_utils.check_torch_load_is_safe = lambda: None import transformers.modeling_utils transformers.modeling_utils.check_torch_load_is_safe = lambda: None # [Bypass FSDPModule Error] Sửa lỗi thư viện trl import FSDPModule trên PyTorch cũ import torch.distributed.fsdp as fsdp if not hasattr(fsdp, "FSDPModule"): fsdp.FSDPModule = fsdp.FullyShardedDataParallel import csv import json from datetime import datetime from pathlib import Path from PIL import Image from datasets import load_dataset # Import các thành phần từ thư mục src from src.models.medical_vqa_model import MedicalVQAModelA from src.models.multimodal_vqa import MultimodalVQA from src.utils.visualization import MedicalImageTransform as MedicalTransform from src.data.medical_dataset import MedicalVQADataset from src.utils.text_utils import get_target_answer, normalize_answer, postprocess_answer def build_training_arguments(training_arguments_cls, **kwargs): """Create TrainingArguments across transformers versions.""" if "evaluation_strategy" in kwargs and "eval_strategy" not in kwargs: alias_kwargs = dict(kwargs) alias_kwargs["eval_strategy"] = alias_kwargs.pop("evaluation_strategy") try: return training_arguments_cls(**alias_kwargs) except TypeError as exc: if "eval_strategy" not in str(exc): raise return training_arguments_cls(**kwargs) def vqa_collate_fn(batch): """Hàm gom batch tùy chỉnh để xử lý ảnh PIL và raw text.""" elem = batch[0] collated = {} for key in elem.keys(): if key in ['image', 'input_ids', 'attention_mask', 'label_closed', 'target_ids', 'chosen_ids', 'rejected_ids']: collated[key] = torch.stack([item[key] for item in batch]) else: # Giữ nguyên list cho PIL images và raw text collated[key] = [item[key] for item in batch] return collated def flatten_dict(data, parent_key="", sep="."): items = {} for key, value in data.items(): new_key = f"{parent_key}{sep}{key}" if parent_key else str(key) if isinstance(value, dict): items.update(flatten_dict(value, new_key, sep=sep)) elif isinstance(value, (list, tuple)): continue else: items[new_key] = value return items def create_history_dir(base_log_dir, variant): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") history_dir = os.path.join(base_log_dir, "history", variant, timestamp) os.makedirs(history_dir, exist_ok=True) return history_dir def save_history_records(history_dir, records): os.makedirs(history_dir, exist_ok=True) json_path = os.path.join(history_dir, "history.json") csv_path = os.path.join(history_dir, "history.csv") with open(json_path, "w", encoding="utf-8") as f: json.dump(records, f, ensure_ascii=False, indent=2) flat_rows = [flatten_dict(record) for record in records] if flat_rows: fieldnames = sorted({key for row in flat_rows for key in row.keys()}) with open(csv_path, "w", encoding="utf-8", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(flat_rows) def select_best_adapter_checkpoint(checkpoint_root: str): checkpoint_root = Path(checkpoint_root) if not checkpoint_root.exists(): raise FileNotFoundError(f"Không tìm thấy thư mục checkpoint: {checkpoint_root}") def _is_valid_adapter_checkpoint(path: Path) -> bool: adapter_cfg = path / "adapter_config.json" adapter_weights = path / "adapter_model.safetensors" if not adapter_cfg.exists() or not adapter_weights.exists(): return False try: from safetensors import safe_open with safe_open(str(adapter_weights), framework="pt", device="cpu") as f: return len(f.keys()) > 0 except Exception as exc: print(f"[WARN] Bỏ qua checkpoint lỗi {path}: {exc}") return False checkpoint_dirs = sorted( p for p in checkpoint_root.glob("checkpoint-*") if _is_valid_adapter_checkpoint(p) ) if not checkpoint_dirs: raise FileNotFoundError(f"Không có adapter checkpoint hợp lệ trong {checkpoint_root}") for state_file in sorted(checkpoint_root.glob("checkpoint-*/trainer_state.json"), reverse=True): try: state = json.loads(state_file.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): continue best_path = state.get("best_model_checkpoint") if best_path: best_dir = Path(best_path.replace("./", "")) if not best_dir.is_absolute(): best_dir = Path.cwd() / best_dir if _is_valid_adapter_checkpoint(best_dir): return best_dir.resolve() return checkpoint_dirs[-1].resolve() def build_dpo_instruction_prompt(question: str, max_words: int = 10) -> str: question = str(question or "").strip() instruction = ( "Chi tra loi bang tieng Viet. " "Khong dung tieng Anh. " "Khong lap lai cau hoi. " "Khong mo ta hinh anh chung chung. " f"Chi tra loi truc tiep dap an, toi da {max_words} tu." ) return f"USER: \n{question}\n{instruction} ASSISTANT:" def load_latest_variant_metrics(history_root: str, variant: str) -> dict | None: variant_dir = Path(history_root) / variant if not variant_dir.exists(): return None history_files = sorted(variant_dir.glob("*/history.json")) if not history_files: return None for history_file in reversed(history_files): try: records = json.loads(history_file.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): continue if records: return records[-1] return None def evaluate_dpo_acceptance(b2_metrics: dict | None, dpo_metrics: dict) -> dict: if not b2_metrics: return { "status": "unknown", "reason": "missing_b2_metrics", "summary": "Khong tim thay metrics B2 de doi chieu.", } def pct_delta(key: str) -> float | None: b2_val = b2_metrics.get(key) dpo_val = dpo_metrics.get(key) if b2_val is None or dpo_val is None: return None return (dpo_val - b2_val) * 100.0 deltas = { "accuracy": pct_delta("val_accuracy_normalized"), "f1": pct_delta("val_f1_normalized"), "bleu4": pct_delta("val_bleu4_normalized"), "closed_acc": pct_delta("val_closed_accuracy"), "open_semantic": pct_delta("val_open_semantic"), "open_bert": pct_delta("val_open_bertscore"), } failed_drop = any( delta is not None and delta < -1.0 for delta in (deltas["accuracy"], deltas["f1"], deltas["bleu4"]) ) closed_ok = ( b2_metrics.get("val_closed_accuracy") is not None and dpo_metrics.get("val_closed_accuracy") is not None and dpo_metrics["val_closed_accuracy"] >= b2_metrics["val_closed_accuracy"] ) open_ok = ( b2_metrics.get("val_open_semantic") is not None and dpo_metrics.get("val_open_semantic") is not None and b2_metrics.get("val_open_bertscore") is not None and dpo_metrics.get("val_open_bertscore") is not None and dpo_metrics["val_open_semantic"] >= b2_metrics["val_open_semantic"] and (dpo_metrics["val_open_bertscore"] - b2_metrics["val_open_bertscore"]) * 100.0 >= -0.3 ) accepted = (not failed_drop) and (closed_ok or open_ok) def _fmt(delta: float | None) -> str: return "N/A" if delta is None else f"{delta:.2f}" summary = ( f"DPO vs B2 deltas (pp): Acc={_fmt(deltas['accuracy'])} | F1={_fmt(deltas['f1'])} | " f"BLEU={_fmt(deltas['bleu4'])} | Closed={_fmt(deltas['closed_acc'])} | " f"OpenSem={_fmt(deltas['open_semantic'])} | OpenBERT={_fmt(deltas['open_bert'])}" ) return { "status": "accepted" if accepted else "failed", "reason": "criteria_met" if accepted else "metric_drop_or_no_gain", "summary": summary, "deltas_pp": deltas, "closed_ok": closed_ok, "open_ok": open_ok, } def evaluate_refinement_acceptance(base_metrics: dict | None, rl_metrics: dict) -> dict: return evaluate_dpo_acceptance(base_metrics, rl_metrics) def sanitize_dpo_completion(question: str, answer: str, max_words: int = 10) -> str: question_norm = normalize_answer(question) answer_norm = postprocess_answer(answer, max_words=max_words) if answer_norm in {"yes", "có"}: return "có" if answer_norm in {"no", "không"}: return "không" is_closed = any( pattern in question_norm for pattern in ["bình thường", "bat thuong", "normal", "abnormal"] ) or question_norm.endswith(" không") or " có " in f" {question_norm} " if is_closed: if any(token in answer_norm for token in ["không", "no", "not normal", "abnormal"]): return "không" if any(token in answer_norm for token in ["có", "yes", "bình thường", "normal", "present", "detected"]): return "có" return answer_norm def resolve_dpo_image(item: dict, hf_train_data=None, image_dir: str | None = None): source_idx = item.get("source_idx") if hf_train_data is not None and source_idx is not None and 0 <= int(source_idx) < len(hf_train_data): img = hf_train_data[int(source_idx)].get("image") if img is not None and getattr(img, "mode", None) != "RGB": img = img.convert("RGB") return img image_name = item.get("image") if image_name and image_dir: img_path = os.path.join(image_dir, image_name) if os.path.exists(img_path): return Image.open(img_path).convert("RGB") return None def infer_closed_answer_type(item: dict, answer: str | None = None) -> bool: answer_norm = normalize_answer(answer if answer is not None else get_target_answer(item)) answer_type = str(item.get("answer_type", "")).strip().upper() label_closed = item.get("label_closed", None) if answer_type == "CLOSED" or label_closed in (0, 1): return True return answer_norm in {"có", "không", "yes", "no"} def move_model_batch_to_device(batch: dict, device: torch.device) -> dict: moved = {} for key, value in batch.items(): if hasattr(value, "to"): moved[key] = value.to(device) else: moved[key] = value return moved def build_multimodal_completion_batch(processor, prompts, completions, images, max_length=None): full_texts = [f"{prompt}{completion}" for prompt, completion in zip(prompts, completions)] batch = processor( text=full_texts, images=images, return_tensors="pt", padding=True, truncation=False, ) prompt_batch = processor( text=prompts, images=images, return_tensors="pt", padding=True, truncation=False, ) completion_mask = torch.zeros_like(batch["input_ids"], dtype=torch.long) prompt_lengths = prompt_batch["attention_mask"].sum(dim=1) for i, prompt_len in enumerate(prompt_lengths.tolist()): token_positions = batch["attention_mask"][i].nonzero(as_tuple=True)[0] completion_mask[i, token_positions[prompt_len:]] = 1 if max_length is not None and batch["input_ids"].shape[1] > max_length: batch["input_ids"] = batch["input_ids"][:, :max_length] batch["attention_mask"] = batch["attention_mask"][:, :max_length] completion_mask = completion_mask[:, :max_length] for key in ("token_type_ids", "mm_token_type_ids"): if key in batch: batch[key] = batch[key][:, :max_length] return batch, completion_mask def compute_masked_sequence_logprobs(model, batch, completion_mask): model_inputs = move_model_batch_to_device(batch, next(model.parameters()).device) completion_mask = completion_mask.to(model_inputs["input_ids"].device) outputs = model(**model_inputs) logits = outputs.logits[:, :-1, :] labels = model_inputs["input_ids"][:, 1:] token_mask = completion_mask[:, 1:].float() log_probs = F.log_softmax(logits, dim=-1) token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1) masked_log_probs = token_log_probs * token_mask denom = token_mask.sum(dim=1).clamp_min(1.0) seq_log_probs = masked_log_probs.sum(dim=1) / denom probs = log_probs.exp() token_entropy = -(probs * log_probs).sum(dim=-1) seq_entropy = (token_entropy * token_mask).sum(dim=1) / denom return seq_log_probs, seq_entropy def compute_single_open_reward(pred: str, ref: str) -> tuple[float, dict]: from src.utils.metrics import compute_exact_match, compute_f1, compute_rouge_l from src.utils import metrics as metrics_module norm_pred = normalize_answer(pred) or "." norm_ref = normalize_answer(ref) or "." exact = compute_exact_match(norm_pred, norm_ref) f1 = compute_f1(norm_pred, norm_ref) rouge_l = compute_rouge_l(norm_pred, norm_ref) bert = 0.0 scorer = getattr(metrics_module, "bert_scorer", None) if scorer is not None: try: _, _, bert_f1 = scorer.score([norm_pred], [norm_ref]) bert = float(bert_f1.mean().item()) except Exception: bert = 0.0 blended = (0.55 * bert) + (0.30 * f1) + (0.10 * rouge_l) + (0.05 * exact) reward = (2.0 * blended) - 1.0 return reward, { "bert": bert, "f1": f1, "rouge_l": rouge_l, "exact": exact, "blended": blended, } def train(args): # 1. Load Cấu hình with open(args.config, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # ── WandB Setup ────────────────────────────────────────────────────────── _wandb_cfg = config.get("wandb", {}) _use_wandb = bool(os.environ.get("WANDB_API_KEY") or os.environ.get("WANDB_MODE")) if _use_wandb: _api_key = os.environ.get("WANDB_API_KEY") if _api_key: wandb.login(key=_api_key) # Offline mode: set WANDB_MODE=offline hoặc config wandb.offline: true _offline = _wandb_cfg.get("offline", False) or \ os.environ.get("WANDB_MODE", "").lower() == "offline" if _offline: os.environ["WANDB_MODE"] = "offline" print("[INFO] WandB chạy ở chế độ OFFLINE (sync sau bằng: wandb sync)") # Tags theo variant từ YAML _tags = _wandb_cfg.get("tags", {}).get(args.variant, []) # Rich config ghi đầy đủ thông tin experiment _run_config = { # ── Model architecture ── "variant": args.variant, "decoder_type": config["model_a"].get("decoder_type"), "image_encoder": config["model_a"].get("image_encoder"), "text_encoder": config["model_a"].get("text_encoder"), "hidden_size": config["model_a"].get("hidden_size"), "transformer_heads": config["model_a"].get("transformer_heads"), "transformer_ff_dim": config["model_a"].get("transformer_ff_dim"), "transformer_layers": config["model_a"].get("transformer_decoder_layers"), "norm_first": config["model_a"].get("transformer_norm_first"), "freeze_phobert_layers": config["model_a"].get("freeze_phobert_layers"), # ── Training ── "learning_rate": config["train"].get("learning_rate"), "phobert_lr": config["train"].get("phobert_lr"), "vision_lr": config["train"].get("vision_lr"), "batch_size": config["train"].get("batch_size"), "grad_accum_steps": config["train"].get("gradient_accumulation_steps"), "effective_batch": config["train"].get("batch_size", 32) * config["train"].get("gradient_accumulation_steps", 1), "label_smoothing": config["train"].get("label_smoothing"), "open_loss_weight": config["train"].get("open_loss_weight"), "warmup_epochs": config["train"].get("warmup_epochs"), "scheduler": config["train"].get("scheduler"), "patience": config["train"].get("patience"), "use_amp": config["train"].get("use_amp"), # ── Data ── "dataset": config["data"].get("dataset_name"), "max_question_len": config["data"].get("max_question_len"), "max_answer_len": config["data"].get("max_answer_len"), # ── Eval ── "beam_width": config["eval"].get("beam_width_a") if args.variant in ("A1", "A2") else config["eval"].get("beam_width_b"), } # Thêm hardware info if torch.cuda.is_available(): _run_config["gpu_name"] = torch.cuda.get_device_name(0) _run_config["gpu_count"] = torch.cuda.device_count() _run_config["vram_gb"] = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1) _entity = _wandb_cfg.get("entity") or None # None = WandB dùng default entity wandb.init( project=_wandb_cfg.get("project", "MedicalVQA-Vietnam"), entity=_entity, name=f"{args.variant}-{datetime.now().strftime('%m%d-%H%M')}", group=_wandb_cfg.get("group", "DL-Final"), job_type=_wandb_cfg.get("job_type", "train"), tags=_tags, notes=_wandb_cfg.get("notes", ""), config=_run_config, save_code=_wandb_cfg.get("save_code", True), reinit="finish_previous", # Kết thúc run trước nếu chạy nhiều variant liên tiếp ) print(f"[INFO] ✅ WandB run: {wandb.run.url}") # Watch model gradients nếu được bật if _wandb_cfg.get("watch_model", False): # model chưa khởi tạo ở đây — hook sẽ được gọi sau khi model được tạo os.environ["_WANDB_WATCH_PENDING"] = "1" else: print("[INFO] WandB không được cấu hình (thiếu WANDB_API_KEY) — bỏ qua logging.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INFO] Thiết bị sử dụng: {device}") history_dir = create_history_dir(config.get("log_dir", "logs/medical_vqa"), args.variant) print(f"[INFO] Lưu training history tại: {history_dir}") # 2. Tokenizer & Dataset tokenizer = AutoTokenizer.from_pretrained(config['model_a']['phobert_model']) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token transform = MedicalTransform(size=config['data']['image_size']) answer_max_words = int(config['data'].get('answer_max_words', 10)) # Nạp dữ liệu từ HuggingFace Hub hoặc cục bộ hf_repo = config['data'].get('hf_dataset') use_hf_splits = bool(config['data'].get('use_hf_splits', True)) if hf_repo and use_hf_splits: print(f"[INFO] Đang tải dữ liệu từ Hub: {hf_repo}") dataset_dict = load_dataset(hf_repo) if args.debug: print("[WARNING] DEBUG MODE: Chỉ lấy 20 mẫu để chạy thử.") dataset_dict['train'] = dataset_dict['train'].select(range(min(20, len(dataset_dict['train'])))) config['train']['epochs'] = 2 config['train']['batch_size'] = 2 train_ds = MedicalVQADataset( hf_dataset=dataset_dict['train'], tokenizer=tokenizer, transform=transform, max_seq_len=config['data']['max_question_len'], max_ans_len=config['data']['max_answer_len'], answer_max_words=answer_max_words ) val_ds = MedicalVQADataset( hf_dataset=dataset_dict['validation'], tokenizer=tokenizer, transform=transform, max_seq_len=config['data']['max_question_len'], max_ans_len=config['data']['max_answer_len'], answer_max_words=answer_max_words ) else: vqa_path = config['data']['vqa_json'] print(f"[INFO] Đang tải dữ liệu cục bộ từ: {vqa_path}") full_dataset = MedicalVQADataset( json_path=vqa_path, image_dir=config['data']['image_dir'], tokenizer=tokenizer, transform=transform, max_seq_len=config['data']['max_question_len'], max_ans_len=config['data']['max_answer_len'], answer_max_words=answer_max_words ) train_size = int(0.8 * len(full_dataset)) val_size = len(full_dataset) - train_size train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) train_loader = DataLoader( train_ds, batch_size=config['train']['batch_size'], shuffle=True, collate_fn=vqa_collate_fn, num_workers=config['train'].get('num_workers', 0), pin_memory=config['train'].get('pin_memory', False) ) val_loader = DataLoader( val_ds, batch_size=config['train']['eval_batch_size'] if 'eval_batch_size' in config['train'] else 8, collate_fn=vqa_collate_fn ) # 3. Khởi tạo Mô hình dựa trên Variant if args.variant in ['A1', 'A2']: decoder_type = "lstm" if args.variant == 'A1' else "transformer" model = MedicalVQAModelA( decoder_type=decoder_type, vocab_size=len(tokenizer), hidden_size=config['model_a'].get('hidden_size', 768), phobert_model=config['model_a'].get('phobert_model', "vinai/phobert-base") ).to(device) # Log model param count lên WandB if wandb.run: total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) wandb.config.update({ "total_params_M": round(total_params / 1e6, 2), "trainable_params_M": round(trainable_params / 1e6, 2), }) print(f"[INFO] Tổng params: {total_params/1e6:.1f}M | Trainable: {trainable_params/1e6:.1f}M") # wandb.watch: chỉ bật nếu log_gradients: true if _wandb_cfg.get("log_gradients", False): wandb.watch(model, log="gradients", log_freq=_wandb_cfg.get("log_freq", 50)) # Thiết lập Optimizer với Differential Learning Rate optimizer = optim.AdamW([ {'params': model.image_encoder.parameters(), 'lr': float(config['train']['vision_lr'])}, {'params': model.text_encoder.parameters(), 'lr': float(config['train']['phobert_lr'])}, {'params': model.fusion.parameters(), 'lr': float(config['train']['learning_rate'])}, {'params': model.decoder.parameters(), 'lr': float(config['train']['learning_rate'])} ]) # [CRITICAL FIX] Dùng Cosine Schedule với Warmup, step theo batch thay vì epoch from transformers import get_cosine_schedule_with_warmup # Use a_epochs for Direction A models (A1, A2), otherwise use default epochs if args.variant in ['A1', 'A2']: epochs = config['train'].get('a_epochs', config['train']['epochs']) else: epochs = config['train']['epochs'] warmup_epochs = config['train'].get('warmup_epochs', 5) accumulation_steps = config['train'].get('gradient_accumulation_steps', 2) total_steps = epochs * len(train_loader) // max(accumulation_steps, 1) warmup_steps = warmup_epochs * len(train_loader) // max(accumulation_steps, 1) scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps ) # Khởi tạo Trainer với pad_token_id và beam_width từ config beam_width = config['eval'].get('beam_width_a', 5) from src.engine.trainer import MedicalVQATrainer trainer = MedicalVQATrainer( model=model, train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, scheduler=scheduler, device=device, config={ **config, 'variant': args.variant, 'history_dir': history_dir, # Pass tunable open-loss weight so trainer doesn't use hardcoded value 'open_loss_weight': config['train'].get('open_loss_weight', 2.0), }, pad_token_id=tokenizer.pad_token_id, beam_width=beam_width ) print(f"[INFO] Beam Width cho Hướng A: {beam_width}") print(f"[INFO] Bắt đầu huấn luyện cấu hình {args.variant} ({epochs} epochs)...") trainer.train(epochs, tokenizer=tokenizer) if wandb.run: wandb.finish() return elif args.variant == 'PPO': from src.engine.medical_eval import evaluate_multimodal_vqa ppo_cfg = config.get('ppo', {}) ppo_answer_max_words = int(ppo_cfg.get('max_answer_words', min(answer_max_words, 6))) wrapper = MultimodalVQA( model_id=config['model_b']['model_name'], lora_r=int(config['model_b'].get('lora_r', 16)), lora_alpha=int(config['model_b'].get('lora_alpha', 32)), lora_dropout=float(config['model_b'].get('lora_dropout', 0.05)), lora_target_modules=config['model_b'].get('lora_target_modules'), ) b2_checkpoint = select_best_adapter_checkpoint(config['train'].get('b2_output_dir', './checkpoints/B2')) print(f"[INFO] PPO sẽ khởi tạo từ B2 checkpoint: {b2_checkpoint}") model, processor = wrapper.load_model(adapter_path=str(b2_checkpoint), is_trainable=True) if not ppo_cfg.get('train_mlp_lora', False): frozen_lora = 0 for name, param in model.named_parameters(): if "lora_" in name and any(proj in name for proj in ("gate_proj", "up_proj", "down_proj")): param.requires_grad = False frozen_lora += param.numel() print(f"[INFO] PPO đang freeze LoRA MLP để giảm VRAM: {frozen_lora:,} tham số") model.print_trainable_parameters() def _build_ppo_source(): if hf_repo: return dataset_dict['train'], dataset_dict['train'] if hasattr(train_ds, "dataset") and hasattr(train_ds.dataset, "data"): subset_indices = getattr(train_ds, "indices", list(range(len(train_ds.dataset.data)))) local_items = [train_ds.dataset.data[i] for i in subset_indices] return local_items, None raise ValueError("Khong the truy cap raw train data de tao PPO rollout set.") def _prepare_ppo_records(raw_items, num_samples: int, closed_ratio: float): closed_records = [] open_records = [] for idx in range(len(raw_items)): item = raw_items[idx] question = str(item.get("question_vi", item.get("question", ""))).strip() target = get_target_answer(item, max_words=ppo_answer_max_words) if not question or not target: continue record = { "question": question, "target": target, "source_idx": idx, "image": item.get("image_name"), "is_closed": infer_closed_answer_type(item, target), } if record["is_closed"]: closed_records.append(record) else: open_records.append(record) rng = random.Random(int(config.get("seed", 42))) rng.shuffle(closed_records) rng.shuffle(open_records) target_closed = min(len(closed_records), int(round(num_samples * closed_ratio))) target_open = min(len(open_records), max(0, num_samples - target_closed)) selected = closed_records[:target_closed] + open_records[:target_open] rng.shuffle(selected) return selected raw_train_source, hf_train_source = _build_ppo_source() ppo_records = _prepare_ppo_records( raw_train_source, num_samples=int(ppo_cfg.get('num_samples', 192)), closed_ratio=float(ppo_cfg.get('closed_ratio', 0.5)), ) if not ppo_records: raise ValueError("Khong tao duoc PPO rollout set hop le.") print(f"[INFO] PPO rollout set: {len(ppo_records)} mau") trainable_params = [param for param in model.parameters() if param.requires_grad] optimizer = optim.AdamW( trainable_params, lr=float(ppo_cfg.get('learning_rate', 5.0e-7)), weight_decay=float(ppo_cfg.get('weight_decay', 0.0)), ) rollout_batch_size = max(1, int(ppo_cfg.get('rollout_batch_size', 2))) total_updates = max(1, (len(ppo_records) + rollout_batch_size - 1) // rollout_batch_size) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_updates) ppo_history = [] eos = processor.tokenizer.eos_token or "" max_seq_length = max(int(config['train'].get('dpo_max_length', 768)), 768) grad_clip = float(config['train'].get('grad_clip', 1.0)) entropy_coef = float(ppo_cfg.get('entropy_coef', 0.001)) clip_range = float(ppo_cfg.get('clip_range', 0.2)) max_new_tokens = int(ppo_cfg.get('max_new_tokens', 12)) temperature = float(ppo_cfg.get('temperature', 0.8)) top_p = float(ppo_cfg.get('top_p', 0.9)) closed_positive = float(ppo_cfg.get('closed_positive_reward', 1.0)) closed_negative = float(ppo_cfg.get('closed_negative_reward', -1.0)) print("[INFO] Bắt đầu huấn luyện PPO-style refinement...") model.train() for update_idx in range(total_updates): batch_records = ppo_records[update_idx * rollout_batch_size:(update_idx + 1) * rollout_batch_size] prompts, images, questions, targets, closed_flags = [], [], [], [], [] for record in batch_records: image = resolve_dpo_image( record, hf_train_data=hf_train_source, image_dir=config['data'].get('image_dir'), ) if image is None: continue prompts.append(build_dpo_instruction_prompt(record["question"], max_words=ppo_answer_max_words)) images.append(image) questions.append(record["question"]) targets.append(record["target"]) closed_flags.append(record["is_closed"]) if not prompts: continue generation_inputs = processor( text=prompts, images=images, return_tensors="pt", padding=True, ) generation_inputs = move_model_batch_to_device(generation_inputs, next(model.parameters()).device) if "pixel_values" in generation_inputs: generation_inputs["pixel_values"] = generation_inputs["pixel_values"].to(torch.bfloat16) with torch.no_grad(): generated_ids = model.generate( **generation_inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, num_beams=1, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, ) prompt_token_len = generation_inputs["input_ids"].shape[1] generated_texts = processor.batch_decode( generated_ids[:, prompt_token_len:], skip_special_tokens=True, ) sampled_answers = [] rewards = [] reward_breakdown = [] for question, target, is_closed, raw_output in zip(questions, targets, closed_flags, generated_texts): pred = sanitize_dpo_completion(question, raw_output, max_words=ppo_answer_max_words) if not pred: pred = "không" if is_closed else "không rõ" sampled_answers.append(pred) if is_closed: reward = closed_positive if normalize_answer(pred) == normalize_answer(target) else closed_negative rewards.append(reward) reward_breakdown.append({"exact": float(reward > 0), "reward": reward}) else: reward, details = compute_single_open_reward(pred, target) rewards.append(reward) reward_breakdown.append(details | {"reward": reward}) completion_texts = [f" {pred}{eos}" for pred in sampled_answers] rollout_batch, rollout_mask = build_multimodal_completion_batch( processor, prompts, completion_texts, images, max_length=max_seq_length, ) with torch.no_grad(): old_seq_log_probs, _ = compute_masked_sequence_logprobs(model, rollout_batch, rollout_mask) reward_tensor = torch.tensor(rewards, dtype=torch.float32, device=old_seq_log_probs.device) if reward_tensor.numel() > 1: advantages = reward_tensor - reward_tensor.mean() advantages = advantages / advantages.std(unbiased=False).clamp_min(1e-6) else: advantages = reward_tensor optimizer.zero_grad(set_to_none=True) new_seq_log_probs, entropy = compute_masked_sequence_logprobs(model, rollout_batch, rollout_mask) ratios = torch.exp(new_seq_log_probs - old_seq_log_probs.detach()) clipped_ratios = torch.clamp(ratios, 1.0 - clip_range, 1.0 + clip_range) surrogate_1 = ratios * advantages surrogate_2 = clipped_ratios * advantages policy_loss = -torch.min(surrogate_1, surrogate_2).mean() entropy_bonus = entropy.mean() loss = policy_loss - (entropy_coef * entropy_bonus) loss.backward() torch.nn.utils.clip_grad_norm_(trainable_params, grad_clip) optimizer.step() scheduler.step() closed_rewards = [r for r, is_closed in zip(rewards, closed_flags) if is_closed] open_rewards = [r for r, is_closed in zip(rewards, closed_flags) if not is_closed] log_record = { "epoch": 1, "update": update_idx + 1, "train_loss": float(loss.detach().cpu().item()), "policy_loss": float(policy_loss.detach().cpu().item()), "entropy": float(entropy_bonus.detach().cpu().item()), "avg_reward": float(sum(rewards) / len(rewards)), "avg_closed_reward": float(sum(closed_rewards) / len(closed_rewards)) if closed_rewards else None, "avg_open_reward": float(sum(open_rewards) / len(open_rewards)) if open_rewards else None, "learning_rate": float(scheduler.get_last_lr()[0]), "sample_predictions": sampled_answers[:2], "sample_targets": targets[:2], "reward_breakdown": reward_breakdown[:2], } ppo_history.append(log_record) if wandb.run: wandb.log({ "ppo/train_loss": log_record["train_loss"], "ppo/policy_loss": log_record["policy_loss"], "ppo/entropy": log_record["entropy"], "ppo/avg_reward": log_record["avg_reward"], "ppo/avg_closed_reward": log_record["avg_closed_reward"], "ppo/avg_open_reward": log_record["avg_open_reward"], "ppo/learning_rate": log_record["learning_rate"], "ppo/update": log_record["update"], }) del generation_inputs, generated_ids if torch.cuda.is_available(): torch.cuda.empty_cache() final_ppo_dir = Path("checkpoints/PPO/final_adapter") final_ppo_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(final_ppo_dir)) processor.save_pretrained(str(final_ppo_dir)) with open("checkpoints/medical_vqa_ppo_from.txt", "w", encoding="utf-8") as f: f.write(str(b2_checkpoint)) print("[INFO] Đang chạy đánh giá nghiệm thu trên tập Validation cho PPO...") model.eval() metrics = evaluate_multimodal_vqa( model, val_loader, device, processor, beam_width=config['eval'].get('beam_width_b', 1), beam_width_closed=config['eval'].get('beam_width_b_closed', 1), beam_width_open=config['eval'].get('beam_width_b_open', config['eval'].get('beam_width_b', 1)), max_new_tokens_closed=config['eval'].get('max_new_tokens_b_closed', 4), max_new_tokens_open=config['eval'].get('max_new_tokens_b_open', answer_max_words + 6), generation_batch_size=config['eval'].get('generation_batch_size_b', 1), max_words=answer_max_words, variant='PPO' ) closed_eval = metrics.get('closed_eval', {}) open_eval = metrics.get('open_eval', {}) ppo_history.append({ "epoch": 1, "val_accuracy_normalized": metrics.get('accuracy_normalized'), "val_f1_normalized": metrics.get('f1_normalized'), "val_bleu4_normalized": metrics.get('bleu4_normalized'), "val_bert_score_raw": metrics.get('bert_score_raw'), "val_semantic_raw": metrics.get('semantic_raw'), "val_closed_accuracy": closed_eval.get('accuracy', 0), "val_closed_em": closed_eval.get('em', 0), "val_closed_f1": closed_eval.get('f1', 0), "val_open_semantic": open_eval.get('semantic', 0), "val_open_bertscore": open_eval.get('bert_score', 0), "val_open_f1": open_eval.get('f1', 0), "val_open_rouge_l": open_eval.get('rouge_l', 0), }) b2_metrics = load_latest_variant_metrics(os.path.join(config['log_dir'], "history"), "B2") ppo_acceptance = evaluate_refinement_acceptance(b2_metrics, ppo_history[-1]) ppo_history[-1]["ppo_acceptance"] = ppo_acceptance print(f"[INFO] {ppo_acceptance['summary']}") if ppo_acceptance["status"] == "accepted": print("[SUCCESS] PPO accepted: dat tieu chi refinement nhe tren B2.") elif ppo_acceptance["status"] == "failed": print("[WARN] PPO failed, keep B2. Khong khuyen nghi tiep tuc tuning them.") os.makedirs("checkpoints/PPO", exist_ok=True) with open("checkpoints/PPO/acceptance_summary.json", "w", encoding="utf-8") as f: json.dump(ppo_acceptance, f, ensure_ascii=False, indent=2) save_history_records(history_dir, ppo_history) print("[SUCCESS] Đã lưu checkpoint và metrics PPO.") return elif args.variant == 'DPO': from trl import DPOTrainer try: from trl import DPOConfig except ImportError: DPOConfig = None from transformers import TrainingArguments from datasets import Dataset as HFDataset import inspect dpo_answer_max_words = int(config.get('dpo', {}).get('max_answer_words', min(answer_max_words, 6))) wrapper = MultimodalVQA( model_id=config['model_b']['model_name'], lora_r=int(config['model_b'].get('lora_r', 16)), lora_alpha=int(config['model_b'].get('lora_alpha', 32)), lora_dropout=float(config['model_b'].get('lora_dropout', 0.05)), lora_target_modules=config['model_b'].get('lora_target_modules'), ) explicit_b2_checkpoint = ( config.get('train', {}).get('b2_checkpoint') or os.environ.get('B2_CHECKPOINT_PATH') ) if explicit_b2_checkpoint: b2_checkpoint = Path(explicit_b2_checkpoint).expanduser().resolve() if not b2_checkpoint.exists(): raise FileNotFoundError(f"Không tìm thấy B2 checkpoint được chỉ định: {b2_checkpoint}") print(f"[INFO] DPO sẽ khởi tạo từ B2 checkpoint chỉ định: {b2_checkpoint}") else: b2_checkpoint = select_best_adapter_checkpoint(config['train'].get('b2_output_dir', './checkpoints/B2')) print(f"[INFO] DPO sẽ khởi tạo từ B2 checkpoint: {b2_checkpoint}") try: model, processor = wrapper.load_model(adapter_path=str(b2_checkpoint), is_trainable=True) except Exception as exc: print(f"[WARNING] Không load được B2 checkpoint, fallback sang base LLaVA-Med + LoRA mới: {exc}") model, processor = wrapper.load_model(adapter_path=None, is_trainable=True) if not config['train'].get('dpo_train_mlp_lora', False): frozen_lora = 0 for name, param in model.named_parameters(): if "lora_" in name and any(proj in name for proj in ("gate_proj", "up_proj", "down_proj")): param.requires_grad = False frozen_lora += param.numel() print(f"[INFO] DPO đang freeze LoRA MLP để giảm VRAM: {frozen_lora:,} tham số") model.print_trainable_parameters() # Tạo/Load Preference Data pref_json = config.get('dpo', {}).get('preference_data', 'data/preference_data_slake.json') force_rebuild_pref = bool(config.get('dpo', {}).get('force_rebuild_preference_data', False)) if force_rebuild_pref and os.path.exists(pref_json): print(f"[INFO] Dang xoa preference data cu de tao lai theo cau hinh hien tai: {pref_json}") os.remove(pref_json) if not os.path.exists(pref_json): print(f"[INFO] Chưa có preference data. Đang tự động tạo từ training data...") from src.engine.dpo_trainer import create_preference_data if hf_repo: raw_data = [{"question_vi": item["question_vi"], "answer_vi": get_target_answer(item, max_words=dpo_answer_max_words), "image_name": item.get("image_name"), "source_idx": i} for i, item in enumerate(dataset_dict['train'])] tmp_json = "data/tmp_train_for_dpo.json" os.makedirs("data", exist_ok=True) with open(tmp_json, 'w', encoding='utf-8') as f: json.dump(raw_data, f, ensure_ascii=False, indent=2) create_preference_data( tmp_json, pref_json, num_pairs=int(config.get('dpo', {}).get('num_pairs', 400)), closed_ratio=float(config.get('dpo', {}).get('closed_ratio', 0.6)), max_answer_words=dpo_answer_max_words, ) else: create_preference_data( config['data']['vqa_json'], pref_json, num_pairs=int(config.get('dpo', {}).get('num_pairs', 400)), closed_ratio=float(config.get('dpo', {}).get('closed_ratio', 0.6)), max_answer_words=dpo_answer_max_words, ) # Đọc file JSON preference data with open(pref_json, 'r', encoding='utf-8') as f: pref_data = json.load(f) if hf_repo and any("source_idx" not in item for item in pref_data): print("[INFO] Preference data cu khong co source_idx. Dang tao lai de giu lien ket image cho DPO...") from src.engine.dpo_trainer import create_preference_data raw_data = [{"question_vi": item["question_vi"], "answer_vi": get_target_answer(item, max_words=dpo_answer_max_words), "image_name": item.get("image_name"), "source_idx": i} for i, item in enumerate(dataset_dict['train'])] tmp_json = "data/tmp_train_for_dpo.json" with open(tmp_json, 'w', encoding='utf-8') as f: json.dump(raw_data, f, ensure_ascii=False, indent=2) create_preference_data( tmp_json, pref_json, num_pairs=int(config.get('dpo', {}).get('num_pairs', 400)), closed_ratio=float(config.get('dpo', {}).get('closed_ratio', 0.6)), max_answer_words=dpo_answer_max_words, ) with open(pref_json, 'r', encoding='utf-8') as f: pref_data = json.load(f) # Chuẩn bị HF Dataset cho DPOTrainer (yêu cầu cột: prompt, chosen, rejected) prompts, chosens, rejecteds, images = [], [], [], [] eos = processor.tokenizer.eos_token or "" filtered_pairs = 0 for item in pref_data: q = item.get("question", "") chosen = sanitize_dpo_completion(q, item.get("chosen", ""), max_words=dpo_answer_max_words) rejected = sanitize_dpo_completion(q, item.get("rejected", ""), max_words=dpo_answer_max_words) image = resolve_dpo_image( item, hf_train_data=dataset_dict['train'] if hf_repo else None, image_dir=config['data'].get('image_dir'), ) if not chosen or not rejected or chosen == rejected or image is None: filtered_pairs += 1 continue prompts.append(build_dpo_instruction_prompt(q, max_words=dpo_answer_max_words)) chosens.append(f" {chosen}{eos}") rejecteds.append(f" {rejected}{eos}") images.append(image) if not prompts: raise ValueError("Khong con cap preference hop le sau khi sanitize DPO data.") if filtered_pairs: print(f"[INFO] Da bo qua {filtered_pairs} cap preference khong hop le sau sanitize.") dpo_hf_dataset = HFDataset.from_dict({ "prompt": prompts, "chosen": chosens, "rejected": rejecteds, "image": images, }) class MultimodalDPODataCollator: def __init__(self, processor, max_length=None): self.processor = processor self.tokenizer = processor.tokenizer # LLaVA expands a single placeholder into hundreds of visual tokens. # If max_length is too small, the processor truncates those tokens and raises # "image token count" mismatch. Keep a safe floor for multimodal DPO. self.max_length = max(max_length or 0, 768) if max_length is not None else None def __call__(self, examples): prompts = [example["prompt"] for example in examples] chosens = [example["chosen"] for example in examples] rejecteds = [example["rejected"] for example in examples] images = [example["image"] for example in examples] full_texts = [f"{prompt}{chosen}" for prompt, chosen in zip(prompts, chosens)] full_texts.extend(f"{prompt}{rejected}" for prompt, rejected in zip(prompts, rejecteds)) repeated_prompts = prompts + prompts repeated_images = images + images batch = self.processor( text=full_texts, images=repeated_images, return_tensors="pt", padding=True, truncation=False, ) prompt_batch = self.processor( text=repeated_prompts, images=repeated_images, return_tensors="pt", padding=True, truncation=False, ) completion_mask = torch.zeros_like(batch["input_ids"], dtype=torch.long) prompt_lengths = prompt_batch["attention_mask"].sum(dim=1) for i, prompt_len in enumerate(prompt_lengths.tolist()): token_positions = batch["attention_mask"][i].nonzero(as_tuple=True)[0] completion_mask[i, token_positions[prompt_len:]] = 1 if self.max_length is not None and batch["input_ids"].shape[1] > self.max_length: batch["input_ids"] = batch["input_ids"][:, :self.max_length] batch["attention_mask"] = batch["attention_mask"][:, :self.max_length] completion_mask = completion_mask[:, :self.max_length] for key in ("token_type_ids", "mm_token_type_ids"): if key in batch: batch[key] = batch[key][:, :self.max_length] batch["completion_mask"] = completion_mask return batch dpo_sequence_limits = { "max_length": max(int(config['train'].get('dpo_max_length', 768)), 768), "max_prompt_length": int(config['train'].get('dpo_max_prompt_length', 96)), "max_completion_length": int(config['train'].get('dpo_max_completion_length', 24)), } training_args_dict = { "output_dir": "./checkpoints/DPO", "per_device_train_batch_size": int(config['train'].get('dpo_batch_size', 1)), "gradient_accumulation_steps": int(config['train'].get('dpo_gradient_accumulation_steps', 8)), "num_train_epochs": config['train'].get('dpo_epochs', 1), "learning_rate": float(config.get('dpo', {}).get('learning_rate', 1.0e-6)), "lr_scheduler_type": "cosine", # [OPTIMIZED] Giúp hội tụ mượt mà hơn "warmup_ratio": 0.1, # [OPTIMIZED] Tránh sốc gradient ở epoch đầu "bf16": True, "remove_unused_columns": False, "logging_steps": 10, "save_strategy": "epoch", "save_total_limit": 1, "optim": config['train'].get('dpo_optim', 'paged_adamw_8bit'), "gradient_checkpointing": True, } if DPOConfig is not None: training_args_dict["beta"] = float(config.get('dpo', {}).get('beta', 0.1)) dpo_config_params = set(inspect.signature(DPOConfig.__init__).parameters) for key, value in dpo_sequence_limits.items(): if key in dpo_config_params: training_args_dict[key] = value training_args = DPOConfig(**training_args_dict) else: training_args = build_training_arguments(TrainingArguments, **training_args_dict) training_args.model_init_kwargs = None dpo_kwargs = { "model": model, "args": training_args, "train_dataset": dpo_hf_dataset, "data_collator": MultimodalDPODataCollator(processor, max_length=dpo_sequence_limits["max_length"]), } dpo_trainer_params = set(inspect.signature(DPOTrainer.__init__).parameters) for key, value in dpo_sequence_limits.items(): if key in dpo_trainer_params: dpo_kwargs[key] = value try: print("[INFO] Thử khởi tạo DPOTrainer với processing_class...") trainer = DPOTrainer(**dpo_kwargs, processing_class=processor) except TypeError: try: trainer = DPOTrainer(**dpo_kwargs, tokenizer=processor) except TypeError: trainer = DPOTrainer(**dpo_kwargs, tokenizer=processor.tokenizer) print("[INFO] Bắt đầu huấn luyện DPO...") trainer.train() os.makedirs("checkpoints", exist_ok=True) final_dpo_dir = Path("checkpoints/DPO/final_adapter") final_dpo_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(final_dpo_dir)) processor.save_pretrained(str(final_dpo_dir)) with open("checkpoints/medical_vqa_dpo_from.txt", "w", encoding="utf-8") as f: f.write(str(b2_checkpoint)) # [FIX] Đánh giá DPO sau khi train xong để có Accuracy, F1, BLEU cho biểu đồ so sánh from src.engine.medical_eval import evaluate_multimodal_vqa print("[INFO] Đang chạy đánh giá nghiệm thu trên tập Validation cho DPO...") model.eval() metrics = evaluate_multimodal_vqa( model, val_loader, device, processor, beam_width=config['eval'].get('beam_width_b', 1), beam_width_closed=config['eval'].get('beam_width_b_closed', 1), beam_width_open=config['eval'].get('beam_width_b_open', config['eval'].get('beam_width_b', 1)), max_new_tokens_closed=config['eval'].get('max_new_tokens_b_closed', 4), max_new_tokens_open=config['eval'].get('max_new_tokens_b_open', answer_max_words + 6), generation_batch_size=config['eval'].get('generation_batch_size_b', 1), max_words=answer_max_words, variant='DPO' ) closed_eval = metrics.get('closed_eval', {}) open_eval = metrics.get('open_eval', {}) print(f"\n[RESULT DPO - CLOSED QUESTIONS]") print(f"Count: {closed_eval.get('count', 0)}") print(f"Accuracy: {closed_eval.get('accuracy', 0):.4f}") print(f"EM: {closed_eval.get('em', 0):.4f}") print(f"F1: {closed_eval.get('f1', 0):.4f}") print(f"\n[RESULT DPO - OPEN QUESTIONS]") print(f"Count: {open_eval.get('count', 0)}") print(f"Semantic: {open_eval.get('semantic', 0):.4f}") print(f"BERTScore: {open_eval.get('bert_score', 0):.4f}") print(f"F1: {open_eval.get('f1', 0):.4f}") print(f"ROUGE-L: {open_eval.get('rouge_l', 0):.4f}") final_epoch = training_args.num_train_epochs trainer.state.log_history.append({ "epoch": final_epoch, "val_accuracy_normalized": metrics.get('accuracy_normalized'), "val_f1_normalized": metrics.get('f1_normalized'), "val_bleu4_normalized": metrics.get('bleu4_normalized'), "val_bert_score_raw": metrics.get('bert_score_raw'), "val_semantic_raw": metrics.get('semantic_raw'), "val_closed_accuracy": closed_eval.get('accuracy', 0), "val_closed_em": closed_eval.get('em', 0), "val_closed_f1": closed_eval.get('f1', 0), "val_open_semantic": open_eval.get('semantic', 0), "val_open_bertscore": open_eval.get('bert_score', 0), "val_open_f1": open_eval.get('f1', 0), "val_open_rouge_l": open_eval.get('rouge_l', 0), }) b2_metrics = load_latest_variant_metrics(os.path.join(config['log_dir'], "history"), "B2") dpo_acceptance = evaluate_dpo_acceptance(b2_metrics, trainer.state.log_history[-1]) trainer.state.log_history[-1]["dpo_acceptance"] = dpo_acceptance print(f"[INFO] {dpo_acceptance['summary']}") if dpo_acceptance["status"] == "accepted": print("[SUCCESS] DPO accepted: dat tieu chi refinement nhe tren B2.") elif dpo_acceptance["status"] == "failed": print("[WARN] DPO failed, keep B2. Khong khuyen nghi tiep tuc tuning them.") os.makedirs("checkpoints/DPO", exist_ok=True) with open("checkpoints/DPO/acceptance_summary.json", "w", encoding="utf-8") as f: json.dump(dpo_acceptance, f, ensure_ascii=False, indent=2) save_history_records(history_dir, trainer.state.log_history) print("[SUCCESS] Đã lưu checkpoint và metrics DPO.") return elif args.variant == 'B2': # Fine-tuning LLaVA-Med from transformers import TrainingArguments, Trainer from datasets import Dataset as HFDataset wrapper = MultimodalVQA( model_id=config['model_b']['model_name'], lora_r=int(config['model_b'].get('lora_r', 16)), lora_alpha=int(config['model_b'].get('lora_alpha', 32)), lora_dropout=float(config['model_b'].get('lora_dropout', 0.05)), lora_target_modules=config['model_b'].get('lora_target_modules'), ) model, processor = wrapper.load_model() def make_sft_dataset(raw_ds): prompts = [] answers = [] texts = [] images = [] for i in range(len(raw_ds)): item = raw_ds[i] if isinstance(item, dict): q = item.get("question_vi", item.get("question", item.get("raw_questions", ""))) a = get_target_answer(item, max_words=answer_max_words) answer_type = str(item.get("answer_type", "")).upper() label_closed = item.get("label_closed", None) if answer_type == "CLOSED" or label_closed in (0, 1) or a in {"có", "không", "yes", "no"}: a_norm = str(a).strip().lower() a = "không" if a_norm in {"không", "khong", "no", "false", "absent"} else "có" prompt = wrapper.build_instruction_prompt(q, language="vi", include_answer=False) prompts.append(prompt) answers.append(a) eos = processor.tokenizer.eos_token or "" texts.append(f"{prompt} {a}{eos}") img = item.get("image", None) if img is not None: if img.mode != "RGB": img = img.convert("RGB") images.append(img) return HFDataset.from_dict({"prompt": prompts, "answer": answers, "text": texts, "image": images}) if hf_repo: sft_train = make_sft_dataset(dataset_dict['train']) sft_val = make_sft_dataset(dataset_dict['validation']) else: sft_train = make_sft_dataset(train_ds) sft_val = make_sft_dataset(val_ds) class MultimodalDataCollator: def __init__(self, processor, max_length=None): self.processor = processor self.tokenizer = processor.tokenizer self.max_length = max_length def __call__(self, examples): texts = [example["text"] for example in examples] prompts = [example["prompt"] for example in examples] images = [example["image"] for example in examples] batch = self.processor( text=texts, images=images, return_tensors="pt", padding=True, ) labels = batch["input_ids"].clone() labels[labels == self.tokenizer.pad_token_id] = -100 # Mask the full prompt so SFT loss is computed only on the answer. # Searching for "ASSISTANT:" token ids is brittle because tokenization can # split the separator differently across models. prompt_batch = self.processor( text=prompts, images=images, return_tensors="pt", padding=True, ) prompt_lengths = prompt_batch["attention_mask"].sum(dim=1) for i, prompt_len in enumerate(prompt_lengths.tolist()): token_positions = batch["attention_mask"][i].nonzero(as_tuple=True)[0] labels[i, token_positions[:prompt_len]] = -100 batch["labels"] = labels # Remove text and image lists as Trainer only wants tensors return batch b2_micro_batch = int(config['train'].get('b2_batch_size', 1)) b2_grad_accum = int(config['train'].get('b2_gradient_accumulation_steps', max(config['train'].get('gradient_accumulation_steps', 2), 1))) b2_max_length = int(config['train'].get('b2_max_length', config['data'].get('max_question_len', 64) + config['data'].get('max_answer_len', 20) + 32)) training_args = build_training_arguments( TrainingArguments, output_dir="./checkpoints/B2", per_device_train_batch_size=b2_micro_batch, per_device_eval_batch_size=int(config['train'].get('b2_eval_batch_size', 1)), gradient_accumulation_steps=b2_grad_accum, num_train_epochs=config['train'].get('epochs', 3), learning_rate=float(config['train'].get('b2_lr', 2.0e-5)), lr_scheduler_type="cosine", warmup_steps=int(config['train'].get('b2_warmup_steps', 50)), bf16=True, fp16=False, gradient_checkpointing=True, remove_unused_columns=False, logging_steps=10, evaluation_strategy="epoch", save_strategy="epoch", save_total_limit=2, optim=config['train'].get('b2_optim', 'paged_adamw_8bit'), max_grad_norm=float(config['train'].get('grad_clip', 1.0)), dataloader_num_workers=int(config['train'].get('b2_num_workers', 4)), dataloader_pin_memory=bool(config['train'].get('pin_memory', True)), load_best_model_at_end=config['train'].get('b2_load_best_model_at_end', True), metric_for_best_model=config['train'].get('b2_metric_for_best', 'eval_loss'), greater_is_better=False, ) training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} trainer = Trainer( model=model, args=training_args, train_dataset=sft_train, eval_dataset=sft_val, data_collator=MultimodalDataCollator(processor, max_length=b2_max_length) ) trainer.train() # [FIX] Đánh giá B2 sau khi train xong để có Accuracy, F1, BLEU cho biểu đồ so sánh from src.engine.medical_eval import evaluate_multimodal_vqa print("[INFO] Đang chạy đánh giá nghiệm thu trên tập Validation cho B2...") # Đưa model về evaluation mode model.eval() metrics = evaluate_multimodal_vqa( model, val_loader, device, processor, beam_width=config['eval'].get('beam_width_b', 1), beam_width_closed=config['eval'].get('beam_width_b_closed', 1), beam_width_open=config['eval'].get('beam_width_b_open', config['eval'].get('beam_width_b', 1)), max_new_tokens_closed=config['eval'].get('max_new_tokens_b_closed', 4), max_new_tokens_open=config['eval'].get('max_new_tokens_b_open', answer_max_words + 6), generation_batch_size=config['eval'].get('generation_batch_size_b', 1), max_words=answer_max_words, variant='B2' ) closed_eval = metrics.get('closed_eval', {}) open_eval = metrics.get('open_eval', {}) print(f"\n[RESULT B2 - CLOSED QUESTIONS]") print(f"Count: {closed_eval.get('count', 0)}") print(f"Accuracy: {closed_eval.get('accuracy', 0):.4f}") print(f"EM: {closed_eval.get('em', 0):.4f}") print(f"F1: {closed_eval.get('f1', 0):.4f}") print(f"\n[RESULT B2 - OPEN QUESTIONS]") print(f"Count: {open_eval.get('count', 0)}") print(f"Semantic: {open_eval.get('semantic', 0):.4f}") print(f"BERTScore: {open_eval.get('bert_score', 0):.4f}") print(f"F1: {open_eval.get('f1', 0):.4f}") print(f"ROUGE-L: {open_eval.get('rouge_l', 0):.4f}") if 'long_answers_eval' in metrics: print(f"\n[RESULT B2 - LONG METRICS]") print(f"Accuracy: {metrics['long_answers_eval'].get('accuracy', 0):.4f}") print(f"F1: {metrics['long_answers_eval'].get('f1', 0):.4f}") print(f"Semantic: {metrics['long_answers_eval'].get('semantic', 0):.4f}") print(f"BERTScore: {metrics['long_answers_eval'].get('bert_score', 0):.4f}") # Gắn thêm vào log_history cho wandb trainer.state.log_history.append({ "epoch": training_args.num_train_epochs, "val_long_accuracy": metrics['long_answers_eval'].get('accuracy', 0), "val_long_f1": metrics['long_answers_eval'].get('f1', 0), "val_long_semantic": metrics['long_answers_eval'].get('semantic', 0), "val_long_bertscore": metrics['long_answers_eval'].get('bert_score', 0), }) # Gắn kết quả vào history để compare_models.py đọc được final_epoch = training_args.num_train_epochs trainer.state.log_history.append({ "epoch": final_epoch, "val_accuracy_normalized": metrics.get('accuracy_normalized'), "val_f1_normalized": metrics.get('f1_normalized'), "val_bleu4_normalized": metrics.get('bleu4_normalized'), "val_bert_score_raw": metrics.get('bert_score_raw'), "val_semantic_raw": metrics.get('semantic_raw'), "val_closed_accuracy": closed_eval.get('accuracy', 0), "val_closed_em": closed_eval.get('em', 0), "val_closed_f1": closed_eval.get('f1', 0), "val_open_semantic": open_eval.get('semantic', 0), "val_open_bertscore": open_eval.get('bert_score', 0), "val_open_f1": open_eval.get('f1', 0), "val_open_rouge_l": open_eval.get('rouge_l', 0), }) save_history_records(history_dir, trainer.state.log_history) return elif args.variant == 'B1': # Zero-shot Evaluation cho Hướng B from src.engine.medical_eval import evaluate_multimodal_vqa wrapper = MultimodalVQA(model_id=config['model_b']['model_name']) model, processor = wrapper.load_model() beam_width = config['eval'].get('beam_width_b', 1) print(f"[INFO] Bắt đầu đánh giá B1 với Beam Width = {beam_width}...") metrics = evaluate_multimodal_vqa( model, val_loader, device, processor, beam_width=beam_width, beam_width_closed=config['eval'].get('beam_width_b_closed', beam_width), beam_width_open=config['eval'].get('beam_width_b_open', beam_width), max_new_tokens_closed=config['eval'].get('max_new_tokens_b_closed', 4), max_new_tokens_open=config['eval'].get('max_new_tokens_b_open', answer_max_words + 6), generation_batch_size=config['eval'].get('generation_batch_size_b', 1), max_words=answer_max_words, variant='B1' ) closed_eval = metrics.get('closed_eval', {}) open_eval = metrics.get('open_eval', {}) print(f"\n[RESULT B1 - CLOSED QUESTIONS]") print(f"Count: {closed_eval.get('count', 0)}") print(f"Accuracy: {closed_eval.get('accuracy', 0):.4f}") print(f"EM: {closed_eval.get('em', 0):.4f}") print(f"F1: {closed_eval.get('f1', 0):.4f}") print(f"\n[RESULT B1 - OPEN QUESTIONS]") print(f"Count: {open_eval.get('count', 0)}") print(f"Semantic: {open_eval.get('semantic', 0):.4f}") print(f"BERTScore: {open_eval.get('bert_score', 0):.4f}") print(f"F1: {open_eval.get('f1', 0):.4f}") print(f"ROUGE-L: {open_eval.get('rouge_l', 0):.4f}") if 'long_answers_eval' in metrics: print(f"\n[RESULT B1 - LONG METRICS]") print(f"Accuracy: {metrics['long_answers_eval'].get('accuracy', 0):.4f}") print(f"F1: {metrics['long_answers_eval'].get('f1', 0):.4f}") print(f"Semantic: {metrics['long_answers_eval'].get('semantic', 0):.4f}") print(f"BERTScore: {metrics['long_answers_eval'].get('bert_score', 0):.4f}") # [FIX] Lưu dưới dạng record có 'epoch' để compare_models.py có thể parse save_history_records(history_dir, [{ "epoch": 1, "variant": "B1", "beam_width": beam_width, "train_loss": 0.0, # zero-shot không có train loss "val_accuracy_normalized": float(metrics.get('accuracy_normalized', metrics.get('accuracy', 0))), "val_f1_normalized": float(metrics.get('f1_normalized', metrics.get('f1', 0))), "val_bleu4_normalized": float(metrics.get('bleu4_normalized', metrics.get('bleu4', 0))), "val_bert_score_raw": float(metrics.get('bert_score_raw', metrics.get('bert_score', 0))), "val_semantic_raw": float(metrics.get('semantic_raw', metrics.get('semantic', 0))), "val_closed_accuracy": float(closed_eval.get('accuracy', 0)), "val_closed_em": float(closed_eval.get('em', 0)), "val_closed_f1": float(closed_eval.get('f1', 0)), "val_open_semantic": float(open_eval.get('semantic', 0)), "val_open_bertscore": float(open_eval.get('bert_score', 0)), "val_open_f1": float(open_eval.get('f1', 0)), "val_open_rouge_l": float(open_eval.get('rouge_l', 0)), "metrics": metrics, }]) return if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="configs/medical_vqa.yaml") parser.add_argument("--variant", type=str, choices=['A1', 'A2', 'B1', 'B2', 'DPO', 'PPO'], required=True) parser.add_argument("--debug", action="store_true") parser.add_argument("--no_compare", action="store_true", help="Bỏ qua vẽ chart so sánh 5 model sau khi train xong") args = parser.parse_args() train(args) # Auto-generate comparison charts after training if not args.no_compare: import subprocess, sys log_dir = "logs/medical_vqa/history" out_dir = "results/charts" print(f"\n[INFO] 📊 Tự động vẽ biểu đồ so sánh 5 model → {out_dir}/") try: subprocess.run( [sys.executable, "scripts/compare_models.py", "--log_dir", log_dir, "--out", out_dir], check=False ) except Exception as e: print(f"[WARNING] compare_models.py thất bại: {e}") print(" Chạy thủ công: python scripts/compare_models.py")