import argparse import html import json from pathlib import Path import torch import yaml from datasets import load_dataset from peft import PeftModel from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoTokenizer, LlavaForConditionalGeneration, LlavaProcessor from src.data.medical_dataset import MedicalVQADataset from src.models.medical_vqa_model import MedicalVQAModelA from src.models.multimodal_vqa import MultimodalVQA from src.utils.text_utils import normalize_answer, postprocess_answer from src.utils.translator import MedicalTranslator from src.utils.visualization import MedicalImageTransform as MedicalTransform def vqa_collate_fn(batch): elem = batch[0] collated = {} for key in elem.keys(): if key in ["image", "input_ids", "attention_mask", "label_closed", "target_ids", "chosen_ids", "rejected_ids"]: collated[key] = torch.stack([item[key] for item in batch]) else: collated[key] = [item[key] for item in batch] return collated def normalize_for_metric(text: str) -> str: return str(text).strip().lower() def _normalize_closed_answer(question_vi: str, question_en: str, pred_vi: str, pred_en: str = "") -> str: question_vi_norm = normalize_answer(question_vi) question_en_norm = normalize_answer(question_en) pred_vi_norm = normalize_answer(pred_vi) pred_en_norm = normalize_answer(pred_en) combined = " ".join(part for part in [pred_vi_norm, pred_en_norm] if part).strip() is_normality_question = any( pattern in " ".join([question_vi_norm, question_en_norm]) for pattern in ["bình thường", "normal", "abnormal", "bat thuong"] ) if is_normality_question: if any(pattern in combined for pattern in ["không bình thường", "not normal"]): return "không" if any(pattern in combined.split() for pattern in ["có", "yes"]): return "có" if any(pattern in combined for pattern in [ "bình thường", "normal", "no significant abnormalities", "no abnormality", "unremarkable", "appears to be normal", "without significant abnormalities", "không phát hiện bất thường", ]): return "có" if any(pattern in combined for pattern in [ "bất thường", "abnormal", "abnormality detected", "fracture", "lesion", "mass", "effusion", "pneumothorax", ]): return "không" else: if any(pattern in combined for pattern in ["không", "no", "absent", "not seen", "negative", "none"]): return "không" if any(pattern in combined for pattern in ["có", "yes", "present", "detected", "positive"]): return "có" return pred_vi_norm or pred_en_norm _B1_FEW_SHOT = ( "Q: Is there cardiomegaly? A: yes\n" "Q: What organ is shown? A: lung\n" "Q: Is the aorta normal? A: no\n" "Q: What abnormality is present? A: pleural effusion\n" ) def _build_b1_prompt(question_en: str, max_words: int) -> str: return ( f"USER: \n" f"Answer each question with medical terminology only, " f"no more than {max_words} words, no full sentences.\n" f"{_B1_FEW_SHOT}" f"Q: {question_en} A: ASSISTANT:" ) _EN_VI_DIRECT = { "yes": "có", "no": "không", "present": "có", "absent": "không", "normal": "bình thường", "abnormal": "bất thường", "true": "có", "false": "không", "positive": "có", "negative": "không", "lung": "phổi", "lungs": "phổi", "heart": "tim", "liver": "gan", "spleen": "lách", "kidney": "thận", "brain": "não", "bladder": "bàng quang", "chest": "ngực", "abdomen": "bụng", "pelvis": "xương chậu", "spine": "cột sống", "rib": "xương sườn", "ribs": "xương sườn", "trachea": "khí quản", "aorta": "động mạch chủ", "diaphragm": "cơ hoành", "mediastinum": "trung thất", "chest x-ray": "x-quang ngực", "x-ray": "x-quang", "xray": "x-quang", "mri": "mri", "ct": "ct", "ultrasound": "siêu âm", "ct scan": "ct", "mri scan": "mri", "axial": "mặt phẳng ngang", "coronal": "mặt phẳng vành", "sagittal": "mặt phẳng dọc", "transverse": "mặt phẳng ngang", "cardiomegaly": "tim to", "pneumonia": "viêm phổi", "pleural effusion": "tràn dịch màng phổi", "pneumothorax": "tràn khí màng phổi", "fracture": "gãy xương", "edema": "phù nề", "pulmonary edema": "phù phổi", "consolidation": "đông đặc", "atelectasis": "xẹp phổi", "opacity": "mờ đục", "mass": "khối u", "nodule": "nốt", "lesion": "tổn thương", "tumor": "khối u", "effusion": "tràn dịch", "infiltrate": "thâm nhiễm", "fibrosis": "xơ hóa", "calcification": "vôi hóa", "carcinoma": "ung thư", "metastasis": "di căn", "bilateral": "hai bên", "unilateral": "một bên", "left": "trái", "right": "phải", "upper": "trên", "lower": "dưới", "upper left": "phía trên bên trái", "upper right": "phía trên bên phải", "lower left": "phía dưới bên trái", "lower right": "phía dưới bên phải", } def _extract_key_medical_term(raw_en: str, max_words: int) -> str: import re text = raw_en.strip().lower() prefixes = [ r"^the (image|scan|x-ray|xray|mri|ct|picture|photo|radiograph) (shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+", r"^based on the (image|scan|x-ray|mri|ct)\s*,?\s*", r"^in (this|the) (image|scan|x-ray|mri|ct)\s*,?\s*", r"^i (can see|observe|notice|see)\s+", r"^there (is|are)\s+(a |an |some )?", r"^(it |this )(shows?|is|appears?|looks?)\s+(like\s+)?", r"^the (patient|subject)\s+(has|shows?|presents?)\s+", r"^(a|an|the)\s+", ] for pat in prefixes: text = re.sub(pat, "", text) text = re.sub(r"[.!?,;:]+$", "", text).strip() text = re.sub(r"\s+", " ", text).strip() words = text.split() return " ".join(words[:max_words]) if words else raw_en.strip() def _en_to_vi_direct(en_text: str): return _EN_VI_DIRECT.get(en_text.strip().lower()) def predict_direction_a(model, dataloader, device, tokenizer, beam_width=1, max_len=32, max_words=10): model.eval() rows = [] with torch.no_grad(): for batch in tqdm(dataloader, desc="Predicting A"): images = batch["image"].to(device) input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["label_closed"] logits_closed, pred_ids = model.inference(images, input_ids, attention_mask, beam_width=beam_width, max_len=max_len) preds_text_raw = [postprocess_answer(t, max_words=max_words) for t in tokenizer.batch_decode(pred_ids, skip_special_tokens=True)] preds_text = list(preds_text_raw) closed_map = {0: "không", 1: "có"} closed_preds_idx = torch.argmax(logits_closed, dim=-1) for i in range(len(preds_text)): if labels[i].item() != -1: preds_text[i] = closed_map[closed_preds_idx[i].item()] preds_text[i] = postprocess_answer(preds_text[i], max_words=max_words) for i in range(len(preds_text)): rows.append({ "ground_truth": normalize_for_metric(postprocess_answer(batch["raw_answer"][i], max_words=max_words)), "ground_truth_en": normalize_for_metric(batch.get("raw_answer_en", [""])[i] if "raw_answer_en" in batch else ""), "predicted": normalize_for_metric(preds_text[i]), "predicted_raw": normalize_for_metric(preds_text_raw[i]), "predicted_display": normalize_for_metric(preds_text_raw[i]), "predicted_en": "", }) return rows def predict_direction_b(model, dataloader, device, processor, variant="B1", beam_width=1, beam_width_closed=1, beam_width_open=1, max_new_tokens_closed=4, max_new_tokens_open=16, generation_batch_size=1, max_words=10): model.eval() rows = [] translator = MedicalTranslator(device=device.type) wrapper = MultimodalVQA() def _run_generation(raw_images, prompts, sample_indices, num_beams, max_new_tokens): if not sample_indices: return [] decoded_outputs = [] chunk_size = generation_batch_size if num_beams > 1 else max(generation_batch_size, 2) for start in range(0, len(sample_indices), chunk_size): chunk_indices = sample_indices[start:start + chunk_size] text_subset = [prompts[i] for i in chunk_indices] image_subset = [raw_images[i] for i in chunk_indices] inputs = processor(text=text_subset, images=image_subset, return_tensors="pt", padding=True).to(device) if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) output_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, num_beams=num_beams, early_stopping=num_beams > 1, ) input_token_len = inputs.input_ids.shape[1] decoded_outputs.extend(processor.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)) del inputs, output_ids if device.type == "cuda": torch.cuda.empty_cache() return decoded_outputs with torch.no_grad(): for batch in tqdm(dataloader, desc=f"Predicting {variant}"): raw_images = batch["raw_image"] questions_vi = batch.get("raw_questions", []) questions_en = batch.get("raw_questions_en", []) refs_vi_raw = batch.get("raw_answer", []) refs_en_raw = batch.get("raw_answer_en", []) labels = batch["label_closed"] if variant == "B1": if not questions_en or any(not str(q).strip() for q in questions_en): questions_en = translator.translate_vi2en(questions_vi) prompts = [_build_b1_prompt(q, max_words) for q in questions_en] else: prompts = [wrapper.build_instruction_prompt(q, language="vi", include_answer=False) for q in questions_vi] preds_raw = [""] * len(prompts) closed_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl != -1] open_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl == -1] if variant == "B1": preds_raw = _run_generation(raw_images, prompts, list(range(len(prompts))), beam_width_open, max_new_tokens_open) else: for idx, pred in zip(closed_idx, _run_generation(raw_images, prompts, closed_idx, beam_width_closed, max_new_tokens_closed)): preds_raw[idx] = pred for idx, pred in zip(open_idx, _run_generation(raw_images, prompts, open_idx, beam_width_open, max_new_tokens_open)): preds_raw[idx] = pred preds_vi = [] preds_vi_display = [] preds_en_clean = [] if variant == "B1": preds_en_clean = [_extract_key_medical_term(p, 50) for p in preds_raw] needs_translate_idx = [] needs_translate_txt = [] for i, pred_en in enumerate(preds_en_clean): if labels[i].item() != -1: preds_vi.append(_normalize_closed_answer(questions_vi[i], questions_en[i], pred_en, pred_en)) else: vi_direct = _en_to_vi_direct(pred_en) if vi_direct is not None: preds_vi.append(postprocess_answer(vi_direct, max_words=max_words)) else: preds_vi.append(None) needs_translate_idx.append(i) needs_translate_txt.append(pred_en) if needs_translate_txt: translated = translator.translate_en2vi(needs_translate_txt) if isinstance(translated, str): translated = [translated] for idx, vi in zip(needs_translate_idx, translated): preds_vi[idx] = postprocess_answer(vi, max_words=max_words) preds_vi_display = list(preds_vi) else: preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_raw] for i, pred_vi in enumerate(preds_raw): if labels[i].item() != -1: preds_vi.append(_normalize_closed_answer(questions_vi[i], questions_en[i] if i < len(questions_en) else "", pred_vi)) else: preds_vi.append(pred_vi) preds_en_clean = [""] * len(preds_raw) preds_vi = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi] preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi_display] preds_vi_raw = list(preds_vi_display) refs_vi = [postprocess_answer(r, max_words=max_words) for r in refs_vi_raw] refs_en = [postprocess_answer(r, max_words=max_words) if r else "" for r in refs_en_raw] for i in range(len(preds_vi)): rows.append({ "ground_truth": normalize_for_metric(refs_vi[i]), "ground_truth_en": normalize_for_metric(refs_en[i]), "predicted": normalize_for_metric(preds_vi[i]), "predicted_raw": normalize_for_metric(preds_vi_raw[i]), "predicted_display": normalize_for_metric(preds_vi_display[i]), "predicted_en": normalize_for_metric(preds_en_clean[i] if i < len(preds_en_clean) else ""), }) return rows def select_best_adapter_checkpoint(checkpoint_root: str): checkpoint_root = Path(checkpoint_root) if not checkpoint_root.exists(): raise FileNotFoundError(f"Không tìm thấy thư mục checkpoint: {checkpoint_root}") checkpoint_dirs = sorted( p for p in checkpoint_root.glob("checkpoint-*") if (p / "adapter_config.json").exists() ) if not checkpoint_dirs: raise FileNotFoundError(f"Không có adapter checkpoint trong {checkpoint_root}") for state_file in sorted(checkpoint_root.glob("checkpoint-*/trainer_state.json"), reverse=True): try: state = json.loads(state_file.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): continue best_path = state.get("best_model_checkpoint") if best_path: best_dir = Path(best_path.replace("./", "")) if not best_dir.is_absolute(): best_dir = Path.cwd() / best_dir if (best_dir / "adapter_config.json").exists(): return best_dir.resolve() return checkpoint_dirs[-1].resolve() def load_config(config_path: str): with open(config_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) def build_dataset_and_loader(config, split: str, tokenizer): hf_repo = config["data"].get("hf_dataset") if not hf_repo: raise ValueError("Script này hiện yêu cầu dataset từ Hugging Face Hub.") dataset_dict = load_dataset(hf_repo) if split not in dataset_dict: raise ValueError(f"Dataset không có split '{split}'. Các split hiện có: {list(dataset_dict.keys())}") answer_max_words = int(config["data"].get("answer_max_words", 10)) transform = MedicalTransform(size=config["data"]["image_size"]) dataset = MedicalVQADataset( hf_dataset=dataset_dict[split], tokenizer=tokenizer, transform=transform, max_seq_len=config["data"]["max_question_len"], max_ans_len=config["data"]["max_answer_len"], answer_max_words=answer_max_words, ) loader = DataLoader( dataset, batch_size=int(config["train"].get("eval_batch_size", 8)), shuffle=False, collate_fn=vqa_collate_fn, ) return dataset_dict[split], loader def load_direction_a_model(variant: str, config, tokenizer, device): ckpt_path = Path(f"checkpoints/medical_vqa_{variant}_best.pth") if not ckpt_path.exists(): resume_path = Path(f"checkpoints/medical_vqa_{variant}_resume.pth") ckpt_path = resume_path if resume_path.exists() else None if ckpt_path is None or not ckpt_path.exists(): raise FileNotFoundError(f"Không tìm thấy checkpoint cho {variant}") decoder_type = "lstm" if variant == "A1" else "transformer" model = MedicalVQAModelA( decoder_type=decoder_type, vocab_size=len(tokenizer), hidden_size=config["model_a"].get("hidden_size", 768), phobert_model=config["model_a"].get("phobert_model", "vinai/phobert-base"), ).to(device) payload = torch.load(ckpt_path, map_location=device) state_dict = payload.get("model_state_dict") if isinstance(payload, dict) and "model_state_dict" in payload else payload model.load_state_dict(state_dict, strict=False) model.eval() return model, str(ckpt_path) def build_llava_base_and_processor(config): wrapper = MultimodalVQA( model_id=config["model_b"]["model_name"], lora_r=int(config["model_b"].get("lora_r", 16)), lora_alpha=int(config["model_b"].get("lora_alpha", 32)), lora_dropout=float(config["model_b"].get("lora_dropout", 0.05)), lora_target_modules=config["model_b"].get("lora_target_modules"), ) processor = LlavaProcessor.from_pretrained(wrapper.model_id) processor.tokenizer.padding_side = "left" base_model = LlavaForConditionalGeneration.from_pretrained( wrapper.model_id, quantization_config=wrapper.bnb_config, device_map="auto", ) base_model.config.use_cache = False return wrapper, processor, base_model def load_direction_b_model(variant: str, config): wrapper, processor, base_model = build_llava_base_and_processor(config) if variant == "B1": model = base_model checkpoint = config["model_b"]["model_name"] elif variant == "B2": ckpt_dir = select_best_adapter_checkpoint(config["train"].get("b2_output_dir", "./checkpoints/B2")) model = PeftModel.from_pretrained(base_model, str(ckpt_dir), is_trainable=False) checkpoint = str(ckpt_dir) elif variant == "DPO": ckpt_dir = Path("checkpoints/DPO/final_adapter") model = PeftModel.from_pretrained(base_model, str(ckpt_dir), is_trainable=False) checkpoint = str(ckpt_dir) elif variant == "PPO": ckpt_dir = Path("checkpoints/PPO/final_adapter") model = PeftModel.from_pretrained(base_model, str(ckpt_dir), is_trainable=False) checkpoint = str(ckpt_dir) else: raise ValueError(f"Variant không hỗ trợ trong script này: {variant}") model.eval() return model, processor, checkpoint def convert_prediction_rows(hf_split, prediction_rows, variant: str, checkpoint: str): rows = [] for idx, item in enumerate(hf_split): pred_row = prediction_rows[idx] if idx < len(prediction_rows) else {} rows.append({ "idx": idx, "variant": variant, "checkpoint": checkpoint, "id": item.get("id"), "source": item.get("source"), "image_name": item.get("image_name"), "answer_type": item.get("answer_type"), "question": item.get("question"), "question_vi": item.get("question_vi"), "ground_truth": pred_row.get("ground_truth", ""), "ground_truth_en": pred_row.get("ground_truth_en", ""), "predicted": pred_row.get("predicted", ""), "predicted_raw": pred_row.get("predicted_raw", ""), "predicted_display": pred_row.get("predicted_display", ""), "predicted_en": pred_row.get("predicted_en", ""), }) return rows def build_side_by_side(hf_split, prediction_map): variants = list(prediction_map.keys()) combined = [] for idx, item in enumerate(hf_split): row = { "idx": idx, "id": item.get("id"), "source": item.get("source"), "image_name": item.get("image_name"), "answer_type": item.get("answer_type"), "question": item.get("question"), "question_vi": item.get("question_vi"), "ground_truth": item.get("answer_vi"), "ground_truth_full_vi": item.get("answer_full_vi"), } for variant in variants: preds = prediction_map[variant] row[f"{variant}_predicted"] = preds[idx]["predicted"] if idx < len(preds) else "" row[f"{variant}_predicted_raw"] = preds[idx]["predicted_raw"] if idx < len(preds) else "" combined.append(row) return combined def export_preview_images(hf_split, output_dir: Path, split: str, image_size: int = 256): image_dir = output_dir / f"{split}_images" image_dir.mkdir(parents=True, exist_ok=True) image_refs = [] for idx, item in enumerate(hf_split): image = item["image"] if image.mode != "RGB": image = image.convert("RGB") preview = image.copy() preview.thumbnail((image_size, image_size)) image_name = Path(str(item.get("image_name") or f"{idx}.jpg")).name save_name = f"{idx:04d}_{image_name}" save_path = image_dir / save_name preview.save(save_path, format="JPEG", quality=90) image_refs.append(save_path.relative_to(output_dir).as_posix()) return image_refs def render_compare_html(compare_rows, variants, output_dir: Path, split: str): html_path = output_dir / f"compare_{split}_{'_'.join(variants)}.html" cards = [] for row in compare_rows: img_src = html.escape(row.get("image_preview", "")) question_vi = html.escape(str(row.get("question_vi", ""))) question_en = html.escape(str(row.get("question", ""))) answer_type = html.escape(str(row.get("answer_type", ""))) ground_truth = html.escape(str(row.get("ground_truth", ""))) image_name = html.escape(str(row.get("image_name", ""))) preds_html = [] for variant in variants: pred = html.escape(str(row.get(f"{variant}_predicted", ""))) raw = html.escape(str(row.get(f"{variant}_predicted_raw", ""))) preds_html.append( f"""
{variant}
Pred: {pred}
Raw: {raw}
""" ) cards.append( f"""
{image_name}
Idx: {row.get("idx", "")}
Image: {image_name}
Type: {answer_type}
Q (VI): {question_vi}
Q (EN): {question_en}
GT: {ground_truth}
{''.join(preds_html)}
""" ) page = f""" Compare Predictions - {split}

So sánh prediction {html.escape(split)}

Models: {html.escape(', '.join(variants))}
{''.join(cards)}
""" html_path.write_text(page, encoding="utf-8") return html_path def main(): parser = argparse.ArgumentParser(description="Xuất prediction của A1/A2/B1/B2/DPO/PPO để so sánh.") parser.add_argument("--config", default="configs/medical_vqa.yaml") parser.add_argument("--split", default="test", choices=["train", "validation", "test"]) parser.add_argument("--variants", nargs="+", default=["A1", "A2", "B1", "B2"]) parser.add_argument("--output-dir", default="results/predictions") args = parser.parse_args() config = load_config(args.config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(config["model_a"]["phobert_model"]) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token hf_split, dataloader = build_dataset_and_loader(config, args.split, tokenizer) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) image_refs = export_preview_images(hf_split, output_dir, args.split) summary = {} prediction_map = {} for variant in args.variants: print(f"[INFO] Đang chạy prediction cho {variant} trên split '{args.split}'...") if variant in {"A1", "A2"}: model, checkpoint = load_direction_a_model(variant, config, tokenizer, device) prediction_rows = predict_direction_a( model, dataloader, device, tokenizer, beam_width=int(config["eval"].get("beam_width_a", 5)), max_len=int(config["data"].get("max_answer_len", 20)), max_words=int(config["data"].get("answer_max_words", 10)), ) else: model, processor, checkpoint = load_direction_b_model(variant, config) prediction_rows = predict_direction_b( model, dataloader, device, processor, beam_width=int(config["eval"].get("beam_width_b", 5)), beam_width_closed=int(config["eval"].get("beam_width_b_closed", 1)), beam_width_open=int(config["eval"].get("beam_width_b_open", config["eval"].get("beam_width_b", 5))), max_new_tokens_closed=int(config["eval"].get("max_new_tokens_b_closed", 4)), max_new_tokens_open=int(config["eval"].get("max_new_tokens_b_open", int(config["data"].get("answer_max_words", 10)) + 6)), generation_batch_size=int(config["eval"].get("generation_batch_size_b", 1)), max_words=int(config["data"].get("answer_max_words", 10)), variant=variant, ) rows = convert_prediction_rows(hf_split, prediction_rows, variant, checkpoint) prediction_map[variant] = rows out_path = output_dir / f"{variant}_{args.split}_predictions.json" with open(out_path, "w", encoding="utf-8") as f: json.dump(rows, f, ensure_ascii=False, indent=2) summary[variant] = { "checkpoint": checkpoint, "num_predictions": len(rows), } print(f"[SUCCESS] Đã lưu {out_path}") del model if variant in {"B1", "B2", "DPO", "PPO"}: del processor if torch.cuda.is_available(): torch.cuda.empty_cache() compare_rows = build_side_by_side(hf_split, prediction_map) for idx, row in enumerate(compare_rows): row["image_preview"] = image_refs[idx] if idx < len(image_refs) else "" compare_path = output_dir / f"compare_{args.split}_{'_'.join(args.variants)}.json" with open(compare_path, "w", encoding="utf-8") as f: json.dump(compare_rows, f, ensure_ascii=False, indent=2) summary_path = output_dir / f"summary_{args.split}_{'_'.join(args.variants)}.json" with open(summary_path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) html_path = render_compare_html(compare_rows, args.variants, output_dir, args.split) print(f"[SUCCESS] Đã lưu file so sánh tại {compare_path}") print(f"[SUCCESS] Đã lưu summary tại {summary_path}") print(f"[SUCCESS] Đã lưu HTML hiển thị ảnh tại {html_path}") if __name__ == "__main__": main()