import torch from torch.utils.data import Dataset from PIL import Image import json import os from src.utils.text_utils import get_target_answer, normalize_answer, text_normalize class MedicalVQADataset(Dataset): """ Dataset class chung cho Medical VQA (SLAKE + VQA-RAD). """ def __init__(self, hf_dataset=None, json_path=None, image_dir=None, tokenizer=None, transform=None, max_seq_len=64, max_ans_len=10, is_dpo=False, in_channels=1, answer_max_words=10): if hf_dataset is not None: self.data = hf_dataset self.use_hf = True elif json_path is not None: with open(json_path, "r", encoding="utf-8") as f: self.data = json.load(f) self.use_hf = False else: raise ValueError("Phải cung cấp hf_dataset hoặc json_path!") self.image_dir = image_dir self.tokenizer = tokenizer self.transform = transform self.max_seq_len = max_seq_len self.max_ans_len = max_ans_len self.is_dpo = is_dpo self.in_channels = in_channels self.answer_max_words = answer_max_words # Mapping for closed questions (Yes/No) self.label_map = {"no": 0, "yes": 1, "không": 0, "có": 1} def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] # 1. Xử lý ảnh if self.use_hf: image = item["image"] if self.in_channels == 1: if image.mode != "L": image = image.convert("L") else: if image.mode != "RGB": image = image.convert("RGB") else: # DPO preference data might use 'image' or 'image_name' img_name = item.get("image_name") or item.get("image") img_path = os.path.join(self.image_dir, img_name) mode = "L" if self.in_channels == 1 else "RGB" image = Image.open(img_path).convert(mode) raw_image = image # Bản lưu trữ cho Multimodal Processor (chưa Normalize) if self.transform: image = self.transform(image) else: from torchvision import transforms image = transforms.ToTensor()(image) # 2. Xử lý câu hỏi q_key = "question" if self.is_dpo else "question_vi" raw_question = item[q_key] raw_question_en = item.get("question", raw_question) # Lấy bản tiếng Anh nếu có question = text_normalize(raw_question) encoding = self.tokenizer( question, padding="max_length", truncation=True, max_length=self.max_seq_len, return_tensors="pt" ) if self.is_dpo: # 3. Xử lý DPO Preference (Chosen vs Rejected) chosen_ans = normalize_answer(item["chosen"]) rejected_ans = normalize_answer(item["rejected"]) chosen_encoding = self.tokenizer(chosen_ans, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt") rejected_encoding = self.tokenizer(rejected_ans, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt") return { "image": image, "raw_image": raw_image, "raw_questions": raw_question, "raw_questions_en": raw_question_en, "input_ids": encoding["input_ids"].flatten(), "attention_mask": encoding["attention_mask"].flatten(), "chosen_ids": chosen_encoding["input_ids"].flatten(), "rejected_ids": rejected_encoding["input_ids"].flatten(), } # 3. Xử lý câu trả lời chuẩn (Non-DPO) answer = get_target_answer(item, max_words=self.answer_max_words) answer_en = normalize_answer(item.get("answer", answer)) # Lấy bản tiếng Anh nếu có label_closed = self.label_map.get(answer, -1) ans_encoding = self.tokenizer( answer, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt" ) return { "image": image, "raw_image": raw_image, "raw_questions": raw_question, "raw_questions_en": raw_question_en, "input_ids": encoding["input_ids"].flatten(), "attention_mask": encoding["attention_mask"].flatten(), "label_closed": torch.tensor(label_closed, dtype=torch.long), "target_ids": ans_encoding["input_ids"].flatten(), "raw_answer": answer, "raw_answer_full": normalize_answer(item.get("answer_full_vi", answer)), "raw_answer_en": answer_en }