""" DataLoader builders with production-ready configuration. """ import torch from torch.utils.data import DataLoader, DistributedSampler from typing import Optional from .widerface import WiderFaceDataset from .augmentations import TrainAugmentation, ValAugmentation def build_train_loader( data_root: str, batch_size: int = 8, target_size: int = 640, num_workers: int = 4, use_landmarks: bool = False, enable_robustness: bool = True, distributed: bool = False, rank: int = 0, world_size: int = 1, ) -> DataLoader: """Build training data loader with SCRFD augmentation pipeline.""" transform = TrainAugmentation( target_size=target_size, enable_robustness=enable_robustness, ) dataset = WiderFaceDataset( root_dir=data_root, split='train', transform=transform, use_landmarks=use_landmarks, min_face_size=2, ) sampler = None if distributed: sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) loader = DataLoader( dataset, batch_size=batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=num_workers, pin_memory=True, collate_fn=WiderFaceDataset.collate_fn, drop_last=True, ) return loader def build_val_loader( data_root: str, batch_size: int = 1, target_size: int = 640, num_workers: int = 4, use_landmarks: bool = False, ) -> DataLoader: """Build validation data loader.""" transform = ValAugmentation(target_size=target_size) dataset = WiderFaceDataset( root_dir=data_root, split='val', transform=transform, use_landmarks=use_landmarks, min_face_size=1, ) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=WiderFaceDataset.collate_fn, ) return loader