File size: 4,930 Bytes
d63774a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from torch.utils.data import Dataset
from PIL import Image
import json
import os
from src.utils.text_utils import get_target_answer, normalize_answer, text_normalize

class MedicalVQADataset(Dataset):
    """
    Dataset class chung cho Medical VQA (SLAKE + VQA-RAD).
    """
    def __init__(self, hf_dataset=None, json_path=None, image_dir=None, tokenizer=None, transform=None, max_seq_len=64, max_ans_len=10, is_dpo=False, in_channels=1, answer_max_words=10):
        if hf_dataset is not None:
            self.data = hf_dataset
            self.use_hf = True
        elif json_path is not None:
            with open(json_path, "r", encoding="utf-8") as f:
                self.data = json.load(f)
            self.use_hf = False
        else:
            raise ValueError("Phải cung cấp hf_dataset hoặc json_path!")
            
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_seq_len = max_seq_len
        self.max_ans_len = max_ans_len
        self.is_dpo = is_dpo
        self.in_channels = in_channels
        self.answer_max_words = answer_max_words
        
        # Mapping for closed questions (Yes/No)
        self.label_map = {"no": 0, "yes": 1, "không": 0, "có": 1}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        # 1. Xử lý ảnh
        if self.use_hf:
            image = item["image"]
            if self.in_channels == 1:
                if image.mode != "L": image = image.convert("L")
            else:
                if image.mode != "RGB": image = image.convert("RGB")
        else:
            # DPO preference data might use 'image' or 'image_name'
            img_name = item.get("image_name") or item.get("image")
            img_path = os.path.join(self.image_dir, img_name)
            mode = "L" if self.in_channels == 1 else "RGB"
            image = Image.open(img_path).convert(mode)
            
        raw_image = image # Bản lưu trữ cho Multimodal Processor (chưa Normalize)

        if self.transform:
            image = self.transform(image)
        else:
            from torchvision import transforms
            image = transforms.ToTensor()(image)
            
        # 2. Xử lý câu hỏi
        q_key = "question" if self.is_dpo else "question_vi"
        raw_question = item[q_key]
        raw_question_en = item.get("question", raw_question) # Lấy bản tiếng Anh nếu có
        
        question = text_normalize(raw_question)
        encoding = self.tokenizer(
            question,
            padding="max_length",
            truncation=True,
            max_length=self.max_seq_len,
            return_tensors="pt"
        )
        
        if self.is_dpo:
            # 3. Xử lý DPO Preference (Chosen vs Rejected)
            chosen_ans = normalize_answer(item["chosen"])
            rejected_ans = normalize_answer(item["rejected"])
            
            chosen_encoding = self.tokenizer(chosen_ans, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt")
            rejected_encoding = self.tokenizer(rejected_ans, padding="max_length", truncation=True, max_length=self.max_ans_len, return_tensors="pt")
            
            return {
                "image": image,
                "raw_image": raw_image,
                "raw_questions": raw_question,
                "raw_questions_en": raw_question_en,
                "input_ids": encoding["input_ids"].flatten(),
                "attention_mask": encoding["attention_mask"].flatten(),
                "chosen_ids": chosen_encoding["input_ids"].flatten(),
                "rejected_ids": rejected_encoding["input_ids"].flatten(),
            }
        
        # 3. Xử lý câu trả lời chuẩn (Non-DPO)
        answer = get_target_answer(item, max_words=self.answer_max_words)
        answer_en = normalize_answer(item.get("answer", answer)) # Lấy bản tiếng Anh nếu có
        label_closed = self.label_map.get(answer, -1)
        
        ans_encoding = self.tokenizer(
            answer,
            padding="max_length",
            truncation=True,
            max_length=self.max_ans_len,
            return_tensors="pt"
        )
        
        return {
            "image": image,
            "raw_image": raw_image,
            "raw_questions": raw_question,
            "raw_questions_en": raw_question_en,
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "label_closed": torch.tensor(label_closed, dtype=torch.long),
            "target_ids": ans_encoding["input_ids"].flatten(),
            "raw_answer": answer,
            "raw_answer_full": normalize_answer(item.get("answer_full_vi", answer)),
            "raw_answer_en": answer_en
        }