Medical-VQA / scripts /export_predictions.py
SpringWang08's picture
Deploy Gradio notebook-style Medical VQA app
5551585 verified
import argparse
import html
import json
from pathlib import Path
import torch
import yaml
from datasets import load_dataset
from peft import PeftModel
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, LlavaForConditionalGeneration, LlavaProcessor
from src.data.medical_dataset import MedicalVQADataset
from src.models.medical_vqa_model import MedicalVQAModelA
from src.models.multimodal_vqa import MultimodalVQA
from src.utils.text_utils import normalize_answer, postprocess_answer
from src.utils.translator import MedicalTranslator
from src.utils.visualization import MedicalImageTransform as MedicalTransform
def vqa_collate_fn(batch):
elem = batch[0]
collated = {}
for key in elem.keys():
if key in ["image", "input_ids", "attention_mask", "label_closed", "target_ids", "chosen_ids", "rejected_ids"]:
collated[key] = torch.stack([item[key] for item in batch])
else:
collated[key] = [item[key] for item in batch]
return collated
def normalize_for_metric(text: str) -> str:
return str(text).strip().lower()
def _normalize_closed_answer(question_vi: str, question_en: str, pred_vi: str, pred_en: str = "") -> str:
question_vi_norm = normalize_answer(question_vi)
question_en_norm = normalize_answer(question_en)
pred_vi_norm = normalize_answer(pred_vi)
pred_en_norm = normalize_answer(pred_en)
combined = " ".join(part for part in [pred_vi_norm, pred_en_norm] if part).strip()
is_normality_question = any(
pattern in " ".join([question_vi_norm, question_en_norm])
for pattern in ["bình thường", "normal", "abnormal", "bat thuong"]
)
if is_normality_question:
if any(pattern in combined for pattern in ["không bình thường", "not normal"]):
return "không"
if any(pattern in combined.split() for pattern in ["có", "yes"]):
return "có"
if any(pattern in combined for pattern in [
"bình thường", "normal", "no significant abnormalities", "no abnormality",
"unremarkable", "appears to be normal", "without significant abnormalities",
"không phát hiện bất thường",
]):
return "có"
if any(pattern in combined for pattern in [
"bất thường", "abnormal", "abnormality detected", "fracture", "lesion",
"mass", "effusion", "pneumothorax",
]):
return "không"
else:
if any(pattern in combined for pattern in ["không", "no", "absent", "not seen", "negative", "none"]):
return "không"
if any(pattern in combined for pattern in ["có", "yes", "present", "detected", "positive"]):
return "có"
return pred_vi_norm or pred_en_norm
_B1_FEW_SHOT = (
"Q: Is there cardiomegaly? A: yes\n"
"Q: What organ is shown? A: lung\n"
"Q: Is the aorta normal? A: no\n"
"Q: What abnormality is present? A: pleural effusion\n"
)
def _build_b1_prompt(question_en: str, max_words: int) -> str:
return (
f"USER: <image>\n"
f"Answer each question with medical terminology only, "
f"no more than {max_words} words, no full sentences.\n"
f"{_B1_FEW_SHOT}"
f"Q: {question_en} A: ASSISTANT:"
)
_EN_VI_DIRECT = {
"yes": "có", "no": "không", "present": "có", "absent": "không",
"normal": "bình thường", "abnormal": "bất thường", "true": "có", "false": "không",
"positive": "có", "negative": "không", "lung": "phổi", "lungs": "phổi",
"heart": "tim", "liver": "gan", "spleen": "lách", "kidney": "thận", "brain": "não",
"bladder": "bàng quang", "chest": "ngực", "abdomen": "bụng", "pelvis": "xương chậu",
"spine": "cột sống", "rib": "xương sườn", "ribs": "xương sườn", "trachea": "khí quản",
"aorta": "động mạch chủ", "diaphragm": "cơ hoành", "mediastinum": "trung thất",
"chest x-ray": "x-quang ngực", "x-ray": "x-quang", "xray": "x-quang", "mri": "mri",
"ct": "ct", "ultrasound": "siêu âm", "ct scan": "ct", "mri scan": "mri",
"axial": "mặt phẳng ngang", "coronal": "mặt phẳng vành", "sagittal": "mặt phẳng dọc",
"transverse": "mặt phẳng ngang", "cardiomegaly": "tim to", "pneumonia": "viêm phổi",
"pleural effusion": "tràn dịch màng phổi", "pneumothorax": "tràn khí màng phổi",
"fracture": "gãy xương", "edema": "phù nề", "pulmonary edema": "phù phổi",
"consolidation": "đông đặc", "atelectasis": "xẹp phổi", "opacity": "mờ đục",
"mass": "khối u", "nodule": "nốt", "lesion": "tổn thương", "tumor": "khối u",
"effusion": "tràn dịch", "infiltrate": "thâm nhiễm", "fibrosis": "xơ hóa",
"calcification": "vôi hóa", "carcinoma": "ung thư", "metastasis": "di căn",
"bilateral": "hai bên", "unilateral": "một bên", "left": "trái", "right": "phải",
"upper": "trên", "lower": "dưới", "upper left": "phía trên bên trái", "upper right": "phía trên bên phải",
"lower left": "phía dưới bên trái", "lower right": "phía dưới bên phải",
}
def _extract_key_medical_term(raw_en: str, max_words: int) -> str:
import re
text = raw_en.strip().lower()
prefixes = [
r"^the (image|scan|x-ray|xray|mri|ct|picture|photo|radiograph) (shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+",
r"^based on the (image|scan|x-ray|mri|ct)\s*,?\s*",
r"^in (this|the) (image|scan|x-ray|mri|ct)\s*,?\s*",
r"^i (can see|observe|notice|see)\s+",
r"^there (is|are)\s+(a |an |some )?",
r"^(it |this )(shows?|is|appears?|looks?)\s+(like\s+)?",
r"^the (patient|subject)\s+(has|shows?|presents?)\s+",
r"^(a|an|the)\s+",
]
for pat in prefixes:
text = re.sub(pat, "", text)
text = re.sub(r"[.!?,;:]+$", "", text).strip()
text = re.sub(r"\s+", " ", text).strip()
words = text.split()
return " ".join(words[:max_words]) if words else raw_en.strip()
def _en_to_vi_direct(en_text: str):
return _EN_VI_DIRECT.get(en_text.strip().lower())
def predict_direction_a(model, dataloader, device, tokenizer, beam_width=1, max_len=32, max_words=10):
model.eval()
rows = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Predicting A"):
images = batch["image"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label_closed"]
logits_closed, pred_ids = model.inference(images, input_ids, attention_mask, beam_width=beam_width, max_len=max_len)
preds_text_raw = [postprocess_answer(t, max_words=max_words) for t in tokenizer.batch_decode(pred_ids, skip_special_tokens=True)]
preds_text = list(preds_text_raw)
closed_map = {0: "không", 1: "có"}
closed_preds_idx = torch.argmax(logits_closed, dim=-1)
for i in range(len(preds_text)):
if labels[i].item() != -1:
preds_text[i] = closed_map[closed_preds_idx[i].item()]
preds_text[i] = postprocess_answer(preds_text[i], max_words=max_words)
for i in range(len(preds_text)):
rows.append({
"ground_truth": normalize_for_metric(postprocess_answer(batch["raw_answer"][i], max_words=max_words)),
"ground_truth_en": normalize_for_metric(batch.get("raw_answer_en", [""])[i] if "raw_answer_en" in batch else ""),
"predicted": normalize_for_metric(preds_text[i]),
"predicted_raw": normalize_for_metric(preds_text_raw[i]),
"predicted_display": normalize_for_metric(preds_text_raw[i]),
"predicted_en": "",
})
return rows
def predict_direction_b(model, dataloader, device, processor, variant="B1", beam_width=1, beam_width_closed=1, beam_width_open=1, max_new_tokens_closed=4, max_new_tokens_open=16, generation_batch_size=1, max_words=10):
model.eval()
rows = []
translator = MedicalTranslator(device=device.type)
wrapper = MultimodalVQA()
def _run_generation(raw_images, prompts, sample_indices, num_beams, max_new_tokens):
if not sample_indices:
return []
decoded_outputs = []
chunk_size = generation_batch_size if num_beams > 1 else max(generation_batch_size, 2)
for start in range(0, len(sample_indices), chunk_size):
chunk_indices = sample_indices[start:start + chunk_size]
text_subset = [prompts[i] for i in chunk_indices]
image_subset = [raw_images[i] for i in chunk_indices]
inputs = processor(text=text_subset, images=image_subset, return_tensors="pt", padding=True).to(device)
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=num_beams,
early_stopping=num_beams > 1,
)
input_token_len = inputs.input_ids.shape[1]
decoded_outputs.extend(processor.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True))
del inputs, output_ids
if device.type == "cuda":
torch.cuda.empty_cache()
return decoded_outputs
with torch.no_grad():
for batch in tqdm(dataloader, desc=f"Predicting {variant}"):
raw_images = batch["raw_image"]
questions_vi = batch.get("raw_questions", [])
questions_en = batch.get("raw_questions_en", [])
refs_vi_raw = batch.get("raw_answer", [])
refs_en_raw = batch.get("raw_answer_en", [])
labels = batch["label_closed"]
if variant == "B1":
if not questions_en or any(not str(q).strip() for q in questions_en):
questions_en = translator.translate_vi2en(questions_vi)
prompts = [_build_b1_prompt(q, max_words) for q in questions_en]
else:
prompts = [wrapper.build_instruction_prompt(q, language="vi", include_answer=False) for q in questions_vi]
preds_raw = [""] * len(prompts)
closed_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl != -1]
open_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl == -1]
if variant == "B1":
preds_raw = _run_generation(raw_images, prompts, list(range(len(prompts))), beam_width_open, max_new_tokens_open)
else:
for idx, pred in zip(closed_idx, _run_generation(raw_images, prompts, closed_idx, beam_width_closed, max_new_tokens_closed)):
preds_raw[idx] = pred
for idx, pred in zip(open_idx, _run_generation(raw_images, prompts, open_idx, beam_width_open, max_new_tokens_open)):
preds_raw[idx] = pred
preds_vi = []
preds_vi_display = []
preds_en_clean = []
if variant == "B1":
preds_en_clean = [_extract_key_medical_term(p, 50) for p in preds_raw]
needs_translate_idx = []
needs_translate_txt = []
for i, pred_en in enumerate(preds_en_clean):
if labels[i].item() != -1:
preds_vi.append(_normalize_closed_answer(questions_vi[i], questions_en[i], pred_en, pred_en))
else:
vi_direct = _en_to_vi_direct(pred_en)
if vi_direct is not None:
preds_vi.append(postprocess_answer(vi_direct, max_words=max_words))
else:
preds_vi.append(None)
needs_translate_idx.append(i)
needs_translate_txt.append(pred_en)
if needs_translate_txt:
translated = translator.translate_en2vi(needs_translate_txt)
if isinstance(translated, str):
translated = [translated]
for idx, vi in zip(needs_translate_idx, translated):
preds_vi[idx] = postprocess_answer(vi, max_words=max_words)
preds_vi_display = list(preds_vi)
else:
preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_raw]
for i, pred_vi in enumerate(preds_raw):
if labels[i].item() != -1:
preds_vi.append(_normalize_closed_answer(questions_vi[i], questions_en[i] if i < len(questions_en) else "", pred_vi))
else:
preds_vi.append(pred_vi)
preds_en_clean = [""] * len(preds_raw)
preds_vi = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi]
preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi_display]
preds_vi_raw = list(preds_vi_display)
refs_vi = [postprocess_answer(r, max_words=max_words) for r in refs_vi_raw]
refs_en = [postprocess_answer(r, max_words=max_words) if r else "" for r in refs_en_raw]
for i in range(len(preds_vi)):
rows.append({
"ground_truth": normalize_for_metric(refs_vi[i]),
"ground_truth_en": normalize_for_metric(refs_en[i]),
"predicted": normalize_for_metric(preds_vi[i]),
"predicted_raw": normalize_for_metric(preds_vi_raw[i]),
"predicted_display": normalize_for_metric(preds_vi_display[i]),
"predicted_en": normalize_for_metric(preds_en_clean[i] if i < len(preds_en_clean) else ""),
})
return rows
def select_best_adapter_checkpoint(checkpoint_root: str):
checkpoint_root = Path(checkpoint_root)
if not checkpoint_root.exists():
raise FileNotFoundError(f"Không tìm thấy thư mục checkpoint: {checkpoint_root}")
checkpoint_dirs = sorted(
p for p in checkpoint_root.glob("checkpoint-*")
if (p / "adapter_config.json").exists()
)
if not checkpoint_dirs:
raise FileNotFoundError(f"Không có adapter checkpoint trong {checkpoint_root}")
for state_file in sorted(checkpoint_root.glob("checkpoint-*/trainer_state.json"), reverse=True):
try:
state = json.loads(state_file.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
continue
best_path = state.get("best_model_checkpoint")
if best_path:
best_dir = Path(best_path.replace("./", ""))
if not best_dir.is_absolute():
best_dir = Path.cwd() / best_dir
if (best_dir / "adapter_config.json").exists():
return best_dir.resolve()
return checkpoint_dirs[-1].resolve()
def load_config(config_path: str):
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def build_dataset_and_loader(config, split: str, tokenizer):
hf_repo = config["data"].get("hf_dataset")
if not hf_repo:
raise ValueError("Script này hiện yêu cầu dataset từ Hugging Face Hub.")
dataset_dict = load_dataset(hf_repo)
if split not in dataset_dict:
raise ValueError(f"Dataset không có split '{split}'. Các split hiện có: {list(dataset_dict.keys())}")
answer_max_words = int(config["data"].get("answer_max_words", 10))
transform = MedicalTransform(size=config["data"]["image_size"])
dataset = MedicalVQADataset(
hf_dataset=dataset_dict[split],
tokenizer=tokenizer,
transform=transform,
max_seq_len=config["data"]["max_question_len"],
max_ans_len=config["data"]["max_answer_len"],
answer_max_words=answer_max_words,
)
loader = DataLoader(
dataset,
batch_size=int(config["train"].get("eval_batch_size", 8)),
shuffle=False,
collate_fn=vqa_collate_fn,
)
return dataset_dict[split], loader
def load_direction_a_model(variant: str, config, tokenizer, device):
ckpt_path = Path(f"checkpoints/medical_vqa_{variant}_best.pth")
if not ckpt_path.exists():
resume_path = Path(f"checkpoints/medical_vqa_{variant}_resume.pth")
ckpt_path = resume_path if resume_path.exists() else None
if ckpt_path is None or not ckpt_path.exists():
raise FileNotFoundError(f"Không tìm thấy checkpoint cho {variant}")
decoder_type = "lstm" if variant == "A1" else "transformer"
model = MedicalVQAModelA(
decoder_type=decoder_type,
vocab_size=len(tokenizer),
hidden_size=config["model_a"].get("hidden_size", 768),
phobert_model=config["model_a"].get("phobert_model", "vinai/phobert-base"),
).to(device)
payload = torch.load(ckpt_path, map_location=device)
state_dict = payload.get("model_state_dict") if isinstance(payload, dict) and "model_state_dict" in payload else payload
model.load_state_dict(state_dict, strict=False)
model.eval()
return model, str(ckpt_path)
def build_llava_base_and_processor(config):
wrapper = MultimodalVQA(
model_id=config["model_b"]["model_name"],
lora_r=int(config["model_b"].get("lora_r", 16)),
lora_alpha=int(config["model_b"].get("lora_alpha", 32)),
lora_dropout=float(config["model_b"].get("lora_dropout", 0.05)),
lora_target_modules=config["model_b"].get("lora_target_modules"),
)
processor = LlavaProcessor.from_pretrained(wrapper.model_id)
processor.tokenizer.padding_side = "left"
base_model = LlavaForConditionalGeneration.from_pretrained(
wrapper.model_id,
quantization_config=wrapper.bnb_config,
device_map="auto",
)
base_model.config.use_cache = False
return wrapper, processor, base_model
def load_direction_b_model(variant: str, config):
wrapper, processor, base_model = build_llava_base_and_processor(config)
if variant == "B1":
model = base_model
checkpoint = config["model_b"]["model_name"]
elif variant == "B2":
ckpt_dir = select_best_adapter_checkpoint(config["train"].get("b2_output_dir", "./checkpoints/B2"))
model = PeftModel.from_pretrained(base_model, str(ckpt_dir), is_trainable=False)
checkpoint = str(ckpt_dir)
elif variant == "DPO":
ckpt_dir = Path("checkpoints/DPO/final_adapter")
model = PeftModel.from_pretrained(base_model, str(ckpt_dir), is_trainable=False)
checkpoint = str(ckpt_dir)
elif variant == "PPO":
ckpt_dir = Path("checkpoints/PPO/final_adapter")
model = PeftModel.from_pretrained(base_model, str(ckpt_dir), is_trainable=False)
checkpoint = str(ckpt_dir)
else:
raise ValueError(f"Variant không hỗ trợ trong script này: {variant}")
model.eval()
return model, processor, checkpoint
def convert_prediction_rows(hf_split, prediction_rows, variant: str, checkpoint: str):
rows = []
for idx, item in enumerate(hf_split):
pred_row = prediction_rows[idx] if idx < len(prediction_rows) else {}
rows.append({
"idx": idx,
"variant": variant,
"checkpoint": checkpoint,
"id": item.get("id"),
"source": item.get("source"),
"image_name": item.get("image_name"),
"answer_type": item.get("answer_type"),
"question": item.get("question"),
"question_vi": item.get("question_vi"),
"ground_truth": pred_row.get("ground_truth", ""),
"ground_truth_en": pred_row.get("ground_truth_en", ""),
"predicted": pred_row.get("predicted", ""),
"predicted_raw": pred_row.get("predicted_raw", ""),
"predicted_display": pred_row.get("predicted_display", ""),
"predicted_en": pred_row.get("predicted_en", ""),
})
return rows
def build_side_by_side(hf_split, prediction_map):
variants = list(prediction_map.keys())
combined = []
for idx, item in enumerate(hf_split):
row = {
"idx": idx,
"id": item.get("id"),
"source": item.get("source"),
"image_name": item.get("image_name"),
"answer_type": item.get("answer_type"),
"question": item.get("question"),
"question_vi": item.get("question_vi"),
"ground_truth": item.get("answer_vi"),
"ground_truth_full_vi": item.get("answer_full_vi"),
}
for variant in variants:
preds = prediction_map[variant]
row[f"{variant}_predicted"] = preds[idx]["predicted"] if idx < len(preds) else ""
row[f"{variant}_predicted_raw"] = preds[idx]["predicted_raw"] if idx < len(preds) else ""
combined.append(row)
return combined
def export_preview_images(hf_split, output_dir: Path, split: str, image_size: int = 256):
image_dir = output_dir / f"{split}_images"
image_dir.mkdir(parents=True, exist_ok=True)
image_refs = []
for idx, item in enumerate(hf_split):
image = item["image"]
if image.mode != "RGB":
image = image.convert("RGB")
preview = image.copy()
preview.thumbnail((image_size, image_size))
image_name = Path(str(item.get("image_name") or f"{idx}.jpg")).name
save_name = f"{idx:04d}_{image_name}"
save_path = image_dir / save_name
preview.save(save_path, format="JPEG", quality=90)
image_refs.append(save_path.relative_to(output_dir).as_posix())
return image_refs
def render_compare_html(compare_rows, variants, output_dir: Path, split: str):
html_path = output_dir / f"compare_{split}_{'_'.join(variants)}.html"
cards = []
for row in compare_rows:
img_src = html.escape(row.get("image_preview", ""))
question_vi = html.escape(str(row.get("question_vi", "")))
question_en = html.escape(str(row.get("question", "")))
answer_type = html.escape(str(row.get("answer_type", "")))
ground_truth = html.escape(str(row.get("ground_truth", "")))
image_name = html.escape(str(row.get("image_name", "")))
preds_html = []
for variant in variants:
pred = html.escape(str(row.get(f"{variant}_predicted", "")))
raw = html.escape(str(row.get(f"{variant}_predicted_raw", "")))
preds_html.append(
f"""
<div class="pred">
<div class="pred-title">{variant}</div>
<div><strong>Pred:</strong> {pred}</div>
<div class="muted"><strong>Raw:</strong> {raw}</div>
</div>
"""
)
cards.append(
f"""
<article class="card">
<div class="media">
<img src="{img_src}" alt="{image_name}" loading="lazy" />
<div class="meta">
<div><strong>Idx:</strong> {row.get("idx", "")}</div>
<div><strong>Image:</strong> {image_name}</div>
<div><strong>Type:</strong> {answer_type}</div>
</div>
</div>
<div class="content">
<div><strong>Q (VI):</strong> {question_vi}</div>
<div class="muted"><strong>Q (EN):</strong> {question_en}</div>
<div class="gt"><strong>GT:</strong> {ground_truth}</div>
<div class="pred-grid">
{''.join(preds_html)}
</div>
</div>
</article>
"""
)
page = f"""<!DOCTYPE html>
<html lang="vi">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Compare Predictions - {split}</title>
<style>
:root {{
--bg: #f5f1e8;
--panel: #fffdf8;
--ink: #1d1b16;
--muted: #6e675c;
--line: #d8cfbf;
--accent: #8f3d2e;
}}
* {{ box-sizing: border-box; }}
body {{
margin: 0;
font-family: Georgia, "Times New Roman", serif;
background: linear-gradient(180deg, #efe7d7 0%, var(--bg) 100%);
color: var(--ink);
}}
.wrap {{
width: min(1200px, calc(100vw - 32px));
margin: 24px auto 40px;
}}
h1 {{
margin: 0 0 8px;
font-size: 32px;
}}
.sub {{
color: var(--muted);
margin-bottom: 24px;
}}
.card {{
display: grid;
grid-template-columns: 260px 1fr;
gap: 18px;
background: var(--panel);
border: 1px solid var(--line);
border-radius: 18px;
padding: 16px;
margin-bottom: 16px;
box-shadow: 0 10px 30px rgba(40, 28, 12, 0.06);
}}
.media img {{
width: 100%;
border-radius: 12px;
display: block;
border: 1px solid var(--line);
background: #fff;
}}
.meta {{
margin-top: 10px;
color: var(--muted);
font-size: 14px;
line-height: 1.5;
}}
.content {{
display: flex;
flex-direction: column;
gap: 8px;
line-height: 1.5;
}}
.muted {{
color: var(--muted);
}}
.gt {{
padding: 10px 12px;
background: #f6efe4;
border-left: 4px solid var(--accent);
border-radius: 8px;
}}
.pred-grid {{
display: grid;
grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 12px;
margin-top: 8px;
}}
.pred {{
border: 1px solid var(--line);
border-radius: 12px;
padding: 12px;
background: #fff;
}}
.pred-title {{
font-weight: 700;
margin-bottom: 6px;
color: var(--accent);
}}
@media (max-width: 820px) {{
.card {{
grid-template-columns: 1fr;
}}
.pred-grid {{
grid-template-columns: 1fr;
}}
}}
</style>
</head>
<body>
<main class="wrap">
<h1>So sánh prediction {html.escape(split)}</h1>
<div class="sub">Models: {html.escape(', '.join(variants))}</div>
{''.join(cards)}
</main>
</body>
</html>
"""
html_path.write_text(page, encoding="utf-8")
return html_path
def main():
parser = argparse.ArgumentParser(description="Xuất prediction của A1/A2/B1/B2/DPO/PPO để so sánh.")
parser.add_argument("--config", default="configs/medical_vqa.yaml")
parser.add_argument("--split", default="test", choices=["train", "validation", "test"])
parser.add_argument("--variants", nargs="+", default=["A1", "A2", "B1", "B2"])
parser.add_argument("--output-dir", default="results/predictions")
args = parser.parse_args()
config = load_config(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(config["model_a"]["phobert_model"])
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
hf_split, dataloader = build_dataset_and_loader(config, args.split, tokenizer)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
image_refs = export_preview_images(hf_split, output_dir, args.split)
summary = {}
prediction_map = {}
for variant in args.variants:
print(f"[INFO] Đang chạy prediction cho {variant} trên split '{args.split}'...")
if variant in {"A1", "A2"}:
model, checkpoint = load_direction_a_model(variant, config, tokenizer, device)
prediction_rows = predict_direction_a(
model,
dataloader,
device,
tokenizer,
beam_width=int(config["eval"].get("beam_width_a", 5)),
max_len=int(config["data"].get("max_answer_len", 20)),
max_words=int(config["data"].get("answer_max_words", 10)),
)
else:
model, processor, checkpoint = load_direction_b_model(variant, config)
prediction_rows = predict_direction_b(
model,
dataloader,
device,
processor,
beam_width=int(config["eval"].get("beam_width_b", 5)),
beam_width_closed=int(config["eval"].get("beam_width_b_closed", 1)),
beam_width_open=int(config["eval"].get("beam_width_b_open", config["eval"].get("beam_width_b", 5))),
max_new_tokens_closed=int(config["eval"].get("max_new_tokens_b_closed", 4)),
max_new_tokens_open=int(config["eval"].get("max_new_tokens_b_open", int(config["data"].get("answer_max_words", 10)) + 6)),
generation_batch_size=int(config["eval"].get("generation_batch_size_b", 1)),
max_words=int(config["data"].get("answer_max_words", 10)),
variant=variant,
)
rows = convert_prediction_rows(hf_split, prediction_rows, variant, checkpoint)
prediction_map[variant] = rows
out_path = output_dir / f"{variant}_{args.split}_predictions.json"
with open(out_path, "w", encoding="utf-8") as f:
json.dump(rows, f, ensure_ascii=False, indent=2)
summary[variant] = {
"checkpoint": checkpoint,
"num_predictions": len(rows),
}
print(f"[SUCCESS] Đã lưu {out_path}")
del model
if variant in {"B1", "B2", "DPO", "PPO"}:
del processor
if torch.cuda.is_available():
torch.cuda.empty_cache()
compare_rows = build_side_by_side(hf_split, prediction_map)
for idx, row in enumerate(compare_rows):
row["image_preview"] = image_refs[idx] if idx < len(image_refs) else ""
compare_path = output_dir / f"compare_{args.split}_{'_'.join(args.variants)}.json"
with open(compare_path, "w", encoding="utf-8") as f:
json.dump(compare_rows, f, ensure_ascii=False, indent=2)
summary_path = output_dir / f"summary_{args.split}_{'_'.join(args.variants)}.json"
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
html_path = render_compare_html(compare_rows, args.variants, output_dir, args.split)
print(f"[SUCCESS] Đã lưu file so sánh tại {compare_path}")
print(f"[SUCCESS] Đã lưu summary tại {summary_path}")
print(f"[SUCCESS] Đã lưu HTML hiển thị ảnh tại {html_path}")
if __name__ == "__main__":
main()