cledouxluma commited on
Commit
8499cad
·
verified ·
1 Parent(s): 4931970

Upload data/dataloader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data/dataloader.py +85 -0
data/dataloader.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataLoader builders with production-ready configuration.
3
+ """
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader, DistributedSampler
7
+ from typing import Optional
8
+
9
+ from .widerface import WiderFaceDataset
10
+ from .augmentations import TrainAugmentation, ValAugmentation
11
+
12
+
13
+ def build_train_loader(
14
+ data_root: str,
15
+ batch_size: int = 8,
16
+ target_size: int = 640,
17
+ num_workers: int = 4,
18
+ use_landmarks: bool = False,
19
+ enable_robustness: bool = True,
20
+ distributed: bool = False,
21
+ rank: int = 0,
22
+ world_size: int = 1,
23
+ ) -> DataLoader:
24
+ """Build training data loader with SCRFD augmentation pipeline."""
25
+
26
+ transform = TrainAugmentation(
27
+ target_size=target_size,
28
+ enable_robustness=enable_robustness,
29
+ )
30
+
31
+ dataset = WiderFaceDataset(
32
+ root_dir=data_root,
33
+ split='train',
34
+ transform=transform,
35
+ use_landmarks=use_landmarks,
36
+ min_face_size=2,
37
+ )
38
+
39
+ sampler = None
40
+ if distributed:
41
+ sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
42
+
43
+ loader = DataLoader(
44
+ dataset,
45
+ batch_size=batch_size,
46
+ shuffle=(sampler is None),
47
+ sampler=sampler,
48
+ num_workers=num_workers,
49
+ pin_memory=True,
50
+ collate_fn=WiderFaceDataset.collate_fn,
51
+ drop_last=True,
52
+ )
53
+
54
+ return loader
55
+
56
+
57
+ def build_val_loader(
58
+ data_root: str,
59
+ batch_size: int = 1,
60
+ target_size: int = 640,
61
+ num_workers: int = 4,
62
+ use_landmarks: bool = False,
63
+ ) -> DataLoader:
64
+ """Build validation data loader."""
65
+
66
+ transform = ValAugmentation(target_size=target_size)
67
+
68
+ dataset = WiderFaceDataset(
69
+ root_dir=data_root,
70
+ split='val',
71
+ transform=transform,
72
+ use_landmarks=use_landmarks,
73
+ min_face_size=1,
74
+ )
75
+
76
+ loader = DataLoader(
77
+ dataset,
78
+ batch_size=batch_size,
79
+ shuffle=False,
80
+ num_workers=num_workers,
81
+ pin_memory=True,
82
+ collate_fn=WiderFaceDataset.collate_fn,
83
+ )
84
+
85
+ return loader