fix visualization bug
Browse files- app.py +98 -386
- dataset/dataset_seg.py +53 -15
app.py
CHANGED
|
@@ -346,70 +346,8 @@ def create_single_sample_dataloader(patient_idx: int, task: str):
|
|
| 346 |
raise
|
| 347 |
|
| 348 |
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
try:
|
| 352 |
-
sample_patients, dataset_root = load_sample_data()
|
| 353 |
-
|
| 354 |
-
if task == 'classification':
|
| 355 |
-
if 'risk' not in sample_patients:
|
| 356 |
-
raise ValueError("No risk assessment data found")
|
| 357 |
-
patient_row = sample_patients['risk'].iloc[patient_idx]
|
| 358 |
-
# Risk dataset uses 'highb' instead of 'dwi'
|
| 359 |
-
t2w_path = dataset_root / patient_row["t2w"]
|
| 360 |
-
dwi_path = dataset_root / patient_row["highb"]
|
| 361 |
-
adc_path = dataset_root / patient_row["adc"]
|
| 362 |
-
else:
|
| 363 |
-
if 'UCL' not in sample_patients:
|
| 364 |
-
raise ValueError("No UCL segmentation data found")
|
| 365 |
-
patient_row = sample_patients['UCL'].iloc[patient_idx]
|
| 366 |
-
# UCL dataset uses 'dwi'
|
| 367 |
-
t2w_path = dataset_root / patient_row["t2w"]
|
| 368 |
-
dwi_path = dataset_root / patient_row["dwi"]
|
| 369 |
-
adc_path = dataset_root / patient_row["adc"]
|
| 370 |
-
|
| 371 |
-
# Read raw data for visualization (without transforms)
|
| 372 |
-
spatial_index = [2, 1, 0] # DHW format as in dataset
|
| 373 |
-
|
| 374 |
-
def read_nifti(path):
|
| 375 |
-
vol = nib.load(path)
|
| 376 |
-
vol = vol.get_fdata().astype(np.float32).transpose(spatial_index)
|
| 377 |
-
return torch.from_numpy(vol)
|
| 378 |
-
|
| 379 |
-
t2w = read_nifti(str(t2w_path))
|
| 380 |
-
dwi = read_nifti(str(dwi_path))
|
| 381 |
-
adc = read_nifti(str(adc_path))
|
| 382 |
-
|
| 383 |
-
# Load prostate mask from existing file
|
| 384 |
-
prostate_mask = None
|
| 385 |
-
prostate_mask_path = t2w_path.parent / "prostate_mask.nii.gz"
|
| 386 |
-
if prostate_mask_path.exists():
|
| 387 |
-
prostate_mask = read_nifti(str(prostate_mask_path))
|
| 388 |
-
logger.info(f"Loaded prostate mask from {prostate_mask_path}")
|
| 389 |
-
else:
|
| 390 |
-
logger.warning(f"Prostate mask not found at {prostate_mask_path}")
|
| 391 |
-
|
| 392 |
-
# Load ground truth if available
|
| 393 |
-
ground_truth = None
|
| 394 |
-
if task == 'classification' and 'pirads' in patient_row:
|
| 395 |
-
ground_truth = int(patient_row['pirads']) - 2 # Convert to 0-3 range
|
| 396 |
-
elif task == 'segmentation' and 'lesion' in patient_row:
|
| 397 |
-
lesion_path = dataset_root / patient_row["lesion"]
|
| 398 |
-
ground_truth = read_nifti(str(lesion_path))
|
| 399 |
-
ground_truth = ground_truth > 0 # Binarize
|
| 400 |
-
|
| 401 |
-
return {
|
| 402 |
-
't2w': t2w,
|
| 403 |
-
'dwi': dwi,
|
| 404 |
-
'adc': adc,
|
| 405 |
-
'ground_truth': ground_truth,
|
| 406 |
-
'prostate_mask': prostate_mask,
|
| 407 |
-
'patient_info': patient_row
|
| 408 |
-
}
|
| 409 |
-
|
| 410 |
-
except Exception as e:
|
| 411 |
-
logger.error(f"Error loading raw patient data: {e}")
|
| 412 |
-
raise
|
| 413 |
|
| 414 |
|
| 415 |
def visualize_multimodal_results(preprocessed_data: Dict, prediction: torch.Tensor, task: str) -> plt.Figure:
|
|
@@ -1006,7 +944,9 @@ def run_classification_inference(patient_id: str, progress=gr.Progress()):
|
|
| 1006 |
logits = []
|
| 1007 |
model.eval()
|
| 1008 |
with torch.no_grad():
|
| 1009 |
-
for idx,
|
|
|
|
|
|
|
| 1010 |
img, gt = img.to(args.device), gt.to(args.device)
|
| 1011 |
logit = model(img)
|
| 1012 |
logits.append(logit)
|
|
@@ -1079,7 +1019,14 @@ def run_segmentation_inference(patient_id: str, progress=gr.Progress()):
|
|
| 1079 |
# Exact demo inference loop
|
| 1080 |
model.eval()
|
| 1081 |
with torch.no_grad():
|
| 1082 |
-
for idx,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1083 |
img, gt = img.to(args.device), gt.to(args.device)
|
| 1084 |
if args.sliding_window:
|
| 1085 |
pred = sliding_window_inference(
|
|
@@ -1178,346 +1125,111 @@ def run_segmentation_inference(patient_id: str, progress=gr.Progress()):
|
|
| 1178 |
return fig, empty_3d, f"Error: {str(e)}", gr.update(maximum=63, value=32)
|
| 1179 |
|
| 1180 |
|
| 1181 |
-
# 请用这个新版本完全替换掉你代码中旧的 get_preprocessed_patient_data 函数
|
| 1182 |
-
|
| 1183 |
def get_preprocessed_patient_data(patient_idx: int, task: str):
|
| 1184 |
"""
|
| 1185 |
-
Get preprocessed patient data that matches exactly what the model uses
|
| 1186 |
-
This
|
| 1187 |
-
as the lesion mask (label) by reusing the dataset's transform pipeline.
|
| 1188 |
"""
|
| 1189 |
try:
|
| 1190 |
logger.info(f"Loading preprocessed patient data for patient {patient_idx}, task {task}")
|
| 1191 |
|
| 1192 |
-
# Create dataloader to get the preprocessed data
|
| 1193 |
-
data_loader, args,
|
| 1194 |
|
| 1195 |
-
#
|
| 1196 |
-
|
| 1197 |
-
|
| 1198 |
-
|
| 1199 |
-
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
|
| 1211 |
-
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
# 分离出各个模态
|
| 1215 |
-
if img.shape[1] >= 3:
|
| 1216 |
-
preprocessed_data['t2w_preprocessed'] = img[0, 0]
|
| 1217 |
-
preprocessed_data['dwi_preprocessed'] = img[0, 1]
|
| 1218 |
-
preprocessed_data['adc_preprocessed'] = img[0, 2]
|
| 1219 |
-
else: # Fallback
|
| 1220 |
-
preprocessed_data['t2w_preprocessed'] = img[0, 0]
|
| 1221 |
-
preprocessed_data['dwi_preprocessed'] = img[0, 0]
|
| 1222 |
-
preprocessed_data['adc_preprocessed'] = img[0, 0]
|
| 1223 |
-
|
| 1224 |
-
# 处理真实标签 (lesion mask)
|
| 1225 |
-
if gt is not None and torch.is_tensor(gt) and gt.numel() > 0:
|
| 1226 |
-
gt_tensor = gt[0]
|
| 1227 |
-
if gt_tensor.ndim == 4: # [C, D, H, W]
|
| 1228 |
-
gt_tensor = gt_tensor[0] # [D, H, W]
|
| 1229 |
-
preprocessed_data['ground_truth_preprocessed'] = gt_tensor
|
| 1230 |
-
else:
|
| 1231 |
-
preprocessed_data['ground_truth_preprocessed'] = None
|
| 1232 |
-
|
| 1233 |
-
# --- 新的、更简洁的前列腺蒙版处理逻辑 ---
|
| 1234 |
-
preprocessed_data['prostate_mask_preprocessed'] = None
|
| 1235 |
-
if 'prostate_mask' in raw_data and raw_data['prostate_mask'] is not None:
|
| 1236 |
-
logger.info("Applying identical transforms to prostate mask as the lesion mask...")
|
| 1237 |
-
|
| 1238 |
-
# 准备一个和Dataset输出结构类似的字典,但把prostate_mask放在'label'键上
|
| 1239 |
-
# 'image'键也需要,因为某些空间变换可能需要参考图像的affine信息
|
| 1240 |
-
prostate_mask_input_dict = {
|
| 1241 |
-
'image': raw_data['t2w'], # 使用原始T2W作为参考图像
|
| 1242 |
-
'label': raw_data['prostate_mask']
|
| 1243 |
-
}
|
| 1244 |
-
|
| 1245 |
-
# 使用从Dataset中获取的完全相同的transforms管道
|
| 1246 |
-
transformed_dict = transforms(prostate_mask_input_dict)
|
| 1247 |
|
| 1248 |
-
#
|
| 1249 |
-
|
|
|
|
| 1250 |
|
| 1251 |
-
|
| 1252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1253 |
|
| 1254 |
-
|
| 1255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1256 |
|
| 1257 |
-
|
| 1258 |
-
|
| 1259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1260 |
else:
|
| 1261 |
-
|
| 1262 |
-
else:
|
| 1263 |
-
logger.warning("No raw prostate mask found.")
|
| 1264 |
|
| 1265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1266 |
return preprocessed_data
|
| 1267 |
|
| 1268 |
except Exception as e:
|
| 1269 |
logger.error(f"Error loading preprocessed patient data: {e}")
|
| 1270 |
-
import traceback
|
| 1271 |
-
logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 1272 |
raise
|
| 1273 |
|
| 1274 |
-
# def apply_preprocessing_to_mask(prostate_mask: torch.Tensor, args, transforms) -> torch.Tensor:
|
| 1275 |
-
# """
|
| 1276 |
-
# Apply the same preprocessing transformations to prostate mask as used for image data
|
| 1277 |
-
# This ensures consistent spatial alignment between gland mask and lesion predictions
|
| 1278 |
-
# """
|
| 1279 |
-
# try:
|
| 1280 |
-
# logger.info("Applying consistent preprocessing to prostate mask...")
|
| 1281 |
-
# logger.info(f"Input mask: shape={prostate_mask.shape}, non-zero voxels={torch.count_nonzero(prostate_mask)}")
|
| 1282 |
-
|
| 1283 |
-
# # Create a dummy data dict similar to what the dataset would create
|
| 1284 |
-
# # We need to mimic the dataset structure for transforms to work
|
| 1285 |
-
# mask_data = {
|
| 1286 |
-
# 'image': prostate_mask.unsqueeze(0), # Add channel dimension [C, D, H, W]
|
| 1287 |
-
# 'label': prostate_mask.unsqueeze(0) # Use same mask as label for processing
|
| 1288 |
-
# }
|
| 1289 |
-
|
| 1290 |
-
# logger.info(f"Created data dict with shapes: image={mask_data['image'].shape}, label={mask_data['label'].shape}")
|
| 1291 |
-
|
| 1292 |
-
# # Apply the same transforms that are used for the actual data
|
| 1293 |
-
# try:
|
| 1294 |
-
# logger.info("Applying dataset transforms...")
|
| 1295 |
-
# transformed_data = transforms(mask_data)
|
| 1296 |
-
# preprocessed_mask = transformed_data['label'][0] # Remove channel dimension
|
| 1297 |
-
|
| 1298 |
-
# non_zero_after = torch.count_nonzero(preprocessed_mask)
|
| 1299 |
-
# non_zero_original = torch.count_nonzero(prostate_mask)
|
| 1300 |
-
# logger.info(f"Transform result: shape={preprocessed_mask.shape}, non-zero voxels={non_zero_after}")
|
| 1301 |
-
# logger.info(f"Original mask had {non_zero_original} non-zero voxels")
|
| 1302 |
-
|
| 1303 |
-
# # Check if transforms removed all data
|
| 1304 |
-
# if non_zero_after == 0 and non_zero_original > 0:
|
| 1305 |
-
# logger.warning("Transforms removed all mask data, falling back to manual preprocessing")
|
| 1306 |
-
# return apply_manual_preprocessing_to_mask(prostate_mask, args)
|
| 1307 |
-
|
| 1308 |
-
# # Ensure binary values
|
| 1309 |
-
# preprocessed_mask = (preprocessed_mask > 0.5).float()
|
| 1310 |
-
# final_non_zero = torch.count_nonzero(preprocessed_mask)
|
| 1311 |
-
|
| 1312 |
-
# logger.info(f"✅ Successfully applied preprocessing to prostate mask: shape={preprocessed_mask.shape}, non-zero voxels={final_non_zero}")
|
| 1313 |
-
# return preprocessed_mask
|
| 1314 |
-
|
| 1315 |
-
# except Exception as e:
|
| 1316 |
-
# logger.warning(f"Could not apply transforms to prostate mask: {e}")
|
| 1317 |
-
# # Fallback to manual preprocessing
|
| 1318 |
-
# logger.info("Falling back to manual preprocessing...")
|
| 1319 |
-
# return apply_manual_preprocessing_to_mask(prostate_mask, args)
|
| 1320 |
-
|
| 1321 |
-
# except Exception as e:
|
| 1322 |
-
# logger.error(f"Error applying preprocessing to prostate mask: {e}")
|
| 1323 |
-
# import traceback
|
| 1324 |
-
# logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 1325 |
-
|
| 1326 |
-
# # Final fallback - just return the original mask if everything fails
|
| 1327 |
-
# logger.warning("Returning original mask as final fallback")
|
| 1328 |
-
# return prostate_mask
|
| 1329 |
-
|
| 1330 |
-
# def apply_preprocessing_to_mask(prostate_mask: torch.Tensor, args, transforms) -> torch.Tensor:
|
| 1331 |
-
# """
|
| 1332 |
-
# Apply ONLY SPATIAL preprocessing transformations to the prostate mask.
|
| 1333 |
-
# This ensures consistent spatial alignment without corrupting the mask's binary values.
|
| 1334 |
-
# """
|
| 1335 |
-
# try:
|
| 1336 |
-
# logger.info("Applying consistent SPATIAL preprocessing to prostate mask...")
|
| 1337 |
-
# logger.info(f"Input mask: shape={prostate_mask.shape}, non-zero voxels={torch.count_nonzero(prostate_mask)}")
|
| 1338 |
-
|
| 1339 |
-
# # Define MONAI spatial transforms to filter for
|
| 1340 |
-
|
| 1341 |
-
# # Filter for spatial transforms ONLY
|
| 1342 |
-
# spatial_transforms = [
|
| 1343 |
-
# t for t in transforms.transforms
|
| 1344 |
-
# if isinstance(t, (Spacing, CenterSpatialCrop, SpatialPad, Resize))
|
| 1345 |
-
# ]
|
| 1346 |
-
|
| 1347 |
-
# if not spatial_transforms:
|
| 1348 |
-
# logger.warning("No spatial transforms found in the pipeline. Falling back to manual preprocessing.")
|
| 1349 |
-
# return apply_manual_preprocessing_to_mask(prostate_mask, args)
|
| 1350 |
|
| 1351 |
-
#
|
| 1352 |
-
#
|
| 1353 |
-
#
|
| 1354 |
-
|
| 1355 |
-
# # Create a data dictionary. Use a key like 'label' or 'mask'
|
| 1356 |
-
# # to avoid intensity normalization which is often applied to 'image' key.
|
| 1357 |
-
# mask_data = {'label': prostate_mask.unsqueeze(0)} # Add channel dim: [1, D, H, W]
|
| 1358 |
-
|
| 1359 |
-
# # Apply the filtered spatial transformations
|
| 1360 |
-
# transformed_data = spatial_pipeline(mask_data)
|
| 1361 |
-
# preprocessed_mask = transformed_data['label'][0] # Remove channel dim
|
| 1362 |
-
|
| 1363 |
-
# # Ensure mask remains binary after interpolation
|
| 1364 |
-
# preprocessed_mask = (preprocessed_mask > 0.5).float()
|
| 1365 |
-
|
| 1366 |
-
# final_non_zero = torch.count_nonzero(preprocessed_mask)
|
| 1367 |
-
# logger.info(f"✅ Successfully applied SPATIAL preprocessing: shape={preprocessed_mask.shape}, non-zero voxels={final_non_zero}")
|
| 1368 |
-
|
| 1369 |
-
# if final_non_zero == 0 and torch.count_nonzero(prostate_mask) > 0:
|
| 1370 |
-
# logger.warning("Spatial transforms resulted in an empty mask. Trying manual fallback.")
|
| 1371 |
-
# return apply_manual_preprocessing_to_mask(prostate_mask, args)
|
| 1372 |
-
|
| 1373 |
-
# return preprocessed_mask
|
| 1374 |
-
|
| 1375 |
-
# except Exception as e:
|
| 1376 |
-
# logger.error(f"Error applying filtered spatial preprocessing: {e}. Falling back to manual method.")
|
| 1377 |
-
# import traceback
|
| 1378 |
-
# logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 1379 |
-
# return apply_manual_preprocessing_to_mask(prostate_mask, args)
|
| 1380 |
-
|
| 1381 |
-
# def apply_manual_preprocessing_to_mask(prostate_mask: torch.Tensor, args) -> torch.Tensor:
|
| 1382 |
-
# """
|
| 1383 |
-
# Apply manual preprocessing to prostate mask to match the spatial dimensions
|
| 1384 |
-
# This includes the key transformations: resampling, cropping, and padding
|
| 1385 |
-
# """
|
| 1386 |
-
# try:
|
| 1387 |
-
# logger.info("Applying manual preprocessing to prostate mask...")
|
| 1388 |
-
# logger.info(f"Input mask: shape={prostate_mask.shape}, non-zero voxels={torch.count_nonzero(prostate_mask)}")
|
| 1389 |
-
|
| 1390 |
-
# # Add batch and channel dimensions for MONAI transforms
|
| 1391 |
-
# mask_tensor = prostate_mask.unsqueeze(0).unsqueeze(0) # [1, 1, D, H, W]
|
| 1392 |
-
# logger.info(f"After adding dimensions: {mask_tensor.shape}")
|
| 1393 |
-
|
| 1394 |
-
# # Apply spatial resampling to match the target spacing - be conservative with large masks
|
| 1395 |
-
# if hasattr(args, 'spacing') and args.spacing is not None:
|
| 1396 |
-
# logger.info(f"Applying spacing transform with spacing: {args.spacing}")
|
| 1397 |
-
# # For large original masks, use more conservative resampling
|
| 1398 |
-
# original_size = prostate_mask.shape
|
| 1399 |
-
# if max(original_size) > 300: # Large original mask
|
| 1400 |
-
# logger.info("Large original mask detected, using conservative resampling")
|
| 1401 |
-
# # Use linear interpolation for large masks to preserve more structure
|
| 1402 |
-
# spacing_transform = Spacing(pixdim=args.spacing, mode='bilinear')
|
| 1403 |
-
# else:
|
| 1404 |
-
# spacing_transform = Spacing(pixdim=args.spacing, mode='nearest')
|
| 1405 |
-
|
| 1406 |
-
# mask_tensor_before = mask_tensor.clone()
|
| 1407 |
-
# mask_tensor = spacing_transform(mask_tensor)
|
| 1408 |
-
|
| 1409 |
-
# logger.info(f"After spacing transform: {mask_tensor.shape}")
|
| 1410 |
-
# logger.info(f"Non-zero voxels after spacing: {torch.count_nonzero(mask_tensor)}")
|
| 1411 |
-
|
| 1412 |
-
# # Check if spacing transform removed all data
|
| 1413 |
-
# if torch.count_nonzero(mask_tensor) == 0 and torch.count_nonzero(mask_tensor_before) > 0:
|
| 1414 |
-
# logger.warning("Spacing transform removed all mask data, reverting")
|
| 1415 |
-
# mask_tensor = mask_tensor_before
|
| 1416 |
-
|
| 1417 |
-
# # Apply cropping to match the input size - be conservative to preserve mask content
|
| 1418 |
-
# if hasattr(args, 'crop_spatial_size') and args.crop_spatial_size is not None:
|
| 1419 |
-
# current_shape = mask_tensor.shape[2:] # [D, H, W]
|
| 1420 |
-
# target_shape = args.crop_spatial_size
|
| 1421 |
-
|
| 1422 |
-
# logger.info(f"Current shape after spacing: {current_shape}, target crop size: {target_shape}")
|
| 1423 |
-
|
| 1424 |
-
# # Check if we need to crop at all
|
| 1425 |
-
# needs_crop = any(c > t for c, t in zip(current_shape, target_shape))
|
| 1426 |
-
|
| 1427 |
-
# if needs_crop:
|
| 1428 |
-
# logger.info(f"Applying center crop to size: {target_shape}")
|
| 1429 |
-
|
| 1430 |
-
# # For large reductions, be more conservative
|
| 1431 |
-
# crop_size_ratios = [t/c for c, t in zip(current_shape, target_shape)]
|
| 1432 |
-
# min_ratio = min(crop_size_ratios)
|
| 1433 |
-
|
| 1434 |
-
# if min_ratio < 0.6: # Very aggressive crop (>40% reduction)
|
| 1435 |
-
# logger.warning(f"Very aggressive crop detected (min ratio: {min_ratio:.2f}), using conservative approach")
|
| 1436 |
-
# # Use a safer crop size that's not too aggressive
|
| 1437 |
-
# safe_crop_size = tuple(min(c, max(t, int(c * 0.7))) for c, t in zip(current_shape, target_shape))
|
| 1438 |
-
# logger.info(f"Using safer crop size: {safe_crop_size}")
|
| 1439 |
-
# crop_transform = CenterSpatialCrop(roi_size=safe_crop_size)
|
| 1440 |
-
# else:
|
| 1441 |
-
# crop_transform = CenterSpatialCrop(roi_size=target_shape)
|
| 1442 |
-
|
| 1443 |
-
# mask_tensor_before = mask_tensor.clone()
|
| 1444 |
-
# mask_tensor = crop_transform(mask_tensor)
|
| 1445 |
-
|
| 1446 |
-
# logger.info(f"After crop: {mask_tensor.shape}")
|
| 1447 |
-
# logger.info(f"Non-zero voxels after crop: {torch.count_nonzero(mask_tensor)}")
|
| 1448 |
-
|
| 1449 |
-
# # Check if cropping removed all data
|
| 1450 |
-
# if torch.count_nonzero(mask_tensor) == 0 and torch.count_nonzero(mask_tensor_before) > 0:
|
| 1451 |
-
# logger.warning("Crop removed all mask data, skipping crop")
|
| 1452 |
-
# mask_tensor = mask_tensor_before
|
| 1453 |
-
# else:
|
| 1454 |
-
# logger.info("No cropping needed, current size is within target")
|
| 1455 |
-
|
| 1456 |
-
# # Apply padding if needed to match exact target size
|
| 1457 |
-
# if hasattr(args, 'crop_spatial_size') and args.crop_spatial_size is not None:
|
| 1458 |
-
# current_shape = mask_tensor.shape[2:] # [D, H, W]
|
| 1459 |
-
# target_shape = args.crop_spatial_size
|
| 1460 |
-
|
| 1461 |
-
# if current_shape != target_shape:
|
| 1462 |
-
# logger.info(f"Applying padding from {current_shape} to {target_shape}")
|
| 1463 |
-
# pad_transform = SpatialPad(spatial_size=target_shape, mode='constant')
|
| 1464 |
-
# mask_tensor = pad_transform(mask_tensor)
|
| 1465 |
-
# logger.info(f"After padding: {mask_tensor.shape}")
|
| 1466 |
-
# logger.info(f"Non-zero voxels after padding: {torch.count_nonzero(mask_tensor)}")
|
| 1467 |
-
|
| 1468 |
-
# # Remove batch and channel dimensions
|
| 1469 |
-
# preprocessed_mask = mask_tensor[0, 0] # [D, H, W]
|
| 1470 |
-
|
| 1471 |
-
# # Ensure binary values (thresholding at 0.5)
|
| 1472 |
-
# preprocessed_mask = (preprocessed_mask > 0.5).float()
|
| 1473 |
-
|
| 1474 |
-
# logger.info(f"Manual preprocessing completed: shape={preprocessed_mask.shape}, non-zero voxels={torch.count_nonzero(preprocessed_mask)}")
|
| 1475 |
-
|
| 1476 |
-
# return preprocessed_mask
|
| 1477 |
-
|
| 1478 |
-
# except Exception as e:
|
| 1479 |
-
# logger.error(f"Error in manual preprocessing: {e}")
|
| 1480 |
-
# import traceback
|
| 1481 |
-
# logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 1482 |
-
|
| 1483 |
-
# # Final fallback - simple resize
|
| 1484 |
-
# logger.warning("Falling back to simple resize")
|
| 1485 |
-
# target_shape = args.crop_spatial_size if hasattr(args, 'crop_spatial_size') else (64, 224, 224)
|
| 1486 |
-
|
| 1487 |
-
# try:
|
| 1488 |
-
# if prostate_mask.shape != target_shape:
|
| 1489 |
-
# logger.info(f"Simple resize from {prostate_mask.shape} to {target_shape}")
|
| 1490 |
-
# zoom_factors = tuple(target_shape[i] / prostate_mask.shape[i] for i in range(3))
|
| 1491 |
-
# logger.info(f"Zoom factors: {zoom_factors}")
|
| 1492 |
-
|
| 1493 |
-
# # Use order=1 (linear) for better preservation of structures, then threshold
|
| 1494 |
-
# mask_np = prostate_mask.cpu().numpy().astype(np.float32)
|
| 1495 |
-
# resized_mask = ndimage.zoom(mask_np, zoom_factors, order=1, prefilter=False)
|
| 1496 |
-
|
| 1497 |
-
# # Apply threshold to maintain binary nature
|
| 1498 |
-
# resized_mask = (resized_mask > 0.3).astype(np.float32) # Lower threshold to preserve more
|
| 1499 |
-
|
| 1500 |
-
# result = torch.from_numpy(resized_mask)
|
| 1501 |
-
|
| 1502 |
-
# logger.info(f"Simple resize result: shape={result.shape}, non-zero voxels={torch.count_nonzero(result)}")
|
| 1503 |
-
|
| 1504 |
-
# if torch.count_nonzero(result) == 0:
|
| 1505 |
-
# logger.warning("Simple resize resulted in empty mask, trying with very low threshold")
|
| 1506 |
-
# # Try with an even lower threshold
|
| 1507 |
-
# resized_mask_low = ndimage.zoom(mask_np, zoom_factors, order=1, prefilter=False)
|
| 1508 |
-
# resized_mask_low = (resized_mask_low > 0.1).astype(np.float32)
|
| 1509 |
-
# result = torch.from_numpy(resized_mask_low)
|
| 1510 |
-
# logger.info(f"Low threshold result: non-zero voxels={torch.count_nonzero(result)}")
|
| 1511 |
-
|
| 1512 |
-
# return result
|
| 1513 |
-
# else:
|
| 1514 |
-
# # Ensure binary
|
| 1515 |
-
# result = (prostate_mask > 0.5).float()
|
| 1516 |
-
# return result
|
| 1517 |
-
# except Exception as final_error:
|
| 1518 |
-
# logger.error(f"Simple resize failed: {final_error}")
|
| 1519 |
-
# # Return empty mask if everything fails
|
| 1520 |
-
# return torch.zeros(target_shape, dtype=torch.float32)
|
| 1521 |
|
| 1522 |
|
| 1523 |
def create_interface():
|
|
|
|
| 346 |
raise
|
| 347 |
|
| 348 |
|
| 349 |
+
# Removed get_raw_patient_data function as it's no longer needed
|
| 350 |
+
# All data is now processed through the unified preprocessing pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
|
| 353 |
def visualize_multimodal_results(preprocessed_data: Dict, prediction: torch.Tensor, task: str) -> plt.Figure:
|
|
|
|
| 944 |
logits = []
|
| 945 |
model.eval()
|
| 946 |
with torch.no_grad():
|
| 947 |
+
for idx, data in enumerate(data_loader):
|
| 948 |
+
# Classification only returns (img, gt, pid)
|
| 949 |
+
img, gt, pid = data
|
| 950 |
img, gt = img.to(args.device), gt.to(args.device)
|
| 951 |
logit = model(img)
|
| 952 |
logits.append(logit)
|
|
|
|
| 1019 |
# Exact demo inference loop
|
| 1020 |
model.eval()
|
| 1021 |
with torch.no_grad():
|
| 1022 |
+
for idx, data in enumerate(data_loader):
|
| 1023 |
+
# Segmentation returns (img, gt, pid, gland) or (img, gt, pid)
|
| 1024 |
+
if len(data) == 4:
|
| 1025 |
+
img, gt, pid, gland = data
|
| 1026 |
+
else:
|
| 1027 |
+
img, gt, pid = data
|
| 1028 |
+
gland = None
|
| 1029 |
+
|
| 1030 |
img, gt = img.to(args.device), gt.to(args.device)
|
| 1031 |
if args.sliding_window:
|
| 1032 |
pred = sliding_window_inference(
|
|
|
|
| 1125 |
return fig, empty_3d, f"Error: {str(e)}", gr.update(maximum=63, value=32)
|
| 1126 |
|
| 1127 |
|
|
|
|
|
|
|
| 1128 |
def get_preprocessed_patient_data(patient_idx: int, task: str):
|
| 1129 |
"""
|
| 1130 |
+
Get preprocessed patient data that matches exactly what the model uses
|
| 1131 |
+
This ensures spatial consistency between displayed images and predictions
|
|
|
|
| 1132 |
"""
|
| 1133 |
try:
|
| 1134 |
logger.info(f"Loading preprocessed patient data for patient {patient_idx}, task {task}")
|
| 1135 |
|
| 1136 |
+
# Create dataloader to get the exact preprocessed data that the model uses
|
| 1137 |
+
data_loader, args, dataset = create_single_sample_dataloader(patient_idx, task)
|
| 1138 |
|
| 1139 |
+
# Extract the preprocessed data from the dataloader
|
| 1140 |
+
for idx, data in enumerate(data_loader):
|
| 1141 |
+
# Handle different return formats based on task
|
| 1142 |
+
if task == 'classification':
|
| 1143 |
+
# Classification returns: (img, gt, pid)
|
| 1144 |
+
img, gt, pid = data
|
| 1145 |
+
gland = None
|
| 1146 |
+
else:
|
| 1147 |
+
# Segmentation returns: (img, gt, pid, gland)
|
| 1148 |
+
if len(data) == 4:
|
| 1149 |
+
img, gt, pid, gland = data
|
| 1150 |
+
# Remove batch dimension from gland if it exists
|
| 1151 |
+
if gland is not None and len(gland) > 0:
|
| 1152 |
+
gland = gland[0] # Remove batch dimension
|
| 1153 |
+
else:
|
| 1154 |
+
# Fallback for old format
|
| 1155 |
+
img, gt, pid = data
|
| 1156 |
+
gland = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1157 |
|
| 1158 |
+
# img is the preprocessed multi-modal image [B, C, D, H, W]
|
| 1159 |
+
# gt is the preprocessed ground truth [B, C, D, H, W] or [B, D, H, W]
|
| 1160 |
+
# gland is the preprocessed prostate mask [B, C, D, H, W] or [B, D, H, W] or None
|
| 1161 |
|
| 1162 |
+
preprocessed_data = {
|
| 1163 |
+
'preprocessed_image': img[0], # Remove batch dimension [C, D, H, W]
|
| 1164 |
+
'preprocessed_gt': gt[0] if gt is not None else None, # Remove batch dimension
|
| 1165 |
+
'patient_id': pid[0] if isinstance(pid, (list, tuple)) else pid,
|
| 1166 |
+
'spatial_shape': img.shape[2:], # [D, H, W]
|
| 1167 |
+
'args': args
|
| 1168 |
+
}
|
| 1169 |
|
| 1170 |
+
# Extract individual modalities from the preprocessed image
|
| 1171 |
+
# Assuming the model input has 3 channels: [T2W, DWI, ADC]
|
| 1172 |
+
if img.shape[1] >= 3: # Check if we have at least 3 channels
|
| 1173 |
+
preprocessed_data['t2w_preprocessed'] = img[0, 0] # [D, H, W]
|
| 1174 |
+
preprocessed_data['dwi_preprocessed'] = img[0, 1] # [D, H, W]
|
| 1175 |
+
preprocessed_data['adc_preprocessed'] = img[0, 2] # [D, H, W]
|
| 1176 |
+
else:
|
| 1177 |
+
logger.warning(f"Unexpected number of input channels: {img.shape[1]}")
|
| 1178 |
+
# Fallback: use the same channel for all modalities
|
| 1179 |
+
preprocessed_data['t2w_preprocessed'] = img[0, 0]
|
| 1180 |
+
preprocessed_data['dwi_preprocessed'] = img[0, 0]
|
| 1181 |
+
preprocessed_data['adc_preprocessed'] = img[0, 0]
|
| 1182 |
|
| 1183 |
+
# Convert ground truth to proper format
|
| 1184 |
+
if gt is not None:
|
| 1185 |
+
if len(gt.shape) == 4: # [B, D, H, W]
|
| 1186 |
+
preprocessed_data['ground_truth_preprocessed'] = gt[0] # [D, H, W]
|
| 1187 |
+
elif len(gt.shape) == 5: # [B, C, D, H, W]
|
| 1188 |
+
preprocessed_data['ground_truth_preprocessed'] = gt[0, 0] # [D, H, W]
|
| 1189 |
+
else:
|
| 1190 |
+
logger.warning(f"Unexpected ground truth shape: {gt.shape}")
|
| 1191 |
+
preprocessed_data['ground_truth_preprocessed'] = None
|
| 1192 |
else:
|
| 1193 |
+
preprocessed_data['ground_truth_preprocessed'] = None
|
|
|
|
|
|
|
| 1194 |
|
| 1195 |
+
# Handle prostate gland mask - now from preprocessed data
|
| 1196 |
+
if gland is not None:
|
| 1197 |
+
# Convert gland mask to proper format
|
| 1198 |
+
if len(gland.shape) == 3: # [D, H, W]
|
| 1199 |
+
preprocessed_gland = gland
|
| 1200 |
+
elif len(gland.shape) == 4: # [C, D, H, W]
|
| 1201 |
+
preprocessed_gland = gland[0] # [D, H, W]
|
| 1202 |
+
else:
|
| 1203 |
+
logger.warning(f"Unexpected gland mask shape: {gland.shape}")
|
| 1204 |
+
preprocessed_gland = None
|
| 1205 |
+
|
| 1206 |
+
if preprocessed_gland is not None:
|
| 1207 |
+
# Ensure binary values
|
| 1208 |
+
preprocessed_gland = (preprocessed_gland > 0.5).float()
|
| 1209 |
+
non_zero_voxels = torch.count_nonzero(preprocessed_gland)
|
| 1210 |
+
|
| 1211 |
+
logger.info(f"✅ Preprocessed prostate gland mask: shape={preprocessed_gland.shape}, non-zero voxels={non_zero_voxels}")
|
| 1212 |
+
preprocessed_data['prostate_mask_preprocessed'] = preprocessed_gland if non_zero_voxels > 0 else None
|
| 1213 |
+
else:
|
| 1214 |
+
logger.warning("Could not process gland mask")
|
| 1215 |
+
preprocessed_data['prostate_mask_preprocessed'] = None
|
| 1216 |
+
else:
|
| 1217 |
+
logger.info("No prostate gland mask found in preprocessed data")
|
| 1218 |
+
preprocessed_data['prostate_mask_preprocessed'] = None
|
| 1219 |
+
|
| 1220 |
+
break # Only process the first (and only) sample
|
| 1221 |
+
|
| 1222 |
+
logger.info(f"Successfully loaded preprocessed data with spatial shape: {preprocessed_data['spatial_shape']}")
|
| 1223 |
return preprocessed_data
|
| 1224 |
|
| 1225 |
except Exception as e:
|
| 1226 |
logger.error(f"Error loading preprocessed patient data: {e}")
|
|
|
|
|
|
|
| 1227 |
raise
|
| 1228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1229 |
|
| 1230 |
+
# Removed manual prostate mask preprocessing functions
|
| 1231 |
+
# These are no longer needed as the gland mask is now processed
|
| 1232 |
+
# together with image and lesion data in the same pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1233 |
|
| 1234 |
|
| 1235 |
def create_interface():
|
dataset/dataset_seg.py
CHANGED
|
@@ -73,14 +73,44 @@ class UCLSet(BaseVolumeDataset):
|
|
| 73 |
img = torch.stack([t2w, dwi, adc], 0)
|
| 74 |
seg = self.read(path["lesion"]).unsqueeze(0)
|
| 75 |
seg = seg > 0
|
| 76 |
-
|
| 77 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if self.transforms is not None:
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
if type(trans_dict) == list:
|
| 81 |
trans_dict = trans_dict[0]
|
|
|
|
| 82 |
img, seg = trans_dict["image"], trans_dict["label"]
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# TODO: need to update; unfinished
|
| 86 |
"""
|
|
@@ -204,22 +234,25 @@ def get_transforms(args):
|
|
| 204 |
train_transforms = [
|
| 205 |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 206 |
RandRotated(
|
| 207 |
-
keys=["image", "label"],
|
| 208 |
prob=0.3,
|
| 209 |
range_x=30 / 180 * np.pi,
|
| 210 |
keep_size=False,
|
| 211 |
-
mode=["bilinear", "nearest"],
|
|
|
|
| 212 |
),
|
| 213 |
RandZoomd(
|
| 214 |
-
keys=["image", "label"],
|
| 215 |
prob=0.3,
|
| 216 |
min_zoom=[1, 0.9, 0.9],
|
| 217 |
max_zoom=[1, 1.1, 1.1],
|
| 218 |
-
mode=["trilinear", "nearest"],
|
|
|
|
| 219 |
),
|
| 220 |
SpatialPadd(
|
| 221 |
-
keys=["image", "label"],
|
| 222 |
spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
|
|
|
|
| 223 |
),
|
| 224 |
# RandCropByPosNegLabeld(
|
| 225 |
# keys=["image", "label"],
|
|
@@ -230,11 +263,12 @@ def get_transforms(args):
|
|
| 230 |
# num_samples=1,
|
| 231 |
# ),
|
| 232 |
RandSpatialCropd(
|
| 233 |
-
keys=["image", "label"],
|
| 234 |
roi_size=args.crop_spatial_size,
|
| 235 |
random_size=False,
|
|
|
|
| 236 |
),
|
| 237 |
-
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
|
| 238 |
# BinarizeLabeld(keys=["label"])
|
| 239 |
RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),
|
| 240 |
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),
|
|
@@ -247,11 +281,13 @@ def get_transforms(args):
|
|
| 247 |
[
|
| 248 |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 249 |
CenterSpatialCropd(
|
| 250 |
-
keys=["image", "label"], roi_size=args.crop_spatial_size
|
|
|
|
| 251 |
),
|
| 252 |
SpatialPadd(
|
| 253 |
-
keys=["image", "label"],
|
| 254 |
spatial_size=[i for i in args.crop_spatial_size],
|
|
|
|
| 255 |
),
|
| 256 |
# BinarizeLabeld(keys=["label"])
|
| 257 |
]
|
|
@@ -260,11 +296,13 @@ def get_transforms(args):
|
|
| 260 |
[
|
| 261 |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 262 |
CenterSpatialCropd(
|
| 263 |
-
keys=["image", "label"], roi_size=args.crop_spatial_size
|
|
|
|
| 264 |
),
|
| 265 |
SpatialPadd(
|
| 266 |
-
keys=["image", "label"],
|
| 267 |
spatial_size=[i for i in args.crop_spatial_size],
|
|
|
|
| 268 |
),
|
| 269 |
# BinarizeLabeld(keys=["label"])
|
| 270 |
]
|
|
|
|
| 73 |
img = torch.stack([t2w, dwi, adc], 0)
|
| 74 |
seg = self.read(path["lesion"]).unsqueeze(0)
|
| 75 |
seg = seg > 0
|
| 76 |
+
|
| 77 |
+
# Load prostate gland mask if available
|
| 78 |
+
gland = None
|
| 79 |
+
try:
|
| 80 |
+
# Try to find prostate_mask.nii.gz in the same directory as t2w
|
| 81 |
+
t2w_path = os.path.join(self.root, path["t2w"])
|
| 82 |
+
prostate_mask_path = os.path.join(os.path.dirname(t2w_path), "prostate_mask.nii.gz")
|
| 83 |
+
if os.path.exists(prostate_mask_path):
|
| 84 |
+
# Read prostate mask relative to the t2w directory
|
| 85 |
+
relative_mask_path = os.path.relpath(prostate_mask_path, self.root)
|
| 86 |
+
gland = self.read(relative_mask_path).unsqueeze(0)
|
| 87 |
+
gland = gland > 0 # Binarize
|
| 88 |
+
print(f"Loaded prostate mask: {relative_mask_path}, shape: {gland.shape}, non-zero: {torch.count_nonzero(gland)}")
|
| 89 |
+
else:
|
| 90 |
+
print(f"Prostate mask not found at: {prostate_mask_path}")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"Failed to load prostate mask for {path['t2w']}: {e}")
|
| 93 |
+
|
| 94 |
if self.transforms is not None:
|
| 95 |
+
# Include gland in transforms if available
|
| 96 |
+
if gland is not None:
|
| 97 |
+
trans_dict = self.transforms({"image": img, "label": seg, "gland": gland})
|
| 98 |
+
else:
|
| 99 |
+
trans_dict = self.transforms({"image": img, "label": seg})
|
| 100 |
+
|
| 101 |
if type(trans_dict) == list:
|
| 102 |
trans_dict = trans_dict[0]
|
| 103 |
+
|
| 104 |
img, seg = trans_dict["image"], trans_dict["label"]
|
| 105 |
+
|
| 106 |
+
# Extract processed gland if it was included
|
| 107 |
+
if gland is not None and "gland" in trans_dict:
|
| 108 |
+
gland = trans_dict["gland"]
|
| 109 |
+
else:
|
| 110 |
+
gland = None
|
| 111 |
+
|
| 112 |
+
# Return gland as part of the data (we'll modify the app to handle this)
|
| 113 |
+
return img, seg, torch.tensor(idx, dtype=torch.long), gland
|
| 114 |
|
| 115 |
# TODO: need to update; unfinished
|
| 116 |
"""
|
|
|
|
| 234 |
train_transforms = [
|
| 235 |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 236 |
RandRotated(
|
| 237 |
+
keys=["image", "label", "gland"],
|
| 238 |
prob=0.3,
|
| 239 |
range_x=30 / 180 * np.pi,
|
| 240 |
keep_size=False,
|
| 241 |
+
mode=["bilinear", "nearest", "nearest"],
|
| 242 |
+
allow_missing_keys=True,
|
| 243 |
),
|
| 244 |
RandZoomd(
|
| 245 |
+
keys=["image", "label", "gland"],
|
| 246 |
prob=0.3,
|
| 247 |
min_zoom=[1, 0.9, 0.9],
|
| 248 |
max_zoom=[1, 1.1, 1.1],
|
| 249 |
+
mode=["trilinear", "nearest", "nearest"],
|
| 250 |
+
allow_missing_keys=True,
|
| 251 |
),
|
| 252 |
SpatialPadd(
|
| 253 |
+
keys=["image", "label", "gland"],
|
| 254 |
spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
|
| 255 |
+
allow_missing_keys=True,
|
| 256 |
),
|
| 257 |
# RandCropByPosNegLabeld(
|
| 258 |
# keys=["image", "label"],
|
|
|
|
| 263 |
# num_samples=1,
|
| 264 |
# ),
|
| 265 |
RandSpatialCropd(
|
| 266 |
+
keys=["image", "label", "gland"],
|
| 267 |
roi_size=args.crop_spatial_size,
|
| 268 |
random_size=False,
|
| 269 |
+
allow_missing_keys=True,
|
| 270 |
),
|
| 271 |
+
RandFlipd(keys=["image", "label", "gland"], prob=0.5, spatial_axis=2, allow_missing_keys=True),
|
| 272 |
# BinarizeLabeld(keys=["label"])
|
| 273 |
RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),
|
| 274 |
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),
|
|
|
|
| 281 |
[
|
| 282 |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 283 |
CenterSpatialCropd(
|
| 284 |
+
keys=["image", "label", "gland"], roi_size=args.crop_spatial_size,
|
| 285 |
+
allow_missing_keys=True,
|
| 286 |
),
|
| 287 |
SpatialPadd(
|
| 288 |
+
keys=["image", "label", "gland"],
|
| 289 |
spatial_size=[i for i in args.crop_spatial_size],
|
| 290 |
+
allow_missing_keys=True,
|
| 291 |
),
|
| 292 |
# BinarizeLabeld(keys=["label"])
|
| 293 |
]
|
|
|
|
| 296 |
[
|
| 297 |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 298 |
CenterSpatialCropd(
|
| 299 |
+
keys=["image", "label", "gland"], roi_size=args.crop_spatial_size,
|
| 300 |
+
allow_missing_keys=True,
|
| 301 |
),
|
| 302 |
SpatialPadd(
|
| 303 |
+
keys=["image", "label", "gland"],
|
| 304 |
spatial_size=[i for i in args.crop_spatial_size],
|
| 305 |
+
allow_missing_keys=True,
|
| 306 |
),
|
| 307 |
# BinarizeLabeld(keys=["label"])
|
| 308 |
]
|