Spaces:
Paused
Paused
Commit ·
d63774a
1
Parent(s): 1857fb3
Deploy Medical VQA app
Browse files- .gitignore +10 -0
- Dockerfile +32 -9
- configs/medical_vqa.yaml +184 -0
- data/.gitkeep +0 -0
- data/images/__MACOSX +1 -0
- data/images/imgs +1 -0
- data/judge_results.json +0 -0
- data/medical_dict.json +33 -0
- data/medical_dict_vn.json +1 -0
- data/merged_vqa_vi.json +0 -0
- data/merged_vqa_vi_cleaned.json +0 -0
- data/translate_checkpoint.json +0 -0
- requirements.txt +60 -0
- src/__init__.py +1 -0
- src/data/__init__.py +1 -0
- src/data/medical_dataset.py +121 -0
- src/engine/__init__.py +1 -0
- src/engine/dpo_trainer.py +306 -0
- src/engine/medical_eval.py +731 -0
- src/engine/trainer.py +461 -0
- src/models/__init__.py +1 -0
- src/models/encoder.py +23 -0
- src/models/medical_vqa_model.py +100 -0
- src/models/multimodal_vqa.py +79 -0
- src/models/phobert_encoder.py +24 -0
- src/models/transformer_decoder.py +214 -0
- src/utils/__init__.py +1 -0
- src/utils/answer_rewriter.py +244 -0
- src/utils/discriminative_lr.py +152 -0
- src/utils/early_stopping.py +258 -0
- src/utils/evaluation_viz.py +194 -0
- src/utils/helpers.py +30 -0
- src/utils/metrics.py +192 -0
- src/utils/optimized_metrics.py +202 -0
- src/utils/text_utils.py +203 -0
- src/utils/translator.py +183 -0
- src/utils/visualization.py +53 -0
- web/README.md +96 -0
- web/main.py +978 -0
- web/static/index.html +656 -0
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.DS_Store
|
| 6 |
+
checkpoints/
|
| 7 |
+
logs/
|
| 8 |
+
.ipynb_checkpoints/
|
| 9 |
+
venv/
|
| 10 |
+
env/
|
Dockerfile
CHANGED
|
@@ -1,19 +1,42 @@
|
|
| 1 |
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 4 |
-
python3
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
WORKDIR /app
|
| 8 |
|
| 9 |
-
COPY requirements.txt .
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
COPY . /app
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
ENV WEB_PRELOAD_MODELS=1
|
| 18 |
|
| 19 |
-
CMD ["uvicorn", "web.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
|
|
|
| 1 |
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
|
| 2 |
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
TOKENIZERS_PARALLELISM=false \
|
| 7 |
+
HF_HOME=/data/.huggingface \
|
| 8 |
+
HUGGINGFACE_HUB_CACHE=/data/.huggingface/hub \
|
| 9 |
+
TRANSFORMERS_CACHE=/data/.huggingface/transformers \
|
| 10 |
+
WEB_PRELOAD_MODELS=1
|
| 11 |
+
|
| 12 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
python3 \
|
| 14 |
+
python3-pip \
|
| 15 |
+
python3-dev \
|
| 16 |
+
build-essential \
|
| 17 |
+
git \
|
| 18 |
+
curl \
|
| 19 |
+
ca-certificates \
|
| 20 |
+
libgl1 \
|
| 21 |
+
libglib2.0-0 \
|
| 22 |
+
libsm6 \
|
| 23 |
+
libxext6 \
|
| 24 |
+
libxrender1 \
|
| 25 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 26 |
|
| 27 |
WORKDIR /app
|
| 28 |
|
| 29 |
+
COPY requirements.txt /app/requirements.txt
|
| 30 |
+
|
| 31 |
+
RUN python3 -m pip install --upgrade pip setuptools wheel && \
|
| 32 |
+
python3 -m pip install --index-url https://download.pytorch.org/whl/cu121 \
|
| 33 |
+
torch torchvision torchaudio && \
|
| 34 |
+
python3 -m pip install -r requirements.txt
|
| 35 |
|
| 36 |
COPY . /app
|
| 37 |
|
| 38 |
+
RUN mkdir -p /data/.huggingface
|
| 39 |
+
|
| 40 |
+
EXPOSE 7860
|
|
|
|
| 41 |
|
| 42 |
+
CMD ["python3", "-m", "uvicorn", "web.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
configs/medical_vqa.yaml
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 2 |
+
# Medical VQA — SLAKE + VQA-RAD (Merged) Configuration
|
| 3 |
+
# Pipeline: Dictionary-Enhanced Translation → ResNet50+PhoBERT → A1/A2/B1/B2
|
| 4 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 5 |
+
|
| 6 |
+
seed: 42
|
| 7 |
+
device: "auto" # auto | cuda | mps | cpu
|
| 8 |
+
log_dir: "logs/medical_vqa"
|
| 9 |
+
ckpt_dir: "checkpoints/medical_vqa"
|
| 10 |
+
|
| 11 |
+
# ── Data ──
|
| 12 |
+
data:
|
| 13 |
+
dataset_name: "slake_vqa_rad_merged"
|
| 14 |
+
hf_dataset: "SpringWang08/medical-vqa-vi" # [NEW] Dataset trên Hub
|
| 15 |
+
use_hf_splits: true
|
| 16 |
+
vqa_json: "data/merged_vqa_vi_cleaned.json"
|
| 17 |
+
image_dir: "data/images"
|
| 18 |
+
manual_test_json: "data/manual_test_set.json" # 50+ mẫu test thủ công (bắt buộc)
|
| 19 |
+
train_ratio: 0.80
|
| 20 |
+
val_ratio: 0.10
|
| 21 |
+
test_ratio: 0.10
|
| 22 |
+
image_size: 224
|
| 23 |
+
max_question_len: 64 # PhoBERT max tokens
|
| 24 |
+
max_answer_len: 20 # token budget cho đáp án ngắn <=10 từ
|
| 25 |
+
answer_max_words: 10
|
| 26 |
+
use_short_answer_targets: true
|
| 27 |
+
normalize_medical_terms: true
|
| 28 |
+
postprocess_answer: true
|
| 29 |
+
|
| 30 |
+
# ── Hướng A: Kiến trúc rời (Modular Architecture) ──
|
| 31 |
+
model_a:
|
| 32 |
+
image_encoder: "densenet121_xrv" # [FIX] Đổi từ resnet50 sang DenseNet-121 (XRV) y tế
|
| 33 |
+
text_encoder: "phobert" # vinai/phobert-base
|
| 34 |
+
phobert_model: "vinai/phobert-base"
|
| 35 |
+
freeze_phobert_layers: 10 # Đóng băng 10/12 lớp đầu
|
| 36 |
+
hidden_size: 768
|
| 37 |
+
num_decoder_layers: 2
|
| 38 |
+
dropout: 0.3
|
| 39 |
+
fusion: "co_attention" # [UPGRADE] concat → co_attention (Kim et al., NeurIPS 2018)
|
| 40 |
+
|
| 41 |
+
# Decoder configs (A1 = lstm, A2 = transformer)
|
| 42 |
+
decoder_type: "lstm" # "lstm" hoặc "transformer"
|
| 43 |
+
# Transformer Decoder specific (A2)
|
| 44 |
+
transformer_heads: 8
|
| 45 |
+
transformer_ff_dim: 3072 # 4 × hidden_size (768×4) — chuẩn Transformer gốc
|
| 46 |
+
transformer_decoder_layers: 3
|
| 47 |
+
transformer_norm_first: true # Pre-LN: giúp A2 hội tụ sớm và ổn định hơn
|
| 48 |
+
transformer_dropout: 0.1
|
| 49 |
+
|
| 50 |
+
# ── Hướng B: Multimodal Pretrained ──
|
| 51 |
+
model_b:
|
| 52 |
+
model_name: "chaoyinshe/llava-med-v1.5-mistral-7b-hf" # Bản HF compatible cho transformers
|
| 53 |
+
use_lora: true
|
| 54 |
+
lora_r: 16
|
| 55 |
+
lora_alpha: 32
|
| 56 |
+
lora_dropout: 0.05
|
| 57 |
+
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] # [OPTIMIZED] Thêm MLP layers giúp tăng mạnh khả năng suy luận y khoa
|
| 58 |
+
|
| 59 |
+
# ── DPO (Reinforcement Learning) ──
|
| 60 |
+
dpo:
|
| 61 |
+
preference_data: "data/preference_data_slake_v3.json"
|
| 62 |
+
beta: 0.01 # Refinement nhe tren B2, tranh keo model lech qua manh
|
| 63 |
+
num_epochs: 1
|
| 64 |
+
learning_rate: 1.0e-6
|
| 65 |
+
num_pairs: 400 # So cap vua du de alignment nhe
|
| 66 |
+
closed_ratio: 0.6 # 60% closed / 40% open cho muc tieu can bang tong the
|
| 67 |
+
max_answer_words: 6 # Preference data ngan, dung phan phoi short-answer VQA
|
| 68 |
+
force_rebuild_preference_data: true # Luon tao lai v3 de tranh xai nham file cu
|
| 69 |
+
eval_steps: 50 # [NEW] Đánh giá VQA mỗi 50 steps để track reward margin
|
| 70 |
+
save_best_only: true # Chỉ lưu checkpoint khi reward margin cải thiện
|
| 71 |
+
|
| 72 |
+
# ── PPO (Reinforcement Learning) ──
|
| 73 |
+
ppo:
|
| 74 |
+
learning_rate: 5.0e-7 # PPO refinement phai nhe hon DPO de tranh lam xau B2
|
| 75 |
+
num_samples: 192 # Rollout subset nho de giu chi phi train hop ly
|
| 76 |
+
closed_ratio: 0.5 # PPO danh cho ca closed va open
|
| 77 |
+
rollout_batch_size: 2 # Batch nho de tranh OOM khi generate + update
|
| 78 |
+
clip_range: 0.2
|
| 79 |
+
entropy_coef: 0.001
|
| 80 |
+
temperature: 0.8
|
| 81 |
+
top_p: 0.9
|
| 82 |
+
max_new_tokens: 12
|
| 83 |
+
max_answer_words: 6
|
| 84 |
+
closed_positive_reward: 1.0
|
| 85 |
+
closed_negative_reward: -1.0
|
| 86 |
+
train_mlp_lora: false
|
| 87 |
+
weight_decay: 0.0
|
| 88 |
+
|
| 89 |
+
# ── Training ──
|
| 90 |
+
train:
|
| 91 |
+
a_epochs: 50 # A1 + A2 (Direction A) model epochs
|
| 92 |
+
epochs: 5 # B2 fine-tuned model epochs
|
| 93 |
+
dpo_epochs: 1 # DPO refinement nhe, 1 epoch duy nhat
|
| 94 |
+
batch_size: 32 # [OPTIMIZED FOR A1/A2]
|
| 95 |
+
b2_batch_size: 1 # Micro-batch cho LLaVA-Med; tăng effective batch bằng gradient accumulation
|
| 96 |
+
b2_gradient_accumulation_steps: 8 # Effective batch = 8, an toàn hơn nhiều cho 7B multimodal
|
| 97 |
+
b2_eval_batch_size: 1 # Eval của LLaVA-Med cũng dễ OOM nếu để lớn
|
| 98 |
+
b2_max_length: 128 # Giới hạn token khi SFT để giảm memory footprint
|
| 99 |
+
b2_warmup_steps: 50 # Thay cho warmup_ratio đã deprecated
|
| 100 |
+
b2_optim: "paged_adamw_8bit" # Optimizer tiết kiệm VRAM cho QLoRA
|
| 101 |
+
b2_num_workers: 4 # Vừa đủ cho multimodal collator, tránh overhead
|
| 102 |
+
dpo_batch_size: 1 # [OOM FIX] DPO trên GPU 24GB nên dùng micro-batch = 1
|
| 103 |
+
dpo_gradient_accumulation_steps: 16 # Peak VRAM không đổi, nhưng effective batch vẫn đủ ổn định
|
| 104 |
+
dpo_max_length: 768 # [OOM FIX] LLaVA can chen du image tokens; < 576 se gay mismatch
|
| 105 |
+
dpo_max_prompt_length: 640 # [OOM FIX] Cho phep giu tron image tokens + prompt ngan
|
| 106 |
+
dpo_max_completion_length: 12 # [OOM FIX] Completion chỉ vài từ là đủ cho VQA
|
| 107 |
+
dpo_optim: "paged_adamw_8bit" # [OOM FIX] Optimizer tiết kiệm bộ nhớ
|
| 108 |
+
dpo_train_mlp_lora: false # [OOM FIX] Chỉ train attention LoRA, freeze MLP LoRA để tránh OOM
|
| 109 |
+
eval_batch_size: 16 # [OPTIMIZED FOR 4090] Tăng tốc độ Evaluation
|
| 110 |
+
learning_rate: 3.0e-4 # Cho Hướng A
|
| 111 |
+
b2_lr: 2.0e-5 # Cho B2 (SFT) - chuẩn cho LLaVA
|
| 112 |
+
phobert_lr: 1.0e-5
|
| 113 |
+
vision_lr: 1.0e-5
|
| 114 |
+
label_smoothing: 0.1
|
| 115 |
+
grad_clip: 5.0
|
| 116 |
+
patience: 10
|
| 117 |
+
warmup_epochs: 3 # [TUNED] 5→3: dataset vừa, warmup dài làm chậm hội tụ
|
| 118 |
+
scheduler: "cosine"
|
| 119 |
+
eta_min: 1.0e-6
|
| 120 |
+
eval_every: 2
|
| 121 |
+
open_loss_weight: 2.0 # [TUNED] 3.0→2.0: giảm dominance của open-head
|
| 122 |
+
use_amp: true # Luôn bật Mixed Precision cho T4
|
| 123 |
+
gradient_accumulation_steps: 2 # [OPTIMIZATION] Effective batch = 32*2 = 64
|
| 124 |
+
use_discriminative_lr: true # [OPTIMIZATION] Different LR for different layers
|
| 125 |
+
use_dynamic_class_weights: true # [OPTIMIZATION] Compute class weights from data
|
| 126 |
+
num_workers: 16 # [OPTIMIZED FOR VAST.AI] Tăng từ 2 lên 16 nhờ CPU 36 cores
|
| 127 |
+
pin_memory: true # Tăng tốc truyền dữ liệu lên GPU
|
| 128 |
+
# B2 SFT specific
|
| 129 |
+
b2_load_best_model_at_end: true # [FIX B2] Dùng best checkpoint (epoch 4), không dùng epoch cuối
|
| 130 |
+
b2_metric_for_best: "eval_loss" # B2 early stop theo eval_loss (đáy tại epoch 4)
|
| 131 |
+
|
| 132 |
+
# ── Evaluation ──
|
| 133 |
+
eval:
|
| 134 |
+
beam_width_a: 5 # Dành cho Hướng A (Nhanh & Chất lượng)
|
| 135 |
+
beam_width_b: 5 # Dùng cho open-ended của Hướng B khi cần so công bằng với Hướng A
|
| 136 |
+
beam_width_b_closed: 1 # Closed questions của Hướng B không cần beam lớn
|
| 137 |
+
beam_width_b_open: 5 # Open questions của Hướng B dùng beam lớn hơn để tăng semantic match
|
| 138 |
+
max_new_tokens_b_closed: 4 # Giữ câu đóng rất ngắn
|
| 139 |
+
max_new_tokens_b_open: 16 # Đủ biên cho postprocess cắt về <=10 từ
|
| 140 |
+
generation_batch_size_b: 1 # Micro-batch cho generation của Hướng B để tránh OOM khi beam > 1
|
| 141 |
+
metrics: ["vqa_accuracy", "bleu", "rouge_l", "meteor", "bertscore"]
|
| 142 |
+
llm_judge: true
|
| 143 |
+
llm_judge_model: "gemini-1.5-flash" # API model for LLM-as-a-judge
|
| 144 |
+
|
| 145 |
+
# ── Model Variants (Bắt buộc 4 cấu hình) ──
|
| 146 |
+
model_variants:
|
| 147 |
+
A1_LSTM_Decoder:
|
| 148 |
+
direction: "A"
|
| 149 |
+
decoder_type: "lstm"
|
| 150 |
+
description: "Hướng A với LSTM decoder"
|
| 151 |
+
A2_Transformer_Decoder:
|
| 152 |
+
direction: "A"
|
| 153 |
+
decoder_type: "transformer"
|
| 154 |
+
description: "Hướng A với Transformer decoder"
|
| 155 |
+
B1_ZeroShot:
|
| 156 |
+
direction: "B"
|
| 157 |
+
fine_tuned: false
|
| 158 |
+
description: "Hướng B zero-shot"
|
| 159 |
+
B2_FineTuned:
|
| 160 |
+
direction: "B"
|
| 161 |
+
fine_tuned: true
|
| 162 |
+
description: "Hướng B fine-tuned"
|
| 163 |
+
|
| 164 |
+
# ── WandB ──
|
| 165 |
+
wandb:
|
| 166 |
+
project: "MedicalVQA-Vietnam" # Tên project trên wandb.ai
|
| 167 |
+
entity: "" # Để trống = dùng account đang login (vxq123)
|
| 168 |
+
group: "DL-Final-523H0173-523H0178" # Nhóm tất cả 5 variant trong 1 group
|
| 169 |
+
job_type: "train" # train | eval | debug
|
| 170 |
+
# Tags tự động gắn cho từng variant
|
| 171 |
+
tags:
|
| 172 |
+
A1: ["lstm", "modular", "direction-A", "phobert", "densenet"]
|
| 173 |
+
A2: ["transformer", "pre-ln", "weight-tying", "direction-A", "phobert"]
|
| 174 |
+
B1: ["zero-shot", "llava-med", "direction-B", "few-shot-prompt"]
|
| 175 |
+
B2: ["sft", "lora", "fine-tuned", "direction-B", "llava-med"]
|
| 176 |
+
DPO: ["dpo", "rlhf", "preference", "direction-B"]
|
| 177 |
+
PPO: ["ppo", "reinforcement-learning", "direction-B", "llava-med"]
|
| 178 |
+
notes: "Medical VQA Vietnamese — SLAKE + VQA-RAD merged dataset"
|
| 179 |
+
log_gradients: false # true = log gradient histogram (chậm hơn ~10%)
|
| 180 |
+
log_freq: 50 # Log mỗi N batch
|
| 181 |
+
watch_model: false # true = theo dõi weights (tốn bandwidth)
|
| 182 |
+
save_code: true # Upload source code lên WandB
|
| 183 |
+
# Offline mode khi không có Internet (sync sau bằng wandb sync)
|
| 184 |
+
offline: false # đặt true khi train trên Vast.ai không có internet
|
data/.gitkeep
ADDED
|
File without changes
|
data/images/__MACOSX
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/Users/springwang/.cache/huggingface/hub/datasets--BoKelvin--SLAKE/snapshots/a9083ce6c34ac3ffb17671a605962924d8a8f9e9/unzipped_imgs/__MACOSX
|
data/images/imgs
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/Users/springwang/.cache/huggingface/hub/datasets--BoKelvin--SLAKE/snapshots/a9083ce6c34ac3ffb17671a605962924d8a8f9e9/unzipped_imgs/imgs
|
data/judge_results.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/medical_dict.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chest": "ngực",
|
| 3 |
+
"lung": "phổi",
|
| 4 |
+
"heart": "tim",
|
| 5 |
+
"brain": "não",
|
| 6 |
+
"liver": "gan",
|
| 7 |
+
"kidney": "thận",
|
| 8 |
+
"bone": "xương",
|
| 9 |
+
"spine": "cột sống",
|
| 10 |
+
"joint": "khớp",
|
| 11 |
+
"abnormal": "bất thường",
|
| 12 |
+
"normal": "bình thường",
|
| 13 |
+
"mass": "khối u",
|
| 14 |
+
"tumor": "khối u",
|
| 15 |
+
"cyst": "nang",
|
| 16 |
+
"fracture": "gãy xương",
|
| 17 |
+
"effusion": "tràn dịch",
|
| 18 |
+
"infiltration": "thâm nhiễm",
|
| 19 |
+
"pneumonia": "viêm phổi",
|
| 20 |
+
"cardiomegaly": "tim to",
|
| 21 |
+
"ct scan": "chụp cắt lớp vi tính (CT)",
|
| 22 |
+
"computed tomography": "chụp cắt lớp vi tính (CT)",
|
| 23 |
+
"mri": "chụp cộng hưởng từ (MRI)",
|
| 24 |
+
"x-ray": "chụp X-quang",
|
| 25 |
+
"ultrasound": "siêu âm",
|
| 26 |
+
"abdomen": "bụng",
|
| 27 |
+
"pelvis": "chậu",
|
| 28 |
+
"head": "đầu",
|
| 29 |
+
"shoulder": "vai",
|
| 30 |
+
"knee": "gối",
|
| 31 |
+
"hand": "tay",
|
| 32 |
+
"foot": "chân"
|
| 33 |
+
}
|
data/medical_dict_vn.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
404: Not Found
|
data/merged_vqa_vi.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/merged_vqa_vi_cleaned.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/translate_checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 2 |
+
# Medical VQA — requirements.txt
|
| 3 |
+
# Python 3.10+ | CUDA 11.8+ | Tested on: RTX 4090, T4, A100
|
| 4 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 5 |
+
|
| 6 |
+
# ── Deep Learning Core ───────────────────────────────────────────────────
|
| 7 |
+
torch>=2.1.0
|
| 8 |
+
torchvision>=0.16.0
|
| 9 |
+
torchaudio>=2.1.0 # cần cho một số HF pipeline
|
| 10 |
+
|
| 11 |
+
# ── Medical Imaging ──────────────────────────────────────────────────────
|
| 12 |
+
torchxrayvision>=1.1.0 # DenseNet-121 XRV pretrained weights
|
| 13 |
+
opencv-python-headless>=4.8.0 # CLAHE preprocessing (headless = không cần GUI)
|
| 14 |
+
Pillow>=10.0.0
|
| 15 |
+
|
| 16 |
+
# ── HuggingFace Ecosystem ────────────────────────────────────────────────
|
| 17 |
+
transformers>=4.38.0 # PhoBERT, LLaVA-Med, SFTTrainer
|
| 18 |
+
huggingface_hub>=0.20.0
|
| 19 |
+
datasets>=2.18.0
|
| 20 |
+
tokenizers>=0.15.0
|
| 21 |
+
accelerate>=0.27.0 # Mixed precision, device mapping
|
| 22 |
+
sentencepiece>=0.1.99 # PhoBERT tokenizer
|
| 23 |
+
peft>=0.9.0 # LoRA cho LLaVA-Med (B2)
|
| 24 |
+
trl>=0.8.1 # SFTTrainer + DPOTrainer
|
| 25 |
+
|
| 26 |
+
# ── Quantization (B2 / DPO 4-bit) ───────────────────────────────────────
|
| 27 |
+
bitsandbytes>=0.43.0 # 4-bit quantization cho LLaVA-Med
|
| 28 |
+
|
| 29 |
+
# ── Vietnamese NLP ───────────────────────────────────────────────────────
|
| 30 |
+
underthesea>=6.8.0 # Tokenization tiếng Việt
|
| 31 |
+
|
| 32 |
+
# ── NLP Metrics ──────────────────────────────────────────────────────────
|
| 33 |
+
nltk>=3.8.1
|
| 34 |
+
bert-score>=0.3.13
|
| 35 |
+
rouge-score>=0.1.2
|
| 36 |
+
evaluate>=0.4.1 # HuggingFace evaluate (BLEU, METEOR)
|
| 37 |
+
|
| 38 |
+
# ── Data & Scientific ────────────────────────────────────────────────────
|
| 39 |
+
numpy>=1.26.0
|
| 40 |
+
pandas>=2.1.0
|
| 41 |
+
scikit-learn>=1.4.0
|
| 42 |
+
scipy>=1.12.0
|
| 43 |
+
|
| 44 |
+
# ── Visualization ────────────────────────────────────────────────────────
|
| 45 |
+
matplotlib>=3.8.0
|
| 46 |
+
seaborn>=0.13.0
|
| 47 |
+
|
| 48 |
+
# ── Experiment Tracking ──────────────────────────────────────────────────
|
| 49 |
+
wandb>=0.16.0
|
| 50 |
+
|
| 51 |
+
# ── Config & Utilities ───────────────────────────────────────────────────
|
| 52 |
+
pyyaml>=6.0.1
|
| 53 |
+
python-dotenv>=1.0.1
|
| 54 |
+
tqdm>=4.66.0
|
| 55 |
+
requests>=2.31.0
|
| 56 |
+
|
| 57 |
+
# ── Jupyter (local development) ──────────────────────────────────────────
|
| 58 |
+
jupyter>=1.0.0
|
| 59 |
+
ipython>=8.22.0
|
| 60 |
+
ipywidgets>=8.1.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Initialize src package
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Initialize src.data package
|
src/data/medical_dataset.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from src.utils.text_utils import get_target_answer, normalize_answer, text_normalize
|
| 7 |
+
|
| 8 |
+
class MedicalVQADataset(Dataset):
|
| 9 |
+
"""
|
| 10 |
+
Dataset class chung cho Medical VQA (SLAKE + VQA-RAD).
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, hf_dataset=None, json_path=None, image_dir=None, tokenizer=None, transform=None, max_seq_len=64, max_ans_len=10, is_dpo=False, in_channels=1, answer_max_words=10):
|
| 13 |
+
if hf_dataset is not None:
|
| 14 |
+
self.data = hf_dataset
|
| 15 |
+
self.use_hf = True
|
| 16 |
+
elif json_path is not None:
|
| 17 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 18 |
+
self.data = json.load(f)
|
| 19 |
+
self.use_hf = False
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError("Phải cung cấp hf_dataset hoặc json_path!")
|
| 22 |
+
|
| 23 |
+
self.image_dir = image_dir
|
| 24 |
+
self.tokenizer = tokenizer
|
| 25 |
+
self.transform = transform
|
| 26 |
+
self.max_seq_len = max_seq_len
|
| 27 |
+
self.max_ans_len = max_ans_len
|
| 28 |
+
self.is_dpo = is_dpo
|
| 29 |
+
self.in_channels = in_channels
|
| 30 |
+
self.answer_max_words = answer_max_words
|
| 31 |
+
|
| 32 |
+
# Mapping for closed questions (Yes/No)
|
| 33 |
+
self.label_map = {"no": 0, "yes": 1, "không": 0, "có": 1}
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return len(self.data)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, idx):
|
| 39 |
+
item = self.data[idx]
|
| 40 |
+
|
| 41 |
+
# 1. Xử lý ảnh
|
| 42 |
+
if self.use_hf:
|
| 43 |
+
image = item["image"]
|
| 44 |
+
if self.in_channels == 1:
|
| 45 |
+
if image.mode != "L": image = image.convert("L")
|
| 46 |
+
else:
|
| 47 |
+
if image.mode != "RGB": image = image.convert("RGB")
|
| 48 |
+
else:
|
| 49 |
+
# DPO preference data might use 'image' or 'image_name'
|
| 50 |
+
img_name = item.get("image_name") or item.get("image")
|
| 51 |
+
img_path = os.path.join(self.image_dir, img_name)
|
| 52 |
+
mode = "L" if self.in_channels == 1 else "RGB"
|
| 53 |
+
image = Image.open(img_path).convert(mode)
|
| 54 |
+
|
| 55 |
+
raw_image = image # Bản lưu trữ cho Multimodal Processor (chưa Normalize)
|
| 56 |
+
|
| 57 |
+
if self.transform:
|
| 58 |
+
image = self.transform(image)
|
| 59 |
+
else:
|
| 60 |
+
from torchvision import transforms
|
| 61 |
+
image = transforms.ToTensor()(image)
|
| 62 |
+
|
| 63 |
+
# 2. Xử lý câu hỏi
|
| 64 |
+
q_key = "question" if self.is_dpo else "question_vi"
|
| 65 |
+
raw_question = item[q_key]
|
| 66 |
+
raw_question_en = item.get("question", raw_question) # Lấy bản tiếng Anh nếu có
|
| 67 |
+
|
| 68 |
+
question = text_normalize(raw_question)
|
| 69 |
+
encoding = self.tokenizer(
|
| 70 |
+
question,
|
| 71 |
+
padding="max_length",
|
| 72 |
+
truncation=True,
|
| 73 |
+
max_length=self.max_seq_len,
|
| 74 |
+
return_tensors="pt"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if self.is_dpo:
|
| 78 |
+
# 3. Xử lý DPO Preference (Chosen vs Rejected)
|
| 79 |
+
chosen_ans = normalize_answer(item["chosen"])
|
| 80 |
+
rejected_ans = normalize_answer(item["rejected"])
|
| 81 |
+
|
| 82 |
+
chosen_encoding = self.tokenizer(chosen_ans, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt")
|
| 83 |
+
rejected_encoding = self.tokenizer(rejected_ans, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt")
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
"image": image,
|
| 87 |
+
"raw_image": raw_image,
|
| 88 |
+
"raw_questions": raw_question,
|
| 89 |
+
"raw_questions_en": raw_question_en,
|
| 90 |
+
"input_ids": encoding["input_ids"].flatten(),
|
| 91 |
+
"attention_mask": encoding["attention_mask"].flatten(),
|
| 92 |
+
"chosen_ids": chosen_encoding["input_ids"].flatten(),
|
| 93 |
+
"rejected_ids": rejected_encoding["input_ids"].flatten(),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# 3. Xử lý câu trả lời chuẩn (Non-DPO)
|
| 97 |
+
answer = get_target_answer(item, max_words=self.answer_max_words)
|
| 98 |
+
answer_en = normalize_answer(item.get("answer", answer)) # Lấy bản tiếng Anh nếu có
|
| 99 |
+
label_closed = self.label_map.get(answer, -1)
|
| 100 |
+
|
| 101 |
+
ans_encoding = self.tokenizer(
|
| 102 |
+
answer,
|
| 103 |
+
padding="max_length",
|
| 104 |
+
truncation=True,
|
| 105 |
+
max_length=self.max_ans_len,
|
| 106 |
+
return_tensors="pt"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return {
|
| 110 |
+
"image": image,
|
| 111 |
+
"raw_image": raw_image,
|
| 112 |
+
"raw_questions": raw_question,
|
| 113 |
+
"raw_questions_en": raw_question_en,
|
| 114 |
+
"input_ids": encoding["input_ids"].flatten(),
|
| 115 |
+
"attention_mask": encoding["attention_mask"].flatten(),
|
| 116 |
+
"label_closed": torch.tensor(label_closed, dtype=torch.long),
|
| 117 |
+
"target_ids": ans_encoding["input_ids"].flatten(),
|
| 118 |
+
"raw_answer": answer,
|
| 119 |
+
"raw_answer_full": normalize_answer(item.get("answer_full_vi", answer)),
|
| 120 |
+
"raw_answer_en": answer_en
|
| 121 |
+
}
|
src/engine/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Initialize src.engine package
|
src/engine/dpo_trainer.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
from src.utils.text_utils import get_target_answer, normalize_answer
|
| 9 |
+
|
| 10 |
+
def _is_closed_question(question: str, answer: str) -> bool:
|
| 11 |
+
q = normalize_answer(question)
|
| 12 |
+
a = normalize_answer(answer)
|
| 13 |
+
return (
|
| 14 |
+
a in {"có", "không"}
|
| 15 |
+
or q.endswith(" không")
|
| 16 |
+
or " bình thường " in f" {q} "
|
| 17 |
+
or " có " in f" {q} "
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _flip_closed_answer(answer: str) -> str:
|
| 22 |
+
a = normalize_answer(answer)
|
| 23 |
+
if a == "có":
|
| 24 |
+
return "không"
|
| 25 |
+
if a == "không":
|
| 26 |
+
return "có"
|
| 27 |
+
return a
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _answer_category(question: str, answer: str) -> str:
|
| 31 |
+
q = normalize_answer(question)
|
| 32 |
+
a = normalize_answer(answer)
|
| 33 |
+
if _is_closed_question(question, answer):
|
| 34 |
+
return "closed"
|
| 35 |
+
if any(term in q for term in ["ở đâu", "vi tri", "where"]):
|
| 36 |
+
return "location"
|
| 37 |
+
if any(term in a for term in ["trái", "phải", "trên", "dưới", "giữa", "bên"]):
|
| 38 |
+
return "location"
|
| 39 |
+
if any(term in a for term in ["mặt phẳng", "ngang", "vành", "dọc"]):
|
| 40 |
+
return "plane"
|
| 41 |
+
if any(term in a for term in ["gan", "phổi", "tim", "não", "thận", "lách", "bàng quang", "khí quản", "trung thất"]):
|
| 42 |
+
return "organ"
|
| 43 |
+
return "finding"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _build_answer_pools(data: list[dict], max_words: int) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
|
| 47 |
+
question_to_answers = {}
|
| 48 |
+
category_to_answers = {}
|
| 49 |
+
for item in data:
|
| 50 |
+
question = item.get("question_vi", item.get("question", ""))
|
| 51 |
+
answer = get_target_answer(item, max_words=max_words)
|
| 52 |
+
if not question or not answer:
|
| 53 |
+
continue
|
| 54 |
+
q_norm = normalize_answer(question)
|
| 55 |
+
a_norm = normalize_answer(answer)
|
| 56 |
+
category = _answer_category(question, answer)
|
| 57 |
+
question_to_answers.setdefault(q_norm, [])
|
| 58 |
+
if a_norm not in question_to_answers[q_norm]:
|
| 59 |
+
question_to_answers[q_norm].append(a_norm)
|
| 60 |
+
category_to_answers.setdefault(category, [])
|
| 61 |
+
if a_norm not in category_to_answers[category]:
|
| 62 |
+
category_to_answers[category].append(a_norm)
|
| 63 |
+
return question_to_answers, category_to_answers
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _build_rejected_candidates(
|
| 67 |
+
data: list[dict],
|
| 68 |
+
idx: int,
|
| 69 |
+
chosen: str,
|
| 70 |
+
question_to_answers: dict[str, list[str]],
|
| 71 |
+
category_to_answers: dict[str, list[str]],
|
| 72 |
+
) -> list[str]:
|
| 73 |
+
item = data[idx]
|
| 74 |
+
question = item.get("question_vi", item.get("question", ""))
|
| 75 |
+
question_norm = normalize_answer(question)
|
| 76 |
+
chosen_norm = normalize_answer(chosen)
|
| 77 |
+
category = _answer_category(question, chosen)
|
| 78 |
+
candidates = []
|
| 79 |
+
|
| 80 |
+
if _is_closed_question(question, chosen):
|
| 81 |
+
flipped = _flip_closed_answer(chosen)
|
| 82 |
+
if flipped and flipped != chosen_norm:
|
| 83 |
+
candidates.append(flipped)
|
| 84 |
+
else:
|
| 85 |
+
for answer in question_to_answers.get(question_norm, []):
|
| 86 |
+
if answer != chosen_norm:
|
| 87 |
+
candidates.append(answer)
|
| 88 |
+
for answer in category_to_answers.get(category, []):
|
| 89 |
+
if answer != chosen_norm:
|
| 90 |
+
candidates.append(answer)
|
| 91 |
+
deduped = []
|
| 92 |
+
seen = set()
|
| 93 |
+
for candidate in candidates:
|
| 94 |
+
candidate_norm = normalize_answer(candidate)
|
| 95 |
+
if not candidate_norm or candidate_norm == chosen_norm or candidate_norm in seen:
|
| 96 |
+
continue
|
| 97 |
+
seen.add(candidate_norm)
|
| 98 |
+
deduped.append(candidate_norm)
|
| 99 |
+
return deduped
|
| 100 |
+
|
| 101 |
+
def _build_pair_record(item: dict, source_idx: int, chosen: str, rejected: str) -> dict:
|
| 102 |
+
return {
|
| 103 |
+
"image": item.get("image_name") or item.get("image"),
|
| 104 |
+
"source_idx": source_idx,
|
| 105 |
+
"question": item["question_vi"],
|
| 106 |
+
"chosen": chosen,
|
| 107 |
+
"rejected": rejected,
|
| 108 |
+
"answer_type": _answer_category(item["question_vi"], chosen),
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _round_robin_merge(grouped_pairs: dict[str, list[dict]], target_count: int) -> list[dict]:
|
| 113 |
+
ordered_groups = sorted(grouped_pairs.keys())
|
| 114 |
+
merged = []
|
| 115 |
+
while len(merged) < target_count:
|
| 116 |
+
progressed = False
|
| 117 |
+
for group in ordered_groups:
|
| 118 |
+
if grouped_pairs[group]:
|
| 119 |
+
merged.append(grouped_pairs[group].pop())
|
| 120 |
+
progressed = True
|
| 121 |
+
if len(merged) >= target_count:
|
| 122 |
+
break
|
| 123 |
+
if not progressed:
|
| 124 |
+
break
|
| 125 |
+
return merged
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def create_preference_data(
|
| 129 |
+
vqa_json_path,
|
| 130 |
+
output_path,
|
| 131 |
+
num_pairs=400,
|
| 132 |
+
closed_ratio=0.6,
|
| 133 |
+
max_answer_words=6,
|
| 134 |
+
seed=42,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Tạo dữ liệu Preference (Chosen vs Rejected) cho DPO.
|
| 138 |
+
Trong Medical VQA, 'Rejected' thường là các câu trả lời bị hallucination hoặc sai thuật ngữ y khoa.
|
| 139 |
+
"""
|
| 140 |
+
with open(vqa_json_path, 'r', encoding='utf-8') as f:
|
| 141 |
+
data = json.load(f)
|
| 142 |
+
|
| 143 |
+
question_to_answers, category_to_answers = _build_answer_pools(data, max_words=max_answer_words)
|
| 144 |
+
rng = random.Random(seed)
|
| 145 |
+
closed_pairs = []
|
| 146 |
+
open_pairs_by_group = {"location": [], "plane": [], "organ": [], "finding": []}
|
| 147 |
+
|
| 148 |
+
for i in range(len(data)):
|
| 149 |
+
item = data[i]
|
| 150 |
+
chosen = get_target_answer(item, max_words=max_answer_words)
|
| 151 |
+
chosen_norm = normalize_answer(chosen)
|
| 152 |
+
if not chosen_norm or len(chosen_norm.split()) > max_answer_words:
|
| 153 |
+
continue
|
| 154 |
+
rejected_candidates = _build_rejected_candidates(
|
| 155 |
+
data,
|
| 156 |
+
i,
|
| 157 |
+
chosen_norm,
|
| 158 |
+
question_to_answers=question_to_answers,
|
| 159 |
+
category_to_answers=category_to_answers,
|
| 160 |
+
)
|
| 161 |
+
category = _answer_category(item["question_vi"], chosen_norm)
|
| 162 |
+
|
| 163 |
+
for rejected in rejected_candidates:
|
| 164 |
+
if len(rejected.split()) > max_answer_words:
|
| 165 |
+
continue
|
| 166 |
+
pair = _build_pair_record(item, i, chosen_norm, rejected)
|
| 167 |
+
if category == "closed":
|
| 168 |
+
closed_pairs.append(pair)
|
| 169 |
+
elif category in open_pairs_by_group:
|
| 170 |
+
open_pairs_by_group[category].append(pair)
|
| 171 |
+
|
| 172 |
+
rng.shuffle(closed_pairs)
|
| 173 |
+
for pairs in open_pairs_by_group.values():
|
| 174 |
+
rng.shuffle(pairs)
|
| 175 |
+
|
| 176 |
+
target_closed = min(len(closed_pairs), int(round(num_pairs * closed_ratio)))
|
| 177 |
+
target_open = max(0, num_pairs - target_closed)
|
| 178 |
+
sampled_closed = closed_pairs[:target_closed]
|
| 179 |
+
sampled_open = _round_robin_merge(open_pairs_by_group, target_open)
|
| 180 |
+
pref_data = sampled_closed + sampled_open
|
| 181 |
+
rng.shuffle(pref_data)
|
| 182 |
+
|
| 183 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 184 |
+
json.dump(pref_data, f, ensure_ascii=False, indent=2)
|
| 185 |
+
|
| 186 |
+
print(
|
| 187 |
+
f"[SUCCESS] Đã tạo {len(pref_data)} cặp preference dữ liệu tại {output_path} "
|
| 188 |
+
f"(closed={len(sampled_closed)}, open={len(sampled_open)})"
|
| 189 |
+
)
|
| 190 |
+
preview_count = min(30, len(pref_data))
|
| 191 |
+
if preview_count:
|
| 192 |
+
print(f"[INFO] Preview {preview_count} cặp preference đầu tiên để kiểm tra nhanh:")
|
| 193 |
+
for idx, pair in enumerate(pref_data[:preview_count], start=1):
|
| 194 |
+
print(
|
| 195 |
+
f" [{idx:02d}] type={pair.get('answer_type')} | "
|
| 196 |
+
f"Q={pair.get('question')} | chosen={pair.get('chosen')} | rejected={pair.get('rejected')}"
|
| 197 |
+
)
|
| 198 |
+
return pref_data
|
| 199 |
+
|
| 200 |
+
class MedicalDPOTrainer:
|
| 201 |
+
"""
|
| 202 |
+
Trainer cho Direct Preference Optimization (DPO) trên LLaVA-Med.
|
| 203 |
+
Giúp tối ưu hóa mô hình dựa trên các cặp preference dữ liệu y tế.
|
| 204 |
+
"""
|
| 205 |
+
def __init__(self, model, reference_model, train_loader, optimizer, device, config):
|
| 206 |
+
self.model = model
|
| 207 |
+
self.reference_model = reference_model
|
| 208 |
+
self.train_loader = train_loader
|
| 209 |
+
self.optimizer = optimizer
|
| 210 |
+
self.device = device
|
| 211 |
+
self.config = config
|
| 212 |
+
self.beta = config.get('dpo_beta', 0.1)
|
| 213 |
+
|
| 214 |
+
def get_log_probs(self, logits, labels):
|
| 215 |
+
"""
|
| 216 |
+
Tính log probabilities cho các sequence.
|
| 217 |
+
logits: [batch, seq_len, vocab]
|
| 218 |
+
labels: [batch, seq_len]
|
| 219 |
+
"""
|
| 220 |
+
# Shift logits và labels để khớp (next token prediction)
|
| 221 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 222 |
+
# Lấy log prob của các token đúng
|
| 223 |
+
per_token_logps = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
| 224 |
+
# Chỉ lấy các token không phải padding (giả định mask > 0)
|
| 225 |
+
return (per_token_logps * (labels != 0)).sum(-1)
|
| 226 |
+
|
| 227 |
+
def compute_loss(self, policy_chosen_logps, policy_rejected_logps,
|
| 228 |
+
reference_chosen_logps, reference_rejected_logps):
|
| 229 |
+
"""
|
| 230 |
+
Tính DPO loss theo công thức: -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))
|
| 231 |
+
"""
|
| 232 |
+
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
| 233 |
+
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
| 234 |
+
|
| 235 |
+
logits = pi_logratios - ref_logratios
|
| 236 |
+
loss = -F.logsigmoid(self.beta * logits).mean()
|
| 237 |
+
|
| 238 |
+
# Thêm các chỉ số để theo dõi (rewards)
|
| 239 |
+
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
| 240 |
+
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
| 241 |
+
|
| 242 |
+
return loss, chosen_rewards, rejected_rewards
|
| 243 |
+
|
| 244 |
+
def train(self, epochs=3):
|
| 245 |
+
print(f"[INFO] Bắt đầu huấn luyện DPO (beta={self.beta})...")
|
| 246 |
+
self.model.train()
|
| 247 |
+
self.reference_model.eval()
|
| 248 |
+
# Freeze reference model để tiết kiệm VRAM (Quan trọng cho T4)
|
| 249 |
+
for param in self.reference_model.parameters():
|
| 250 |
+
param.requires_grad_(False)
|
| 251 |
+
|
| 252 |
+
print(f"[INFO] DPO Trainer Ready ({self.device})")
|
| 253 |
+
|
| 254 |
+
for epoch in range(epochs):
|
| 255 |
+
self.model.train()
|
| 256 |
+
total_loss = 0.0 # Đã thêm dòng khởi tạo total_loss tại đây
|
| 257 |
+
pbar = tqdm(self.train_loader, desc=f"DPO Epoch {epoch+1}")
|
| 258 |
+
|
| 259 |
+
for batch in pbar:
|
| 260 |
+
images = batch['image'].to(self.device)
|
| 261 |
+
chosen_ids = batch['chosen_ids'].to(self.device)
|
| 262 |
+
rejected_ids = batch['rejected_ids'].to(self.device)
|
| 263 |
+
|
| 264 |
+
# Tính Logits cho Chosen và Rejected (Sử dụng Duck Typing/Safe Forward)
|
| 265 |
+
try:
|
| 266 |
+
# Case: LLaVA-style multimodal model
|
| 267 |
+
outputs_w = self.model(input_ids=chosen_ids, pixel_values=images, labels=chosen_ids)
|
| 268 |
+
outputs_l = self.model(input_ids=rejected_ids, pixel_values=images, labels=rejected_ids)
|
| 269 |
+
logits_w = outputs_w.logits
|
| 270 |
+
logits_l = outputs_l.logits
|
| 271 |
+
except Exception:
|
| 272 |
+
# Fallback: Modular model (A1/A2 style)
|
| 273 |
+
_, logits_w = self.model(images, chosen_ids)
|
| 274 |
+
_, logits_l = self.model(images, rejected_ids)
|
| 275 |
+
|
| 276 |
+
# 2. Forward Reference Model (No Grad)
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
try:
|
| 279 |
+
# Multimodal case
|
| 280 |
+
outputs_ref_w = self.reference_model(input_ids=chosen_ids, pixel_values=images, labels=chosen_ids)
|
| 281 |
+
outputs_ref_l = self.reference_model(input_ids=rejected_ids, pixel_values=images, labels=rejected_ids)
|
| 282 |
+
ref_logits_w = outputs_ref_w.logits
|
| 283 |
+
ref_logits_l = outputs_ref_l.logits
|
| 284 |
+
except Exception:
|
| 285 |
+
# Modular case
|
| 286 |
+
_, ref_logits_w = self.reference_model(images, chosen_ids)
|
| 287 |
+
_, ref_logits_l = self.reference_model(images, rejected_ids)
|
| 288 |
+
|
| 289 |
+
# 3. Tính log probs
|
| 290 |
+
logps_w = self.get_log_probs(logits_w, chosen_ids)
|
| 291 |
+
logps_l = self.get_log_probs(logits_l, rejected_ids)
|
| 292 |
+
ref_logps_w = self.get_log_probs(ref_logits_w, chosen_ids)
|
| 293 |
+
ref_logps_l = self.get_log_probs(ref_logits_l, rejected_ids)
|
| 294 |
+
|
| 295 |
+
# 4. Tính Loss
|
| 296 |
+
loss, _, _ = self.compute_loss(logps_w, logps_l, ref_logps_w, ref_logps_l)
|
| 297 |
+
|
| 298 |
+
# 5. Backward
|
| 299 |
+
self.optimizer.zero_grad()
|
| 300 |
+
loss.backward()
|
| 301 |
+
self.optimizer.step()
|
| 302 |
+
|
| 303 |
+
total_loss += loss.item()
|
| 304 |
+
pbar.set_postfix({"loss": loss.item()})
|
| 305 |
+
|
| 306 |
+
print(f"Epoch {epoch+1} | DPO Loss: {total_loss/len(self.train_loader):.4f}")
|
src/engine/medical_eval.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from src.utils.metrics import batch_metrics, compute_bertscore, compute_semantic_score
|
| 4 |
+
from src.utils.text_utils import is_medical_term_compliant, normalize_answer, postprocess_answer
|
| 5 |
+
|
| 6 |
+
def normalize_for_metric(text: str) -> str:
|
| 7 |
+
return text.strip().lower()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _normalize_closed_answer(question_vi: str, question_en: str, pred_vi: str, pred_en: str = "") -> str:
|
| 11 |
+
"""Map descriptive yes/no-style outputs to closed-form labels."""
|
| 12 |
+
question_vi_norm = normalize_answer(question_vi)
|
| 13 |
+
question_en_norm = normalize_answer(question_en)
|
| 14 |
+
pred_vi_norm = normalize_answer(pred_vi)
|
| 15 |
+
pred_en_norm = normalize_answer(pred_en)
|
| 16 |
+
combined = " ".join(part for part in [pred_vi_norm, pred_en_norm] if part).strip()
|
| 17 |
+
|
| 18 |
+
is_normality_question = any(
|
| 19 |
+
pattern in " ".join([question_vi_norm, question_en_norm])
|
| 20 |
+
for pattern in ["bình thường", "normal", "abnormal", "bat thuong"]
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
if is_normality_question:
|
| 24 |
+
explicit_negative_patterns = [
|
| 25 |
+
"không bình thường",
|
| 26 |
+
"not normal",
|
| 27 |
+
]
|
| 28 |
+
explicit_positive_patterns = [
|
| 29 |
+
"có",
|
| 30 |
+
"yes",
|
| 31 |
+
]
|
| 32 |
+
positive_patterns = [
|
| 33 |
+
"bình thường",
|
| 34 |
+
"normal",
|
| 35 |
+
"no significant abnormalities",
|
| 36 |
+
"no abnormality",
|
| 37 |
+
"unremarkable",
|
| 38 |
+
"appears to be normal",
|
| 39 |
+
"without significant abnormalities",
|
| 40 |
+
"không phát hiện bất thường",
|
| 41 |
+
]
|
| 42 |
+
negative_patterns = [
|
| 43 |
+
"bất thường",
|
| 44 |
+
"abnormal",
|
| 45 |
+
"abnormality detected",
|
| 46 |
+
"fracture",
|
| 47 |
+
"lesion",
|
| 48 |
+
"mass",
|
| 49 |
+
"effusion",
|
| 50 |
+
"pneumothorax",
|
| 51 |
+
]
|
| 52 |
+
if any(pattern in combined for pattern in explicit_negative_patterns):
|
| 53 |
+
return "không"
|
| 54 |
+
if any(pattern in combined.split() for pattern in explicit_positive_patterns):
|
| 55 |
+
return "có"
|
| 56 |
+
if any(pattern in combined for pattern in positive_patterns):
|
| 57 |
+
return "có"
|
| 58 |
+
if any(pattern in combined for pattern in negative_patterns):
|
| 59 |
+
return "không"
|
| 60 |
+
else:
|
| 61 |
+
positive_patterns = [
|
| 62 |
+
"có",
|
| 63 |
+
"yes",
|
| 64 |
+
"present",
|
| 65 |
+
"detected",
|
| 66 |
+
"positive",
|
| 67 |
+
]
|
| 68 |
+
negative_patterns = [
|
| 69 |
+
"không",
|
| 70 |
+
"no",
|
| 71 |
+
"absent",
|
| 72 |
+
"not seen",
|
| 73 |
+
"negative",
|
| 74 |
+
"none",
|
| 75 |
+
]
|
| 76 |
+
# For presence/absence questions, "không có ..." contains "có" but
|
| 77 |
+
# semantically means no. Check negation before positive cues.
|
| 78 |
+
if any(pattern in combined for pattern in negative_patterns):
|
| 79 |
+
return "không"
|
| 80 |
+
if any(pattern in combined for pattern in positive_patterns):
|
| 81 |
+
return "có"
|
| 82 |
+
|
| 83 |
+
fallback_positive_patterns = [
|
| 84 |
+
"bình thường",
|
| 85 |
+
"normal",
|
| 86 |
+
"no significant abnormalities",
|
| 87 |
+
"no abnormality",
|
| 88 |
+
"unremarkable",
|
| 89 |
+
"appears to be normal",
|
| 90 |
+
"without significant abnormalities",
|
| 91 |
+
"không phát hiện bất thường",
|
| 92 |
+
]
|
| 93 |
+
fallback_negative_patterns = [
|
| 94 |
+
"bất thường",
|
| 95 |
+
"abnormal",
|
| 96 |
+
"abnormality detected",
|
| 97 |
+
"fracture",
|
| 98 |
+
"lesion",
|
| 99 |
+
"mass",
|
| 100 |
+
"effusion",
|
| 101 |
+
"pneumothorax",
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
if any(pattern in combined for pattern in fallback_positive_patterns):
|
| 105 |
+
return "có"
|
| 106 |
+
if any(pattern in combined for pattern in fallback_negative_patterns):
|
| 107 |
+
return "không"
|
| 108 |
+
return pred_vi_norm or pred_en_norm
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _compute_format_stats(preds: list[str], max_words: int) -> dict[str, float]:
|
| 112 |
+
if not preds:
|
| 113 |
+
return {
|
| 114 |
+
"max_10_word_compliance_rate": 0.0,
|
| 115 |
+
"medical_term_compliance_rate": 0.0,
|
| 116 |
+
"avg_answer_length": 0.0,
|
| 117 |
+
}
|
| 118 |
+
word_counts = [len(p.split()) for p in preds]
|
| 119 |
+
return {
|
| 120 |
+
"max_10_word_compliance_rate": sum(1 for count in word_counts if count <= max_words) / len(word_counts),
|
| 121 |
+
"medical_term_compliance_rate": sum(1 for pred in preds if is_medical_term_compliant(pred)) / len(preds),
|
| 122 |
+
"avg_answer_length": sum(word_counts) / len(word_counts),
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _build_bad_words_ids(processor, variant: str) -> list[list[int]] | None:
|
| 127 |
+
if variant not in {"B1", "B2", "DPO", "PPO"}:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
tokenizer = getattr(processor, "tokenizer", None)
|
| 131 |
+
if tokenizer is None:
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
banned_phrases = [
|
| 135 |
+
"yes",
|
| 136 |
+
"no",
|
| 137 |
+
"the answer is",
|
| 138 |
+
"the image is",
|
| 139 |
+
"this image is",
|
| 140 |
+
"the image shows",
|
| 141 |
+
"the scan shows",
|
| 142 |
+
"there is",
|
| 143 |
+
"there are",
|
| 144 |
+
"it appears",
|
| 145 |
+
"the finding is",
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
bad_words_ids = []
|
| 149 |
+
for phrase in banned_phrases:
|
| 150 |
+
token_ids = tokenizer.encode(phrase, add_special_tokens=False)
|
| 151 |
+
if token_ids:
|
| 152 |
+
bad_words_ids.append(token_ids)
|
| 153 |
+
return bad_words_ids or None
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _attach_metric_views(metrics: dict[str, float]) -> dict[str, float]:
|
| 157 |
+
"""Add explicit metric names while preserving backward-compatible aliases."""
|
| 158 |
+
if "accuracy" in metrics:
|
| 159 |
+
metrics["accuracy_normalized"] = metrics["accuracy"]
|
| 160 |
+
if "em" in metrics:
|
| 161 |
+
metrics["em_normalized"] = metrics["em"]
|
| 162 |
+
if "f1" in metrics:
|
| 163 |
+
metrics["f1_normalized"] = metrics["f1"]
|
| 164 |
+
if "bleu1" in metrics:
|
| 165 |
+
metrics["bleu1_normalized"] = metrics["bleu1"]
|
| 166 |
+
if "bleu2" in metrics:
|
| 167 |
+
metrics["bleu2_normalized"] = metrics["bleu2"]
|
| 168 |
+
if "bleu3" in metrics:
|
| 169 |
+
metrics["bleu3_normalized"] = metrics["bleu3"]
|
| 170 |
+
if "bleu4" in metrics:
|
| 171 |
+
metrics["bleu4_normalized"] = metrics["bleu4"]
|
| 172 |
+
if "rouge_l" in metrics:
|
| 173 |
+
metrics["rouge_l_normalized"] = metrics["rouge_l"]
|
| 174 |
+
if "meteor" in metrics:
|
| 175 |
+
metrics["meteor_normalized"] = metrics["meteor"]
|
| 176 |
+
if "bert_score" in metrics:
|
| 177 |
+
metrics["bert_score_raw"] = metrics["bert_score"]
|
| 178 |
+
if "semantic" in metrics:
|
| 179 |
+
metrics["semantic_raw"] = metrics["semantic"]
|
| 180 |
+
return metrics
|
| 181 |
+
|
| 182 |
+
class MedicalVQAEvaluator:
|
| 183 |
+
"""
|
| 184 |
+
Hệ thống đánh giá hợp nhất cho cả Hướng A và Hướng B.
|
| 185 |
+
"""
|
| 186 |
+
def __init__(self, device, tokenizer=None, processor=None):
|
| 187 |
+
self.device = device
|
| 188 |
+
self.tokenizer = tokenizer
|
| 189 |
+
self.processor = processor
|
| 190 |
+
|
| 191 |
+
def evaluate(self, model, dataloader, variant_type='A', beam_width=1):
|
| 192 |
+
"""
|
| 193 |
+
Giao diện chung để đánh giá bất kỳ variant nào.
|
| 194 |
+
"""
|
| 195 |
+
if variant_type == 'A':
|
| 196 |
+
return evaluate_vqa(model, dataloader, self.device, self.tokenizer, beam_width)
|
| 197 |
+
else:
|
| 198 |
+
return evaluate_multimodal_vqa(model, dataloader, self.device, self.processor, beam_width, variant=variant_type)
|
| 199 |
+
|
| 200 |
+
def evaluate_vqa(model, dataloader, device, tokenizer, beam_width=1, max_len=32, max_words=10):
|
| 201 |
+
model.eval()
|
| 202 |
+
all_preds = []
|
| 203 |
+
all_preds_raw = []
|
| 204 |
+
all_preds_display = []
|
| 205 |
+
all_refs = []
|
| 206 |
+
all_refs_full = []
|
| 207 |
+
all_is_closed = []
|
| 208 |
+
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
for batch in tqdm(dataloader, desc="Evaluating"):
|
| 211 |
+
images = batch['image'].to(device)
|
| 212 |
+
input_ids = batch['input_ids'].to(device)
|
| 213 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 214 |
+
labels = batch['label_closed']
|
| 215 |
+
|
| 216 |
+
# [FIX] Gọi inference() để lấy CẢ HAI head outputs, truyền max_len từ config
|
| 217 |
+
logits_closed, pred_ids = model.inference(images, input_ids, attention_mask, beam_width=beam_width, max_len=max_len)
|
| 218 |
+
|
| 219 |
+
# Decode generative head + làm sạch subword artifacts
|
| 220 |
+
preds_text_raw = [
|
| 221 |
+
postprocess_answer(t, max_words=max_words)
|
| 222 |
+
for t in tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
| 223 |
+
]
|
| 224 |
+
preds_text = list(preds_text_raw)
|
| 225 |
+
|
| 226 |
+
# [CRITICAL FIX] Với câu Đóng (Yes/No), dùng classifier head thay vì generator
|
| 227 |
+
closed_map = {0: "không", 1: "có"}
|
| 228 |
+
closed_preds_idx = torch.argmax(logits_closed, dim=-1) # [B]
|
| 229 |
+
for i in range(len(preds_text)):
|
| 230 |
+
if labels[i].item() != -1: # Câu hỏi đóng
|
| 231 |
+
preds_text[i] = closed_map[closed_preds_idx[i].item()]
|
| 232 |
+
preds_text[i] = postprocess_answer(preds_text[i], max_words=max_words)
|
| 233 |
+
|
| 234 |
+
# Debug: Hiển thị cả câu Đóng và câu Mở để kiểm tra đa dạng
|
| 235 |
+
if len(all_preds) == 0:
|
| 236 |
+
print("\n--- DEBUG PREDICTIONS ---")
|
| 237 |
+
shown_closed, shown_open = 0, 0
|
| 238 |
+
for i in range(len(preds_text)):
|
| 239 |
+
is_closed = labels[i].item() != -1
|
| 240 |
+
if (is_closed and shown_closed < 2) or (not is_closed and shown_open < 2):
|
| 241 |
+
q_type = "CLOSED" if is_closed else "OPEN"
|
| 242 |
+
print(f"[{q_type}] Q: {batch['raw_questions'][i]}")
|
| 243 |
+
print(f" Pred raw: '{preds_text_raw[i]}'")
|
| 244 |
+
print(f" Pred normalized: '{preds_text[i]}'")
|
| 245 |
+
print(f" GT : '{batch['raw_answer'][i]}'")
|
| 246 |
+
if is_closed: shown_closed += 1
|
| 247 |
+
else: shown_open += 1
|
| 248 |
+
if shown_closed >= 2 and shown_open >= 2:
|
| 249 |
+
break
|
| 250 |
+
print("--------------------------\n")
|
| 251 |
+
|
| 252 |
+
all_preds.extend([normalize_for_metric(p) for p in preds_text])
|
| 253 |
+
all_preds_raw.extend([normalize_for_metric(p) for p in preds_text_raw])
|
| 254 |
+
all_preds_display.extend([normalize_for_metric(p) for p in preds_text_raw])
|
| 255 |
+
# [CRITICAL FIX] Dùng đáp án Tiếng Việt để chấm điểm
|
| 256 |
+
all_refs.extend([normalize_for_metric(postprocess_answer(r, max_words=max_words)) for r in batch['raw_answer']])
|
| 257 |
+
all_refs_full.extend([normalize_for_metric(postprocess_answer(r, max_words=100)) for r in batch.get('raw_answer_full', batch['raw_answer'])])
|
| 258 |
+
is_closed = (batch['label_closed'] != -1).tolist()
|
| 259 |
+
all_is_closed.extend(is_closed)
|
| 260 |
+
|
| 261 |
+
metrics = batch_metrics(all_preds, all_refs)
|
| 262 |
+
metrics["semantic"] = compute_semantic_score(all_preds_raw, all_refs)
|
| 263 |
+
metrics["bert_score"] = compute_bertscore(all_preds_raw, all_refs)
|
| 264 |
+
metrics = _attach_metric_views(metrics)
|
| 265 |
+
metrics.update(_compute_format_stats(all_preds, max_words=max_words))
|
| 266 |
+
metrics['predictions'] = all_preds
|
| 267 |
+
metrics['predictions_raw'] = all_preds_raw
|
| 268 |
+
metrics['predictions_display'] = all_preds_display
|
| 269 |
+
metrics['ground_truths'] = all_refs
|
| 270 |
+
|
| 271 |
+
closed_preds = [p for p, c in zip(all_preds, all_is_closed) if c]
|
| 272 |
+
closed_refs = [r for r, c in zip(all_refs, all_is_closed) if c]
|
| 273 |
+
closed_preds_raw = [p for p, c in zip(all_preds_raw, all_is_closed) if c]
|
| 274 |
+
if closed_preds:
|
| 275 |
+
metrics['closed'] = batch_metrics(closed_preds, closed_refs)
|
| 276 |
+
metrics['closed']["semantic"] = compute_semantic_score(closed_preds_raw, closed_refs)
|
| 277 |
+
metrics['closed']["bert_score"] = compute_bertscore(closed_preds_raw, closed_refs)
|
| 278 |
+
metrics['closed'] = _attach_metric_views(metrics['closed'])
|
| 279 |
+
metrics['closed'].update(_compute_format_stats(closed_preds, max_words=max_words))
|
| 280 |
+
metrics['closed_eval'] = {
|
| 281 |
+
"accuracy": metrics['closed'].get("accuracy_normalized", 0.0),
|
| 282 |
+
"em": metrics['closed'].get("em_normalized", 0.0),
|
| 283 |
+
"f1": metrics['closed'].get("f1_normalized", 0.0),
|
| 284 |
+
"count": len(closed_preds),
|
| 285 |
+
}
|
| 286 |
+
open_preds = [p for p, c in zip(all_preds, all_is_closed) if not c]
|
| 287 |
+
open_refs = [r for r, c in zip(all_refs, all_is_closed) if not c]
|
| 288 |
+
open_preds_raw = [p for p, c in zip(all_preds_raw, all_is_closed) if not c]
|
| 289 |
+
if open_preds:
|
| 290 |
+
metrics['open'] = batch_metrics(open_preds, open_refs)
|
| 291 |
+
metrics['open']["semantic"] = compute_semantic_score(open_preds_raw, open_refs)
|
| 292 |
+
metrics['open']["bert_score"] = compute_bertscore(open_preds_raw, open_refs)
|
| 293 |
+
metrics['open'] = _attach_metric_views(metrics['open'])
|
| 294 |
+
metrics['open'].update(_compute_format_stats(open_preds, max_words=max_words))
|
| 295 |
+
metrics['open_eval'] = {
|
| 296 |
+
"semantic": metrics['open'].get("semantic_raw", 0.0),
|
| 297 |
+
"bert_score": metrics['open'].get("bert_score_raw", 0.0),
|
| 298 |
+
"f1": metrics['open'].get("f1_normalized", 0.0),
|
| 299 |
+
"rouge_l": metrics['open'].get("rouge_l_normalized", 0.0),
|
| 300 |
+
"count": len(open_preds),
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
metrics['long_answers_eval'] = {
|
| 304 |
+
"accuracy": batch_metrics(all_preds, all_refs_full).get("accuracy_normalized", 0),
|
| 305 |
+
"f1": batch_metrics(all_preds, all_refs_full).get("f1_normalized", 0),
|
| 306 |
+
"bleu4": batch_metrics(all_preds, all_refs_full).get("bleu4_normalized", 0),
|
| 307 |
+
"semantic": compute_semantic_score(all_preds_raw, all_refs_full),
|
| 308 |
+
"bert_score": compute_bertscore(all_preds_raw, all_refs_full)
|
| 309 |
+
}
|
| 310 |
+
return metrics
|
| 311 |
+
|
| 312 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 313 |
+
# B1 HELPERS
|
| 314 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 315 |
+
|
| 316 |
+
_B1_FEW_SHOT = (
|
| 317 |
+
"Q: Is there cardiomegaly? A: yes\n"
|
| 318 |
+
"Q: What organ is shown? A: lung\n"
|
| 319 |
+
"Q: Is the aorta normal? A: no\n"
|
| 320 |
+
"Q: What abnormality is present? A: pleural effusion\n"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _build_b1_prompt(question_en: str, max_words: int) -> str:
|
| 325 |
+
"""
|
| 326 |
+
Few-shot prompt ép LLaVA trả lời ngắn (≤max_words từ y tế), không sinh câu dài.
|
| 327 |
+
Đặt 4 ví dụ in-context trước câu hỏi thực để suppress verbose prefix.
|
| 328 |
+
"""
|
| 329 |
+
return (
|
| 330 |
+
f"USER: <image>\n"
|
| 331 |
+
f"Answer each question with medical terminology only, "
|
| 332 |
+
f"no more than {max_words} words, no full sentences.\n"
|
| 333 |
+
f"{_B1_FEW_SHOT}"
|
| 334 |
+
f"Q: {question_en} A: ASSISTANT:"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# En → Vi fast lookup (50+ thuật ngữ y tế thường gặp trong SLAKE + VQA-RAD)
|
| 339 |
+
_EN_VI_DIRECT: dict = {
|
| 340 |
+
# binary
|
| 341 |
+
"yes": "có", "no": "không",
|
| 342 |
+
"present": "có", "absent": "không",
|
| 343 |
+
"normal": "bình thường", "abnormal": "bất thường",
|
| 344 |
+
"true": "có", "false": "không",
|
| 345 |
+
"positive": "có", "negative": "không",
|
| 346 |
+
# anatomy
|
| 347 |
+
"lung": "phổi", "lungs": "phổi",
|
| 348 |
+
"heart": "tim", "liver": "gan", "spleen": "lách",
|
| 349 |
+
"kidney": "thận", "brain": "não", "bladder": "bàng quang",
|
| 350 |
+
"chest": "ngực", "abdomen": "bụng", "pelvis": "xương chậu",
|
| 351 |
+
"spine": "cột sống", "rib": "xương sườn", "ribs": "xương sườn",
|
| 352 |
+
"trachea": "khí quản", "aorta": "động mạch chủ",
|
| 353 |
+
"diaphragm": "cơ hoành", "mediastinum": "trung thất",
|
| 354 |
+
# modality
|
| 355 |
+
"chest x-ray": "x-quang ngực", "x-ray": "x-quang", "xray": "x-quang",
|
| 356 |
+
"mri": "mri", "ct": "ct", "ultrasound": "siêu âm",
|
| 357 |
+
"ct scan": "ct", "mri scan": "mri",
|
| 358 |
+
# planes
|
| 359 |
+
"axial": "mặt phẳng ngang",
|
| 360 |
+
"coronal": "mặt phẳng vành",
|
| 361 |
+
"sagittal": "mặt phẳng dọc",
|
| 362 |
+
"transverse": "mặt phẳng ngang",
|
| 363 |
+
# pathologies
|
| 364 |
+
"cardiomegaly": "tim to",
|
| 365 |
+
"pneumonia": "viêm phổi",
|
| 366 |
+
"pleural effusion": "tràn dịch màng phổi",
|
| 367 |
+
"pneumothorax": "tràn khí màng phổi",
|
| 368 |
+
"fracture": "gãy xương",
|
| 369 |
+
"edema": "phù nề",
|
| 370 |
+
"pulmonary edema": "phù phổi",
|
| 371 |
+
"consolidation": "đông đặc",
|
| 372 |
+
"atelectasis": "xẹp phổi",
|
| 373 |
+
"opacity": "mờ đục",
|
| 374 |
+
"mass": "khối u",
|
| 375 |
+
"nodule": "nốt",
|
| 376 |
+
"lesion": "tổn thương",
|
| 377 |
+
"tumor": "khối u",
|
| 378 |
+
"effusion": "tràn dịch",
|
| 379 |
+
"infiltrate": "thâm nhiễm",
|
| 380 |
+
"fibrosis": "xơ hóa",
|
| 381 |
+
"calcification": "vôi hóa",
|
| 382 |
+
"carcinoma": "ung thư",
|
| 383 |
+
"metastasis": "di căn",
|
| 384 |
+
"bilateral": "hai bên",
|
| 385 |
+
"unilateral": "một bên",
|
| 386 |
+
"left": "trái", "right": "phải",
|
| 387 |
+
"upper": "trên", "lower": "dưới",
|
| 388 |
+
"right upper quadrant": "phía trên bên phải",
|
| 389 |
+
"left upper quadrant": "phía trên bên trái",
|
| 390 |
+
"right lower quadrant": "phía dưới bên phải",
|
| 391 |
+
"left lower quadrant": "phía dưới bên trái",
|
| 392 |
+
"right upper": "phía trên bên phải",
|
| 393 |
+
"left upper": "phía trên bên trái",
|
| 394 |
+
"upper left": "phía trên bên trái",
|
| 395 |
+
"upper right": "phía trên bên phải",
|
| 396 |
+
"lower left": "phía dưới bên trái",
|
| 397 |
+
"lower right": "phía dưới bên phải",
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def _extract_key_medical_term(raw_en: str, max_words: int) -> str:
|
| 402 |
+
"""
|
| 403 |
+
Loại bỏ verbose prefix LLaVA hay sinh ("The image shows a chest X-ray with..."),
|
| 404 |
+
chỉ giữ lại thuật ngữ y tế chính.
|
| 405 |
+
"""
|
| 406 |
+
import re
|
| 407 |
+
text = raw_en.strip().lower()
|
| 408 |
+
|
| 409 |
+
# Các prefix verbose phổ biến cần xóa
|
| 410 |
+
prefixes = [
|
| 411 |
+
r"^the (image|scan|x-ray|xray|mri|ct|picture|photo|radiograph) (shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+",
|
| 412 |
+
r"^based on the (image|scan|x-ray|mri|ct)\s*,?\s*",
|
| 413 |
+
r"^in (this|the) (image|scan|x-ray|mri|ct)\s*,?\s*",
|
| 414 |
+
r"^i (can see|observe|notice|see)\s+",
|
| 415 |
+
r"^there (is|are)\s+(a |an |some )?",
|
| 416 |
+
r"^(it |this )(shows?|is|appears?|looks?)\s+(like\s+)?",
|
| 417 |
+
r"^the (patient|subject)\s+(has|shows?|presents?)\s+",
|
| 418 |
+
r"^(a|an|the)\s+",
|
| 419 |
+
r"^[a-z\s]+ is (located|seen|found|present)( in| at| on)?\s+(the\s+)?",
|
| 420 |
+
]
|
| 421 |
+
for pat in prefixes:
|
| 422 |
+
text = re.sub(pat, "", text)
|
| 423 |
+
|
| 424 |
+
text = re.sub(r"[.!?,;:]+$", "", text).strip()
|
| 425 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 426 |
+
|
| 427 |
+
words = text.split()
|
| 428 |
+
return " ".join(words[:max_words]) if words else raw_en.strip()
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def _en_to_vi_direct(en_text: str) -> str | None:
|
| 432 |
+
"""
|
| 433 |
+
Tra từ điển nhanh. Sắp xếp theo độ dài giảm dần để phrase dài match trước.
|
| 434 |
+
Trả về None nếu không match → caller dùng Translation Model.
|
| 435 |
+
"""
|
| 436 |
+
norm = en_text.strip().lower()
|
| 437 |
+
if norm in _EN_VI_DIRECT:
|
| 438 |
+
return _EN_VI_DIRECT[norm]
|
| 439 |
+
return None
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def _dual_score_open(
|
| 443 |
+
preds_vi: list,
|
| 444 |
+
preds_en: list,
|
| 445 |
+
refs_vi: list,
|
| 446 |
+
refs_en: list,
|
| 447 |
+
) -> list:
|
| 448 |
+
"""
|
| 449 |
+
Với mỗi câu hỏi mở, so sánh F1 Vi vs F1 En rồi chọn prediction tốt hơn.
|
| 450 |
+
Giải quyết 0% open-ended do dịch thuật mất nghĩa.
|
| 451 |
+
"""
|
| 452 |
+
from src.utils.metrics import compute_f1
|
| 453 |
+
from src.utils.text_utils import normalize_answer
|
| 454 |
+
result = []
|
| 455 |
+
for pv, pe, rv, re_ in zip(preds_vi, preds_en, refs_vi, refs_en):
|
| 456 |
+
f1_vi = compute_f1(pv, rv)
|
| 457 |
+
f1_en = compute_f1(normalize_answer(pe), normalize_answer(re_)) if re_ else 0.0
|
| 458 |
+
result.append(pv if f1_vi >= f1_en else normalize_answer(pe))
|
| 459 |
+
return result
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def evaluate_multimodal_vqa(
|
| 463 |
+
model,
|
| 464 |
+
dataloader,
|
| 465 |
+
device,
|
| 466 |
+
processor,
|
| 467 |
+
beam_width=1,
|
| 468 |
+
max_words=10,
|
| 469 |
+
variant='B1',
|
| 470 |
+
beam_width_closed=None,
|
| 471 |
+
beam_width_open=None,
|
| 472 |
+
max_new_tokens_closed=None,
|
| 473 |
+
max_new_tokens_open=None,
|
| 474 |
+
generation_batch_size=None,
|
| 475 |
+
):
|
| 476 |
+
"""
|
| 477 |
+
B1 Zero-Shot evaluation & B2/DPO/PPO Fine-Tuned evaluation.
|
| 478 |
+
"""
|
| 479 |
+
model.eval()
|
| 480 |
+
all_preds = []
|
| 481 |
+
all_preds_raw = []
|
| 482 |
+
all_preds_display = []
|
| 483 |
+
all_preds_en = []
|
| 484 |
+
all_refs = []
|
| 485 |
+
all_refs_full = []
|
| 486 |
+
all_refs_en = []
|
| 487 |
+
all_is_closed = []
|
| 488 |
+
|
| 489 |
+
from src.utils.translator import MedicalTranslator
|
| 490 |
+
translator = MedicalTranslator(device=device.type)
|
| 491 |
+
|
| 492 |
+
from src.models.multimodal_vqa import MultimodalVQA
|
| 493 |
+
wrapper = MultimodalVQA()
|
| 494 |
+
|
| 495 |
+
beam_width_closed = beam_width if beam_width_closed is None else beam_width_closed
|
| 496 |
+
beam_width_open = beam_width if beam_width_open is None else beam_width_open
|
| 497 |
+
max_new_tokens_closed = 4 if max_new_tokens_closed is None else max_new_tokens_closed
|
| 498 |
+
max_new_tokens_open = (max_words + 6) if max_new_tokens_open is None else max_new_tokens_open
|
| 499 |
+
generation_batch_size = 1 if generation_batch_size is None else max(1, int(generation_batch_size))
|
| 500 |
+
bad_words_ids = _build_bad_words_ids(processor, variant)
|
| 501 |
+
|
| 502 |
+
with torch.no_grad():
|
| 503 |
+
for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating {variant}")):
|
| 504 |
+
raw_images = batch.get('raw_image')
|
| 505 |
+
questions_vi = batch.get('raw_questions', [])
|
| 506 |
+
questions_en = batch.get('raw_questions_en', [])
|
| 507 |
+
refs_vi_raw = batch.get('raw_answer', [])
|
| 508 |
+
refs_en_raw = batch.get('raw_answer_en', [])
|
| 509 |
+
labels = batch['label_closed']
|
| 510 |
+
|
| 511 |
+
if variant == 'B1':
|
| 512 |
+
# B1 (Zero-shot) needs English translation & English few-shot prompt
|
| 513 |
+
if not questions_en or any(not str(q).strip() for q in questions_en):
|
| 514 |
+
questions_en = translator.translate_vi2en(questions_vi)
|
| 515 |
+
prompts = [_build_b1_prompt(q, max_words) for q in questions_en]
|
| 516 |
+
else:
|
| 517 |
+
# B2 / DPO / PPO (Fine-tuned) expect Vietnamese instruction directly
|
| 518 |
+
prompts = [wrapper.build_instruction_prompt(q, language="vi", include_answer=False) for q in questions_vi]
|
| 519 |
+
preds_raw = [""] * len(prompts)
|
| 520 |
+
closed_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl != -1]
|
| 521 |
+
open_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl == -1]
|
| 522 |
+
|
| 523 |
+
def _run_generation(sample_indices, num_beams, max_new_tokens):
|
| 524 |
+
if not sample_indices:
|
| 525 |
+
return []
|
| 526 |
+
decoded_outputs = []
|
| 527 |
+
chunk_size = generation_batch_size if num_beams > 1 else max(generation_batch_size, 2)
|
| 528 |
+
|
| 529 |
+
for start in range(0, len(sample_indices), chunk_size):
|
| 530 |
+
chunk_indices = sample_indices[start:start + chunk_size]
|
| 531 |
+
text_subset = [prompts[i] for i in chunk_indices]
|
| 532 |
+
image_subset = [raw_images[i] for i in chunk_indices] if raw_images is not None else None
|
| 533 |
+
if image_subset is not None:
|
| 534 |
+
inputs = processor(
|
| 535 |
+
text=text_subset,
|
| 536 |
+
images=image_subset,
|
| 537 |
+
return_tensors="pt",
|
| 538 |
+
padding=True,
|
| 539 |
+
).to(device)
|
| 540 |
+
if "pixel_values" in inputs:
|
| 541 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
|
| 542 |
+
else:
|
| 543 |
+
inputs = processor(text=text_subset, return_tensors="pt", padding=True).to(device)
|
| 544 |
+
|
| 545 |
+
output_ids = model.generate(
|
| 546 |
+
**inputs,
|
| 547 |
+
max_new_tokens=max_new_tokens,
|
| 548 |
+
do_sample=False,
|
| 549 |
+
num_beams=num_beams,
|
| 550 |
+
early_stopping=num_beams > 1,
|
| 551 |
+
bad_words_ids=bad_words_ids,
|
| 552 |
+
)
|
| 553 |
+
input_token_len = inputs.input_ids.shape[1]
|
| 554 |
+
decoded_outputs.extend(
|
| 555 |
+
processor.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
del inputs, output_ids
|
| 559 |
+
if device.type == "cuda":
|
| 560 |
+
torch.cuda.empty_cache()
|
| 561 |
+
|
| 562 |
+
return decoded_outputs
|
| 563 |
+
|
| 564 |
+
if variant == 'B1':
|
| 565 |
+
generated = _run_generation(list(range(len(prompts))), beam_width_open, max_new_tokens_open)
|
| 566 |
+
preds_raw = generated
|
| 567 |
+
else:
|
| 568 |
+
for idx, pred in zip(closed_idx, _run_generation(closed_idx, beam_width_closed, max_new_tokens_closed)):
|
| 569 |
+
preds_raw[idx] = pred
|
| 570 |
+
for idx, pred in zip(open_idx, _run_generation(open_idx, beam_width_open, max_new_tokens_open)):
|
| 571 |
+
preds_raw[idx] = pred
|
| 572 |
+
|
| 573 |
+
preds_vi = []
|
| 574 |
+
preds_vi_display = []
|
| 575 |
+
preds_en_clean = []
|
| 576 |
+
|
| 577 |
+
if variant == 'B1':
|
| 578 |
+
# [FIX 2] Strip verbose prefix → giữ key medical term. Tránh cắt vụn câu tiếng Anh để Dịch thuật hiểu đúng.
|
| 579 |
+
preds_en_clean = [_extract_key_medical_term(p, 50) for p in preds_raw]
|
| 580 |
+
|
| 581 |
+
# [FIX 3 + 5] Per-sample: closed → normalize En trước; open → dict lookup rồi Translation Model
|
| 582 |
+
needs_translate_idx = [] # index cần dịch
|
| 583 |
+
needs_translate_txt = []
|
| 584 |
+
|
| 585 |
+
for i, pred_en in enumerate(preds_en_clean):
|
| 586 |
+
if labels[i].item() != -1:
|
| 587 |
+
# Closed: dùng _normalize_closed_answer với En pred (chính xác hơn)
|
| 588 |
+
preds_vi.append(
|
| 589 |
+
_normalize_closed_answer(
|
| 590 |
+
questions_vi[i], questions_en[i], pred_en, pred_en
|
| 591 |
+
)
|
| 592 |
+
)
|
| 593 |
+
else:
|
| 594 |
+
# Open: thử dict nhanh trước
|
| 595 |
+
vi_direct = _en_to_vi_direct(pred_en)
|
| 596 |
+
if vi_direct is not None:
|
| 597 |
+
preds_vi.append(postprocess_answer(vi_direct, max_words=max_words))
|
| 598 |
+
else:
|
| 599 |
+
preds_vi.append(None) # placeholder
|
| 600 |
+
needs_translate_idx.append(i)
|
| 601 |
+
needs_translate_txt.append(pred_en)
|
| 602 |
+
|
| 603 |
+
# Batch dịch những câu cần Translation Model
|
| 604 |
+
if needs_translate_txt:
|
| 605 |
+
translated = translator.translate_en2vi(needs_translate_txt)
|
| 606 |
+
if isinstance(translated, str):
|
| 607 |
+
translated = [translated]
|
| 608 |
+
for idx, vi in zip(needs_translate_idx, translated):
|
| 609 |
+
preds_vi[idx] = postprocess_answer(vi, max_words=max_words)
|
| 610 |
+
preds_vi_display = list(preds_vi)
|
| 611 |
+
else:
|
| 612 |
+
# B2 / DPO / PPO directly outputs Vietnamese, no translation needed
|
| 613 |
+
preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_raw]
|
| 614 |
+
for i, pred_vi in enumerate(preds_raw):
|
| 615 |
+
if labels[i].item() != -1:
|
| 616 |
+
preds_vi.append(
|
| 617 |
+
_normalize_closed_answer(
|
| 618 |
+
questions_vi[i], questions_en[i] if i < len(questions_en) else "", pred_vi
|
| 619 |
+
)
|
| 620 |
+
)
|
| 621 |
+
else:
|
| 622 |
+
preds_vi.append(pred_vi)
|
| 623 |
+
preds_en_clean = [""] * len(preds_raw)
|
| 624 |
+
|
| 625 |
+
# Đảm bảo không có None
|
| 626 |
+
preds_vi = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi]
|
| 627 |
+
preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi_display]
|
| 628 |
+
preds_vi_raw = list(preds_vi_display)
|
| 629 |
+
|
| 630 |
+
# Refs
|
| 631 |
+
refs_vi = [postprocess_answer(r, max_words=max_words) for r in refs_vi_raw]
|
| 632 |
+
refs_en = [postprocess_answer(r, max_words=max_words) if r else "" for r in refs_en_raw]
|
| 633 |
+
|
| 634 |
+
# Debug batch đầu
|
| 635 |
+
if batch_idx == 0:
|
| 636 |
+
print(f"\n--- DEBUG {variant} (Evaluation) ---")
|
| 637 |
+
for i in range(min(4, len(preds_vi))):
|
| 638 |
+
q_type = "CLOSED" if labels[i].item() != -1 else "OPEN"
|
| 639 |
+
if variant == 'B1':
|
| 640 |
+
print(f"[{q_type}] Q (En): {questions_en[i]}")
|
| 641 |
+
print(f" Pred (En raw): '{preds_raw[i]}'")
|
| 642 |
+
print(f" Pred (En clean): '{preds_en_clean[i]}'")
|
| 643 |
+
else:
|
| 644 |
+
print(f"[{q_type}] Q (Vi): {questions_vi[i]}")
|
| 645 |
+
print(f" Pred (Vi raw): '{preds_raw[i]}'")
|
| 646 |
+
print(f" Pred display: '{preds_vi_display[i]}'")
|
| 647 |
+
print(f" Pred (Vi): '{preds_vi[i]}'")
|
| 648 |
+
print(f" GT (Vi): '{refs_vi[i]}' | GT (En): '{refs_en[i]}'")
|
| 649 |
+
print("-----------------------------------------\n")
|
| 650 |
+
|
| 651 |
+
all_preds.extend([normalize_for_metric(p) for p in preds_vi])
|
| 652 |
+
all_preds_raw.extend([normalize_for_metric(p) for p in preds_vi_raw])
|
| 653 |
+
all_preds_display.extend([normalize_for_metric(p) for p in preds_vi_display])
|
| 654 |
+
all_preds_en.extend([normalize_for_metric(p) for p in preds_en_clean])
|
| 655 |
+
all_refs.extend([normalize_for_metric(r) for r in refs_vi])
|
| 656 |
+
all_refs_full.extend([normalize_for_metric(postprocess_answer(r, max_words=100)) for r in batch.get('raw_answer_full', batch['raw_answer'])])
|
| 657 |
+
all_refs_en.extend([normalize_for_metric(r) for r in refs_en])
|
| 658 |
+
all_is_closed.extend((labels != -1).tolist())
|
| 659 |
+
|
| 660 |
+
# [FIX 4] Dual-language scoring cho open-ended (chỉ dùng cho B1)
|
| 661 |
+
if variant == 'B1':
|
| 662 |
+
open_idx = [i for i, c in enumerate(all_is_closed) if not c]
|
| 663 |
+
if open_idx:
|
| 664 |
+
best_open = _dual_score_open(
|
| 665 |
+
[all_preds[i] for i in open_idx],
|
| 666 |
+
[all_preds_en[i] for i in open_idx],
|
| 667 |
+
[all_refs[i] for i in open_idx],
|
| 668 |
+
[all_refs_en[i] for i in open_idx],
|
| 669 |
+
)
|
| 670 |
+
for k, i in enumerate(open_idx):
|
| 671 |
+
all_preds[i] = best_open[k]
|
| 672 |
+
|
| 673 |
+
# ── Compute metrics ──────────────────────────────────────────────────────
|
| 674 |
+
metrics = batch_metrics(all_preds, all_refs)
|
| 675 |
+
metrics["semantic"] = compute_semantic_score(all_preds_raw, all_refs)
|
| 676 |
+
metrics["bert_score"] = compute_bertscore(all_preds_raw, all_refs)
|
| 677 |
+
metrics = _attach_metric_views(metrics)
|
| 678 |
+
metrics.update(_compute_format_stats(all_preds, max_words=max_words))
|
| 679 |
+
metrics['predictions'] = all_preds
|
| 680 |
+
metrics['predictions_raw'] = all_preds_raw
|
| 681 |
+
metrics['predictions_display'] = all_preds_display
|
| 682 |
+
metrics['predictions_en'] = all_preds_en
|
| 683 |
+
metrics['ground_truths'] = all_refs
|
| 684 |
+
metrics['ground_truths_en'] = all_refs_en
|
| 685 |
+
|
| 686 |
+
def _subset(pred_list, ref_list, pred_raw_list):
|
| 687 |
+
m = batch_metrics(pred_list, ref_list)
|
| 688 |
+
m["semantic"] = compute_semantic_score(pred_raw_list, ref_list)
|
| 689 |
+
m["bert_score"] = compute_bertscore(pred_raw_list, ref_list)
|
| 690 |
+
m = _attach_metric_views(m)
|
| 691 |
+
m.update(_compute_format_stats(pred_list, max_words=max_words))
|
| 692 |
+
return m
|
| 693 |
+
|
| 694 |
+
closed_idx = [i for i, c in enumerate(all_is_closed) if c]
|
| 695 |
+
open_idx = [i for i, c in enumerate(all_is_closed) if not c]
|
| 696 |
+
|
| 697 |
+
if closed_idx:
|
| 698 |
+
metrics['closed'] = _subset(
|
| 699 |
+
[all_preds[i] for i in closed_idx],
|
| 700 |
+
[all_refs[i] for i in closed_idx],
|
| 701 |
+
[all_preds_raw[i] for i in closed_idx],
|
| 702 |
+
)
|
| 703 |
+
metrics['closed_eval'] = {
|
| 704 |
+
"accuracy": metrics['closed'].get("accuracy_normalized", 0.0),
|
| 705 |
+
"em": metrics['closed'].get("em_normalized", 0.0),
|
| 706 |
+
"f1": metrics['closed'].get("f1_normalized", 0.0),
|
| 707 |
+
"count": len(closed_idx),
|
| 708 |
+
}
|
| 709 |
+
if open_idx:
|
| 710 |
+
metrics['open'] = _subset(
|
| 711 |
+
[all_preds[i] for i in open_idx],
|
| 712 |
+
[all_refs[i] for i in open_idx],
|
| 713 |
+
[all_preds_raw[i] for i in open_idx],
|
| 714 |
+
)
|
| 715 |
+
metrics['open_eval'] = {
|
| 716 |
+
"semantic": metrics['open'].get("semantic_raw", 0.0),
|
| 717 |
+
"bert_score": metrics['open'].get("bert_score_raw", 0.0),
|
| 718 |
+
"f1": metrics['open'].get("f1_normalized", 0.0),
|
| 719 |
+
"rouge_l": metrics['open'].get("rouge_l_normalized", 0.0),
|
| 720 |
+
"count": len(open_idx),
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
metrics['long_answers_eval'] = {
|
| 724 |
+
"accuracy": batch_metrics(all_preds, all_refs_full).get("accuracy_normalized", 0),
|
| 725 |
+
"f1": batch_metrics(all_preds, all_refs_full).get("f1_normalized", 0),
|
| 726 |
+
"bleu4": batch_metrics(all_preds, all_refs_full).get("bleu4_normalized", 0),
|
| 727 |
+
"semantic": compute_semantic_score(all_preds_raw, all_refs_full),
|
| 728 |
+
"bert_score": compute_bertscore(all_preds_raw, all_refs_full)
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
return metrics
|
src/engine/trainer.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import wandb
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import os
|
| 6 |
+
import csv
|
| 7 |
+
import json
|
| 8 |
+
from src.utils.early_stopping import DynamicClassWeights
|
| 9 |
+
|
| 10 |
+
class MedicalVQATrainer:
|
| 11 |
+
def __init__(self, model, train_loader, val_loader, optimizer, device, config, scheduler=None, pad_token_id=0, beam_width=1):
|
| 12 |
+
self.model = model
|
| 13 |
+
self.train_loader = train_loader
|
| 14 |
+
self.val_loader = val_loader
|
| 15 |
+
self.optimizer = optimizer
|
| 16 |
+
self.scheduler = scheduler
|
| 17 |
+
self.device = device
|
| 18 |
+
self.config = config
|
| 19 |
+
self.beam_width = beam_width
|
| 20 |
+
|
| 21 |
+
# [FIX] Dynamic class weights computed from actual data distribution
|
| 22 |
+
# Replaces hard-coded [1.0, 2.5] which may not match real imbalance ratio
|
| 23 |
+
dynamic_weights = DynamicClassWeights.compute_weights(train_loader, device=device)
|
| 24 |
+
self.criterion_closed = nn.CrossEntropyLoss(weight=dynamic_weights)
|
| 25 |
+
|
| 26 |
+
# [NOTE] Label smoothing only on open-ended head: closed-head needs sharp 0/1
|
| 27 |
+
self.criterion_open = nn.CrossEntropyLoss(
|
| 28 |
+
ignore_index=pad_token_id,
|
| 29 |
+
label_smoothing=config['train'].get('label_smoothing', 0.0)
|
| 30 |
+
)
|
| 31 |
+
self.criterion_closed_hard = nn.CrossEntropyLoss(weight=dynamic_weights) # no smoothing
|
| 32 |
+
|
| 33 |
+
# AMP (Automatic Mixed Precision)
|
| 34 |
+
self.use_amp = config['train'].get('use_amp', False) and device.type == 'cuda'
|
| 35 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
|
| 36 |
+
self.history = []
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def _flatten_dict(data, parent_key="", sep="."):
|
| 40 |
+
items = {}
|
| 41 |
+
for key, value in data.items():
|
| 42 |
+
new_key = f"{parent_key}{sep}{key}" if parent_key else str(key)
|
| 43 |
+
if isinstance(value, dict):
|
| 44 |
+
items.update(MedicalVQATrainer._flatten_dict(value, new_key, sep=sep))
|
| 45 |
+
elif isinstance(value, (list, tuple)):
|
| 46 |
+
continue
|
| 47 |
+
else:
|
| 48 |
+
items[new_key] = value
|
| 49 |
+
return items
|
| 50 |
+
|
| 51 |
+
def save_history(self, output_dir):
|
| 52 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 53 |
+
json_path = os.path.join(output_dir, "history.json")
|
| 54 |
+
csv_path = os.path.join(output_dir, "history.csv")
|
| 55 |
+
|
| 56 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
| 57 |
+
json.dump(self.history, f, ensure_ascii=False, indent=2)
|
| 58 |
+
|
| 59 |
+
flat_rows = [self._flatten_dict(row) for row in self.history]
|
| 60 |
+
if flat_rows:
|
| 61 |
+
fieldnames = sorted({key for row in flat_rows for key in row.keys()})
|
| 62 |
+
with open(csv_path, "w", encoding="utf-8", newline="") as f:
|
| 63 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
| 64 |
+
writer.writeheader()
|
| 65 |
+
writer.writerows(flat_rows)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def _compute_closed_weights(train_loader):
|
| 69 |
+
"""Đếm phân phối Yes/No và tính inverse frequency weights."""
|
| 70 |
+
counts = {0: 0, 1: 0} # 0=không, 1=có
|
| 71 |
+
for batch in train_loader:
|
| 72 |
+
labels = batch['label_closed']
|
| 73 |
+
for lbl in labels:
|
| 74 |
+
v = lbl.item()
|
| 75 |
+
if v in counts:
|
| 76 |
+
counts[v] += 1
|
| 77 |
+
|
| 78 |
+
total = counts[0] + counts[1]
|
| 79 |
+
if total == 0:
|
| 80 |
+
return torch.ones(2)
|
| 81 |
+
|
| 82 |
+
# Inverse frequency: class ít mẫu → weight cao hơn
|
| 83 |
+
w0 = total / (2 * max(counts[0], 1))
|
| 84 |
+
w1 = total / (2 * max(counts[1], 1))
|
| 85 |
+
weights = torch.tensor([w0, w1], dtype=torch.float32)
|
| 86 |
+
print(f"[INFO] Closed question distribution: không={counts[0]}, có={counts[1]}")
|
| 87 |
+
return weights
|
| 88 |
+
|
| 89 |
+
def train_epoch(self, epoch):
|
| 90 |
+
self.model.train()
|
| 91 |
+
total_loss = 0
|
| 92 |
+
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
|
| 93 |
+
|
| 94 |
+
# [OPTIMIZATION] Gradient accumulation for larger effective batch size
|
| 95 |
+
accumulation_steps = self.config['train'].get('gradient_accumulation_steps', 2)
|
| 96 |
+
|
| 97 |
+
for batch_idx, batch in enumerate(pbar):
|
| 98 |
+
images = batch['image'].to(self.device)
|
| 99 |
+
input_ids = batch['input_ids'].to(self.device)
|
| 100 |
+
attention_mask = batch['attention_mask'].to(self.device)
|
| 101 |
+
label_closed = batch['label_closed'].to(self.device)
|
| 102 |
+
target_ids = batch['target_ids'].to(self.device)
|
| 103 |
+
|
| 104 |
+
# Zero gradients only at the beginning or after optimizer step
|
| 105 |
+
if batch_idx % accumulation_steps == 0:
|
| 106 |
+
self.optimizer.zero_grad()
|
| 107 |
+
|
| 108 |
+
# Sử dụng AMP Autocast
|
| 109 |
+
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
| 110 |
+
# Teacher Forcing: Input là <s> A B, Target là A B </s>
|
| 111 |
+
decoder_input = target_ids[:, :-1]
|
| 112 |
+
decoder_target = target_ids[:, 1:]
|
| 113 |
+
|
| 114 |
+
logits_closed, logits_open = self.model(images, input_ids, attention_mask, decoder_input)
|
| 115 |
+
|
| 116 |
+
# Loss calculation
|
| 117 |
+
loss = 0
|
| 118 |
+
mask_closed = (label_closed != -1)
|
| 119 |
+
if mask_closed.any():
|
| 120 |
+
loss += self.criterion_closed(logits_closed[mask_closed], label_closed[mask_closed])
|
| 121 |
+
|
| 122 |
+
# Phân tách Loss Generator để chống Mode Collapse (Lười biếng)
|
| 123 |
+
vocab_size = logits_open.size(-1)
|
| 124 |
+
mask_open = (label_closed == -1)
|
| 125 |
+
|
| 126 |
+
# 1. Câu hỏi Yes/No: Giảm trọng số xuống cực thấp (0.1) để model không bị thiên vị
|
| 127 |
+
if mask_closed.any():
|
| 128 |
+
loss_gen_closed = self.criterion_open(logits_open[mask_closed].reshape(-1, vocab_size), decoder_target[mask_closed].reshape(-1))
|
| 129 |
+
loss += loss_gen_closed * 0.1
|
| 130 |
+
|
| 131 |
+
# 2. Câu hỏi Mở: Tăng trọng số + Length Penalty + Coverage Penalty
|
| 132 |
+
if mask_open.any():
|
| 133 |
+
open_logits = logits_open[mask_open]
|
| 134 |
+
open_targets = decoder_target[mask_open]
|
| 135 |
+
loss_gen_open = self.criterion_open(open_logits.reshape(-1, vocab_size), open_targets.reshape(-1))
|
| 136 |
+
|
| 137 |
+
# Length penalty: phạt nếu model sinh quá ít token có nghĩa
|
| 138 |
+
pred_lengths = (open_targets != self.criterion_open.ignore_index).float().sum(dim=-1).mean()
|
| 139 |
+
length_penalty = torch.clamp(1.0 - pred_lengths / 15.0, min=0.0)
|
| 140 |
+
|
| 141 |
+
# Thay coverage loss bằng entropy penalty (đúng hơn)
|
| 142 |
+
# Phạt khi model quá confident vào 1 token
|
| 143 |
+
probs = torch.softmax(open_logits, dim=-1) # [N, seq, vocab]
|
| 144 |
+
entropy = -(probs * torch.log(probs + 1e-9)).sum(dim=-1).mean()
|
| 145 |
+
coverage_loss = torch.clamp(2.0 - entropy, min=0.0) # phạt nếu entropy < 2.0
|
| 146 |
+
|
| 147 |
+
# [TUNED] Reduce weight 3.0→2.0: open head was dominating,
|
| 148 |
+
# causing closed-head accuracy to plateau (observed in A1/A2 runs)
|
| 149 |
+
open_loss_weight = self.config.get('open_loss_weight', 2.0)
|
| 150 |
+
loss += (loss_gen_open + 0.3 * length_penalty + 0.1 * coverage_loss) * open_loss_weight
|
| 151 |
+
|
| 152 |
+
# [OPTIMIZATION] Normalize loss by accumulation steps for proper gradient scaling
|
| 153 |
+
loss = loss / accumulation_steps
|
| 154 |
+
|
| 155 |
+
# Backward với GradScaler
|
| 156 |
+
self.scaler.scale(loss).backward()
|
| 157 |
+
|
| 158 |
+
# [OPTIMIZATION] Update weights only after accumulating gradients
|
| 159 |
+
is_last_batch = (batch_idx + 1) == len(self.train_loader)
|
| 160 |
+
if (batch_idx + 1) % accumulation_steps == 0 or is_last_batch:
|
| 161 |
+
# Gradient Clipping
|
| 162 |
+
if self.config['train'].get('grad_clip'):
|
| 163 |
+
self.scaler.unscale_(self.optimizer)
|
| 164 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), self.config['train']['grad_clip'])
|
| 165 |
+
|
| 166 |
+
self.scaler.step(self.optimizer)
|
| 167 |
+
self.scaler.update()
|
| 168 |
+
|
| 169 |
+
# [CRITICAL FIX] Step scheduler sau mỗi batch thay vì epoch để warmup mượt hơn
|
| 170 |
+
if self.scheduler:
|
| 171 |
+
self.scheduler.step()
|
| 172 |
+
|
| 173 |
+
total_loss += loss.item() * accumulation_steps
|
| 174 |
+
# [FIX] Log LR cho từng param group — hiển thị decoder LR (group cuối) trên progress bar
|
| 175 |
+
decoder_lr = self.optimizer.param_groups[-1]['lr']
|
| 176 |
+
vision_lr = self.optimizer.param_groups[0]['lr']
|
| 177 |
+
if wandb.run:
|
| 178 |
+
wandb.log({
|
| 179 |
+
"batch_loss": loss.item(),
|
| 180 |
+
"lr_vision": vision_lr,
|
| 181 |
+
"lr_decoder": decoder_lr,
|
| 182 |
+
})
|
| 183 |
+
pbar.set_postfix({"loss": f"{loss.item():.3f}", "dec_lr": f"{decoder_lr:.1e}", "vis_lr": f"{vision_lr:.1e}"})
|
| 184 |
+
|
| 185 |
+
epoch_train_loss = total_loss / len(self.train_loader)
|
| 186 |
+
if wandb.run:
|
| 187 |
+
wandb.log({"train_loss_epoch": epoch_train_loss})
|
| 188 |
+
return epoch_train_loss
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def val_epoch(self, tokenizer, epoch=0):
|
| 192 |
+
"""
|
| 193 |
+
Thực hiện đánh giá trên tập Validation sau mỗi Epoch.
|
| 194 |
+
"""
|
| 195 |
+
from src.engine.medical_eval import evaluate_vqa
|
| 196 |
+
max_ans_len = self.config.get('data', {}).get('max_answer_len', 32)
|
| 197 |
+
max_words = self.config.get('data', {}).get('answer_max_words', 10)
|
| 198 |
+
print(f"\n🔍 Đang chạy Validation cho Epoch {epoch} (max_ans_len={max_ans_len})...")
|
| 199 |
+
metrics = evaluate_vqa(
|
| 200 |
+
self.model,
|
| 201 |
+
self.val_loader,
|
| 202 |
+
self.device,
|
| 203 |
+
tokenizer,
|
| 204 |
+
beam_width=self.beam_width,
|
| 205 |
+
max_len=max_ans_len,
|
| 206 |
+
max_words=max_words
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# In các metrics quan trọng
|
| 210 |
+
print(
|
| 211 |
+
f"[METRICS] Accuracy: {metrics.get('accuracy_normalized', metrics['accuracy']):.4f} | "
|
| 212 |
+
f"F1: {metrics.get('f1_normalized', metrics['f1']):.4f} | "
|
| 213 |
+
f"BLEU-4: {metrics.get('bleu4_normalized', metrics['bleu4']):.4f}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if wandb.run:
|
| 217 |
+
wandb.log({
|
| 218 |
+
"epoch": epoch,
|
| 219 |
+
"val_accuracy": metrics["accuracy"],
|
| 220 |
+
"val_accuracy_normalized": metrics.get("accuracy_normalized", metrics["accuracy"]),
|
| 221 |
+
"val_f1": metrics["f1"],
|
| 222 |
+
"val_f1_normalized": metrics.get("f1_normalized", metrics["f1"]),
|
| 223 |
+
"val_bleu4": metrics["bleu4"],
|
| 224 |
+
"val_bleu4_normalized": metrics.get("bleu4_normalized", metrics["bleu4"]),
|
| 225 |
+
"val_bert_score": metrics.get("bert_score", 0),
|
| 226 |
+
"val_bert_score_raw": metrics.get("bert_score_raw", metrics.get("bert_score", 0)),
|
| 227 |
+
"val_semantic_raw": metrics.get("semantic_raw", metrics.get("semantic", 0)),
|
| 228 |
+
})
|
| 229 |
+
|
| 230 |
+
return metrics
|
| 231 |
+
|
| 232 |
+
def train(self, epochs, tokenizer=None):
|
| 233 |
+
best_val_acc = 0.0
|
| 234 |
+
patience = self.config['train'].get('patience', 10)
|
| 235 |
+
counter = 0
|
| 236 |
+
ckpt_dir = "checkpoints"
|
| 237 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 238 |
+
history_dir = self.config.get("history_dir")
|
| 239 |
+
|
| 240 |
+
print(f"[INFO] Bắt đầu huấn luyện trong {epochs} epochs...")
|
| 241 |
+
|
| 242 |
+
# Log to WandB if available
|
| 243 |
+
if wandb.run is not None:
|
| 244 |
+
wandb.config.update({
|
| 245 |
+
'total_epochs': epochs,
|
| 246 |
+
'patience': patience,
|
| 247 |
+
'variant': self.config.get('variant', 'Unknown'),
|
| 248 |
+
'device': str(self.device),
|
| 249 |
+
'use_amp': self.use_amp,
|
| 250 |
+
})
|
| 251 |
+
for epoch in range(1, epochs + 1):
|
| 252 |
+
train_loss = self.train_epoch(epoch)
|
| 253 |
+
metrics = self.val_epoch(tokenizer, epoch=epoch)
|
| 254 |
+
|
| 255 |
+
val_acc = metrics.get('accuracy_normalized', metrics.get('accuracy', 0))
|
| 256 |
+
closed_eval = metrics.get("closed_eval", {})
|
| 257 |
+
open_eval = metrics.get("open_eval", {})
|
| 258 |
+
is_best = val_acc > best_val_acc
|
| 259 |
+
epoch_record = {
|
| 260 |
+
"epoch": epoch,
|
| 261 |
+
"train_loss": float(train_loss),
|
| 262 |
+
"val_accuracy": float(metrics.get("accuracy", 0.0)),
|
| 263 |
+
"val_accuracy_normalized": float(metrics.get("accuracy_normalized", metrics.get("accuracy", 0.0))),
|
| 264 |
+
"val_f1": float(metrics.get("f1", 0.0)),
|
| 265 |
+
"val_f1_normalized": float(metrics.get("f1_normalized", metrics.get("f1", 0.0))),
|
| 266 |
+
"val_bleu4": float(metrics.get("bleu4", 0.0)),
|
| 267 |
+
"val_bleu4_normalized": float(metrics.get("bleu4_normalized", metrics.get("bleu4", 0.0))),
|
| 268 |
+
"val_bert_score": float(metrics.get("bert_score", 0.0)),
|
| 269 |
+
"val_bert_score_raw": float(metrics.get("bert_score_raw", metrics.get("bert_score", 0.0))),
|
| 270 |
+
"val_semantic_raw": float(metrics.get("semantic_raw", metrics.get("semantic", 0.0))),
|
| 271 |
+
"val_closed_accuracy": float(closed_eval.get("accuracy", metrics.get("closed", {}).get("accuracy", -1))),
|
| 272 |
+
"val_closed_em": float(closed_eval.get("em", metrics.get("closed", {}).get("em", -1))),
|
| 273 |
+
"val_closed_f1": float(closed_eval.get("f1", metrics.get("closed", {}).get("f1", -1))),
|
| 274 |
+
"val_open_accuracy": float(metrics.get("open", {}).get("accuracy", -1)),
|
| 275 |
+
"val_open_semantic": float(open_eval.get("semantic", metrics.get("open", {}).get("semantic", -1))),
|
| 276 |
+
"val_open_bertscore": float(open_eval.get("bert_score", metrics.get("open", {}).get("bert_score", -1))),
|
| 277 |
+
"val_open_f1": float(open_eval.get("f1", metrics.get("open", {}).get("f1", -1))),
|
| 278 |
+
"val_open_rouge_l": float(open_eval.get("rouge_l", metrics.get("open", {}).get("rouge_l", -1))),
|
| 279 |
+
"best_so_far": bool(is_best),
|
| 280 |
+
"metrics": metrics,
|
| 281 |
+
}
|
| 282 |
+
self.history.append(epoch_record)
|
| 283 |
+
|
| 284 |
+
# Kiểm tra và Lưu Best Checkpoint
|
| 285 |
+
if is_best:
|
| 286 |
+
best_val_acc = val_acc
|
| 287 |
+
counter = 0
|
| 288 |
+
variant = self.config.get('variant', 'A')
|
| 289 |
+
save_path = os.path.join(ckpt_dir, f"medical_vqa_{variant}_best.pth")
|
| 290 |
+
torch.save(self.model.state_dict(), save_path)
|
| 291 |
+
|
| 292 |
+
resume_path = os.path.join(ckpt_dir, f"medical_vqa_{variant}_resume.pth")
|
| 293 |
+
checkpoint = {
|
| 294 |
+
'epoch': epoch,
|
| 295 |
+
'model_state_dict': self.model.state_dict(),
|
| 296 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 297 |
+
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
| 298 |
+
'best_val_acc': best_val_acc,
|
| 299 |
+
'train_loss': float(train_loss),
|
| 300 |
+
}
|
| 301 |
+
torch.save(checkpoint, resume_path)
|
| 302 |
+
print(f"🌟 Best model saved with Accuracy: {val_acc:.4f}")
|
| 303 |
+
else:
|
| 304 |
+
counter += 1
|
| 305 |
+
if history_dir:
|
| 306 |
+
self.save_history(history_dir)
|
| 307 |
+
if counter >= patience:
|
| 308 |
+
print(f"🛑 Early stopping tại epoch {epoch}!")
|
| 309 |
+
break
|
| 310 |
+
|
| 311 |
+
print("[INFO] Huấn luyện hoàn tất.")
|
| 312 |
+
if history_dir:
|
| 313 |
+
self.save_history(history_dir)
|
| 314 |
+
|
| 315 |
+
# ── Auto-plot sau khi training kết thúc ──────────────────────────────
|
| 316 |
+
if history_dir and len(self.history) >= 1:
|
| 317 |
+
chart_paths = self.plot_training_results(history_dir)
|
| 318 |
+
print(f"[INFO] 📊 Đã lưu {len(chart_paths)} biểu đồ tại: {history_dir}")
|
| 319 |
+
|
| 320 |
+
return self.history
|
| 321 |
+
|
| 322 |
+
# ── Visualization ────────────────────────────────────────────────────────
|
| 323 |
+
|
| 324 |
+
def plot_training_results(self, output_dir: str) -> list:
|
| 325 |
+
"""
|
| 326 |
+
Tự động vẽ và lưu 4 biểu đồ sau khi training kết thúc:
|
| 327 |
+
1. Train Loss theo epoch
|
| 328 |
+
2. Val Accuracy + F1 + BLEU-4 (multi-metric)
|
| 329 |
+
3. Closed vs Open Accuracy (bar per epoch)
|
| 330 |
+
4. BERTScore + Semantic Score
|
| 331 |
+
Trả về list các đường dẫn file ảnh đã lưu.
|
| 332 |
+
"""
|
| 333 |
+
try:
|
| 334 |
+
import matplotlib
|
| 335 |
+
matplotlib.use("Agg") # Non-interactive backend (an toàn cho server)
|
| 336 |
+
import matplotlib.pyplot as plt
|
| 337 |
+
import matplotlib.ticker as mticker
|
| 338 |
+
except ImportError:
|
| 339 |
+
print("[WARNING] matplotlib chưa cài — bỏ qua vẽ biểu đồ.")
|
| 340 |
+
return []
|
| 341 |
+
|
| 342 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 343 |
+
variant = self.config.get('variant', 'Model')
|
| 344 |
+
epochs = [r["epoch"] for r in self.history]
|
| 345 |
+
saved = []
|
| 346 |
+
|
| 347 |
+
# Palette
|
| 348 |
+
COLORS = {
|
| 349 |
+
"loss": "#e74c3c",
|
| 350 |
+
"accuracy": "#2ecc71",
|
| 351 |
+
"f1": "#3498db",
|
| 352 |
+
"bleu4": "#9b59b6",
|
| 353 |
+
"bert": "#e67e22",
|
| 354 |
+
"semantic": "#1abc9c",
|
| 355 |
+
"closed": "#2980b9",
|
| 356 |
+
"open": "#e74c3c",
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
def _finish(fig, fname):
|
| 360 |
+
fig.tight_layout()
|
| 361 |
+
path = os.path.join(output_dir, fname)
|
| 362 |
+
fig.savefig(path, dpi=150, bbox_inches="tight")
|
| 363 |
+
plt.close(fig)
|
| 364 |
+
# Upload to WandB if available
|
| 365 |
+
if wandb.run:
|
| 366 |
+
wandb.log({fname.replace(".png", ""): wandb.Image(path)})
|
| 367 |
+
saved.append(path)
|
| 368 |
+
|
| 369 |
+
# ── Chart 1: Train Loss ──────────────────────────────────────────────
|
| 370 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 371 |
+
ax.plot(epochs, [r["train_loss"] for r in self.history],
|
| 372 |
+
color=COLORS["loss"], linewidth=2.5, marker="o", markersize=5,
|
| 373 |
+
label="Train Loss")
|
| 374 |
+
ax.set_title(f"[{variant}] Train Loss per Epoch", fontsize=14, fontweight="bold")
|
| 375 |
+
ax.set_xlabel("Epoch"); ax.set_ylabel("Loss")
|
| 376 |
+
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
|
| 377 |
+
ax.legend(); ax.grid(True, alpha=0.3)
|
| 378 |
+
_finish(fig, f"{variant}_01_train_loss.png")
|
| 379 |
+
|
| 380 |
+
# ── Chart 2: Validation Metrics (Acc / F1 / BLEU-4) ─────────────────
|
| 381 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 382 |
+
ax.plot(epochs, [r["val_accuracy_normalized"] for r in self.history],
|
| 383 |
+
color=COLORS["accuracy"], linewidth=2.5, marker="o", label="Accuracy")
|
| 384 |
+
ax.plot(epochs, [r["val_f1_normalized"] for r in self.history],
|
| 385 |
+
color=COLORS["f1"], linewidth=2.5, marker="s", label="F1")
|
| 386 |
+
ax.plot(epochs, [r["val_bleu4_normalized"] for r in self.history],
|
| 387 |
+
color=COLORS["bleu4"], linewidth=2.5, marker="^", label="BLEU-4")
|
| 388 |
+
# Mark best epoch
|
| 389 |
+
best_epoch = max(self.history, key=lambda r: r["val_accuracy_normalized"])
|
| 390 |
+
ax.axvline(x=best_epoch["epoch"], color="gray", linestyle="--", alpha=0.6,
|
| 391 |
+
label=f"Best epoch {best_epoch['epoch']} ({best_epoch['val_accuracy_normalized']:.2%})")
|
| 392 |
+
ax.set_title(f"[{variant}] Validation Metrics per Epoch", fontsize=14, fontweight="bold")
|
| 393 |
+
ax.set_xlabel("Epoch"); ax.set_ylabel("Score")
|
| 394 |
+
ax.set_ylim(0, 1.05)
|
| 395 |
+
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
|
| 396 |
+
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
|
| 397 |
+
ax.legend(loc="lower right"); ax.grid(True, alpha=0.3)
|
| 398 |
+
_finish(fig, f"{variant}_02_val_metrics.png")
|
| 399 |
+
|
| 400 |
+
# ── Chart 3: Closed vs Open Accuracy ────────────────────────────────
|
| 401 |
+
closed_vals = [r["val_closed_accuracy"] for r in self.history]
|
| 402 |
+
open_vals = [r["val_open_accuracy"] for r in self.history]
|
| 403 |
+
has_closed = any(v >= 0 for v in closed_vals)
|
| 404 |
+
has_open = any(v >= 0 for v in open_vals)
|
| 405 |
+
|
| 406 |
+
if has_closed or has_open:
|
| 407 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 408 |
+
w = 0.35
|
| 409 |
+
x = range(len(epochs))
|
| 410 |
+
if has_closed:
|
| 411 |
+
c_vals = [v if v >= 0 else 0 for v in closed_vals]
|
| 412 |
+
ax.bar([i - w/2 for i in x], c_vals, w, label="Closed (Yes/No)",
|
| 413 |
+
color=COLORS["closed"], alpha=0.85)
|
| 414 |
+
if has_open:
|
| 415 |
+
o_vals = [v if v >= 0 else 0 for v in open_vals]
|
| 416 |
+
ax.bar([i + w/2 for i in x], o_vals, w, label="Open-ended",
|
| 417 |
+
color=COLORS["open"], alpha=0.85)
|
| 418 |
+
ax.set_xticks(list(x)); ax.set_xticklabels([f"E{e}" for e in epochs])
|
| 419 |
+
ax.set_title(f"[{variant}] Closed vs Open Accuracy per Epoch",
|
| 420 |
+
fontsize=14, fontweight="bold")
|
| 421 |
+
ax.set_ylabel("Accuracy")
|
| 422 |
+
ax.set_ylim(0, 1.05)
|
| 423 |
+
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
|
| 424 |
+
ax.legend(); ax.grid(True, alpha=0.3, axis="y")
|
| 425 |
+
_finish(fig, f"{variant}_03_closed_vs_open.png")
|
| 426 |
+
|
| 427 |
+
# ── Chart 4: BERTScore + Semantic Score ──────────────────────────────
|
| 428 |
+
bert_vals = [r["val_bert_score_raw"] for r in self.history]
|
| 429 |
+
semantic_vals = [r["val_semantic_raw"] for r in self.history]
|
| 430 |
+
if any(v > 0 for v in bert_vals + semantic_vals):
|
| 431 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 432 |
+
ax.plot(epochs, bert_vals, color=COLORS["bert"], linewidth=2.5,
|
| 433 |
+
marker="o", label="BERTScore")
|
| 434 |
+
ax.plot(epochs, semantic_vals, color=COLORS["semantic"], linewidth=2.5,
|
| 435 |
+
marker="s", label="Semantic Score")
|
| 436 |
+
ax.set_title(f"[{variant}] BERTScore & Semantic Score per Epoch",
|
| 437 |
+
fontsize=14, fontweight="bold")
|
| 438 |
+
ax.set_xlabel("Epoch"); ax.set_ylabel("Score")
|
| 439 |
+
ax.set_ylim(0, 1.05)
|
| 440 |
+
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
|
| 441 |
+
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
|
| 442 |
+
ax.legend(); ax.grid(True, alpha=0.3)
|
| 443 |
+
_finish(fig, f"{variant}_04_bert_semantic.png")
|
| 444 |
+
|
| 445 |
+
# ── Print final summary table ─────────────────────────────────────────
|
| 446 |
+
print("\n" + "═" * 72)
|
| 447 |
+
print(f" 📊 TRAINING SUMMARY — {variant}")
|
| 448 |
+
print("═" * 72)
|
| 449 |
+
print(f" {'Epoch':>5} {'TrainLoss':>10} {'Accuracy':>9} {'F1':>7} {'BLEU-4':>7} {'Best':>5}")
|
| 450 |
+
print("─" * 72)
|
| 451 |
+
for r in self.history:
|
| 452 |
+
star = " ★" if r.get("best_so_far") else ""
|
| 453 |
+
print(
|
| 454 |
+
f" {r['epoch']:>5} {r['train_loss']:>10.4f} "
|
| 455 |
+
f"{r['val_accuracy_normalized']:>9.2%} "
|
| 456 |
+
f"{r['val_f1_normalized']:>7.2%} "
|
| 457 |
+
f"{r['val_bleu4_normalized']:>7.2%}{star}"
|
| 458 |
+
)
|
| 459 |
+
print("═" * 72 + "\n")
|
| 460 |
+
|
| 461 |
+
return saved
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Initialize src.models package
|
src/models/encoder.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchxrayvision as xrv
|
| 4 |
+
|
| 5 |
+
class MedicalImageEncoder(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
SOTA Image Encoder sử dụng DenseNet-121 (TorchXRayVision)
|
| 8 |
+
Pretrained trên 200K+ ảnh X-ray (CheXpert, NIH, v.v.)
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, pretrained=True):
|
| 11 |
+
super(MedicalImageEncoder, self).__init__()
|
| 12 |
+
if pretrained:
|
| 13 |
+
self.model = xrv.models.DenseNet(weights="densenet121-res224-chex")
|
| 14 |
+
else:
|
| 15 |
+
self.model = xrv.models.DenseNet(weights=None)
|
| 16 |
+
|
| 17 |
+
self.model.classifier = nn.Identity() # Bỏ lớp phân loại
|
| 18 |
+
self.projector = nn.Linear(1024, 768) # Map về dimension của PhoBERT
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
feat_map = self.model.features(x) # [B, 1024, 7, 7]
|
| 22 |
+
feat_map = feat_map.flatten(2).transpose(1, 2) # [B, 49, 1024]
|
| 23 |
+
return self.projector(feat_map) # [B, 49, 768]
|
src/models/medical_vqa_model.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .encoder import MedicalImageEncoder
|
| 4 |
+
from .phobert_encoder import PhoBERTEncoder
|
| 5 |
+
from .transformer_decoder import MedicalVQADecoder
|
| 6 |
+
|
| 7 |
+
class CoAttentionFusion(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Cơ chế Co-Attention giúp mô hình tập trung vào các vùng ảnh và từ ngữ liên quan lẫn nhau.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, hidden_size=768, nhead=8):
|
| 12 |
+
super(CoAttentionFusion, self).__init__()
|
| 13 |
+
# Cross-modal attention: Ảnh hỏi Chữ và Chữ hỏi Ảnh
|
| 14 |
+
self.v2t_attn = nn.MultiheadAttention(hidden_size, nhead, batch_first=True)
|
| 15 |
+
self.t2v_attn = nn.MultiheadAttention(hidden_size, nhead, batch_first=True)
|
| 16 |
+
|
| 17 |
+
self.fusion_layer = nn.Sequential(
|
| 18 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
| 19 |
+
nn.LayerNorm(hidden_size),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.Dropout(0.1)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, v_feats, t_feats):
|
| 25 |
+
# v_feats: [B, 49, 768] — KHÔNG cần unsqueeze nữa
|
| 26 |
+
t_seq = t_feats.unsqueeze(1) # [B, 1, 768] — text vẫn giữ
|
| 27 |
+
|
| 28 |
+
# Parallel Co-Attention
|
| 29 |
+
v_fused, _ = self.v2t_attn(v_feats, t_seq, t_seq)
|
| 30 |
+
t_fused, _ = self.t2v_attn(t_seq, v_feats, v_feats)
|
| 31 |
+
|
| 32 |
+
# v_fused: [B, 49, 768] → pool về [B, 1, 768] trước khi concat
|
| 33 |
+
v_fused = v_fused.mean(dim=1, keepdim=True)
|
| 34 |
+
|
| 35 |
+
# Kết hợp thông tin từ cả hai hướng
|
| 36 |
+
combined = torch.cat([v_fused, t_fused], dim=-1) # [B, 1, 1536]
|
| 37 |
+
return self.fusion_layer(combined) # [B, 1, 768]
|
| 38 |
+
|
| 39 |
+
class MedicalVQAModelA(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Kiến trúc rời (Hướng A) cho Medical VQA Tiếng Việt.
|
| 42 |
+
Sử dụng DenseNet-121 (XRV) + PhoBERT + Co-Attention + Dual-Head Decoder.
|
| 43 |
+
"""
|
| 44 |
+
def __init__(self, decoder_type="transformer", vocab_size=30000, hidden_size=768, phobert_model=None, **kwargs):
|
| 45 |
+
super(MedicalVQAModelA, self).__init__()
|
| 46 |
+
|
| 47 |
+
# 1. Image Encoder (DenseNet-121 XRV)
|
| 48 |
+
self.image_encoder = MedicalImageEncoder(pretrained=True)
|
| 49 |
+
|
| 50 |
+
# 2. Text Encoder (PhoBERT)
|
| 51 |
+
self.text_encoder = PhoBERTEncoder(model_name=phobert_model) if phobert_model else PhoBERTEncoder()
|
| 52 |
+
|
| 53 |
+
# 3. Fusion Layer (Co-Attention Fusion)
|
| 54 |
+
self.fusion = CoAttentionFusion(hidden_size=hidden_size, nhead=8)
|
| 55 |
+
|
| 56 |
+
# 4. Trích xuất pretrained embeddings từ PhoBERT cho Decoder
|
| 57 |
+
phobert_embeddings = self.text_encoder.bert.embeddings.word_embeddings.weight
|
| 58 |
+
actual_vocab_size = phobert_embeddings.size(0)
|
| 59 |
+
|
| 60 |
+
# 5. Decoder (LSTM / Transformer)
|
| 61 |
+
self.decoder = MedicalVQADecoder(
|
| 62 |
+
decoder_type=decoder_type,
|
| 63 |
+
vocab_size=actual_vocab_size,
|
| 64 |
+
hidden_size=hidden_size,
|
| 65 |
+
pretrained_embeddings=phobert_embeddings
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, images, input_ids, attention_mask, labels_open=None, labels_closed=None):
|
| 69 |
+
v_feats = self.image_encoder(images)
|
| 70 |
+
t_feats = self.text_encoder(input_ids, attention_mask)
|
| 71 |
+
fused = self.fusion(v_feats, t_feats)
|
| 72 |
+
|
| 73 |
+
logits_closed, logits_open = self.decoder(fused, labels_open)
|
| 74 |
+
|
| 75 |
+
return logits_closed, logits_open
|
| 76 |
+
|
| 77 |
+
def generate(self, images, input_ids, attention_mask, beam_width=1, max_len=10):
|
| 78 |
+
"""
|
| 79 |
+
Giao diện chuyên biệt cho quá trình Inference (chỉ trả token IDs cho open-ended).
|
| 80 |
+
"""
|
| 81 |
+
v_feats = self.image_encoder(images)
|
| 82 |
+
t_feats = self.text_encoder(input_ids, attention_mask)
|
| 83 |
+
fused = self.fusion(v_feats, t_feats)
|
| 84 |
+
|
| 85 |
+
return self.decoder.generate(fused, beam_width=beam_width, max_len=max_len)
|
| 86 |
+
|
| 87 |
+
def inference(self, images, input_ids, attention_mask, beam_width=1, max_len=10):
|
| 88 |
+
"""
|
| 89 |
+
[NEW] Trả về CẢ HAI dual-head outputs:
|
| 90 |
+
- logits_closed: [B, 2] — dùng cho câu Yes/No (classifier head)
|
| 91 |
+
- generated_ids: [B, max_len] — dùng cho câu mở (generative head)
|
| 92 |
+
"""
|
| 93 |
+
v_feats = self.image_encoder(images)
|
| 94 |
+
t_feats = self.text_encoder(input_ids, attention_mask)
|
| 95 |
+
fused = self.fusion(v_feats, t_feats)
|
| 96 |
+
|
| 97 |
+
logits_closed = self.decoder.classifier_head(fused.squeeze(1)) # [B, 2]
|
| 98 |
+
generated_ids = self.decoder.generate(fused, beam_width=beam_width, max_len=max_len) # [B, max_len]
|
| 99 |
+
|
| 100 |
+
return logits_closed, generated_ids
|
src/models/multimodal_vqa.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import LlavaProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
|
| 3 |
+
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
|
| 4 |
+
|
| 5 |
+
class MultimodalVQA:
|
| 6 |
+
"""
|
| 7 |
+
Wrapper cho LLaVA-Med-7B tích hợp QLoRA 4-bit để huấn luyện trên Kaggle.
|
| 8 |
+
Sử dụng kiến trúc LLaVA-1.5 (microsoft/llava-med-v1.5-7b).
|
| 9 |
+
"""
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
model_id="chaoyinshe/llava-med-v1.5-mistral-7b-hf",
|
| 13 |
+
lora_r=16,
|
| 14 |
+
lora_alpha=32,
|
| 15 |
+
lora_dropout=0.05,
|
| 16 |
+
lora_target_modules=None,
|
| 17 |
+
):
|
| 18 |
+
self.model_id = model_id
|
| 19 |
+
|
| 20 |
+
# 1. Cấu hình Quantization 4-bit (Tiết kiệm VRAM)
|
| 21 |
+
self.bnb_config = BitsAndBytesConfig(
|
| 22 |
+
load_in_4bit=True,
|
| 23 |
+
bnb_4bit_use_double_quant=True,
|
| 24 |
+
bnb_4bit_quant_type="nf4",
|
| 25 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# 2. Cấu hình LoRA (Chỉ huấn luyện một phần nhỏ tham số)
|
| 29 |
+
self.peft_config = LoraConfig(
|
| 30 |
+
r=lora_r,
|
| 31 |
+
lora_alpha=lora_alpha,
|
| 32 |
+
target_modules=lora_target_modules or ["q_proj", "v_proj", "k_proj", "o_proj"],
|
| 33 |
+
lora_dropout=lora_dropout,
|
| 34 |
+
bias="none",
|
| 35 |
+
task_type="CAUSAL_LM"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def load_model(self, adapter_path=None, is_trainable=True):
|
| 39 |
+
print(f"[INFO] Đang tải LLaVA-Med-v1.5-7B với chế độ 4-bit...")
|
| 40 |
+
processor = LlavaProcessor.from_pretrained(self.model_id)
|
| 41 |
+
processor.tokenizer.padding_side = "left" # Bắt buộc cho decoder-only models
|
| 42 |
+
model = LlavaForConditionalGeneration.from_pretrained(
|
| 43 |
+
self.model_id,
|
| 44 |
+
quantization_config=self.bnb_config,
|
| 45 |
+
device_map="auto"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
model.config.use_cache = False
|
| 49 |
+
|
| 50 |
+
# Chuẩn bị mô hình cho PEFT
|
| 51 |
+
model = prepare_model_for_kbit_training(model)
|
| 52 |
+
if adapter_path:
|
| 53 |
+
print(f"[INFO] Đang nạp adapter LoRA từ: {adapter_path}")
|
| 54 |
+
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=is_trainable)
|
| 55 |
+
else:
|
| 56 |
+
model = get_peft_model(model, self.peft_config)
|
| 57 |
+
model.gradient_checkpointing_enable()
|
| 58 |
+
model.enable_input_require_grads()
|
| 59 |
+
|
| 60 |
+
model.print_trainable_parameters()
|
| 61 |
+
return model, processor
|
| 62 |
+
|
| 63 |
+
def generate_prompt_vi(self, question_en):
|
| 64 |
+
"""
|
| 65 |
+
Hàm hỗ trợ tạo prompt cho LLaVA-Med (EN).
|
| 66 |
+
Nhớ dùng Translation Layer trước khi gọi hàm này.
|
| 67 |
+
"""
|
| 68 |
+
return self.build_instruction_prompt(question_en, language="en", include_answer=False)
|
| 69 |
+
|
| 70 |
+
def build_instruction_prompt(self, question, language="vi", include_answer=False):
|
| 71 |
+
"""
|
| 72 |
+
Prompt thống nhất cho zero-shot, SFT và demo.
|
| 73 |
+
"""
|
| 74 |
+
if language == "vi":
|
| 75 |
+
instruction = "Chi tra loi bang tieng Viet, khong dung tieng Anh, thuat ngu y khoa chuan, ngan gon, toi da 10 tu."
|
| 76 |
+
else:
|
| 77 |
+
instruction = "Answer with standard medical terminology, concise, at most 10 words."
|
| 78 |
+
suffix = " ASSISTANT:" if not include_answer else ""
|
| 79 |
+
return f"USER: <image>\n{question}\n{instruction}{suffix}"
|
src/models/phobert_encoder.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from transformers import AutoModel
|
| 3 |
+
|
| 4 |
+
class PhoBERTEncoder(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Text Encoder sử dụng PhoBERT pretrained.
|
| 7 |
+
Hỗ trợ tiếng Việt tốt nhất cho Medical VQA.
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, model_name="vinai/phobert-base", freeze_layers=10):
|
| 10 |
+
super(PhoBERTEncoder, self).__init__()
|
| 11 |
+
self.bert = AutoModel.from_pretrained(model_name, use_safetensors=True)
|
| 12 |
+
|
| 13 |
+
# Đóng băng các lớp Transformer đầu tiên nếu cần
|
| 14 |
+
if freeze_layers > 0:
|
| 15 |
+
for param in self.bert.embeddings.parameters():
|
| 16 |
+
param.requires_grad = False
|
| 17 |
+
for layer in self.bert.encoder.layer[:freeze_layers]:
|
| 18 |
+
for param in layer.parameters():
|
| 19 |
+
param.requires_grad = False
|
| 20 |
+
|
| 21 |
+
def forward(self, input_ids, attention_mask):
|
| 22 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 23 |
+
# Lấy [CLS] token đại diện cho toàn bộ câu hỏi
|
| 24 |
+
return outputs.last_hidden_state[:, 0, :]
|
src/models/transformer_decoder.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MedicalVQADecoder(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
decoder_type: str = "transformer",
|
| 10 |
+
vocab_size: int = 30000,
|
| 11 |
+
hidden_size: int = 768,
|
| 12 |
+
pretrained_embeddings=None,
|
| 13 |
+
num_layers: int = 3,
|
| 14 |
+
nhead: int = 8,
|
| 15 |
+
dropout: float = 0.1,
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.decoder_type = decoder_type.lower()
|
| 19 |
+
self.vocab_size = vocab_size
|
| 20 |
+
self.hidden_size = hidden_size
|
| 21 |
+
|
| 22 |
+
# ── Nhánh 1: Classifier cho Yes/No ──────────────────────────────────
|
| 23 |
+
# [FIX] Thêm Dropout + GELU theo best-practice hiện đại
|
| 24 |
+
self.classifier_head = nn.Sequential(
|
| 25 |
+
nn.Linear(hidden_size, 512),
|
| 26 |
+
nn.GELU(),
|
| 27 |
+
nn.Dropout(dropout),
|
| 28 |
+
nn.Linear(512, 2),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# ── Nhánh 2: Generator ───────────────────────────────────────────────
|
| 32 |
+
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
|
| 33 |
+
if pretrained_embeddings is not None:
|
| 34 |
+
self.embedding.weight.data.copy_(pretrained_embeddings)
|
| 35 |
+
|
| 36 |
+
if self.decoder_type == "lstm":
|
| 37 |
+
self.generator = nn.LSTM(
|
| 38 |
+
hidden_size, hidden_size, num_layers=1, batch_first=True
|
| 39 |
+
)
|
| 40 |
+
else:
|
| 41 |
+
# [FIX A2] Pre-LayerNorm (norm_first=True): hội tụ ổn định hơn, giảm gap A1-A2
|
| 42 |
+
# dim_feedforward=4*hidden (768*4=3072) theo chuẩn Transformer gốc
|
| 43 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
| 44 |
+
d_model=hidden_size,
|
| 45 |
+
nhead=nhead,
|
| 46 |
+
dim_feedforward=hidden_size * 4,
|
| 47 |
+
dropout=dropout,
|
| 48 |
+
activation="gelu",
|
| 49 |
+
batch_first=True,
|
| 50 |
+
norm_first=True,
|
| 51 |
+
)
|
| 52 |
+
self.generator = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
| 53 |
+
|
| 54 |
+
self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False)
|
| 55 |
+
|
| 56 |
+
# [OPTIMIZATION] Weight Tying: chia sẻ trọng số Embedding ↔ Output Projection
|
| 57 |
+
# Giảm ~vocab_size * hidden_size params, cải thiện generalization (Press & Wolf 2017)
|
| 58 |
+
self.output_layer.weight = self.embedding.weight
|
| 59 |
+
|
| 60 |
+
# [OPTIMIZATION] Cache causal mask để tránh re-allocate mỗi forward pass
|
| 61 |
+
self._causal_mask_cache: dict[tuple, torch.Tensor] = {}
|
| 62 |
+
|
| 63 |
+
# ── Mask helper ─────────────────────────────────────────────────────────
|
| 64 |
+
def _get_causal_mask(self, sz: int, device: torch.device) -> torch.Tensor:
|
| 65 |
+
key = (sz, str(device))
|
| 66 |
+
if key not in self._causal_mask_cache:
|
| 67 |
+
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
|
| 68 |
+
self._causal_mask_cache[key] = mask
|
| 69 |
+
return self._causal_mask_cache[key]
|
| 70 |
+
|
| 71 |
+
# ── Public generate API ──────────────────────────────────────────────────
|
| 72 |
+
def generate(self, fused_features, beam_width: int = 1, max_len: int = 10):
|
| 73 |
+
"""Sinh câu trả lời. Trả về token IDs [B, max_len]."""
|
| 74 |
+
if beam_width <= 1:
|
| 75 |
+
return self._greedy_search(fused_features, max_len)
|
| 76 |
+
return self._beam_search(fused_features, beam_width, max_len)
|
| 77 |
+
|
| 78 |
+
# ── Greedy Search ────────────────────────────────────────────────────────
|
| 79 |
+
def _greedy_search(self, fused_features, max_len: int):
|
| 80 |
+
"""
|
| 81 |
+
Greedy decoding (beam_width=1).
|
| 82 |
+
LSTM: chỉ feed token cuối, h_state giữ ngữ cảnh → tránh O(n²) recompute.
|
| 83 |
+
Trả về token IDs [B, max_len].
|
| 84 |
+
"""
|
| 85 |
+
batch_size = fused_features.size(0)
|
| 86 |
+
device = fused_features.device
|
| 87 |
+
generated = torch.zeros((batch_size, 1), dtype=torch.long, device=device) # BOS=0
|
| 88 |
+
h_state = None
|
| 89 |
+
|
| 90 |
+
for _ in range(max_len):
|
| 91 |
+
if self.decoder_type == "lstm":
|
| 92 |
+
curr_emb = self.embedding(generated[:, -1:]) # [B,1,H]
|
| 93 |
+
if h_state is None:
|
| 94 |
+
h0 = fused_features.transpose(0, 1).contiguous()
|
| 95 |
+
h_state = (h0, torch.zeros_like(h0))
|
| 96 |
+
outputs, h_state = self.generator(curr_emb, h_state)
|
| 97 |
+
else:
|
| 98 |
+
curr_emb = self.embedding(generated)
|
| 99 |
+
tgt_mask = self._get_causal_mask(generated.size(1), device)
|
| 100 |
+
outputs = self.generator(curr_emb, fused_features, tgt_mask=tgt_mask)
|
| 101 |
+
|
| 102 |
+
next_token = self.output_layer(outputs[:, -1:, :]).argmax(dim=-1)
|
| 103 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 104 |
+
|
| 105 |
+
return generated[:, 1:] # Bỏ BOS
|
| 106 |
+
|
| 107 |
+
# ── Beam Search ──────────────────────────────────────────────────────────
|
| 108 |
+
def _beam_search(
|
| 109 |
+
self,
|
| 110 |
+
fused_features,
|
| 111 |
+
beam_width: int,
|
| 112 |
+
max_len: int,
|
| 113 |
+
repetition_penalty: float = 1.2,
|
| 114 |
+
alpha: float = 0.7,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Beam Search với Length Normalization + Vectorised Repetition Penalty.
|
| 118 |
+
[FIX] Thay vòng for Python sang tensor ops để tăng tốc ~3-5× trên GPU.
|
| 119 |
+
Trả về token IDs [B, max_len].
|
| 120 |
+
"""
|
| 121 |
+
batch_size = fused_features.size(0)
|
| 122 |
+
device = fused_features.device
|
| 123 |
+
all_results = []
|
| 124 |
+
|
| 125 |
+
for b in range(batch_size):
|
| 126 |
+
feat = fused_features[b:b+1] # [1, 1, H]
|
| 127 |
+
beams = [(torch.zeros((1, 1), dtype=torch.long, device=device), 0.0, None)]
|
| 128 |
+
|
| 129 |
+
for _ in range(max_len):
|
| 130 |
+
new_beams = []
|
| 131 |
+
for seq, score, h_state in beams:
|
| 132 |
+
if seq[0, -1].item() == 2: # EOS
|
| 133 |
+
new_beams.append((seq, score, h_state))
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
if self.decoder_type == "lstm":
|
| 137 |
+
curr_emb = self.embedding(seq[:, -1:])
|
| 138 |
+
if h_state is None:
|
| 139 |
+
h0 = feat.transpose(0, 1).contiguous()
|
| 140 |
+
h_state = (h0, torch.zeros_like(h0))
|
| 141 |
+
outputs, next_h = self.generator(curr_emb, h_state)
|
| 142 |
+
else:
|
| 143 |
+
curr_emb = self.embedding(seq)
|
| 144 |
+
tgt_mask = self._get_causal_mask(seq.size(1), device)
|
| 145 |
+
outputs = self.generator(curr_emb, feat, tgt_mask=tgt_mask)
|
| 146 |
+
next_h = None
|
| 147 |
+
|
| 148 |
+
logits = self.output_layer(outputs[:, -1, :]).squeeze(0) # [V]
|
| 149 |
+
|
| 150 |
+
# [OPTIMIZED] Vectorised Repetition Penalty (thay vòng for Python)
|
| 151 |
+
unique_ids = seq[0].unique()
|
| 152 |
+
valid_ids = unique_ids[(unique_ids != 0) & (unique_ids != 2)]
|
| 153 |
+
if valid_ids.numel() > 0:
|
| 154 |
+
neg_mask = logits[valid_ids] < 0
|
| 155 |
+
factors = torch.where(
|
| 156 |
+
neg_mask,
|
| 157 |
+
torch.full_like(logits[valid_ids], repetition_penalty),
|
| 158 |
+
torch.full_like(logits[valid_ids], 1.0 / repetition_penalty),
|
| 159 |
+
)
|
| 160 |
+
logits = logits.clone()
|
| 161 |
+
logits[valid_ids] = logits[valid_ids] * factors
|
| 162 |
+
|
| 163 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 164 |
+
topk_log_probs, topk_ids = torch.topk(log_probs, beam_width)
|
| 165 |
+
|
| 166 |
+
for i in range(beam_width):
|
| 167 |
+
new_seq = torch.cat([seq, topk_ids[i].view(1, 1)], dim=1)
|
| 168 |
+
new_beams.append((new_seq, score + topk_log_probs[i].item(), next_h))
|
| 169 |
+
|
| 170 |
+
def _norm_score(beam):
|
| 171 |
+
seq_len = max(beam[0].size(1) - 1, 1)
|
| 172 |
+
return beam[1] / (seq_len ** alpha)
|
| 173 |
+
|
| 174 |
+
new_beams.sort(key=_norm_score, reverse=True)
|
| 175 |
+
beams = new_beams[:beam_width]
|
| 176 |
+
|
| 177 |
+
if all(bm[0][0, -1].item() == 2 for bm in beams):
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
beams.sort(key=_norm_score, reverse=True)
|
| 181 |
+
best_seq = beams[0][0][:, 1:] # Bỏ BOS
|
| 182 |
+
|
| 183 |
+
if best_seq.size(1) < max_len:
|
| 184 |
+
pad = torch.zeros((1, max_len - best_seq.size(1)), dtype=torch.long, device=device)
|
| 185 |
+
best_seq = torch.cat([best_seq, pad], dim=1)
|
| 186 |
+
else:
|
| 187 |
+
best_seq = best_seq[:, :max_len]
|
| 188 |
+
all_results.append(best_seq)
|
| 189 |
+
|
| 190 |
+
return torch.cat(all_results, dim=0) # [B, max_len]
|
| 191 |
+
|
| 192 |
+
# ── Training Forward ─────────────────────────────────────────────────────
|
| 193 |
+
def forward(self, fused_features, target_ids=None, beam_width: int = 1):
|
| 194 |
+
"""
|
| 195 |
+
fused_features: [B, 1, H]
|
| 196 |
+
target_ids: [B, SeqLen] — Teacher Forcing; None → inference
|
| 197 |
+
"""
|
| 198 |
+
logits_closed = self.classifier_head(fused_features.squeeze(1))
|
| 199 |
+
|
| 200 |
+
if target_ids is not None:
|
| 201 |
+
target_emb = self.embedding(target_ids)
|
| 202 |
+
|
| 203 |
+
if self.decoder_type == "lstm":
|
| 204 |
+
h0 = fused_features.transpose(0, 1).contiguous()
|
| 205 |
+
outputs, _ = self.generator(target_emb, (h0, torch.zeros_like(h0)))
|
| 206 |
+
else:
|
| 207 |
+
tgt_mask = self._get_causal_mask(target_ids.size(1), target_ids.device)
|
| 208 |
+
outputs = self.generator(target_emb, fused_features, tgt_mask=tgt_mask)
|
| 209 |
+
|
| 210 |
+
logits_open = self.output_layer(outputs)
|
| 211 |
+
else:
|
| 212 |
+
logits_open = self.generate(fused_features, beam_width=beam_width)
|
| 213 |
+
|
| 214 |
+
return logits_closed, logits_open
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Initialize src.utils package
|
src/utils/answer_rewriter.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from src.utils.text_utils import postprocess_answer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _as_bool(value: object, default: bool = False) -> bool:
|
| 10 |
+
if value is None:
|
| 11 |
+
return default
|
| 12 |
+
if isinstance(value, bool):
|
| 13 |
+
return value
|
| 14 |
+
return str(value).strip().lower() in {"1", "true", "yes", "y", "on"}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class RewriteConfig:
|
| 19 |
+
enabled: bool = False
|
| 20 |
+
model_id: str = ""
|
| 21 |
+
use_4bit: bool = True
|
| 22 |
+
max_new_tokens: int = 28
|
| 23 |
+
max_words: int = 10
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MedicalAnswerRewriter:
|
| 27 |
+
"""
|
| 28 |
+
Rewrite lớp cuối cho VQA output.
|
| 29 |
+
|
| 30 |
+
Mục tiêu:
|
| 31 |
+
- Giữ nguyên ý nghĩa gốc.
|
| 32 |
+
- Làm câu trả lời tự nhiên và đầy đủ hơn một chút.
|
| 33 |
+
- Vẫn giới hạn tối đa số từ theo cấu hình.
|
| 34 |
+
|
| 35 |
+
Mô hình này không thay thế VQA model chính.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, config: RewriteConfig | None = None) -> None:
|
| 39 |
+
self.config = config or self._load_config()
|
| 40 |
+
self._load_attempted = False
|
| 41 |
+
self._ready = False
|
| 42 |
+
self._tokenizer = None
|
| 43 |
+
self._model = None
|
| 44 |
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def _load_config() -> RewriteConfig:
|
| 48 |
+
model_id = (
|
| 49 |
+
os.getenv("ANSWER_REWRITE_MODEL_ID", "").strip()
|
| 50 |
+
or os.getenv("QWEN_REWRITE_MODEL_ID", "").strip()
|
| 51 |
+
or "Qwen/Qwen2.5-14B-Instruct"
|
| 52 |
+
)
|
| 53 |
+
enabled = _as_bool(os.getenv("ANSWER_REWRITE_ENABLED"), default=True)
|
| 54 |
+
use_4bit = _as_bool(os.getenv("ANSWER_REWRITE_USE_4BIT"), default=True)
|
| 55 |
+
max_new_tokens = int(os.getenv("ANSWER_REWRITE_MAX_NEW_TOKENS", "28"))
|
| 56 |
+
max_words = int(os.getenv("ANSWER_REWRITE_MAX_WORDS", "10"))
|
| 57 |
+
return RewriteConfig(
|
| 58 |
+
enabled=enabled,
|
| 59 |
+
model_id=model_id,
|
| 60 |
+
use_4bit=use_4bit,
|
| 61 |
+
max_new_tokens=max_new_tokens,
|
| 62 |
+
max_words=max_words,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def enabled(self) -> bool:
|
| 67 |
+
return bool(self.config.enabled and self.config.model_id)
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def model_id(self) -> str:
|
| 71 |
+
return self.config.model_id
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def ready(self) -> bool:
|
| 75 |
+
return self._ready
|
| 76 |
+
|
| 77 |
+
def _lazy_load(self) -> None:
|
| 78 |
+
if self._load_attempted:
|
| 79 |
+
return
|
| 80 |
+
self._load_attempted = True
|
| 81 |
+
|
| 82 |
+
if not self.enabled:
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 87 |
+
hf_token = (
|
| 88 |
+
os.getenv("ANSWER_REWRITE_HF_TOKEN", "").strip()
|
| 89 |
+
or os.getenv("HF_TOKEN", "").strip()
|
| 90 |
+
or os.getenv("HUGGINGFACE_HUB_TOKEN", "").strip()
|
| 91 |
+
or None
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
tokenizer = AutoTokenizer.from_pretrained(self.config.model_id, trust_remote_code=True, token=hf_token)
|
| 95 |
+
model_kwargs = {
|
| 96 |
+
"trust_remote_code": True,
|
| 97 |
+
"low_cpu_mem_usage": True,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
if self._device.type == "cuda":
|
| 101 |
+
if self.config.use_4bit:
|
| 102 |
+
try:
|
| 103 |
+
from transformers import BitsAndBytesConfig
|
| 104 |
+
|
| 105 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 106 |
+
load_in_4bit=True,
|
| 107 |
+
bnb_4bit_use_double_quant=True,
|
| 108 |
+
bnb_4bit_quant_type="nf4",
|
| 109 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 110 |
+
)
|
| 111 |
+
except Exception as exc:
|
| 112 |
+
print(f"[WARNING] Rewrite 4-bit config unavailable, falling back to bf16: {exc}")
|
| 113 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 114 |
+
else:
|
| 115 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 116 |
+
model_kwargs["device_map"] = "auto"
|
| 117 |
+
else:
|
| 118 |
+
model_kwargs["torch_dtype"] = torch.float32
|
| 119 |
+
|
| 120 |
+
if hf_token is not None:
|
| 121 |
+
model_kwargs["token"] = hf_token
|
| 122 |
+
|
| 123 |
+
model = AutoModelForCausalLM.from_pretrained(self.config.model_id, **model_kwargs)
|
| 124 |
+
model.eval()
|
| 125 |
+
|
| 126 |
+
self._tokenizer = tokenizer
|
| 127 |
+
self._model = model
|
| 128 |
+
self._ready = True
|
| 129 |
+
print(f"[INFO] ✅ Answer rewriter ready: {self.config.model_id}")
|
| 130 |
+
except Exception as exc:
|
| 131 |
+
self._ready = False
|
| 132 |
+
print(f"[WARNING] ❌ Answer rewriter load failed: {exc}")
|
| 133 |
+
|
| 134 |
+
def _build_messages(self, question: str, answer: str, language: str = "vi") -> list[dict[str, str]]:
|
| 135 |
+
system_prompt = (
|
| 136 |
+
"Bạn là bộ biên tập câu trả lời cho hệ thống Medical VQA. "
|
| 137 |
+
"Nhiệm vụ của bạn là viết lại câu trả lời gốc thành một câu ngắn, tự nhiên, "
|
| 138 |
+
"rõ nghĩa hơn nhưng KHÔNG thêm thông tin mới ngoài nội dung đã có. "
|
| 139 |
+
"Giới hạn tối đa 10 từ. Chỉ trả về câu trả lời cuối cùng."
|
| 140 |
+
)
|
| 141 |
+
if language.lower().startswith("en"):
|
| 142 |
+
system_prompt = (
|
| 143 |
+
"You are an editor for a Medical VQA system. "
|
| 144 |
+
"Rewrite the raw answer into a short, natural, clearer sentence "
|
| 145 |
+
"without adding facts beyond the original answer. "
|
| 146 |
+
"Use at most 10 words. Return only the final answer."
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
examples = [
|
| 150 |
+
{
|
| 151 |
+
"question": "Ảnh này có tràn dịch màng phổi không?",
|
| 152 |
+
"answer": "không",
|
| 153 |
+
"rewrite": "Không, không có tràn dịch màng phổi.",
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"question": "Hình ảnh có tim to không?",
|
| 157 |
+
"answer": "có",
|
| 158 |
+
"rewrite": "Có, tim to.",
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"question": "Đây là loại ảnh gì?",
|
| 162 |
+
"answer": "x quang ngực",
|
| 163 |
+
"rewrite": "X-quang ngực.",
|
| 164 |
+
},
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
if language.lower().startswith("en"):
|
| 168 |
+
examples = [
|
| 169 |
+
{
|
| 170 |
+
"question": "Is there pleural effusion?",
|
| 171 |
+
"answer": "no",
|
| 172 |
+
"rewrite": "No, no pleural effusion.",
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"question": "Is the heart enlarged?",
|
| 176 |
+
"answer": "yes",
|
| 177 |
+
"rewrite": "Yes, enlarged heart.",
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"question": "What modality is this?",
|
| 181 |
+
"answer": "chest x ray",
|
| 182 |
+
"rewrite": "Chest X-ray.",
|
| 183 |
+
},
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}]
|
| 187 |
+
for ex in examples:
|
| 188 |
+
messages.append(
|
| 189 |
+
{
|
| 190 |
+
"role": "user",
|
| 191 |
+
"content": f"Câu hỏi: {ex['question']}\nĐáp án gốc: {ex['answer']}",
|
| 192 |
+
}
|
| 193 |
+
)
|
| 194 |
+
messages.append({"role": "assistant", "content": ex["rewrite"]})
|
| 195 |
+
|
| 196 |
+
user_prompt = f"Câu hỏi: {question}\nĐáp án gốc: {answer}\nViết lại ngắn gọn, tự nhiên, không thêm thông tin mới."
|
| 197 |
+
if language.lower().startswith("en"):
|
| 198 |
+
user_prompt = (
|
| 199 |
+
f"Question: {question}\nRaw answer: {answer}\n"
|
| 200 |
+
"Rewrite it into a short, natural answer without adding new facts."
|
| 201 |
+
)
|
| 202 |
+
messages.append({"role": "user", "content": user_prompt})
|
| 203 |
+
return messages
|
| 204 |
+
|
| 205 |
+
def rewrite(self, question: str, answer: str, language: str = "vi") -> str:
|
| 206 |
+
"""
|
| 207 |
+
Rewrite câu trả lời để tự nhiên hơn.
|
| 208 |
+
Nếu rewrite model không sẵn sàng, trả về output đã postprocess.
|
| 209 |
+
"""
|
| 210 |
+
if not answer:
|
| 211 |
+
return ""
|
| 212 |
+
|
| 213 |
+
self._lazy_load()
|
| 214 |
+
fallback = postprocess_answer(answer, max_words=self.config.max_words)
|
| 215 |
+
if not self.enabled or not self._ready:
|
| 216 |
+
return fallback
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
messages = self._build_messages(question=question, answer=answer, language=language)
|
| 220 |
+
prompt = self._tokenizer.apply_chat_template(
|
| 221 |
+
messages,
|
| 222 |
+
tokenize=False,
|
| 223 |
+
add_generation_prompt=True,
|
| 224 |
+
)
|
| 225 |
+
inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True)
|
| 226 |
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
| 227 |
+
|
| 228 |
+
with torch.inference_mode():
|
| 229 |
+
output_ids = self._model.generate(
|
| 230 |
+
**inputs,
|
| 231 |
+
max_new_tokens=self.config.max_new_tokens,
|
| 232 |
+
do_sample=False,
|
| 233 |
+
temperature=0.1,
|
| 234 |
+
repetition_penalty=1.05,
|
| 235 |
+
pad_token_id=self._tokenizer.eos_token_id,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
prompt_len = inputs["input_ids"].shape[1]
|
| 239 |
+
generated = self._tokenizer.decode(output_ids[0][prompt_len:], skip_special_tokens=True).strip()
|
| 240 |
+
cleaned = postprocess_answer(generated, max_words=self.config.max_words)
|
| 241 |
+
return cleaned or fallback
|
| 242 |
+
except Exception as exc:
|
| 243 |
+
print(f"[WARNING] Rewrite failed: {exc}")
|
| 244 |
+
return fallback
|
src/utils/discriminative_lr.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Discriminative learning rates for different model layers.
|
| 3 |
+
Earlier layers (pretrained) get lower LR to preserve learned features.
|
| 4 |
+
Later layers get higher LR for task-specific adaptation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.optim import AdamW
|
| 9 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_discriminative_optimizer(model, config):
|
| 13 |
+
"""
|
| 14 |
+
Create optimizer with discriminative learning rates.
|
| 15 |
+
|
| 16 |
+
Layer groups and their learning rates:
|
| 17 |
+
- Image Encoder (pretrained XRV): 1e-5 (preserve medical features)
|
| 18 |
+
- Text Encoder (PhoBERT): 1e-5 (preserve language understanding)
|
| 19 |
+
- Fusion layer (co-attention): 1e-4 (moderate adaptation)
|
| 20 |
+
- Decoder (task-specific): 1e-3 (heavy adaptation)
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model: Model with parameter groups
|
| 24 |
+
config: Config dict with learning rates
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Optimizer with layer-specific learning rates
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# Define parameter groups with different learning rates
|
| 31 |
+
param_groups = []
|
| 32 |
+
|
| 33 |
+
base_lr = float(config['train'].get('learning_rate', 3e-4))
|
| 34 |
+
vision_lr = float(config['train'].get('vision_lr', 1e-5))
|
| 35 |
+
phobert_lr = float(config['train'].get('phobert_lr', 1e-5))
|
| 36 |
+
|
| 37 |
+
# Group 1: Image Encoder (lowest LR)
|
| 38 |
+
if hasattr(model, 'image_encoder'):
|
| 39 |
+
param_groups.append({
|
| 40 |
+
'params': model.image_encoder.parameters(),
|
| 41 |
+
'lr': vision_lr,
|
| 42 |
+
'name': 'image_encoder'
|
| 43 |
+
})
|
| 44 |
+
|
| 45 |
+
# Group 2: Text Encoder (low LR)
|
| 46 |
+
if hasattr(model, 'text_encoder'):
|
| 47 |
+
param_groups.append({
|
| 48 |
+
'params': model.text_encoder.parameters(),
|
| 49 |
+
'lr': phobert_lr,
|
| 50 |
+
'name': 'text_encoder'
|
| 51 |
+
})
|
| 52 |
+
|
| 53 |
+
# Group 3: Fusion/Attention layers (medium LR)
|
| 54 |
+
fusion_params = []
|
| 55 |
+
if hasattr(model, 'fusion'):
|
| 56 |
+
fusion_params.extend(model.fusion.parameters())
|
| 57 |
+
if hasattr(model, 'co_attention'):
|
| 58 |
+
fusion_params.extend(model.co_attention.parameters())
|
| 59 |
+
if hasattr(model, 'spatial_attention'):
|
| 60 |
+
fusion_params.extend(model.spatial_attention.parameters())
|
| 61 |
+
|
| 62 |
+
if fusion_params:
|
| 63 |
+
param_groups.append({
|
| 64 |
+
'params': fusion_params,
|
| 65 |
+
'lr': base_lr * 0.5, # 50% of base LR
|
| 66 |
+
'name': 'fusion'
|
| 67 |
+
})
|
| 68 |
+
|
| 69 |
+
# Group 4: Decoder (highest LR)
|
| 70 |
+
decoder_params = []
|
| 71 |
+
if hasattr(model, 'decoder'):
|
| 72 |
+
decoder_params.extend(model.decoder.parameters())
|
| 73 |
+
if hasattr(model, 'open_head'):
|
| 74 |
+
decoder_params.extend(model.open_head.parameters())
|
| 75 |
+
if hasattr(model, 'closed_head'):
|
| 76 |
+
decoder_params.extend(model.closed_head.parameters())
|
| 77 |
+
|
| 78 |
+
if decoder_params:
|
| 79 |
+
param_groups.append({
|
| 80 |
+
'params': decoder_params,
|
| 81 |
+
'lr': base_lr, # Full base LR
|
| 82 |
+
'name': 'decoder'
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
# Group 5: Any remaining parameters
|
| 86 |
+
# Collect all params that aren't in above groups
|
| 87 |
+
all_params = set(model.parameters())
|
| 88 |
+
grouped_params = set()
|
| 89 |
+
for group in param_groups:
|
| 90 |
+
grouped_params.update(group['params'])
|
| 91 |
+
|
| 92 |
+
remaining_params = [p for p in all_params if p not in grouped_params]
|
| 93 |
+
if remaining_params:
|
| 94 |
+
param_groups.append({
|
| 95 |
+
'params': remaining_params,
|
| 96 |
+
'lr': base_lr * 0.1, # 10% of base LR for safety
|
| 97 |
+
'name': 'remaining'
|
| 98 |
+
})
|
| 99 |
+
|
| 100 |
+
# Create optimizer
|
| 101 |
+
optimizer = AdamW(
|
| 102 |
+
param_groups,
|
| 103 |
+
betas=(0.9, 0.999),
|
| 104 |
+
weight_decay=config['train'].get('weight_decay', 0.01)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Log layer learning rates
|
| 108 |
+
print("[INFO] Discriminative Learning Rates Setup:")
|
| 109 |
+
for group in param_groups:
|
| 110 |
+
param_count = sum(p.numel() for p in group['params'])
|
| 111 |
+
print(f" {group['name']:15s}: LR={group['lr']:.2e}, Params={param_count:,}")
|
| 112 |
+
|
| 113 |
+
return optimizer
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def create_scheduler_with_warmup(optimizer, num_training_steps, config):
|
| 117 |
+
"""
|
| 118 |
+
Create cosine scheduler with warmup.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
optimizer: Optimizer instance
|
| 122 |
+
num_training_steps: Total training steps
|
| 123 |
+
config: Config dict
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
LambdaLR scheduler with warmup
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
warmup_steps = int(num_training_steps * config['train'].get('warmup_steps_ratio', 0.1))
|
| 130 |
+
|
| 131 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 132 |
+
optimizer,
|
| 133 |
+
num_warmup_steps=warmup_steps,
|
| 134 |
+
num_training_steps=num_training_steps,
|
| 135 |
+
num_cycles=0.5, # 0.5 = cosine goes from 1 to 0
|
| 136 |
+
last_epoch=-1
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
print(f"[INFO] Scheduler: Cosine with warmup")
|
| 140 |
+
print(f" Warmup steps: {warmup_steps} ({warmup_steps/num_training_steps*100:.1f}%)")
|
| 141 |
+
print(f" Total steps: {num_training_steps}")
|
| 142 |
+
|
| 143 |
+
return scheduler
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_current_learning_rates(optimizer):
|
| 147 |
+
"""Get current learning rate for each parameter group."""
|
| 148 |
+
lrs = {}
|
| 149 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
| 150 |
+
name = param_group.get('name', f'group_{i}')
|
| 151 |
+
lrs[name] = param_group['lr']
|
| 152 |
+
return lrs
|
src/utils/early_stopping.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced early stopping with multi-metric support.
|
| 3 |
+
Prevents overfitting by tracking multiple metrics simultaneously.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import torch
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MultiMetricEarlyStopping:
|
| 13 |
+
"""
|
| 14 |
+
Early stopping that considers multiple metrics with weighted scores.
|
| 15 |
+
|
| 16 |
+
Advantages over single-metric stopping:
|
| 17 |
+
- Prevents overfitting on one metric while degrading others
|
| 18 |
+
- Better general model performance
|
| 19 |
+
- More stable convergence
|
| 20 |
+
|
| 21 |
+
Example metric weights:
|
| 22 |
+
{'loss': 0.2, 'accuracy': 0.4, 'bertscore': 0.3, 'f1': 0.1}
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, patience=5, metric_weights=None, mode='maximize',
|
| 26 |
+
save_dir=None, verbose=True):
|
| 27 |
+
"""
|
| 28 |
+
Args:
|
| 29 |
+
patience: Number of evaluations with no improvement before stopping
|
| 30 |
+
metric_weights: Dict of {metric_name: weight}. If None, uses 'loss' only
|
| 31 |
+
mode: 'maximize' or 'minimize'
|
| 32 |
+
save_dir: Directory to save best model
|
| 33 |
+
verbose: Print progress
|
| 34 |
+
"""
|
| 35 |
+
self.patience = patience
|
| 36 |
+
self.counter = 0
|
| 37 |
+
self.best_score = None
|
| 38 |
+
self.best_metrics = None
|
| 39 |
+
self.save_dir = Path(save_dir) if save_dir else None
|
| 40 |
+
self.verbose = verbose
|
| 41 |
+
self.mode = mode
|
| 42 |
+
|
| 43 |
+
# Default metric weights if not provided
|
| 44 |
+
if metric_weights is None:
|
| 45 |
+
self.metric_weights = {'loss': 1.0}
|
| 46 |
+
else:
|
| 47 |
+
self.metric_weights = metric_weights
|
| 48 |
+
# Normalize weights to sum to 1
|
| 49 |
+
total_weight = sum(self.metric_weights.values())
|
| 50 |
+
self.metric_weights = {k: v/total_weight for k, v in self.metric_weights.items()}
|
| 51 |
+
|
| 52 |
+
self.history = []
|
| 53 |
+
|
| 54 |
+
if self.save_dir:
|
| 55 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
def compute_score(self, metrics):
|
| 58 |
+
"""
|
| 59 |
+
Compute weighted score from multiple metrics.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
metrics: Dict of metric_name -> value
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Weighted score
|
| 66 |
+
"""
|
| 67 |
+
score = 0.0
|
| 68 |
+
|
| 69 |
+
for metric_name, weight in self.metric_weights.items():
|
| 70 |
+
if metric_name not in metrics:
|
| 71 |
+
if self.verbose:
|
| 72 |
+
print(f"[WARNING] Metric '{metric_name}' not found in current metrics")
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
metric_value = metrics[metric_name]
|
| 76 |
+
|
| 77 |
+
# Handle loss (we want to minimize it)
|
| 78 |
+
if 'loss' in metric_name.lower():
|
| 79 |
+
# Invert loss for maximization context
|
| 80 |
+
metric_contribution = -metric_value if self.mode == 'maximize' else metric_value
|
| 81 |
+
else:
|
| 82 |
+
# Most metrics should be maximized (accuracy, F1, etc.)
|
| 83 |
+
metric_contribution = metric_value
|
| 84 |
+
|
| 85 |
+
score += metric_contribution * weight
|
| 86 |
+
|
| 87 |
+
return score
|
| 88 |
+
|
| 89 |
+
def __call__(self, metrics, model=None, epoch=None):
|
| 90 |
+
"""
|
| 91 |
+
Check if should stop training.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
metrics: Dict of metric_name -> value
|
| 95 |
+
model: Model to save if best
|
| 96 |
+
epoch: Current epoch number
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
True if should stop, False otherwise
|
| 100 |
+
"""
|
| 101 |
+
score = self.compute_score(metrics)
|
| 102 |
+
|
| 103 |
+
# Store history
|
| 104 |
+
self.history.append({
|
| 105 |
+
'epoch': epoch,
|
| 106 |
+
'score': score,
|
| 107 |
+
'metrics': metrics.copy()
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
if self.best_score is None:
|
| 111 |
+
self.best_score = score
|
| 112 |
+
self.best_metrics = metrics.copy()
|
| 113 |
+
if model is not None and self.save_dir:
|
| 114 |
+
self._save_checkpoint(model, epoch, metrics)
|
| 115 |
+
elif score > self.best_score:
|
| 116 |
+
self.best_score = score
|
| 117 |
+
self.best_metrics = metrics.copy()
|
| 118 |
+
self.counter = 0
|
| 119 |
+
if model is not None and self.save_dir:
|
| 120 |
+
self._save_checkpoint(model, epoch, metrics)
|
| 121 |
+
if self.verbose:
|
| 122 |
+
print(f"✓ Epoch {epoch}: New best score {score:.4f}")
|
| 123 |
+
else:
|
| 124 |
+
self.counter += 1
|
| 125 |
+
if self.verbose:
|
| 126 |
+
print(f"✗ Epoch {epoch}: No improvement ({self.counter}/{self.patience})")
|
| 127 |
+
|
| 128 |
+
# Check if should stop
|
| 129 |
+
if self.counter >= self.patience:
|
| 130 |
+
if self.verbose:
|
| 131 |
+
print(f"\n[EARLY STOPPING] Patience exceeded. Best metrics:")
|
| 132 |
+
for k, v in self.best_metrics.items():
|
| 133 |
+
if isinstance(v, float):
|
| 134 |
+
print(f" {k}: {v:.4f}")
|
| 135 |
+
return True
|
| 136 |
+
|
| 137 |
+
return False
|
| 138 |
+
|
| 139 |
+
def _save_checkpoint(self, model, epoch, metrics):
|
| 140 |
+
"""Save best model checkpoint."""
|
| 141 |
+
if self.save_dir is None:
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
checkpoint = {
|
| 145 |
+
'epoch': epoch,
|
| 146 |
+
'model_state_dict': model.state_dict(),
|
| 147 |
+
'metrics': metrics
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
save_path = self.save_dir / f"best_checkpoint_epoch_{epoch}.pt"
|
| 151 |
+
torch.save(checkpoint, save_path)
|
| 152 |
+
|
| 153 |
+
# Also save metrics record
|
| 154 |
+
metrics_path = self.save_dir / f"best_metrics_epoch_{epoch}.json"
|
| 155 |
+
with open(metrics_path, 'w') as f:
|
| 156 |
+
json.dump(metrics, f, indent=2, default=str)
|
| 157 |
+
|
| 158 |
+
if self.verbose:
|
| 159 |
+
print(f" 💾 Saved checkpoint to {save_path}")
|
| 160 |
+
|
| 161 |
+
def get_best_metrics(self):
|
| 162 |
+
"""Return best metrics found during training."""
|
| 163 |
+
return self.best_metrics
|
| 164 |
+
|
| 165 |
+
def get_history(self):
|
| 166 |
+
"""Return training history."""
|
| 167 |
+
return self.history
|
| 168 |
+
|
| 169 |
+
def plot_metrics(self, save_path=None):
|
| 170 |
+
"""
|
| 171 |
+
Plot metric progression during training.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
save_path: Path to save figure
|
| 175 |
+
"""
|
| 176 |
+
try:
|
| 177 |
+
import matplotlib.pyplot as plt
|
| 178 |
+
except ImportError:
|
| 179 |
+
print("[WARNING] matplotlib not installed, cannot plot")
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
if not self.history:
|
| 183 |
+
print("[WARNING] No history to plot")
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
epochs = [h['epoch'] for h in self.history]
|
| 187 |
+
scores = [h['score'] for h in self.history]
|
| 188 |
+
|
| 189 |
+
plt.figure(figsize=(10, 6))
|
| 190 |
+
plt.plot(epochs, scores, 'b-o', label='Composite Score')
|
| 191 |
+
plt.axhline(y=self.best_score, color='r', linestyle='--', label=f'Best: {self.best_score:.4f}')
|
| 192 |
+
plt.xlabel('Epoch')
|
| 193 |
+
plt.ylabel('Score')
|
| 194 |
+
plt.legend()
|
| 195 |
+
plt.title('Early Stopping - Composite Metric Score')
|
| 196 |
+
plt.grid(True, alpha=0.3)
|
| 197 |
+
|
| 198 |
+
if save_path:
|
| 199 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 200 |
+
print(f"[INFO] Metric plot saved to {save_path}")
|
| 201 |
+
|
| 202 |
+
plt.close()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class DynamicClassWeights:
|
| 206 |
+
"""
|
| 207 |
+
Compute class weights dynamically from training data.
|
| 208 |
+
Adapts to actual data distribution.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def compute_weights(dataloader, device='cpu'):
|
| 213 |
+
"""
|
| 214 |
+
Compute class weights from data distribution.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
dataloader: DataLoader to analyze
|
| 218 |
+
device: Device for tensor
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Tensor of class weights
|
| 222 |
+
"""
|
| 223 |
+
class_counts = {}
|
| 224 |
+
|
| 225 |
+
for batch in dataloader:
|
| 226 |
+
labels = batch.get('label_closed', None)
|
| 227 |
+
if labels is None:
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
# Count occurrences of each class
|
| 231 |
+
unique_labels, counts = torch.unique(labels, return_counts=True)
|
| 232 |
+
for label, count in zip(unique_labels, counts):
|
| 233 |
+
label_idx = label.item()
|
| 234 |
+
if label_idx >= 0: # Ignore negative indices
|
| 235 |
+
class_counts[label_idx] = class_counts.get(label_idx, 0) + count.item()
|
| 236 |
+
|
| 237 |
+
if not class_counts:
|
| 238 |
+
# Default weights if no data found
|
| 239 |
+
return torch.ones(2, device=device)
|
| 240 |
+
|
| 241 |
+
# Compute inverse frequency weights
|
| 242 |
+
total_samples = sum(class_counts.values())
|
| 243 |
+
num_classes = len(class_counts)
|
| 244 |
+
|
| 245 |
+
weights = torch.zeros(max(class_counts.keys()) + 1, device=device)
|
| 246 |
+
for class_idx, count in class_counts.items():
|
| 247 |
+
# Weight = total / (num_classes * count) - higher weight for rarer classes
|
| 248 |
+
weight = total_samples / (num_classes * max(count, 1))
|
| 249 |
+
weights[class_idx] = weight
|
| 250 |
+
|
| 251 |
+
# Normalize to sum to num_classes
|
| 252 |
+
weights = weights / weights.sum() * num_classes
|
| 253 |
+
|
| 254 |
+
print("[INFO] Dynamic Class Weights:")
|
| 255 |
+
for class_idx in sorted(class_counts.keys()):
|
| 256 |
+
print(f" Class {class_idx}: Weight={weights[class_idx]:.4f}, Samples={class_counts[class_idx]}")
|
| 257 |
+
|
| 258 |
+
return weights.to(device)
|
src/utils/evaluation_viz.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import seaborn as sns
|
| 4 |
+
from sklearn.metrics import confusion_matrix
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
def plot_confusion_matrix(y_true, y_pred, classes, title='Confusion Matrix', cmap=plt.cm.Blues):
|
| 8 |
+
"""
|
| 9 |
+
Vẽ Confusion Matrix chuyên nghiệp cho các câu hỏi Closed-ended (Yes/No).
|
| 10 |
+
"""
|
| 11 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 12 |
+
plt.figure(figsize=(8, 6))
|
| 13 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap=cmap,
|
| 14 |
+
xticklabels=classes, yticklabels=classes)
|
| 15 |
+
plt.title(title, fontsize=15)
|
| 16 |
+
plt.ylabel('Ground Truth', fontsize=12)
|
| 17 |
+
plt.xlabel('Predicted', fontsize=12)
|
| 18 |
+
plt.tight_layout()
|
| 19 |
+
return plt
|
| 20 |
+
|
| 21 |
+
def plot_radar_chart(model_names, metrics_data, categories, title='Model Comparison (All Variants)'):
|
| 22 |
+
"""
|
| 23 |
+
Vẽ biểu đồ Radar để so sánh 5 biến thể trên nhiều tiêu chí (Accuracy, BLEU, ROUGE, BERTScore).
|
| 24 |
+
metrics_data: List of lists, mỗi list là chỉ số của 1 model.
|
| 25 |
+
"""
|
| 26 |
+
N = len(categories)
|
| 27 |
+
angles = [n / float(N) * 2 * np.pi for n in range(N)]
|
| 28 |
+
angles += angles[:1]
|
| 29 |
+
|
| 30 |
+
fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))
|
| 31 |
+
|
| 32 |
+
for i, model_name in enumerate(model_names):
|
| 33 |
+
values = metrics_data[i]
|
| 34 |
+
values += values[:1]
|
| 35 |
+
ax.plot(angles, values, linewidth=2, linestyle='solid', label=model_name)
|
| 36 |
+
ax.fill(angles, values, alpha=0.1)
|
| 37 |
+
|
| 38 |
+
ax.set_theta_offset(np.pi / 2)
|
| 39 |
+
ax.set_theta_direction(-1)
|
| 40 |
+
plt.xticks(angles[:-1], categories, fontsize=12)
|
| 41 |
+
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
|
| 42 |
+
plt.title(title, size=20, y=1.1)
|
| 43 |
+
return plt
|
| 44 |
+
|
| 45 |
+
def plot_training_history(history, title='Training History'):
|
| 46 |
+
"""
|
| 47 |
+
Vẽ đồ thị Loss và Accuracy trong quá trình huấn luyện.
|
| 48 |
+
history: dict có keys 'train_loss', 'val_acc', v.v.
|
| 49 |
+
"""
|
| 50 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
| 51 |
+
|
| 52 |
+
# Loss plot
|
| 53 |
+
ax1.plot(history['train_loss'], label='Train Loss')
|
| 54 |
+
if 'val_loss' in history:
|
| 55 |
+
ax1.plot(history['val_loss'], label='Val Loss')
|
| 56 |
+
ax1.set_title('Loss Evolution')
|
| 57 |
+
ax1.set_xlabel('Epochs')
|
| 58 |
+
ax1.set_ylabel('Loss')
|
| 59 |
+
ax1.legend()
|
| 60 |
+
ax1.grid(True)
|
| 61 |
+
|
| 62 |
+
# Accuracy plot
|
| 63 |
+
ax2.plot(history['val_acc'], label='Val Accuracy', color='green')
|
| 64 |
+
ax2.set_title('Accuracy Evolution')
|
| 65 |
+
ax2.set_xlabel('Epochs')
|
| 66 |
+
ax2.set_ylabel('Accuracy')
|
| 67 |
+
ax2.legend()
|
| 68 |
+
ax2.grid(True)
|
| 69 |
+
|
| 70 |
+
plt.suptitle(title, fontsize=16)
|
| 71 |
+
plt.tight_layout()
|
| 72 |
+
return plt
|
| 73 |
+
|
| 74 |
+
def plot_benchmark_comparison(results_df, metric='Accuracy'):
|
| 75 |
+
"""
|
| 76 |
+
Biểu đồ cột so sánh một chỉ số cụ thể giữa các mô hình.
|
| 77 |
+
results_df: DataFrame có cột 'Model' và các chỉ số.
|
| 78 |
+
"""
|
| 79 |
+
plt.figure(figsize=(10, 6))
|
| 80 |
+
sns.set_style("whitegrid")
|
| 81 |
+
ax = sns.barplot(x='Model', y=metric, data=results_df, palette='viridis')
|
| 82 |
+
|
| 83 |
+
for p in ax.patches:
|
| 84 |
+
ax.annotate(format(p.get_height(), '.4f'),
|
| 85 |
+
(p.get_x() + p.get_width() / 2., p.get_height()),
|
| 86 |
+
ha = 'center', va = 'center',
|
| 87 |
+
xytext = (0, 9),
|
| 88 |
+
textcoords = 'offset points',
|
| 89 |
+
fontsize=11)
|
| 90 |
+
|
| 91 |
+
plt.title(f'Comparison of {metric} across Variants', fontsize=15)
|
| 92 |
+
plt.ylim(0, 1.1)
|
| 93 |
+
plt.tight_layout()
|
| 94 |
+
return plt
|
| 95 |
+
|
| 96 |
+
def plot_accuracy_by_category(data_df, category_col='Organ', title='Accuracy by Medical Category'):
|
| 97 |
+
"""
|
| 98 |
+
Biểu đồ cột phân nhóm để so sánh độ chính xác giữa các cơ quan hoặc loại câu hỏi.
|
| 99 |
+
data_df: DataFrame có cột category_col, 'Model', và 'Correct' (bool).
|
| 100 |
+
"""
|
| 101 |
+
acc_df = data_df.groupby([category_col, 'Model'])['Correct'].mean().reset_index()
|
| 102 |
+
|
| 103 |
+
plt.figure(figsize=(12, 6))
|
| 104 |
+
sns.barplot(x=category_col, y='Correct', hue='Model', data=acc_df)
|
| 105 |
+
plt.title(title, fontsize=15)
|
| 106 |
+
plt.ylabel('Accuracy')
|
| 107 |
+
plt.xticks(rotation=45)
|
| 108 |
+
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
| 109 |
+
plt.tight_layout()
|
| 110 |
+
return plt
|
| 111 |
+
|
| 112 |
+
def plot_semantic_distribution(model_scores_dict, title='Semantic Score Distribution (LLM-Judge)'):
|
| 113 |
+
"""
|
| 114 |
+
Vẽ biểu đồ Violin để so sánh phân bổ điểm số ngữ nghĩa giữa các model (ví dụ B2 vs DPO).
|
| 115 |
+
model_scores_dict: {'Model A': [scores], 'Model B': [scores]}
|
| 116 |
+
"""
|
| 117 |
+
data = []
|
| 118 |
+
for model, scores in model_scores_dict.items():
|
| 119 |
+
for s in scores:
|
| 120 |
+
data.append({'Model': model, 'Score': s})
|
| 121 |
+
df = pd.DataFrame(data)
|
| 122 |
+
|
| 123 |
+
plt.figure(figsize=(10, 6))
|
| 124 |
+
sns.violinplot(x='Model', y='Score', data=df, inner="quart", palette="Set3")
|
| 125 |
+
plt.title(title, fontsize=15)
|
| 126 |
+
plt.ylim(-0.1, 1.1)
|
| 127 |
+
plt.tight_layout()
|
| 128 |
+
return plt
|
| 129 |
+
|
| 130 |
+
def plot_latency_vs_accuracy(model_stats, title='Accuracy vs. Latency Trade-off'):
|
| 131 |
+
"""
|
| 132 |
+
Biểu đồ bong bóng so sánh Tốc độ và Độ chính xác.
|
| 133 |
+
model_stats: List of dicts [{'name': 'A1', 'accuracy': 0.8, 'latency': 0.1, 'params': 100M}, ...]
|
| 134 |
+
"""
|
| 135 |
+
df = pd.DataFrame(model_stats)
|
| 136 |
+
plt.figure(figsize=(10, 7))
|
| 137 |
+
|
| 138 |
+
scatter = plt.scatter(df['latency'], df['accuracy'],
|
| 139 |
+
s=df['params_mb']*10, # Kích thước bong bóng theo số lượng tham số
|
| 140 |
+
alpha=0.5, c=np.arange(len(df)), cmap='viridis')
|
| 141 |
+
|
| 142 |
+
for i, txt in enumerate(df['name']):
|
| 143 |
+
plt.annotate(txt, (df['latency'][i], df['accuracy'][i]), fontsize=12)
|
| 144 |
+
|
| 145 |
+
plt.xlabel('Latency (seconds/sample)', fontsize=12)
|
| 146 |
+
plt.ylabel('Accuracy', fontsize=12)
|
| 147 |
+
plt.title(title, fontsize=15)
|
| 148 |
+
plt.grid(True, linestyle='--', alpha=0.6)
|
| 149 |
+
plt.tight_layout()
|
| 150 |
+
return plt
|
| 151 |
+
|
| 152 |
+
def plot_calibration_curve(y_true, y_probs, n_bins=10, title='Calibration Curve (Reliability)'):
|
| 153 |
+
"""
|
| 154 |
+
Biểu đồ hiệu chuẩn để xem độ tin cậy của xác suất dự đoán.
|
| 155 |
+
y_true: nhãn thực tế [0, 1]
|
| 156 |
+
y_probs: xác suất dự đoán lớp 1
|
| 157 |
+
"""
|
| 158 |
+
from sklearn.calibration import calibration_curve
|
| 159 |
+
prob_true, prob_pred = calibration_curve(y_true, y_probs, n_bins=n_bins)
|
| 160 |
+
|
| 161 |
+
plt.figure(figsize=(8, 8))
|
| 162 |
+
plt.plot(prob_pred, prob_true, "s-", label='Model')
|
| 163 |
+
plt.plot([0, 1], [0, 1], "k--", label='Perfectly Calibrated')
|
| 164 |
+
plt.ylabel('Fraction of Positives', fontsize=12)
|
| 165 |
+
plt.xlabel('Mean Predicted Probability', fontsize=12)
|
| 166 |
+
plt.title(title, fontsize=15)
|
| 167 |
+
plt.legend(loc="lower right")
|
| 168 |
+
plt.grid(True)
|
| 169 |
+
plt.tight_layout()
|
| 170 |
+
return plt
|
| 171 |
+
|
| 172 |
+
def plot_performance_vs_length(questions, corrects, title='Accuracy vs. Question Length'):
|
| 173 |
+
"""
|
| 174 |
+
Biểu đồ xem độ chính xác có giảm khi câu hỏi dài hơn không.
|
| 175 |
+
questions: list các câu hỏi.
|
| 176 |
+
corrects: list các giá trị bool (đúng/sai).
|
| 177 |
+
"""
|
| 178 |
+
lengths = [len(q.split()) for q in questions]
|
| 179 |
+
df = pd.DataFrame({'Length': lengths, 'Correct': corrects})
|
| 180 |
+
# Chia nhóm độ dài (bins)
|
| 181 |
+
df['Length_Group'] = pd.cut(df['Length'], bins=[0, 5, 10, 15, 20, 30, 50],
|
| 182 |
+
labels=['1-5', '6-10', '11-15', '16-20', '21-30', '31+'])
|
| 183 |
+
|
| 184 |
+
acc_by_len = df.groupby('Length_Group')['Correct'].mean().reset_index()
|
| 185 |
+
|
| 186 |
+
plt.figure(figsize=(10, 6))
|
| 187 |
+
sns.lineplot(x='Length_Group', y='Correct', data=acc_by_len, marker='o', color='red')
|
| 188 |
+
plt.title(title, fontsize=15)
|
| 189 |
+
plt.ylabel('Accuracy')
|
| 190 |
+
plt.xlabel('Question Length (words)')
|
| 191 |
+
plt.ylim(0, 1.1)
|
| 192 |
+
plt.grid(True, axis='y')
|
| 193 |
+
plt.tight_layout()
|
| 194 |
+
return plt
|
src/utils/helpers.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import collections
|
| 3 |
+
|
| 4 |
+
def normalize_answer(s):
|
| 5 |
+
"""
|
| 6 |
+
Chuẩn hóa câu trả lời: viết thường, bỏ dấu câu, bỏ mạo từ...
|
| 7 |
+
"""
|
| 8 |
+
def remove_articles(text):
|
| 9 |
+
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
| 10 |
+
|
| 11 |
+
def white_space_fix(text):
|
| 12 |
+
return ' '.join(text.split())
|
| 13 |
+
|
| 14 |
+
def remove_punc(text):
|
| 15 |
+
exclude = set(r'!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~')
|
| 16 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
| 17 |
+
|
| 18 |
+
def lower(text):
|
| 19 |
+
return text.lower()
|
| 20 |
+
|
| 21 |
+
return white_space_fix(remove_articles(remove_punc(lower(str(s)))))
|
| 22 |
+
|
| 23 |
+
def majority_answer(answer_list):
|
| 24 |
+
"""
|
| 25 |
+
Lấy câu trả lời xuất hiện nhiều nhất trong danh sách (Voting).
|
| 26 |
+
"""
|
| 27 |
+
if not answer_list:
|
| 28 |
+
return ""
|
| 29 |
+
count = collections.Counter([normalize_answer(a) for a in answer_list])
|
| 30 |
+
return count.most_common(1)[0][0]
|
src/utils/metrics.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation metrics for VQA: Accuracy, EM, F1, BLEU-1~4, METEOR, and Semantic Score."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
from collections import Counter
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
| 8 |
+
from nltk.translate.meteor_score import meteor_score as _nltk_meteor
|
| 9 |
+
|
| 10 |
+
import nltk
|
| 11 |
+
try:
|
| 12 |
+
nltk.data.find('corpora/wordnet')
|
| 13 |
+
except LookupError:
|
| 14 |
+
print("[INFO] Đang tự động tải bộ từ điển NLTK WordNet cho METEOR score...")
|
| 15 |
+
nltk.download('wordnet', quiet=True)
|
| 16 |
+
nltk.download('omw-1.4', quiet=True)
|
| 17 |
+
|
| 18 |
+
# 1. Semantic Score (SentenceTransformer)
|
| 19 |
+
try:
|
| 20 |
+
from sentence_transformers import SentenceTransformer, util
|
| 21 |
+
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 22 |
+
except Exception as e:
|
| 23 |
+
semantic_model = None
|
| 24 |
+
print(f"Warning: Could not load SentenceTransformer: {e}")
|
| 25 |
+
|
| 26 |
+
# 2. BERTScore
|
| 27 |
+
try:
|
| 28 |
+
from bert_score import BERTScorer
|
| 29 |
+
# Ép sử dụng model multilingual để tránh lỗi attribute của Tokenizer trên Python 3.12
|
| 30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
bert_scorer = BERTScorer(model_type="bert-base-multilingual-cased", device=device)
|
| 32 |
+
except ImportError:
|
| 33 |
+
print("[WARNING] Thư viện bert_score chưa được cài đặt.")
|
| 34 |
+
bert_scorer = None
|
| 35 |
+
except Exception as e:
|
| 36 |
+
bert_scorer = None
|
| 37 |
+
print(f"Warning: Could not load BERTScorer: {e}")
|
| 38 |
+
|
| 39 |
+
# 3. ROUGE-L
|
| 40 |
+
try:
|
| 41 |
+
from rouge_score import rouge_scorer as rs
|
| 42 |
+
rouge_l_scorer = rs.RougeScorer(['rougeL'], use_stemmer=True)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
rouge_l_scorer = None
|
| 45 |
+
print(f"Warning: Could not load rouge-score: {e}")
|
| 46 |
+
|
| 47 |
+
# [FIX] Import from the local text_utils instead of non-existent src.data.preprocessing
|
| 48 |
+
from .text_utils import normalize_answer, majority_answer
|
| 49 |
+
|
| 50 |
+
def compute_rouge_l(pred: str, refs) -> float:
|
| 51 |
+
"""Tính ROUGE-L (Lấy MAX over multiple refs)."""
|
| 52 |
+
if not rouge_l_scorer: return 0.0
|
| 53 |
+
if isinstance(refs, str): refs = [refs]
|
| 54 |
+
best_rouge = 0.0
|
| 55 |
+
for r in refs:
|
| 56 |
+
score = rouge_l_scorer.score(normalize_answer(r), normalize_answer(pred))['rougeL'].fmeasure
|
| 57 |
+
best_rouge = max(best_rouge, score)
|
| 58 |
+
return best_rouge
|
| 59 |
+
|
| 60 |
+
def compute_bertscore(preds: list[str], refs: list) -> float:
|
| 61 |
+
"""Tính BERTScore cho cả batch."""
|
| 62 |
+
if not bert_scorer or not preds or not refs:
|
| 63 |
+
return 0.0
|
| 64 |
+
|
| 65 |
+
clean_preds = [normalize_answer(p) if normalize_answer(p).strip() else "." for p in preds]
|
| 66 |
+
clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs]
|
| 67 |
+
clean_refs = [r if r.strip() else "." for r in clean_refs]
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
# Tăng tốc bằng cách tắt idf nếu cần
|
| 71 |
+
P, R, F1 = bert_scorer.score(clean_preds, clean_refs)
|
| 72 |
+
return float(F1.mean().item())
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"[WARNING] BERTScore error: {e}")
|
| 75 |
+
return 0.0
|
| 76 |
+
|
| 77 |
+
def compute_exact_match(pred: str, refs) -> float:
|
| 78 |
+
"""So khớp chính xác lấy MAX (soft match over multiple refs)."""
|
| 79 |
+
if isinstance(refs, str): refs = [refs]
|
| 80 |
+
return float(any(normalize_answer(pred) == normalize_answer(r) for r in refs))
|
| 81 |
+
|
| 82 |
+
def compute_f1(pred: str, refs) -> float:
|
| 83 |
+
"""Tính F1-score ở mức độ token. Lấy MAX over multiple refs."""
|
| 84 |
+
if isinstance(refs, str): refs = [refs]
|
| 85 |
+
best_f1 = 0.0
|
| 86 |
+
p_toks = normalize_answer(pred).split()
|
| 87 |
+
for r in refs:
|
| 88 |
+
r_toks = normalize_answer(r).split()
|
| 89 |
+
if not p_toks or not r_toks:
|
| 90 |
+
f1 = float(p_toks == r_toks)
|
| 91 |
+
else:
|
| 92 |
+
common = Counter(p_toks) & Counter(r_toks)
|
| 93 |
+
num_same = sum(common.values())
|
| 94 |
+
if num_same == 0:
|
| 95 |
+
f1 = 0.0
|
| 96 |
+
else:
|
| 97 |
+
precision = num_same / len(p_toks)
|
| 98 |
+
recall = num_same / len(r_toks)
|
| 99 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 100 |
+
best_f1 = max(best_f1, f1)
|
| 101 |
+
return best_f1
|
| 102 |
+
|
| 103 |
+
def compute_bleu(pred: str, refs) -> dict[str, float]:
|
| 104 |
+
"""Tính BLEU from 1 đến 4 sử dụng corpus-level refs."""
|
| 105 |
+
if isinstance(refs, str): refs = [refs]
|
| 106 |
+
smoothie = SmoothingFunction().method4
|
| 107 |
+
p_toks = normalize_answer(pred).split()
|
| 108 |
+
r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()]
|
| 109 |
+
|
| 110 |
+
if not p_toks or not r_toks_list:
|
| 111 |
+
return {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0}
|
| 112 |
+
|
| 113 |
+
weights = [
|
| 114 |
+
(1, 0, 0, 0), # BLEU-1
|
| 115 |
+
(0.5, 0.5, 0, 0), # BLEU-2
|
| 116 |
+
(0.33, 0.33, 0.33, 0), # BLEU-3
|
| 117 |
+
(0.25, 0.25, 0.25, 0.25) # BLEU-4
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
f"bleu{i+1}": sentence_bleu(r_toks_list, p_toks, weights=w, smoothing_function=smoothie)
|
| 122 |
+
for i, w in enumerate(weights)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
def compute_meteor(pred: str, refs) -> float:
|
| 126 |
+
"""Tính METEOR score (hỗ trợ N refs)."""
|
| 127 |
+
if isinstance(refs, str): refs = [refs]
|
| 128 |
+
p_toks = normalize_answer(pred).split()
|
| 129 |
+
r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()]
|
| 130 |
+
if not p_toks or not r_toks_list:
|
| 131 |
+
return 0.0
|
| 132 |
+
return _nltk_meteor(r_toks_list, p_toks)
|
| 133 |
+
|
| 134 |
+
def compute_vqa_accuracy(pred: str, direct_answers) -> float:
|
| 135 |
+
"""
|
| 136 |
+
Tính VQA Accuracy mềm: min(#người_cùng_đáp_án / 3, 1.0).
|
| 137 |
+
Using cho các tập dữ liệu có nhiều người gắn nhãn (như A-OKVQA).
|
| 138 |
+
"""
|
| 139 |
+
if isinstance(direct_answers, str):
|
| 140 |
+
return compute_exact_match(pred, direct_answers)
|
| 141 |
+
|
| 142 |
+
normed_pred = normalize_answer(pred)
|
| 143 |
+
matches = sum(1 for a in direct_answers if normalize_answer(a) == normed_pred)
|
| 144 |
+
return min(matches / 3.0, 1.0)
|
| 145 |
+
|
| 146 |
+
def compute_semantic_score(preds: list[str], refs: list) -> float:
|
| 147 |
+
"""Tính điểm tương đồng ngữ nghĩa bằng Cosine Similarity."""
|
| 148 |
+
if not semantic_model or not preds or not refs:
|
| 149 |
+
return 0.0
|
| 150 |
+
|
| 151 |
+
clean_preds = [normalize_answer(p) for p in preds]
|
| 152 |
+
# Take the most representative string if it's a list for semantic comparison
|
| 153 |
+
clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs]
|
| 154 |
+
|
| 155 |
+
# Encode to Vector (Embeddings)
|
| 156 |
+
pred_embs = semantic_model.encode(clean_preds, convert_to_tensor=True, show_progress_bar=False)
|
| 157 |
+
ref_embs = semantic_model.encode(clean_refs, convert_to_tensor=True, show_progress_bar=False)
|
| 158 |
+
|
| 159 |
+
# Compute Cosine distance matrix and take diagonal (1-to-1 comparison)
|
| 160 |
+
cosine_scores = util.cos_sim(pred_embs, ref_embs)
|
| 161 |
+
scores = torch.diag(cosine_scores)
|
| 162 |
+
|
| 163 |
+
return float(scores.mean().item())
|
| 164 |
+
|
| 165 |
+
def batch_metrics(predictions: list[str], references: list) -> dict[str, float]:
|
| 166 |
+
"""Tổng hợp toàn bộ chỉ số đo lường trên batch."""
|
| 167 |
+
results = {
|
| 168 |
+
"accuracy": [], "em": [], "f1": [], "meteor": [],
|
| 169 |
+
"bleu1": [], "bleu2": [], "bleu3": [], "bleu4": [],
|
| 170 |
+
"rouge_l": []
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
for pred, ref in zip(predictions, references):
|
| 174 |
+
# Pass full refs list to compute_f1, compute_bleu to maximize score
|
| 175 |
+
results["accuracy"].append(compute_vqa_accuracy(pred, ref))
|
| 176 |
+
results["em"].append(compute_exact_match(pred, ref))
|
| 177 |
+
results["f1"].append(compute_f1(pred, ref))
|
| 178 |
+
results["meteor"].append(compute_meteor(pred, ref))
|
| 179 |
+
results["rouge_l"].append(compute_rouge_l(pred, ref))
|
| 180 |
+
|
| 181 |
+
bleus = compute_bleu(pred, ref)
|
| 182 |
+
for k, v in bleus.items():
|
| 183 |
+
results[k].append(v)
|
| 184 |
+
|
| 185 |
+
# Average traditional metrics
|
| 186 |
+
final_metrics = {k: float(np.mean(v)) for k, v in results.items()}
|
| 187 |
+
|
| 188 |
+
# Compute Semantic Score and BERTScore for entire batch
|
| 189 |
+
final_metrics["semantic"] = compute_semantic_score(predictions, references)
|
| 190 |
+
final_metrics["bert_score"] = compute_bertscore(predictions, references)
|
| 191 |
+
|
| 192 |
+
return final_metrics
|
src/utils/optimized_metrics.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Optimized metrics computation with batching for significant speed improvement.
|
| 3 |
+
Replaces sequential computation with parallel batch processing.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List, Tuple, Dict
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from bert_score import score as bert_score_fn
|
| 15 |
+
except ImportError:
|
| 16 |
+
bert_score_fn = None
|
| 17 |
+
warnings.warn("bert-score not installed, BERTScore will be unavailable")
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from rouge_score import rouge_scorer
|
| 21 |
+
ROUGE_SCORER = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
|
| 22 |
+
except ImportError:
|
| 23 |
+
ROUGE_SCORER = None
|
| 24 |
+
warnings.warn("rouge-score not installed, ROUGE will be unavailable")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def normalize_answer(s: str) -> str:
|
| 28 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
| 29 |
+
s = s.lower().strip()
|
| 30 |
+
return " ".join(s.split())
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def compute_bertscore_batch(preds: List[str], refs: List[str],
|
| 34 |
+
model_type: str = "bert-base-multilingual-cased",
|
| 35 |
+
batch_size: int = 32,
|
| 36 |
+
device: str = "cuda") -> float:
|
| 37 |
+
"""
|
| 38 |
+
Compute BERTScore efficiently using batch processing.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
preds: List of predictions
|
| 42 |
+
refs: List of references
|
| 43 |
+
model_type: BERT model to use
|
| 44 |
+
batch_size: Batch size for processing
|
| 45 |
+
device: Device to run on (cuda/cpu)
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Average F1 score
|
| 49 |
+
|
| 50 |
+
Performance: 10-20x faster than sequential computation
|
| 51 |
+
"""
|
| 52 |
+
if not bert_score_fn or not preds or not refs:
|
| 53 |
+
return 0.0
|
| 54 |
+
|
| 55 |
+
clean_preds = [normalize_answer(p) if normalize_answer(p).strip() else "." for p in preds]
|
| 56 |
+
clean_refs = [normalize_answer(r) if isinstance(r, str) else normalize_answer(r[0] if r else ".") for r in refs]
|
| 57 |
+
clean_refs = [r if r.strip() else "." for r in clean_refs]
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
# Key optimization: batch compute scores instead of sequential
|
| 61 |
+
P, R, F1 = bert_score_fn(
|
| 62 |
+
clean_preds,
|
| 63 |
+
clean_refs,
|
| 64 |
+
model_type=model_type,
|
| 65 |
+
batch_size=batch_size,
|
| 66 |
+
device=device,
|
| 67 |
+
verbose=False
|
| 68 |
+
)
|
| 69 |
+
return float(F1.mean().item())
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"[WARNING] BERTScore error: {e}")
|
| 72 |
+
return 0.0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def compute_rouge_batch(preds: List[str], refs: List[str],
|
| 76 |
+
rouge_types: List[str] = ['rouge1', 'rougeL']) -> Dict[str, float]:
|
| 77 |
+
"""
|
| 78 |
+
Compute ROUGE scores efficiently using batched computation.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
preds: List of predictions
|
| 82 |
+
refs: List of references
|
| 83 |
+
rouge_types: ROUGE metrics to compute
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Dictionary of ROUGE scores
|
| 87 |
+
|
| 88 |
+
Performance: Vectorized computation
|
| 89 |
+
"""
|
| 90 |
+
if not ROUGE_SCORER or not preds or not refs:
|
| 91 |
+
return {f"{rt}_f": 0.0 for rt in rouge_types}
|
| 92 |
+
|
| 93 |
+
clean_preds = [normalize_answer(p) if normalize_answer(p).strip() else "." for p in preds]
|
| 94 |
+
clean_refs = [normalize_answer(r) if isinstance(r, str) else normalize_answer(r[0] if r else ".") for r in refs]
|
| 95 |
+
|
| 96 |
+
results = {f"{rt}_f": [] for rt in rouge_types}
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
for pred, ref in zip(clean_preds, clean_refs):
|
| 100 |
+
scores = ROUGE_SCORER.score(ref, pred)
|
| 101 |
+
for rt in rouge_types:
|
| 102 |
+
results[f"{rt}_f"].append(scores[rt].fmeasure)
|
| 103 |
+
|
| 104 |
+
# Average across all samples
|
| 105 |
+
averaged = {k: np.mean(v) if v else 0.0 for k, v in results.items()}
|
| 106 |
+
return averaged
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"[WARNING] ROUGE error: {e}")
|
| 109 |
+
return {f"{rt}_f": 0.0 for rt in rouge_types}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def compute_exact_match_batch(preds: List[str], refs: List[str]) -> float:
|
| 113 |
+
"""
|
| 114 |
+
Compute exact match efficiently in batch.
|
| 115 |
+
|
| 116 |
+
Performance: Vectorized string comparison
|
| 117 |
+
"""
|
| 118 |
+
clean_preds = [normalize_answer(p) for p in preds]
|
| 119 |
+
clean_refs = [normalize_answer(r) if isinstance(r, str) else normalize_answer(r[0] if r else "") for r in refs]
|
| 120 |
+
|
| 121 |
+
matches = sum(1 for p, r in zip(clean_preds, clean_refs) if p == r)
|
| 122 |
+
return matches / len(clean_preds) if clean_preds else 0.0
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def compute_f1_batch(preds: List[str], refs: List[str]) -> float:
|
| 126 |
+
"""
|
| 127 |
+
Compute F1-score efficiently in batch.
|
| 128 |
+
|
| 129 |
+
Performance: Vectorized token comparison
|
| 130 |
+
"""
|
| 131 |
+
f1_scores = []
|
| 132 |
+
|
| 133 |
+
for pred, ref in zip(preds, refs):
|
| 134 |
+
p_toks = normalize_answer(pred).split()
|
| 135 |
+
r_toks = normalize_answer(ref).split() if isinstance(ref, str) else normalize_answer(ref[0] if ref else "").split()
|
| 136 |
+
|
| 137 |
+
if not p_toks or not r_toks:
|
| 138 |
+
f1 = float(p_toks == r_toks)
|
| 139 |
+
else:
|
| 140 |
+
common = Counter(p_toks) & Counter(r_toks)
|
| 141 |
+
num_same = sum(common.values())
|
| 142 |
+
|
| 143 |
+
if num_same == 0:
|
| 144 |
+
f1 = 0.0
|
| 145 |
+
else:
|
| 146 |
+
precision = num_same / len(p_toks)
|
| 147 |
+
recall = num_same / len(r_toks)
|
| 148 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 149 |
+
|
| 150 |
+
f1_scores.append(f1)
|
| 151 |
+
|
| 152 |
+
return np.mean(f1_scores) if f1_scores else 0.0
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def batch_metrics_optimized(predictions: List[str], references: List[str],
|
| 156 |
+
use_bertscore: bool = True,
|
| 157 |
+
use_rouge: bool = True,
|
| 158 |
+
device: str = "cuda") -> Dict[str, float]:
|
| 159 |
+
"""
|
| 160 |
+
Compute all metrics efficiently in batch mode.
|
| 161 |
+
|
| 162 |
+
Key optimizations:
|
| 163 |
+
- BERTScore: Batch computation (10-20x faster)
|
| 164 |
+
- ROUGE: Vectorized computation
|
| 165 |
+
- F1/EM: Parallel token processing
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
predictions: List of predictions
|
| 169 |
+
references: List of references
|
| 170 |
+
use_bertscore: Include BERTScore
|
| 171 |
+
use_rouge: Include ROUGE scores
|
| 172 |
+
device: Device for computation
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Dictionary of all metrics
|
| 176 |
+
|
| 177 |
+
Performance gain: 95% reduction in evaluation time
|
| 178 |
+
"""
|
| 179 |
+
metrics = {}
|
| 180 |
+
|
| 181 |
+
# Core metrics (fast)
|
| 182 |
+
metrics['exact_match'] = compute_exact_match_batch(predictions, references)
|
| 183 |
+
metrics['f1'] = compute_f1_batch(predictions, references)
|
| 184 |
+
|
| 185 |
+
# Semantic metrics (optimized with batching)
|
| 186 |
+
if use_bertscore:
|
| 187 |
+
metrics['bert_score'] = compute_bertscore_batch(
|
| 188 |
+
predictions, references,
|
| 189 |
+
device=device
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if use_rouge:
|
| 193 |
+
rouge_scores = compute_rouge_batch(predictions, references)
|
| 194 |
+
metrics.update(rouge_scores)
|
| 195 |
+
|
| 196 |
+
return metrics
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# Compatibility wrapper for existing code
|
| 200 |
+
def compute_bertscore(preds: list, refs: list) -> float:
|
| 201 |
+
"""Legacy wrapper for backward compatibility."""
|
| 202 |
+
return compute_bertscore_batch(preds, refs)
|
src/utils/text_utils.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import Counter
|
| 3 |
+
|
| 4 |
+
from underthesea import text_normalize as uts_text_normalize, word_tokenize
|
| 5 |
+
|
| 6 |
+
_MEDICAL_TERM_MAP = {
|
| 7 |
+
"xray": "x-quang",
|
| 8 |
+
"x ray": "x-quang",
|
| 9 |
+
"x-ray": "x-quang",
|
| 10 |
+
"x quang": "x-quang",
|
| 11 |
+
"mri scan": "mri",
|
| 12 |
+
"mr": "mri",
|
| 13 |
+
"ct scan": "ct",
|
| 14 |
+
"ct-scan": "ct",
|
| 15 |
+
"cat scan": "ct",
|
| 16 |
+
"computed tomography": "ct",
|
| 17 |
+
"transverse plane": "mặt phẳng ngang",
|
| 18 |
+
"transverse plane": "mặt phẳng ngang",
|
| 19 |
+
"coronal plane": "mặt phẳng vành",
|
| 20 |
+
"sagittal plane": "mặt phẳng dọc",
|
| 21 |
+
"elliptical": "hình elip",
|
| 22 |
+
"spleen": "lách",
|
| 23 |
+
"liver": "gan",
|
| 24 |
+
"lung": "phổi",
|
| 25 |
+
"lungs": "phổi",
|
| 26 |
+
"heart": "tim",
|
| 27 |
+
"brain": "não",
|
| 28 |
+
"kidney": "thận",
|
| 29 |
+
"bladder": "bàng quang",
|
| 30 |
+
"cardiomegaly": "tim to",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
_NON_CANONICAL_ALIASES = {
|
| 34 |
+
"xray",
|
| 35 |
+
"x ray",
|
| 36 |
+
"x-ray",
|
| 37 |
+
"x quang",
|
| 38 |
+
"mri scan",
|
| 39 |
+
"mr",
|
| 40 |
+
"ct scan",
|
| 41 |
+
"ct-scan",
|
| 42 |
+
"cat scan",
|
| 43 |
+
"computed tomography",
|
| 44 |
+
"transverse plane",
|
| 45 |
+
"coronal plane",
|
| 46 |
+
"sagittal plane",
|
| 47 |
+
"elliptical",
|
| 48 |
+
"spleen",
|
| 49 |
+
"liver",
|
| 50 |
+
"lung",
|
| 51 |
+
"lungs",
|
| 52 |
+
"heart",
|
| 53 |
+
"brain",
|
| 54 |
+
"kidney",
|
| 55 |
+
"bladder",
|
| 56 |
+
"cardiomegaly",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def text_normalize(text: str) -> str:
|
| 61 |
+
"""Wrapper để chuẩn hóa Unicode và spacing cho tiếng Việt."""
|
| 62 |
+
if not text:
|
| 63 |
+
return ""
|
| 64 |
+
return uts_text_normalize(str(text))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def normalize_answer(text: str) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Chuẩn hóa đáp án về dạng canonical để train/eval ổn định.
|
| 70 |
+
"""
|
| 71 |
+
if not text:
|
| 72 |
+
return ""
|
| 73 |
+
|
| 74 |
+
text = text_normalize(str(text))
|
| 75 |
+
text = text.replace("_", " ")
|
| 76 |
+
text = text.lower().strip()
|
| 77 |
+
text = re.sub(r"[@#]{1,2}", " ", text)
|
| 78 |
+
text = re.sub(r"[“”\"']", "", text)
|
| 79 |
+
text = re.sub(r"[,:;!?()\[\]{}]+", " ", text)
|
| 80 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 81 |
+
|
| 82 |
+
for src, dst in sorted(_MEDICAL_TERM_MAP.items(), key=lambda item: -len(item[0])):
|
| 83 |
+
text = re.sub(rf"\b{re.escape(src)}\b", dst, text)
|
| 84 |
+
|
| 85 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 86 |
+
text = re.sub(r"[.\-]+$", "", text).strip()
|
| 87 |
+
return text
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _tokenize_vietnamese_words(text: str) -> list[str]:
|
| 91 |
+
normalized = normalize_answer(text)
|
| 92 |
+
if not normalized:
|
| 93 |
+
return []
|
| 94 |
+
try:
|
| 95 |
+
tokens = word_tokenize(normalized)
|
| 96 |
+
return [token.strip() for token in tokens if token and token.strip()]
|
| 97 |
+
except Exception:
|
| 98 |
+
return normalized.split()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def count_words(text: str) -> int:
|
| 102 |
+
return len(_tokenize_vietnamese_words(text))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _trim_to_max_words(text: str, max_words: int) -> str:
|
| 106 |
+
words = _tokenize_vietnamese_words(text)
|
| 107 |
+
if len(words) <= max_words:
|
| 108 |
+
return " ".join(words)
|
| 109 |
+
return " ".join(words[:max_words])
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _choose_best_answer_text(answer_vi: str, answer_full_vi: str, max_words: int) -> str:
|
| 113 |
+
short_answer = normalize_answer(answer_vi)
|
| 114 |
+
full_answer = normalize_answer(answer_full_vi)
|
| 115 |
+
|
| 116 |
+
if short_answer and count_words(short_answer) <= max_words:
|
| 117 |
+
return short_answer
|
| 118 |
+
if full_answer:
|
| 119 |
+
return _trim_to_max_words(full_answer, max_words)
|
| 120 |
+
return _trim_to_max_words(short_answer, max_words)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_target_answer(item: dict, max_words: int = 10) -> str:
|
| 124 |
+
"""
|
| 125 |
+
Chọn target answer ngắn, chuẩn hóa và không vượt quá số từ cho phép.
|
| 126 |
+
"""
|
| 127 |
+
answer_vi = item.get("answer_vi", "")
|
| 128 |
+
answer_full_vi = item.get("answer_full_vi", "")
|
| 129 |
+
answer = _choose_best_answer_text(answer_vi, answer_full_vi, max_words=max_words)
|
| 130 |
+
if answer:
|
| 131 |
+
return answer
|
| 132 |
+
fallback = item.get("answer", "")
|
| 133 |
+
return _trim_to_max_words(fallback, max_words)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def postprocess_answer(text: str, max_words: int = 10) -> str:
|
| 137 |
+
"""
|
| 138 |
+
Chuẩn hóa output model và cắt ngắn về tối đa `max_words`.
|
| 139 |
+
Không mở rộng câu trả lời để tránh làm xấu exact match.
|
| 140 |
+
"""
|
| 141 |
+
if not text:
|
| 142 |
+
return ""
|
| 143 |
+
text = clean_vqa_output(text)
|
| 144 |
+
text = normalize_answer(text)
|
| 145 |
+
return _trim_to_max_words(text, max_words=max_words)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def is_medical_term_compliant(text: str) -> bool:
|
| 149 |
+
"""
|
| 150 |
+
Heuristic nhẹ: không còn alias y khoa phổ biến chưa canonicalize.
|
| 151 |
+
"""
|
| 152 |
+
normalized = normalize_answer(text)
|
| 153 |
+
if not normalized:
|
| 154 |
+
return False
|
| 155 |
+
for alias in _NON_CANONICAL_ALIASES:
|
| 156 |
+
if re.search(rf"\b{re.escape(alias)}\b", normalized):
|
| 157 |
+
return False
|
| 158 |
+
return True
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def majority_answer(answers: list[str]) -> str:
|
| 162 |
+
"""
|
| 163 |
+
Trả về câu trả lời xuất hiện nhiều nhất trong danh sách.
|
| 164 |
+
"""
|
| 165 |
+
if not answers:
|
| 166 |
+
return ""
|
| 167 |
+
if isinstance(answers, str):
|
| 168 |
+
return normalize_answer(answers)
|
| 169 |
+
counts = Counter([normalize_answer(a) for a in answers])
|
| 170 |
+
return counts.most_common(1)[0][0]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def clean_vqa_output(text: str) -> str:
|
| 174 |
+
"""
|
| 175 |
+
Làm sạch output từ tokenizer trước khi postprocess.
|
| 176 |
+
"""
|
| 177 |
+
if not text:
|
| 178 |
+
return ""
|
| 179 |
+
text = re.sub(r"@@\s?", "", text)
|
| 180 |
+
text = re.sub(r"##_?", "", text)
|
| 181 |
+
text = re.sub(r"^\s*yes\s*,?\s*", "có ", text, flags=re.IGNORECASE)
|
| 182 |
+
text = re.sub(r"^\s*no\s*,?\s*", "không ", text, flags=re.IGNORECASE)
|
| 183 |
+
text = re.sub(
|
| 184 |
+
r"^\s*(the answer is|the image is|this image is|the scan is|the ct is|the mri is|there is|there are)\s+",
|
| 185 |
+
"",
|
| 186 |
+
text,
|
| 187 |
+
flags=re.IGNORECASE,
|
| 188 |
+
)
|
| 189 |
+
text = re.sub(
|
| 190 |
+
r"^(có|không)\s+(the\s+)?(image|scan|x-ray|xray|mri|ct|picture|photo|radiograph)\s+(is|shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+",
|
| 191 |
+
r"\1 ",
|
| 192 |
+
text,
|
| 193 |
+
flags=re.IGNORECASE,
|
| 194 |
+
)
|
| 195 |
+
text = re.sub(
|
| 196 |
+
r"^(the\s+)?(image|scan|x-ray|xray|mri|ct|picture|photo|radiograph)\s+(is|shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+",
|
| 197 |
+
"",
|
| 198 |
+
text,
|
| 199 |
+
flags=re.IGNORECASE,
|
| 200 |
+
)
|
| 201 |
+
text = re.sub(r"\b(answer|response|assistant|trả lời)\b\s*:?\s*$", "", text, flags=re.IGNORECASE)
|
| 202 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 203 |
+
return text
|
src/utils/translator.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from src.utils.text_utils import postprocess_answer
|
| 6 |
+
|
| 7 |
+
class MedicalTranslator:
|
| 8 |
+
"""
|
| 9 |
+
Dịch thuật y tế với cơ chế Lazy Loading + Independent Fallback.
|
| 10 |
+
- Vi→En: MarianMT (Helsinki-NLP) trên CPU
|
| 11 |
+
- En→Vi: MedCrab-1.5B (4-bit) trên GPU phụ (nếu có)
|
| 12 |
+
Mỗi model load độc lập — nếu 1 cái fail, cái kia vẫn hoạt động.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, device="cpu", dict_path="data/medical_dict.json"):
|
| 15 |
+
self.device_str = device # "cuda" hoặc "cpu"
|
| 16 |
+
|
| 17 |
+
# Chọn GPU: nếu Dual GPU → dùng cuda:1, nếu Single → dùng cuda:0
|
| 18 |
+
if torch.cuda.is_available() and device == "cuda":
|
| 19 |
+
if torch.cuda.device_count() > 1:
|
| 20 |
+
self.gpu_device = torch.device("cuda:1")
|
| 21 |
+
print(f"[INFO] Dual-GPU detected → Translator on {self.gpu_device}")
|
| 22 |
+
else:
|
| 23 |
+
self.gpu_device = torch.device("cuda:0")
|
| 24 |
+
else:
|
| 25 |
+
self.gpu_device = torch.device("cpu")
|
| 26 |
+
|
| 27 |
+
# State flags
|
| 28 |
+
self._load_attempted = False
|
| 29 |
+
self._vi2en_ready = False
|
| 30 |
+
self._en2vi_ready = False
|
| 31 |
+
|
| 32 |
+
# Models (lazy)
|
| 33 |
+
self._vi2en_model = None
|
| 34 |
+
self._vi2en_tokenizer = None
|
| 35 |
+
self._en2vi_model = None
|
| 36 |
+
self._en2vi_tokenizer = None
|
| 37 |
+
|
| 38 |
+
# Medical dictionary
|
| 39 |
+
self.med_dict = {}
|
| 40 |
+
if os.path.exists(dict_path):
|
| 41 |
+
try:
|
| 42 |
+
with open(dict_path, 'r', encoding='utf-8') as f:
|
| 43 |
+
self.med_dict = json.load(f)
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
def _lazy_load(self):
|
| 48 |
+
"""Nạp models. Chỉ gọi 1 lần duy nhất."""
|
| 49 |
+
if self._load_attempted:
|
| 50 |
+
return
|
| 51 |
+
self._load_attempted = True
|
| 52 |
+
print("[INFO] Đang nạp Translation Models (Lazy Load)...")
|
| 53 |
+
|
| 54 |
+
# ── 1. Helsinki-NLP Vi→En (Chạy trên CPU, nhẹ ~300MB) ──
|
| 55 |
+
try:
|
| 56 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 57 |
+
vi2en_id = "Helsinki-NLP/opus-mt-vi-en"
|
| 58 |
+
self._vi2en_tokenizer = AutoTokenizer.from_pretrained(vi2en_id)
|
| 59 |
+
self._vi2en_model = AutoModelForSeq2SeqLM.from_pretrained(vi2en_id).to("cpu")
|
| 60 |
+
self._vi2en_model.eval()
|
| 61 |
+
self._vi2en_ready = True
|
| 62 |
+
print("[INFO] ✅ Helsinki-NLP (Vi→En) đã sẵn sàng trên CPU")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"[WARNING] ❌ Helsinki-NLP load thất bại: {e}")
|
| 65 |
+
|
| 66 |
+
# ── 2. MedCrab En→Vi (4-bit trên GPU) ──
|
| 67 |
+
try:
|
| 68 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 69 |
+
bnb_config = BitsAndBytesConfig(
|
| 70 |
+
load_in_4bit=True,
|
| 71 |
+
bnb_4bit_use_double_quant=True,
|
| 72 |
+
bnb_4bit_quant_type="nf4",
|
| 73 |
+
bnb_4bit_compute_dtype=torch.float16
|
| 74 |
+
)
|
| 75 |
+
medcrab_id = "pnnbao-ump/MedCrab-1.5B"
|
| 76 |
+
self._en2vi_tokenizer = AutoTokenizer.from_pretrained(medcrab_id)
|
| 77 |
+
|
| 78 |
+
d_map = {"": self.gpu_device} if self.gpu_device.type == "cuda" else None
|
| 79 |
+
self._en2vi_model = AutoModelForCausalLM.from_pretrained(
|
| 80 |
+
medcrab_id,
|
| 81 |
+
quantization_config=bnb_config,
|
| 82 |
+
device_map=d_map,
|
| 83 |
+
low_cpu_mem_usage=True
|
| 84 |
+
)
|
| 85 |
+
self._en2vi_model.eval()
|
| 86 |
+
self._en2vi_ready = True
|
| 87 |
+
print(f"[INFO] ✅ MedCrab-1.5B (En→Vi) đã sẵn sàng trên {self.gpu_device}")
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"[WARNING] ❌ MedCrab load thất bại: {e}")
|
| 90 |
+
|
| 91 |
+
# ── Vi → En ──
|
| 92 |
+
def translate_vi2en(self, text):
|
| 93 |
+
"""Dịch câu hỏi Tiếng Việt sang Tiếng Anh."""
|
| 94 |
+
if not text:
|
| 95 |
+
return text
|
| 96 |
+
self._lazy_load()
|
| 97 |
+
|
| 98 |
+
if not self._vi2en_ready:
|
| 99 |
+
# Fallback: trả về nguyên văn (LLaVA vẫn hiểu được một phần)
|
| 100 |
+
return text
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
texts = text if isinstance(text, list) else [text]
|
| 104 |
+
results = []
|
| 105 |
+
for t in texts:
|
| 106 |
+
inputs = self._vi2en_tokenizer(t, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
output_ids = self._vi2en_model.generate(**inputs, max_new_tokens=128)
|
| 109 |
+
translated = self._vi2en_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 110 |
+
results.append(translated)
|
| 111 |
+
return results if isinstance(text, list) else results[0]
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"[WARNING] Vi→En error: {e}")
|
| 114 |
+
return text
|
| 115 |
+
|
| 116 |
+
# ── En → Vi ──
|
| 117 |
+
def translate_en2vi(self, text):
|
| 118 |
+
"""Dịch kết quả từ LLaVA-Med sang Tiếng Việt."""
|
| 119 |
+
if not text:
|
| 120 |
+
return text
|
| 121 |
+
|
| 122 |
+
# 1. Ánh xạ trực tiếp nhãn nhị phân (nhanh + chính xác 100%)
|
| 123 |
+
if isinstance(text, str):
|
| 124 |
+
t = text.lower().strip().rstrip(".").rstrip(",").strip()
|
| 125 |
+
|
| 126 |
+
# Xử lý các câu trả lời dài bắt đầu bằng Yes/No của LLaVA (vd: "No, the image does not...")
|
| 127 |
+
if t.startswith("yes"):
|
| 128 |
+
return "có"
|
| 129 |
+
if t.startswith("no"):
|
| 130 |
+
return "không"
|
| 131 |
+
|
| 132 |
+
# Exact match trước
|
| 133 |
+
direct_map = {
|
| 134 |
+
"true": "có", "false": "không",
|
| 135 |
+
"correct": "có", "incorrect": "không",
|
| 136 |
+
"present": "có", "absent": "không",
|
| 137 |
+
"normal": "bình thường", "abnormal": "bất thường",
|
| 138 |
+
}
|
| 139 |
+
if t in direct_map:
|
| 140 |
+
return direct_map[t]
|
| 141 |
+
|
| 142 |
+
# 2. Dịch bằng MedCrab
|
| 143 |
+
self._lazy_load()
|
| 144 |
+
if not self._en2vi_ready:
|
| 145 |
+
if isinstance(text, list):
|
| 146 |
+
return text
|
| 147 |
+
return text
|
| 148 |
+
|
| 149 |
+
if isinstance(text, list):
|
| 150 |
+
return [self._medcrab_translate(t) for t in text]
|
| 151 |
+
return self._medcrab_translate(text)
|
| 152 |
+
|
| 153 |
+
def _medcrab_translate(self, text):
|
| 154 |
+
"""Dịch 1 câu En→Vi bằng MedCrab với ràng buộc ngắn gọn."""
|
| 155 |
+
# Kiểm tra ánh xạ trực tiếp trước
|
| 156 |
+
t = text.lower().strip().rstrip(".").rstrip(",").strip()
|
| 157 |
+
direct_map = {
|
| 158 |
+
"yes": "có", "no": "không",
|
| 159 |
+
"normal": "bình thường", "abnormal": "bất thường",
|
| 160 |
+
}
|
| 161 |
+
if t in direct_map:
|
| 162 |
+
return direct_map[t]
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
prompt = f"English: {text}\nVietnamese (trả lời ngắn gọn):"
|
| 166 |
+
inputs = self._en2vi_tokenizer(prompt, return_tensors="pt").to(self.gpu_device)
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
outputs = self._en2vi_model.generate(
|
| 170 |
+
**inputs,
|
| 171 |
+
max_new_tokens=30,
|
| 172 |
+
repetition_penalty=1.2,
|
| 173 |
+
temperature=0.1,
|
| 174 |
+
do_sample=False,
|
| 175 |
+
pad_token_id=self._en2vi_tokenizer.eos_token_id
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
full_text = self._en2vi_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 179 |
+
translated = full_text.split("Vietnamese (trả lời ngắn gọn):")[-1].strip()
|
| 180 |
+
return translated
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"[WARNING] En→Vi error: {e}")
|
| 183 |
+
return text
|
src/utils/visualization.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
|
| 6 |
+
def apply_clahe(img_array):
|
| 7 |
+
"""
|
| 8 |
+
Áp dụng Contrast Limited Adaptive Histogram Equalization (CLAHE).
|
| 9 |
+
Giúp tăng cường độ tương phản cục bộ cho ảnh X-ray.
|
| 10 |
+
"""
|
| 11 |
+
# Nếu ảnh đang ở dạng float [0, 1], chuyển về uint8 [0, 255]
|
| 12 |
+
if img_array.max() <= 1.0:
|
| 13 |
+
img_array = (img_array * 255).astype(np.uint8)
|
| 14 |
+
|
| 15 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
| 16 |
+
|
| 17 |
+
# Xử lý cho ảnh xám (Grayscale)
|
| 18 |
+
if len(img_array.shape) == 2:
|
| 19 |
+
img_clahe = clahe.apply(img_array)
|
| 20 |
+
# Xử lý cho ảnh màu (RGB) - Chuyển sang LAB để giữ màu sắc
|
| 21 |
+
else:
|
| 22 |
+
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
| 23 |
+
l, a, b = cv2.split(lab)
|
| 24 |
+
l_clahe = clahe.apply(l)
|
| 25 |
+
img_clahe = cv2.merge((l_clahe, a, b))
|
| 26 |
+
img_clahe = cv2.cvtColor(img_clahe, cv2.COLOR_LAB2RGB)
|
| 27 |
+
|
| 28 |
+
return img_clahe.astype(np.float32) / 255.0
|
| 29 |
+
|
| 30 |
+
class MedicalImageTransform:
|
| 31 |
+
"""
|
| 32 |
+
Custom transform tích hợp CLAHE và chuẩn hóa cho Medical VQA (Hướng A - XRV).
|
| 33 |
+
"""
|
| 34 |
+
def __init__(self, size=224):
|
| 35 |
+
self.resize = transforms.Resize((size, size))
|
| 36 |
+
self.normalize = transforms.Normalize(mean=[0.5], std=[0.5])
|
| 37 |
+
|
| 38 |
+
def __call__(self, img):
|
| 39 |
+
# 1. Resize
|
| 40 |
+
img = self.resize(img)
|
| 41 |
+
|
| 42 |
+
# 2. Apply CLAHE (Tăng cường độ tương phản y tế)
|
| 43 |
+
img_np = np.array(img)
|
| 44 |
+
img_clahe = apply_clahe(img_np) # Trả về ảnh [0, 1]
|
| 45 |
+
|
| 46 |
+
# 3. Chuyển sang Tensor 1 kênh cho DenseNet XRV
|
| 47 |
+
# img_clahe shape: [224, 224]
|
| 48 |
+
img_tensor = torch.from_numpy(img_clahe).unsqueeze(0) # [1, 224, 224]
|
| 49 |
+
|
| 50 |
+
# 4. Chuẩn hóa về dải [-1024, 1024] cho DenseNet XRV.
|
| 51 |
+
# XRV được train trên dải cường độ cao này để bảo tồn chi tiết y tế.
|
| 52 |
+
img_tensor = img_tensor * 2048.0 - 1024.0
|
| 53 |
+
return img_tensor
|
web/README.md
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Medical VQA Web
|
| 2 |
+
|
| 3 |
+
Thư mục này chứa FastAPI + web UI để:
|
| 4 |
+
|
| 5 |
+
- upload ảnh
|
| 6 |
+
- nhập câu hỏi VQA
|
| 7 |
+
- chạy dự đoán
|
| 8 |
+
- so sánh 6 model: `A1`, `A2`, `B1`, `B2`, `DPO`, `PPO`
|
| 9 |
+
|
| 10 |
+
### Chạy server
|
| 11 |
+
|
| 12 |
+
Từ thư mục gốc project:
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
uvicorn web.main:app --reload --host 0.0.0.0 --port 8000
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Nếu muốn preload toàn bộ model khi startup trên GPU:
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
WEB_PRELOAD_MODELS=1 uvicorn web.main:app --host 0.0.0.0 --port 8000
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Khi chạy trên GPU, nên để `--workers 1` để tránh mỗi worker nạp một bản model riêng.
|
| 25 |
+
|
| 26 |
+
### Chạy bằng Docker
|
| 27 |
+
|
| 28 |
+
Build image:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
docker build -t medical-vqa-web .
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Run container trên máy có GPU:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
docker run --rm \
|
| 38 |
+
--gpus all \
|
| 39 |
+
-p 8000:8000 \
|
| 40 |
+
-e WEB_PRELOAD_MODELS=1 \
|
| 41 |
+
-v medical-vqa-hf-cache:/hf_cache \
|
| 42 |
+
medical-vqa-web
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Nếu muốn chạy lại nhanh hơn, giữ volume cache `medical-vqa-hf-cache` để không tải lại model Hugging Face mỗi lần.
|
| 46 |
+
|
| 47 |
+
### Tùy chọn: rewrite output bằng Qwen
|
| 48 |
+
|
| 49 |
+
Lớp rewrite hiện đã bật mặc định và sẽ tự thử load Qwen từ Hugging Face Hub khi server khởi động.
|
| 50 |
+
Nếu bạn muốn đổi sang model repo khác trên Hub, đặt thêm các biến môi trường sau:
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
ANSWER_REWRITE_ENABLED=1
|
| 54 |
+
ANSWER_REWRITE_MODEL_ID=Qwen/Qwen2.5-14B-Instruct
|
| 55 |
+
ANSWER_REWRITE_USE_4BIT=1
|
| 56 |
+
ANSWER_REWRITE_MAX_NEW_TOKENS=28
|
| 57 |
+
ANSWER_REWRITE_MAX_WORDS=10
|
| 58 |
+
ANSWER_REWRITE_HF_TOKEN=hf_...
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Lớp này chỉ rewrite phần output hiển thị, không thay thế model VQA chính. Nếu model rewrite không load được, hệ thống sẽ tự fallback về output hiện tại.
|
| 62 |
+
|
| 63 |
+
Mở:
|
| 64 |
+
|
| 65 |
+
```text
|
| 66 |
+
http://localhost:8000
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### API
|
| 70 |
+
|
| 71 |
+
- `GET /health`
|
| 72 |
+
- kiểm tra trạng thái server và artifact khả dụng
|
| 73 |
+
- `GET /v1/models`
|
| 74 |
+
- trả metadata 6 model
|
| 75 |
+
- `POST /v1/predict`
|
| 76 |
+
- form-data:
|
| 77 |
+
- `question`: câu hỏi VQA
|
| 78 |
+
- `image`: ảnh đầu vào
|
| 79 |
+
- `model_name` hoặc `model_names`:
|
| 80 |
+
- nếu bỏ trống thì chạy toàn bộ 6 model
|
| 81 |
+
- `model_names` nhận chuỗi JSON list hoặc chuỗi phân tách bằng dấu phẩy
|
| 82 |
+
|
| 83 |
+
### Artifact cần có
|
| 84 |
+
|
| 85 |
+
- `A1`: `checkpoints/medical_vqa_A1_best.pth`
|
| 86 |
+
- `A2`: `checkpoints/medical_vqa_A2_best.pth`
|
| 87 |
+
- `B1`: model base từ `model_b.model_name` trong `configs/medical_vqa.yaml`
|
| 88 |
+
- `B2`: checkpoint tốt nhất trong `checkpoints/B2/checkpoint-*`
|
| 89 |
+
- `DPO`: `checkpoints/DPO/final_adapter` hoặc `checkpoints/DPO/checkpoint-25`
|
| 90 |
+
- `PPO`: `checkpoints/PPO/final_adapter`
|
| 91 |
+
|
| 92 |
+
### Lưu ý
|
| 93 |
+
|
| 94 |
+
- `B1`, `B2`, `DPO`, `PPO` cần CUDA để chạy ổn trong cấu hình hiện tại.
|
| 95 |
+
- Nếu một model chưa có artifact hoặc không đủ điều kiện chạy, UI vẫn hiển thị lỗi riêng cho model đó thay vì làm hỏng toàn bộ request.
|
| 96 |
+
- Web giữ model trong cache sau lần load đầu tiên, nên request sau sẽ nhanh hơn đáng kể.
|
web/main.py
ADDED
|
@@ -0,0 +1,978 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import collections
|
| 3 |
+
import gc
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import time
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
| 14 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 15 |
+
from fastapi.staticfiles import StaticFiles
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from peft import PeftModel
|
| 18 |
+
from transformers import AutoTokenizer, LlavaForConditionalGeneration, LlavaProcessor
|
| 19 |
+
from src.models.medical_vqa_model import MedicalVQAModelA
|
| 20 |
+
from src.models.multimodal_vqa import MultimodalVQA
|
| 21 |
+
from src.utils.answer_rewriter import MedicalAnswerRewriter
|
| 22 |
+
from src.utils.helpers import majority_answer
|
| 23 |
+
from src.utils.text_utils import postprocess_answer
|
| 24 |
+
from src.utils.translator import MedicalTranslator
|
| 25 |
+
from src.utils.visualization import MedicalImageTransform
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
ROOT_DIR = Path(__file__).resolve().parent.parent
|
| 29 |
+
CONFIG_PATH = ROOT_DIR / "configs" / "medical_vqa.yaml"
|
| 30 |
+
|
| 31 |
+
VARIANT_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO"]
|
| 32 |
+
|
| 33 |
+
VARIANT_META = {
|
| 34 |
+
"A1": {
|
| 35 |
+
"family": "A",
|
| 36 |
+
"title": "A1",
|
| 37 |
+
"subtitle": "LSTM baseline",
|
| 38 |
+
"description": "DenseNet-121 + PhoBERT + LSTM",
|
| 39 |
+
},
|
| 40 |
+
"A2": {
|
| 41 |
+
"family": "A",
|
| 42 |
+
"title": "A2",
|
| 43 |
+
"subtitle": "Transformer decoder",
|
| 44 |
+
"description": "DenseNet-121 + PhoBERT + Transformer",
|
| 45 |
+
},
|
| 46 |
+
"B1": {
|
| 47 |
+
"family": "B",
|
| 48 |
+
"title": "B1",
|
| 49 |
+
"subtitle": "Zero-shot",
|
| 50 |
+
"description": "LLaVA-Med base",
|
| 51 |
+
},
|
| 52 |
+
"B2": {
|
| 53 |
+
"family": "B",
|
| 54 |
+
"title": "B2",
|
| 55 |
+
"subtitle": "Fine-tuned",
|
| 56 |
+
"description": "LLaVA-Med + LoRA",
|
| 57 |
+
},
|
| 58 |
+
"DPO": {
|
| 59 |
+
"family": "B",
|
| 60 |
+
"title": "DPO",
|
| 61 |
+
"subtitle": "Alignment",
|
| 62 |
+
"description": "B2 + Direct Preference Optimization",
|
| 63 |
+
},
|
| 64 |
+
"PPO": {
|
| 65 |
+
"family": "B",
|
| 66 |
+
"title": "PPO",
|
| 67 |
+
"subtitle": "RL refinement",
|
| 68 |
+
"description": "B2 + Proximal Policy Optimization",
|
| 69 |
+
},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
SUGGESTION_DATA_PATH = ROOT_DIR / "data" / "merged_vqa_vi_cleaned.json"
|
| 73 |
+
SUGGESTION_LIMIT = int(os.getenv("WEB_SUGGESTION_LIMIT", "8"))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _read_config() -> dict[str, Any]:
|
| 77 |
+
try:
|
| 78 |
+
import yaml
|
| 79 |
+
|
| 80 |
+
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
| 81 |
+
return yaml.safe_load(f) or {}
|
| 82 |
+
except Exception as exc:
|
| 83 |
+
raise RuntimeError(f"Failed to read config at {CONFIG_PATH}: {exc}") from exc
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
CFG = _read_config()
|
| 87 |
+
|
| 88 |
+
app = FastAPI(title="Medical VQA Compare API", version="2.0.0")
|
| 89 |
+
|
| 90 |
+
static_dir = os.path.join(os.path.dirname(__file__), "static")
|
| 91 |
+
if os.path.isdir(static_dir):
|
| 92 |
+
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class VQAServerState:
|
| 96 |
+
def __init__(self) -> None:
|
| 97 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 98 |
+
if self.device.type == "cuda":
|
| 99 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 100 |
+
torch.set_float32_matmul_precision("high")
|
| 101 |
+
self.image_size = int(CFG.get("data", {}).get("image_size", 224))
|
| 102 |
+
self.answer_max_words = int(CFG.get("data", {}).get("answer_max_words", 10))
|
| 103 |
+
self.max_question_len = int(CFG.get("data", {}).get("max_question_len", 64))
|
| 104 |
+
self.max_answer_len = int(CFG.get("data", {}).get("max_answer_len", 20))
|
| 105 |
+
self.model_a_cfg = CFG.get("model_a", {})
|
| 106 |
+
self.model_b_cfg = CFG.get("model_b", {})
|
| 107 |
+
self.eval_cfg = CFG.get("eval", {})
|
| 108 |
+
self.models_dir = ROOT_DIR / "checkpoints"
|
| 109 |
+
self.qa_tokenizer = None
|
| 110 |
+
self.translator = MedicalTranslator(device="cpu")
|
| 111 |
+
self.answer_rewriter = MedicalAnswerRewriter()
|
| 112 |
+
self.image_transform = MedicalImageTransform(size=self.image_size)
|
| 113 |
+
self.cache_lock = asyncio.Lock()
|
| 114 |
+
self.b_lock = asyncio.Lock()
|
| 115 |
+
self.a_models: dict[str, dict[str, Any]] = {}
|
| 116 |
+
self.llava_bundle: dict[str, Any] | None = None
|
| 117 |
+
self.question_suggestions: list[dict[str, Any]] = []
|
| 118 |
+
self.preload_models = os.getenv("WEB_PRELOAD_MODELS", "1" if self.device.type == "cuda" else "0") == "1"
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def phobert_model(self) -> str:
|
| 122 |
+
return self.model_a_cfg.get("phobert_model", "vinai/phobert-base")
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def llava_model_id(self) -> str:
|
| 126 |
+
return self.model_b_cfg.get("model_name", "chaoyinshe/llava-med-v1.5-mistral-7b-hf")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
state = VQAServerState()
|
| 130 |
+
load_lock = asyncio.Lock()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _artifact_exists(path: Path) -> bool:
|
| 134 |
+
return path.exists()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _as_bool(value: Any) -> bool:
|
| 138 |
+
if isinstance(value, bool):
|
| 139 |
+
return value
|
| 140 |
+
if value is None:
|
| 141 |
+
return False
|
| 142 |
+
return str(value).strip().lower() in {"true", "1", "yes", "y"}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _normalize_text_key(text: Any) -> str:
|
| 146 |
+
normalized = str(text or "").strip().lower()
|
| 147 |
+
normalized = re.sub(r"\s+", " ", normalized)
|
| 148 |
+
return normalized
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _suggestion_category(item: dict[str, Any], question: str) -> str:
|
| 152 |
+
content_type = str(item.get("content_type", "")).strip()
|
| 153 |
+
if content_type:
|
| 154 |
+
return content_type
|
| 155 |
+
|
| 156 |
+
q = question.lower()
|
| 157 |
+
if any(token in q for token in ["bất thường", "abnormal", "normal", "có vẻ"]):
|
| 158 |
+
return "Abnormality"
|
| 159 |
+
if any(token in q for token in ["phương thức", "modality", "chụp", "scan", "x-ray", "ct", "mri"]):
|
| 160 |
+
return "Modality"
|
| 161 |
+
if any(token in q for token in ["mặt phẳng", "plane", "lát cắt"]):
|
| 162 |
+
return "Plane"
|
| 163 |
+
if any(token in q for token in ["bao nhiêu", "how many", "số lượng"]):
|
| 164 |
+
return "Quantity"
|
| 165 |
+
if any(token in q for token in ["màu", "color"]):
|
| 166 |
+
return "Color"
|
| 167 |
+
if any(token in q for token in ["ở đâu", "vị trí", "where"]):
|
| 168 |
+
return "Position"
|
| 169 |
+
if any(token in q for token in ["chứa", "contain", "có "]):
|
| 170 |
+
return "Organ"
|
| 171 |
+
return "General"
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _load_question_suggestions(limit: int = SUGGESTION_LIMIT) -> list[dict[str, Any]]:
|
| 175 |
+
if not SUGGESTION_DATA_PATH.exists():
|
| 176 |
+
return []
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
with SUGGESTION_DATA_PATH.open("r", encoding="utf-8") as f:
|
| 180 |
+
dataset = json.load(f)
|
| 181 |
+
except Exception as exc:
|
| 182 |
+
print(f"[WARNING] Failed to read suggestion dataset: {exc}")
|
| 183 |
+
return []
|
| 184 |
+
|
| 185 |
+
groups: dict[str, list[dict[str, Any]]] = collections.defaultdict(list)
|
| 186 |
+
for item in dataset:
|
| 187 |
+
if not _as_bool(item.get("question_vi_valid", True)):
|
| 188 |
+
continue
|
| 189 |
+
if _as_bool(item.get("low_quality", False)):
|
| 190 |
+
continue
|
| 191 |
+
question = str(item.get("question_vi") or "").strip()
|
| 192 |
+
if not question:
|
| 193 |
+
continue
|
| 194 |
+
groups[_normalize_text_key(question)].append(item)
|
| 195 |
+
|
| 196 |
+
candidates: list[dict[str, Any]] = []
|
| 197 |
+
for items in groups.values():
|
| 198 |
+
if len(items) < 8:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
question = str(items[0].get("question_vi") or "").strip()
|
| 202 |
+
if not question:
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
answer_texts = []
|
| 206 |
+
content_types = []
|
| 207 |
+
answer_types = []
|
| 208 |
+
modalities = []
|
| 209 |
+
for item in items:
|
| 210 |
+
answer = str(item.get("answer_vi") or item.get("answer") or "").strip()
|
| 211 |
+
if answer:
|
| 212 |
+
answer_texts.append(_normalize_text_key(answer))
|
| 213 |
+
content_types.append(str(item.get("content_type", "")).strip())
|
| 214 |
+
answer_types.append(str(item.get("answer_type", "")).strip().upper())
|
| 215 |
+
modalities.append(str(item.get("modality", "")).strip())
|
| 216 |
+
|
| 217 |
+
if not answer_texts:
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
answer_counter = collections.Counter(answer_texts)
|
| 221 |
+
top_answer, top_count = answer_counter.most_common(1)[0]
|
| 222 |
+
total = len(answer_texts)
|
| 223 |
+
confidence = top_count / total
|
| 224 |
+
answer_type = collections.Counter(answer_types).most_common(1)[0][0] if answer_types else ""
|
| 225 |
+
content_type = collections.Counter([c for c in content_types if c]).most_common(1)[0][0] if any(content_types) else ""
|
| 226 |
+
modality = collections.Counter([m for m in modalities if m]).most_common(1)[0][0] if any(modalities) else ""
|
| 227 |
+
|
| 228 |
+
if answer_type == "CLOSED":
|
| 229 |
+
if confidence < 0.85:
|
| 230 |
+
continue
|
| 231 |
+
elif confidence < 0.92:
|
| 232 |
+
continue
|
| 233 |
+
if answer_type != "CLOSED" and len(top_answer.split()) > 3:
|
| 234 |
+
continue
|
| 235 |
+
if len(question) > 140:
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
category = _suggestion_category(items[0], question)
|
| 239 |
+
category_bonus = {
|
| 240 |
+
"Abnormality": 5.0,
|
| 241 |
+
"Modality": 4.5,
|
| 242 |
+
"Plane": 4.25,
|
| 243 |
+
"Organ": 4.0,
|
| 244 |
+
"Position": 3.5,
|
| 245 |
+
"Quantity": 3.25,
|
| 246 |
+
"Color": 3.0,
|
| 247 |
+
"General": 2.0,
|
| 248 |
+
}.get(category, 2.0)
|
| 249 |
+
score = confidence * 100.0 + min(total, 80) * 0.15 + category_bonus - len(question) * 0.02
|
| 250 |
+
|
| 251 |
+
candidates.append(
|
| 252 |
+
{
|
| 253 |
+
"question": question,
|
| 254 |
+
"question_key": _normalize_text_key(question),
|
| 255 |
+
"answer": top_answer,
|
| 256 |
+
"answer_type": answer_type or "OPEN",
|
| 257 |
+
"content_type": content_type or category,
|
| 258 |
+
"modality": modality,
|
| 259 |
+
"confidence": round(confidence, 3),
|
| 260 |
+
"sample_count": total,
|
| 261 |
+
"score": round(score, 3),
|
| 262 |
+
}
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if not candidates:
|
| 266 |
+
return []
|
| 267 |
+
|
| 268 |
+
priority_order = ["Abnormality", "Modality", "Plane", "Organ", "Position", "Quantity", "Color", "General"]
|
| 269 |
+
selected: list[dict[str, Any]] = []
|
| 270 |
+
used_keys: set[str] = set()
|
| 271 |
+
per_category_limit = 2
|
| 272 |
+
category_counts: dict[str, int] = collections.defaultdict(int)
|
| 273 |
+
|
| 274 |
+
for category in priority_order:
|
| 275 |
+
category_candidates = sorted(
|
| 276 |
+
(c for c in candidates if c["content_type"].lower() == category.lower()),
|
| 277 |
+
key=lambda item: (item["score"], item["confidence"], item["sample_count"]),
|
| 278 |
+
reverse=True,
|
| 279 |
+
)
|
| 280 |
+
for candidate in category_candidates:
|
| 281 |
+
if candidate["question_key"] in used_keys:
|
| 282 |
+
continue
|
| 283 |
+
if category_counts[category] >= per_category_limit:
|
| 284 |
+
break
|
| 285 |
+
selected.append(candidate)
|
| 286 |
+
used_keys.add(candidate["question_key"])
|
| 287 |
+
category_counts[category] += 1
|
| 288 |
+
if len(selected) >= limit:
|
| 289 |
+
break
|
| 290 |
+
if len(selected) >= limit:
|
| 291 |
+
break
|
| 292 |
+
|
| 293 |
+
if len(selected) < limit:
|
| 294 |
+
for candidate in sorted(candidates, key=lambda item: (item["score"], item["confidence"], item["sample_count"]), reverse=True):
|
| 295 |
+
if candidate["question_key"] in used_keys:
|
| 296 |
+
continue
|
| 297 |
+
selected.append(candidate)
|
| 298 |
+
used_keys.add(candidate["question_key"])
|
| 299 |
+
if len(selected) >= limit:
|
| 300 |
+
break
|
| 301 |
+
|
| 302 |
+
return selected[:limit]
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _select_best_b2_checkpoint(checkpoint_root: Path) -> Optional[Path]:
|
| 306 |
+
if not checkpoint_root.exists():
|
| 307 |
+
return None
|
| 308 |
+
|
| 309 |
+
best_dir: Optional[Path] = None
|
| 310 |
+
best_metric: Optional[float] = None
|
| 311 |
+
|
| 312 |
+
for ckpt_dir in sorted(checkpoint_root.glob("checkpoint-*")):
|
| 313 |
+
state_file = ckpt_dir / "trainer_state.json"
|
| 314 |
+
if not state_file.exists():
|
| 315 |
+
continue
|
| 316 |
+
try:
|
| 317 |
+
trainer_state = json.loads(state_file.read_text(encoding="utf-8"))
|
| 318 |
+
except Exception:
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
metric = trainer_state.get("best_metric")
|
| 322 |
+
if isinstance(metric, str):
|
| 323 |
+
try:
|
| 324 |
+
metric = float(metric)
|
| 325 |
+
except ValueError:
|
| 326 |
+
metric = None
|
| 327 |
+
|
| 328 |
+
if metric is None:
|
| 329 |
+
eval_losses = [
|
| 330 |
+
rec.get("eval_loss")
|
| 331 |
+
for rec in trainer_state.get("log_history", [])
|
| 332 |
+
if isinstance(rec, dict) and rec.get("eval_loss") is not None
|
| 333 |
+
]
|
| 334 |
+
metric = min(eval_losses) if eval_losses else None
|
| 335 |
+
|
| 336 |
+
if metric is None:
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
if best_metric is None or metric < best_metric:
|
| 340 |
+
best_metric = metric
|
| 341 |
+
best_dir = ckpt_dir
|
| 342 |
+
|
| 343 |
+
if best_dir is not None:
|
| 344 |
+
return best_dir
|
| 345 |
+
|
| 346 |
+
checkpoints = sorted(checkpoint_root.glob("checkpoint-*"))
|
| 347 |
+
return checkpoints[-1] if checkpoints else None
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def _resolve_variant_artifact(variant: str) -> dict[str, Any]:
|
| 351 |
+
if variant in {"A1", "A2"}:
|
| 352 |
+
ckpt_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_best.pth"
|
| 353 |
+
if not ckpt_path.exists():
|
| 354 |
+
resume_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_resume.pth"
|
| 355 |
+
ckpt_path = resume_path if resume_path.exists() else ckpt_path
|
| 356 |
+
return {"type": "direction_a", "path": ckpt_path}
|
| 357 |
+
|
| 358 |
+
if variant == "B1":
|
| 359 |
+
return {"type": "llava_base", "path": state.llava_model_id}
|
| 360 |
+
|
| 361 |
+
if variant == "B2":
|
| 362 |
+
ckpt_dir = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
|
| 363 |
+
return {"type": "llava_adapter", "path": ckpt_dir}
|
| 364 |
+
|
| 365 |
+
if variant == "DPO":
|
| 366 |
+
final_adapter = ROOT_DIR / "checkpoints" / "DPO" / "final_adapter"
|
| 367 |
+
fallback = ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25"
|
| 368 |
+
return {"type": "llava_adapter", "path": final_adapter if final_adapter.exists() else fallback}
|
| 369 |
+
|
| 370 |
+
if variant == "PPO":
|
| 371 |
+
final_adapter = ROOT_DIR / "checkpoints" / "PPO" / "final_adapter"
|
| 372 |
+
return {"type": "llava_adapter", "path": final_adapter}
|
| 373 |
+
|
| 374 |
+
raise ValueError(f"Unknown variant: {variant}")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _llava_adapter_specs() -> list[tuple[str, Path]]:
|
| 378 |
+
specs: list[tuple[str, Path]] = []
|
| 379 |
+
for variant in ("B2", "DPO", "PPO"):
|
| 380 |
+
artifact = _resolve_variant_artifact(variant)["path"]
|
| 381 |
+
if isinstance(artifact, Path) and artifact.exists():
|
| 382 |
+
specs.append((variant, artifact))
|
| 383 |
+
return specs
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _ensure_qa_tokenizer():
|
| 387 |
+
if state.qa_tokenizer is None:
|
| 388 |
+
tokenizer = AutoTokenizer.from_pretrained(state.phobert_model)
|
| 389 |
+
if tokenizer.pad_token is None:
|
| 390 |
+
tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token
|
| 391 |
+
state.qa_tokenizer = tokenizer
|
| 392 |
+
return state.qa_tokenizer
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def _looks_vietnamese(text: str) -> bool:
|
| 396 |
+
vi_marks = "ăâđêôơưáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựýỳỷỹỵ"
|
| 397 |
+
lowered = text.lower()
|
| 398 |
+
if any(ch in vi_marks for ch in lowered):
|
| 399 |
+
return True
|
| 400 |
+
vi_keywords = {
|
| 401 |
+
"không",
|
| 402 |
+
"có",
|
| 403 |
+
"bệnh",
|
| 404 |
+
"phổi",
|
| 405 |
+
"tim",
|
| 406 |
+
"sọ",
|
| 407 |
+
"xương",
|
| 408 |
+
"ảnh",
|
| 409 |
+
"hỏi",
|
| 410 |
+
"đâu",
|
| 411 |
+
"gì",
|
| 412 |
+
"như thế nào",
|
| 413 |
+
}
|
| 414 |
+
return any(keyword in lowered for keyword in vi_keywords)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _looks_closed_question(question: str) -> bool:
|
| 418 |
+
normalized = question.lower().strip()
|
| 419 |
+
normalized = re.sub(r"\s+", " ", normalized)
|
| 420 |
+
closed_prefixes = (
|
| 421 |
+
"is ",
|
| 422 |
+
"are ",
|
| 423 |
+
"was ",
|
| 424 |
+
"were ",
|
| 425 |
+
"do ",
|
| 426 |
+
"does ",
|
| 427 |
+
"did ",
|
| 428 |
+
"can ",
|
| 429 |
+
"could ",
|
| 430 |
+
"should ",
|
| 431 |
+
"would ",
|
| 432 |
+
"has ",
|
| 433 |
+
"have ",
|
| 434 |
+
"had ",
|
| 435 |
+
"có ",
|
| 436 |
+
"có phải",
|
| 437 |
+
"liệu ",
|
| 438 |
+
)
|
| 439 |
+
closed_keywords = {
|
| 440 |
+
"yes",
|
| 441 |
+
"no",
|
| 442 |
+
"không",
|
| 443 |
+
"có",
|
| 444 |
+
"normal",
|
| 445 |
+
"abnormal",
|
| 446 |
+
"present",
|
| 447 |
+
"absent",
|
| 448 |
+
"sốt",
|
| 449 |
+
}
|
| 450 |
+
open_prefixes = ("what ", "where ", "when ", "who ", "which ", "how ", "why ")
|
| 451 |
+
if normalized.startswith(open_prefixes):
|
| 452 |
+
return False
|
| 453 |
+
if normalized.startswith(closed_prefixes):
|
| 454 |
+
return True
|
| 455 |
+
return any(word in normalized.split() for word in closed_keywords)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def _normalize_closed_answer(question_vi: str, question_en: str, pred_vi: str, pred_en: str = "") -> str:
|
| 459 |
+
question_text = f"{question_vi} {question_en}".lower()
|
| 460 |
+
combined = " ".join(part for part in [pred_vi, pred_en] if part).lower().strip()
|
| 461 |
+
combined_norm = re.sub(r"\s+", " ", combined)
|
| 462 |
+
|
| 463 |
+
is_normality_question = any(pattern in question_text for pattern in ["bình thường", "normal", "abnormal"])
|
| 464 |
+
|
| 465 |
+
if is_normality_question:
|
| 466 |
+
if any(pattern in combined_norm for pattern in ["không bình thường", "not normal"]):
|
| 467 |
+
return "không"
|
| 468 |
+
if any(pattern in combined_norm.split() for pattern in ["có", "yes"]):
|
| 469 |
+
return "có"
|
| 470 |
+
if any(pattern in combined_norm for pattern in ["bình thường", "normal", "unremarkable", "no significant abnormalities"]):
|
| 471 |
+
return "có"
|
| 472 |
+
if any(pattern in combined_norm for pattern in ["bất thường", "abnormal", "fracture", "lesion", "mass", "effusion", "pneumothorax"]):
|
| 473 |
+
return "không"
|
| 474 |
+
else:
|
| 475 |
+
if any(pattern in combined_norm for pattern in ["không", "no", "absent", "negative", "none"]):
|
| 476 |
+
return "không"
|
| 477 |
+
if any(pattern in combined_norm for pattern in ["có", "yes", "present", "detected", "positive"]):
|
| 478 |
+
return "có"
|
| 479 |
+
|
| 480 |
+
if any(pattern in combined_norm for pattern in ["bình thường", "normal", "unremarkable", "no significant abnormalities"]):
|
| 481 |
+
return "có"
|
| 482 |
+
if any(pattern in combined_norm for pattern in ["bất thường", "abnormal", "fracture", "lesion", "mass", "effusion", "pneumothorax"]):
|
| 483 |
+
return "không"
|
| 484 |
+
return pred_vi or pred_en or ""
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def _build_bad_words_ids(processor, variant: str) -> list[list[int]] | None:
|
| 488 |
+
if variant not in {"B1", "B2", "DPO", "PPO"}:
|
| 489 |
+
return None
|
| 490 |
+
tokenizer = getattr(processor, "tokenizer", None)
|
| 491 |
+
if tokenizer is None:
|
| 492 |
+
return None
|
| 493 |
+
banned_phrases = [
|
| 494 |
+
"yes",
|
| 495 |
+
"no",
|
| 496 |
+
"the answer is",
|
| 497 |
+
"the image is",
|
| 498 |
+
"this image is",
|
| 499 |
+
"the image shows",
|
| 500 |
+
"the scan shows",
|
| 501 |
+
"there is",
|
| 502 |
+
"there are",
|
| 503 |
+
"it appears",
|
| 504 |
+
"the finding is",
|
| 505 |
+
]
|
| 506 |
+
bad_words_ids = []
|
| 507 |
+
for phrase in banned_phrases:
|
| 508 |
+
token_ids = tokenizer.encode(phrase, add_special_tokens=False)
|
| 509 |
+
if token_ids:
|
| 510 |
+
bad_words_ids.append(token_ids)
|
| 511 |
+
return bad_words_ids or None
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def _build_b1_prompt(question_en: str, max_words: int) -> str:
|
| 515 |
+
instruction = f"Answer in Vietnamese, concise, at most {max_words} words."
|
| 516 |
+
return f"USER: <image>\n{question_en}\n{instruction} ASSISTANT:"
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def _rewrite_final_answer(question: str, raw_answer: str, language: str = "vi") -> str:
|
| 520 |
+
"""
|
| 521 |
+
Chỉ rewrite phần output hiển thị cuối cùng.
|
| 522 |
+
Raw prediction vẫn được giữ nguyên trong payload để debug.
|
| 523 |
+
"""
|
| 524 |
+
candidate = state.answer_rewriter.rewrite(question=question, answer=raw_answer, language=language)
|
| 525 |
+
candidate = postprocess_answer(candidate, max_words=state.answer_max_words)
|
| 526 |
+
if candidate:
|
| 527 |
+
return candidate
|
| 528 |
+
return postprocess_answer(raw_answer, max_words=state.answer_max_words)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def _extract_key_medical_term(raw_en: str, max_words: int) -> str:
|
| 532 |
+
text = re.sub(r"\s+", " ", (raw_en or "").strip())
|
| 533 |
+
if not text:
|
| 534 |
+
return ""
|
| 535 |
+
return " ".join(text.split()[:max_words])
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def _en_to_vi_direct(en_text: str) -> Optional[str]:
|
| 539 |
+
text = (en_text or "").strip().lower()
|
| 540 |
+
mapping = {
|
| 541 |
+
"yes": "có",
|
| 542 |
+
"no": "không",
|
| 543 |
+
"normal": "bình thường",
|
| 544 |
+
"abnormal": "bất thường",
|
| 545 |
+
"present": "có",
|
| 546 |
+
"absent": "không",
|
| 547 |
+
}
|
| 548 |
+
return mapping.get(text)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def _prepare_question_text(question: str, variant: str) -> tuple[str, str]:
|
| 552 |
+
question = question.strip()
|
| 553 |
+
if not question:
|
| 554 |
+
return "", ""
|
| 555 |
+
|
| 556 |
+
if variant == "B1":
|
| 557 |
+
question_en = question if not _looks_vietnamese(question) else state.translator.translate_vi2en(question)
|
| 558 |
+
return question, question_en
|
| 559 |
+
|
| 560 |
+
question_vi = question if _looks_vietnamese(question) else state.translator.translate_en2vi(question)
|
| 561 |
+
return question_vi, question
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
async def _ensure_direction_a_model(variant: str):
|
| 565 |
+
if variant not in {"A1", "A2"}:
|
| 566 |
+
raise ValueError(f"Unsupported direction A variant: {variant}")
|
| 567 |
+
|
| 568 |
+
cached = state.a_models.get(variant)
|
| 569 |
+
if cached is not None:
|
| 570 |
+
return cached
|
| 571 |
+
|
| 572 |
+
async with state.cache_lock:
|
| 573 |
+
cached = state.a_models.get(variant)
|
| 574 |
+
if cached is not None:
|
| 575 |
+
return cached
|
| 576 |
+
|
| 577 |
+
tokenizer = _ensure_qa_tokenizer()
|
| 578 |
+
ckpt_path = _resolve_variant_artifact(variant)["path"]
|
| 579 |
+
if not isinstance(ckpt_path, Path) or not ckpt_path.exists():
|
| 580 |
+
raise FileNotFoundError(f"Không tìm thấy checkpoint cho {variant}: {ckpt_path}")
|
| 581 |
+
|
| 582 |
+
decoder_type = "lstm" if variant == "A1" else "transformer"
|
| 583 |
+
model = MedicalVQAModelA(
|
| 584 |
+
decoder_type=decoder_type,
|
| 585 |
+
vocab_size=len(tokenizer),
|
| 586 |
+
hidden_size=int(state.model_a_cfg.get("hidden_size", 768)),
|
| 587 |
+
phobert_model=state.phobert_model,
|
| 588 |
+
).to(state.device)
|
| 589 |
+
|
| 590 |
+
payload = torch.load(ckpt_path, map_location=state.device)
|
| 591 |
+
state_dict = payload.get("model_state_dict") if isinstance(payload, dict) and "model_state_dict" in payload else payload
|
| 592 |
+
model.load_state_dict(state_dict, strict=False)
|
| 593 |
+
model.eval()
|
| 594 |
+
bundle = {
|
| 595 |
+
"variant": variant,
|
| 596 |
+
"family": "A",
|
| 597 |
+
"model": model,
|
| 598 |
+
"tokenizer": tokenizer,
|
| 599 |
+
"checkpoint": str(ckpt_path),
|
| 600 |
+
}
|
| 601 |
+
state.a_models[variant] = bundle
|
| 602 |
+
return bundle
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def _build_llava_base_and_processor():
|
| 606 |
+
if not torch.cuda.is_available():
|
| 607 |
+
raise RuntimeError("Các model LLaVA (B1/B2/DPO/PPO) cần CUDA để chạy trong web này.")
|
| 608 |
+
|
| 609 |
+
wrapper = MultimodalVQA(
|
| 610 |
+
model_id=state.llava_model_id,
|
| 611 |
+
lora_r=int(state.model_b_cfg.get("lora_r", 16)),
|
| 612 |
+
lora_alpha=int(state.model_b_cfg.get("lora_alpha", 32)),
|
| 613 |
+
lora_dropout=float(state.model_b_cfg.get("lora_dropout", 0.05)),
|
| 614 |
+
lora_target_modules=state.model_b_cfg.get("lora_target_modules"),
|
| 615 |
+
)
|
| 616 |
+
processor = LlavaProcessor.from_pretrained(wrapper.model_id)
|
| 617 |
+
processor.tokenizer.padding_side = "left"
|
| 618 |
+
base_model = LlavaForConditionalGeneration.from_pretrained(
|
| 619 |
+
wrapper.model_id,
|
| 620 |
+
quantization_config=wrapper.bnb_config,
|
| 621 |
+
device_map="auto",
|
| 622 |
+
)
|
| 623 |
+
base_model.config.use_cache = False
|
| 624 |
+
return wrapper, processor, base_model
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
async def _ensure_llava_bundle():
|
| 628 |
+
cached = state.llava_bundle
|
| 629 |
+
if cached is not None:
|
| 630 |
+
return cached
|
| 631 |
+
|
| 632 |
+
async with state.cache_lock:
|
| 633 |
+
cached = state.llava_bundle
|
| 634 |
+
if cached is not None:
|
| 635 |
+
return cached
|
| 636 |
+
|
| 637 |
+
wrapper, processor, base_model = _build_llava_base_and_processor()
|
| 638 |
+
adapter_specs = _llava_adapter_specs()
|
| 639 |
+
adapter_name_map = {variant: variant for variant, _ in adapter_specs}
|
| 640 |
+
|
| 641 |
+
if adapter_specs:
|
| 642 |
+
first_variant, first_path = adapter_specs[0]
|
| 643 |
+
model = PeftModel.from_pretrained(
|
| 644 |
+
base_model,
|
| 645 |
+
str(first_path),
|
| 646 |
+
adapter_name=first_variant,
|
| 647 |
+
is_trainable=False,
|
| 648 |
+
)
|
| 649 |
+
for variant, path in adapter_specs[1:]:
|
| 650 |
+
model.load_adapter(str(path), adapter_name=variant, is_trainable=False)
|
| 651 |
+
model.set_adapter(first_variant)
|
| 652 |
+
else:
|
| 653 |
+
model = base_model
|
| 654 |
+
|
| 655 |
+
model.eval()
|
| 656 |
+
bundle = {
|
| 657 |
+
"family": "B",
|
| 658 |
+
"model": model,
|
| 659 |
+
"processor": processor,
|
| 660 |
+
"wrapper": wrapper,
|
| 661 |
+
"checkpoint": adapter_specs[0][1].as_posix() if adapter_specs else state.llava_model_id,
|
| 662 |
+
"adapter_name_map": adapter_name_map,
|
| 663 |
+
"peft": bool(adapter_specs),
|
| 664 |
+
}
|
| 665 |
+
state.llava_bundle = bundle
|
| 666 |
+
return bundle
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def _predict_direction_a(bundle: dict[str, Any], question_vi: str, image: Image.Image) -> dict[str, Any]:
|
| 670 |
+
model = bundle["model"]
|
| 671 |
+
tokenizer = bundle["tokenizer"]
|
| 672 |
+
image_tensor = state.image_transform(image.convert("L")).unsqueeze(0).to(state.device)
|
| 673 |
+
|
| 674 |
+
inputs = tokenizer(
|
| 675 |
+
question_vi,
|
| 676 |
+
padding="max_length",
|
| 677 |
+
truncation=True,
|
| 678 |
+
max_length=state.max_question_len,
|
| 679 |
+
return_tensors="pt",
|
| 680 |
+
)
|
| 681 |
+
input_ids = inputs["input_ids"].to(state.device)
|
| 682 |
+
attention_mask = inputs["attention_mask"].to(state.device)
|
| 683 |
+
is_closed = _looks_closed_question(question_vi)
|
| 684 |
+
|
| 685 |
+
with torch.inference_mode():
|
| 686 |
+
logits_closed, pred_ids = model.inference(
|
| 687 |
+
image_tensor,
|
| 688 |
+
input_ids,
|
| 689 |
+
attention_mask,
|
| 690 |
+
beam_width=int(state.eval_cfg.get("beam_width_a", 5)),
|
| 691 |
+
max_len=state.max_answer_len,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
if is_closed:
|
| 695 |
+
prediction_raw = "có" if logits_closed.argmax(dim=1).item() == 1 else "không"
|
| 696 |
+
prediction = _rewrite_final_answer(question_vi, prediction_raw, language="vi")
|
| 697 |
+
else:
|
| 698 |
+
prediction_raw = tokenizer.decode(pred_ids[0], skip_special_tokens=True)
|
| 699 |
+
prediction = _rewrite_final_answer(question_vi, prediction_raw, language="vi")
|
| 700 |
+
|
| 701 |
+
return {
|
| 702 |
+
"prediction": prediction,
|
| 703 |
+
"prediction_raw": prediction_raw,
|
| 704 |
+
"status": "ok",
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
async def _predict_direction_b(
|
| 709 |
+
bundle: dict[str, Any],
|
| 710 |
+
question_vi: str,
|
| 711 |
+
question_en: str,
|
| 712 |
+
image: Image.Image,
|
| 713 |
+
variant: str,
|
| 714 |
+
) -> dict[str, Any]:
|
| 715 |
+
model = bundle["model"]
|
| 716 |
+
processor = bundle["processor"]
|
| 717 |
+
wrapper = bundle["wrapper"]
|
| 718 |
+
is_closed = _looks_closed_question(question_vi if variant != "B1" else question_en)
|
| 719 |
+
question_for_variant = question_en if variant == "B1" else question_vi
|
| 720 |
+
adapter_name = bundle.get("adapter_name_map", {}).get(variant)
|
| 721 |
+
|
| 722 |
+
if variant == "B1":
|
| 723 |
+
prompt = _build_b1_prompt(question_for_variant, state.answer_max_words)
|
| 724 |
+
num_beams = int(state.eval_cfg.get("beam_width_b_open", 5))
|
| 725 |
+
max_new_tokens = int(state.eval_cfg.get("max_new_tokens_b_open", state.answer_max_words + 6))
|
| 726 |
+
else:
|
| 727 |
+
prompt = wrapper.build_instruction_prompt(question_for_variant, language="vi", include_answer=False)
|
| 728 |
+
num_beams = int(state.eval_cfg.get("beam_width_b_closed", 1)) if is_closed else int(
|
| 729 |
+
state.eval_cfg.get("beam_width_b_open", 5)
|
| 730 |
+
)
|
| 731 |
+
max_new_tokens = int(state.eval_cfg.get("max_new_tokens_b_closed", 4)) if is_closed else int(
|
| 732 |
+
state.eval_cfg.get("max_new_tokens_b_open", state.answer_max_words + 6)
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
bad_words_ids = _build_bad_words_ids(processor, variant)
|
| 736 |
+
inputs = processor(text=[prompt], images=[image.convert("RGB")], return_tensors="pt", padding=True)
|
| 737 |
+
inputs = inputs.to(state.device)
|
| 738 |
+
if "pixel_values" in inputs and torch.cuda.is_available():
|
| 739 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
|
| 740 |
+
|
| 741 |
+
async with state.b_lock:
|
| 742 |
+
if adapter_name and hasattr(model, "set_adapter"):
|
| 743 |
+
model.set_adapter(adapter_name)
|
| 744 |
+
if variant == "B1" and hasattr(model, "disable_adapter"):
|
| 745 |
+
with model.disable_adapter():
|
| 746 |
+
with torch.inference_mode():
|
| 747 |
+
output_ids = model.generate(
|
| 748 |
+
**inputs,
|
| 749 |
+
max_new_tokens=max_new_tokens,
|
| 750 |
+
do_sample=False,
|
| 751 |
+
num_beams=num_beams,
|
| 752 |
+
early_stopping=num_beams > 1,
|
| 753 |
+
bad_words_ids=bad_words_ids,
|
| 754 |
+
)
|
| 755 |
+
else:
|
| 756 |
+
with torch.inference_mode():
|
| 757 |
+
output_ids = model.generate(
|
| 758 |
+
**inputs,
|
| 759 |
+
max_new_tokens=max_new_tokens,
|
| 760 |
+
do_sample=False,
|
| 761 |
+
num_beams=num_beams,
|
| 762 |
+
early_stopping=num_beams > 1,
|
| 763 |
+
bad_words_ids=bad_words_ids,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
input_token_len = inputs.input_ids.shape[1]
|
| 767 |
+
pred_raw = processor.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0].strip()
|
| 768 |
+
|
| 769 |
+
if variant == "B1":
|
| 770 |
+
pred_en = _extract_key_medical_term(pred_raw, 50)
|
| 771 |
+
if is_closed:
|
| 772 |
+
prediction = _normalize_closed_answer(question_vi, question_en, pred_en, pred_en)
|
| 773 |
+
else:
|
| 774 |
+
prediction = _en_to_vi_direct(pred_en)
|
| 775 |
+
if prediction is None:
|
| 776 |
+
prediction = state.translator.translate_en2vi(pred_en)
|
| 777 |
+
prediction = postprocess_answer(prediction, max_words=state.answer_max_words)
|
| 778 |
+
else:
|
| 779 |
+
if is_closed:
|
| 780 |
+
prediction = _normalize_closed_answer(question_vi, question_en, pred_raw)
|
| 781 |
+
else:
|
| 782 |
+
prediction = postprocess_answer(pred_raw, max_words=state.answer_max_words)
|
| 783 |
+
|
| 784 |
+
prediction = _rewrite_final_answer(question_vi or question_en, prediction, language="vi")
|
| 785 |
+
|
| 786 |
+
return {
|
| 787 |
+
"prediction": prediction,
|
| 788 |
+
"prediction_raw": pred_raw,
|
| 789 |
+
"status": "ok",
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
async def predict_variant(variant: str, question: str, image: Image.Image) -> dict[str, Any]:
|
| 794 |
+
start = time.perf_counter()
|
| 795 |
+
try:
|
| 796 |
+
if variant in {"A1", "A2"}:
|
| 797 |
+
bundle = await _ensure_direction_a_model(variant)
|
| 798 |
+
else:
|
| 799 |
+
artifact = _resolve_variant_artifact(variant)["path"]
|
| 800 |
+
if variant != "B1" and (not isinstance(artifact, Path) or not artifact.exists()):
|
| 801 |
+
raise FileNotFoundError(f"Không tìm thấy artifact cho {variant}: {artifact}")
|
| 802 |
+
bundle = await _ensure_llava_bundle()
|
| 803 |
+
question_vi, question_en = _prepare_question_text(question, variant)
|
| 804 |
+
if variant == "B1":
|
| 805 |
+
if not question_en:
|
| 806 |
+
question_en = question
|
| 807 |
+
result = await _predict_direction_b(bundle, question_vi, question_en, image, variant)
|
| 808 |
+
elif bundle["family"] == "A":
|
| 809 |
+
result = _predict_direction_a(bundle, question_vi, image)
|
| 810 |
+
else:
|
| 811 |
+
result = await _predict_direction_b(bundle, question_vi, question_en, image, variant)
|
| 812 |
+
|
| 813 |
+
result.update(
|
| 814 |
+
{
|
| 815 |
+
"variant": variant,
|
| 816 |
+
"checkpoint": (
|
| 817 |
+
bundle.get("checkpoint", "")
|
| 818 |
+
if variant in {"A1", "A2"}
|
| 819 |
+
else str(_resolve_variant_artifact(variant)["path"])
|
| 820 |
+
if variant != "B1"
|
| 821 |
+
else state.llava_model_id
|
| 822 |
+
),
|
| 823 |
+
"latency_ms": round((time.perf_counter() - start) * 1000, 2),
|
| 824 |
+
}
|
| 825 |
+
)
|
| 826 |
+
return result
|
| 827 |
+
except Exception as exc:
|
| 828 |
+
return {
|
| 829 |
+
"variant": variant,
|
| 830 |
+
"prediction": "",
|
| 831 |
+
"prediction_raw": "",
|
| 832 |
+
"status": f"error: {exc}",
|
| 833 |
+
"checkpoint": "",
|
| 834 |
+
"latency_ms": round((time.perf_counter() - start) * 1000, 2),
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def _parse_model_selection(raw_model_name: Optional[str], raw_model_names: Optional[str]) -> list[str]:
|
| 839 |
+
if raw_model_names:
|
| 840 |
+
try:
|
| 841 |
+
parsed = json.loads(raw_model_names)
|
| 842 |
+
except Exception:
|
| 843 |
+
parsed = [part.strip() for part in raw_model_names.split(",") if part.strip()]
|
| 844 |
+
if isinstance(parsed, str):
|
| 845 |
+
parsed = [parsed]
|
| 846 |
+
selected = [name for name in parsed if name in VARIANT_ORDER]
|
| 847 |
+
if selected:
|
| 848 |
+
return selected
|
| 849 |
+
|
| 850 |
+
if raw_model_name and raw_model_name in VARIANT_ORDER:
|
| 851 |
+
return [raw_model_name]
|
| 852 |
+
|
| 853 |
+
return VARIANT_ORDER[:]
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
def _variant_availability() -> dict[str, dict[str, Any]]:
|
| 857 |
+
b2_checkpoint = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
|
| 858 |
+
cuda_ready = torch.cuda.is_available()
|
| 859 |
+
return {
|
| 860 |
+
"A1": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth")), "artifact": "checkpoints/medical_vqa_A1_best.pth"},
|
| 861 |
+
"A2": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth")), "artifact": "checkpoints/medical_vqa_A2_best.pth"},
|
| 862 |
+
"B1": {"available": cuda_ready, "artifact": state.llava_model_id},
|
| 863 |
+
"B2": {"available": cuda_ready and b2_checkpoint is not None, "artifact": str(b2_checkpoint) if b2_checkpoint else ""},
|
| 864 |
+
"DPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") or _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25")), "artifact": "checkpoints/DPO/final_adapter"},
|
| 865 |
+
"PPO": {"available": cuda_ready and _artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter"), "artifact": "checkpoints/PPO/final_adapter"},
|
| 866 |
+
}
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
@app.on_event("startup")
|
| 870 |
+
async def startup_event() -> None:
|
| 871 |
+
_ensure_qa_tokenizer()
|
| 872 |
+
state.question_suggestions = _load_question_suggestions()
|
| 873 |
+
if state.preload_models:
|
| 874 |
+
try:
|
| 875 |
+
for variant in ("A1", "A2"):
|
| 876 |
+
await _ensure_direction_a_model(variant)
|
| 877 |
+
await _ensure_llava_bundle()
|
| 878 |
+
except Exception as exc:
|
| 879 |
+
print(f"[WARNING] Model preload skipped: {exc}")
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
@app.get("/v1/models")
|
| 883 |
+
def list_models() -> JSONResponse:
|
| 884 |
+
payload = []
|
| 885 |
+
availability = _variant_availability()
|
| 886 |
+
for variant in VARIANT_ORDER:
|
| 887 |
+
meta = VARIANT_META[variant]
|
| 888 |
+
info = availability.get(variant, {})
|
| 889 |
+
payload.append(
|
| 890 |
+
{
|
| 891 |
+
"name": variant,
|
| 892 |
+
"family": meta["family"],
|
| 893 |
+
"title": meta["title"],
|
| 894 |
+
"subtitle": meta["subtitle"],
|
| 895 |
+
"description": meta["description"],
|
| 896 |
+
"available": bool(info.get("available")),
|
| 897 |
+
"artifact": info.get("artifact", ""),
|
| 898 |
+
}
|
| 899 |
+
)
|
| 900 |
+
return JSONResponse({"models": payload})
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
@app.post("/v1/predict")
|
| 904 |
+
async def predict(
|
| 905 |
+
question: str = Form(..., description="Question for VQA"),
|
| 906 |
+
model_name: Optional[str] = Form(None, description="Legacy single model name"),
|
| 907 |
+
model_names: Optional[str] = Form(None, description="Comma-separated or JSON list of models"),
|
| 908 |
+
image: UploadFile = File(..., description="Image input (JPEG/PNG)"),
|
| 909 |
+
) -> JSONResponse:
|
| 910 |
+
if not question.strip():
|
| 911 |
+
raise HTTPException(status_code=400, detail="Question is required.")
|
| 912 |
+
|
| 913 |
+
try:
|
| 914 |
+
img_bytes = await image.read()
|
| 915 |
+
pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
| 916 |
+
except Exception as exc:
|
| 917 |
+
raise HTTPException(status_code=400, detail=f"Failed to read image file: {exc}") from exc
|
| 918 |
+
|
| 919 |
+
selected_models = _parse_model_selection(model_name, model_names)
|
| 920 |
+
results = []
|
| 921 |
+
async with load_lock:
|
| 922 |
+
for variant in selected_models:
|
| 923 |
+
results.append(await predict_variant(variant, question, pil_img))
|
| 924 |
+
|
| 925 |
+
predictions = {item["variant"]: item["prediction"] for item in results if item.get("status") == "ok"}
|
| 926 |
+
summary = {
|
| 927 |
+
"majority_vote": majority_answer(list(predictions.values())) if predictions else "",
|
| 928 |
+
"success_count": sum(1 for item in results if item.get("status") == "ok"),
|
| 929 |
+
"error_count": sum(1 for item in results if item.get("status", "").startswith("error")),
|
| 930 |
+
}
|
| 931 |
+
|
| 932 |
+
return JSONResponse(
|
| 933 |
+
{
|
| 934 |
+
"question": question,
|
| 935 |
+
"selected_models": selected_models,
|
| 936 |
+
"results": results,
|
| 937 |
+
"summary": summary,
|
| 938 |
+
}
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
@app.get("/v1/question-suggestions")
|
| 943 |
+
def question_suggestions(limit: int = SUGGESTION_LIMIT) -> JSONResponse:
|
| 944 |
+
suggestions = state.question_suggestions or _load_question_suggestions(limit)
|
| 945 |
+
clipped = suggestions[: max(1, min(limit, len(suggestions)))] if suggestions else []
|
| 946 |
+
return JSONResponse({"suggestions": clipped})
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
@app.get("/health")
|
| 950 |
+
def health() -> JSONResponse:
|
| 951 |
+
availability = _variant_availability()
|
| 952 |
+
return JSONResponse(
|
| 953 |
+
{
|
| 954 |
+
"status": "ok",
|
| 955 |
+
"device": str(state.device),
|
| 956 |
+
"preload_enabled": state.preload_models,
|
| 957 |
+
"answer_rewrite_enabled": state.answer_rewriter.enabled,
|
| 958 |
+
"answer_rewrite_model_id": state.answer_rewriter.model_id,
|
| 959 |
+
"answer_rewrite_ready": state.answer_rewriter.ready,
|
| 960 |
+
"suggestions_cached": len(state.question_suggestions),
|
| 961 |
+
"cached": {
|
| 962 |
+
"A": sorted(state.a_models.keys()),
|
| 963 |
+
"B": bool(state.llava_bundle),
|
| 964 |
+
},
|
| 965 |
+
"models": {
|
| 966 |
+
variant: {"available": availability[variant]["available"], "artifact": availability[variant]["artifact"]}
|
| 967 |
+
for variant in VARIANT_ORDER
|
| 968 |
+
},
|
| 969 |
+
}
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
@app.get("/", include_in_schema=False)
|
| 974 |
+
def index() -> FileResponse:
|
| 975 |
+
index_path = os.path.join(os.path.dirname(__file__), "static", "index.html")
|
| 976 |
+
if not os.path.exists(index_path):
|
| 977 |
+
raise HTTPException(status_code=500, detail="Frontend index.html not found.")
|
| 978 |
+
return FileResponse(index_path)
|
web/static/index.html
ADDED
|
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8"/>
|
| 5 |
+
<meta content="width=device-width, initial-scale=1.0" name="viewport"/>
|
| 6 |
+
<title>Medical VQA Compare</title>
|
| 7 |
+
<link href="https://fonts.googleapis.com/css2?family=Material+Symbols+Outlined:wght,FILL@100..700,0..1&display=swap" rel="stylesheet"/>
|
| 8 |
+
<link href="https://fonts.googleapis.com" rel="preconnect"/>
|
| 9 |
+
<link crossorigin="" href="https://fonts.gstatic.com" rel="preconnect"/>
|
| 10 |
+
<link href="https://fonts.googleapis.com/css2?family=Cinzel:wght@400;500;600;700&family=Noto+Serif+SC:wght@300;400;500;700&display=swap" rel="stylesheet"/>
|
| 11 |
+
<script src="https://cdn.tailwindcss.com?plugins=forms,container-queries"></script>
|
| 12 |
+
<script id="tailwind-config">
|
| 13 |
+
tailwind.config = {
|
| 14 |
+
darkMode: "class",
|
| 15 |
+
theme: {
|
| 16 |
+
extend: {
|
| 17 |
+
colors: {
|
| 18 |
+
"imperial-red": "#A8181B",
|
| 19 |
+
"china-gold": "#A88412",
|
| 20 |
+
"gold-light": "#F9E79F",
|
| 21 |
+
"deep-crimson": "#7D0A0D",
|
| 22 |
+
"ink-black": "#1A1A1A",
|
| 23 |
+
"paper-white": "#FDFBF7",
|
| 24 |
+
"jade-dark": "#0B3D30"
|
| 25 |
+
},
|
| 26 |
+
fontFamily: {
|
| 27 |
+
"serif": ["Noto Serif SC", "Cinzel", "serif"],
|
| 28 |
+
"display": ["Cinzel", "serif"]
|
| 29 |
+
},
|
| 30 |
+
backgroundImage: {
|
| 31 |
+
'cloud-pattern': "url(\"data:image/svg+xml,%3Csvg width='60' height='60' viewBox='0 0 60 60' xmlns='http://www.w3.org/2000/svg'%3E%3Cg fill='none' fill-rule='evenodd'%3E%3Cg fill='%23d4af37' fill-opacity='0.1'%3E%3Cpath d='M36 34v-4h-2v4h-4v2h4v4h2v-4h4v-2h-4zm0-30V0h-2v4h-4v2h4v4h2V6h4V4h-4zM6 34v-4H4v4H0v2h4v4h2v-4h4v-2H6zM6 4V0H4v4H0v2h4v4h2V6h4V4H6z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E\")",
|
| 32 |
+
'ink-wash': "linear-gradient(to bottom right, #FDFBF7, #F2EFE9)"
|
| 33 |
+
},
|
| 34 |
+
boxShadow: {
|
| 35 |
+
'gold-glow': '0 0 15px rgba(212, 175, 55, 0.3)',
|
| 36 |
+
'red-glow': '0 4px 20px rgba(168, 24, 27, 0.25)'
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
</script>
|
| 42 |
+
<style>
|
| 43 |
+
:root {
|
| 44 |
+
--tilt-x: 0deg;
|
| 45 |
+
--tilt-y: 0deg;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
.scene-3d {
|
| 49 |
+
perspective: 1600px;
|
| 50 |
+
transform-style: preserve-3d;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.tilt-card {
|
| 54 |
+
transform-style: preserve-3d;
|
| 55 |
+
transition: transform 180ms ease, box-shadow 180ms ease;
|
| 56 |
+
will-change: transform;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.tilt-card:hover {
|
| 60 |
+
box-shadow: 0 24px 50px rgba(168, 24, 27, 0.18), 0 10px 20px rgba(0, 0, 0, 0.08);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
.float-slow {
|
| 64 |
+
animation: floatY 6.5s ease-in-out infinite;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
.float-med {
|
| 68 |
+
animation: floatY 5.2s ease-in-out infinite;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.float-fast {
|
| 72 |
+
animation: floatY 4.4s ease-in-out infinite;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
.spin-slow {
|
| 76 |
+
animation: spin360 18s linear infinite;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
.pulse-ring {
|
| 80 |
+
animation: pulseRing 2.8s ease-in-out infinite;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
.hover-lift {
|
| 84 |
+
transform: translateZ(18px);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
.medical-glow {
|
| 88 |
+
box-shadow: 0 0 0 1px rgba(212, 175, 55, 0.18), 0 12px 40px rgba(168, 24, 27, 0.16);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.depth-line {
|
| 92 |
+
position: relative;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
.depth-line::before {
|
| 96 |
+
content: "";
|
| 97 |
+
position: absolute;
|
| 98 |
+
inset: 0;
|
| 99 |
+
border-radius: inherit;
|
| 100 |
+
background: linear-gradient(135deg, rgba(255,255,255,0.45), rgba(255,255,255,0.03));
|
| 101 |
+
transform: translateZ(-2px);
|
| 102 |
+
pointer-events: none;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
@keyframes floatY {
|
| 106 |
+
0%, 100% { transform: translateY(0px) translateZ(0); }
|
| 107 |
+
50% { transform: translateY(-10px) translateZ(18px); }
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
@keyframes spin360 {
|
| 111 |
+
from { transform: rotate(0deg); }
|
| 112 |
+
to { transform: rotate(360deg); }
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
@keyframes pulseRing {
|
| 116 |
+
0%, 100% { transform: scale(1); opacity: 0.65; }
|
| 117 |
+
50% { transform: scale(1.08); opacity: 1; }
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
@media (prefers-reduced-motion: reduce) {
|
| 121 |
+
.float-slow,
|
| 122 |
+
.float-med,
|
| 123 |
+
.float-fast,
|
| 124 |
+
.spin-slow,
|
| 125 |
+
.pulse-ring {
|
| 126 |
+
animation: none !important;
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
</style>
|
| 130 |
+
<style type="text/tailwindcss">
|
| 131 |
+
@layer utilities {
|
| 132 |
+
.ornate-border {
|
| 133 |
+
border: 2px solid #D4AF37;
|
| 134 |
+
position: relative;
|
| 135 |
+
}
|
| 136 |
+
.ornate-border::before {
|
| 137 |
+
content: "";
|
| 138 |
+
position: absolute;
|
| 139 |
+
top: -4px; left: -4px; right: -4px; bottom: -4px;
|
| 140 |
+
border: 1px solid #D4AF37;
|
| 141 |
+
pointer-events: none;
|
| 142 |
+
opacity: 0.5;
|
| 143 |
+
}
|
| 144 |
+
.horse-bg-clip {
|
| 145 |
+
background-clip: text;
|
| 146 |
+
-webkit-background-clip: text;
|
| 147 |
+
color: transparent;
|
| 148 |
+
background-image: linear-gradient(to right, #D4AF37, #F9E79F, #D4AF37);
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
</style>
|
| 152 |
+
</head>
|
| 153 |
+
<body class="bg-paper-white font-serif text-ink-black antialiased selection:bg-imperial-red/20 selection:text-imperial-red bg-cloud-pattern min-h-screen">
|
| 154 |
+
<div class="relative flex min-h-screen w-full flex-col overflow-x-hidden bg-gradient-to-b from-imperial-red/5 to-transparent">
|
| 155 |
+
<header class="sticky top-0 z-50 flex h-[64px] w-full items-center justify-between bg-paper-white/95 px-4 md:px-8 backdrop-blur-md border-b-2 border-china-gold/30 shadow-sm">
|
| 156 |
+
<div class="mx-auto flex w-full max-w-[1280px] items-center justify-between">
|
| 157 |
+
<div class="flex items-center gap-3 hover:opacity-80 transition-opacity cursor-pointer">
|
| 158 |
+
<div class="flex items-center justify-center size-10 rounded-full border border-china-gold bg-imperial-red text-china-gold">
|
| 159 |
+
<span class="material-symbols-outlined text-[24px]">bedroom_baby</span>
|
| 160 |
+
</div>
|
| 161 |
+
<span class="text-[20px] font-display font-bold tracking-wide text-imperial-red">Medical <span class="text-china-gold">VQA</span></span>
|
| 162 |
+
</div>
|
| 163 |
+
<nav class="hidden md:flex items-center gap-10">
|
| 164 |
+
<a class="text-[14px] font-medium text-ink-black/70 hover:text-imperial-red transition-colors uppercase tracking-widest" href="#upload">Upload</a>
|
| 165 |
+
<a class="text-[14px] font-medium text-ink-black/70 hover:text-imperial-red transition-colors uppercase tracking-widest" href="#results">Models</a>
|
| 166 |
+
<a class="text-[14px] font-medium text-ink-black/70 hover:text-imperial-red transition-colors uppercase tracking-widest" href="#results">Results</a>
|
| 167 |
+
</nav>
|
| 168 |
+
<div class="flex items-center gap-4">
|
| 169 |
+
<button class="hidden md:flex h-9 items-center justify-center rounded-sm border border-imperial-red bg-transparent px-5 text-[13px] font-bold text-imperial-red transition-all hover:bg-imperial-red hover:text-paper-white uppercase tracking-wider">
|
| 170 |
+
X2 Vision
|
| 171 |
+
</button>
|
| 172 |
+
</div>
|
| 173 |
+
</div>
|
| 174 |
+
</header>
|
| 175 |
+
|
| 176 |
+
<main class="flex flex-1 flex-col items-center pt-12 pb-24 px-4 sm:px-6">
|
| 177 |
+
<div class="flex flex-col items-center text-center max-w-4xl mx-auto mb-14">
|
| 178 |
+
<div class="mb-4 flex items-center gap-2">
|
| 179 |
+
<div class="h-[1px] w-12 bg-china-gold"></div>
|
| 180 |
+
<span class="text-china-gold font-display text-sm tracking-[0.2em] uppercase">6-model comparison</span>
|
| 181 |
+
<div class="h-[1px] w-12 bg-china-gold"></div>
|
| 182 |
+
</div>
|
| 183 |
+
<h1 class="text-imperial-red text-[42px] md:text-[64px] font-display font-bold leading-[1.1] tracking-tight mb-6 drop-shadow-sm">
|
| 184 |
+
Medical<br/>
|
| 185 |
+
|
| 186 |
+
<span class="whitespace-nowrap">Visual Question Answering</span>
|
| 187 |
+
</h1>
|
| 188 |
+
<p class="text-ink-black/70 text-[18px] md:text-[20px] font-light leading-relaxed max-w-3xl font-serif italic">
|
| 189 |
+
|
| 190 |
+
</p>
|
| 191 |
+
<div class="mt-8 scene-3d relative w-full max-w-[760px]">
|
| 192 |
+
<div class="absolute inset-0 rounded-full bg-imperial-red/10 blur-3xl pulse-ring"></div>
|
| 193 |
+
<div class="relative mx-auto flex items-center justify-center gap-5 md:gap-8">
|
| 194 |
+
<div class="tilt-card float-slow depth-line medical-glow rounded-full border border-china-gold/35 bg-paper-white/95 px-5 py-4 flex items-center gap-3">
|
| 195 |
+
<span class="material-symbols-outlined text-[30px] text-imperial-red">medical_services</span>
|
| 196 |
+
<div class="text-left">
|
| 197 |
+
<div class="text-[11px] uppercase tracking-[0.22em] text-china-gold font-bold">Clinical</div>
|
| 198 |
+
<div class="font-display font-bold text-ink-black">Assist</div>
|
| 199 |
+
</div>
|
| 200 |
+
</div>
|
| 201 |
+
<div class="tilt-card float-med depth-line medical-glow rounded-full border border-china-gold/35 bg-paper-white/95 px-5 py-4 flex items-center gap-3">
|
| 202 |
+
<span class="material-symbols-outlined text-[30px] text-imperial-red spin-slow">monitor_heart</span>
|
| 203 |
+
<div class="text-left">
|
| 204 |
+
<div class="text-[11px] uppercase tracking-[0.22em] text-china-gold font-bold">Vitals</div>
|
| 205 |
+
<div class="font-display font-bold text-ink-black">Heartbeat</div>
|
| 206 |
+
</div>
|
| 207 |
+
</div>
|
| 208 |
+
<div class="tilt-card float-fast depth-line medical-glow rounded-full border border-china-gold/35 bg-paper-white/95 px-5 py-4 flex items-center gap-3">
|
| 209 |
+
<span class="material-symbols-outlined text-[30px] text-imperial-red">biotech</span>
|
| 210 |
+
<div class="text-left">
|
| 211 |
+
<div class="text-[11px] uppercase tracking-[0.22em] text-china-gold font-bold">Imaging</div>
|
| 212 |
+
<div class="font-display font-bold text-ink-black">Analyzer</div>
|
| 213 |
+
</div>
|
| 214 |
+
</div>
|
| 215 |
+
</div>
|
| 216 |
+
</div>
|
| 217 |
+
</div>
|
| 218 |
+
|
| 219 |
+
<div id="upload" class="w-full max-w-[1280px] bg-paper-white rounded-none shadow-gold-glow ornate-border flex flex-col lg:flex-row relative">
|
| 220 |
+
<div class="absolute -top-2 -left-2 size-8 border-t-4 border-l-4 border-imperial-red z-10"></div>
|
| 221 |
+
<div class="absolute -top-2 -right-2 size-8 border-t-4 border-r-4 border-imperial-red z-10"></div>
|
| 222 |
+
<div class="absolute -bottom-2 -left-2 size-8 border-b-4 border-l-4 border-imperial-red z-10"></div>
|
| 223 |
+
<div class="absolute -bottom-2 -right-2 size-8 border-b-4 border-r-4 border-imperial-red z-10"></div>
|
| 224 |
+
|
| 225 |
+
<div class="w-full lg:w-[42%] p-8 md:p-12 flex flex-col border-b lg:border-b-0 lg:border-r border-china-gold/30 bg-[url('https://www.transparenttextures.com/patterns/rice-paper-2.png')]">
|
| 226 |
+
<div class="flex items-center justify-between mb-6">
|
| 227 |
+
<h3 class="text-[20px] font-display font-bold text-ink-black border-l-4 border-imperial-red pl-3">Source Scroll</h3>
|
| 228 |
+
<button id="reset-btn" class="text-imperial-red text-sm font-medium hover:text-deep-crimson flex items-center gap-1 transition-colors">
|
| 229 |
+
<span class="material-symbols-outlined text-[18px]">restart_alt</span>
|
| 230 |
+
Reset
|
| 231 |
+
</button>
|
| 232 |
+
</div>
|
| 233 |
+
|
| 234 |
+
<div id="dropzone" class="relative group w-full aspect-square md:aspect-[4/3] bg-[#F2EFE9] border-2 border-dashed border-china-gold/60 flex items-center justify-center transition-all hover:border-imperial-red hover:bg-white cursor-pointer shadow-inner overflow-hidden">
|
| 235 |
+
<div class="absolute inset-2 border border-china-gold/20 pointer-events-none"></div>
|
| 236 |
+
<div id="dropzone-empty" class="flex flex-col items-center gap-4 z-10 p-6 text-center">
|
| 237 |
+
<div class="size-16 rounded-full bg-imperial-red/5 flex items-center justify-center text-imperial-red mb-2">
|
| 238 |
+
<span class="material-symbols-outlined text-4xl">cloud_upload</span>
|
| 239 |
+
</div>
|
| 240 |
+
<div class="space-y-2">
|
| 241 |
+
<p class="text-ink-black font-display font-semibold text-lg">Upload Image</p>
|
| 242 |
+
<p class="text-ink-black/50 text-sm font-serif italic">JPG, PNG, WEBP accepted</p>
|
| 243 |
+
</div>
|
| 244 |
+
</div>
|
| 245 |
+
<img id="preview" class="absolute inset-0 h-full w-full object-contain bg-white hidden" alt="Preview"/>
|
| 246 |
+
<input id="image-input" aria-label="Upload Image" class="absolute inset-0 opacity-0 cursor-pointer" type="file" accept="image/*"/>
|
| 247 |
+
</div>
|
| 248 |
+
|
| 249 |
+
</div>
|
| 250 |
+
|
| 251 |
+
<div class="w-full lg:w-[58%] p-8 md:p-12 flex flex-col bg-paper-white bg-[url('https://www.transparenttextures.com/patterns/rice-paper-2.png')]">
|
| 252 |
+
<div class="mb-6">
|
| 253 |
+
<h3 class="text-[20px] font-display font-bold text-ink-black border-l-4 border-imperial-red pl-3 mb-2">Inquiry</h3>
|
| 254 |
+
<p class="text-ink-black/60 text-sm italic font-serif">Ask one question and compare every model response in parallel.</p>
|
| 255 |
+
</div>
|
| 256 |
+
|
| 257 |
+
<div class="flex-1 flex flex-col gap-6">
|
| 258 |
+
<label class="relative flex-1">
|
| 259 |
+
<textarea id="question" class="w-full h-40 md:h-full resize-none border border-china-gold/40 bg-[#F9F7F2] p-6 text-[18px] text-ink-black placeholder:text-ink-black/30 focus:border-imperial-red focus:ring-1 focus:ring-imperial-red focus:outline-none transition-shadow font-serif leading-relaxed" placeholder="What abnormality is visible in the image? / Có bất thường gì không?"></textarea>
|
| 260 |
+
<div class="absolute top-0 right-0 p-2 opacity-10 pointer-events-none">
|
| 261 |
+
<span class="material-symbols-outlined text-6xl text-imperial-red">edit_note</span>
|
| 262 |
+
</div>
|
| 263 |
+
<div class="absolute bottom-3 right-3 text-xs text-china-gold font-display" id="char-count">0/200 Characters</div>
|
| 264 |
+
</label>
|
| 265 |
+
|
| 266 |
+
<div class="flex flex-wrap items-center gap-2 pt-1">
|
| 267 |
+
<span class="text-[12px] md:text-[13px] uppercase tracking-[0.24em] text-china-gold font-bold mr-1">Gợi ý:</span>
|
| 268 |
+
<div id="suggestions-row" class="flex flex-wrap gap-2"></div>
|
| 269 |
+
</div>
|
| 270 |
+
|
| 271 |
+
<div class="space-y-5 pt-2">
|
| 272 |
+
<div class="flex items-center gap-3">
|
| 273 |
+
<span class="text-xs font-bold uppercase tracking-widest text-china-gold">Model set:</span>
|
| 274 |
+
<div class="flex gap-2 overflow-x-auto pb-1 no-scrollbar">
|
| 275 |
+
<button type="button" class="model-chip whitespace-nowrap border border-china-gold/30 bg-white px-4 py-1.5 text-xs font-medium text-ink-black hover:border-imperial-red hover:text-imperial-red transition-colors font-serif" data-model="A1">A1</button>
|
| 276 |
+
<button type="button" class="model-chip whitespace-nowrap border border-china-gold/30 bg-white px-4 py-1.5 text-xs font-medium text-ink-black hover:border-imperial-red hover:text-imperial-red transition-colors font-serif" data-model="A2">A2</button>
|
| 277 |
+
<button type="button" class="model-chip whitespace-nowrap border border-china-gold/30 bg-white px-4 py-1.5 text-xs font-medium text-ink-black hover:border-imperial-red hover:text-imperial-red transition-colors font-serif" data-model="B1">B1</button>
|
| 278 |
+
<button type="button" class="model-chip whitespace-nowrap border border-china-gold/30 bg-white px-4 py-1.5 text-xs font-medium text-ink-black hover:border-imperial-red hover:text-imperial-red transition-colors font-serif" data-model="B2">B2</button>
|
| 279 |
+
<button type="button" class="model-chip whitespace-nowrap border border-china-gold/30 bg-white px-4 py-1.5 text-xs font-medium text-ink-black hover:border-imperial-red hover:text-imperial-red transition-colors font-serif" data-model="DPO">DPO</button>
|
| 280 |
+
<button type="button" class="model-chip whitespace-nowrap border border-china-gold/30 bg-white px-4 py-1.5 text-xs font-medium text-ink-black hover:border-imperial-red hover:text-imperial-red transition-colors font-serif" data-model="PPO">PPO</button>
|
| 281 |
+
</div>
|
| 282 |
+
</div>
|
| 283 |
+
|
| 284 |
+
<button id="run-btn" class="group w-full bg-gradient-to-r from-imperial-red to-deep-crimson hover:from-red-700 hover:to-red-900 py-4 px-6 text-[18px] font-bold text-gold-light shadow-red-glow transition-all active:scale-[0.99] flex items-center justify-center gap-3 border border-china-gold relative overflow-hidden">
|
| 285 |
+
<div class="absolute inset-0 bg-[url('https://www.transparenttextures.com/patterns/black-scales.png')] opacity-10"></div>
|
| 286 |
+
<span class="relative z-10 font-display tracking-widest uppercase">Run Comparison</span>
|
| 287 |
+
<span class="material-symbols-outlined text-[24px] relative z-10 group-hover:rotate-12 transition-transform text-gold-light">savings</span>
|
| 288 |
+
<span class="material-symbols-outlined absolute right-6 text-[28px] opacity-20 group-hover:opacity-40 transition-opacity text-gold-light">chess_knight</span>
|
| 289 |
+
</button>
|
| 290 |
+
|
| 291 |
+
<div class="text-center text-sm font-serif italic text-ink-black/60" id="status-text">Select an image, enter a question, then run all six models.</div>
|
| 292 |
+
</div>
|
| 293 |
+
</div>
|
| 294 |
+
</div>
|
| 295 |
+
</div>
|
| 296 |
+
|
| 297 |
+
<section id="results" class="mt-20 w-full max-w-[1280px]">
|
| 298 |
+
<div class="flex flex-col items-center text-center mb-8">
|
| 299 |
+
<div class="h-[1px] w-24 bg-china-gold"></div>
|
| 300 |
+
<h2 class="mt-4 text-imperial-red text-[30px] md:text-[40px] font-display font-bold tracking-tight">Outputs</h2>
|
| 301 |
+
<p class="mt-3 max-w-3xl text-ink-black/65 italic font-serif">
|
| 302 |
+
Six models, six output cards, one result per card.
|
| 303 |
+
</p>
|
| 304 |
+
</div>
|
| 305 |
+
|
| 306 |
+
<div id="results-grid" class="grid grid-cols-1 md:grid-cols-2 xl:grid-cols-3 gap-6"></div>
|
| 307 |
+
</section>
|
| 308 |
+
|
| 309 |
+
<div class="mt-24 grid grid-cols-1 md:grid-cols-3 gap-8 w-full max-w-[1280px] relative">
|
| 310 |
+
<div class="absolute -top-12 left-1/2 -translate-x-1/2 w-24 h-1 bg-china-gold rounded-full"></div>
|
| 311 |
+
<div class="flex flex-col gap-4 p-8 bg-paper-white border border-china-gold/20 shadow-sm hover:shadow-gold-glow transition-shadow duration-300 relative overflow-hidden group">
|
| 312 |
+
<div class="absolute top-0 left-0 w-full h-1 bg-gradient-to-r from-transparent via-imperial-red to-transparent opacity-0 group-hover:opacity-100 transition-opacity"></div>
|
| 313 |
+
<div class="size-12 flex items-center justify-center text-imperial-red mb-2 border border-china-gold/30 rounded-full bg-gold-light/20">
|
| 314 |
+
<span class="material-symbols-outlined text-2xl">neurology</span>
|
| 315 |
+
</div>
|
| 316 |
+
<h4 class="text-[18px] font-display font-bold text-ink-black">A1 / A2</h4>
|
| 317 |
+
<p class="text-[15px] leading-relaxed text-ink-black/70 font-serif">
|
| 318 |
+
Closed-vocab models with separate answer heads. The new UI gives each model a dedicated response card.
|
| 319 |
+
</p>
|
| 320 |
+
</div>
|
| 321 |
+
<div class="flex flex-col gap-4 p-8 bg-paper-white border border-china-gold/20 shadow-sm hover:shadow-gold-glow transition-shadow duration-300 relative overflow-hidden group">
|
| 322 |
+
<div class="absolute top-0 left-0 w-full h-1 bg-gradient-to-r from-transparent via-imperial-red to-transparent opacity-0 group-hover:opacity-100 transition-opacity"></div>
|
| 323 |
+
<div class="size-12 flex items-center justify-center text-imperial-red mb-2 border border-china-gold/30 rounded-full bg-gold-light/20">
|
| 324 |
+
<span class="material-symbols-outlined text-2xl">bolt</span>
|
| 325 |
+
</div>
|
| 326 |
+
<h4 class="text-[18px] font-display font-bold text-ink-black">B1 / B2</h4>
|
| 327 |
+
<p class="text-[15px] leading-relaxed text-ink-black/70 font-serif">
|
| 328 |
+
Zero-shot and fine-tuned LLaVA models are compared side by side with latency and raw answer displayed.
|
| 329 |
+
</p>
|
| 330 |
+
</div>
|
| 331 |
+
<div class="flex flex-col gap-4 p-8 bg-paper-white border border-china-gold/20 shadow-sm hover:shadow-gold-glow transition-shadow duration-300 relative overflow-hidden group">
|
| 332 |
+
<div class="absolute top-0 left-0 w-full h-1 bg-gradient-to-r from-transparent via-imperial-red to-transparent opacity-0 group-hover:opacity-100 transition-opacity"></div>
|
| 333 |
+
<div class="size-12 flex items-center justify-center text-imperial-red mb-2 border border-china-gold/30 rounded-full bg-gold-light/20">
|
| 334 |
+
<span class="material-symbols-outlined text-2xl">verified</span>
|
| 335 |
+
</div>
|
| 336 |
+
<h4 class="text-[18px] font-display font-bold text-ink-black">DPO / PPO</h4>
|
| 337 |
+
<p class="text-[15px] leading-relaxed text-ink-black/70 font-serif">
|
| 338 |
+
Alignment and RL variants now have equal room in the grid, making the comparison feel intentional.
|
| 339 |
+
</p>
|
| 340 |
+
</div>
|
| 341 |
+
</div>
|
| 342 |
+
</main>
|
| 343 |
+
|
| 344 |
+
<footer class="w-full border-t-4 border-imperial-red bg-ink-black text-paper-white py-12">
|
| 345 |
+
<div class="mx-auto flex max-w-[1280px] flex-col md:flex-row items-center justify-between gap-8 px-4 md:px-0">
|
| 346 |
+
<div class="flex flex-col gap-2 md:items-start items-center">
|
| 347 |
+
<div class="flex items-center gap-2 mb-2">
|
| 348 |
+
<span class="material-symbols-outlined text-china-gold">chess_knight</span>
|
| 349 |
+
<span class="font-display font-bold text-lg tracking-wider">VQA RESEARCH</span>
|
| 350 |
+
</div>
|
| 351 |
+
<div class="text-[13px] text-paper-white/60 font-serif">
|
| 352 |
+
Medical VQA web demo for six-model comparison.
|
| 353 |
+
</div>
|
| 354 |
+
</div>
|
| 355 |
+
<div class="flex gap-8 text-[13px] text-paper-white/80 font-display tracking-widest uppercase">
|
| 356 |
+
<a class="hover:text-china-gold transition-colors" href="#upload">Upload</a>
|
| 357 |
+
<a class="hover:text-china-gold transition-colors" href="#results">Results</a>
|
| 358 |
+
</div>
|
| 359 |
+
</div>
|
| 360 |
+
</footer>
|
| 361 |
+
</div>
|
| 362 |
+
|
| 363 |
+
<script>
|
| 364 |
+
const MODEL_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO"];
|
| 365 |
+
const MODEL_META = {
|
| 366 |
+
A1: { name: "A1", title: "LSTM Baseline", note: "DenseNet-121 + PhoBERT + LSTM", icon: "neurology" },
|
| 367 |
+
A2: { name: "A2", title: "Transformer Decoder", note: "DenseNet-121 + PhoBERT + Transformer", icon: "schema" },
|
| 368 |
+
B1: { name: "B1", title: "Zero-shot", note: "LLaVA-Med base", icon: "visibility" },
|
| 369 |
+
B2: { name: "B2", title: "Fine-tuned", note: "LLaVA-Med + LoRA", icon: "precision_manufacturing" },
|
| 370 |
+
DPO: { name: "DPO", title: "Alignment", note: "B2 + DPO", icon: "verified" },
|
| 371 |
+
PPO: { name: "PPO", title: "RL refinement", note: "B2 + PPO", icon: "syringe" },
|
| 372 |
+
};
|
| 373 |
+
|
| 374 |
+
const el = {
|
| 375 |
+
imageInput: document.getElementById("image-input"),
|
| 376 |
+
preview: document.getElementById("preview"),
|
| 377 |
+
dropzoneEmpty: document.getElementById("dropzone-empty"),
|
| 378 |
+
dropzone: document.getElementById("dropzone"),
|
| 379 |
+
question: document.getElementById("question"),
|
| 380 |
+
charCount: document.getElementById("char-count"),
|
| 381 |
+
suggestionsRow: document.getElementById("suggestions-row"),
|
| 382 |
+
runBtn: document.getElementById("run-btn"),
|
| 383 |
+
resetBtn: document.getElementById("reset-btn"),
|
| 384 |
+
statusText: document.getElementById("status-text"),
|
| 385 |
+
resultsGrid: document.getElementById("results-grid"),
|
| 386 |
+
};
|
| 387 |
+
|
| 388 |
+
let currentImageFile = null;
|
| 389 |
+
let selectedModels = new Set(MODEL_ORDER);
|
| 390 |
+
let questionSuggestions = [];
|
| 391 |
+
|
| 392 |
+
function escapeHtml(value) {
|
| 393 |
+
return String(value ?? "")
|
| 394 |
+
.replaceAll("&", "&")
|
| 395 |
+
.replaceAll("<", "<")
|
| 396 |
+
.replaceAll(">", ">")
|
| 397 |
+
.replaceAll('"', """);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
function updateCharCount() {
|
| 401 |
+
el.charCount.textContent = `${el.question.value.length}/200 Characters`;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
function setStatus(message) {
|
| 405 |
+
el.statusText.textContent = message;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
function setPreview(file) {
|
| 409 |
+
currentImageFile = file || null;
|
| 410 |
+
if (!file) {
|
| 411 |
+
el.preview.classList.add("hidden");
|
| 412 |
+
el.dropzoneEmpty.classList.remove("hidden");
|
| 413 |
+
el.preview.src = "";
|
| 414 |
+
return;
|
| 415 |
+
}
|
| 416 |
+
const url = URL.createObjectURL(file);
|
| 417 |
+
el.preview.src = url;
|
| 418 |
+
el.preview.classList.remove("hidden");
|
| 419 |
+
el.dropzoneEmpty.classList.add("hidden");
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
function fillQuestion(question) {
|
| 423 |
+
el.question.value = question || "";
|
| 424 |
+
updateCharCount();
|
| 425 |
+
el.question.focus();
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
function renderQuestionSuggestions(items) {
|
| 429 |
+
questionSuggestions = items || [];
|
| 430 |
+
if (!questionSuggestions.length) {
|
| 431 |
+
el.suggestionsRow.innerHTML = "";
|
| 432 |
+
return;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
el.suggestionsRow.innerHTML = questionSuggestions.map((item, index) => {
|
| 436 |
+
const question = escapeHtml(item.question || "");
|
| 437 |
+
return `
|
| 438 |
+
<button type="button" class="hint-chip inline-flex items-center gap-2 rounded-full bg-transparent px-2 py-1 text-left text-[12px] leading-tight text-ink-black/50 hover:bg-imperial-red/5 hover:text-imperial-red transition-all" data-suggestion-index="${index}">
|
| 439 |
+
<span class="size-1.5 rounded-full bg-imperial-red/70"></span>
|
| 440 |
+
<span class="truncate max-w-[280px] font-serif">${question}</span>
|
| 441 |
+
</button>
|
| 442 |
+
`;
|
| 443 |
+
}).join("");
|
| 444 |
+
|
| 445 |
+
el.suggestionsRow.querySelectorAll("[data-suggestion-index]").forEach((button) => {
|
| 446 |
+
button.addEventListener("click", () => {
|
| 447 |
+
const index = Number(button.dataset.suggestionIndex);
|
| 448 |
+
const item = questionSuggestions[index];
|
| 449 |
+
if (!item) return;
|
| 450 |
+
fillQuestion(item.question);
|
| 451 |
+
setStatus(`Loaded suggested question.`);
|
| 452 |
+
});
|
| 453 |
+
});
|
| 454 |
+
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
async function loadQuestionSuggestions() {
|
| 458 |
+
if (questionSuggestions.length) {
|
| 459 |
+
return;
|
| 460 |
+
}
|
| 461 |
+
el.suggestionsRow.innerHTML = "";
|
| 462 |
+
|
| 463 |
+
try {
|
| 464 |
+
const res = await fetch("/v1/question-suggestions?limit=8");
|
| 465 |
+
const data = await res.json();
|
| 466 |
+
renderQuestionSuggestions(data.suggestions || []);
|
| 467 |
+
} catch (err) {
|
| 468 |
+
el.suggestionsRow.innerHTML = "";
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
function renderModelGrid(results) {
|
| 473 |
+
const byVariant = Object.fromEntries(results.map((r) => [r.variant, r]));
|
| 474 |
+
|
| 475 |
+
el.resultsGrid.innerHTML = MODEL_ORDER.map((variant) => {
|
| 476 |
+
const meta = MODEL_META[variant];
|
| 477 |
+
const res = byVariant[variant];
|
| 478 |
+
const status = res ? res.status : "not requested";
|
| 479 |
+
const ok = res && res.status === "ok";
|
| 480 |
+
const answer = res ? (res.prediction || res.status) : "Not requested";
|
| 481 |
+
const cardTone = ok ? "border-emerald-200/70 shadow-[0_18px_40px_rgba(16,185,129,0.10)]" : res ? "border-rose-200/70 shadow-[0_18px_40px_rgba(244,63,94,0.08)]" : "border-china-gold/25 shadow-sm";
|
| 482 |
+
const answerTone = ok ? "text-ink-black" : res ? "text-rose-700" : "text-amber-700";
|
| 483 |
+
return `
|
| 484 |
+
<article class="tilt-card bg-paper-white border ${cardTone} p-5 md:p-6 flex flex-col gap-4 relative overflow-hidden">
|
| 485 |
+
<div class="absolute inset-x-0 top-0 h-1 bg-gradient-to-r from-transparent via-imperial-red to-transparent ${ok ? 'opacity-100' : 'opacity-45'}"></div>
|
| 486 |
+
<div class="flex items-center justify-between gap-4">
|
| 487 |
+
<div class="flex items-center gap-3">
|
| 488 |
+
<div class="size-11 rounded-full border flex items-center justify-center ${ok ? 'bg-emerald-50 text-emerald-700 border-emerald-200' : res ? 'bg-rose-50 text-rose-700 border-rose-200' : 'bg-amber-50 text-amber-700 border-amber-200'} ${ok ? 'pulse-ring' : ''}">
|
| 489 |
+
<span class="material-symbols-outlined text-[22px]">${meta.icon}</span>
|
| 490 |
+
</div>
|
| 491 |
+
<div>
|
| 492 |
+
<div class="text-[11px] uppercase tracking-[0.2em] text-china-gold font-bold">${meta.name}</div>
|
| 493 |
+
<div class="text-[15px] font-display font-bold text-ink-black">${meta.title}</div>
|
| 494 |
+
</div>
|
| 495 |
+
</div>
|
| 496 |
+
<span class="text-[11px] uppercase tracking-[0.18em] font-bold ${ok ? 'text-emerald-700' : res ? 'text-rose-700' : 'text-amber-700'}">
|
| 497 |
+
${res ? (ok ? "Output" : "Error") : "Idle"}
|
| 498 |
+
</span>
|
| 499 |
+
</div>
|
| 500 |
+
|
| 501 |
+
<div class="min-h-[120px] rounded-none border border-china-gold/20 bg-[#FAF7F0] p-5 flex items-center">
|
| 502 |
+
<p class="text-[18px] md:text-[20px] leading-relaxed font-serif ${answerTone}">
|
| 503 |
+
${escapeHtml(answer)}
|
| 504 |
+
</p>
|
| 505 |
+
</div>
|
| 506 |
+
|
| 507 |
+
<div class="flex items-center justify-between text-[12px] text-ink-black/55">
|
| 508 |
+
<span>${escapeHtml(res ? (res.prediction_raw || "") : "")}</span>
|
| 509 |
+
<span>${escapeHtml(status)}</span>
|
| 510 |
+
</div>
|
| 511 |
+
</article>
|
| 512 |
+
`;
|
| 513 |
+
}).join("");
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
function updateModelChips() {
|
| 517 |
+
document.querySelectorAll(".model-chip").forEach((chip) => {
|
| 518 |
+
const variant = chip.dataset.model;
|
| 519 |
+
const active = selectedModels.has(variant);
|
| 520 |
+
chip.style.background = active ? "#A8181B" : "#fff";
|
| 521 |
+
chip.style.color = active ? "#FDFBF7" : "#1A1A1A";
|
| 522 |
+
chip.style.borderColor = active ? "#A8181B" : "rgba(212,175,55,0.35)";
|
| 523 |
+
});
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
function applyTiltEffect(selector, maxRotate = 6) {
|
| 527 |
+
document.querySelectorAll(selector).forEach((card) => {
|
| 528 |
+
if (card.dataset.tiltBound === "1") return;
|
| 529 |
+
card.dataset.tiltBound = "1";
|
| 530 |
+
card.addEventListener("mousemove", (e) => {
|
| 531 |
+
const rect = card.getBoundingClientRect();
|
| 532 |
+
const x = (e.clientX - rect.left) / rect.width;
|
| 533 |
+
const y = (e.clientY - rect.top) / rect.height;
|
| 534 |
+
const rotateY = (x - 0.5) * maxRotate * 2;
|
| 535 |
+
const rotateX = (0.5 - y) * maxRotate * 2;
|
| 536 |
+
card.style.transform = `rotateX(${rotateX}deg) rotateY(${rotateY}deg) translateY(-2px)`;
|
| 537 |
+
});
|
| 538 |
+
card.addEventListener("mouseleave", () => {
|
| 539 |
+
card.style.transform = "";
|
| 540 |
+
});
|
| 541 |
+
});
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
async function loadModels() {
|
| 545 |
+
try {
|
| 546 |
+
const res = await fetch("/v1/models");
|
| 547 |
+
const data = await res.json();
|
| 548 |
+
updateModelChips();
|
| 549 |
+
setStatus("Ready. Upload an image and run all six models.");
|
| 550 |
+
} catch (err) {
|
| 551 |
+
setStatus(`Failed to load model metadata: ${err.message}`);
|
| 552 |
+
}
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
el.imageInput.addEventListener("change", (e) => {
|
| 556 |
+
setPreview(e.target.files?.[0]);
|
| 557 |
+
});
|
| 558 |
+
|
| 559 |
+
el.dropzone.addEventListener("click", () => {
|
| 560 |
+
el.imageInput.click();
|
| 561 |
+
});
|
| 562 |
+
|
| 563 |
+
el.dropzone.addEventListener("dragover", (e) => {
|
| 564 |
+
e.preventDefault();
|
| 565 |
+
el.dropzone.classList.add("ring-2", "ring-imperial-red/30");
|
| 566 |
+
});
|
| 567 |
+
el.dropzone.addEventListener("dragleave", () => {
|
| 568 |
+
el.dropzone.classList.remove("ring-2", "ring-imperial-red/30");
|
| 569 |
+
});
|
| 570 |
+
el.dropzone.addEventListener("drop", (e) => {
|
| 571 |
+
e.preventDefault();
|
| 572 |
+
el.dropzone.classList.remove("ring-2", "ring-imperial-red/30");
|
| 573 |
+
const file = e.dataTransfer.files?.[0];
|
| 574 |
+
if (file) {
|
| 575 |
+
const dt = new DataTransfer();
|
| 576 |
+
dt.items.add(file);
|
| 577 |
+
el.imageInput.files = dt.files;
|
| 578 |
+
setPreview(file);
|
| 579 |
+
}
|
| 580 |
+
});
|
| 581 |
+
|
| 582 |
+
el.question.addEventListener("input", updateCharCount);
|
| 583 |
+
el.question.addEventListener("focus", loadQuestionSuggestions, { once: true });
|
| 584 |
+
|
| 585 |
+
document.querySelectorAll(".model-chip").forEach((chip) => {
|
| 586 |
+
chip.addEventListener("click", () => {
|
| 587 |
+
const variant = chip.dataset.model;
|
| 588 |
+
if (selectedModels.has(variant)) selectedModels.delete(variant);
|
| 589 |
+
else selectedModels.add(variant);
|
| 590 |
+
if (selectedModels.size === 0) {
|
| 591 |
+
selectedModels = new Set(MODEL_ORDER);
|
| 592 |
+
}
|
| 593 |
+
updateModelChips();
|
| 594 |
+
});
|
| 595 |
+
});
|
| 596 |
+
|
| 597 |
+
el.resetBtn.addEventListener("click", () => {
|
| 598 |
+
selectedModels = new Set(MODEL_ORDER);
|
| 599 |
+
el.question.value = "";
|
| 600 |
+
el.imageInput.value = "";
|
| 601 |
+
setPreview(null);
|
| 602 |
+
updateCharCount();
|
| 603 |
+
updateModelChips();
|
| 604 |
+
el.resultsGrid.innerHTML = "";
|
| 605 |
+
setStatus("Reset complete.");
|
| 606 |
+
});
|
| 607 |
+
|
| 608 |
+
el.runBtn.addEventListener("click", async () => {
|
| 609 |
+
if (!currentImageFile) {
|
| 610 |
+
setStatus("Please upload an image first.");
|
| 611 |
+
return;
|
| 612 |
+
}
|
| 613 |
+
if (!el.question.value.trim()) {
|
| 614 |
+
setStatus("Please enter a question.");
|
| 615 |
+
return;
|
| 616 |
+
}
|
| 617 |
+
if (selectedModels.size === 0) {
|
| 618 |
+
setStatus("Please select at least one model.");
|
| 619 |
+
return;
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
el.runBtn.disabled = true;
|
| 623 |
+
el.runBtn.querySelector("span").textContent = "Running...";
|
| 624 |
+
setStatus("Running all selected models...");
|
| 625 |
+
|
| 626 |
+
try {
|
| 627 |
+
const formData = new FormData();
|
| 628 |
+
formData.append("question", el.question.value.trim());
|
| 629 |
+
formData.append("model_names", JSON.stringify(Array.from(selectedModels)));
|
| 630 |
+
formData.append("image", currentImageFile);
|
| 631 |
+
|
| 632 |
+
const res = await fetch("/v1/predict", { method: "POST", body: formData });
|
| 633 |
+
const data = await res.json();
|
| 634 |
+
if (!res.ok) {
|
| 635 |
+
throw new Error(data?.detail || "Prediction failed");
|
| 636 |
+
}
|
| 637 |
+
renderModelGrid(data.results || [], data.question || el.question.value.trim(), data.summary);
|
| 638 |
+
applyTiltEffect(".tilt-card", 5);
|
| 639 |
+
setStatus(`Done. ${data.summary?.success_count ?? 0} models succeeded.`);
|
| 640 |
+
} catch (err) {
|
| 641 |
+
setStatus(err.message || "Prediction failed");
|
| 642 |
+
} finally {
|
| 643 |
+
el.runBtn.disabled = false;
|
| 644 |
+
el.runBtn.querySelector("span").textContent = "Run Comparison";
|
| 645 |
+
}
|
| 646 |
+
});
|
| 647 |
+
|
| 648 |
+
updateCharCount();
|
| 649 |
+
updateModelChips();
|
| 650 |
+
loadModels();
|
| 651 |
+
loadQuestionSuggestions();
|
| 652 |
+
renderModelGrid([], "", null);
|
| 653 |
+
applyTiltEffect(".tilt-card", 5);
|
| 654 |
+
</script>
|
| 655 |
+
|
| 656 |
+
</body></html>
|