ours / text2layout /MyDataset_rib.py
diing's picture
Add files using upload-large-folder tool
75b1a45 verified
# MyDataset_rib.py
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
DISEASES = [
"Atelectasis", "Calcification", "Cardiomegaly", "Consolidation",
"Diffuse Nodule", "Edema", "Effusion", "Emphysema", "Enlarged Cardiomediastinum",
"Fibrosis", "Fracture", "Mass", "Nodule", "Opacity", "Pleural Thickening", "Pneumothorax",
"Pacemaker",
]
DISEASES_LOWER = {d.lower() for d in DISEASES}
def _resolve_path(rel_or_abs, base_dir):
if not rel_or_abs:
return None
return rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(base_dir, rel_or_abs)
def _load_gray_int(path):
arr = np.array(Image.open(path).convert("L"))
return arr.astype(np.int32)
def _load_bin_mask_resize(path, hw):
H, W = hw
im = Image.open(path).convert("L")
if im.size != (W, H):
im = im.resize((W, H), resample=Image.NEAREST)
arr = np.array(im)
return (arr > 0).astype(np.float32)
def _remap_to_sequential(label_map_int):
vals = np.unique(label_map_int)
vals = vals[vals != 0]
if len(vals) == 0:
return label_map_int.astype(np.int32), 0
out = np.zeros_like(label_map_int, dtype=np.int32)
for i, v in enumerate(vals, start=1):
out[label_map_int == v] = i
return out, int(out.max())
def _norm(ch, maxv):
return ch.astype(np.float32) / float(maxv) if maxv > 0 else ch.astype(np.float32)
def _get_first_disease_name(item):
attn_list = item.get("attn_list", [])
if (
len(attn_list) > 0
and isinstance(attn_list[0], (list, tuple))
and len(attn_list[0]) >= 1
):
return str(attn_list[0][0]).strip()
return None
def _is_target_record(item):
disease_name = _get_first_disease_name(item)
if disease_name is None:
return False
return disease_name.lower() in DISEASES_LOWER
class MyDataset_rib(Dataset):
def __init__(self, args, tokenizer):
self.data_dir = args.train_data_dir
self.prompt_dir = args.train_data_prompt
self.tokenizer = tokenizer
with open(self.prompt_dir, "rt") as f:
raw_data = [json.loads(l) for l in f if l.strip()]
self.data = [item for item in raw_data if _is_target_record(item)]
print(f"Loaded {len(self.data)} valid records from {len(raw_data)} total records.")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
organ_path = _resolve_path(item.get("organ", ""), self.data_dir)
rib_path = _resolve_path(item.get("rib", ""), self.data_dir)
mask_path = _resolve_path(item.get("mask", ""), self.data_dir)
# organ
organ_map_raw = _load_gray_int(organ_path)
organ_map, organ_max = _remap_to_sequential(organ_map_raw)
# rib
rib_map_raw = _load_gray_int(rib_path)
rib_map, rib_max = _remap_to_sequential(rib_map_raw)
H, W = organ_map.shape
# disease channel: binary 0/1
disease_ch = np.zeros((H, W), dtype=np.float32)
if mask_path is not None and os.path.exists(mask_path):
disease_ch = _load_bin_mask_resize(mask_path, (H, W)).astype(np.float32)
organ_ch = _norm(organ_map, organ_max)
rib_ch = _norm(rib_map, rib_max)
empty_ch = np.zeros((H, W), dtype=np.float32)
original = np.stack([empty_ch, organ_ch, rib_ch], axis=0).astype(np.float32)
edited = np.stack([disease_ch, organ_ch, rib_ch], axis=0).astype(np.float32)
enc = self.tokenizer(
item["prompt"],
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return {
"original_pixel_values": torch.from_numpy(original),
"edited_pixel_values": torch.from_numpy(edited),
"input_ids": enc.input_ids.squeeze(0),
}