import os from dataclasses import dataclass import torch from src.utils.text_utils import postprocess_answer def _as_bool(value: object, default: bool = False) -> bool: if value is None: return default if isinstance(value, bool): return value return str(value).strip().lower() in {"1", "true", "yes", "y", "on"} @dataclass class RewriteConfig: enabled: bool = False model_id: str = "" use_4bit: bool = True max_new_tokens: int = 28 max_words: int = 10 _REWRITE_STYLE_BY_MODEL = { "A1": { "vi": "Diễn đạt đơn giản, trực tiếp, gần với đáp án gốc.", "en": "Use simple, direct wording close to the raw answer.", }, "A2": { "vi": "Diễn đạt như một quan sát ngắn trên hình ảnh.", "en": "Word it as a short imaging observation.", }, "B1": { "vi": "Diễn đạt tự nhiên, mềm hơn, dễ đọc.", "en": "Use natural, softer, easy-to-read wording.", }, "B2": { "vi": "Diễn đạt hay hơn A1/A2, theo phong cách lâm sàng súc tích.", "en": "Use stronger concise clinical wording than A1/A2.", }, "DPO": { "vi": "Diễn đạt hay nhất theo hướng thận trọng, chuyên nghiệp.", "en": "Use the most careful, professional wording.", }, "PPO": { "vi": "Diễn đạt hay nhất theo hướng rõ ràng, mạch lạc.", "en": "Use the clearest, most polished wording.", }, } _MODEL_SPECIFIC_EXAMPLES = { "A1": { "vi": { "question": "Ảnh có khối u không?", "answer": "có", "rewrite": "Có, có khối u.", }, "en": { "question": "Is there a mass?", "answer": "yes", "rewrite": "Yes, there is a mass.", }, }, "A2": { "vi": { "question": "Ảnh có khối u không?", "answer": "có", "rewrite": "Có, thấy khối u trên ảnh.", }, "en": { "question": "Is there a mass?", "answer": "yes", "rewrite": "Yes, a mass is seen.", }, }, "B2": { "vi": { "question": "Ảnh có khối u không?", "answer": "có", "rewrite": "Có, hình ảnh gợi ý khối u.", }, "en": { "question": "Is there a mass?", "answer": "yes", "rewrite": "Yes, imaging suggests a mass.", }, }, "DPO": { "vi": { "question": "Ảnh có khối u không?", "answer": "có", "rewrite": "Có, có dấu hiệu gợi ý khối u.", }, "en": { "question": "Is there a mass?", "answer": "yes", "rewrite": "Yes, findings suggest a mass.", }, }, "PPO": { "vi": { "question": "Ảnh có khối u không?", "answer": "có", "rewrite": "Có, kết quả gợi ý khối u rõ.", }, "en": { "question": "Is there a mass?", "answer": "yes", "rewrite": "Yes, results clearly suggest a mass.", }, }, } class MedicalAnswerRewriter: """ Rewrite lớp cuối cho VQA output. Mục tiêu: - Giữ nguyên ý nghĩa gốc. - Làm câu trả lời tự nhiên và đầy đủ hơn một chút. - Vẫn giới hạn tối đa số từ theo cấu hình. Mô hình này không thay thế VQA model chính. """ def __init__(self, config: RewriteConfig | None = None) -> None: self.config = config or self._load_config() self._load_attempted = False self._ready = False self._tokenizer = None self._model = None self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @staticmethod def _load_config() -> RewriteConfig: model_id = ( os.getenv("ANSWER_REWRITE_MODEL_ID", "").strip() or os.getenv("QWEN_REWRITE_MODEL_ID", "").strip() or "Qwen/Qwen2.5-14B-Instruct" ) enabled = _as_bool(os.getenv("ANSWER_REWRITE_ENABLED"), default=True) use_4bit = _as_bool(os.getenv("ANSWER_REWRITE_USE_4BIT"), default=True) max_new_tokens = int(os.getenv("ANSWER_REWRITE_MAX_NEW_TOKENS", "28")) max_words = int(os.getenv("ANSWER_REWRITE_MAX_WORDS", "10")) return RewriteConfig( enabled=enabled, model_id=model_id, use_4bit=use_4bit, max_new_tokens=max_new_tokens, max_words=max_words, ) @property def enabled(self) -> bool: return bool(self.config.enabled and self.config.model_id) @property def model_id(self) -> str: return self.config.model_id @property def ready(self) -> bool: return self._ready def _lazy_load(self) -> None: if self._load_attempted: return self._load_attempted = True if not self.enabled: return try: from transformers import AutoModelForCausalLM, AutoTokenizer hf_token = ( os.getenv("ANSWER_REWRITE_HF_TOKEN", "").strip() or os.getenv("HF_TOKEN", "").strip() or os.getenv("HUGGINGFACE_HUB_TOKEN", "").strip() or None ) tokenizer = AutoTokenizer.from_pretrained(self.config.model_id, trust_remote_code=True, token=hf_token) model_kwargs = { "trust_remote_code": True, "low_cpu_mem_usage": True, } if self._device.type == "cuda": if self.config.use_4bit: try: from transformers import BitsAndBytesConfig model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) except Exception as exc: print(f"[WARNING] Rewrite 4-bit config unavailable, falling back to bf16: {exc}") model_kwargs["torch_dtype"] = torch.bfloat16 else: model_kwargs["torch_dtype"] = torch.bfloat16 model_kwargs["device_map"] = "auto" else: model_kwargs["torch_dtype"] = torch.float32 if hf_token is not None: model_kwargs["token"] = hf_token model = AutoModelForCausalLM.from_pretrained(self.config.model_id, **model_kwargs) model.eval() self._tokenizer = tokenizer self._model = model self._ready = True print(f"[INFO] ✅ Answer rewriter ready: {self.config.model_id}") except Exception as exc: self._ready = False print(f"[WARNING] ❌ Answer rewriter load failed: {exc}") def _get_style_instruction(self, source_model: str | None, language: str) -> str: if not source_model: return "" style = _REWRITE_STYLE_BY_MODEL.get(source_model.upper()) if not style: return "" lang_key = "en" if language.lower().startswith("en") else "vi" return style[lang_key] def _get_model_specific_example(self, source_model: str | None, language: str) -> dict[str, str] | None: if not source_model: return None examples = _MODEL_SPECIFIC_EXAMPLES.get(source_model.upper()) if not examples: return None lang_key = "en" if language.lower().startswith("en") else "vi" return examples[lang_key] def _build_messages( self, question: str, answer: str, language: str = "vi", source_model: str | None = None, ) -> list[dict[str, str]]: style_instruction = self._get_style_instruction(source_model, language) model_example = self._get_model_specific_example(source_model, language) system_prompt = ( "Bạn là bộ biên tập câu trả lời cho hệ thống Medical VQA. " "Nhiệm vụ của bạn là mở rộng đáp án gốc thành một câu trả lời đầy đủ, " "tự nhiên và rõ nghĩa hơn, nhưng vẫn phải bám sát đáp án gốc. " "KHÔNG thêm thông tin y khoa mới, KHÔNG suy diễn ngoài đáp án gốc. " "Có thể dùng câu hỏi để xác định đối tượng y khoa đang được hỏi, " "nhưng đáp án gốc quyết định ý nghĩa đúng/sai/có/không. " "Nếu nhiều model có cùng đáp án gốc, vẫn dùng phong cách riêng của model hiện tại. " "CÂU TRẢ LỜI BẮT BUỘC PHẢI DƯỚI 10 TỪ, ÍT NHẤT 3 TỪ. " "Chỉ trả về câu trả lời cuối cùng." ) if style_instruction: system_prompt += f" Phong cách riêng cho model này: {style_instruction}" if language.lower().startswith("en"): system_prompt = ( "You are an editor for a Medical VQA system. " "Expand the raw answer into a fuller, natural, clearer answer " "while staying strictly based on the raw answer. " "Do not add new medical facts or infer beyond the raw answer. " "You may use the question to identify the medical target, " "but the raw answer controls yes/no/presence/absence. " "If several models share the same raw answer, still use this model's wording style. " "THE ANSWER MUST BE UNDER 10 WORDS and at least 3 words. " "Return only the final answer." ) if style_instruction: system_prompt += f" Model-specific wording style: {style_instruction}" examples = [ { "question": "Ảnh này có tràn dịch màng phổi không?", "answer": "không", "rewrite": "Không, không thấy tràn dịch màng phổi.", }, { "question": "Hình ảnh có tim to không?", "answer": "có", "rewrite": "Có, hình ảnh cho thấy tim to.", }, { "question": "Đây là loại ảnh gì?", "answer": "x quang ngực", "rewrite": "Đây là ảnh X-quang ngực.", }, ] if language.lower().startswith("en"): examples = [ { "question": "Is there pleural effusion?", "answer": "no", "rewrite": "No, pleural effusion is not seen.", }, { "question": "Is the heart enlarged?", "answer": "yes", "rewrite": "Yes, the heart appears enlarged.", }, { "question": "What modality is this?", "answer": "chest x ray", "rewrite": "This is a chest X-ray.", }, ] if model_example: examples.append(model_example) messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}] for ex in examples: messages.append( { "role": "user", "content": f"Câu hỏi: {ex['question']}\nĐáp án gốc: {ex['answer']}", } ) messages.append({"role": "assistant", "content": ex["rewrite"]}) user_prompt = ( f"Câu hỏi: {question}\n" f"Đáp án gốc: {answer}\n" f"Model nguồn: {source_model or 'unknown'}\n" "Viết lại thành câu đầy đủ hơn, tự nhiên hơn, dưới 10 từ. " "CHỈ DÙNG THÔNG TIN TỪ ĐÁP ÁN GỐC." ) if style_instruction: user_prompt += f"\nPhong cách diễn đạt: {style_instruction}" if language.lower().startswith("en"): user_prompt = ( f"Question: {question}\nRaw answer: {answer}\n" f"Source model: {source_model or 'unknown'}\n" "Rewrite it as a fuller, natural answer under 10 words. " "Use only information from the raw answer." ) if style_instruction: user_prompt += f"\nWording style: {style_instruction}" messages.append({"role": "user", "content": user_prompt}) return messages def rewrite( self, question: str, answer: str, language: str = "vi", source_model: str | None = None, ) -> str: """ Rewrite câu trả lời để tự nhiên hơn. Nếu rewrite model không sẵn sàng, trả về output đã postprocess. """ if not answer: return "" self._lazy_load() fallback = postprocess_answer(answer, max_words=self.config.max_words) if not self.enabled or not self._ready: return fallback try: messages = self._build_messages( question=question, answer=answer, language=language, source_model=source_model, ) prompt = self._tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True) inputs = {k: v.to(self._device) for k, v in inputs.items()} with torch.inference_mode(): output_ids = self._model.generate( **inputs, max_new_tokens=self.config.max_new_tokens, do_sample=False, temperature=0.1, repetition_penalty=1.05, pad_token_id=self._tokenizer.eos_token_id, ) prompt_len = inputs["input_ids"].shape[1] generated = self._tokenizer.decode(output_ids[0][prompt_len:], skip_special_tokens=True).strip() cleaned = postprocess_answer(generated, max_words=self.config.max_words) return cleaned or fallback except Exception as exc: print(f"[WARNING] Rewrite failed: {exc}") return fallback def rewrite_final_answer( question: str, answer: str, language: str = "vi", source_model: str | None = None, ) -> str: """ Helper tiện dùng trong notebook / web. """ rewriter = MedicalAnswerRewriter() return rewriter.rewrite( question=question, answer=answer, language=language, source_model=source_model, )