Spaces:
Paused
Paused
| import torch | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import json | |
| import os | |
| from src.utils.text_utils import get_target_answer, normalize_answer, text_normalize | |
| class MedicalVQADataset(Dataset): | |
| """ | |
| Dataset class chung cho Medical VQA (SLAKE + VQA-RAD). | |
| """ | |
| def __init__(self, hf_dataset=None, json_path=None, image_dir=None, tokenizer=None, transform=None, max_seq_len=64, max_ans_len=10, is_dpo=False, in_channels=1, answer_max_words=10): | |
| if hf_dataset is not None: | |
| self.data = hf_dataset | |
| self.use_hf = True | |
| elif json_path is not None: | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| self.data = json.load(f) | |
| self.use_hf = False | |
| else: | |
| raise ValueError("Phải cung cấp hf_dataset hoặc json_path!") | |
| self.image_dir = image_dir | |
| self.tokenizer = tokenizer | |
| self.transform = transform | |
| self.max_seq_len = max_seq_len | |
| self.max_ans_len = max_ans_len | |
| self.is_dpo = is_dpo | |
| self.in_channels = in_channels | |
| self.answer_max_words = answer_max_words | |
| # Mapping for closed questions (Yes/No) | |
| self.label_map = {"no": 0, "yes": 1, "không": 0, "có": 1} | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| # 1. Xử lý ảnh | |
| if self.use_hf: | |
| image = item["image"] | |
| if self.in_channels == 1: | |
| if image.mode != "L": image = image.convert("L") | |
| else: | |
| if image.mode != "RGB": image = image.convert("RGB") | |
| else: | |
| # DPO preference data might use 'image' or 'image_name' | |
| img_name = item.get("image_name") or item.get("image") | |
| img_path = os.path.join(self.image_dir, img_name) | |
| mode = "L" if self.in_channels == 1 else "RGB" | |
| image = Image.open(img_path).convert(mode) | |
| raw_image = image # Bản lưu trữ cho Multimodal Processor (chưa Normalize) | |
| if self.transform: | |
| image = self.transform(image) | |
| else: | |
| from torchvision import transforms | |
| image = transforms.ToTensor()(image) | |
| # 2. Xử lý câu hỏi | |
| q_key = "question" if self.is_dpo else "question_vi" | |
| raw_question = item[q_key] | |
| raw_question_en = item.get("question", raw_question) # Lấy bản tiếng Anh nếu có | |
| question = text_normalize(raw_question) | |
| encoding = self.tokenizer( | |
| question, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_seq_len, | |
| return_tensors="pt" | |
| ) | |
| if self.is_dpo: | |
| # 3. Xử lý DPO Preference (Chosen vs Rejected) | |
| chosen_ans = normalize_answer(item["chosen"]) | |
| rejected_ans = normalize_answer(item["rejected"]) | |
| chosen_encoding = self.tokenizer(chosen_ans, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt") | |
| rejected_encoding = self.tokenizer(rejected_ans, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt") | |
| return { | |
| "image": image, | |
| "raw_image": raw_image, | |
| "raw_questions": raw_question, | |
| "raw_questions_en": raw_question_en, | |
| "input_ids": encoding["input_ids"].flatten(), | |
| "attention_mask": encoding["attention_mask"].flatten(), | |
| "chosen_ids": chosen_encoding["input_ids"].flatten(), | |
| "rejected_ids": rejected_encoding["input_ids"].flatten(), | |
| } | |
| # 3. Xử lý câu trả lời chuẩn (Non-DPO) | |
| answer = get_target_answer(item, max_words=self.answer_max_words) | |
| answer_en = normalize_answer(item.get("answer", answer)) # Lấy bản tiếng Anh nếu có | |
| label_closed = self.label_map.get(answer, -1) | |
| ans_encoding = self.tokenizer( | |
| answer, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_ans_len, | |
| return_tensors="pt" | |
| ) | |
| return { | |
| "image": image, | |
| "raw_image": raw_image, | |
| "raw_questions": raw_question, | |
| "raw_questions_en": raw_question_en, | |
| "input_ids": encoding["input_ids"].flatten(), | |
| "attention_mask": encoding["attention_mask"].flatten(), | |
| "label_closed": torch.tensor(label_closed, dtype=torch.long), | |
| "target_ids": ans_encoding["input_ids"].flatten(), | |
| "raw_answer": answer, | |
| "raw_answer_full": normalize_answer(item.get("answer_full_vi", answer)), | |
| "raw_answer_en": answer_en | |
| } | |