facedet / data /dataloader.py
cledouxluma's picture
Upload data/dataloader.py with huggingface_hub
8499cad verified
"""
DataLoader builders with production-ready configuration.
"""
import torch
from torch.utils.data import DataLoader, DistributedSampler
from typing import Optional
from .widerface import WiderFaceDataset
from .augmentations import TrainAugmentation, ValAugmentation
def build_train_loader(
data_root: str,
batch_size: int = 8,
target_size: int = 640,
num_workers: int = 4,
use_landmarks: bool = False,
enable_robustness: bool = True,
distributed: bool = False,
rank: int = 0,
world_size: int = 1,
) -> DataLoader:
"""Build training data loader with SCRFD augmentation pipeline."""
transform = TrainAugmentation(
target_size=target_size,
enable_robustness=enable_robustness,
)
dataset = WiderFaceDataset(
root_dir=data_root,
split='train',
transform=transform,
use_landmarks=use_landmarks,
min_face_size=2,
)
sampler = None
if distributed:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(sampler is None),
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
collate_fn=WiderFaceDataset.collate_fn,
drop_last=True,
)
return loader
def build_val_loader(
data_root: str,
batch_size: int = 1,
target_size: int = 640,
num_workers: int = 4,
use_landmarks: bool = False,
) -> DataLoader:
"""Build validation data loader."""
transform = ValAugmentation(target_size=target_size)
dataset = WiderFaceDataset(
root_dir=data_root,
split='val',
transform=transform,
use_landmarks=use_landmarks,
min_face_size=1,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=WiderFaceDataset.collate_fn,
)
return loader