Medical-VQA / src /utils /answer_rewriter.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
raw
history blame
9.19 kB
import os
from dataclasses import dataclass
import torch
from src.utils.text_utils import postprocess_answer
def _as_bool(value: object, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
return str(value).strip().lower() in {"1", "true", "yes", "y", "on"}
@dataclass
class RewriteConfig:
enabled: bool = False
model_id: str = ""
use_4bit: bool = True
max_new_tokens: int = 28
max_words: int = 10
class MedicalAnswerRewriter:
"""
Rewrite lớp cuối cho VQA output.
Mục tiêu:
- Giữ nguyên ý nghĩa gốc.
- Làm câu trả lời tự nhiên và đầy đủ hơn một chút.
- Vẫn giới hạn tối đa số từ theo cấu hình.
Mô hình này không thay thế VQA model chính.
"""
def __init__(self, config: RewriteConfig | None = None) -> None:
self.config = config or self._load_config()
self._load_attempted = False
self._ready = False
self._tokenizer = None
self._model = None
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@staticmethod
def _load_config() -> RewriteConfig:
model_id = (
os.getenv("ANSWER_REWRITE_MODEL_ID", "").strip()
or os.getenv("QWEN_REWRITE_MODEL_ID", "").strip()
or "Qwen/Qwen2.5-14B-Instruct"
)
enabled = _as_bool(os.getenv("ANSWER_REWRITE_ENABLED"), default=True)
use_4bit = _as_bool(os.getenv("ANSWER_REWRITE_USE_4BIT"), default=True)
max_new_tokens = int(os.getenv("ANSWER_REWRITE_MAX_NEW_TOKENS", "28"))
max_words = int(os.getenv("ANSWER_REWRITE_MAX_WORDS", "10"))
return RewriteConfig(
enabled=enabled,
model_id=model_id,
use_4bit=use_4bit,
max_new_tokens=max_new_tokens,
max_words=max_words,
)
@property
def enabled(self) -> bool:
return bool(self.config.enabled and self.config.model_id)
@property
def model_id(self) -> str:
return self.config.model_id
@property
def ready(self) -> bool:
return self._ready
def _lazy_load(self) -> None:
if self._load_attempted:
return
self._load_attempted = True
if not self.enabled:
return
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
hf_token = (
os.getenv("ANSWER_REWRITE_HF_TOKEN", "").strip()
or os.getenv("HF_TOKEN", "").strip()
or os.getenv("HUGGINGFACE_HUB_TOKEN", "").strip()
or None
)
tokenizer = AutoTokenizer.from_pretrained(self.config.model_id, trust_remote_code=True, token=hf_token)
model_kwargs = {
"trust_remote_code": True,
"low_cpu_mem_usage": True,
}
if self._device.type == "cuda":
if self.config.use_4bit:
try:
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
except Exception as exc:
print(f"[WARNING] Rewrite 4-bit config unavailable, falling back to bf16: {exc}")
model_kwargs["torch_dtype"] = torch.bfloat16
else:
model_kwargs["torch_dtype"] = torch.bfloat16
model_kwargs["device_map"] = "auto"
else:
model_kwargs["torch_dtype"] = torch.float32
if hf_token is not None:
model_kwargs["token"] = hf_token
model = AutoModelForCausalLM.from_pretrained(self.config.model_id, **model_kwargs)
model.eval()
self._tokenizer = tokenizer
self._model = model
self._ready = True
print(f"[INFO] ✅ Answer rewriter ready: {self.config.model_id}")
except Exception as exc:
self._ready = False
print(f"[WARNING] ❌ Answer rewriter load failed: {exc}")
def _build_messages(self, question: str, answer: str, language: str = "vi") -> list[dict[str, str]]:
system_prompt = (
"Bạn là bộ biên tập câu trả lời cho hệ thống Medical VQA. "
"Nhiệm vụ của bạn là viết lại câu trả lời gốc thành một câu ngắn, tự nhiên, "
"rõ nghĩa hơn nhưng KHÔNG thêm thông tin mới ngoài nội dung đã có. "
"Giới hạn tối đa 10 từ. Chỉ trả về câu trả lời cuối cùng."
)
if language.lower().startswith("en"):
system_prompt = (
"You are an editor for a Medical VQA system. "
"Rewrite the raw answer into a short, natural, clearer sentence "
"without adding facts beyond the original answer. "
"Use at most 10 words. Return only the final answer."
)
examples = [
{
"question": "Ảnh này có tràn dịch màng phổi không?",
"answer": "không",
"rewrite": "Không, không có tràn dịch màng phổi.",
},
{
"question": "Hình ảnh có tim to không?",
"answer": "có",
"rewrite": "Có, tim to.",
},
{
"question": "Đây là loại ảnh gì?",
"answer": "x quang ngực",
"rewrite": "X-quang ngực.",
},
]
if language.lower().startswith("en"):
examples = [
{
"question": "Is there pleural effusion?",
"answer": "no",
"rewrite": "No, no pleural effusion.",
},
{
"question": "Is the heart enlarged?",
"answer": "yes",
"rewrite": "Yes, enlarged heart.",
},
{
"question": "What modality is this?",
"answer": "chest x ray",
"rewrite": "Chest X-ray.",
},
]
messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}]
for ex in examples:
messages.append(
{
"role": "user",
"content": f"Câu hỏi: {ex['question']}\nĐáp án gốc: {ex['answer']}",
}
)
messages.append({"role": "assistant", "content": ex["rewrite"]})
user_prompt = f"Câu hỏi: {question}\nĐáp án gốc: {answer}\nViết lại ngắn gọn, tự nhiên, không thêm thông tin mới."
if language.lower().startswith("en"):
user_prompt = (
f"Question: {question}\nRaw answer: {answer}\n"
"Rewrite it into a short, natural answer without adding new facts."
)
messages.append({"role": "user", "content": user_prompt})
return messages
def rewrite(self, question: str, answer: str, language: str = "vi") -> str:
"""
Rewrite câu trả lời để tự nhiên hơn.
Nếu rewrite model không sẵn sàng, trả về output đã postprocess.
"""
if not answer:
return ""
self._lazy_load()
fallback = postprocess_answer(answer, max_words=self.config.max_words)
if not self.enabled or not self._ready:
return fallback
try:
messages = self._build_messages(question=question, answer=answer, language=language)
prompt = self._tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.inference_mode():
output_ids = self._model.generate(
**inputs,
max_new_tokens=self.config.max_new_tokens,
do_sample=False,
temperature=0.1,
repetition_penalty=1.05,
pad_token_id=self._tokenizer.eos_token_id,
)
prompt_len = inputs["input_ids"].shape[1]
generated = self._tokenizer.decode(output_ids[0][prompt_len:], skip_special_tokens=True).strip()
cleaned = postprocess_answer(generated, max_words=self.config.max_words)
return cleaned or fallback
except Exception as exc:
print(f"[WARNING] Rewrite failed: {exc}")
return fallback