| """ |
| DataLoader builders with production-ready configuration. |
| """ |
|
|
| import torch |
| from torch.utils.data import DataLoader, DistributedSampler |
| from typing import Optional |
|
|
| from .widerface import WiderFaceDataset |
| from .augmentations import TrainAugmentation, ValAugmentation |
|
|
|
|
| def build_train_loader( |
| data_root: str, |
| batch_size: int = 8, |
| target_size: int = 640, |
| num_workers: int = 4, |
| use_landmarks: bool = False, |
| enable_robustness: bool = True, |
| distributed: bool = False, |
| rank: int = 0, |
| world_size: int = 1, |
| ) -> DataLoader: |
| """Build training data loader with SCRFD augmentation pipeline.""" |
|
|
| transform = TrainAugmentation( |
| target_size=target_size, |
| enable_robustness=enable_robustness, |
| ) |
|
|
| dataset = WiderFaceDataset( |
| root_dir=data_root, |
| split='train', |
| transform=transform, |
| use_landmarks=use_landmarks, |
| min_face_size=2, |
| ) |
|
|
| sampler = None |
| if distributed: |
| sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=(sampler is None), |
| sampler=sampler, |
| num_workers=num_workers, |
| pin_memory=True, |
| collate_fn=WiderFaceDataset.collate_fn, |
| drop_last=True, |
| ) |
|
|
| return loader |
|
|
|
|
| def build_val_loader( |
| data_root: str, |
| batch_size: int = 1, |
| target_size: int = 640, |
| num_workers: int = 4, |
| use_landmarks: bool = False, |
| ) -> DataLoader: |
| """Build validation data loader.""" |
|
|
| transform = ValAugmentation(target_size=target_size) |
|
|
| dataset = WiderFaceDataset( |
| root_dir=data_root, |
| split='val', |
| transform=transform, |
| use_landmarks=use_landmarks, |
| min_face_size=1, |
| ) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=True, |
| collate_fn=WiderFaceDataset.collate_fn, |
| ) |
|
|
| return loader |
|
|