Spaces:
Paused
Paused
| import wandb | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, random_split | |
| from transformers import AutoTokenizer | |
| import yaml | |
| import argparse | |
| import os | |
| import random | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| # [Bypass CVE-2025-32434] Bỏ qua yêu cầu nâng cấp PyTorch 2.6 của transformers | |
| import transformers.utils.import_utils | |
| transformers.utils.import_utils.check_torch_load_is_safe = lambda: None | |
| import transformers.modeling_utils | |
| transformers.modeling_utils.check_torch_load_is_safe = lambda: None | |
| # [Bypass FSDPModule Error] Sửa lỗi thư viện trl import FSDPModule trên PyTorch cũ | |
| import torch.distributed.fsdp as fsdp | |
| if not hasattr(fsdp, "FSDPModule"): | |
| fsdp.FSDPModule = fsdp.FullyShardedDataParallel | |
| import csv | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| from PIL import Image | |
| from datasets import load_dataset | |
| # Import các thành phần từ thư mục src | |
| from src.models.medical_vqa_model import MedicalVQAModelA | |
| from src.models.multimodal_vqa import MultimodalVQA | |
| from src.utils.visualization import MedicalImageTransform as MedicalTransform | |
| from src.data.medical_dataset import MedicalVQADataset | |
| from src.utils.text_utils import get_target_answer, normalize_answer, postprocess_answer | |
| def build_training_arguments(training_arguments_cls, **kwargs): | |
| """Create TrainingArguments across transformers versions.""" | |
| if "evaluation_strategy" in kwargs and "eval_strategy" not in kwargs: | |
| alias_kwargs = dict(kwargs) | |
| alias_kwargs["eval_strategy"] = alias_kwargs.pop("evaluation_strategy") | |
| try: | |
| return training_arguments_cls(**alias_kwargs) | |
| except TypeError as exc: | |
| if "eval_strategy" not in str(exc): | |
| raise | |
| return training_arguments_cls(**kwargs) | |
| def vqa_collate_fn(batch): | |
| """Hàm gom batch tùy chỉnh để xử lý ảnh PIL và raw text.""" | |
| elem = batch[0] | |
| collated = {} | |
| for key in elem.keys(): | |
| if key in ['image', 'input_ids', 'attention_mask', 'label_closed', 'target_ids', 'chosen_ids', 'rejected_ids']: | |
| collated[key] = torch.stack([item[key] for item in batch]) | |
| else: | |
| # Giữ nguyên list cho PIL images và raw text | |
| collated[key] = [item[key] for item in batch] | |
| return collated | |
| def flatten_dict(data, parent_key="", sep="."): | |
| items = {} | |
| for key, value in data.items(): | |
| new_key = f"{parent_key}{sep}{key}" if parent_key else str(key) | |
| if isinstance(value, dict): | |
| items.update(flatten_dict(value, new_key, sep=sep)) | |
| elif isinstance(value, (list, tuple)): | |
| continue | |
| else: | |
| items[new_key] = value | |
| return items | |
| def create_history_dir(base_log_dir, variant): | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| history_dir = os.path.join(base_log_dir, "history", variant, timestamp) | |
| os.makedirs(history_dir, exist_ok=True) | |
| return history_dir | |
| def save_history_records(history_dir, records): | |
| os.makedirs(history_dir, exist_ok=True) | |
| json_path = os.path.join(history_dir, "history.json") | |
| csv_path = os.path.join(history_dir, "history.csv") | |
| with open(json_path, "w", encoding="utf-8") as f: | |
| json.dump(records, f, ensure_ascii=False, indent=2) | |
| flat_rows = [flatten_dict(record) for record in records] | |
| if flat_rows: | |
| fieldnames = sorted({key for row in flat_rows for key in row.keys()}) | |
| with open(csv_path, "w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| writer.writeheader() | |
| writer.writerows(flat_rows) | |
| def select_best_adapter_checkpoint(checkpoint_root: str): | |
| checkpoint_root = Path(checkpoint_root) | |
| if not checkpoint_root.exists(): | |
| raise FileNotFoundError(f"Không tìm thấy thư mục checkpoint: {checkpoint_root}") | |
| def _is_valid_adapter_checkpoint(path: Path) -> bool: | |
| adapter_cfg = path / "adapter_config.json" | |
| adapter_weights = path / "adapter_model.safetensors" | |
| if not adapter_cfg.exists() or not adapter_weights.exists(): | |
| return False | |
| try: | |
| from safetensors import safe_open | |
| with safe_open(str(adapter_weights), framework="pt", device="cpu") as f: | |
| return len(f.keys()) > 0 | |
| except Exception as exc: | |
| print(f"[WARN] Bỏ qua checkpoint lỗi {path}: {exc}") | |
| return False | |
| checkpoint_dirs = sorted( | |
| p for p in checkpoint_root.glob("checkpoint-*") | |
| if _is_valid_adapter_checkpoint(p) | |
| ) | |
| if not checkpoint_dirs: | |
| raise FileNotFoundError(f"Không có adapter checkpoint hợp lệ trong {checkpoint_root}") | |
| for state_file in sorted(checkpoint_root.glob("checkpoint-*/trainer_state.json"), reverse=True): | |
| try: | |
| state = json.loads(state_file.read_text(encoding="utf-8")) | |
| except (OSError, json.JSONDecodeError): | |
| continue | |
| best_path = state.get("best_model_checkpoint") | |
| if best_path: | |
| best_dir = Path(best_path.replace("./", "")) | |
| if not best_dir.is_absolute(): | |
| best_dir = Path.cwd() / best_dir | |
| if _is_valid_adapter_checkpoint(best_dir): | |
| return best_dir.resolve() | |
| return checkpoint_dirs[-1].resolve() | |
| def build_dpo_instruction_prompt(question: str, max_words: int = 10) -> str: | |
| question = str(question or "").strip() | |
| instruction = ( | |
| "Chi tra loi bang tieng Viet. " | |
| "Khong dung tieng Anh. " | |
| "Khong lap lai cau hoi. " | |
| "Khong mo ta hinh anh chung chung. " | |
| f"Chi tra loi truc tiep dap an, toi da {max_words} tu." | |
| ) | |
| return f"USER: <image>\n{question}\n{instruction} ASSISTANT:" | |
| def load_latest_variant_metrics(history_root: str, variant: str) -> dict | None: | |
| variant_dir = Path(history_root) / variant | |
| if not variant_dir.exists(): | |
| return None | |
| history_files = sorted(variant_dir.glob("*/history.json")) | |
| if not history_files: | |
| return None | |
| for history_file in reversed(history_files): | |
| try: | |
| records = json.loads(history_file.read_text(encoding="utf-8")) | |
| except (OSError, json.JSONDecodeError): | |
| continue | |
| if records: | |
| return records[-1] | |
| return None | |
| def evaluate_dpo_acceptance(b2_metrics: dict | None, dpo_metrics: dict) -> dict: | |
| if not b2_metrics: | |
| return { | |
| "status": "unknown", | |
| "reason": "missing_b2_metrics", | |
| "summary": "Khong tim thay metrics B2 de doi chieu.", | |
| } | |
| def pct_delta(key: str) -> float | None: | |
| b2_val = b2_metrics.get(key) | |
| dpo_val = dpo_metrics.get(key) | |
| if b2_val is None or dpo_val is None: | |
| return None | |
| return (dpo_val - b2_val) * 100.0 | |
| deltas = { | |
| "accuracy": pct_delta("val_accuracy_normalized"), | |
| "f1": pct_delta("val_f1_normalized"), | |
| "bleu4": pct_delta("val_bleu4_normalized"), | |
| "closed_acc": pct_delta("val_closed_accuracy"), | |
| "open_semantic": pct_delta("val_open_semantic"), | |
| "open_bert": pct_delta("val_open_bertscore"), | |
| } | |
| failed_drop = any( | |
| delta is not None and delta < -1.0 | |
| for delta in (deltas["accuracy"], deltas["f1"], deltas["bleu4"]) | |
| ) | |
| closed_ok = ( | |
| b2_metrics.get("val_closed_accuracy") is not None | |
| and dpo_metrics.get("val_closed_accuracy") is not None | |
| and dpo_metrics["val_closed_accuracy"] >= b2_metrics["val_closed_accuracy"] | |
| ) | |
| open_ok = ( | |
| b2_metrics.get("val_open_semantic") is not None | |
| and dpo_metrics.get("val_open_semantic") is not None | |
| and b2_metrics.get("val_open_bertscore") is not None | |
| and dpo_metrics.get("val_open_bertscore") is not None | |
| and dpo_metrics["val_open_semantic"] >= b2_metrics["val_open_semantic"] | |
| and (dpo_metrics["val_open_bertscore"] - b2_metrics["val_open_bertscore"]) * 100.0 >= -0.3 | |
| ) | |
| accepted = (not failed_drop) and (closed_ok or open_ok) | |
| def _fmt(delta: float | None) -> str: | |
| return "N/A" if delta is None else f"{delta:.2f}" | |
| summary = ( | |
| f"DPO vs B2 deltas (pp): Acc={_fmt(deltas['accuracy'])} | F1={_fmt(deltas['f1'])} | " | |
| f"BLEU={_fmt(deltas['bleu4'])} | Closed={_fmt(deltas['closed_acc'])} | " | |
| f"OpenSem={_fmt(deltas['open_semantic'])} | OpenBERT={_fmt(deltas['open_bert'])}" | |
| ) | |
| return { | |
| "status": "accepted" if accepted else "failed", | |
| "reason": "criteria_met" if accepted else "metric_drop_or_no_gain", | |
| "summary": summary, | |
| "deltas_pp": deltas, | |
| "closed_ok": closed_ok, | |
| "open_ok": open_ok, | |
| } | |
| def evaluate_refinement_acceptance(base_metrics: dict | None, rl_metrics: dict) -> dict: | |
| return evaluate_dpo_acceptance(base_metrics, rl_metrics) | |
| def sanitize_dpo_completion(question: str, answer: str, max_words: int = 10) -> str: | |
| question_norm = normalize_answer(question) | |
| answer_norm = postprocess_answer(answer, max_words=max_words) | |
| if answer_norm in {"yes", "có"}: | |
| return "có" | |
| if answer_norm in {"no", "không"}: | |
| return "không" | |
| is_closed = any( | |
| pattern in question_norm | |
| for pattern in ["bình thường", "bat thuong", "normal", "abnormal"] | |
| ) or question_norm.endswith(" không") or " có " in f" {question_norm} " | |
| if is_closed: | |
| if any(token in answer_norm for token in ["không", "no", "not normal", "abnormal"]): | |
| return "không" | |
| if any(token in answer_norm for token in ["có", "yes", "bình thường", "normal", "present", "detected"]): | |
| return "có" | |
| return answer_norm | |
| def resolve_dpo_image(item: dict, hf_train_data=None, image_dir: str | None = None): | |
| source_idx = item.get("source_idx") | |
| if hf_train_data is not None and source_idx is not None and 0 <= int(source_idx) < len(hf_train_data): | |
| img = hf_train_data[int(source_idx)].get("image") | |
| if img is not None and getattr(img, "mode", None) != "RGB": | |
| img = img.convert("RGB") | |
| return img | |
| image_name = item.get("image") | |
| if image_name and image_dir: | |
| img_path = os.path.join(image_dir, image_name) | |
| if os.path.exists(img_path): | |
| return Image.open(img_path).convert("RGB") | |
| return None | |
| def infer_closed_answer_type(item: dict, answer: str | None = None) -> bool: | |
| answer_norm = normalize_answer(answer if answer is not None else get_target_answer(item)) | |
| answer_type = str(item.get("answer_type", "")).strip().upper() | |
| label_closed = item.get("label_closed", None) | |
| if answer_type == "CLOSED" or label_closed in (0, 1): | |
| return True | |
| return answer_norm in {"có", "không", "yes", "no"} | |
| def move_model_batch_to_device(batch: dict, device: torch.device) -> dict: | |
| moved = {} | |
| for key, value in batch.items(): | |
| if hasattr(value, "to"): | |
| moved[key] = value.to(device) | |
| else: | |
| moved[key] = value | |
| return moved | |
| def build_multimodal_completion_batch(processor, prompts, completions, images, max_length=None): | |
| full_texts = [f"{prompt}{completion}" for prompt, completion in zip(prompts, completions)] | |
| batch = processor( | |
| text=full_texts, | |
| images=images, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=False, | |
| ) | |
| prompt_batch = processor( | |
| text=prompts, | |
| images=images, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=False, | |
| ) | |
| completion_mask = torch.zeros_like(batch["input_ids"], dtype=torch.long) | |
| prompt_lengths = prompt_batch["attention_mask"].sum(dim=1) | |
| for i, prompt_len in enumerate(prompt_lengths.tolist()): | |
| token_positions = batch["attention_mask"][i].nonzero(as_tuple=True)[0] | |
| completion_mask[i, token_positions[prompt_len:]] = 1 | |
| if max_length is not None and batch["input_ids"].shape[1] > max_length: | |
| batch["input_ids"] = batch["input_ids"][:, :max_length] | |
| batch["attention_mask"] = batch["attention_mask"][:, :max_length] | |
| completion_mask = completion_mask[:, :max_length] | |
| for key in ("token_type_ids", "mm_token_type_ids"): | |
| if key in batch: | |
| batch[key] = batch[key][:, :max_length] | |
| return batch, completion_mask | |
| def compute_masked_sequence_logprobs(model, batch, completion_mask): | |
| model_inputs = move_model_batch_to_device(batch, next(model.parameters()).device) | |
| completion_mask = completion_mask.to(model_inputs["input_ids"].device) | |
| outputs = model(**model_inputs) | |
| logits = outputs.logits[:, :-1, :] | |
| labels = model_inputs["input_ids"][:, 1:] | |
| token_mask = completion_mask[:, 1:].float() | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1) | |
| masked_log_probs = token_log_probs * token_mask | |
| denom = token_mask.sum(dim=1).clamp_min(1.0) | |
| seq_log_probs = masked_log_probs.sum(dim=1) / denom | |
| probs = log_probs.exp() | |
| token_entropy = -(probs * log_probs).sum(dim=-1) | |
| seq_entropy = (token_entropy * token_mask).sum(dim=1) / denom | |
| return seq_log_probs, seq_entropy | |
| def compute_single_open_reward(pred: str, ref: str) -> tuple[float, dict]: | |
| from src.utils.metrics import compute_exact_match, compute_f1, compute_rouge_l | |
| from src.utils import metrics as metrics_module | |
| norm_pred = normalize_answer(pred) or "." | |
| norm_ref = normalize_answer(ref) or "." | |
| exact = compute_exact_match(norm_pred, norm_ref) | |
| f1 = compute_f1(norm_pred, norm_ref) | |
| rouge_l = compute_rouge_l(norm_pred, norm_ref) | |
| bert = 0.0 | |
| scorer = getattr(metrics_module, "bert_scorer", None) | |
| if scorer is not None: | |
| try: | |
| _, _, bert_f1 = scorer.score([norm_pred], [norm_ref]) | |
| bert = float(bert_f1.mean().item()) | |
| except Exception: | |
| bert = 0.0 | |
| blended = (0.55 * bert) + (0.30 * f1) + (0.10 * rouge_l) + (0.05 * exact) | |
| reward = (2.0 * blended) - 1.0 | |
| return reward, { | |
| "bert": bert, | |
| "f1": f1, | |
| "rouge_l": rouge_l, | |
| "exact": exact, | |
| "blended": blended, | |
| } | |
| def train(args): | |
| # 1. Load Cấu hình | |
| with open(args.config, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| # ── WandB Setup ────────────────────────────────────────────────────────── | |
| _wandb_cfg = config.get("wandb", {}) | |
| _use_wandb = bool(os.environ.get("WANDB_API_KEY") or os.environ.get("WANDB_MODE")) | |
| if _use_wandb: | |
| _api_key = os.environ.get("WANDB_API_KEY") | |
| if _api_key: | |
| wandb.login(key=_api_key) | |
| # Offline mode: set WANDB_MODE=offline hoặc config wandb.offline: true | |
| _offline = _wandb_cfg.get("offline", False) or \ | |
| os.environ.get("WANDB_MODE", "").lower() == "offline" | |
| if _offline: | |
| os.environ["WANDB_MODE"] = "offline" | |
| print("[INFO] WandB chạy ở chế độ OFFLINE (sync sau bằng: wandb sync)") | |
| # Tags theo variant từ YAML | |
| _tags = _wandb_cfg.get("tags", {}).get(args.variant, []) | |
| # Rich config ghi đầy đủ thông tin experiment | |
| _run_config = { | |
| # ── Model architecture ── | |
| "variant": args.variant, | |
| "decoder_type": config["model_a"].get("decoder_type"), | |
| "image_encoder": config["model_a"].get("image_encoder"), | |
| "text_encoder": config["model_a"].get("text_encoder"), | |
| "hidden_size": config["model_a"].get("hidden_size"), | |
| "transformer_heads": config["model_a"].get("transformer_heads"), | |
| "transformer_ff_dim": config["model_a"].get("transformer_ff_dim"), | |
| "transformer_layers": config["model_a"].get("transformer_decoder_layers"), | |
| "norm_first": config["model_a"].get("transformer_norm_first"), | |
| "freeze_phobert_layers": config["model_a"].get("freeze_phobert_layers"), | |
| # ── Training ── | |
| "learning_rate": config["train"].get("learning_rate"), | |
| "phobert_lr": config["train"].get("phobert_lr"), | |
| "vision_lr": config["train"].get("vision_lr"), | |
| "batch_size": config["train"].get("batch_size"), | |
| "grad_accum_steps": config["train"].get("gradient_accumulation_steps"), | |
| "effective_batch": config["train"].get("batch_size", 32) * | |
| config["train"].get("gradient_accumulation_steps", 1), | |
| "label_smoothing": config["train"].get("label_smoothing"), | |
| "open_loss_weight": config["train"].get("open_loss_weight"), | |
| "warmup_epochs": config["train"].get("warmup_epochs"), | |
| "scheduler": config["train"].get("scheduler"), | |
| "patience": config["train"].get("patience"), | |
| "use_amp": config["train"].get("use_amp"), | |
| # ── Data ── | |
| "dataset": config["data"].get("dataset_name"), | |
| "max_question_len": config["data"].get("max_question_len"), | |
| "max_answer_len": config["data"].get("max_answer_len"), | |
| # ── Eval ── | |
| "beam_width": config["eval"].get("beam_width_a") if args.variant in ("A1", "A2") | |
| else config["eval"].get("beam_width_b"), | |
| } | |
| # Thêm hardware info | |
| if torch.cuda.is_available(): | |
| _run_config["gpu_name"] = torch.cuda.get_device_name(0) | |
| _run_config["gpu_count"] = torch.cuda.device_count() | |
| _run_config["vram_gb"] = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1) | |
| _entity = _wandb_cfg.get("entity") or None # None = WandB dùng default entity | |
| wandb.init( | |
| project=_wandb_cfg.get("project", "MedicalVQA-Vietnam"), | |
| entity=_entity, | |
| name=f"{args.variant}-{datetime.now().strftime('%m%d-%H%M')}", | |
| group=_wandb_cfg.get("group", "DL-Final"), | |
| job_type=_wandb_cfg.get("job_type", "train"), | |
| tags=_tags, | |
| notes=_wandb_cfg.get("notes", ""), | |
| config=_run_config, | |
| save_code=_wandb_cfg.get("save_code", True), | |
| reinit="finish_previous", # Kết thúc run trước nếu chạy nhiều variant liên tiếp | |
| ) | |
| print(f"[INFO] ✅ WandB run: {wandb.run.url}") | |
| # Watch model gradients nếu được bật | |
| if _wandb_cfg.get("watch_model", False): | |
| # model chưa khởi tạo ở đây — hook sẽ được gọi sau khi model được tạo | |
| os.environ["_WANDB_WATCH_PENDING"] = "1" | |
| else: | |
| print("[INFO] WandB không được cấu hình (thiếu WANDB_API_KEY) — bỏ qua logging.") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[INFO] Thiết bị sử dụng: {device}") | |
| history_dir = create_history_dir(config.get("log_dir", "logs/medical_vqa"), args.variant) | |
| print(f"[INFO] Lưu training history tại: {history_dir}") | |
| # 2. Tokenizer & Dataset | |
| tokenizer = AutoTokenizer.from_pretrained(config['model_a']['phobert_model']) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token | |
| transform = MedicalTransform(size=config['data']['image_size']) | |
| answer_max_words = int(config['data'].get('answer_max_words', 10)) | |
| # Nạp dữ liệu từ HuggingFace Hub hoặc cục bộ | |
| hf_repo = config['data'].get('hf_dataset') | |
| use_hf_splits = bool(config['data'].get('use_hf_splits', True)) | |
| if hf_repo and use_hf_splits: | |
| print(f"[INFO] Đang tải dữ liệu từ Hub: {hf_repo}") | |
| dataset_dict = load_dataset(hf_repo) | |
| if args.debug: | |
| print("[WARNING] DEBUG MODE: Chỉ lấy 20 mẫu để chạy thử.") | |
| dataset_dict['train'] = dataset_dict['train'].select(range(min(20, len(dataset_dict['train'])))) | |
| config['train']['epochs'] = 2 | |
| config['train']['batch_size'] = 2 | |
| train_ds = MedicalVQADataset( | |
| hf_dataset=dataset_dict['train'], | |
| tokenizer=tokenizer, | |
| transform=transform, | |
| max_seq_len=config['data']['max_question_len'], | |
| max_ans_len=config['data']['max_answer_len'], | |
| answer_max_words=answer_max_words | |
| ) | |
| val_ds = MedicalVQADataset( | |
| hf_dataset=dataset_dict['validation'], | |
| tokenizer=tokenizer, | |
| transform=transform, | |
| max_seq_len=config['data']['max_question_len'], | |
| max_ans_len=config['data']['max_answer_len'], | |
| answer_max_words=answer_max_words | |
| ) | |
| else: | |
| vqa_path = config['data']['vqa_json'] | |
| print(f"[INFO] Đang tải dữ liệu cục bộ từ: {vqa_path}") | |
| full_dataset = MedicalVQADataset( | |
| json_path=vqa_path, | |
| image_dir=config['data']['image_dir'], | |
| tokenizer=tokenizer, | |
| transform=transform, | |
| max_seq_len=config['data']['max_question_len'], | |
| max_ans_len=config['data']['max_answer_len'], | |
| answer_max_words=answer_max_words | |
| ) | |
| train_size = int(0.8 * len(full_dataset)) | |
| val_size = len(full_dataset) - train_size | |
| train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) | |
| train_loader = DataLoader( | |
| train_ds, | |
| batch_size=config['train']['batch_size'], | |
| shuffle=True, | |
| collate_fn=vqa_collate_fn, | |
| num_workers=config['train'].get('num_workers', 0), | |
| pin_memory=config['train'].get('pin_memory', False) | |
| ) | |
| val_loader = DataLoader( | |
| val_ds, | |
| batch_size=config['train']['eval_batch_size'] if 'eval_batch_size' in config['train'] else 8, | |
| collate_fn=vqa_collate_fn | |
| ) | |
| # 3. Khởi tạo Mô hình dựa trên Variant | |
| if args.variant in ['A1', 'A2']: | |
| decoder_type = "lstm" if args.variant == 'A1' else "transformer" | |
| model = MedicalVQAModelA( | |
| decoder_type=decoder_type, | |
| vocab_size=len(tokenizer), | |
| hidden_size=config['model_a'].get('hidden_size', 768), | |
| phobert_model=config['model_a'].get('phobert_model', "vinai/phobert-base") | |
| ).to(device) | |
| # Log model param count lên WandB | |
| if wandb.run: | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| wandb.config.update({ | |
| "total_params_M": round(total_params / 1e6, 2), | |
| "trainable_params_M": round(trainable_params / 1e6, 2), | |
| }) | |
| print(f"[INFO] Tổng params: {total_params/1e6:.1f}M | Trainable: {trainable_params/1e6:.1f}M") | |
| # wandb.watch: chỉ bật nếu log_gradients: true | |
| if _wandb_cfg.get("log_gradients", False): | |
| wandb.watch(model, log="gradients", | |
| log_freq=_wandb_cfg.get("log_freq", 50)) | |
| # Thiết lập Optimizer với Differential Learning Rate | |
| optimizer = optim.AdamW([ | |
| {'params': model.image_encoder.parameters(), 'lr': float(config['train']['vision_lr'])}, | |
| {'params': model.text_encoder.parameters(), 'lr': float(config['train']['phobert_lr'])}, | |
| {'params': model.fusion.parameters(), 'lr': float(config['train']['learning_rate'])}, | |
| {'params': model.decoder.parameters(), 'lr': float(config['train']['learning_rate'])} | |
| ]) | |
| # [CRITICAL FIX] Dùng Cosine Schedule với Warmup, step theo batch thay vì epoch | |
| from transformers import get_cosine_schedule_with_warmup | |
| # Use a_epochs for Direction A models (A1, A2), otherwise use default epochs | |
| if args.variant in ['A1', 'A2']: | |
| epochs = config['train'].get('a_epochs', config['train']['epochs']) | |
| else: | |
| epochs = config['train']['epochs'] | |
| warmup_epochs = config['train'].get('warmup_epochs', 5) | |
| accumulation_steps = config['train'].get('gradient_accumulation_steps', 2) | |
| total_steps = epochs * len(train_loader) // max(accumulation_steps, 1) | |
| warmup_steps = warmup_epochs * len(train_loader) // max(accumulation_steps, 1) | |
| scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=warmup_steps, | |
| num_training_steps=total_steps | |
| ) | |
| # Khởi tạo Trainer với pad_token_id và beam_width từ config | |
| beam_width = config['eval'].get('beam_width_a', 5) | |
| from src.engine.trainer import MedicalVQATrainer | |
| trainer = MedicalVQATrainer( | |
| model=model, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| device=device, | |
| config={ | |
| **config, | |
| 'variant': args.variant, | |
| 'history_dir': history_dir, | |
| # Pass tunable open-loss weight so trainer doesn't use hardcoded value | |
| 'open_loss_weight': config['train'].get('open_loss_weight', 2.0), | |
| }, | |
| pad_token_id=tokenizer.pad_token_id, | |
| beam_width=beam_width | |
| ) | |
| print(f"[INFO] Beam Width cho Hướng A: {beam_width}") | |
| print(f"[INFO] Bắt đầu huấn luyện cấu hình {args.variant} ({epochs} epochs)...") | |
| trainer.train(epochs, tokenizer=tokenizer) | |
| if wandb.run: | |
| wandb.finish() | |
| return | |
| elif args.variant == 'PPO': | |
| from src.engine.medical_eval import evaluate_multimodal_vqa | |
| ppo_cfg = config.get('ppo', {}) | |
| ppo_answer_max_words = int(ppo_cfg.get('max_answer_words', min(answer_max_words, 6))) | |
| wrapper = MultimodalVQA( | |
| model_id=config['model_b']['model_name'], | |
| lora_r=int(config['model_b'].get('lora_r', 16)), | |
| lora_alpha=int(config['model_b'].get('lora_alpha', 32)), | |
| lora_dropout=float(config['model_b'].get('lora_dropout', 0.05)), | |
| lora_target_modules=config['model_b'].get('lora_target_modules'), | |
| ) | |
| b2_checkpoint = select_best_adapter_checkpoint(config['train'].get('b2_output_dir', './checkpoints/B2')) | |
| print(f"[INFO] PPO sẽ khởi tạo từ B2 checkpoint: {b2_checkpoint}") | |
| model, processor = wrapper.load_model(adapter_path=str(b2_checkpoint), is_trainable=True) | |
| if not ppo_cfg.get('train_mlp_lora', False): | |
| frozen_lora = 0 | |
| for name, param in model.named_parameters(): | |
| if "lora_" in name and any(proj in name for proj in ("gate_proj", "up_proj", "down_proj")): | |
| param.requires_grad = False | |
| frozen_lora += param.numel() | |
| print(f"[INFO] PPO đang freeze LoRA MLP để giảm VRAM: {frozen_lora:,} tham số") | |
| model.print_trainable_parameters() | |
| def _build_ppo_source(): | |
| if hf_repo: | |
| return dataset_dict['train'], dataset_dict['train'] | |
| if hasattr(train_ds, "dataset") and hasattr(train_ds.dataset, "data"): | |
| subset_indices = getattr(train_ds, "indices", list(range(len(train_ds.dataset.data)))) | |
| local_items = [train_ds.dataset.data[i] for i in subset_indices] | |
| return local_items, None | |
| raise ValueError("Khong the truy cap raw train data de tao PPO rollout set.") | |
| def _prepare_ppo_records(raw_items, num_samples: int, closed_ratio: float): | |
| closed_records = [] | |
| open_records = [] | |
| for idx in range(len(raw_items)): | |
| item = raw_items[idx] | |
| question = str(item.get("question_vi", item.get("question", ""))).strip() | |
| target = get_target_answer(item, max_words=ppo_answer_max_words) | |
| if not question or not target: | |
| continue | |
| record = { | |
| "question": question, | |
| "target": target, | |
| "source_idx": idx, | |
| "image": item.get("image_name"), | |
| "is_closed": infer_closed_answer_type(item, target), | |
| } | |
| if record["is_closed"]: | |
| closed_records.append(record) | |
| else: | |
| open_records.append(record) | |
| rng = random.Random(int(config.get("seed", 42))) | |
| rng.shuffle(closed_records) | |
| rng.shuffle(open_records) | |
| target_closed = min(len(closed_records), int(round(num_samples * closed_ratio))) | |
| target_open = min(len(open_records), max(0, num_samples - target_closed)) | |
| selected = closed_records[:target_closed] + open_records[:target_open] | |
| rng.shuffle(selected) | |
| return selected | |
| raw_train_source, hf_train_source = _build_ppo_source() | |
| ppo_records = _prepare_ppo_records( | |
| raw_train_source, | |
| num_samples=int(ppo_cfg.get('num_samples', 192)), | |
| closed_ratio=float(ppo_cfg.get('closed_ratio', 0.5)), | |
| ) | |
| if not ppo_records: | |
| raise ValueError("Khong tao duoc PPO rollout set hop le.") | |
| print(f"[INFO] PPO rollout set: {len(ppo_records)} mau") | |
| trainable_params = [param for param in model.parameters() if param.requires_grad] | |
| optimizer = optim.AdamW( | |
| trainable_params, | |
| lr=float(ppo_cfg.get('learning_rate', 5.0e-7)), | |
| weight_decay=float(ppo_cfg.get('weight_decay', 0.0)), | |
| ) | |
| rollout_batch_size = max(1, int(ppo_cfg.get('rollout_batch_size', 2))) | |
| total_updates = max(1, (len(ppo_records) + rollout_batch_size - 1) // rollout_batch_size) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_updates) | |
| ppo_history = [] | |
| eos = processor.tokenizer.eos_token or "" | |
| max_seq_length = max(int(config['train'].get('dpo_max_length', 768)), 768) | |
| grad_clip = float(config['train'].get('grad_clip', 1.0)) | |
| entropy_coef = float(ppo_cfg.get('entropy_coef', 0.001)) | |
| clip_range = float(ppo_cfg.get('clip_range', 0.2)) | |
| max_new_tokens = int(ppo_cfg.get('max_new_tokens', 12)) | |
| temperature = float(ppo_cfg.get('temperature', 0.8)) | |
| top_p = float(ppo_cfg.get('top_p', 0.9)) | |
| closed_positive = float(ppo_cfg.get('closed_positive_reward', 1.0)) | |
| closed_negative = float(ppo_cfg.get('closed_negative_reward', -1.0)) | |
| print("[INFO] Bắt đầu huấn luyện PPO-style refinement...") | |
| model.train() | |
| for update_idx in range(total_updates): | |
| batch_records = ppo_records[update_idx * rollout_batch_size:(update_idx + 1) * rollout_batch_size] | |
| prompts, images, questions, targets, closed_flags = [], [], [], [], [] | |
| for record in batch_records: | |
| image = resolve_dpo_image( | |
| record, | |
| hf_train_data=hf_train_source, | |
| image_dir=config['data'].get('image_dir'), | |
| ) | |
| if image is None: | |
| continue | |
| prompts.append(build_dpo_instruction_prompt(record["question"], max_words=ppo_answer_max_words)) | |
| images.append(image) | |
| questions.append(record["question"]) | |
| targets.append(record["target"]) | |
| closed_flags.append(record["is_closed"]) | |
| if not prompts: | |
| continue | |
| generation_inputs = processor( | |
| text=prompts, | |
| images=images, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| generation_inputs = move_model_batch_to_device(generation_inputs, next(model.parameters()).device) | |
| if "pixel_values" in generation_inputs: | |
| generation_inputs["pixel_values"] = generation_inputs["pixel_values"].to(torch.bfloat16) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **generation_inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_beams=1, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| prompt_token_len = generation_inputs["input_ids"].shape[1] | |
| generated_texts = processor.batch_decode( | |
| generated_ids[:, prompt_token_len:], | |
| skip_special_tokens=True, | |
| ) | |
| sampled_answers = [] | |
| rewards = [] | |
| reward_breakdown = [] | |
| for question, target, is_closed, raw_output in zip(questions, targets, closed_flags, generated_texts): | |
| pred = sanitize_dpo_completion(question, raw_output, max_words=ppo_answer_max_words) | |
| if not pred: | |
| pred = "không" if is_closed else "không rõ" | |
| sampled_answers.append(pred) | |
| if is_closed: | |
| reward = closed_positive if normalize_answer(pred) == normalize_answer(target) else closed_negative | |
| rewards.append(reward) | |
| reward_breakdown.append({"exact": float(reward > 0), "reward": reward}) | |
| else: | |
| reward, details = compute_single_open_reward(pred, target) | |
| rewards.append(reward) | |
| reward_breakdown.append(details | {"reward": reward}) | |
| completion_texts = [f" {pred}{eos}" for pred in sampled_answers] | |
| rollout_batch, rollout_mask = build_multimodal_completion_batch( | |
| processor, | |
| prompts, | |
| completion_texts, | |
| images, | |
| max_length=max_seq_length, | |
| ) | |
| with torch.no_grad(): | |
| old_seq_log_probs, _ = compute_masked_sequence_logprobs(model, rollout_batch, rollout_mask) | |
| reward_tensor = torch.tensor(rewards, dtype=torch.float32, device=old_seq_log_probs.device) | |
| if reward_tensor.numel() > 1: | |
| advantages = reward_tensor - reward_tensor.mean() | |
| advantages = advantages / advantages.std(unbiased=False).clamp_min(1e-6) | |
| else: | |
| advantages = reward_tensor | |
| optimizer.zero_grad(set_to_none=True) | |
| new_seq_log_probs, entropy = compute_masked_sequence_logprobs(model, rollout_batch, rollout_mask) | |
| ratios = torch.exp(new_seq_log_probs - old_seq_log_probs.detach()) | |
| clipped_ratios = torch.clamp(ratios, 1.0 - clip_range, 1.0 + clip_range) | |
| surrogate_1 = ratios * advantages | |
| surrogate_2 = clipped_ratios * advantages | |
| policy_loss = -torch.min(surrogate_1, surrogate_2).mean() | |
| entropy_bonus = entropy.mean() | |
| loss = policy_loss - (entropy_coef * entropy_bonus) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(trainable_params, grad_clip) | |
| optimizer.step() | |
| scheduler.step() | |
| closed_rewards = [r for r, is_closed in zip(rewards, closed_flags) if is_closed] | |
| open_rewards = [r for r, is_closed in zip(rewards, closed_flags) if not is_closed] | |
| log_record = { | |
| "epoch": 1, | |
| "update": update_idx + 1, | |
| "train_loss": float(loss.detach().cpu().item()), | |
| "policy_loss": float(policy_loss.detach().cpu().item()), | |
| "entropy": float(entropy_bonus.detach().cpu().item()), | |
| "avg_reward": float(sum(rewards) / len(rewards)), | |
| "avg_closed_reward": float(sum(closed_rewards) / len(closed_rewards)) if closed_rewards else None, | |
| "avg_open_reward": float(sum(open_rewards) / len(open_rewards)) if open_rewards else None, | |
| "learning_rate": float(scheduler.get_last_lr()[0]), | |
| "sample_predictions": sampled_answers[:2], | |
| "sample_targets": targets[:2], | |
| "reward_breakdown": reward_breakdown[:2], | |
| } | |
| ppo_history.append(log_record) | |
| if wandb.run: | |
| wandb.log({ | |
| "ppo/train_loss": log_record["train_loss"], | |
| "ppo/policy_loss": log_record["policy_loss"], | |
| "ppo/entropy": log_record["entropy"], | |
| "ppo/avg_reward": log_record["avg_reward"], | |
| "ppo/avg_closed_reward": log_record["avg_closed_reward"], | |
| "ppo/avg_open_reward": log_record["avg_open_reward"], | |
| "ppo/learning_rate": log_record["learning_rate"], | |
| "ppo/update": log_record["update"], | |
| }) | |
| del generation_inputs, generated_ids | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| final_ppo_dir = Path("checkpoints/PPO/final_adapter") | |
| final_ppo_dir.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(str(final_ppo_dir)) | |
| processor.save_pretrained(str(final_ppo_dir)) | |
| with open("checkpoints/medical_vqa_ppo_from.txt", "w", encoding="utf-8") as f: | |
| f.write(str(b2_checkpoint)) | |
| print("[INFO] Đang chạy đánh giá nghiệm thu trên tập Validation cho PPO...") | |
| model.eval() | |
| metrics = evaluate_multimodal_vqa( | |
| model, | |
| val_loader, | |
| device, | |
| processor, | |
| beam_width=config['eval'].get('beam_width_b', 1), | |
| beam_width_closed=config['eval'].get('beam_width_b_closed', 1), | |
| beam_width_open=config['eval'].get('beam_width_b_open', config['eval'].get('beam_width_b', 1)), | |
| max_new_tokens_closed=config['eval'].get('max_new_tokens_b_closed', 4), | |
| max_new_tokens_open=config['eval'].get('max_new_tokens_b_open', answer_max_words + 6), | |
| generation_batch_size=config['eval'].get('generation_batch_size_b', 1), | |
| max_words=answer_max_words, | |
| variant='PPO' | |
| ) | |
| closed_eval = metrics.get('closed_eval', {}) | |
| open_eval = metrics.get('open_eval', {}) | |
| ppo_history.append({ | |
| "epoch": 1, | |
| "val_accuracy_normalized": metrics.get('accuracy_normalized'), | |
| "val_f1_normalized": metrics.get('f1_normalized'), | |
| "val_bleu4_normalized": metrics.get('bleu4_normalized'), | |
| "val_bert_score_raw": metrics.get('bert_score_raw'), | |
| "val_semantic_raw": metrics.get('semantic_raw'), | |
| "val_closed_accuracy": closed_eval.get('accuracy', 0), | |
| "val_closed_em": closed_eval.get('em', 0), | |
| "val_closed_f1": closed_eval.get('f1', 0), | |
| "val_open_semantic": open_eval.get('semantic', 0), | |
| "val_open_bertscore": open_eval.get('bert_score', 0), | |
| "val_open_f1": open_eval.get('f1', 0), | |
| "val_open_rouge_l": open_eval.get('rouge_l', 0), | |
| }) | |
| b2_metrics = load_latest_variant_metrics(os.path.join(config['log_dir'], "history"), "B2") | |
| ppo_acceptance = evaluate_refinement_acceptance(b2_metrics, ppo_history[-1]) | |
| ppo_history[-1]["ppo_acceptance"] = ppo_acceptance | |
| print(f"[INFO] {ppo_acceptance['summary']}") | |
| if ppo_acceptance["status"] == "accepted": | |
| print("[SUCCESS] PPO accepted: dat tieu chi refinement nhe tren B2.") | |
| elif ppo_acceptance["status"] == "failed": | |
| print("[WARN] PPO failed, keep B2. Khong khuyen nghi tiep tuc tuning them.") | |
| os.makedirs("checkpoints/PPO", exist_ok=True) | |
| with open("checkpoints/PPO/acceptance_summary.json", "w", encoding="utf-8") as f: | |
| json.dump(ppo_acceptance, f, ensure_ascii=False, indent=2) | |
| save_history_records(history_dir, ppo_history) | |
| print("[SUCCESS] Đã lưu checkpoint và metrics PPO.") | |
| return | |
| elif args.variant == 'DPO': | |
| from trl import DPOTrainer | |
| try: | |
| from trl import DPOConfig | |
| except ImportError: | |
| DPOConfig = None | |
| from transformers import TrainingArguments | |
| from datasets import Dataset as HFDataset | |
| import inspect | |
| dpo_answer_max_words = int(config.get('dpo', {}).get('max_answer_words', min(answer_max_words, 6))) | |
| wrapper = MultimodalVQA( | |
| model_id=config['model_b']['model_name'], | |
| lora_r=int(config['model_b'].get('lora_r', 16)), | |
| lora_alpha=int(config['model_b'].get('lora_alpha', 32)), | |
| lora_dropout=float(config['model_b'].get('lora_dropout', 0.05)), | |
| lora_target_modules=config['model_b'].get('lora_target_modules'), | |
| ) | |
| explicit_b2_checkpoint = ( | |
| config.get('train', {}).get('b2_checkpoint') | |
| or os.environ.get('B2_CHECKPOINT_PATH') | |
| ) | |
| if explicit_b2_checkpoint: | |
| b2_checkpoint = Path(explicit_b2_checkpoint).expanduser().resolve() | |
| if not b2_checkpoint.exists(): | |
| raise FileNotFoundError(f"Không tìm thấy B2 checkpoint được chỉ định: {b2_checkpoint}") | |
| print(f"[INFO] DPO sẽ khởi tạo từ B2 checkpoint chỉ định: {b2_checkpoint}") | |
| else: | |
| b2_checkpoint = select_best_adapter_checkpoint(config['train'].get('b2_output_dir', './checkpoints/B2')) | |
| print(f"[INFO] DPO sẽ khởi tạo từ B2 checkpoint: {b2_checkpoint}") | |
| try: | |
| model, processor = wrapper.load_model(adapter_path=str(b2_checkpoint), is_trainable=True) | |
| except Exception as exc: | |
| print(f"[WARNING] Không load được B2 checkpoint, fallback sang base LLaVA-Med + LoRA mới: {exc}") | |
| model, processor = wrapper.load_model(adapter_path=None, is_trainable=True) | |
| if not config['train'].get('dpo_train_mlp_lora', False): | |
| frozen_lora = 0 | |
| for name, param in model.named_parameters(): | |
| if "lora_" in name and any(proj in name for proj in ("gate_proj", "up_proj", "down_proj")): | |
| param.requires_grad = False | |
| frozen_lora += param.numel() | |
| print(f"[INFO] DPO đang freeze LoRA MLP để giảm VRAM: {frozen_lora:,} tham số") | |
| model.print_trainable_parameters() | |
| # Tạo/Load Preference Data | |
| pref_json = config.get('dpo', {}).get('preference_data', 'data/preference_data_slake.json') | |
| force_rebuild_pref = bool(config.get('dpo', {}).get('force_rebuild_preference_data', False)) | |
| if force_rebuild_pref and os.path.exists(pref_json): | |
| print(f"[INFO] Dang xoa preference data cu de tao lai theo cau hinh hien tai: {pref_json}") | |
| os.remove(pref_json) | |
| if not os.path.exists(pref_json): | |
| print(f"[INFO] Chưa có preference data. Đang tự động tạo từ training data...") | |
| from src.engine.dpo_trainer import create_preference_data | |
| if hf_repo: | |
| raw_data = [{"question_vi": item["question_vi"], "answer_vi": get_target_answer(item, max_words=dpo_answer_max_words), | |
| "image_name": item.get("image_name"), | |
| "source_idx": i} | |
| for i, item in enumerate(dataset_dict['train'])] | |
| tmp_json = "data/tmp_train_for_dpo.json" | |
| os.makedirs("data", exist_ok=True) | |
| with open(tmp_json, 'w', encoding='utf-8') as f: | |
| json.dump(raw_data, f, ensure_ascii=False, indent=2) | |
| create_preference_data( | |
| tmp_json, | |
| pref_json, | |
| num_pairs=int(config.get('dpo', {}).get('num_pairs', 400)), | |
| closed_ratio=float(config.get('dpo', {}).get('closed_ratio', 0.6)), | |
| max_answer_words=dpo_answer_max_words, | |
| ) | |
| else: | |
| create_preference_data( | |
| config['data']['vqa_json'], | |
| pref_json, | |
| num_pairs=int(config.get('dpo', {}).get('num_pairs', 400)), | |
| closed_ratio=float(config.get('dpo', {}).get('closed_ratio', 0.6)), | |
| max_answer_words=dpo_answer_max_words, | |
| ) | |
| # Đọc file JSON preference data | |
| with open(pref_json, 'r', encoding='utf-8') as f: | |
| pref_data = json.load(f) | |
| if hf_repo and any("source_idx" not in item for item in pref_data): | |
| print("[INFO] Preference data cu khong co source_idx. Dang tao lai de giu lien ket image cho DPO...") | |
| from src.engine.dpo_trainer import create_preference_data | |
| raw_data = [{"question_vi": item["question_vi"], "answer_vi": get_target_answer(item, max_words=dpo_answer_max_words), | |
| "image_name": item.get("image_name"), "source_idx": i} | |
| for i, item in enumerate(dataset_dict['train'])] | |
| tmp_json = "data/tmp_train_for_dpo.json" | |
| with open(tmp_json, 'w', encoding='utf-8') as f: | |
| json.dump(raw_data, f, ensure_ascii=False, indent=2) | |
| create_preference_data( | |
| tmp_json, | |
| pref_json, | |
| num_pairs=int(config.get('dpo', {}).get('num_pairs', 400)), | |
| closed_ratio=float(config.get('dpo', {}).get('closed_ratio', 0.6)), | |
| max_answer_words=dpo_answer_max_words, | |
| ) | |
| with open(pref_json, 'r', encoding='utf-8') as f: | |
| pref_data = json.load(f) | |
| # Chuẩn bị HF Dataset cho DPOTrainer (yêu cầu cột: prompt, chosen, rejected) | |
| prompts, chosens, rejecteds, images = [], [], [], [] | |
| eos = processor.tokenizer.eos_token or "" | |
| filtered_pairs = 0 | |
| for item in pref_data: | |
| q = item.get("question", "") | |
| chosen = sanitize_dpo_completion(q, item.get("chosen", ""), max_words=dpo_answer_max_words) | |
| rejected = sanitize_dpo_completion(q, item.get("rejected", ""), max_words=dpo_answer_max_words) | |
| image = resolve_dpo_image( | |
| item, | |
| hf_train_data=dataset_dict['train'] if hf_repo else None, | |
| image_dir=config['data'].get('image_dir'), | |
| ) | |
| if not chosen or not rejected or chosen == rejected or image is None: | |
| filtered_pairs += 1 | |
| continue | |
| prompts.append(build_dpo_instruction_prompt(q, max_words=dpo_answer_max_words)) | |
| chosens.append(f" {chosen}{eos}") | |
| rejecteds.append(f" {rejected}{eos}") | |
| images.append(image) | |
| if not prompts: | |
| raise ValueError("Khong con cap preference hop le sau khi sanitize DPO data.") | |
| if filtered_pairs: | |
| print(f"[INFO] Da bo qua {filtered_pairs} cap preference khong hop le sau sanitize.") | |
| dpo_hf_dataset = HFDataset.from_dict({ | |
| "prompt": prompts, | |
| "chosen": chosens, | |
| "rejected": rejecteds, | |
| "image": images, | |
| }) | |
| class MultimodalDPODataCollator: | |
| def __init__(self, processor, max_length=None): | |
| self.processor = processor | |
| self.tokenizer = processor.tokenizer | |
| # LLaVA expands a single <image> placeholder into hundreds of visual tokens. | |
| # If max_length is too small, the processor truncates those tokens and raises | |
| # "image token count" mismatch. Keep a safe floor for multimodal DPO. | |
| self.max_length = max(max_length or 0, 768) if max_length is not None else None | |
| def __call__(self, examples): | |
| prompts = [example["prompt"] for example in examples] | |
| chosens = [example["chosen"] for example in examples] | |
| rejecteds = [example["rejected"] for example in examples] | |
| images = [example["image"] for example in examples] | |
| full_texts = [f"{prompt}{chosen}" for prompt, chosen in zip(prompts, chosens)] | |
| full_texts.extend(f"{prompt}{rejected}" for prompt, rejected in zip(prompts, rejecteds)) | |
| repeated_prompts = prompts + prompts | |
| repeated_images = images + images | |
| batch = self.processor( | |
| text=full_texts, | |
| images=repeated_images, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=False, | |
| ) | |
| prompt_batch = self.processor( | |
| text=repeated_prompts, | |
| images=repeated_images, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=False, | |
| ) | |
| completion_mask = torch.zeros_like(batch["input_ids"], dtype=torch.long) | |
| prompt_lengths = prompt_batch["attention_mask"].sum(dim=1) | |
| for i, prompt_len in enumerate(prompt_lengths.tolist()): | |
| token_positions = batch["attention_mask"][i].nonzero(as_tuple=True)[0] | |
| completion_mask[i, token_positions[prompt_len:]] = 1 | |
| if self.max_length is not None and batch["input_ids"].shape[1] > self.max_length: | |
| batch["input_ids"] = batch["input_ids"][:, :self.max_length] | |
| batch["attention_mask"] = batch["attention_mask"][:, :self.max_length] | |
| completion_mask = completion_mask[:, :self.max_length] | |
| for key in ("token_type_ids", "mm_token_type_ids"): | |
| if key in batch: | |
| batch[key] = batch[key][:, :self.max_length] | |
| batch["completion_mask"] = completion_mask | |
| return batch | |
| dpo_sequence_limits = { | |
| "max_length": max(int(config['train'].get('dpo_max_length', 768)), 768), | |
| "max_prompt_length": int(config['train'].get('dpo_max_prompt_length', 96)), | |
| "max_completion_length": int(config['train'].get('dpo_max_completion_length', 24)), | |
| } | |
| training_args_dict = { | |
| "output_dir": "./checkpoints/DPO", | |
| "per_device_train_batch_size": int(config['train'].get('dpo_batch_size', 1)), | |
| "gradient_accumulation_steps": int(config['train'].get('dpo_gradient_accumulation_steps', 8)), | |
| "num_train_epochs": config['train'].get('dpo_epochs', 1), | |
| "learning_rate": float(config.get('dpo', {}).get('learning_rate', 1.0e-6)), | |
| "lr_scheduler_type": "cosine", # [OPTIMIZED] Giúp hội tụ mượt mà hơn | |
| "warmup_ratio": 0.1, # [OPTIMIZED] Tránh sốc gradient ở epoch đầu | |
| "bf16": True, | |
| "remove_unused_columns": False, | |
| "logging_steps": 10, | |
| "save_strategy": "epoch", | |
| "save_total_limit": 1, | |
| "optim": config['train'].get('dpo_optim', 'paged_adamw_8bit'), | |
| "gradient_checkpointing": True, | |
| } | |
| if DPOConfig is not None: | |
| training_args_dict["beta"] = float(config.get('dpo', {}).get('beta', 0.1)) | |
| dpo_config_params = set(inspect.signature(DPOConfig.__init__).parameters) | |
| for key, value in dpo_sequence_limits.items(): | |
| if key in dpo_config_params: | |
| training_args_dict[key] = value | |
| training_args = DPOConfig(**training_args_dict) | |
| else: | |
| training_args = build_training_arguments(TrainingArguments, **training_args_dict) | |
| training_args.model_init_kwargs = None | |
| dpo_kwargs = { | |
| "model": model, | |
| "args": training_args, | |
| "train_dataset": dpo_hf_dataset, | |
| "data_collator": MultimodalDPODataCollator(processor, max_length=dpo_sequence_limits["max_length"]), | |
| } | |
| dpo_trainer_params = set(inspect.signature(DPOTrainer.__init__).parameters) | |
| for key, value in dpo_sequence_limits.items(): | |
| if key in dpo_trainer_params: | |
| dpo_kwargs[key] = value | |
| try: | |
| print("[INFO] Thử khởi tạo DPOTrainer với processing_class...") | |
| trainer = DPOTrainer(**dpo_kwargs, processing_class=processor) | |
| except TypeError: | |
| try: | |
| trainer = DPOTrainer(**dpo_kwargs, tokenizer=processor) | |
| except TypeError: | |
| trainer = DPOTrainer(**dpo_kwargs, tokenizer=processor.tokenizer) | |
| print("[INFO] Bắt đầu huấn luyện DPO...") | |
| trainer.train() | |
| os.makedirs("checkpoints", exist_ok=True) | |
| final_dpo_dir = Path("checkpoints/DPO/final_adapter") | |
| final_dpo_dir.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(str(final_dpo_dir)) | |
| processor.save_pretrained(str(final_dpo_dir)) | |
| with open("checkpoints/medical_vqa_dpo_from.txt", "w", encoding="utf-8") as f: | |
| f.write(str(b2_checkpoint)) | |
| # [FIX] Đánh giá DPO sau khi train xong để có Accuracy, F1, BLEU cho biểu đồ so sánh | |
| from src.engine.medical_eval import evaluate_multimodal_vqa | |
| print("[INFO] Đang chạy đánh giá nghiệm thu trên tập Validation cho DPO...") | |
| model.eval() | |
| metrics = evaluate_multimodal_vqa( | |
| model, | |
| val_loader, | |
| device, | |
| processor, | |
| beam_width=config['eval'].get('beam_width_b', 1), | |
| beam_width_closed=config['eval'].get('beam_width_b_closed', 1), | |
| beam_width_open=config['eval'].get('beam_width_b_open', config['eval'].get('beam_width_b', 1)), | |
| max_new_tokens_closed=config['eval'].get('max_new_tokens_b_closed', 4), | |
| max_new_tokens_open=config['eval'].get('max_new_tokens_b_open', answer_max_words + 6), | |
| generation_batch_size=config['eval'].get('generation_batch_size_b', 1), | |
| max_words=answer_max_words, | |
| variant='DPO' | |
| ) | |
| closed_eval = metrics.get('closed_eval', {}) | |
| open_eval = metrics.get('open_eval', {}) | |
| print(f"\n[RESULT DPO - CLOSED QUESTIONS]") | |
| print(f"Count: {closed_eval.get('count', 0)}") | |
| print(f"Accuracy: {closed_eval.get('accuracy', 0):.4f}") | |
| print(f"EM: {closed_eval.get('em', 0):.4f}") | |
| print(f"F1: {closed_eval.get('f1', 0):.4f}") | |
| print(f"\n[RESULT DPO - OPEN QUESTIONS]") | |
| print(f"Count: {open_eval.get('count', 0)}") | |
| print(f"Semantic: {open_eval.get('semantic', 0):.4f}") | |
| print(f"BERTScore: {open_eval.get('bert_score', 0):.4f}") | |
| print(f"F1: {open_eval.get('f1', 0):.4f}") | |
| print(f"ROUGE-L: {open_eval.get('rouge_l', 0):.4f}") | |
| final_epoch = training_args.num_train_epochs | |
| trainer.state.log_history.append({ | |
| "epoch": final_epoch, | |
| "val_accuracy_normalized": metrics.get('accuracy_normalized'), | |
| "val_f1_normalized": metrics.get('f1_normalized'), | |
| "val_bleu4_normalized": metrics.get('bleu4_normalized'), | |
| "val_bert_score_raw": metrics.get('bert_score_raw'), | |
| "val_semantic_raw": metrics.get('semantic_raw'), | |
| "val_closed_accuracy": closed_eval.get('accuracy', 0), | |
| "val_closed_em": closed_eval.get('em', 0), | |
| "val_closed_f1": closed_eval.get('f1', 0), | |
| "val_open_semantic": open_eval.get('semantic', 0), | |
| "val_open_bertscore": open_eval.get('bert_score', 0), | |
| "val_open_f1": open_eval.get('f1', 0), | |
| "val_open_rouge_l": open_eval.get('rouge_l', 0), | |
| }) | |
| b2_metrics = load_latest_variant_metrics(os.path.join(config['log_dir'], "history"), "B2") | |
| dpo_acceptance = evaluate_dpo_acceptance(b2_metrics, trainer.state.log_history[-1]) | |
| trainer.state.log_history[-1]["dpo_acceptance"] = dpo_acceptance | |
| print(f"[INFO] {dpo_acceptance['summary']}") | |
| if dpo_acceptance["status"] == "accepted": | |
| print("[SUCCESS] DPO accepted: dat tieu chi refinement nhe tren B2.") | |
| elif dpo_acceptance["status"] == "failed": | |
| print("[WARN] DPO failed, keep B2. Khong khuyen nghi tiep tuc tuning them.") | |
| os.makedirs("checkpoints/DPO", exist_ok=True) | |
| with open("checkpoints/DPO/acceptance_summary.json", "w", encoding="utf-8") as f: | |
| json.dump(dpo_acceptance, f, ensure_ascii=False, indent=2) | |
| save_history_records(history_dir, trainer.state.log_history) | |
| print("[SUCCESS] Đã lưu checkpoint và metrics DPO.") | |
| return | |
| elif args.variant == 'B2': | |
| # Fine-tuning LLaVA-Med | |
| from transformers import TrainingArguments, Trainer | |
| from datasets import Dataset as HFDataset | |
| wrapper = MultimodalVQA( | |
| model_id=config['model_b']['model_name'], | |
| lora_r=int(config['model_b'].get('lora_r', 16)), | |
| lora_alpha=int(config['model_b'].get('lora_alpha', 32)), | |
| lora_dropout=float(config['model_b'].get('lora_dropout', 0.05)), | |
| lora_target_modules=config['model_b'].get('lora_target_modules'), | |
| ) | |
| model, processor = wrapper.load_model() | |
| def make_sft_dataset(raw_ds): | |
| prompts = [] | |
| answers = [] | |
| texts = [] | |
| images = [] | |
| for i in range(len(raw_ds)): | |
| item = raw_ds[i] | |
| if isinstance(item, dict): | |
| q = item.get("question_vi", item.get("question", item.get("raw_questions", ""))) | |
| a = get_target_answer(item, max_words=answer_max_words) | |
| answer_type = str(item.get("answer_type", "")).upper() | |
| label_closed = item.get("label_closed", None) | |
| if answer_type == "CLOSED" or label_closed in (0, 1) or a in {"có", "không", "yes", "no"}: | |
| a_norm = str(a).strip().lower() | |
| a = "không" if a_norm in {"không", "khong", "no", "false", "absent"} else "có" | |
| prompt = wrapper.build_instruction_prompt(q, language="vi", include_answer=False) | |
| prompts.append(prompt) | |
| answers.append(a) | |
| eos = processor.tokenizer.eos_token or "" | |
| texts.append(f"{prompt} {a}{eos}") | |
| img = item.get("image", None) | |
| if img is not None: | |
| if img.mode != "RGB": img = img.convert("RGB") | |
| images.append(img) | |
| return HFDataset.from_dict({"prompt": prompts, "answer": answers, "text": texts, "image": images}) | |
| if hf_repo: | |
| sft_train = make_sft_dataset(dataset_dict['train']) | |
| sft_val = make_sft_dataset(dataset_dict['validation']) | |
| else: | |
| sft_train = make_sft_dataset(train_ds) | |
| sft_val = make_sft_dataset(val_ds) | |
| class MultimodalDataCollator: | |
| def __init__(self, processor, max_length=None): | |
| self.processor = processor | |
| self.tokenizer = processor.tokenizer | |
| self.max_length = max_length | |
| def __call__(self, examples): | |
| texts = [example["text"] for example in examples] | |
| prompts = [example["prompt"] for example in examples] | |
| images = [example["image"] for example in examples] | |
| batch = self.processor( | |
| text=texts, | |
| images=images, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| labels = batch["input_ids"].clone() | |
| labels[labels == self.tokenizer.pad_token_id] = -100 | |
| # Mask the full prompt so SFT loss is computed only on the answer. | |
| # Searching for "ASSISTANT:" token ids is brittle because tokenization can | |
| # split the separator differently across models. | |
| prompt_batch = self.processor( | |
| text=prompts, | |
| images=images, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| prompt_lengths = prompt_batch["attention_mask"].sum(dim=1) | |
| for i, prompt_len in enumerate(prompt_lengths.tolist()): | |
| token_positions = batch["attention_mask"][i].nonzero(as_tuple=True)[0] | |
| labels[i, token_positions[:prompt_len]] = -100 | |
| batch["labels"] = labels | |
| # Remove text and image lists as Trainer only wants tensors | |
| return batch | |
| b2_micro_batch = int(config['train'].get('b2_batch_size', 1)) | |
| b2_grad_accum = int(config['train'].get('b2_gradient_accumulation_steps', max(config['train'].get('gradient_accumulation_steps', 2), 1))) | |
| b2_max_length = int(config['train'].get('b2_max_length', config['data'].get('max_question_len', 64) + config['data'].get('max_answer_len', 20) + 32)) | |
| training_args = build_training_arguments( | |
| TrainingArguments, | |
| output_dir="./checkpoints/B2", | |
| per_device_train_batch_size=b2_micro_batch, | |
| per_device_eval_batch_size=int(config['train'].get('b2_eval_batch_size', 1)), | |
| gradient_accumulation_steps=b2_grad_accum, | |
| num_train_epochs=config['train'].get('epochs', 3), | |
| learning_rate=float(config['train'].get('b2_lr', 2.0e-5)), | |
| lr_scheduler_type="cosine", | |
| warmup_steps=int(config['train'].get('b2_warmup_steps', 50)), | |
| bf16=True, | |
| fp16=False, | |
| gradient_checkpointing=True, | |
| remove_unused_columns=False, | |
| logging_steps=10, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| optim=config['train'].get('b2_optim', 'paged_adamw_8bit'), | |
| max_grad_norm=float(config['train'].get('grad_clip', 1.0)), | |
| dataloader_num_workers=int(config['train'].get('b2_num_workers', 4)), | |
| dataloader_pin_memory=bool(config['train'].get('pin_memory', True)), | |
| load_best_model_at_end=config['train'].get('b2_load_best_model_at_end', True), | |
| metric_for_best_model=config['train'].get('b2_metric_for_best', 'eval_loss'), | |
| greater_is_better=False, | |
| ) | |
| training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=sft_train, | |
| eval_dataset=sft_val, | |
| data_collator=MultimodalDataCollator(processor, max_length=b2_max_length) | |
| ) | |
| trainer.train() | |
| # [FIX] Đánh giá B2 sau khi train xong để có Accuracy, F1, BLEU cho biểu đồ so sánh | |
| from src.engine.medical_eval import evaluate_multimodal_vqa | |
| print("[INFO] Đang chạy đánh giá nghiệm thu trên tập Validation cho B2...") | |
| # Đưa model về evaluation mode | |
| model.eval() | |
| metrics = evaluate_multimodal_vqa( | |
| model, | |
| val_loader, | |
| device, | |
| processor, | |
| beam_width=config['eval'].get('beam_width_b', 1), | |
| beam_width_closed=config['eval'].get('beam_width_b_closed', 1), | |
| beam_width_open=config['eval'].get('beam_width_b_open', config['eval'].get('beam_width_b', 1)), | |
| max_new_tokens_closed=config['eval'].get('max_new_tokens_b_closed', 4), | |
| max_new_tokens_open=config['eval'].get('max_new_tokens_b_open', answer_max_words + 6), | |
| generation_batch_size=config['eval'].get('generation_batch_size_b', 1), | |
| max_words=answer_max_words, | |
| variant='B2' | |
| ) | |
| closed_eval = metrics.get('closed_eval', {}) | |
| open_eval = metrics.get('open_eval', {}) | |
| print(f"\n[RESULT B2 - CLOSED QUESTIONS]") | |
| print(f"Count: {closed_eval.get('count', 0)}") | |
| print(f"Accuracy: {closed_eval.get('accuracy', 0):.4f}") | |
| print(f"EM: {closed_eval.get('em', 0):.4f}") | |
| print(f"F1: {closed_eval.get('f1', 0):.4f}") | |
| print(f"\n[RESULT B2 - OPEN QUESTIONS]") | |
| print(f"Count: {open_eval.get('count', 0)}") | |
| print(f"Semantic: {open_eval.get('semantic', 0):.4f}") | |
| print(f"BERTScore: {open_eval.get('bert_score', 0):.4f}") | |
| print(f"F1: {open_eval.get('f1', 0):.4f}") | |
| print(f"ROUGE-L: {open_eval.get('rouge_l', 0):.4f}") | |
| if 'long_answers_eval' in metrics: | |
| print(f"\n[RESULT B2 - LONG METRICS]") | |
| print(f"Accuracy: {metrics['long_answers_eval'].get('accuracy', 0):.4f}") | |
| print(f"F1: {metrics['long_answers_eval'].get('f1', 0):.4f}") | |
| print(f"Semantic: {metrics['long_answers_eval'].get('semantic', 0):.4f}") | |
| print(f"BERTScore: {metrics['long_answers_eval'].get('bert_score', 0):.4f}") | |
| # Gắn thêm vào log_history cho wandb | |
| trainer.state.log_history.append({ | |
| "epoch": training_args.num_train_epochs, | |
| "val_long_accuracy": metrics['long_answers_eval'].get('accuracy', 0), | |
| "val_long_f1": metrics['long_answers_eval'].get('f1', 0), | |
| "val_long_semantic": metrics['long_answers_eval'].get('semantic', 0), | |
| "val_long_bertscore": metrics['long_answers_eval'].get('bert_score', 0), | |
| }) | |
| # Gắn kết quả vào history để compare_models.py đọc được | |
| final_epoch = training_args.num_train_epochs | |
| trainer.state.log_history.append({ | |
| "epoch": final_epoch, | |
| "val_accuracy_normalized": metrics.get('accuracy_normalized'), | |
| "val_f1_normalized": metrics.get('f1_normalized'), | |
| "val_bleu4_normalized": metrics.get('bleu4_normalized'), | |
| "val_bert_score_raw": metrics.get('bert_score_raw'), | |
| "val_semantic_raw": metrics.get('semantic_raw'), | |
| "val_closed_accuracy": closed_eval.get('accuracy', 0), | |
| "val_closed_em": closed_eval.get('em', 0), | |
| "val_closed_f1": closed_eval.get('f1', 0), | |
| "val_open_semantic": open_eval.get('semantic', 0), | |
| "val_open_bertscore": open_eval.get('bert_score', 0), | |
| "val_open_f1": open_eval.get('f1', 0), | |
| "val_open_rouge_l": open_eval.get('rouge_l', 0), | |
| }) | |
| save_history_records(history_dir, trainer.state.log_history) | |
| return | |
| elif args.variant == 'B1': | |
| # Zero-shot Evaluation cho Hướng B | |
| from src.engine.medical_eval import evaluate_multimodal_vqa | |
| wrapper = MultimodalVQA(model_id=config['model_b']['model_name']) | |
| model, processor = wrapper.load_model() | |
| beam_width = config['eval'].get('beam_width_b', 1) | |
| print(f"[INFO] Bắt đầu đánh giá B1 với Beam Width = {beam_width}...") | |
| metrics = evaluate_multimodal_vqa( | |
| model, | |
| val_loader, | |
| device, | |
| processor, | |
| beam_width=beam_width, | |
| beam_width_closed=config['eval'].get('beam_width_b_closed', beam_width), | |
| beam_width_open=config['eval'].get('beam_width_b_open', beam_width), | |
| max_new_tokens_closed=config['eval'].get('max_new_tokens_b_closed', 4), | |
| max_new_tokens_open=config['eval'].get('max_new_tokens_b_open', answer_max_words + 6), | |
| generation_batch_size=config['eval'].get('generation_batch_size_b', 1), | |
| max_words=answer_max_words, | |
| variant='B1' | |
| ) | |
| closed_eval = metrics.get('closed_eval', {}) | |
| open_eval = metrics.get('open_eval', {}) | |
| print(f"\n[RESULT B1 - CLOSED QUESTIONS]") | |
| print(f"Count: {closed_eval.get('count', 0)}") | |
| print(f"Accuracy: {closed_eval.get('accuracy', 0):.4f}") | |
| print(f"EM: {closed_eval.get('em', 0):.4f}") | |
| print(f"F1: {closed_eval.get('f1', 0):.4f}") | |
| print(f"\n[RESULT B1 - OPEN QUESTIONS]") | |
| print(f"Count: {open_eval.get('count', 0)}") | |
| print(f"Semantic: {open_eval.get('semantic', 0):.4f}") | |
| print(f"BERTScore: {open_eval.get('bert_score', 0):.4f}") | |
| print(f"F1: {open_eval.get('f1', 0):.4f}") | |
| print(f"ROUGE-L: {open_eval.get('rouge_l', 0):.4f}") | |
| if 'long_answers_eval' in metrics: | |
| print(f"\n[RESULT B1 - LONG METRICS]") | |
| print(f"Accuracy: {metrics['long_answers_eval'].get('accuracy', 0):.4f}") | |
| print(f"F1: {metrics['long_answers_eval'].get('f1', 0):.4f}") | |
| print(f"Semantic: {metrics['long_answers_eval'].get('semantic', 0):.4f}") | |
| print(f"BERTScore: {metrics['long_answers_eval'].get('bert_score', 0):.4f}") | |
| # [FIX] Lưu dưới dạng record có 'epoch' để compare_models.py có thể parse | |
| save_history_records(history_dir, [{ | |
| "epoch": 1, | |
| "variant": "B1", | |
| "beam_width": beam_width, | |
| "train_loss": 0.0, # zero-shot không có train loss | |
| "val_accuracy_normalized": float(metrics.get('accuracy_normalized', metrics.get('accuracy', 0))), | |
| "val_f1_normalized": float(metrics.get('f1_normalized', metrics.get('f1', 0))), | |
| "val_bleu4_normalized": float(metrics.get('bleu4_normalized', metrics.get('bleu4', 0))), | |
| "val_bert_score_raw": float(metrics.get('bert_score_raw', metrics.get('bert_score', 0))), | |
| "val_semantic_raw": float(metrics.get('semantic_raw', metrics.get('semantic', 0))), | |
| "val_closed_accuracy": float(closed_eval.get('accuracy', 0)), | |
| "val_closed_em": float(closed_eval.get('em', 0)), | |
| "val_closed_f1": float(closed_eval.get('f1', 0)), | |
| "val_open_semantic": float(open_eval.get('semantic', 0)), | |
| "val_open_bertscore": float(open_eval.get('bert_score', 0)), | |
| "val_open_f1": float(open_eval.get('f1', 0)), | |
| "val_open_rouge_l": float(open_eval.get('rouge_l', 0)), | |
| "metrics": metrics, | |
| }]) | |
| return | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="configs/medical_vqa.yaml") | |
| parser.add_argument("--variant", type=str, choices=['A1', 'A2', 'B1', 'B2', 'DPO', 'PPO'], required=True) | |
| parser.add_argument("--debug", action="store_true") | |
| parser.add_argument("--no_compare", action="store_true", | |
| help="Bỏ qua vẽ chart so sánh 5 model sau khi train xong") | |
| args = parser.parse_args() | |
| train(args) | |
| # Auto-generate comparison charts after training | |
| if not args.no_compare: | |
| import subprocess, sys | |
| log_dir = "logs/medical_vqa/history" | |
| out_dir = "results/charts" | |
| print(f"\n[INFO] 📊 Tự động vẽ biểu đồ so sánh 5 model → {out_dir}/") | |
| try: | |
| subprocess.run( | |
| [sys.executable, "scripts/compare_models.py", | |
| "--log_dir", log_dir, "--out", out_dir], | |
| check=False | |
| ) | |
| except Exception as e: | |
| print(f"[WARNING] compare_models.py thất bại: {e}") | |
| print(" Chạy thủ công: python scripts/compare_models.py") | |