Anonymise commited on
Commit
0da921e
·
1 Parent(s): 2098a50

fix visualization bug

Browse files
Files changed (2) hide show
  1. app.py +98 -386
  2. dataset/dataset_seg.py +53 -15
app.py CHANGED
@@ -346,70 +346,8 @@ def create_single_sample_dataloader(patient_idx: int, task: str):
346
  raise
347
 
348
 
349
- def get_raw_patient_data(patient_idx: int, task: str):
350
- """Get raw patient data for visualization"""
351
- try:
352
- sample_patients, dataset_root = load_sample_data()
353
-
354
- if task == 'classification':
355
- if 'risk' not in sample_patients:
356
- raise ValueError("No risk assessment data found")
357
- patient_row = sample_patients['risk'].iloc[patient_idx]
358
- # Risk dataset uses 'highb' instead of 'dwi'
359
- t2w_path = dataset_root / patient_row["t2w"]
360
- dwi_path = dataset_root / patient_row["highb"]
361
- adc_path = dataset_root / patient_row["adc"]
362
- else:
363
- if 'UCL' not in sample_patients:
364
- raise ValueError("No UCL segmentation data found")
365
- patient_row = sample_patients['UCL'].iloc[patient_idx]
366
- # UCL dataset uses 'dwi'
367
- t2w_path = dataset_root / patient_row["t2w"]
368
- dwi_path = dataset_root / patient_row["dwi"]
369
- adc_path = dataset_root / patient_row["adc"]
370
-
371
- # Read raw data for visualization (without transforms)
372
- spatial_index = [2, 1, 0] # DHW format as in dataset
373
-
374
- def read_nifti(path):
375
- vol = nib.load(path)
376
- vol = vol.get_fdata().astype(np.float32).transpose(spatial_index)
377
- return torch.from_numpy(vol)
378
-
379
- t2w = read_nifti(str(t2w_path))
380
- dwi = read_nifti(str(dwi_path))
381
- adc = read_nifti(str(adc_path))
382
-
383
- # Load prostate mask from existing file
384
- prostate_mask = None
385
- prostate_mask_path = t2w_path.parent / "prostate_mask.nii.gz"
386
- if prostate_mask_path.exists():
387
- prostate_mask = read_nifti(str(prostate_mask_path))
388
- logger.info(f"Loaded prostate mask from {prostate_mask_path}")
389
- else:
390
- logger.warning(f"Prostate mask not found at {prostate_mask_path}")
391
-
392
- # Load ground truth if available
393
- ground_truth = None
394
- if task == 'classification' and 'pirads' in patient_row:
395
- ground_truth = int(patient_row['pirads']) - 2 # Convert to 0-3 range
396
- elif task == 'segmentation' and 'lesion' in patient_row:
397
- lesion_path = dataset_root / patient_row["lesion"]
398
- ground_truth = read_nifti(str(lesion_path))
399
- ground_truth = ground_truth > 0 # Binarize
400
-
401
- return {
402
- 't2w': t2w,
403
- 'dwi': dwi,
404
- 'adc': adc,
405
- 'ground_truth': ground_truth,
406
- 'prostate_mask': prostate_mask,
407
- 'patient_info': patient_row
408
- }
409
-
410
- except Exception as e:
411
- logger.error(f"Error loading raw patient data: {e}")
412
- raise
413
 
414
 
415
  def visualize_multimodal_results(preprocessed_data: Dict, prediction: torch.Tensor, task: str) -> plt.Figure:
@@ -1006,7 +944,9 @@ def run_classification_inference(patient_id: str, progress=gr.Progress()):
1006
  logits = []
1007
  model.eval()
1008
  with torch.no_grad():
1009
- for idx, (img, gt, pid) in enumerate(data_loader):
 
 
1010
  img, gt = img.to(args.device), gt.to(args.device)
1011
  logit = model(img)
1012
  logits.append(logit)
@@ -1079,7 +1019,14 @@ def run_segmentation_inference(patient_id: str, progress=gr.Progress()):
1079
  # Exact demo inference loop
1080
  model.eval()
1081
  with torch.no_grad():
1082
- for idx, (img, gt, pid) in enumerate(data_loader):
 
 
 
 
 
 
 
1083
  img, gt = img.to(args.device), gt.to(args.device)
1084
  if args.sliding_window:
1085
  pred = sliding_window_inference(
@@ -1178,346 +1125,111 @@ def run_segmentation_inference(patient_id: str, progress=gr.Progress()):
1178
  return fig, empty_3d, f"Error: {str(e)}", gr.update(maximum=63, value=32)
1179
 
1180
 
1181
- # 请用这个新版本完全替换掉你代码中旧的 get_preprocessed_patient_data 函数
1182
-
1183
  def get_preprocessed_patient_data(patient_idx: int, task: str):
1184
  """
1185
- Get preprocessed patient data that matches exactly what the model uses.
1186
- This new version ensures the prostate mask undergoes the IDENTICAL transform
1187
- as the lesion mask (label) by reusing the dataset's transform pipeline.
1188
  """
1189
  try:
1190
  logger.info(f"Loading preprocessed patient data for patient {patient_idx}, task {task}")
1191
 
1192
- # Create dataloader to get the preprocessed data and the transform pipeline
1193
- data_loader, args, dataset_subset = create_single_sample_dataloader(patient_idx, task)
1194
 
1195
- # Get the original dataset and its transformations
1196
- # Note: data_loader wraps a Subset, which wraps the original Dataset (UCLSet)
1197
- original_dataset = dataset_subset.dataset
1198
- transforms = original_dataset.transform
1199
-
1200
- # Also get the original raw data to access the raw prostate mask
1201
- raw_data = get_raw_patient_data(patient_idx, task)
1202
-
1203
- # --- 主逻辑:从dataloader中提取模型输入和真实标签 ---
1204
- preprocessed_data_from_loader = next(iter(data_loader))
1205
- img, gt, pid = preprocessed_data_from_loader
1206
-
1207
- preprocessed_data = {
1208
- 'preprocessed_image': img[0],
1209
- 'patient_id': pid[0] if isinstance(pid, (list, tuple)) else pid,
1210
- 'spatial_shape': img.shape[2:],
1211
- 'args': args
1212
- }
1213
-
1214
- # 分离出各个模态
1215
- if img.shape[1] >= 3:
1216
- preprocessed_data['t2w_preprocessed'] = img[0, 0]
1217
- preprocessed_data['dwi_preprocessed'] = img[0, 1]
1218
- preprocessed_data['adc_preprocessed'] = img[0, 2]
1219
- else: # Fallback
1220
- preprocessed_data['t2w_preprocessed'] = img[0, 0]
1221
- preprocessed_data['dwi_preprocessed'] = img[0, 0]
1222
- preprocessed_data['adc_preprocessed'] = img[0, 0]
1223
-
1224
- # 处理真实标签 (lesion mask)
1225
- if gt is not None and torch.is_tensor(gt) and gt.numel() > 0:
1226
- gt_tensor = gt[0]
1227
- if gt_tensor.ndim == 4: # [C, D, H, W]
1228
- gt_tensor = gt_tensor[0] # [D, H, W]
1229
- preprocessed_data['ground_truth_preprocessed'] = gt_tensor
1230
- else:
1231
- preprocessed_data['ground_truth_preprocessed'] = None
1232
-
1233
- # --- 新的、更简洁的前列腺蒙版处理逻辑 ---
1234
- preprocessed_data['prostate_mask_preprocessed'] = None
1235
- if 'prostate_mask' in raw_data and raw_data['prostate_mask'] is not None:
1236
- logger.info("Applying identical transforms to prostate mask as the lesion mask...")
1237
-
1238
- # 准备一个和Dataset输出结构类似的字典,但把prostate_mask放在'label'键上
1239
- # 'image'键也需要,因为某些空间变换可能需要参考图像的affine信息
1240
- prostate_mask_input_dict = {
1241
- 'image': raw_data['t2w'], # 使用原始T2W作为参考图像
1242
- 'label': raw_data['prostate_mask']
1243
- }
1244
-
1245
- # 使用从Dataset中获取的完全相同的transforms管道
1246
- transformed_dict = transforms(prostate_mask_input_dict)
1247
 
1248
- # 提取处理后的蒙版
1249
- processed_mask = transformed_dict['label']
 
1250
 
1251
- # 确保二值化
1252
- processed_mask = (processed_mask > 0.5).float()
 
 
 
 
 
1253
 
1254
- non_zero_after = torch.count_nonzero(processed_mask)
1255
- logger.info(f"Prostate mask after preprocessing: shape={processed_mask.shape}, non-zero voxels={non_zero_after}")
 
 
 
 
 
 
 
 
 
 
1256
 
1257
- if non_zero_after > 0:
1258
- preprocessed_data['prostate_mask_preprocessed'] = processed_mask
1259
- logger.info("✅ Successfully applied consistent preprocessing to prostate mask.")
 
 
 
 
 
 
1260
  else:
1261
- logger.warning("⚠️ Preprocessing still resulted in an empty prostate mask, even with the correct pipeline.")
1262
- else:
1263
- logger.warning("No raw prostate mask found.")
1264
 
1265
- logger.info(f"Successfully loaded and processed data with spatial shape: {preprocessed_data['spatial_shape']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1266
  return preprocessed_data
1267
 
1268
  except Exception as e:
1269
  logger.error(f"Error loading preprocessed patient data: {e}")
1270
- import traceback
1271
- logger.error(f"Full traceback: {traceback.format_exc()}")
1272
  raise
1273
 
1274
- # def apply_preprocessing_to_mask(prostate_mask: torch.Tensor, args, transforms) -> torch.Tensor:
1275
- # """
1276
- # Apply the same preprocessing transformations to prostate mask as used for image data
1277
- # This ensures consistent spatial alignment between gland mask and lesion predictions
1278
- # """
1279
- # try:
1280
- # logger.info("Applying consistent preprocessing to prostate mask...")
1281
- # logger.info(f"Input mask: shape={prostate_mask.shape}, non-zero voxels={torch.count_nonzero(prostate_mask)}")
1282
-
1283
- # # Create a dummy data dict similar to what the dataset would create
1284
- # # We need to mimic the dataset structure for transforms to work
1285
- # mask_data = {
1286
- # 'image': prostate_mask.unsqueeze(0), # Add channel dimension [C, D, H, W]
1287
- # 'label': prostate_mask.unsqueeze(0) # Use same mask as label for processing
1288
- # }
1289
-
1290
- # logger.info(f"Created data dict with shapes: image={mask_data['image'].shape}, label={mask_data['label'].shape}")
1291
-
1292
- # # Apply the same transforms that are used for the actual data
1293
- # try:
1294
- # logger.info("Applying dataset transforms...")
1295
- # transformed_data = transforms(mask_data)
1296
- # preprocessed_mask = transformed_data['label'][0] # Remove channel dimension
1297
-
1298
- # non_zero_after = torch.count_nonzero(preprocessed_mask)
1299
- # non_zero_original = torch.count_nonzero(prostate_mask)
1300
- # logger.info(f"Transform result: shape={preprocessed_mask.shape}, non-zero voxels={non_zero_after}")
1301
- # logger.info(f"Original mask had {non_zero_original} non-zero voxels")
1302
-
1303
- # # Check if transforms removed all data
1304
- # if non_zero_after == 0 and non_zero_original > 0:
1305
- # logger.warning("Transforms removed all mask data, falling back to manual preprocessing")
1306
- # return apply_manual_preprocessing_to_mask(prostate_mask, args)
1307
-
1308
- # # Ensure binary values
1309
- # preprocessed_mask = (preprocessed_mask > 0.5).float()
1310
- # final_non_zero = torch.count_nonzero(preprocessed_mask)
1311
-
1312
- # logger.info(f"✅ Successfully applied preprocessing to prostate mask: shape={preprocessed_mask.shape}, non-zero voxels={final_non_zero}")
1313
- # return preprocessed_mask
1314
-
1315
- # except Exception as e:
1316
- # logger.warning(f"Could not apply transforms to prostate mask: {e}")
1317
- # # Fallback to manual preprocessing
1318
- # logger.info("Falling back to manual preprocessing...")
1319
- # return apply_manual_preprocessing_to_mask(prostate_mask, args)
1320
-
1321
- # except Exception as e:
1322
- # logger.error(f"Error applying preprocessing to prostate mask: {e}")
1323
- # import traceback
1324
- # logger.error(f"Full traceback: {traceback.format_exc()}")
1325
-
1326
- # # Final fallback - just return the original mask if everything fails
1327
- # logger.warning("Returning original mask as final fallback")
1328
- # return prostate_mask
1329
-
1330
- # def apply_preprocessing_to_mask(prostate_mask: torch.Tensor, args, transforms) -> torch.Tensor:
1331
- # """
1332
- # Apply ONLY SPATIAL preprocessing transformations to the prostate mask.
1333
- # This ensures consistent spatial alignment without corrupting the mask's binary values.
1334
- # """
1335
- # try:
1336
- # logger.info("Applying consistent SPATIAL preprocessing to prostate mask...")
1337
- # logger.info(f"Input mask: shape={prostate_mask.shape}, non-zero voxels={torch.count_nonzero(prostate_mask)}")
1338
-
1339
- # # Define MONAI spatial transforms to filter for
1340
-
1341
- # # Filter for spatial transforms ONLY
1342
- # spatial_transforms = [
1343
- # t for t in transforms.transforms
1344
- # if isinstance(t, (Spacing, CenterSpatialCrop, SpatialPad, Resize))
1345
- # ]
1346
-
1347
- # if not spatial_transforms:
1348
- # logger.warning("No spatial transforms found in the pipeline. Falling back to manual preprocessing.")
1349
- # return apply_manual_preprocessing_to_mask(prostate_mask, args)
1350
 
1351
- # # Create a new pipeline with only the spatial transforms
1352
- # # Important: MONAI transforms often expect dictionary inputs
1353
- # spatial_pipeline = Compose(spatial_transforms)
1354
-
1355
- # # Create a data dictionary. Use a key like 'label' or 'mask'
1356
- # # to avoid intensity normalization which is often applied to 'image' key.
1357
- # mask_data = {'label': prostate_mask.unsqueeze(0)} # Add channel dim: [1, D, H, W]
1358
-
1359
- # # Apply the filtered spatial transformations
1360
- # transformed_data = spatial_pipeline(mask_data)
1361
- # preprocessed_mask = transformed_data['label'][0] # Remove channel dim
1362
-
1363
- # # Ensure mask remains binary after interpolation
1364
- # preprocessed_mask = (preprocessed_mask > 0.5).float()
1365
-
1366
- # final_non_zero = torch.count_nonzero(preprocessed_mask)
1367
- # logger.info(f"✅ Successfully applied SPATIAL preprocessing: shape={preprocessed_mask.shape}, non-zero voxels={final_non_zero}")
1368
-
1369
- # if final_non_zero == 0 and torch.count_nonzero(prostate_mask) > 0:
1370
- # logger.warning("Spatial transforms resulted in an empty mask. Trying manual fallback.")
1371
- # return apply_manual_preprocessing_to_mask(prostate_mask, args)
1372
-
1373
- # return preprocessed_mask
1374
-
1375
- # except Exception as e:
1376
- # logger.error(f"Error applying filtered spatial preprocessing: {e}. Falling back to manual method.")
1377
- # import traceback
1378
- # logger.error(f"Full traceback: {traceback.format_exc()}")
1379
- # return apply_manual_preprocessing_to_mask(prostate_mask, args)
1380
-
1381
- # def apply_manual_preprocessing_to_mask(prostate_mask: torch.Tensor, args) -> torch.Tensor:
1382
- # """
1383
- # Apply manual preprocessing to prostate mask to match the spatial dimensions
1384
- # This includes the key transformations: resampling, cropping, and padding
1385
- # """
1386
- # try:
1387
- # logger.info("Applying manual preprocessing to prostate mask...")
1388
- # logger.info(f"Input mask: shape={prostate_mask.shape}, non-zero voxels={torch.count_nonzero(prostate_mask)}")
1389
-
1390
- # # Add batch and channel dimensions for MONAI transforms
1391
- # mask_tensor = prostate_mask.unsqueeze(0).unsqueeze(0) # [1, 1, D, H, W]
1392
- # logger.info(f"After adding dimensions: {mask_tensor.shape}")
1393
-
1394
- # # Apply spatial resampling to match the target spacing - be conservative with large masks
1395
- # if hasattr(args, 'spacing') and args.spacing is not None:
1396
- # logger.info(f"Applying spacing transform with spacing: {args.spacing}")
1397
- # # For large original masks, use more conservative resampling
1398
- # original_size = prostate_mask.shape
1399
- # if max(original_size) > 300: # Large original mask
1400
- # logger.info("Large original mask detected, using conservative resampling")
1401
- # # Use linear interpolation for large masks to preserve more structure
1402
- # spacing_transform = Spacing(pixdim=args.spacing, mode='bilinear')
1403
- # else:
1404
- # spacing_transform = Spacing(pixdim=args.spacing, mode='nearest')
1405
-
1406
- # mask_tensor_before = mask_tensor.clone()
1407
- # mask_tensor = spacing_transform(mask_tensor)
1408
-
1409
- # logger.info(f"After spacing transform: {mask_tensor.shape}")
1410
- # logger.info(f"Non-zero voxels after spacing: {torch.count_nonzero(mask_tensor)}")
1411
-
1412
- # # Check if spacing transform removed all data
1413
- # if torch.count_nonzero(mask_tensor) == 0 and torch.count_nonzero(mask_tensor_before) > 0:
1414
- # logger.warning("Spacing transform removed all mask data, reverting")
1415
- # mask_tensor = mask_tensor_before
1416
-
1417
- # # Apply cropping to match the input size - be conservative to preserve mask content
1418
- # if hasattr(args, 'crop_spatial_size') and args.crop_spatial_size is not None:
1419
- # current_shape = mask_tensor.shape[2:] # [D, H, W]
1420
- # target_shape = args.crop_spatial_size
1421
-
1422
- # logger.info(f"Current shape after spacing: {current_shape}, target crop size: {target_shape}")
1423
-
1424
- # # Check if we need to crop at all
1425
- # needs_crop = any(c > t for c, t in zip(current_shape, target_shape))
1426
-
1427
- # if needs_crop:
1428
- # logger.info(f"Applying center crop to size: {target_shape}")
1429
-
1430
- # # For large reductions, be more conservative
1431
- # crop_size_ratios = [t/c for c, t in zip(current_shape, target_shape)]
1432
- # min_ratio = min(crop_size_ratios)
1433
-
1434
- # if min_ratio < 0.6: # Very aggressive crop (>40% reduction)
1435
- # logger.warning(f"Very aggressive crop detected (min ratio: {min_ratio:.2f}), using conservative approach")
1436
- # # Use a safer crop size that's not too aggressive
1437
- # safe_crop_size = tuple(min(c, max(t, int(c * 0.7))) for c, t in zip(current_shape, target_shape))
1438
- # logger.info(f"Using safer crop size: {safe_crop_size}")
1439
- # crop_transform = CenterSpatialCrop(roi_size=safe_crop_size)
1440
- # else:
1441
- # crop_transform = CenterSpatialCrop(roi_size=target_shape)
1442
-
1443
- # mask_tensor_before = mask_tensor.clone()
1444
- # mask_tensor = crop_transform(mask_tensor)
1445
-
1446
- # logger.info(f"After crop: {mask_tensor.shape}")
1447
- # logger.info(f"Non-zero voxels after crop: {torch.count_nonzero(mask_tensor)}")
1448
-
1449
- # # Check if cropping removed all data
1450
- # if torch.count_nonzero(mask_tensor) == 0 and torch.count_nonzero(mask_tensor_before) > 0:
1451
- # logger.warning("Crop removed all mask data, skipping crop")
1452
- # mask_tensor = mask_tensor_before
1453
- # else:
1454
- # logger.info("No cropping needed, current size is within target")
1455
-
1456
- # # Apply padding if needed to match exact target size
1457
- # if hasattr(args, 'crop_spatial_size') and args.crop_spatial_size is not None:
1458
- # current_shape = mask_tensor.shape[2:] # [D, H, W]
1459
- # target_shape = args.crop_spatial_size
1460
-
1461
- # if current_shape != target_shape:
1462
- # logger.info(f"Applying padding from {current_shape} to {target_shape}")
1463
- # pad_transform = SpatialPad(spatial_size=target_shape, mode='constant')
1464
- # mask_tensor = pad_transform(mask_tensor)
1465
- # logger.info(f"After padding: {mask_tensor.shape}")
1466
- # logger.info(f"Non-zero voxels after padding: {torch.count_nonzero(mask_tensor)}")
1467
-
1468
- # # Remove batch and channel dimensions
1469
- # preprocessed_mask = mask_tensor[0, 0] # [D, H, W]
1470
-
1471
- # # Ensure binary values (thresholding at 0.5)
1472
- # preprocessed_mask = (preprocessed_mask > 0.5).float()
1473
-
1474
- # logger.info(f"Manual preprocessing completed: shape={preprocessed_mask.shape}, non-zero voxels={torch.count_nonzero(preprocessed_mask)}")
1475
-
1476
- # return preprocessed_mask
1477
-
1478
- # except Exception as e:
1479
- # logger.error(f"Error in manual preprocessing: {e}")
1480
- # import traceback
1481
- # logger.error(f"Full traceback: {traceback.format_exc()}")
1482
-
1483
- # # Final fallback - simple resize
1484
- # logger.warning("Falling back to simple resize")
1485
- # target_shape = args.crop_spatial_size if hasattr(args, 'crop_spatial_size') else (64, 224, 224)
1486
-
1487
- # try:
1488
- # if prostate_mask.shape != target_shape:
1489
- # logger.info(f"Simple resize from {prostate_mask.shape} to {target_shape}")
1490
- # zoom_factors = tuple(target_shape[i] / prostate_mask.shape[i] for i in range(3))
1491
- # logger.info(f"Zoom factors: {zoom_factors}")
1492
-
1493
- # # Use order=1 (linear) for better preservation of structures, then threshold
1494
- # mask_np = prostate_mask.cpu().numpy().astype(np.float32)
1495
- # resized_mask = ndimage.zoom(mask_np, zoom_factors, order=1, prefilter=False)
1496
-
1497
- # # Apply threshold to maintain binary nature
1498
- # resized_mask = (resized_mask > 0.3).astype(np.float32) # Lower threshold to preserve more
1499
-
1500
- # result = torch.from_numpy(resized_mask)
1501
-
1502
- # logger.info(f"Simple resize result: shape={result.shape}, non-zero voxels={torch.count_nonzero(result)}")
1503
-
1504
- # if torch.count_nonzero(result) == 0:
1505
- # logger.warning("Simple resize resulted in empty mask, trying with very low threshold")
1506
- # # Try with an even lower threshold
1507
- # resized_mask_low = ndimage.zoom(mask_np, zoom_factors, order=1, prefilter=False)
1508
- # resized_mask_low = (resized_mask_low > 0.1).astype(np.float32)
1509
- # result = torch.from_numpy(resized_mask_low)
1510
- # logger.info(f"Low threshold result: non-zero voxels={torch.count_nonzero(result)}")
1511
-
1512
- # return result
1513
- # else:
1514
- # # Ensure binary
1515
- # result = (prostate_mask > 0.5).float()
1516
- # return result
1517
- # except Exception as final_error:
1518
- # logger.error(f"Simple resize failed: {final_error}")
1519
- # # Return empty mask if everything fails
1520
- # return torch.zeros(target_shape, dtype=torch.float32)
1521
 
1522
 
1523
  def create_interface():
 
346
  raise
347
 
348
 
349
+ # Removed get_raw_patient_data function as it's no longer needed
350
+ # All data is now processed through the unified preprocessing pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
 
353
  def visualize_multimodal_results(preprocessed_data: Dict, prediction: torch.Tensor, task: str) -> plt.Figure:
 
944
  logits = []
945
  model.eval()
946
  with torch.no_grad():
947
+ for idx, data in enumerate(data_loader):
948
+ # Classification only returns (img, gt, pid)
949
+ img, gt, pid = data
950
  img, gt = img.to(args.device), gt.to(args.device)
951
  logit = model(img)
952
  logits.append(logit)
 
1019
  # Exact demo inference loop
1020
  model.eval()
1021
  with torch.no_grad():
1022
+ for idx, data in enumerate(data_loader):
1023
+ # Segmentation returns (img, gt, pid, gland) or (img, gt, pid)
1024
+ if len(data) == 4:
1025
+ img, gt, pid, gland = data
1026
+ else:
1027
+ img, gt, pid = data
1028
+ gland = None
1029
+
1030
  img, gt = img.to(args.device), gt.to(args.device)
1031
  if args.sliding_window:
1032
  pred = sliding_window_inference(
 
1125
  return fig, empty_3d, f"Error: {str(e)}", gr.update(maximum=63, value=32)
1126
 
1127
 
 
 
1128
  def get_preprocessed_patient_data(patient_idx: int, task: str):
1129
  """
1130
+ Get preprocessed patient data that matches exactly what the model uses
1131
+ This ensures spatial consistency between displayed images and predictions
 
1132
  """
1133
  try:
1134
  logger.info(f"Loading preprocessed patient data for patient {patient_idx}, task {task}")
1135
 
1136
+ # Create dataloader to get the exact preprocessed data that the model uses
1137
+ data_loader, args, dataset = create_single_sample_dataloader(patient_idx, task)
1138
 
1139
+ # Extract the preprocessed data from the dataloader
1140
+ for idx, data in enumerate(data_loader):
1141
+ # Handle different return formats based on task
1142
+ if task == 'classification':
1143
+ # Classification returns: (img, gt, pid)
1144
+ img, gt, pid = data
1145
+ gland = None
1146
+ else:
1147
+ # Segmentation returns: (img, gt, pid, gland)
1148
+ if len(data) == 4:
1149
+ img, gt, pid, gland = data
1150
+ # Remove batch dimension from gland if it exists
1151
+ if gland is not None and len(gland) > 0:
1152
+ gland = gland[0] # Remove batch dimension
1153
+ else:
1154
+ # Fallback for old format
1155
+ img, gt, pid = data
1156
+ gland = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1157
 
1158
+ # img is the preprocessed multi-modal image [B, C, D, H, W]
1159
+ # gt is the preprocessed ground truth [B, C, D, H, W] or [B, D, H, W]
1160
+ # gland is the preprocessed prostate mask [B, C, D, H, W] or [B, D, H, W] or None
1161
 
1162
+ preprocessed_data = {
1163
+ 'preprocessed_image': img[0], # Remove batch dimension [C, D, H, W]
1164
+ 'preprocessed_gt': gt[0] if gt is not None else None, # Remove batch dimension
1165
+ 'patient_id': pid[0] if isinstance(pid, (list, tuple)) else pid,
1166
+ 'spatial_shape': img.shape[2:], # [D, H, W]
1167
+ 'args': args
1168
+ }
1169
 
1170
+ # Extract individual modalities from the preprocessed image
1171
+ # Assuming the model input has 3 channels: [T2W, DWI, ADC]
1172
+ if img.shape[1] >= 3: # Check if we have at least 3 channels
1173
+ preprocessed_data['t2w_preprocessed'] = img[0, 0] # [D, H, W]
1174
+ preprocessed_data['dwi_preprocessed'] = img[0, 1] # [D, H, W]
1175
+ preprocessed_data['adc_preprocessed'] = img[0, 2] # [D, H, W]
1176
+ else:
1177
+ logger.warning(f"Unexpected number of input channels: {img.shape[1]}")
1178
+ # Fallback: use the same channel for all modalities
1179
+ preprocessed_data['t2w_preprocessed'] = img[0, 0]
1180
+ preprocessed_data['dwi_preprocessed'] = img[0, 0]
1181
+ preprocessed_data['adc_preprocessed'] = img[0, 0]
1182
 
1183
+ # Convert ground truth to proper format
1184
+ if gt is not None:
1185
+ if len(gt.shape) == 4: # [B, D, H, W]
1186
+ preprocessed_data['ground_truth_preprocessed'] = gt[0] # [D, H, W]
1187
+ elif len(gt.shape) == 5: # [B, C, D, H, W]
1188
+ preprocessed_data['ground_truth_preprocessed'] = gt[0, 0] # [D, H, W]
1189
+ else:
1190
+ logger.warning(f"Unexpected ground truth shape: {gt.shape}")
1191
+ preprocessed_data['ground_truth_preprocessed'] = None
1192
  else:
1193
+ preprocessed_data['ground_truth_preprocessed'] = None
 
 
1194
 
1195
+ # Handle prostate gland mask - now from preprocessed data
1196
+ if gland is not None:
1197
+ # Convert gland mask to proper format
1198
+ if len(gland.shape) == 3: # [D, H, W]
1199
+ preprocessed_gland = gland
1200
+ elif len(gland.shape) == 4: # [C, D, H, W]
1201
+ preprocessed_gland = gland[0] # [D, H, W]
1202
+ else:
1203
+ logger.warning(f"Unexpected gland mask shape: {gland.shape}")
1204
+ preprocessed_gland = None
1205
+
1206
+ if preprocessed_gland is not None:
1207
+ # Ensure binary values
1208
+ preprocessed_gland = (preprocessed_gland > 0.5).float()
1209
+ non_zero_voxels = torch.count_nonzero(preprocessed_gland)
1210
+
1211
+ logger.info(f"✅ Preprocessed prostate gland mask: shape={preprocessed_gland.shape}, non-zero voxels={non_zero_voxels}")
1212
+ preprocessed_data['prostate_mask_preprocessed'] = preprocessed_gland if non_zero_voxels > 0 else None
1213
+ else:
1214
+ logger.warning("Could not process gland mask")
1215
+ preprocessed_data['prostate_mask_preprocessed'] = None
1216
+ else:
1217
+ logger.info("No prostate gland mask found in preprocessed data")
1218
+ preprocessed_data['prostate_mask_preprocessed'] = None
1219
+
1220
+ break # Only process the first (and only) sample
1221
+
1222
+ logger.info(f"Successfully loaded preprocessed data with spatial shape: {preprocessed_data['spatial_shape']}")
1223
  return preprocessed_data
1224
 
1225
  except Exception as e:
1226
  logger.error(f"Error loading preprocessed patient data: {e}")
 
 
1227
  raise
1228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1229
 
1230
+ # Removed manual prostate mask preprocessing functions
1231
+ # These are no longer needed as the gland mask is now processed
1232
+ # together with image and lesion data in the same pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1233
 
1234
 
1235
  def create_interface():
dataset/dataset_seg.py CHANGED
@@ -73,14 +73,44 @@ class UCLSet(BaseVolumeDataset):
73
  img = torch.stack([t2w, dwi, adc], 0)
74
  seg = self.read(path["lesion"]).unsqueeze(0)
75
  seg = seg > 0
76
- # print(img.shape)
77
- # seg = (seg == self.target_class).float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if self.transforms is not None:
79
- trans_dict = self.transforms({"image": img, "label": seg})
 
 
 
 
 
80
  if type(trans_dict) == list:
81
  trans_dict = trans_dict[0]
 
82
  img, seg = trans_dict["image"], trans_dict["label"]
83
- return img, seg, torch.tensor(idx, dtype=torch.long)
 
 
 
 
 
 
 
 
84
 
85
  # TODO: need to update; unfinished
86
  """
@@ -204,22 +234,25 @@ def get_transforms(args):
204
  train_transforms = [
205
  NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
206
  RandRotated(
207
- keys=["image", "label"],
208
  prob=0.3,
209
  range_x=30 / 180 * np.pi,
210
  keep_size=False,
211
- mode=["bilinear", "nearest"],
 
212
  ),
213
  RandZoomd(
214
- keys=["image", "label"],
215
  prob=0.3,
216
  min_zoom=[1, 0.9, 0.9],
217
  max_zoom=[1, 1.1, 1.1],
218
- mode=["trilinear", "nearest"],
 
219
  ),
220
  SpatialPadd(
221
- keys=["image", "label"],
222
  spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
 
223
  ),
224
  # RandCropByPosNegLabeld(
225
  # keys=["image", "label"],
@@ -230,11 +263,12 @@ def get_transforms(args):
230
  # num_samples=1,
231
  # ),
232
  RandSpatialCropd(
233
- keys=["image", "label"],
234
  roi_size=args.crop_spatial_size,
235
  random_size=False,
 
236
  ),
237
- RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
238
  # BinarizeLabeld(keys=["label"])
239
  RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),
240
  RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),
@@ -247,11 +281,13 @@ def get_transforms(args):
247
  [
248
  NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
249
  CenterSpatialCropd(
250
- keys=["image", "label"], roi_size=args.crop_spatial_size
 
251
  ),
252
  SpatialPadd(
253
- keys=["image", "label"],
254
  spatial_size=[i for i in args.crop_spatial_size],
 
255
  ),
256
  # BinarizeLabeld(keys=["label"])
257
  ]
@@ -260,11 +296,13 @@ def get_transforms(args):
260
  [
261
  NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
262
  CenterSpatialCropd(
263
- keys=["image", "label"], roi_size=args.crop_spatial_size
 
264
  ),
265
  SpatialPadd(
266
- keys=["image", "label"],
267
  spatial_size=[i for i in args.crop_spatial_size],
 
268
  ),
269
  # BinarizeLabeld(keys=["label"])
270
  ]
 
73
  img = torch.stack([t2w, dwi, adc], 0)
74
  seg = self.read(path["lesion"]).unsqueeze(0)
75
  seg = seg > 0
76
+
77
+ # Load prostate gland mask if available
78
+ gland = None
79
+ try:
80
+ # Try to find prostate_mask.nii.gz in the same directory as t2w
81
+ t2w_path = os.path.join(self.root, path["t2w"])
82
+ prostate_mask_path = os.path.join(os.path.dirname(t2w_path), "prostate_mask.nii.gz")
83
+ if os.path.exists(prostate_mask_path):
84
+ # Read prostate mask relative to the t2w directory
85
+ relative_mask_path = os.path.relpath(prostate_mask_path, self.root)
86
+ gland = self.read(relative_mask_path).unsqueeze(0)
87
+ gland = gland > 0 # Binarize
88
+ print(f"Loaded prostate mask: {relative_mask_path}, shape: {gland.shape}, non-zero: {torch.count_nonzero(gland)}")
89
+ else:
90
+ print(f"Prostate mask not found at: {prostate_mask_path}")
91
+ except Exception as e:
92
+ print(f"Failed to load prostate mask for {path['t2w']}: {e}")
93
+
94
  if self.transforms is not None:
95
+ # Include gland in transforms if available
96
+ if gland is not None:
97
+ trans_dict = self.transforms({"image": img, "label": seg, "gland": gland})
98
+ else:
99
+ trans_dict = self.transforms({"image": img, "label": seg})
100
+
101
  if type(trans_dict) == list:
102
  trans_dict = trans_dict[0]
103
+
104
  img, seg = trans_dict["image"], trans_dict["label"]
105
+
106
+ # Extract processed gland if it was included
107
+ if gland is not None and "gland" in trans_dict:
108
+ gland = trans_dict["gland"]
109
+ else:
110
+ gland = None
111
+
112
+ # Return gland as part of the data (we'll modify the app to handle this)
113
+ return img, seg, torch.tensor(idx, dtype=torch.long), gland
114
 
115
  # TODO: need to update; unfinished
116
  """
 
234
  train_transforms = [
235
  NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
236
  RandRotated(
237
+ keys=["image", "label", "gland"],
238
  prob=0.3,
239
  range_x=30 / 180 * np.pi,
240
  keep_size=False,
241
+ mode=["bilinear", "nearest", "nearest"],
242
+ allow_missing_keys=True,
243
  ),
244
  RandZoomd(
245
+ keys=["image", "label", "gland"],
246
  prob=0.3,
247
  min_zoom=[1, 0.9, 0.9],
248
  max_zoom=[1, 1.1, 1.1],
249
+ mode=["trilinear", "nearest", "nearest"],
250
+ allow_missing_keys=True,
251
  ),
252
  SpatialPadd(
253
+ keys=["image", "label", "gland"],
254
  spatial_size=[round(i * 1.2) for i in args.crop_spatial_size],
255
+ allow_missing_keys=True,
256
  ),
257
  # RandCropByPosNegLabeld(
258
  # keys=["image", "label"],
 
263
  # num_samples=1,
264
  # ),
265
  RandSpatialCropd(
266
+ keys=["image", "label", "gland"],
267
  roi_size=args.crop_spatial_size,
268
  random_size=False,
269
+ allow_missing_keys=True,
270
  ),
271
+ RandFlipd(keys=["image", "label", "gland"], prob=0.5, spatial_axis=2, allow_missing_keys=True),
272
  # BinarizeLabeld(keys=["label"])
273
  RandScaleIntensityd(keys="image", factors=0.1, prob=0.8),
274
  RandShiftIntensityd(keys="image", offsets=0.1, prob=0.8),
 
281
  [
282
  NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
283
  CenterSpatialCropd(
284
+ keys=["image", "label", "gland"], roi_size=args.crop_spatial_size,
285
+ allow_missing_keys=True,
286
  ),
287
  SpatialPadd(
288
+ keys=["image", "label", "gland"],
289
  spatial_size=[i for i in args.crop_spatial_size],
290
+ allow_missing_keys=True,
291
  ),
292
  # BinarizeLabeld(keys=["label"])
293
  ]
 
296
  [
297
  NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
298
  CenterSpatialCropd(
299
+ keys=["image", "label", "gland"], roi_size=args.crop_spatial_size,
300
+ allow_missing_keys=True,
301
  ),
302
  SpatialPadd(
303
+ keys=["image", "label", "gland"],
304
  spatial_size=[i for i in args.crop_spatial_size],
305
+ allow_missing_keys=True,
306
  ),
307
  # BinarizeLabeld(keys=["label"])
308
  ]