File size: 2,017 Bytes
8499cad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """
DataLoader builders with production-ready configuration.
"""
import torch
from torch.utils.data import DataLoader, DistributedSampler
from typing import Optional
from .widerface import WiderFaceDataset
from .augmentations import TrainAugmentation, ValAugmentation
def build_train_loader(
data_root: str,
batch_size: int = 8,
target_size: int = 640,
num_workers: int = 4,
use_landmarks: bool = False,
enable_robustness: bool = True,
distributed: bool = False,
rank: int = 0,
world_size: int = 1,
) -> DataLoader:
"""Build training data loader with SCRFD augmentation pipeline."""
transform = TrainAugmentation(
target_size=target_size,
enable_robustness=enable_robustness,
)
dataset = WiderFaceDataset(
root_dir=data_root,
split='train',
transform=transform,
use_landmarks=use_landmarks,
min_face_size=2,
)
sampler = None
if distributed:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(sampler is None),
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
collate_fn=WiderFaceDataset.collate_fn,
drop_last=True,
)
return loader
def build_val_loader(
data_root: str,
batch_size: int = 1,
target_size: int = 640,
num_workers: int = 4,
use_landmarks: bool = False,
) -> DataLoader:
"""Build validation data loader."""
transform = ValAugmentation(target_size=target_size)
dataset = WiderFaceDataset(
root_dir=data_root,
split='val',
transform=transform,
use_landmarks=use_landmarks,
min_face_size=1,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=WiderFaceDataset.collate_fn,
)
return loader
|