ours / layout2image_multi /MyDataset_rib.py
diing's picture
Add files using upload-large-folder tool
75b1a45 verified
# MyDataset_rib.py
import os, json
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
# 固定疾病索引:背景=0,疾病从1开始
DISEASES = ["Atelectasis", "Calcification", "Cardiomegaly", "Consolidation",
"Diffuse Nodule", "Effusion", "Emphysema", "Fibrosis", "Fracture",
"Mass", "Nodule", "Pleural Thickening", "Pneumothorax"]
DISEASE_TO_IDX = {name: i+1 for i, name in enumerate(DISEASES)} # 1..len(DISEASES)
def _load_gray_int(path):
arr = np.array(Image.open(path).convert("L"))
return arr.astype(np.int32)
def _remap_to_sequential(label_map_int):
# 把任意像素值集合映射到 0..K,保持0为背景
vals = np.unique(label_map_int)
vals = vals[vals != 0]
if len(vals) == 0:
return label_map_int.astype(np.int32), 0
out = np.zeros_like(label_map_int, dtype=np.int32)
for i, v in enumerate(vals, start=1):
out[label_map_int == v] = i
return out, int(out.max())
def _merge_diseases_from_attn_list(attn_list, base_dir, hw):
H, W = hw
items = []
for name, rel_or_abs in attn_list:
idx = DISEASE_TO_IDX.get(name, None)
if idx is None:
continue
path = rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(base_dir, rel_or_abs)
m = (_load_gray_int(path) > 0)
area = int(m.sum())
items.append((area, idx, m))
# 先大后小:小的后写入,会覆盖大的 => “小的不能被大的盖住”
items.sort(key=lambda x: (-x[0], x[1]))
disease_map = np.zeros((H, W), dtype=np.int32)
for area, idx, m in items:
if area == 0:
continue
disease_map[m] = idx
max_idx = int(disease_map.max())
return disease_map, max_idx
class MyDataset_rib(Dataset):
def __init__(self, args, tokenizer):
self.data_dir = args.train_data_dir
self.prompt_dir = args.train_data_prompt
self.tokenizer = tokenizer
with open(self.prompt_dir, 'rt') as f:
self.data = [json.loads(l) for l in f if l.strip()]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# 目标图像
img_path = item['file_name']
img_path = img_path if os.path.isabs(img_path) else os.path.join(self.data_dir, img_path)
img = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32)
img = img / 127.5 - 1.0 # [-1,1]
pixel_values = torch.from_numpy(img).permute(2,0,1) # (3,H,W)
# organ 标签(单图,多器官:像素值=器官ID)
organ_path = item['organ']
organ_path = organ_path if os.path.isabs(organ_path) else os.path.join(self.data_dir, organ_path)
organ_map_raw = _load_gray_int(organ_path)
organ_map, organ_max = _remap_to_sequential(organ_map_raw)
# rib 标签(单图,多肋骨:像素值=肋骨ID),来自 overlap
rib_path = os.path.join(self.data_dir, "rib", item['file_name'].split('/')[0], os.path.basename(item['file_name']))
# rib_path = item['file_name']
# rib_path = rib_path if os.path.isabs(rib_path) else os.path.join(self.data_dir, rib_path)
rib_map_raw = _load_gray_int(rib_path)
# rib_map, rib_max = _remap_to_sequential(rib_map_raw)
rib_bin = (rib_map_raw > 0).astype(np.float32)
H, W = organ_map.shape
# disease:从 attn_list 读取多张二值图并合并为单通道
attn_list = item.get('attn_list', [])
disease_map, disease_max = (np.zeros((H,W), dtype=np.int32), 0)
if len(attn_list) > 0:
disease_map, disease_max = _merge_diseases_from_attn_list(attn_list, self.data_dir, (H, W))
# 各通道归一化到[0,1](仅缩放,不改变标签关系)
def _norm(ch, maxv):
return ch.astype(np.float32) / float(maxv) if maxv > 0 else ch.astype(np.float32)
organ_ch = _norm(organ_map, organ_max)
# rib_ch = _norm(rib_map, rib_max)
rib_ch = rib_bin
disease_ch = _norm(disease_map, disease_max)
cond = np.stack([disease_ch, organ_ch, rib_ch], axis=0).astype(np.float32) # (3,H,W)
conditioning_pixel_values = torch.from_numpy(cond)
enc = self.tokenizer(
item['prompt'],
max_length=self.tokenizer.model_max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
"conditioning_pixel_values": conditioning_pixel_values, # (3,H,W) ∈ [0,1]
"pixel_values": pixel_values, # (3,H,W) ∈ [-1,1]
"input_ids": enc.input_ids.squeeze(0),
}