# MyDataset_rib.py import os, json import numpy as np import torch from torch.utils.data import Dataset from PIL import Image # 固定疾病索引:背景=0,疾病从1开始 DISEASES = ["Atelectasis", "Calcification", "Cardiomegaly", "Consolidation", "Diffuse Nodule", "Effusion", "Emphysema", "Fibrosis", "Fracture", "Mass", "Nodule", "Pleural Thickening", "Pneumothorax"] DISEASE_TO_IDX = {name: i+1 for i, name in enumerate(DISEASES)} # 1..len(DISEASES) def _load_gray_int(path): arr = np.array(Image.open(path).convert("L")) return arr.astype(np.int32) def _remap_to_sequential(label_map_int): # 把任意像素值集合映射到 0..K,保持0为背景 vals = np.unique(label_map_int) vals = vals[vals != 0] if len(vals) == 0: return label_map_int.astype(np.int32), 0 out = np.zeros_like(label_map_int, dtype=np.int32) for i, v in enumerate(vals, start=1): out[label_map_int == v] = i return out, int(out.max()) def _merge_diseases_from_attn_list(attn_list, base_dir, hw): H, W = hw items = [] for name, rel_or_abs in attn_list: idx = DISEASE_TO_IDX.get(name, None) if idx is None: continue path = rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(base_dir, rel_or_abs) m = (_load_gray_int(path) > 0) area = int(m.sum()) items.append((area, idx, m)) # 先大后小:小的后写入,会覆盖大的 => “小的不能被大的盖住” items.sort(key=lambda x: (-x[0], x[1])) disease_map = np.zeros((H, W), dtype=np.int32) for area, idx, m in items: if area == 0: continue disease_map[m] = idx max_idx = int(disease_map.max()) return disease_map, max_idx class MyDataset_rib(Dataset): def __init__(self, args, tokenizer): self.data_dir = args.train_data_dir self.prompt_dir = args.train_data_prompt self.tokenizer = tokenizer with open(self.prompt_dir, 'rt') as f: self.data = [json.loads(l) for l in f if l.strip()] def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] # 目标图像 img_path = item['file_name'] img_path = img_path if os.path.isabs(img_path) else os.path.join(self.data_dir, img_path) img = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32) img = img / 127.5 - 1.0 # [-1,1] pixel_values = torch.from_numpy(img).permute(2,0,1) # (3,H,W) # organ 标签(单图,多器官:像素值=器官ID) organ_path = item['organ'] organ_path = organ_path if os.path.isabs(organ_path) else os.path.join(self.data_dir, organ_path) organ_map_raw = _load_gray_int(organ_path) organ_map, organ_max = _remap_to_sequential(organ_map_raw) # rib 标签(单图,多肋骨:像素值=肋骨ID),来自 overlap rib_path = os.path.join(self.data_dir, "rib", item['file_name'].split('/')[0], os.path.basename(item['file_name'])) # rib_path = item['file_name'] # rib_path = rib_path if os.path.isabs(rib_path) else os.path.join(self.data_dir, rib_path) rib_map_raw = _load_gray_int(rib_path) # rib_map, rib_max = _remap_to_sequential(rib_map_raw) rib_bin = (rib_map_raw > 0).astype(np.float32) H, W = organ_map.shape # disease:从 attn_list 读取多张二值图并合并为单通道 attn_list = item.get('attn_list', []) disease_map, disease_max = (np.zeros((H,W), dtype=np.int32), 0) if len(attn_list) > 0: disease_map, disease_max = _merge_diseases_from_attn_list(attn_list, self.data_dir, (H, W)) # 各通道归一化到[0,1](仅缩放,不改变标签关系) def _norm(ch, maxv): return ch.astype(np.float32) / float(maxv) if maxv > 0 else ch.astype(np.float32) organ_ch = _norm(organ_map, organ_max) # rib_ch = _norm(rib_map, rib_max) rib_ch = rib_bin disease_ch = _norm(disease_map, disease_max) cond = np.stack([disease_ch, organ_ch, rib_ch], axis=0).astype(np.float32) # (3,H,W) conditioning_pixel_values = torch.from_numpy(cond) enc = self.tokenizer( item['prompt'], max_length=self.tokenizer.model_max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { "conditioning_pixel_values": conditioning_pixel_values, # (3,H,W) ∈ [0,1] "pixel_values": pixel_values, # (3,H,W) ∈ [-1,1] "input_ids": enc.input_ids.squeeze(0), }