dreamlessx commited on
Commit
4fe4688
·
verified ·
1 Parent(s): 74d5d6f

Upload landmarkdiff/data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/data.py +400 -0
landmarkdiff/data.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reusable data loading utilities for LandmarkDiff training and evaluation.
2
+
3
+ Provides PyTorch Dataset implementations for loading synthetic training pairs,
4
+ manifest-based datasets, and evaluation datasets. Extracted from the training
5
+ script for reuse across training, evaluation, and testing pipelines.
6
+
7
+ Usage::
8
+
9
+ from landmarkdiff.data import SurgicalPairDataset, create_dataloader
10
+
11
+ dataset = SurgicalPairDataset("data/training_combined", resolution=512)
12
+ loader = create_dataloader(dataset, batch_size=4, num_workers=4)
13
+
14
+ for batch in loader:
15
+ input_img = batch["input"] # (B, 3, H, W) RGB [0,1]
16
+ target_img = batch["target"] # (B, 3, H, W) RGB [0,1]
17
+ conditioning = batch["conditioning"] # (B, 3, H, W) RGB [0,1]
18
+ mask = batch["mask"] # (B, 1, H, W) [0,1]
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import csv
24
+ import json
25
+ import logging
26
+ from pathlib import Path
27
+ from typing import Callable
28
+
29
+ import cv2
30
+ import numpy as np
31
+ import torch
32
+ from torch.utils.data import DataLoader, Dataset, Sampler, WeightedRandomSampler
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Core dataset
39
+ # ---------------------------------------------------------------------------
40
+
41
+ class SurgicalPairDataset(Dataset):
42
+ """Dataset for loading surgical before/after training pairs.
43
+
44
+ Each sample has four components:
45
+ - input: original face image (before surgery)
46
+ - target: modified face image (after surgery)
47
+ - conditioning: 3-channel landmark mesh visualization
48
+ - mask: surgical region mask (soft float)
49
+
50
+ Supports loading from a flat directory of ``{prefix}_input.png`` files
51
+ or from a manifest CSV.
52
+
53
+ Args:
54
+ data_dir: Directory containing training pair images.
55
+ resolution: Target image resolution (square).
56
+ manifest_path: Optional CSV with columns [prefix, procedure, ...].
57
+ If None, auto-discovers pairs from ``*_input.png`` files.
58
+ transform: Optional callable for custom augmentation. Receives and
59
+ returns a dict with numpy arrays.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ data_dir: str | Path,
65
+ resolution: int = 512,
66
+ manifest_path: str | Path | None = None,
67
+ transform: Callable[[dict], dict] | None = None,
68
+ ):
69
+ self.data_dir = Path(data_dir)
70
+ self.resolution = resolution
71
+ self.transform = transform
72
+
73
+ # Discover pairs
74
+ if manifest_path is not None:
75
+ self.pairs, self.metadata = self._load_manifest(Path(manifest_path))
76
+ else:
77
+ self.pairs = sorted(self.data_dir.glob("*_input.png"))
78
+ self.metadata = self._load_metadata()
79
+
80
+ if not self.pairs:
81
+ raise FileNotFoundError(f"No training pairs found in {data_dir}")
82
+
83
+ logger.info("Loaded %d training pairs from %s", len(self.pairs), data_dir)
84
+
85
+ def _load_manifest(self, path: Path) -> tuple[list[Path], dict[str, dict]]:
86
+ """Load pairs from a manifest CSV."""
87
+ pairs = []
88
+ metadata = {}
89
+ with open(path) as f:
90
+ reader = csv.DictReader(f)
91
+ for row in reader:
92
+ prefix = row.get("prefix", row.get("name", ""))
93
+ input_path = self.data_dir / f"{prefix}_input.png"
94
+ if input_path.exists():
95
+ pairs.append(input_path)
96
+ metadata[prefix] = dict(row)
97
+ return pairs, metadata
98
+
99
+ def _load_metadata(self) -> dict[str, dict]:
100
+ """Load metadata from metadata.json if present."""
101
+ meta_path = self.data_dir / "metadata.json"
102
+ if not meta_path.exists():
103
+ return {}
104
+ try:
105
+ with open(meta_path) as f:
106
+ data = json.load(f)
107
+ return data.get("pairs", {})
108
+ except Exception:
109
+ return {}
110
+
111
+ def get_procedure(self, idx: int) -> str:
112
+ """Get the surgical procedure type for a sample."""
113
+ prefix = self._prefix(idx)
114
+ info = self.metadata.get(prefix, {})
115
+ return info.get("procedure", "unknown")
116
+
117
+ def get_procedures(self) -> list[str]:
118
+ """Get procedure types for all samples."""
119
+ return [self.get_procedure(i) for i in range(len(self))]
120
+
121
+ def _prefix(self, idx: int) -> str:
122
+ return self.pairs[idx].stem.replace("_input", "")
123
+
124
+ def __len__(self) -> int:
125
+ return len(self.pairs)
126
+
127
+ def __getitem__(self, idx: int) -> dict:
128
+ prefix = self._prefix(idx)
129
+
130
+ # Load images as BGR uint8
131
+ input_bgr = self._load_image(f"{prefix}_input.png")
132
+ target_bgr = self._load_image(f"{prefix}_target.png")
133
+ cond_bgr = self._load_image(f"{prefix}_conditioning.png")
134
+ mask_arr = self._load_mask(f"{prefix}_mask.png")
135
+
136
+ sample = {
137
+ "input_image": input_bgr,
138
+ "target_image": target_bgr,
139
+ "conditioning": cond_bgr,
140
+ "mask": mask_arr,
141
+ "procedure": self.get_procedure(idx),
142
+ "idx": idx,
143
+ }
144
+
145
+ # Apply custom transform
146
+ if self.transform is not None:
147
+ sample = self.transform(sample)
148
+
149
+ # Convert to tensors
150
+ return {
151
+ "input": bgr_to_tensor(sample["input_image"]),
152
+ "target": bgr_to_tensor(sample["target_image"]),
153
+ "conditioning": bgr_to_tensor(sample["conditioning"]),
154
+ "mask": mask_to_tensor(sample["mask"]),
155
+ "procedure": sample["procedure"],
156
+ "idx": sample["idx"],
157
+ }
158
+
159
+ def _load_image(self, filename: str) -> np.ndarray:
160
+ """Load an image as BGR uint8, resized to resolution."""
161
+ path = self.data_dir / filename
162
+ img = cv2.imread(str(path))
163
+ if img is None:
164
+ logger.warning("Failed to load %s, using blank", path)
165
+ return np.zeros(
166
+ (self.resolution, self.resolution, 3), dtype=np.uint8
167
+ )
168
+ if img.shape[:2] != (self.resolution, self.resolution):
169
+ img = cv2.resize(img, (self.resolution, self.resolution))
170
+ return img
171
+
172
+ def _load_mask(self, filename: str) -> np.ndarray:
173
+ """Load a mask as float32 [0,1], resized to resolution."""
174
+ path = self.data_dir / filename
175
+ if not path.exists():
176
+ return np.ones(
177
+ (self.resolution, self.resolution), dtype=np.float32
178
+ )
179
+ mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
180
+ if mask is None:
181
+ return np.ones(
182
+ (self.resolution, self.resolution), dtype=np.float32
183
+ )
184
+ mask = cv2.resize(mask, (self.resolution, self.resolution))
185
+ return mask.astype(np.float32) / 255.0
186
+
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # Evaluation dataset (input + ground truth)
190
+ # ---------------------------------------------------------------------------
191
+
192
+ class EvalPairDataset(Dataset):
193
+ """Dataset for evaluation: loads input/target pairs with procedure labels.
194
+
195
+ Args:
196
+ data_dir: Directory with evaluation pairs.
197
+ resolution: Target resolution.
198
+ """
199
+
200
+ def __init__(self, data_dir: str | Path, resolution: int = 512):
201
+ self.data_dir = Path(data_dir)
202
+ self.resolution = resolution
203
+ self.pairs = sorted(self.data_dir.glob("*_input.png"))
204
+
205
+ # Load metadata
206
+ meta_path = self.data_dir / "metadata.json"
207
+ self._meta = {}
208
+ if meta_path.exists():
209
+ try:
210
+ with open(meta_path) as f:
211
+ self._meta = json.load(f).get("pairs", {})
212
+ except Exception:
213
+ pass
214
+
215
+ def __len__(self) -> int:
216
+ return len(self.pairs)
217
+
218
+ def __getitem__(self, idx: int) -> dict:
219
+ prefix = self.pairs[idx].stem.replace("_input", "")
220
+
221
+ input_img = self._load(f"{prefix}_input.png")
222
+ target_img = self._load(f"{prefix}_target.png")
223
+
224
+ info = self._meta.get(prefix, {})
225
+ procedure = info.get("procedure", "unknown")
226
+
227
+ return {
228
+ "input": bgr_to_tensor(input_img),
229
+ "target": bgr_to_tensor(target_img),
230
+ "procedure": procedure,
231
+ "prefix": prefix,
232
+ }
233
+
234
+ def _load(self, filename: str) -> np.ndarray:
235
+ path = self.data_dir / filename
236
+ img = cv2.imread(str(path))
237
+ if img is None:
238
+ return np.zeros(
239
+ (self.resolution, self.resolution, 3), dtype=np.uint8
240
+ )
241
+ if img.shape[:2] != (self.resolution, self.resolution):
242
+ img = cv2.resize(img, (self.resolution, self.resolution))
243
+ return img
244
+
245
+
246
+ # ---------------------------------------------------------------------------
247
+ # Conversion utilities
248
+ # ---------------------------------------------------------------------------
249
+
250
+ def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
251
+ """Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
252
+ rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
253
+ return torch.from_numpy(np.ascontiguousarray(rgb)).permute(2, 0, 1)
254
+
255
+
256
+ def tensor_to_bgr(tensor: torch.Tensor) -> np.ndarray:
257
+ """Convert RGB [0,1] tensor (C, H, W) to BGR uint8 image."""
258
+ rgb = tensor.detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy()
259
+ bgr = (rgb[:, :, ::-1] * 255).astype(np.uint8)
260
+ return np.ascontiguousarray(bgr)
261
+
262
+
263
+ def mask_to_tensor(mask: np.ndarray) -> torch.Tensor:
264
+ """Convert float32 mask (H, W) to tensor (1, H, W)."""
265
+ if mask.ndim == 3:
266
+ mask = mask[:, :, 0]
267
+ return torch.from_numpy(mask).unsqueeze(0)
268
+
269
+
270
+ # ---------------------------------------------------------------------------
271
+ # Samplers
272
+ # ---------------------------------------------------------------------------
273
+
274
+ def create_procedure_sampler(
275
+ dataset: SurgicalPairDataset,
276
+ balance_procedures: bool = True,
277
+ ) -> Sampler | None:
278
+ """Create a weighted sampler that balances procedure types.
279
+
280
+ Returns None if balancing is disabled or all procedures are the same.
281
+ """
282
+ if not balance_procedures:
283
+ return None
284
+
285
+ procedures = dataset.get_procedures()
286
+ unique_procs = list(set(procedures))
287
+
288
+ if len(unique_procs) <= 1:
289
+ return None
290
+
291
+ # Count per procedure
292
+ counts = {p: procedures.count(p) for p in unique_procs}
293
+ total = len(procedures)
294
+
295
+ # Weight inversely proportional to count
296
+ weights = []
297
+ for proc in procedures:
298
+ w = total / (len(unique_procs) * counts[proc])
299
+ weights.append(w)
300
+
301
+ return WeightedRandomSampler(
302
+ weights=weights,
303
+ num_samples=len(dataset),
304
+ replacement=True,
305
+ )
306
+
307
+
308
+ # ---------------------------------------------------------------------------
309
+ # DataLoader factory
310
+ # ---------------------------------------------------------------------------
311
+
312
+ def create_dataloader(
313
+ dataset: Dataset,
314
+ batch_size: int = 4,
315
+ num_workers: int = 4,
316
+ shuffle: bool = True,
317
+ sampler: Sampler | None = None,
318
+ pin_memory: bool = True,
319
+ drop_last: bool = True,
320
+ persistent_workers: bool = False,
321
+ ) -> DataLoader:
322
+ """Create a DataLoader with sensible defaults for training.
323
+
324
+ Args:
325
+ dataset: PyTorch Dataset.
326
+ batch_size: Batch size.
327
+ num_workers: Number of data loading workers.
328
+ shuffle: Shuffle data (ignored if sampler is provided).
329
+ sampler: Custom sampler (e.g., from create_procedure_sampler).
330
+ pin_memory: Pin memory for faster GPU transfer.
331
+ drop_last: Drop last incomplete batch.
332
+ persistent_workers: Keep workers alive between epochs.
333
+
334
+ Returns:
335
+ Configured DataLoader.
336
+ """
337
+ if sampler is not None:
338
+ shuffle = False # Sampler and shuffle are mutually exclusive
339
+
340
+ return DataLoader(
341
+ dataset,
342
+ batch_size=batch_size,
343
+ shuffle=shuffle,
344
+ sampler=sampler,
345
+ num_workers=num_workers,
346
+ pin_memory=pin_memory and torch.cuda.is_available(),
347
+ drop_last=drop_last,
348
+ persistent_workers=persistent_workers and num_workers > 0,
349
+ )
350
+
351
+
352
+ # ---------------------------------------------------------------------------
353
+ # Multi-directory dataset
354
+ # ---------------------------------------------------------------------------
355
+
356
+ class CombinedDataset(Dataset):
357
+ """Combine multiple SurgicalPairDatasets into one.
358
+
359
+ Useful for combining synthetic v1, v2, v3 data and real pairs.
360
+
361
+ Args:
362
+ datasets: List of SurgicalPairDataset instances.
363
+ """
364
+
365
+ def __init__(self, datasets: list[SurgicalPairDataset]):
366
+ self.datasets = datasets
367
+ self._cumulative_sizes = []
368
+ total = 0
369
+ for ds in datasets:
370
+ total += len(ds)
371
+ self._cumulative_sizes.append(total)
372
+
373
+ def __len__(self) -> int:
374
+ return self._cumulative_sizes[-1] if self._cumulative_sizes else 0
375
+
376
+ def __getitem__(self, idx: int) -> dict:
377
+ dataset_idx = 0
378
+ for i, size in enumerate(self._cumulative_sizes):
379
+ if idx < size:
380
+ dataset_idx = i
381
+ break
382
+ if dataset_idx > 0:
383
+ idx -= self._cumulative_sizes[dataset_idx - 1]
384
+ return self.datasets[dataset_idx][idx]
385
+
386
+ def get_procedure(self, idx: int) -> str:
387
+ dataset_idx = 0
388
+ for i, size in enumerate(self._cumulative_sizes):
389
+ if idx < size:
390
+ dataset_idx = i
391
+ break
392
+ if dataset_idx > 0:
393
+ idx -= self._cumulative_sizes[dataset_idx - 1]
394
+ return self.datasets[dataset_idx].get_procedure(idx)
395
+
396
+ def get_procedures(self) -> list[str]:
397
+ procs = []
398
+ for ds in self.datasets:
399
+ procs.extend(ds.get_procedures())
400
+ return procs