So sánh prediction {html.escape(split)}
Models: {html.escape(', '.join(variants))}
{''.join(cards)}
"""
html_path.write_text(page, encoding="utf-8")
return html_path
def main():
parser = argparse.ArgumentParser(description="Xuất prediction của A1/A2/B1/B2/DPO/PPO để so sánh.")
parser.add_argument("--config", default="configs/medical_vqa.yaml")
parser.add_argument("--split", default="test", choices=["train", "validation", "test"])
parser.add_argument("--variants", nargs="+", default=["A1", "A2", "B1", "B2"])
parser.add_argument("--output-dir", default="results/predictions")
args = parser.parse_args()
config = load_config(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(config["model_a"]["phobert_model"])
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
hf_split, dataloader = build_dataset_and_loader(config, args.split, tokenizer)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
image_refs = export_preview_images(hf_split, output_dir, args.split)
summary = {}
prediction_map = {}
for variant in args.variants:
print(f"[INFO] Đang chạy prediction cho {variant} trên split '{args.split}'...")
if variant in {"A1", "A2"}:
model, checkpoint = load_direction_a_model(variant, config, tokenizer, device)
prediction_rows = predict_direction_a(
model,
dataloader,
device,
tokenizer,
beam_width=int(config["eval"].get("beam_width_a", 5)),
max_len=int(config["data"].get("max_answer_len", 20)),
max_words=int(config["data"].get("answer_max_words", 10)),
)
else:
model, processor, checkpoint = load_direction_b_model(variant, config)
prediction_rows = predict_direction_b(
model,
dataloader,
device,
processor,
beam_width=int(config["eval"].get("beam_width_b", 5)),
beam_width_closed=int(config["eval"].get("beam_width_b_closed", 1)),
beam_width_open=int(config["eval"].get("beam_width_b_open", config["eval"].get("beam_width_b", 5))),
max_new_tokens_closed=int(config["eval"].get("max_new_tokens_b_closed", 4)),
max_new_tokens_open=int(config["eval"].get("max_new_tokens_b_open", int(config["data"].get("answer_max_words", 10)) + 6)),
generation_batch_size=int(config["eval"].get("generation_batch_size_b", 1)),
max_words=int(config["data"].get("answer_max_words", 10)),
variant=variant,
)
rows = convert_prediction_rows(hf_split, prediction_rows, variant, checkpoint)
prediction_map[variant] = rows
out_path = output_dir / f"{variant}_{args.split}_predictions.json"
with open(out_path, "w", encoding="utf-8") as f:
json.dump(rows, f, ensure_ascii=False, indent=2)
summary[variant] = {
"checkpoint": checkpoint,
"num_predictions": len(rows),
}
print(f"[SUCCESS] Đã lưu {out_path}")
del model
if variant in {"B1", "B2", "DPO", "PPO"}:
del processor
if torch.cuda.is_available():
torch.cuda.empty_cache()
compare_rows = build_side_by_side(hf_split, prediction_map)
for idx, row in enumerate(compare_rows):
row["image_preview"] = image_refs[idx] if idx < len(image_refs) else ""
compare_path = output_dir / f"compare_{args.split}_{'_'.join(args.variants)}.json"
with open(compare_path, "w", encoding="utf-8") as f:
json.dump(compare_rows, f, ensure_ascii=False, indent=2)
summary_path = output_dir / f"summary_{args.split}_{'_'.join(args.variants)}.json"
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
html_path = render_compare_html(compare_rows, args.variants, output_dir, args.split)
print(f"[SUCCESS] Đã lưu file so sánh tại {compare_path}")
print(f"[SUCCESS] Đã lưu summary tại {summary_path}")
print(f"[SUCCESS] Đã lưu HTML hiển thị ảnh tại {html_path}")
if __name__ == "__main__":
main()