Spaces:
Paused
Paused
| import os | |
| from dataclasses import dataclass | |
| import torch | |
| from src.utils.text_utils import postprocess_answer | |
| def _as_bool(value: object, default: bool = False) -> bool: | |
| if value is None: | |
| return default | |
| if isinstance(value, bool): | |
| return value | |
| return str(value).strip().lower() in {"1", "true", "yes", "y", "on"} | |
| class RewriteConfig: | |
| enabled: bool = False | |
| model_id: str = "" | |
| use_4bit: bool = True | |
| max_new_tokens: int = 28 | |
| max_words: int = 10 | |
| class MedicalAnswerRewriter: | |
| """ | |
| Rewrite lớp cuối cho VQA output. | |
| Mục tiêu: | |
| - Giữ nguyên ý nghĩa gốc. | |
| - Làm câu trả lời tự nhiên và đầy đủ hơn một chút. | |
| - Vẫn giới hạn tối đa số từ theo cấu hình. | |
| Mô hình này không thay thế VQA model chính. | |
| """ | |
| def __init__(self, config: RewriteConfig | None = None) -> None: | |
| self.config = config or self._load_config() | |
| self._load_attempted = False | |
| self._ready = False | |
| self._tokenizer = None | |
| self._model = None | |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _load_config() -> RewriteConfig: | |
| model_id = ( | |
| os.getenv("ANSWER_REWRITE_MODEL_ID", "").strip() | |
| or os.getenv("QWEN_REWRITE_MODEL_ID", "").strip() | |
| or "Qwen/Qwen2.5-14B-Instruct" | |
| ) | |
| enabled = _as_bool(os.getenv("ANSWER_REWRITE_ENABLED"), default=True) | |
| use_4bit = _as_bool(os.getenv("ANSWER_REWRITE_USE_4BIT"), default=True) | |
| max_new_tokens = int(os.getenv("ANSWER_REWRITE_MAX_NEW_TOKENS", "28")) | |
| max_words = int(os.getenv("ANSWER_REWRITE_MAX_WORDS", "10")) | |
| return RewriteConfig( | |
| enabled=enabled, | |
| model_id=model_id, | |
| use_4bit=use_4bit, | |
| max_new_tokens=max_new_tokens, | |
| max_words=max_words, | |
| ) | |
| def enabled(self) -> bool: | |
| return bool(self.config.enabled and self.config.model_id) | |
| def model_id(self) -> str: | |
| return self.config.model_id | |
| def ready(self) -> bool: | |
| return self._ready | |
| def _lazy_load(self) -> None: | |
| if self._load_attempted: | |
| return | |
| self._load_attempted = True | |
| if not self.enabled: | |
| return | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| hf_token = ( | |
| os.getenv("ANSWER_REWRITE_HF_TOKEN", "").strip() | |
| or os.getenv("HF_TOKEN", "").strip() | |
| or os.getenv("HUGGINGFACE_HUB_TOKEN", "").strip() | |
| or None | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(self.config.model_id, trust_remote_code=True, token=hf_token) | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "low_cpu_mem_usage": True, | |
| } | |
| if self._device.type == "cuda": | |
| if self.config.use_4bit: | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| except Exception as exc: | |
| print(f"[WARNING] Rewrite 4-bit config unavailable, falling back to bf16: {exc}") | |
| model_kwargs["torch_dtype"] = torch.bfloat16 | |
| else: | |
| model_kwargs["torch_dtype"] = torch.bfloat16 | |
| model_kwargs["device_map"] = "auto" | |
| else: | |
| model_kwargs["torch_dtype"] = torch.float32 | |
| if hf_token is not None: | |
| model_kwargs["token"] = hf_token | |
| model = AutoModelForCausalLM.from_pretrained(self.config.model_id, **model_kwargs) | |
| model.eval() | |
| self._tokenizer = tokenizer | |
| self._model = model | |
| self._ready = True | |
| print(f"[INFO] ✅ Answer rewriter ready: {self.config.model_id}") | |
| except Exception as exc: | |
| self._ready = False | |
| print(f"[WARNING] ❌ Answer rewriter load failed: {exc}") | |
| def _build_messages(self, question: str, answer: str, language: str = "vi") -> list[dict[str, str]]: | |
| system_prompt = ( | |
| "Bạn là bộ biên tập câu trả lời cho hệ thống Medical VQA. " | |
| "Nhiệm vụ của bạn là viết lại câu trả lời gốc thành một câu ngắn, tự nhiên, " | |
| "rõ nghĩa hơn nhưng KHÔNG thêm thông tin mới ngoài nội dung đã có. " | |
| "Giới hạn tối đa 10 từ. Chỉ trả về câu trả lời cuối cùng." | |
| ) | |
| if language.lower().startswith("en"): | |
| system_prompt = ( | |
| "You are an editor for a Medical VQA system. " | |
| "Rewrite the raw answer into a short, natural, clearer sentence " | |
| "without adding facts beyond the original answer. " | |
| "Use at most 10 words. Return only the final answer." | |
| ) | |
| examples = [ | |
| { | |
| "question": "Ảnh này có tràn dịch màng phổi không?", | |
| "answer": "không", | |
| "rewrite": "Không, không có tràn dịch màng phổi.", | |
| }, | |
| { | |
| "question": "Hình ảnh có tim to không?", | |
| "answer": "có", | |
| "rewrite": "Có, tim to.", | |
| }, | |
| { | |
| "question": "Đây là loại ảnh gì?", | |
| "answer": "x quang ngực", | |
| "rewrite": "X-quang ngực.", | |
| }, | |
| ] | |
| if language.lower().startswith("en"): | |
| examples = [ | |
| { | |
| "question": "Is there pleural effusion?", | |
| "answer": "no", | |
| "rewrite": "No, no pleural effusion.", | |
| }, | |
| { | |
| "question": "Is the heart enlarged?", | |
| "answer": "yes", | |
| "rewrite": "Yes, enlarged heart.", | |
| }, | |
| { | |
| "question": "What modality is this?", | |
| "answer": "chest x ray", | |
| "rewrite": "Chest X-ray.", | |
| }, | |
| ] | |
| messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}] | |
| for ex in examples: | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": f"Câu hỏi: {ex['question']}\nĐáp án gốc: {ex['answer']}", | |
| } | |
| ) | |
| messages.append({"role": "assistant", "content": ex["rewrite"]}) | |
| user_prompt = f"Câu hỏi: {question}\nĐáp án gốc: {answer}\nViết lại ngắn gọn, tự nhiên, không thêm thông tin mới." | |
| if language.lower().startswith("en"): | |
| user_prompt = ( | |
| f"Question: {question}\nRaw answer: {answer}\n" | |
| "Rewrite it into a short, natural answer without adding new facts." | |
| ) | |
| messages.append({"role": "user", "content": user_prompt}) | |
| return messages | |
| def rewrite(self, question: str, answer: str, language: str = "vi") -> str: | |
| """ | |
| Rewrite câu trả lời để tự nhiên hơn. | |
| Nếu rewrite model không sẵn sàng, trả về output đã postprocess. | |
| """ | |
| if not answer: | |
| return "" | |
| self._lazy_load() | |
| fallback = postprocess_answer(answer, max_words=self.config.max_words) | |
| if not self.enabled or not self._ready: | |
| return fallback | |
| try: | |
| messages = self._build_messages(question=question, answer=answer, language=language) | |
| prompt = self._tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True) | |
| inputs = {k: v.to(self._device) for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| output_ids = self._model.generate( | |
| **inputs, | |
| max_new_tokens=self.config.max_new_tokens, | |
| do_sample=False, | |
| temperature=0.1, | |
| repetition_penalty=1.05, | |
| pad_token_id=self._tokenizer.eos_token_id, | |
| ) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| generated = self._tokenizer.decode(output_ids[0][prompt_len:], skip_special_tokens=True).strip() | |
| cleaned = postprocess_answer(generated, max_words=self.config.max_words) | |
| return cleaned or fallback | |
| except Exception as exc: | |
| print(f"[WARNING] Rewrite failed: {exc}") | |
| return fallback | |