| |
| import os, json |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| from PIL import Image |
|
|
| |
| DISEASES = ["Atelectasis", "Calcification", "Cardiomegaly", "Consolidation", |
| "Diffuse Nodule", "Effusion", "Emphysema", "Fibrosis", "Fracture", |
| "Mass", "Nodule", "Pleural Thickening", "Pneumothorax"] |
| DISEASE_TO_IDX = {name: i+1 for i, name in enumerate(DISEASES)} |
|
|
| def _load_gray_int(path): |
| arr = np.array(Image.open(path).convert("L")) |
| return arr.astype(np.int32) |
|
|
| def _remap_to_sequential(label_map_int): |
| |
| vals = np.unique(label_map_int) |
| vals = vals[vals != 0] |
| if len(vals) == 0: |
| return label_map_int.astype(np.int32), 0 |
| out = np.zeros_like(label_map_int, dtype=np.int32) |
| for i, v in enumerate(vals, start=1): |
| out[label_map_int == v] = i |
| return out, int(out.max()) |
|
|
| def _merge_diseases_from_attn_list(attn_list, base_dir, hw): |
| H, W = hw |
| items = [] |
| for name, rel_or_abs in attn_list: |
| idx = DISEASE_TO_IDX.get(name, None) |
| if idx is None: |
| continue |
| path = rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(base_dir, rel_or_abs) |
| m = (_load_gray_int(path) > 0) |
| area = int(m.sum()) |
| items.append((area, idx, m)) |
| |
| items.sort(key=lambda x: (-x[0], x[1])) |
| disease_map = np.zeros((H, W), dtype=np.int32) |
| for area, idx, m in items: |
| if area == 0: |
| continue |
| disease_map[m] = idx |
| max_idx = int(disease_map.max()) |
| return disease_map, max_idx |
|
|
| class MyDataset_rib(Dataset): |
| def __init__(self, args, tokenizer): |
| self.data_dir = args.train_data_dir |
| self.prompt_dir = args.train_data_prompt |
| self.tokenizer = tokenizer |
| with open(self.prompt_dir, 'rt') as f: |
| self.data = [json.loads(l) for l in f if l.strip()] |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
|
|
| |
| img_path = item['file_name'] |
| img_path = img_path if os.path.isabs(img_path) else os.path.join(self.data_dir, img_path) |
| img = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32) |
| img = img / 127.5 - 1.0 |
| pixel_values = torch.from_numpy(img).permute(2,0,1) |
|
|
| |
| organ_path = item['organ'] |
| organ_path = organ_path if os.path.isabs(organ_path) else os.path.join(self.data_dir, organ_path) |
| organ_map_raw = _load_gray_int(organ_path) |
| organ_map, organ_max = _remap_to_sequential(organ_map_raw) |
|
|
| |
| rib_path = os.path.join(self.data_dir, "rib", item['file_name'].split('/')[0], os.path.basename(item['file_name'])) |
| |
| |
| rib_map_raw = _load_gray_int(rib_path) |
| |
| rib_bin = (rib_map_raw > 0).astype(np.float32) |
|
|
| H, W = organ_map.shape |
|
|
| |
| attn_list = item.get('attn_list', []) |
| disease_map, disease_max = (np.zeros((H,W), dtype=np.int32), 0) |
| if len(attn_list) > 0: |
| disease_map, disease_max = _merge_diseases_from_attn_list(attn_list, self.data_dir, (H, W)) |
|
|
| |
| def _norm(ch, maxv): |
| return ch.astype(np.float32) / float(maxv) if maxv > 0 else ch.astype(np.float32) |
|
|
| organ_ch = _norm(organ_map, organ_max) |
| |
| rib_ch = rib_bin |
| disease_ch = _norm(disease_map, disease_max) |
|
|
| cond = np.stack([disease_ch, organ_ch, rib_ch], axis=0).astype(np.float32) |
| conditioning_pixel_values = torch.from_numpy(cond) |
|
|
| enc = self.tokenizer( |
| item['prompt'], |
| max_length=self.tokenizer.model_max_length, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
|
|
| return { |
| "conditioning_pixel_values": conditioning_pixel_values, |
| "pixel_values": pixel_values, |
| "input_ids": enc.input_ids.squeeze(0), |
| } |
|
|