dreamlessx commited on
Commit
c1396a6
·
verified ·
1 Parent(s): 8816082

Update landmarkdiff/data.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/data.py +20 -10
landmarkdiff/data.py CHANGED
@@ -38,7 +38,6 @@ logger = logging.getLogger(__name__)
38
  # Core dataset
39
  # ---------------------------------------------------------------------------
40
 
41
-
42
  class SurgicalPairDataset(Dataset):
43
  """Dataset for loading surgical before/after training pairs.
44
 
@@ -163,7 +162,9 @@ class SurgicalPairDataset(Dataset):
163
  img = cv2.imread(str(path))
164
  if img is None:
165
  logger.warning("Failed to load %s, using blank", path)
166
- return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
 
 
167
  if img.shape[:2] != (self.resolution, self.resolution):
168
  img = cv2.resize(img, (self.resolution, self.resolution))
169
  return img
@@ -172,10 +173,14 @@ class SurgicalPairDataset(Dataset):
172
  """Load a mask as float32 [0,1], resized to resolution."""
173
  path = self.data_dir / filename
174
  if not path.exists():
175
- return np.ones((self.resolution, self.resolution), dtype=np.float32)
 
 
176
  mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
177
  if mask is None:
178
- return np.ones((self.resolution, self.resolution), dtype=np.float32)
 
 
179
  mask = cv2.resize(mask, (self.resolution, self.resolution))
180
  return mask.astype(np.float32) / 255.0
181
 
@@ -184,7 +189,6 @@ class SurgicalPairDataset(Dataset):
184
  # Evaluation dataset (input + ground truth)
185
  # ---------------------------------------------------------------------------
186
 
187
-
188
  class EvalPairDataset(Dataset):
189
  """Dataset for evaluation: loads input/target pairs with procedure labels.
190
 
@@ -231,7 +235,9 @@ class EvalPairDataset(Dataset):
231
  path = self.data_dir / filename
232
  img = cv2.imread(str(path))
233
  if img is None:
234
- return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
 
 
235
  if img.shape[:2] != (self.resolution, self.resolution):
236
  img = cv2.resize(img, (self.resolution, self.resolution))
237
  return img
@@ -241,7 +247,6 @@ class EvalPairDataset(Dataset):
241
  # Conversion utilities
242
  # ---------------------------------------------------------------------------
243
 
244
-
245
  def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
246
  """Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
247
  rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
@@ -266,7 +271,6 @@ def mask_to_tensor(mask: np.ndarray) -> torch.Tensor:
266
  # Samplers
267
  # ---------------------------------------------------------------------------
268
 
269
-
270
  def create_procedure_sampler(
271
  dataset: SurgicalPairDataset,
272
  balance_procedures: bool = True,
@@ -305,7 +309,6 @@ def create_procedure_sampler(
305
  # DataLoader factory
306
  # ---------------------------------------------------------------------------
307
 
308
-
309
  def create_dataloader(
310
  dataset: Dataset,
311
  batch_size: int = 4,
@@ -350,7 +353,6 @@ def create_dataloader(
350
  # Multi-directory dataset
351
  # ---------------------------------------------------------------------------
352
 
353
-
354
  class CombinedDataset(Dataset):
355
  """Combine multiple SurgicalPairDatasets into one.
356
 
@@ -372,6 +374,10 @@ class CombinedDataset(Dataset):
372
  return self._cumulative_sizes[-1] if self._cumulative_sizes else 0
373
 
374
  def __getitem__(self, idx: int) -> dict:
 
 
 
 
375
  dataset_idx = 0
376
  for i, size in enumerate(self._cumulative_sizes):
377
  if idx < size:
@@ -382,6 +388,10 @@ class CombinedDataset(Dataset):
382
  return self.datasets[dataset_idx][idx]
383
 
384
  def get_procedure(self, idx: int) -> str:
 
 
 
 
385
  dataset_idx = 0
386
  for i, size in enumerate(self._cumulative_sizes):
387
  if idx < size:
 
38
  # Core dataset
39
  # ---------------------------------------------------------------------------
40
 
 
41
  class SurgicalPairDataset(Dataset):
42
  """Dataset for loading surgical before/after training pairs.
43
 
 
162
  img = cv2.imread(str(path))
163
  if img is None:
164
  logger.warning("Failed to load %s, using blank", path)
165
+ return np.zeros(
166
+ (self.resolution, self.resolution, 3), dtype=np.uint8
167
+ )
168
  if img.shape[:2] != (self.resolution, self.resolution):
169
  img = cv2.resize(img, (self.resolution, self.resolution))
170
  return img
 
173
  """Load a mask as float32 [0,1], resized to resolution."""
174
  path = self.data_dir / filename
175
  if not path.exists():
176
+ return np.ones(
177
+ (self.resolution, self.resolution), dtype=np.float32
178
+ )
179
  mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
180
  if mask is None:
181
+ return np.ones(
182
+ (self.resolution, self.resolution), dtype=np.float32
183
+ )
184
  mask = cv2.resize(mask, (self.resolution, self.resolution))
185
  return mask.astype(np.float32) / 255.0
186
 
 
189
  # Evaluation dataset (input + ground truth)
190
  # ---------------------------------------------------------------------------
191
 
 
192
  class EvalPairDataset(Dataset):
193
  """Dataset for evaluation: loads input/target pairs with procedure labels.
194
 
 
235
  path = self.data_dir / filename
236
  img = cv2.imread(str(path))
237
  if img is None:
238
+ return np.zeros(
239
+ (self.resolution, self.resolution, 3), dtype=np.uint8
240
+ )
241
  if img.shape[:2] != (self.resolution, self.resolution):
242
  img = cv2.resize(img, (self.resolution, self.resolution))
243
  return img
 
247
  # Conversion utilities
248
  # ---------------------------------------------------------------------------
249
 
 
250
  def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
251
  """Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
252
  rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
 
271
  # Samplers
272
  # ---------------------------------------------------------------------------
273
 
 
274
  def create_procedure_sampler(
275
  dataset: SurgicalPairDataset,
276
  balance_procedures: bool = True,
 
309
  # DataLoader factory
310
  # ---------------------------------------------------------------------------
311
 
 
312
  def create_dataloader(
313
  dataset: Dataset,
314
  batch_size: int = 4,
 
353
  # Multi-directory dataset
354
  # ---------------------------------------------------------------------------
355
 
 
356
  class CombinedDataset(Dataset):
357
  """Combine multiple SurgicalPairDatasets into one.
358
 
 
374
  return self._cumulative_sizes[-1] if self._cumulative_sizes else 0
375
 
376
  def __getitem__(self, idx: int) -> dict:
377
+ if idx < 0 or idx >= len(self):
378
+ raise IndexError(
379
+ f"CombinedDataset index {idx} out of range [0, {len(self)})"
380
+ )
381
  dataset_idx = 0
382
  for i, size in enumerate(self._cumulative_sizes):
383
  if idx < size:
 
388
  return self.datasets[dataset_idx][idx]
389
 
390
  def get_procedure(self, idx: int) -> str:
391
+ if idx < 0 or idx >= len(self):
392
+ raise IndexError(
393
+ f"CombinedDataset index {idx} out of range [0, {len(self)})"
394
+ )
395
  dataset_idx = 0
396
  for i, size in enumerate(self._cumulative_sizes):
397
  if idx < size: