Medical-VQA / src /utils /text_utils.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
raw
history blame
5.96 kB
import re
from collections import Counter
from underthesea import text_normalize as uts_text_normalize, word_tokenize
_MEDICAL_TERM_MAP = {
"xray": "x-quang",
"x ray": "x-quang",
"x-ray": "x-quang",
"x quang": "x-quang",
"mri scan": "mri",
"mr": "mri",
"ct scan": "ct",
"ct-scan": "ct",
"cat scan": "ct",
"computed tomography": "ct",
"transverse plane": "mặt phẳng ngang",
"transverse plane": "mặt phẳng ngang",
"coronal plane": "mặt phẳng vành",
"sagittal plane": "mặt phẳng dọc",
"elliptical": "hình elip",
"spleen": "lách",
"liver": "gan",
"lung": "phổi",
"lungs": "phổi",
"heart": "tim",
"brain": "não",
"kidney": "thận",
"bladder": "bàng quang",
"cardiomegaly": "tim to",
}
_NON_CANONICAL_ALIASES = {
"xray",
"x ray",
"x-ray",
"x quang",
"mri scan",
"mr",
"ct scan",
"ct-scan",
"cat scan",
"computed tomography",
"transverse plane",
"coronal plane",
"sagittal plane",
"elliptical",
"spleen",
"liver",
"lung",
"lungs",
"heart",
"brain",
"kidney",
"bladder",
"cardiomegaly",
}
def text_normalize(text: str) -> str:
"""Wrapper để chuẩn hóa Unicode và spacing cho tiếng Việt."""
if not text:
return ""
return uts_text_normalize(str(text))
def normalize_answer(text: str) -> str:
"""
Chuẩn hóa đáp án về dạng canonical để train/eval ổn định.
"""
if not text:
return ""
text = text_normalize(str(text))
text = text.replace("_", " ")
text = text.lower().strip()
text = re.sub(r"[@#]{1,2}", " ", text)
text = re.sub(r"[“”\"']", "", text)
text = re.sub(r"[,:;!?()\[\]{}]+", " ", text)
text = re.sub(r"\s+", " ", text).strip()
for src, dst in sorted(_MEDICAL_TERM_MAP.items(), key=lambda item: -len(item[0])):
text = re.sub(rf"\b{re.escape(src)}\b", dst, text)
text = re.sub(r"\s+", " ", text).strip()
text = re.sub(r"[.\-]+$", "", text).strip()
return text
def _tokenize_vietnamese_words(text: str) -> list[str]:
normalized = normalize_answer(text)
if not normalized:
return []
try:
tokens = word_tokenize(normalized)
return [token.strip() for token in tokens if token and token.strip()]
except Exception:
return normalized.split()
def count_words(text: str) -> int:
return len(_tokenize_vietnamese_words(text))
def _trim_to_max_words(text: str, max_words: int) -> str:
words = _tokenize_vietnamese_words(text)
if len(words) <= max_words:
return " ".join(words)
return " ".join(words[:max_words])
def _choose_best_answer_text(answer_vi: str, answer_full_vi: str, max_words: int) -> str:
short_answer = normalize_answer(answer_vi)
full_answer = normalize_answer(answer_full_vi)
if short_answer and count_words(short_answer) <= max_words:
return short_answer
if full_answer:
return _trim_to_max_words(full_answer, max_words)
return _trim_to_max_words(short_answer, max_words)
def get_target_answer(item: dict, max_words: int = 10) -> str:
"""
Chọn target answer ngắn, chuẩn hóa và không vượt quá số từ cho phép.
"""
answer_vi = item.get("answer_vi", "")
answer_full_vi = item.get("answer_full_vi", "")
answer = _choose_best_answer_text(answer_vi, answer_full_vi, max_words=max_words)
if answer:
return answer
fallback = item.get("answer", "")
return _trim_to_max_words(fallback, max_words)
def postprocess_answer(text: str, max_words: int = 10) -> str:
"""
Chuẩn hóa output model và cắt ngắn về tối đa `max_words`.
Không mở rộng câu trả lời để tránh làm xấu exact match.
"""
if not text:
return ""
text = clean_vqa_output(text)
text = normalize_answer(text)
return _trim_to_max_words(text, max_words=max_words)
def is_medical_term_compliant(text: str) -> bool:
"""
Heuristic nhẹ: không còn alias y khoa phổ biến chưa canonicalize.
"""
normalized = normalize_answer(text)
if not normalized:
return False
for alias in _NON_CANONICAL_ALIASES:
if re.search(rf"\b{re.escape(alias)}\b", normalized):
return False
return True
def majority_answer(answers: list[str]) -> str:
"""
Trả về câu trả lời xuất hiện nhiều nhất trong danh sách.
"""
if not answers:
return ""
if isinstance(answers, str):
return normalize_answer(answers)
counts = Counter([normalize_answer(a) for a in answers])
return counts.most_common(1)[0][0]
def clean_vqa_output(text: str) -> str:
"""
Làm sạch output từ tokenizer trước khi postprocess.
"""
if not text:
return ""
text = re.sub(r"@@\s?", "", text)
text = re.sub(r"##_?", "", text)
text = re.sub(r"^\s*yes\s*,?\s*", "có ", text, flags=re.IGNORECASE)
text = re.sub(r"^\s*no\s*,?\s*", "không ", text, flags=re.IGNORECASE)
text = re.sub(
r"^\s*(the answer is|the image is|this image is|the scan is|the ct is|the mri is|there is|there are)\s+",
"",
text,
flags=re.IGNORECASE,
)
text = re.sub(
r"^(có|không)\s+(the\s+)?(image|scan|x-ray|xray|mri|ct|picture|photo|radiograph)\s+(is|shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+",
r"\1 ",
text,
flags=re.IGNORECASE,
)
text = re.sub(
r"^(the\s+)?(image|scan|x-ray|xray|mri|ct|picture|photo|radiograph)\s+(is|shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+",
"",
text,
flags=re.IGNORECASE,
)
text = re.sub(r"\b(answer|response|assistant|trả lời)\b\s*:?\s*$", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s+", " ", text).strip()
return text