rakib72642 commited on
Commit
08ec965
·
1 Parent(s): 985cbbd

ready init project

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. INSTALL.md +34 -0
  3. README.md +405 -0
  4. cog.yaml +24 -0
  5. cutler/__init__.py +15 -0
  6. cutler/config/__init__.py +3 -0
  7. cutler/config/__pycache__/__init__.cpython-312.pyc +0 -0
  8. cutler/config/__pycache__/cutler_config.cpython-312.pyc +0 -0
  9. cutler/config/cutler_config.py +19 -0
  10. cutler/data/__init__.py +15 -0
  11. cutler/data/__pycache__/__init__.cpython-312.pyc +0 -0
  12. cutler/data/__pycache__/build.cpython-312.pyc +0 -0
  13. cutler/data/__pycache__/dataset_mapper.cpython-312.pyc +0 -0
  14. cutler/data/__pycache__/detection_utils.cpython-312.pyc +0 -0
  15. cutler/data/build.py +561 -0
  16. cutler/data/dataset_mapper.py +193 -0
  17. cutler/data/datasets/__init__.py +16 -0
  18. cutler/data/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
  19. cutler/data/datasets/__pycache__/builtin.cpython-312.pyc +0 -0
  20. cutler/data/datasets/__pycache__/builtin_meta.cpython-312.pyc +0 -0
  21. cutler/data/datasets/__pycache__/coco.cpython-312.pyc +0 -0
  22. cutler/data/datasets/builtin.py +216 -0
  23. cutler/data/datasets/builtin_meta.py +389 -0
  24. cutler/data/datasets/coco.py +544 -0
  25. cutler/data/detection_utils.py +650 -0
  26. cutler/data/transforms/__init__.py +15 -0
  27. cutler/data/transforms/__pycache__/__init__.cpython-312.pyc +0 -0
  28. cutler/data/transforms/__pycache__/augmentation_impl.cpython-312.pyc +0 -0
  29. cutler/data/transforms/__pycache__/transform.cpython-312.pyc +0 -0
  30. cutler/data/transforms/augmentation_impl.py +616 -0
  31. cutler/data/transforms/transform.py +355 -0
  32. cutler/demo/__init__.py +5 -0
  33. cutler/demo/__pycache__/predictor.cpython-312.pyc +0 -0
  34. cutler/demo/cutler_cascade_final.pth +3 -0
  35. cutler/demo/demo.py +197 -0
  36. cutler/demo/imgs/demo1.jpg +3 -0
  37. cutler/demo/imgs/demo2.jpg +3 -0
  38. cutler/demo/imgs/demo3.jpg +3 -0
  39. cutler/demo/imgs/demo4.jpg +3 -0
  40. cutler/demo/imgs/demo5.jpg +3 -0
  41. cutler/demo/imgs/demo6.jpg +3 -0
  42. cutler/demo/imgs/demo7.jpg +3 -0
  43. cutler/demo/imgs/demo8.jpg +3 -0
  44. cutler/demo/predictor.py +219 -0
  45. cutler/demo/wget-log +0 -0
  46. cutler/engine/__init__.py +7 -0
  47. cutler/engine/__pycache__/__init__.cpython-312.pyc +0 -0
  48. cutler/engine/__pycache__/defaults.cpython-312.pyc +0 -0
  49. cutler/engine/__pycache__/train_loop.cpython-312.pyc +0 -0
  50. cutler/engine/defaults.py +726 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.txt filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.gif filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
INSTALL.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Installation
3
+
4
+ ## Requirements
5
+ - Linux or macOS with Python ≥ 3.8
6
+ - PyTorch ≥ 1.8 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
7
+ Install them together at [pytorch.org](https://pytorch.org) to make sure of this.
8
+ Note, please check PyTorch version matches that is required by Detectron2.
9
+ - Detectron2: follow Detectron2 installation instructions.
10
+ - OpenCV ≥ 4.6 is needed by demo and visualization.
11
+
12
+ ## Example conda environment setup
13
+
14
+ ```bash
15
+ conda create --name cutler python=3.8 -y
16
+ conda activate cutler
17
+ conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -c pytorch
18
+ pip install git+https://github.com/lucasb-eyer/pydensecrf.git
19
+
20
+ # under your working directory
21
+ git clone git@github.com:facebookresearch/detectron2.git
22
+ cd detectron2
23
+ pip install -e .
24
+ pip install git+https://github.com/cocodataset/panopticapi.git
25
+ pip install git+https://github.com/mcordts/cityscapesScripts.git
26
+
27
+ cd ..
28
+ git clone --recursive git@github.com:facebookresearch/CutLER.git
29
+ cd CutLER
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+ ## datasets
34
+ If you want to train/evaluate on the datasets, please see [datasets/README.md](datasets/README.md) to see how we prepare datasets for this project.
README.md ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cut and Learn for Unsupervised Image & Video Object Detection and Instance Segmentation
2
+
3
+ **Cut**-and-**LE**a**R**n (**CutLER**) is a simple approach for training object detection and instance segmentation models without human annotations.
4
+ It outperforms previous SOTA by **2.7 times** for AP50 and **2.6 times** for AR on **11 benchmarks**.
5
+
6
+ <p align="center"> <img src='docs/teaser_img.jpg' align="center" > </p>
7
+
8
+ > [**Cut and Learn for Unsupervised Object Detection and Instance Segmentation**](http://people.eecs.berkeley.edu/~xdwang/projects/CutLER/)
9
+ > [Xudong Wang](https://people.eecs.berkeley.edu/~xdwang/), [Rohit Girdhar](https://rohitgirdhar.github.io/), [Stella X. Yu](https://www1.icsi.berkeley.edu/~stellayu/), [Ishan Misra](https://imisra.github.io/)
10
+ > FAIR, Meta AI; UC Berkeley
11
+ > CVPR 2023
12
+
13
+ [[`project page`](http://people.eecs.berkeley.edu/~xdwang/projects/CutLER/)] [[`arxiv`](https://arxiv.org/abs/2301.11320)] [[`colab`](https://colab.research.google.com/drive/1NgEyFHvOfuA2MZZnfNPWg1w5gSr3HOBb?usp=sharing)] [[`bibtex`](#citation)]
14
+
15
+ Unsupervised video instance segmentation (**VideoCutLER**) is also supported. ***We demonstrate that video instance segmentation models can be learned without using any human annotations, without relying on natural videos (ImageNet data alone is sufficient), and even without motion estimations!*** The code is available [here](videocutler).
16
+
17
+ <p align="center">
18
+ <img src="docs/demos_videocutler.gif" width=100%>
19
+ </p>
20
+
21
+ > [**VideoCutLER: Surprisingly Simple Unsupervised Video Instance Segmentation**](https://people.eecs.berkeley.edu/~xdwang/projects/VideoCutLER/videocutler.pdf)
22
+ > [Xudong Wang](https://people.eecs.berkeley.edu/~xdwang/), [Ishan Misra](https://imisra.github.io/), Ziyun Zeng, [Rohit Girdhar](https://rohitgirdhar.github.io/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/)
23
+ > UC Berkeley; FAIR, Meta AI
24
+ > CVPR 2024
25
+
26
+ [[`code`](videocutler/README.md)] [[`PDF`](https://people.eecs.berkeley.edu/~xdwang/projects/VideoCutLER/videocutler.pdf)] [[`arxiv`](https://arxiv.org/abs/2308.14710)] [[`bibtex`](#citation)]
27
+
28
+ ## Features
29
+ - We propose MaskCut approach to generate pseudo-masks for multiple objects in an image.
30
+ - CutLER can learn unsupervised object detectors and instance segmentors solely on ImageNet-1K.
31
+ - CutLER exhibits strong robustness to domain shifts when evaluated on 11 different benchmarks across domains like natural images, video frames, paintings, sketches, etc.
32
+ - CutLER can serve as a pretrained model for fully/semi-supervised detection and segmentation tasks.
33
+ - We also propose VideoCutLER, a surprisingly simple unsupervised video instance segmentation (UVIS) method without relying on optical flows. ImaegNet-1K is all we need for training a SOTA UVIS model!
34
+
35
+ ## Installation
36
+ See [installation instructions](INSTALL.md).
37
+
38
+ ## Dataset Preparation
39
+ See [Preparing Datasets for CutLER](datasets/README.md).
40
+
41
+ ## Method Overview
42
+ <p align="center">
43
+ <img src="docs/pipeline.jpg" width=55%>
44
+ </p>
45
+ Cut-and-Learn has two stages: 1) generating pseudo-masks with MaskCut and 2) learning unsupervised detectors from pseudo-masks of unlabeled data.
46
+
47
+ ### 1. MaskCut
48
+
49
+ MaskCut can be used to provide segmentation masks for multiple instances of each image.
50
+ <p align="center">
51
+ <img src="docs/maskcut.gif" width=100%>
52
+ </p>
53
+
54
+ ### MaskCut Demo
55
+
56
+ Try out the MaskCut demo using Colab (no GPU needed): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1X05lKL_IBRvZB7q6n6pb4w00_tIYjGlf?usp=sharing)
57
+
58
+ Try out the web demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/facebook/MaskCut) (thanks to [@hysts](https://github.com/hysts)!)
59
+
60
+
61
+
62
+
63
+ If you want to run MaskCut locally, we provide `demo.py` that is able to visualize the pseudo-masks produced by MaskCut.
64
+ Run it with:
65
+ ```
66
+ cd maskcut
67
+ python demo.py --img-path imgs/demo2.jpg \
68
+ --N 3 --tau 0.15 --vit-arch base --patch-size 8 \
69
+ [--other-options]
70
+ ```
71
+ We give a few demo images in maskcut/imgs/. If you want to run demo.py with cpu, simply add "--cpu" when running the demo script.
72
+ For imgs/demo4.jpg, you need to use "--N 6" to segment all six instances in the image.
73
+ Following, we give some visualizations of the pseudo-masks on the demo images.
74
+ <p align="center">
75
+ <img src="docs/maskcut-demo.jpg" width=100%>
76
+ </p>
77
+
78
+ ### Generating Annotations for ImageNet-1K with MaskCut
79
+ To generate pseudo-masks for ImageNet-1K using MaskCut, first set up the ImageNet-1K dataset according to the instructions in [datasets/README.md](datasets/README.md), then execute the following command:
80
+ ```
81
+ cd maskcut
82
+ python maskcut.py \
83
+ --vit-arch base --patch-size 8 \
84
+ --tau 0.15 --fixed_size 480 --N 3 \
85
+ --num-folder-per-job 1000 --job-index 0 \
86
+ --dataset-path /path/to/dataset/traindir \
87
+ --out-dir /path/to/save/annotations \
88
+ ```
89
+ As the process of generating pseudo-masks for all 1.3 million images in 1,000 folders takes a significant amount of time, it is recommended to use multiple runs. Each run should process the pseudo-mask generation for a smaller number of image folders by setting "--num-folder-per-job" and "--job-index". Once all runs are completed, you can merge all the resulting json files by using the following command:
90
+ ```
91
+ python merge_jsons.py \
92
+ --base-dir /path/to/save/annotations \
93
+ --num-folder-per-job 2 --fixed-size 480 \
94
+ --tau 0.15 --N 3 \
95
+ --save-path imagenet_train_fixsize480_tau0.15_N3.json
96
+ ```
97
+ The "--num-folder-per-job", "--fixed-size", "--tau" and "--N" of merge_jsons.py should match the ones used to run maskcut.py.
98
+
99
+ We also provide a submitit script to launch the pseudo-mask generation process with multiple nodes.
100
+ ```
101
+ cd maskcut
102
+ bash run_maskcut_with_submitit.sh
103
+ ```
104
+ After that, you can use "merge_jsons.py" to merge all these json files as described above.
105
+
106
+ ### 2. CutLER
107
+
108
+ ### Inference Demo for CutLER with Pre-trained Models
109
+ Try out the CutLER demo using Colab (no GPU needed): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1NgEyFHvOfuA2MZZnfNPWg1w5gSr3HOBb?usp=sharing)
110
+
111
+ Try out the web demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/facebook/CutLER) (thanks to [@hysts](https://github.com/hysts)!)
112
+
113
+
114
+ Try out Replicate demo and the API: [![Replicate](https://replicate.com/cjwbw/cutler/badge)](https://replicate.com/cjwbw/cutler)
115
+
116
+
117
+ If you want to run CutLER demos locally,
118
+ 1. Pick a model and its config file from [model zoo](#model-zoo),
119
+ for example, `model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml`.
120
+ 2. We provide `demo.py` that is able to demo builtin configs. Run it with:
121
+ ```
122
+ cd cutler
123
+ python demo/demo.py --config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN_demo.yaml \
124
+ --input demo/imgs/*.jpg \
125
+ [--other-options]
126
+ --opts MODEL.WEIGHTS /path/to/cutler_w_cascade_checkpoint
127
+ ```
128
+ The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation.
129
+ This command will run the inference and show visualizations in an OpenCV window.
130
+ <!-- For details of the command line arguments, see `demo.py -h` or look at its source code
131
+ to understand its behavior. Some common arguments are: -->
132
+ * To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`.
133
+ * To save outputs to a directory (for images) or a file (for webcam or video), use `--output`.
134
+
135
+ Following, we give some visualizations of the model predictions on the demo images.
136
+ <p align="center">
137
+ <img src="docs/cutler-demo.jpg" width=100%>
138
+ </p>
139
+
140
+ ### Unsupervised Model Learning
141
+ Before training the detector, it is necessary to use MaskCut to generate pseudo-masks for all ImageNet data.
142
+ You can either use the pre-generated json file directly by downloading it from [here](http://dl.fbaipublicfiles.com/cutler/maskcut/imagenet_train_fixsize480_tau0.15_N3.json) and placing it under "DETECTRON2_DATASETS/imagenet/annotations/", or generate your own pseudo-masks by following the instructions in [MaskCut](#1-maskcut).
143
+
144
+ We provide a script `train_net.py`, that is made to train all the configs provided in CutLER.
145
+ To train a model with "train_net.py", first setup the ImageNet-1K dataset following [datasets/README.md](datasets/README.md), then run:
146
+ ```
147
+ cd cutler
148
+ export DETECTRON2_DATASETS=/path/to/DETECTRON2_DATASETS/
149
+ python train_net.py --num-gpus 8 \
150
+ --config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml
151
+ ```
152
+
153
+ If you want to train a model using multiple nodes, you may need to adjust [some model parameters](https://arxiv.org/abs/1706.02677) and some SBATCH command options in "tools/train-1node.sh" and "tools/single-node_run.sh", then run:
154
+ ```
155
+ cd cutler
156
+ sbatch tools/train-1node.sh \
157
+ --config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml \
158
+ MODEL.WEIGHTS /path/to/dino/d2format/model \
159
+ OUTPUT_DIR output/
160
+ ```
161
+ You can also convert a pre-trained DINO model to detectron2's format by yourself following [this link](https://github.com/facebookresearch/moco/tree/main/detection).
162
+
163
+ ### Self-training
164
+ We further improve performance by self-training the model on its predictions.
165
+
166
+ Firstly, we can get model predictions on ImageNet via running:
167
+ ```
168
+ python train_net.py --num-gpus 8 \
169
+ --config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml \
170
+ --test-dataset imagenet_train \
171
+ --eval-only TEST.DETECTIONS_PER_IMAGE 30 \
172
+ MODEL.WEIGHTS output/model_final.pth \ # load previous stage/round checkpoints
173
+ OUTPUT_DIR output/ # path to save model predictions
174
+ ```
175
+ Secondly, we can run the following command to generate the json file for the first round of self-training:
176
+ ```
177
+ python tools/get_self_training_ann.py \
178
+ --new-pred output/inference/coco_instances_results.json \ # load model predictions
179
+ --prev-ann DETECTRON2_DATASETS/imagenet/annotations/imagenet_train_fixsize480_tau0.15_N3.json \ # path to the old annotation file.
180
+ --save-path DETECTRON2_DATASETS/imagenet/annotations/cutler_imagenet1k_train_r1.json \ # path to save a new annotation file.
181
+ --threshold 0.7
182
+ ```
183
+ Finally, place "cutler_imagenet1k_train_r1.json" under "DETECTRON2_DATASETS/imagenet/annotations/", then launch the self-training process:
184
+ ```
185
+ python train_net.py --num-gpus 8 \
186
+ --config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN_self_train.yaml \
187
+ --train-dataset imagenet_train_r1 \
188
+ MODEL.WEIGHTS output/model_final.pth \ # load previous stage/round checkpoints
189
+ OUTPUT_DIR output/self-train-r1/ # path to save checkpoints
190
+ ```
191
+
192
+ You can repeat the steps above to perform multiple rounds of self-training and adjust some arguments as needed (e.g., "--threshold" for round 1 and 2 can be set to 0.7 and 0.65, respectively; "--train-dataset" for round 1 and 2 can be set to "imagenet_train_r1" and "imagenet_train_r2", respectively; MODEL.WEIGHTS for round 1 and 2 should point to the previous stage/round checkpoints). Ensure that all annotation files are placed under DETECTRON2_DATASETS/imagenet/annotations/.
193
+ Please ensure that "--train-dataset", json file names and locations match the ones specified in "cutler/data/datasets/builtin.py".
194
+ Please refer to this [instruction](https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html) for guidance on using custom datasets.
195
+
196
+ You can also directly download the MODEL.WEIGHTS and annotations used for each round of self-training:
197
+ <table><tbody>
198
+ <!-- START TABLE -->
199
+ <!-- TABLE BODY -->
200
+ <!-- ROW: round 1 -->
201
+ <tr><td align="center">round 1</td>
202
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_r1.pth">cutler_cascade_r1.pth</a></td>
203
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/maskcut/cutler_imagenet1k_train_r1.json">cutler_imagenet1k_train_r1.json</a></td>
204
+ </tr>
205
+ <!-- ROW: round 2 -->
206
+ <tr><td align="center">round 2</td>
207
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_r2.pth">cutler_cascade_r2.pth</a></td>
208
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/maskcut/cutler_imagenet1k_train_r2.json">cutler_imagenet1k_train_r2.json</a></td>
209
+ </tr>
210
+ </tbody></table>
211
+
212
+ ### Unsupervised Zero-shot Evaluation
213
+ To evaluate a model's performance on 11 different datasets, please refer to [datasets/README.md](datasets/README.md) for instructions on preparing the datasets. Next, select a model from the model zoo, specify the "model_weights", "config_file" and the path to "DETECTRON2_DATASETS" in `tools/eval.sh`, then run the script.
214
+ ```
215
+ bash tools/eval.sh
216
+ ```
217
+
218
+ ### Model Zoo
219
+ We show zero-shot unsupervised object detection performance (AP50&nbsp;|&nbsp;AR) on 11 different datasets spanning a variety of domains. ^: CutLER using Mask R-CNN as a detector; *: CutLER using Cascade Mask R-CNN as a detector.
220
+ <table><tbody>
221
+ <!-- START TABLE -->
222
+ <!-- TABLE HEADER -->
223
+ <th valign="bottom">Methods</th>
224
+ <th valign="bottom">Models</th>
225
+ <th valign="bottom">COCO</th>
226
+ <th valign="bottom">COCO20K</th>
227
+ <th valign="bottom">VOC</th>
228
+ <th valign="bottom">LVIS</th>
229
+ <th valign="bottom">UVO</th>
230
+ <th valign="bottom">Clipart</th>
231
+ <th valign="bottom">Comic</th>
232
+ <th valign="bottom">Watercolor</th>
233
+ <th valign="bottom">KITTI</th>
234
+ <th valign="bottom">Objects365</th>
235
+ <th valign="bottom">OpenImages</th>
236
+ <!-- TABLE BODY -->
237
+ </tr>
238
+ <tr><td align="center">Prev. SOTA</td>
239
+ <td valign="bottom">-</td>
240
+ <td align="center">9.6&nbsp;|&nbsp;12.6</td>
241
+ <td align="center">9.7&nbsp;|&nbsp;12.6</td>
242
+ <td align="center">15.9&nbsp;|&nbsp;21.3</td>
243
+ <td align="center">3.8&nbsp;|&nbsp;6.4</td>
244
+ <td align="center">10.0&nbsp;|&nbsp;14.2</td>
245
+ <td align="center">7.9&nbsp;|&nbsp;15.1</td>
246
+ <td align="center">9.9&nbsp;|&nbsp;16.3</td>
247
+ <td align="center">6.7&nbsp;|&nbsp;16.2</td>
248
+ <td align="center">7.7&nbsp;|&nbsp;7.1</td>
249
+ <td align="center">8.1&nbsp;|&nbsp;10.2</td>
250
+ <td align="center">9.9&nbsp;|&nbsp;14.9</td>
251
+ </tr>
252
+ <!-- ROW: Box/Mask AP for CutLER -->
253
+ </tr>
254
+ <tr><td align="center">CutLER^</td>
255
+ <td valign="bottom"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_mrcnn_final.pth">download</a></td>
256
+ <td align="center">21.1&nbsp;|&nbsp;29.6</td>
257
+ <td align="center">21.6&nbsp;|&nbsp;30.0</td>
258
+ <td align="center">36.6&nbsp;|&nbsp;41.0</td>
259
+ <td align="center">7.7&nbsp;|&nbsp;18.7</td>
260
+ <td align="center">29.8&nbsp;|&nbsp;38.4</td>
261
+ <td align="center">20.9&nbsp;|&nbsp;38.5</td>
262
+ <td align="center">31.2&nbsp;|&nbsp;37.1</td>
263
+ <td align="center">37.3&nbsp;|&nbsp;39.9</td>
264
+ <td align="center">15.3&nbsp;|&nbsp;25.4</td>
265
+ <td align="center">19.5&nbsp;|&nbsp;30.0</td>
266
+ <td align="center">17.1&nbsp;|&nbsp;26.4</td>
267
+ </tr>
268
+ <!-- ROW: Box/Mask AP for CutLER -->
269
+ </tr>
270
+ <tr><td align="center">CutLER*</td>
271
+ <td valign="bottom"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth">download</a></td>
272
+ <td align="center">21.9&nbsp;|&nbsp;32.7</td>
273
+ <td align="center">22.4&nbsp;|&nbsp;33.1</td>
274
+ <td align="center">36.9&nbsp;|&nbsp;44.3</td>
275
+ <td align="center">8.4&nbsp;|&nbsp;21.8</td>
276
+ <td align="center">31.7&nbsp;|&nbsp;42.8</td>
277
+ <td align="center">21.1&nbsp;|&nbsp;41.3</td>
278
+ <td align="center">30.4&nbsp;|&nbsp;38.6</td>
279
+ <td align="center">37.5&nbsp;|&nbsp;44.6</td>
280
+ <td align="center">18.4&nbsp;|&nbsp;27.5</td>
281
+ <td align="center">21.6&nbsp;|&nbsp;34.2</td>
282
+ <td align="center">17.3&nbsp;|&nbsp;29.6</td>
283
+ </tr>
284
+ </tbody></table>
285
+
286
+ ## Semi-supervised and Fully-supervised Learning
287
+ CutLER can also serve as a pretrained model for training fully supervised object detection and instance segmentation models and improves performance on COCO, including on few-shot benchmarks.
288
+
289
+ ### Training & Evaluation in Command Line
290
+ You can find all the semi-supervised and fully-supervised learning configs provided in CutLER under `model_zoo/configs/COCO-Semisupervised`.
291
+
292
+ To train a model using K% labels with `train_net.py`, first set up the COCO dataset according to [datasets/README.md](datasets/README.md) and specify K value in the config file, then run:
293
+ ```
294
+ python train_net.py --num-gpus 8 \
295
+ --config-file model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_{K}perc.yaml \
296
+ MODEL.WEIGHTS /path/to/cutler_pretrained_model
297
+ ```
298
+
299
+ You can find all config files used to train supervised models under `model_zoo/configs/COCO-Semisupervised`.
300
+ The configs are made for 8-GPU training. To train on 1 GPU, you may need to [change some parameters](https://arxiv.org/abs/1706.02677), e.g. number of GPUs (num-gpus your_num_gpus), learning rates (SOLVER.BASE_LR your_base_lr) and batch size (SOLVER.IMS_PER_BATCH your_batch_size).
301
+
302
+ ### Evaluation
303
+ To evaluate a model's performance, use
304
+ ```
305
+ python train_net.py \
306
+ --config-file model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_{K}perc.yaml \
307
+ --eval-only MODEL.WEIGHTS /path/to/checkpoint_file
308
+ ```
309
+ For more options, see `python train_net.py -h`.
310
+
311
+ ### Model Zoo
312
+ We fine-tune a Cascade R-CNN model initialized with CutLER or MoCo-v2 on varying amounts of labeled COCO data, and show results (Box&nbsp;|&nbsp;Mask AP) on the val2017 split below:
313
+
314
+ <table><tbody>
315
+ <!-- START TABLE -->
316
+ <!-- TABLE HEADER -->
317
+ <th valign="bottom">% of labels</th>
318
+ <th valign="bottom">1%</th>
319
+ <th valign="bottom">2%</th>
320
+ <th valign="bottom">5%</th>
321
+ <th valign="bottom">10%</th>
322
+ <th valign="bottom">20%</th>
323
+ <th valign="bottom">30%</th>
324
+ <th valign="bottom">40%</th>
325
+ <th valign="bottom">50%</th>
326
+ <th valign="bottom">60%</th>
327
+ <th valign="bottom">80%</th>
328
+ <th valign="bottom">100%</th>
329
+ <!-- TABLE BODY -->
330
+ <!-- ROW: Box/Mask AP for CutLER -->
331
+ <tr><td align="center">MoCo-v2</td>
332
+ <td align="center">11.8&nbsp;|&nbsp;10.0</td>
333
+ <td align="center">16.2&nbsp;|&nbsp;13.8</td>
334
+ <td align="center">20.5&nbsp;|&nbsp;17.8</td>
335
+ <td align="center">26.5&nbsp;|&nbsp;23.0</td>
336
+ <td align="center">32.5&nbsp;|&nbsp;28.2</td>
337
+ <td align="center">35.5&nbsp;|&nbsp;30.8</td>
338
+ <td align="center">37.3&nbsp;|&nbsp;32.3</td>
339
+ <td align="center">38.7&nbsp;|&nbsp;33.6</td>
340
+ <td align="center">39.9&nbsp;|&nbsp;34.6</td>
341
+ <td align="center">41.6&nbsp;|&nbsp;36.0</td>
342
+ <td align="center">42.8&nbsp;|&nbsp;37.0</td>
343
+ </tr>
344
+ <!-- ROW: Mask AP -->
345
+ <tr><td align="center">CutLER</td>
346
+ <td align="center">16.8&nbsp;|&nbsp;14.6</td>
347
+ <td align="center">21.6&nbsp;|&nbsp;18.9</td>
348
+ <td align="center">27.8&nbsp;|&nbsp;24.3</td>
349
+ <td align="center">32.2&nbsp;|&nbsp;28.1</td>
350
+ <td align="center">36.6&nbsp;|&nbsp;31.7</td>
351
+ <td align="center">38.2&nbsp;|&nbsp;33.3</td>
352
+ <td align="center">39.9&nbsp;|&nbsp;34.7</td>
353
+ <td align="center">41.5&nbsp;|&nbsp;35.9</td>
354
+ <td align="center">42.3&nbsp;|&nbsp;36.7</td>
355
+ <td align="center">43.8&nbsp;|&nbsp;37.9</td>
356
+ <td align="center">44.7&nbsp;|&nbsp;38.5</td>
357
+ </tr>
358
+ <!-- ROW: Model Downloads -->
359
+ <tr><td align="center">Download</td>
360
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_1perc.pth">model</a></td>
361
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_2perc.pth">model</a></td>
362
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_5perc.pth">model</a></td>
363
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_10perc.pth">model</a></td>
364
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_20perc.pth">model</a></td>
365
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_30perc.pth">model</a></td>
366
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_40perc.pth">model</a></td>
367
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_50perc.pth">model</a></td>
368
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_60perc.pth">model</a></td>
369
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_80perc.pth">model</a></td>
370
+ <td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_fully_100perc.pth">model</a></td>
371
+ </tr>
372
+ </tbody></table>
373
+
374
+ Both MoCo-v2 and our CutLER are trained for the 1x schedule using Detectron2, except for extremely low-shot settings with 1% or 2% labels. When training with 1% or 2% labels, we train both MoCo-v2 and our model for 3,600 iterations with a batch size of 16.
375
+
376
+ ## License
377
+ The majority of CutLER, Detectron2 and DINO are licensed under the [CC-BY-NC license](LICENSE), however portions of the project are available under separate license terms: TokenCut, Bilateral Solver and CRF are licensed under the MIT license; If you later add other third party code, please keep this license info updated, and please let us know if that component is licensed under something other than CC-BY-NC, MIT, or CC0.
378
+
379
+ ## Ethical Considerations
380
+ CutLER's wide range of detection capabilities may introduce similar challenges to many other visual recognition methods.
381
+ As the image can contain arbitrary instances, it may impact the model output.
382
+
383
+ ## How to get support from us?
384
+ If you have any general questions, feel free to email us at [Xudong Wang](mailto:xdwang@eecs.berkeley.edu), [Ishan Misra](mailto:imisra@meta.com) and [Rohit Girdhar](mailto:rgirdhar@meta.com). If you have code or implementation-related questions, please feel free to send emails to us or open an issue in this codebase (We recommend that you open an issue in this codebase, because your questions may help others).
385
+
386
+ ## Citation
387
+ If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.
388
+ ```
389
+ @inproceedings{wang2023cut,
390
+ title={Cut and learn for unsupervised object detection and instance segmentation},
391
+ author={Wang, Xudong and Girdhar, Rohit and Yu, Stella X and Misra, Ishan},
392
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
393
+ pages={3124--3134},
394
+ year={2023}
395
+ }
396
+ ```
397
+
398
+ ```
399
+ @article{wang2023videocutler,
400
+ title={VideoCutLER: Surprisingly Simple Unsupervised Video Instance Segmentation},
401
+ author={Wang, Xudong and Misra, Ishan and Zeng, Ziyun and Girdhar, Rohit and Darrell, Trevor},
402
+ journal={arXiv preprint arXiv:2308.14710},
403
+ year={2023}
404
+ }
405
+ ```
cog.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.6"
4
+ python_version: "3.8"
5
+ python_packages:
6
+ - "torch==1.11.0"
7
+ - "torchvision==0.12.0"
8
+ - "faiss-gpu==1.7.2"
9
+ - "opencv-python==4.6.0.66"
10
+ - "scikit-image==0.19.2"
11
+ - "scikit-learn==1.1.1"
12
+ - "shapely==1.8.2"
13
+ - "timm==0.5.4"
14
+ - "pyyaml==6.0"
15
+ - "colored==1.4.4"
16
+ - "fvcore==0.1.5.post20220512"
17
+ - "gdown==4.5.4"
18
+ - "pycocotools==2.0.6"
19
+ - "numpy==1.20.0"
20
+
21
+ run:
22
+ - pip install git+https://github.com/lucasb-eyer/pydensecrf.git
23
+
24
+ predict: "maskcut/predict.py:Predictor"
cutler/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import config
4
+ import engine
5
+ import modeling
6
+ import structures
7
+ import tools
8
+ import demo
9
+
10
+ # dataset loading
11
+ from . import data # register all new datasets
12
+ from data import datasets # register all new datasets
13
+ from solver import *
14
+
15
+ # from .data import register_all_imagenet
cutler/config/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from .cutler_config import add_cutler_config
cutler/config/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (267 Bytes). View file
 
cutler/config/__pycache__/cutler_config.cpython-312.pyc ADDED
Binary file (1.22 kB). View file
 
cutler/config/cutler_config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from detectron2.config import CfgNode as CN
4
+
5
+ def add_cutler_config(cfg):
6
+ cfg.DATALOADER.COPY_PASTE = False
7
+ cfg.DATALOADER.COPY_PASTE_RATE = 0.0
8
+ cfg.DATALOADER.COPY_PASTE_MIN_RATIO = 0.5
9
+ cfg.DATALOADER.COPY_PASTE_MAX_RATIO = 1.0
10
+ cfg.DATALOADER.COPY_PASTE_RANDOM_NUM = True
11
+ cfg.DATALOADER.VISUALIZE_COPY_PASTE = False
12
+
13
+ cfg.MODEL.ROI_HEADS.USE_DROPLOSS = False
14
+ cfg.MODEL.ROI_HEADS.DROPLOSS_IOU_THRESH = 0.0
15
+
16
+ cfg.SOLVER.BASE_LR_MULTIPLIER = 1
17
+ cfg.SOLVER.BASE_LR_MULTIPLIER_NAMES = []
18
+
19
+ cfg.TEST.NO_SEGM = False
cutler/data/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from . import datasets # ensure the builtin datasets are registered
4
+ from .detection_utils import * # isort:skip
5
+ from .build import (
6
+ build_batch_data_loader,
7
+ build_detection_train_loader,
8
+ build_detection_test_loader,
9
+ get_detection_dataset_dicts,
10
+ load_proposals_into_dataset,
11
+ print_instances_class_histogram,
12
+ )
13
+ from detectron2.data.common import *
14
+
15
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
cutler/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (793 Bytes). View file
 
cutler/data/__pycache__/build.cpython-312.pyc ADDED
Binary file (24.3 kB). View file
 
cutler/data/__pycache__/dataset_mapper.cpython-312.pyc ADDED
Binary file (8.69 kB). View file
 
cutler/data/__pycache__/detection_utils.cpython-312.pyc ADDED
Binary file (27.6 kB). View file
 
cutler/data/build.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/build.py
3
+
4
+ import itertools
5
+ import logging
6
+ import numpy as np
7
+ import operator
8
+ import pickle
9
+ from typing import Any, Callable, Dict, List, Optional, Union
10
+ import torch
11
+ import torch.utils.data as torchdata
12
+ from tabulate import tabulate
13
+ from termcolor import colored
14
+
15
+ from detectron2.config import configurable
16
+ from detectron2.structures import BoxMode
17
+ from detectron2.utils.comm import get_world_size
18
+ from detectron2.utils.env import seed_all_rng
19
+ from detectron2.utils.file_io import PathManager
20
+ from detectron2.utils.logger import _log_api_usage, log_first_n
21
+
22
+ from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
23
+ from detectron2.data.common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
24
+ from data.dataset_mapper import DatasetMapper
25
+ from data.detection_utils import check_metadata_consistency
26
+ from detectron2.data.samplers import (
27
+ InferenceSampler,
28
+ RandomSubsetTrainingSampler,
29
+ RepeatFactorTrainingSampler,
30
+ TrainingSampler,
31
+ )
32
+
33
+ """
34
+ This file contains the default logic to build a dataloader for training or testing.
35
+ """
36
+
37
+ __all__ = [
38
+ "build_batch_data_loader",
39
+ "build_detection_train_loader",
40
+ "build_detection_test_loader",
41
+ "get_detection_dataset_dicts",
42
+ "load_proposals_into_dataset",
43
+ "print_instances_class_histogram",
44
+ ]
45
+
46
+
47
+ def filter_images_with_only_crowd_annotations(dataset_dicts):
48
+ """
49
+ Filter out images with none annotations or only crowd annotations
50
+ (i.e., images without non-crowd annotations).
51
+ A common training-time preprocessing on COCO dataset.
52
+
53
+ Args:
54
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
55
+
56
+ Returns:
57
+ list[dict]: the same format, but filtered.
58
+ """
59
+ num_before = len(dataset_dicts)
60
+
61
+ def valid(anns):
62
+ for ann in anns:
63
+ if ann.get("iscrowd", 0) == 0:
64
+ return True
65
+ return False
66
+
67
+ dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
68
+ num_after = len(dataset_dicts)
69
+ logger = logging.getLogger(__name__)
70
+ logger.info(
71
+ "Removed {} images with no usable annotations. {} images left.".format(
72
+ num_before - num_after, num_after
73
+ )
74
+ )
75
+ print("Removed {} images with no usable annotations. {} images left.".format(
76
+ num_before - num_after, num_after
77
+ ))
78
+ return dataset_dicts
79
+
80
+
81
+ def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
82
+ """
83
+ Filter out images with too few number of keypoints.
84
+
85
+ Args:
86
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
87
+
88
+ Returns:
89
+ list[dict]: the same format as dataset_dicts, but filtered.
90
+ """
91
+ num_before = len(dataset_dicts)
92
+
93
+ def visible_keypoints_in_image(dic):
94
+ # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
95
+ annotations = dic["annotations"]
96
+ return sum(
97
+ (np.array(ann["keypoints"][2::3]) > 0).sum()
98
+ for ann in annotations
99
+ if "keypoints" in ann
100
+ )
101
+
102
+ dataset_dicts = [
103
+ x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
104
+ ]
105
+ num_after = len(dataset_dicts)
106
+ logger = logging.getLogger(__name__)
107
+ logger.info(
108
+ "Removed {} images with fewer than {} keypoints.".format(
109
+ num_before - num_after, min_keypoints_per_image
110
+ )
111
+ )
112
+ return dataset_dicts
113
+
114
+
115
+ def load_proposals_into_dataset(dataset_dicts, proposal_file):
116
+ """
117
+ Load precomputed object proposals into the dataset.
118
+
119
+ The proposal file should be a pickled dict with the following keys:
120
+
121
+ - "ids": list[int] or list[str], the image ids
122
+ - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
123
+ - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
124
+ corresponding to the boxes.
125
+ - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
126
+
127
+ Args:
128
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
129
+ proposal_file (str): file path of pre-computed proposals, in pkl format.
130
+
131
+ Returns:
132
+ list[dict]: the same format as dataset_dicts, but added proposal field.
133
+ """
134
+ logger = logging.getLogger(__name__)
135
+ logger.info("Loading proposals from: {}".format(proposal_file))
136
+
137
+ with PathManager.open(proposal_file, "rb") as f:
138
+ proposals = pickle.load(f, encoding="latin1")
139
+
140
+ # Rename the key names in D1 proposal files
141
+ rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
142
+ for key in rename_keys:
143
+ if key in proposals:
144
+ proposals[rename_keys[key]] = proposals.pop(key)
145
+
146
+ # Fetch the indexes of all proposals that are in the dataset
147
+ # Convert image_id to str since they could be int.
148
+ img_ids = set({str(record["image_id"]) for record in dataset_dicts})
149
+ id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
150
+
151
+ # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
152
+ bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
153
+
154
+ for record in dataset_dicts:
155
+ # Get the index of the proposal
156
+ i = id_to_index[str(record["image_id"])]
157
+
158
+ boxes = proposals["boxes"][i]
159
+ objectness_logits = proposals["objectness_logits"][i]
160
+ # Sort the proposals in descending order of the scores
161
+ inds = objectness_logits.argsort()[::-1]
162
+ record["proposal_boxes"] = boxes[inds]
163
+ record["proposal_objectness_logits"] = objectness_logits[inds]
164
+ record["proposal_bbox_mode"] = bbox_mode
165
+
166
+ return dataset_dicts
167
+
168
+
169
+ def print_instances_class_histogram(dataset_dicts, class_names):
170
+ """
171
+ Args:
172
+ dataset_dicts (list[dict]): list of dataset dicts.
173
+ class_names (list[str]): list of class names (zero-indexed).
174
+ """
175
+ num_classes = len(class_names)
176
+ hist_bins = np.arange(num_classes + 1)
177
+ histogram = np.zeros((num_classes,), dtype=np.int)
178
+ for entry in dataset_dicts:
179
+ annos = entry["annotations"]
180
+ classes = np.asarray(
181
+ [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=np.int
182
+ )
183
+ if len(classes):
184
+ assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
185
+ assert (
186
+ classes.max() < num_classes
187
+ ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
188
+ histogram += np.histogram(classes, bins=hist_bins)[0]
189
+
190
+ N_COLS = min(6, len(class_names) * 2)
191
+
192
+ def short_name(x):
193
+ # make long class names shorter. useful for lvis
194
+ if len(x) > 13:
195
+ return x[:11] + ".."
196
+ return x
197
+
198
+ data = list(
199
+ itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
200
+ )
201
+ total_num_instances = sum(data[1::2])
202
+ data.extend([None] * (N_COLS - (len(data) % N_COLS)))
203
+ if num_classes > 1:
204
+ data.extend(["total", total_num_instances])
205
+ data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
206
+ table = tabulate(
207
+ data,
208
+ headers=["category", "#instances"] * (N_COLS // 2),
209
+ tablefmt="pipe",
210
+ numalign="left",
211
+ stralign="center",
212
+ )
213
+ log_first_n(
214
+ logging.INFO,
215
+ "Distribution of instances among all {} categories:\n".format(num_classes)
216
+ + colored(table, "cyan"),
217
+ key="message",
218
+ )
219
+
220
+
221
+ def get_detection_dataset_dicts(
222
+ names,
223
+ filter_empty=True,
224
+ min_keypoints=0,
225
+ proposal_files=None,
226
+ check_consistency=True,
227
+ ):
228
+ """
229
+ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
230
+
231
+ Args:
232
+ names (str or list[str]): a dataset name or a list of dataset names
233
+ filter_empty (bool): whether to filter out images without instance annotations
234
+ min_keypoints (int): filter out images with fewer keypoints than
235
+ `min_keypoints`. Set to 0 to do nothing.
236
+ proposal_files (list[str]): if given, a list of object proposal files
237
+ that match each dataset in `names`.
238
+ check_consistency (bool): whether to check if datasets have consistent metadata.
239
+
240
+ Returns:
241
+ list[dict]: a list of dicts following the standard dataset dict format.
242
+ """
243
+ if isinstance(names, str):
244
+ names = [names]
245
+ assert len(names), names
246
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
247
+
248
+ if isinstance(dataset_dicts[0], torchdata.Dataset):
249
+ if len(dataset_dicts) > 1:
250
+ # ConcatDataset does not work for iterable style dataset.
251
+ # We could support concat for iterable as well, but it's often
252
+ # not a good idea to concat iterables anyway.
253
+ return torchdata.ConcatDataset(dataset_dicts)
254
+ return dataset_dicts[0]
255
+
256
+ for dataset_name, dicts in zip(names, dataset_dicts):
257
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
258
+
259
+ if proposal_files is not None:
260
+ assert len(names) == len(proposal_files)
261
+ # load precomputed proposals from proposal files
262
+ dataset_dicts = [
263
+ load_proposals_into_dataset(dataset_i_dicts, proposal_file)
264
+ for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
265
+ ]
266
+
267
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
268
+
269
+ has_instances = "annotations" in dataset_dicts[0]
270
+ if filter_empty and has_instances:
271
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
272
+ if min_keypoints > 0 and has_instances:
273
+ dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
274
+
275
+ if check_consistency and has_instances:
276
+ try:
277
+ class_names = MetadataCatalog.get(names[0]).thing_classes
278
+ check_metadata_consistency("thing_classes", names)
279
+ print_instances_class_histogram(dataset_dicts, class_names)
280
+ except AttributeError: # class names are not available for this dataset
281
+ pass
282
+
283
+ assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
284
+ return dataset_dicts
285
+
286
+
287
+ def build_batch_data_loader(
288
+ dataset,
289
+ sampler,
290
+ total_batch_size,
291
+ *,
292
+ aspect_ratio_grouping=False,
293
+ num_workers=0,
294
+ collate_fn=None,
295
+ ):
296
+ """
297
+ Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
298
+ 1. support aspect ratio grouping options
299
+ 2. use no "batch collation", because this is common for detection training
300
+
301
+ Args:
302
+ dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
303
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
304
+ Must be provided iff. ``dataset`` is a map-style dataset.
305
+ total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
306
+ :func:`build_detection_train_loader`.
307
+
308
+ Returns:
309
+ iterable[list]. Length of each list is the batch size of the current
310
+ GPU. Each element in the list comes from the dataset.
311
+ """
312
+ world_size = get_world_size()
313
+ assert (
314
+ total_batch_size > 0 and total_batch_size % world_size == 0
315
+ ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
316
+ total_batch_size, world_size
317
+ )
318
+ batch_size = total_batch_size // world_size
319
+
320
+ if isinstance(dataset, torchdata.IterableDataset):
321
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
322
+ else:
323
+ dataset = ToIterableDataset(dataset, sampler)
324
+
325
+ if aspect_ratio_grouping:
326
+ data_loader = torchdata.DataLoader(
327
+ dataset,
328
+ num_workers=num_workers,
329
+ collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
330
+ worker_init_fn=worker_init_reset_seed,
331
+ ) # yield individual mapped dict
332
+ data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
333
+ if collate_fn is None:
334
+ return data_loader
335
+ return MapDataset(data_loader, collate_fn)
336
+ else:
337
+ return torchdata.DataLoader(
338
+ dataset,
339
+ batch_size=batch_size,
340
+ drop_last=True,
341
+ num_workers=num_workers,
342
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
343
+ worker_init_fn=worker_init_reset_seed,
344
+ )
345
+
346
+
347
+ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
348
+ if dataset is None:
349
+ dataset = get_detection_dataset_dicts(
350
+ cfg.DATASETS.TRAIN,
351
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
352
+ min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
353
+ if cfg.MODEL.KEYPOINT_ON
354
+ else 0,
355
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
356
+ )
357
+ _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
358
+
359
+ if mapper is None:
360
+ mapper = DatasetMapper(cfg, True)
361
+
362
+ if sampler is None:
363
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
364
+ logger = logging.getLogger(__name__)
365
+ if isinstance(dataset, torchdata.IterableDataset):
366
+ logger.info("Not using any sampler since the dataset is IterableDataset.")
367
+ sampler = None
368
+ else:
369
+ logger.info("Using training sampler {}".format(sampler_name))
370
+ if sampler_name == "TrainingSampler":
371
+ sampler = TrainingSampler(len(dataset))
372
+ elif sampler_name == "RepeatFactorTrainingSampler":
373
+ repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
374
+ dataset, cfg.DATALOADER.REPEAT_THRESHOLD
375
+ )
376
+ sampler = RepeatFactorTrainingSampler(repeat_factors)
377
+ elif sampler_name == "RandomSubsetTrainingSampler":
378
+ sampler = RandomSubsetTrainingSampler(
379
+ len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
380
+ )
381
+ else:
382
+ raise ValueError("Unknown training sampler: {}".format(sampler_name))
383
+
384
+ return {
385
+ "dataset": dataset,
386
+ "sampler": sampler,
387
+ "mapper": mapper,
388
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
389
+ "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
390
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
391
+ }
392
+
393
+
394
+ @configurable(from_config=_train_loader_from_config)
395
+ def build_detection_train_loader(
396
+ dataset,
397
+ *,
398
+ mapper,
399
+ sampler=None,
400
+ total_batch_size,
401
+ aspect_ratio_grouping=True,
402
+ num_workers=0,
403
+ collate_fn=None,
404
+ ):
405
+ """
406
+ Build a dataloader for object detection with some default features.
407
+
408
+ Args:
409
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
410
+ or a pytorch dataset (either map-style or iterable). It can be obtained
411
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
412
+ mapper (callable): a callable which takes a sample (dict) from dataset and
413
+ returns the format to be consumed by the model.
414
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
415
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
416
+ indices to be applied on ``dataset``.
417
+ If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
418
+ which coordinates an infinite random shuffle sequence across all workers.
419
+ Sampler must be None if ``dataset`` is iterable.
420
+ total_batch_size (int): total batch size across all workers.
421
+ aspect_ratio_grouping (bool): whether to group images with similar
422
+ aspect ratio for efficiency. When enabled, it requires each
423
+ element in dataset be a dict with keys "width" and "height".
424
+ num_workers (int): number of parallel data loading workers
425
+ collate_fn: a function that determines how to do batching, same as the argument of
426
+ `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of
427
+ data. No collation is OK for small batch size and simple data structures.
428
+ If your batch size is large and each sample contains too many small tensors,
429
+ it's more efficient to collate them in data loader.
430
+
431
+ Returns:
432
+ torch.utils.data.DataLoader:
433
+ a dataloader. Each output from it is a ``list[mapped_element]`` of length
434
+ ``total_batch_size / num_workers``, where ``mapped_element`` is produced
435
+ by the ``mapper``.
436
+ """
437
+ if isinstance(dataset, list):
438
+ dataset = DatasetFromList(dataset, copy=False)
439
+ if mapper is not None:
440
+ dataset = MapDataset(dataset, mapper)
441
+
442
+ if isinstance(dataset, torchdata.IterableDataset):
443
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
444
+ else:
445
+ if sampler is None:
446
+ sampler = TrainingSampler(len(dataset))
447
+ assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}"
448
+ return build_batch_data_loader(
449
+ dataset,
450
+ sampler,
451
+ total_batch_size,
452
+ aspect_ratio_grouping=aspect_ratio_grouping,
453
+ num_workers=num_workers,
454
+ collate_fn=collate_fn,
455
+ )
456
+
457
+
458
+ def _test_loader_from_config(cfg, dataset_name, mapper=None):
459
+ """
460
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
461
+ standard practice is to evaluate each test set individually (not combining them).
462
+ """
463
+ if isinstance(dataset_name, str):
464
+ dataset_name = [dataset_name]
465
+
466
+ dataset = get_detection_dataset_dicts(
467
+ dataset_name,
468
+ filter_empty=False,
469
+ proposal_files=[
470
+ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
471
+ ]
472
+ if cfg.MODEL.LOAD_PROPOSALS
473
+ else None,
474
+ )
475
+ if mapper is None:
476
+ mapper = DatasetMapper(cfg, False)
477
+ return {
478
+ "dataset": dataset,
479
+ "mapper": mapper,
480
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
481
+ "sampler": InferenceSampler(len(dataset))
482
+ if not isinstance(dataset, torchdata.IterableDataset)
483
+ else None,
484
+ }
485
+
486
+
487
+ @configurable(from_config=_test_loader_from_config)
488
+ def build_detection_test_loader(
489
+ dataset: Union[List[Any], torchdata.Dataset],
490
+ *,
491
+ mapper: Callable[[Dict[str, Any]], Any],
492
+ sampler: Optional[torchdata.Sampler] = None,
493
+ batch_size: int = 1,
494
+ num_workers: int = 0,
495
+ collate_fn: Optional[Callable[[List[Any]], Any]] = None,
496
+ ) -> torchdata.DataLoader:
497
+ """
498
+ Similar to `build_detection_train_loader`, with default batch size = 1,
499
+ and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
500
+ to produce the exact set of all samples.
501
+
502
+ Args:
503
+ dataset: a list of dataset dicts,
504
+ or a pytorch dataset (either map-style or iterable). They can be obtained
505
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
506
+ mapper: a callable which takes a sample (dict) from dataset
507
+ and returns the format to be consumed by the model.
508
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
509
+ sampler: a sampler that produces
510
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
511
+ which splits the dataset across all workers. Sampler must be None
512
+ if `dataset` is iterable.
513
+ batch_size: the batch size of the data loader to be created.
514
+ Default to 1 image per worker since this is the standard when reporting
515
+ inference time in papers.
516
+ num_workers: number of parallel data loading workers
517
+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
518
+ Defaults to do no collation and return a list of data.
519
+
520
+ Returns:
521
+ DataLoader: a torch DataLoader, that loads the given detection
522
+ dataset, with test-time transformation and batching.
523
+
524
+ Examples:
525
+ ::
526
+ data_loader = build_detection_test_loader(
527
+ DatasetRegistry.get("my_test"),
528
+ mapper=DatasetMapper(...))
529
+
530
+ # or, instantiate with a CfgNode:
531
+ data_loader = build_detection_test_loader(cfg, "my_test")
532
+ """
533
+ if isinstance(dataset, list):
534
+ dataset = DatasetFromList(dataset, copy=False)
535
+ if mapper is not None:
536
+ dataset = MapDataset(dataset, mapper)
537
+ if isinstance(dataset, torchdata.IterableDataset):
538
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
539
+ else:
540
+ if sampler is None:
541
+ sampler = InferenceSampler(len(dataset))
542
+ return torchdata.DataLoader(
543
+ dataset,
544
+ batch_size=batch_size,
545
+ sampler=sampler,
546
+ drop_last=False,
547
+ num_workers=num_workers,
548
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
549
+ )
550
+
551
+
552
+ def trivial_batch_collator(batch):
553
+ """
554
+ A batch collator that does nothing.
555
+ """
556
+ return batch
557
+
558
+
559
+ def worker_init_reset_seed(worker_id):
560
+ initial_seed = torch.initial_seed() % 2**31
561
+ seed_all_rng(initial_seed + worker_id)
cutler/data/dataset_mapper.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py
3
+
4
+ import copy
5
+ import logging
6
+ import numpy as np
7
+ from typing import List, Optional, Union
8
+ import torch
9
+
10
+ from detectron2.config import configurable
11
+
12
+ import data.detection_utils as utils
13
+ import data.transforms as T
14
+
15
+ """
16
+ This file contains the default mapping that's applied to "dataset dicts".
17
+ """
18
+
19
+ __all__ = ["DatasetMapper"]
20
+
21
+
22
+ class DatasetMapper:
23
+ """
24
+ A callable which takes a dataset dict in Detectron2 Dataset format,
25
+ and map it into a format used by the model.
26
+
27
+ This is the default callable to be used to map your dataset dict into training data.
28
+ You may need to follow it to implement your own one for customized logic,
29
+ such as a different way to read or transform images.
30
+ See :doc:`/tutorials/data_loading` for details.
31
+
32
+ The callable currently does the following:
33
+
34
+ 1. Read the image from "file_name"
35
+ 2. Applies cropping/geometric transforms to the image and annotations
36
+ 3. Prepare data and annotations to Tensor and :class:`Instances`
37
+ """
38
+
39
+ @configurable
40
+ def __init__(
41
+ self,
42
+ is_train: bool,
43
+ *,
44
+ augmentations: List[Union[T.Augmentation, T.Transform]],
45
+ image_format: str,
46
+ use_instance_mask: bool = False,
47
+ use_keypoint: bool = False,
48
+ instance_mask_format: str = "polygon",
49
+ keypoint_hflip_indices: Optional[np.ndarray] = None,
50
+ precomputed_proposal_topk: Optional[int] = None,
51
+ recompute_boxes: bool = False,
52
+ ):
53
+ """
54
+ NOTE: this interface is experimental.
55
+
56
+ Args:
57
+ is_train: whether it's used in training or inference
58
+ augmentations: a list of augmentations or deterministic transforms to apply
59
+ image_format: an image format supported by :func:`detection_utils.read_image`.
60
+ use_instance_mask: whether to process instance segmentation annotations, if available
61
+ use_keypoint: whether to process keypoint annotations if available
62
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
63
+ masks into this format.
64
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
65
+ precomputed_proposal_topk: if given, will load pre-computed
66
+ proposals from dataset_dict and keep the top k proposals for each image.
67
+ recompute_boxes: whether to overwrite bounding box annotations
68
+ by computing tight bounding boxes from instance mask annotations.
69
+ """
70
+ if recompute_boxes:
71
+ assert use_instance_mask, "recompute_boxes requires instance masks"
72
+ # fmt: off
73
+ self.is_train = is_train
74
+ self.augmentations = T.AugmentationList(augmentations)
75
+ self.image_format = image_format
76
+ self.use_instance_mask = use_instance_mask
77
+ self.instance_mask_format = instance_mask_format
78
+ self.use_keypoint = use_keypoint
79
+ self.keypoint_hflip_indices = keypoint_hflip_indices
80
+ self.proposal_topk = precomputed_proposal_topk
81
+ self.recompute_boxes = recompute_boxes
82
+ # fmt: on
83
+ logger = logging.getLogger(__name__)
84
+ mode = "training" if is_train else "inference"
85
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
86
+
87
+ @classmethod
88
+ def from_config(cls, cfg, is_train: bool = True):
89
+ augs = utils.build_augmentation(cfg, is_train)
90
+ if cfg.INPUT.CROP.ENABLED and is_train:
91
+ augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
92
+ recompute_boxes = cfg.MODEL.MASK_ON
93
+ else:
94
+ recompute_boxes = False
95
+
96
+ ret = {
97
+ "is_train": is_train,
98
+ "augmentations": augs,
99
+ "image_format": cfg.INPUT.FORMAT,
100
+ "use_instance_mask": cfg.MODEL.MASK_ON,
101
+ "instance_mask_format": cfg.INPUT.MASK_FORMAT,
102
+ "use_keypoint": cfg.MODEL.KEYPOINT_ON,
103
+ "recompute_boxes": recompute_boxes,
104
+ }
105
+
106
+ if cfg.MODEL.KEYPOINT_ON:
107
+ ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
108
+
109
+ if cfg.MODEL.LOAD_PROPOSALS:
110
+ ret["precomputed_proposal_topk"] = (
111
+ cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
112
+ if is_train
113
+ else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
114
+ )
115
+ return ret
116
+
117
+ def _transform_annotations(self, dataset_dict, transforms, image_shape):
118
+ # USER: Modify this if you want to keep them for some reason.
119
+ for anno in dataset_dict["annotations"]:
120
+ if not self.use_instance_mask:
121
+ anno.pop("segmentation", None)
122
+ if not self.use_keypoint:
123
+ anno.pop("keypoints", None)
124
+
125
+ # USER: Implement additional transformations if you have other types of data
126
+ annos = [
127
+ utils.transform_instance_annotations(
128
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
129
+ )
130
+ for obj in dataset_dict.pop("annotations")
131
+ if obj.get("iscrowd", 0) == 0
132
+ ]
133
+ instances = utils.annotations_to_instances(
134
+ annos, image_shape, mask_format=self.instance_mask_format
135
+ )
136
+
137
+ # After transforms such as cropping are applied, the bounding box may no longer
138
+ # tightly bound the object. As an example, imagine a triangle object
139
+ # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
140
+ # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
141
+ # the intersection of original bounding box and the cropping box.
142
+ if self.recompute_boxes:
143
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
144
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
145
+
146
+ def __call__(self, dataset_dict):
147
+ """
148
+ Args:
149
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
150
+
151
+ Returns:
152
+ dict: a format that builtin models in detectron2 accept
153
+ """
154
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
155
+ # USER: Write your own image loading if it's not from a file
156
+ image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
157
+ utils.check_image_size(dataset_dict, image)
158
+
159
+ # USER: Remove if you don't do semantic/panoptic segmentation.
160
+ if "sem_seg_file_name" in dataset_dict:
161
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
162
+ else:
163
+ sem_seg_gt = None
164
+
165
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
166
+ transforms = self.augmentations(aug_input)
167
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
168
+
169
+ image_shape = image.shape[:2] # h, w
170
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
171
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
172
+ # Therefore it's important to use torch.Tensor.
173
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
174
+ if sem_seg_gt is not None:
175
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
176
+
177
+ # USER: Remove if you don't use pre-computed proposals.
178
+ # Most users would not need this feature.
179
+ if self.proposal_topk is not None:
180
+ utils.transform_proposals(
181
+ dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
182
+ )
183
+
184
+ if not self.is_train:
185
+ # USER: Modify this if you want to keep them for some reason.
186
+ dataset_dict.pop("annotations", None)
187
+ dataset_dict.pop("sem_seg_file_name", None)
188
+ return dataset_dict
189
+
190
+ if "annotations" in dataset_dict:
191
+ self._transform_annotations(dataset_dict, transforms, image_shape)
192
+
193
+ return dataset_dict
cutler/data/datasets/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json
3
+ from .builtin import (
4
+ register_all_imagenet,
5
+ register_all_uvo,
6
+ register_all_coco_ca,
7
+ register_all_coco_semi,
8
+ register_all_lvis,
9
+ register_all_voc,
10
+ register_all_cross_domain,
11
+ register_all_kitti,
12
+ register_all_objects365,
13
+ register_all_openimages,
14
+ )
15
+
16
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
cutler/data/datasets/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (913 Bytes). View file
 
cutler/data/datasets/__pycache__/builtin.cpython-312.pyc ADDED
Binary file (9.74 kB). View file
 
cutler/data/datasets/__pycache__/builtin_meta.cpython-312.pyc ADDED
Binary file (20 kB). View file
 
cutler/data/datasets/__pycache__/coco.cpython-312.pyc ADDED
Binary file (24.8 kB). View file
 
cutler/data/datasets/builtin.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/builtin.py
3
+
4
+ """
5
+ This file registers pre-defined datasets at hard-coded paths, and their metadata.
6
+
7
+ We hard-code metadata for common datasets. This will enable:
8
+ 1. Consistency check when loading the datasets
9
+ 2. Use models on these standard datasets directly and run demos,
10
+ without having to download the dataset annotations
11
+
12
+ We hard-code some paths to the dataset that's assumed to
13
+ exist in "./datasets/".
14
+
15
+ Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
16
+ To add new dataset, refer to the tutorial "docs/DATASETS.md".
17
+ """
18
+
19
+ import os
20
+
21
+ from .builtin_meta import _get_builtin_metadata
22
+ from .coco import register_coco_instances
23
+
24
+ # ==== Predefined datasets and splits for COCO ==========
25
+
26
+ _PREDEFINED_SPLITS_COCO_SEMI = {}
27
+ _PREDEFINED_SPLITS_COCO_SEMI["coco_semi"] = {
28
+ # we use seed 42 to be consistent with previous works on SSL detection and segmentation
29
+ "coco_semi_1perc": ("coco/train2017", "coco/annotations/1perc_instances_train2017.json"),
30
+ "coco_semi_2perc": ("coco/train2017", "coco/annotations/2perc_instances_train2017.json"),
31
+ "coco_semi_5perc": ("coco/train2017", "coco/annotations/5perc_instances_train2017.json"),
32
+ "coco_semi_10perc": ("coco/train2017", "coco/annotations/10perc_instances_train2017.json"),
33
+ "coco_semi_20perc": ("coco/train2017", "coco/annotations/20perc_instances_train2017.json"),
34
+ "coco_semi_30perc": ("coco/train2017", "coco/annotations/30perc_instances_train2017.json"),
35
+ "coco_semi_40perc": ("coco/train2017", "coco/annotations/40perc_instances_train2017.json"),
36
+ "coco_semi_50perc": ("coco/train2017", "coco/annotations/50perc_instances_train2017.json"),
37
+ "coco_semi_60perc": ("coco/train2017", "coco/annotations/60perc_instances_train2017.json"),
38
+ "coco_semi_80perc": ("coco/train2017", "coco/annotations/80perc_instances_train2017.json"),
39
+ }
40
+
41
+ _PREDEFINED_SPLITS_COCO_CA = {}
42
+ _PREDEFINED_SPLITS_COCO_CA["coco_cls_agnostic"] = {
43
+ "cls_agnostic_coco": ("coco/val2017", "coco/annotations/coco_cls_agnostic_instances_val2017.json"),
44
+ "cls_agnostic_coco20k": ("coco/train2014", "coco/annotations/coco20k_trainval_gt.json"),
45
+ }
46
+
47
+ _PREDEFINED_SPLITS_IMAGENET = {}
48
+ _PREDEFINED_SPLITS_IMAGENET["imagenet"] = {
49
+ # maskcut annotations
50
+ "imagenet_train": ("imagenet/train", "imagenet/annotations/imagenet_train_fixsize480_tau0.15_N3.json"),
51
+ # self-training round 1
52
+ "imagenet_train_r1": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r1.json"),
53
+ # self-training round 2
54
+ "imagenet_train_r2": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r2.json"),
55
+ # self-training round 3
56
+ "imagenet_train_r3": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r3.json"),
57
+ }
58
+
59
+ _PREDEFINED_SPLITS_VOC = {}
60
+ _PREDEFINED_SPLITS_VOC["voc"] = {
61
+ 'cls_agnostic_voc': ("voc/", "voc/annotations/trainvaltest_2007_cls_agnostic.json"),
62
+ }
63
+
64
+ _PREDEFINED_SPLITS_CROSSDOMAIN = {}
65
+ _PREDEFINED_SPLITS_CROSSDOMAIN["cross_domain"] = {
66
+ 'cls_agnostic_clipart': ("clipart/", "clipart/annotations/traintest_cls_agnostic.json"),
67
+ 'cls_agnostic_watercolor': ("watercolor/", "watercolor/annotations/traintest_cls_agnostic.json"),
68
+ 'cls_agnostic_comic': ("comic/", "comic/annotations/traintest_cls_agnostic.json"),
69
+ }
70
+
71
+ _PREDEFINED_SPLITS_KITTI = {}
72
+ _PREDEFINED_SPLITS_KITTI["kitti"] = {
73
+ 'cls_agnostic_kitti': ("kitti/", "kitti/annotations/trainval_cls_agnostic.json"),
74
+ }
75
+
76
+ _PREDEFINED_SPLITS_LVIS = {}
77
+ _PREDEFINED_SPLITS_LVIS["lvis"] = {
78
+ "cls_agnostic_lvis": ("coco/", "coco/annotations/lvis1.0_cocofied_val_cls_agnostic.json"),
79
+ }
80
+
81
+ _PREDEFINED_SPLITS_OBJECTS365 = {}
82
+ _PREDEFINED_SPLITS_OBJECTS365["objects365"] = {
83
+ 'cls_agnostic_objects365': ("objects365/val", "objects365/annotations/zhiyuan_objv2_val_cls_agnostic.json"),
84
+ }
85
+
86
+ _PREDEFINED_SPLITS_OpenImages = {}
87
+ _PREDEFINED_SPLITS_OpenImages["openimages"] = {
88
+ 'cls_agnostic_openimages': ("openImages/validation", "openImages/annotations/openimages_val_cls_agnostic.json"),
89
+ }
90
+
91
+ _PREDEFINED_SPLITS_UVO = {}
92
+ _PREDEFINED_SPLITS_UVO["uvo"] = {
93
+ "cls_agnostic_uvo": ("uvo/all_UVO_frames", "uvo/annotations/val_sparse_cleaned_cls_agnostic.json"),
94
+ }
95
+
96
+ def register_all_imagenet(root):
97
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_IMAGENET.items():
98
+ for key, (image_root, json_file) in splits_per_dataset.items():
99
+ # Assume pre-defined datasets live in `./datasets`.
100
+ register_coco_instances(
101
+ key,
102
+ _get_builtin_metadata(dataset_name),
103
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
104
+ os.path.join(root, image_root),
105
+ )
106
+
107
+ def register_all_voc(root):
108
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_VOC.items():
109
+ for key, (image_root, json_file) in splits_per_dataset.items():
110
+ # Assume pre-defined datasets live in `./datasets`.
111
+ register_coco_instances(
112
+ key,
113
+ _get_builtin_metadata(dataset_name),
114
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
115
+ os.path.join(root, image_root),
116
+ )
117
+
118
+ def register_all_cross_domain(root):
119
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_CROSSDOMAIN.items():
120
+ for key, (image_root, json_file) in splits_per_dataset.items():
121
+ # Assume pre-defined datasets live in `./datasets`.
122
+ register_coco_instances(
123
+ key,
124
+ _get_builtin_metadata(dataset_name),
125
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
126
+ os.path.join(root, image_root),
127
+ )
128
+
129
+ def register_all_kitti(root):
130
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_KITTI.items():
131
+ for key, (image_root, json_file) in splits_per_dataset.items():
132
+ # Assume pre-defined datasets live in `./datasets`.
133
+ register_coco_instances(
134
+ key,
135
+ _get_builtin_metadata(dataset_name),
136
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
137
+ os.path.join(root, image_root),
138
+ )
139
+
140
+ def register_all_objects365(root):
141
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_OBJECTS365.items():
142
+ for key, (image_root, json_file) in splits_per_dataset.items():
143
+ # Assume pre-defined datasets live in `./datasets`.
144
+ register_coco_instances(
145
+ key,
146
+ _get_builtin_metadata(dataset_name),
147
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
148
+ os.path.join(root, image_root),
149
+ )
150
+
151
+ def register_all_openimages(root):
152
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_OpenImages.items():
153
+ for key, (image_root, json_file) in splits_per_dataset.items():
154
+ # Assume pre-defined datasets live in `./datasets`.
155
+ register_coco_instances(
156
+ key,
157
+ _get_builtin_metadata(dataset_name),
158
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
159
+ os.path.join(root, image_root),
160
+ )
161
+
162
+ def register_all_lvis(root):
163
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items():
164
+ for key, (image_root, json_file) in splits_per_dataset.items():
165
+ # Assume pre-defined datasets live in `./datasets`.
166
+ register_coco_instances(
167
+ key,
168
+ _get_builtin_metadata(dataset_name),
169
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
170
+ os.path.join(root, image_root),
171
+ )
172
+
173
+ def register_all_uvo(root):
174
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_UVO.items():
175
+ for key, (image_root, json_file) in splits_per_dataset.items():
176
+ # Assume pre-defined datasets live in `./datasets`.
177
+ register_coco_instances(
178
+ key,
179
+ _get_builtin_metadata(dataset_name),
180
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
181
+ os.path.join(root, image_root),
182
+ )
183
+
184
+ def register_all_coco_semi(root):
185
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO_SEMI.items():
186
+ for key, (image_root, json_file) in splits_per_dataset.items():
187
+ # Assume pre-defined datasets live in `./datasets`.
188
+ register_coco_instances(
189
+ key,
190
+ _get_builtin_metadata(dataset_name),
191
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
192
+ os.path.join(root, image_root),
193
+ )
194
+
195
+ def register_all_coco_ca(root):
196
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO_CA.items():
197
+ for key, (image_root, json_file) in splits_per_dataset.items():
198
+ # Assume pre-defined datasets live in `./datasets`.
199
+ register_coco_instances(
200
+ key,
201
+ _get_builtin_metadata(dataset_name),
202
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
203
+ os.path.join(root, image_root),
204
+ )
205
+
206
+ _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
207
+ register_all_coco_semi(_root)
208
+ register_all_coco_ca(_root)
209
+ register_all_imagenet(_root)
210
+ register_all_uvo(_root)
211
+ register_all_voc(_root)
212
+ register_all_cross_domain(_root)
213
+ register_all_kitti(_root)
214
+ register_all_openimages(_root)
215
+ register_all_objects365(_root)
216
+ register_all_lvis(_root)
cutler/data/datasets/builtin_meta.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/builtin_meta.py
3
+
4
+ """
5
+ Note:
6
+ For your custom dataset, there is no need to hard-code metadata anywhere in the code.
7
+ For example, for COCO-format dataset, metadata will be obtained automatically
8
+ when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways
9
+ during loading.
10
+
11
+ However, we hard-coded metadata for a few common dataset here.
12
+ The only goal is to allow users who don't have these dataset to use pre-trained models.
13
+ Users don't have to download a COCO json (which contains metadata), in order to visualize a
14
+ COCO model (with correct class names and colors).
15
+ """
16
+
17
+
18
+ # All coco categories, together with their nice-looking visualization colors
19
+ # It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
20
+ COCO_CATEGORIES = [
21
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
22
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
23
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
24
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
25
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
26
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
27
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
28
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
29
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
30
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
31
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
32
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
33
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
34
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
35
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
36
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
37
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
38
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
39
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
40
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
41
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
42
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
43
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
44
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
45
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
46
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
47
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
48
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
49
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
50
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
51
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
52
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
53
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
54
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
55
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
56
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
57
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
58
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
59
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
60
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
61
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
62
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
63
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
64
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
65
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
66
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
67
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
68
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
69
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
70
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
71
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
72
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
73
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
74
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
75
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
76
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
77
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
78
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
79
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
80
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
81
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
82
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
83
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
84
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
85
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
86
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
87
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
88
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
89
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
90
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
91
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
92
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
93
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
94
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
95
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
96
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
97
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
98
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
99
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
100
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
101
+ {"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"},
102
+ {"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"},
103
+ {"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"},
104
+ {"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"},
105
+ {"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"},
106
+ {"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"},
107
+ {"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"},
108
+ {"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"},
109
+ {"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"},
110
+ {"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"},
111
+ {"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"},
112
+ {"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"},
113
+ {"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"},
114
+ {"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"},
115
+ {"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"},
116
+ {"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"},
117
+ {"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"},
118
+ {"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"},
119
+ {"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"},
120
+ {"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"},
121
+ {"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"},
122
+ {"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"},
123
+ {"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"},
124
+ {"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"},
125
+ {"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"},
126
+ {"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"},
127
+ {"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"},
128
+ {"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"},
129
+ {"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"},
130
+ {"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"},
131
+ {"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"},
132
+ {"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"},
133
+ {"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"},
134
+ {"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"},
135
+ {"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"},
136
+ {"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"},
137
+ {"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"},
138
+ {"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"},
139
+ {"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"},
140
+ {"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"},
141
+ {"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"},
142
+ {"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"},
143
+ {"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"},
144
+ {"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"},
145
+ {"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"},
146
+ {"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"},
147
+ {"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"},
148
+ {"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"},
149
+ {"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"},
150
+ {"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"},
151
+ {"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"},
152
+ {"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"},
153
+ {"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"},
154
+ ]
155
+
156
+ IMAGENET_CATEGORIES = [
157
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "fg"},
158
+ ]
159
+
160
+ UVO_CATEGORIES = [
161
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "object"},
162
+ ]
163
+
164
+ # fmt: off
165
+ COCO_PERSON_KEYPOINT_NAMES = (
166
+ "nose",
167
+ "left_eye", "right_eye",
168
+ "left_ear", "right_ear",
169
+ "left_shoulder", "right_shoulder",
170
+ "left_elbow", "right_elbow",
171
+ "left_wrist", "right_wrist",
172
+ "left_hip", "right_hip",
173
+ "left_knee", "right_knee",
174
+ "left_ankle", "right_ankle",
175
+ )
176
+ # fmt: on
177
+
178
+ # Pairs of keypoints that should be exchanged under horizontal flipping
179
+ COCO_PERSON_KEYPOINT_FLIP_MAP = (
180
+ ("left_eye", "right_eye"),
181
+ ("left_ear", "right_ear"),
182
+ ("left_shoulder", "right_shoulder"),
183
+ ("left_elbow", "right_elbow"),
184
+ ("left_wrist", "right_wrist"),
185
+ ("left_hip", "right_hip"),
186
+ ("left_knee", "right_knee"),
187
+ ("left_ankle", "right_ankle"),
188
+ )
189
+
190
+ # rules for pairs of keypoints to draw a line between, and the line color to use.
191
+ KEYPOINT_CONNECTION_RULES = [
192
+ # face
193
+ ("left_ear", "left_eye", (102, 204, 255)),
194
+ ("right_ear", "right_eye", (51, 153, 255)),
195
+ ("left_eye", "nose", (102, 0, 204)),
196
+ ("nose", "right_eye", (51, 102, 255)),
197
+ # upper-body
198
+ ("left_shoulder", "right_shoulder", (255, 128, 0)),
199
+ ("left_shoulder", "left_elbow", (153, 255, 204)),
200
+ ("right_shoulder", "right_elbow", (128, 229, 255)),
201
+ ("left_elbow", "left_wrist", (153, 255, 153)),
202
+ ("right_elbow", "right_wrist", (102, 255, 224)),
203
+ # lower-body
204
+ ("left_hip", "right_hip", (255, 102, 0)),
205
+ ("left_hip", "left_knee", (255, 255, 77)),
206
+ ("right_hip", "right_knee", (153, 255, 204)),
207
+ ("left_knee", "left_ankle", (191, 255, 128)),
208
+ ("right_knee", "right_ankle", (255, 195, 77)),
209
+ ]
210
+
211
+ # All Cityscapes categories, together with their nice-looking visualization colors
212
+ # It's from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py # noqa
213
+ CITYSCAPES_CATEGORIES = [
214
+ {"color": (128, 64, 128), "isthing": 0, "id": 7, "trainId": 0, "name": "road"},
215
+ {"color": (244, 35, 232), "isthing": 0, "id": 8, "trainId": 1, "name": "sidewalk"},
216
+ {"color": (70, 70, 70), "isthing": 0, "id": 11, "trainId": 2, "name": "building"},
217
+ {"color": (102, 102, 156), "isthing": 0, "id": 12, "trainId": 3, "name": "wall"},
218
+ {"color": (190, 153, 153), "isthing": 0, "id": 13, "trainId": 4, "name": "fence"},
219
+ {"color": (153, 153, 153), "isthing": 0, "id": 17, "trainId": 5, "name": "pole"},
220
+ {"color": (250, 170, 30), "isthing": 0, "id": 19, "trainId": 6, "name": "traffic light"},
221
+ {"color": (220, 220, 0), "isthing": 0, "id": 20, "trainId": 7, "name": "traffic sign"},
222
+ {"color": (107, 142, 35), "isthing": 0, "id": 21, "trainId": 8, "name": "vegetation"},
223
+ {"color": (152, 251, 152), "isthing": 0, "id": 22, "trainId": 9, "name": "terrain"},
224
+ {"color": (70, 130, 180), "isthing": 0, "id": 23, "trainId": 10, "name": "sky"},
225
+ {"color": (220, 20, 60), "isthing": 1, "id": 24, "trainId": 11, "name": "person"},
226
+ {"color": (255, 0, 0), "isthing": 1, "id": 25, "trainId": 12, "name": "rider"},
227
+ {"color": (0, 0, 142), "isthing": 1, "id": 26, "trainId": 13, "name": "car"},
228
+ {"color": (0, 0, 70), "isthing": 1, "id": 27, "trainId": 14, "name": "truck"},
229
+ {"color": (0, 60, 100), "isthing": 1, "id": 28, "trainId": 15, "name": "bus"},
230
+ {"color": (0, 80, 100), "isthing": 1, "id": 31, "trainId": 16, "name": "train"},
231
+ {"color": (0, 0, 230), "isthing": 1, "id": 32, "trainId": 17, "name": "motorcycle"},
232
+ {"color": (119, 11, 32), "isthing": 1, "id": 33, "trainId": 18, "name": "bicycle"},
233
+ ]
234
+
235
+ # fmt: off
236
+ ADE20K_SEM_SEG_CATEGORIES = [
237
+ "wall", "building", "sky", "floor", "tree", "ceiling", "road, route", "bed", "window ", "grass", "cabinet", "sidewalk, pavement", "person", "earth, ground", "door", "table", "mountain, mount", "plant", "curtain", "chair", "car", "water", "painting, picture", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "tub", "rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove", "palm, palm tree", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus", "towel", "light", "truck", "tower", "chandelier", "awning, sunshade, sunblind", "street lamp", "booth", "tv", "plane", "dirt track", "clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "pool", "stool", "barrel, cask", "basket, handbasket", "falls", "tent", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light", "tray", "trash can", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass, drinking glass", "clock", "flag", # noqa
238
+ ]
239
+ # After processed by `prepare_ade20k_sem_seg.py`, id 255 means ignore
240
+ # fmt: on
241
+
242
+
243
+ def _get_coco_instances_meta():
244
+ thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
245
+ thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
246
+ assert len(thing_ids) == 80, len(thing_ids)
247
+ # Mapping from the incontiguous COCO category id to an id in [0, 79]
248
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
249
+ thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
250
+ ret = {
251
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
252
+ "thing_classes": thing_classes,
253
+ "thing_colors": thing_colors,
254
+ }
255
+ return ret
256
+
257
+ def _get_imagenet_instances_meta():
258
+ thing_ids = [k["id"] for k in IMAGENET_CATEGORIES if k["isthing"] == 1]
259
+ thing_colors = [k["color"] for k in IMAGENET_CATEGORIES if k["isthing"] == 1]
260
+ assert len(thing_ids) == 1, len(thing_ids)
261
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
262
+ thing_classes = [k["name"] for k in IMAGENET_CATEGORIES if k["isthing"] == 1]
263
+ ret = {
264
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
265
+ "thing_classes": thing_classes,
266
+ "thing_colors": thing_colors,
267
+ "class_image_count": [{'id': 1, 'image_count': 116986}]
268
+ }
269
+ return ret
270
+
271
+ def _get_UVO_instances_meta():
272
+ thing_ids = [k["id"] for k in UVO_CATEGORIES if k["isthing"] == 1]
273
+ thing_colors = [k["color"] for k in UVO_CATEGORIES if k["isthing"] == 1]
274
+ assert len(thing_ids) == 1, len(thing_ids)
275
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
276
+ thing_classes = [k["name"] for k in UVO_CATEGORIES if k["isthing"] == 1]
277
+ ret = {
278
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
279
+ "thing_classes": thing_classes,
280
+ "thing_colors": thing_colors,
281
+ "class_image_count": [{'id': 1, 'image_count': 116986}]
282
+ }
283
+ return ret
284
+
285
+ def _get_coco_panoptic_separated_meta():
286
+ """
287
+ Returns metadata for "separated" version of the panoptic segmentation dataset.
288
+ """
289
+ stuff_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 0]
290
+ assert len(stuff_ids) == 53, len(stuff_ids)
291
+
292
+ # For semantic segmentation, this mapping maps from contiguous stuff id
293
+ # (in [0, 53], used in models) to ids in the dataset (used for processing results)
294
+ # The id 0 is mapped to an extra category "thing".
295
+ stuff_dataset_id_to_contiguous_id = {k: i + 1 for i, k in enumerate(stuff_ids)}
296
+ # When converting COCO panoptic annotations to semantic annotations
297
+ # We label the "thing" category to 0
298
+ stuff_dataset_id_to_contiguous_id[0] = 0
299
+
300
+ # 54 names for COCO stuff categories (including "things")
301
+ stuff_classes = ["things"] + [
302
+ k["name"].replace("-other", "").replace("-merged", "")
303
+ for k in COCO_CATEGORIES
304
+ if k["isthing"] == 0
305
+ ]
306
+
307
+ # NOTE: I randomly picked a color for things
308
+ stuff_colors = [[82, 18, 128]] + [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 0]
309
+ ret = {
310
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
311
+ "stuff_classes": stuff_classes,
312
+ "stuff_colors": stuff_colors,
313
+ }
314
+ ret.update(_get_coco_instances_meta())
315
+ return ret
316
+
317
+
318
+ def _get_builtin_metadata(dataset_name):
319
+ if dataset_name in ["coco", "coco_semi"]:
320
+ return _get_coco_instances_meta()
321
+ if dataset_name == "coco_panoptic_separated":
322
+ return _get_coco_panoptic_separated_meta()
323
+ elif dataset_name in ["imagenet", "kitti", "cross_domain", "lvis", "voc", "coco_cls_agnostic", "objects365", 'openimages']:
324
+ return _get_imagenet_instances_meta()
325
+ elif dataset_name == "uvo":
326
+ return _get_UVO_instances_meta()
327
+ elif dataset_name == "coco_panoptic_standard":
328
+ meta = {}
329
+ # The following metadata maps contiguous id from [0, #thing categories +
330
+ # #stuff categories) to their names and colors. We have to replica of the
331
+ # same name and color under "thing_*" and "stuff_*" because the current
332
+ # visualization function in D2 handles thing and class classes differently
333
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
334
+ # enable reusing existing visualization functions.
335
+ thing_classes = [k["name"] for k in COCO_CATEGORIES]
336
+ thing_colors = [k["color"] for k in COCO_CATEGORIES]
337
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
338
+ stuff_colors = [k["color"] for k in COCO_CATEGORIES]
339
+
340
+ meta["thing_classes"] = thing_classes
341
+ meta["thing_colors"] = thing_colors
342
+ meta["stuff_classes"] = stuff_classes
343
+ meta["stuff_colors"] = stuff_colors
344
+
345
+ # Convert category id for training:
346
+ # category id: like semantic segmentation, it is the class id for each
347
+ # pixel. Since there are some classes not used in evaluation, the category
348
+ # id is not always contiguous and thus we have two set of category ids:
349
+ # - original category id: category id in the original dataset, mainly
350
+ # used for evaluation.
351
+ # - contiguous category id: [0, #classes), in order to train the linear
352
+ # softmax classifier.
353
+ thing_dataset_id_to_contiguous_id = {}
354
+ stuff_dataset_id_to_contiguous_id = {}
355
+
356
+ for i, cat in enumerate(COCO_CATEGORIES):
357
+ if cat["isthing"]:
358
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
359
+ else:
360
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
361
+
362
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
363
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
364
+
365
+ return meta
366
+ elif dataset_name == "coco_person":
367
+ return {
368
+ "thing_classes": ["person"],
369
+ "keypoint_names": COCO_PERSON_KEYPOINT_NAMES,
370
+ "keypoint_flip_map": COCO_PERSON_KEYPOINT_FLIP_MAP,
371
+ "keypoint_connection_rules": KEYPOINT_CONNECTION_RULES,
372
+ }
373
+ elif dataset_name == "cityscapes":
374
+ # fmt: off
375
+ CITYSCAPES_THING_CLASSES = [
376
+ "person", "rider", "car", "truck",
377
+ "bus", "train", "motorcycle", "bicycle",
378
+ ]
379
+ CITYSCAPES_STUFF_CLASSES = [
380
+ "road", "sidewalk", "building", "wall", "fence", "pole", "traffic light",
381
+ "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car",
382
+ "truck", "bus", "train", "motorcycle", "bicycle",
383
+ ]
384
+ # fmt: on
385
+ return {
386
+ "thing_classes": CITYSCAPES_THING_CLASSES,
387
+ "stuff_classes": CITYSCAPES_STUFF_CLASSES,
388
+ }
389
+ raise KeyError("No built-in metadata for dataset {}".format(dataset_name))
cutler/data/datasets/coco.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/coco.py
3
+
4
+ import contextlib
5
+ import datetime
6
+ import io
7
+ import json
8
+ import logging
9
+ import numpy as np
10
+ import os
11
+ import shutil
12
+ import pycocotools.mask as mask_util
13
+ from fvcore.common.timer import Timer
14
+ from iopath.common.file_io import file_lock
15
+ from PIL import Image
16
+
17
+ from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
18
+ from detectron2.utils.file_io import PathManager
19
+
20
+ from detectron2.data import DatasetCatalog, MetadataCatalog
21
+
22
+ """
23
+ This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format".
24
+ """
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ __all__ = ["load_coco_json", "load_sem_seg", "convert_to_coco_json", "register_coco_instances"]
30
+
31
+
32
+ def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
33
+ """
34
+ Load a json file with COCO's instances annotation format.
35
+ Currently supports instance detection, instance segmentation,
36
+ and person keypoints annotations.
37
+
38
+ Args:
39
+ json_file (str): full path to the json file in COCO instances annotation format.
40
+ image_root (str or path-like): the directory where the images in this json file exists.
41
+ dataset_name (str or None): the name of the dataset (e.g., coco_2017_train).
42
+ When provided, this function will also do the following:
43
+
44
+ * Put "thing_classes" into the metadata associated with this dataset.
45
+ * Map the category ids into a contiguous range (needed by standard dataset format),
46
+ and add "thing_dataset_id_to_contiguous_id" to the metadata associated
47
+ with this dataset.
48
+
49
+ This option should usually be provided, unless users need to load
50
+ the original json content and apply more processing manually.
51
+ extra_annotation_keys (list[str]): list of per-annotation keys that should also be
52
+ loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
53
+ "category_id", "segmentation"). The values for these keys will be returned as-is.
54
+ For example, the densepose annotations are loaded in this way.
55
+
56
+ Returns:
57
+ list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See
58
+ `Using Custom Datasets </tutorials/datasets.html>`_ ) when `dataset_name` is not None.
59
+ If `dataset_name` is None, the returned `category_ids` may be
60
+ incontiguous and may not conform to the Detectron2 standard format.
61
+
62
+ Notes:
63
+ 1. This function does not read the image files.
64
+ The results do not have the "image" field.
65
+ """
66
+ from pycocotools.coco import COCO
67
+
68
+ timer = Timer()
69
+ json_file = PathManager.get_local_path(json_file)
70
+ with contextlib.redirect_stdout(io.StringIO()):
71
+ coco_api = COCO(json_file)
72
+ if timer.seconds() > 1:
73
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
74
+
75
+ id_map = None
76
+ if dataset_name is not None:
77
+ meta = MetadataCatalog.get(dataset_name)
78
+ cat_ids = sorted(coco_api.getCatIds())
79
+ cats = coco_api.loadCats(cat_ids)
80
+ # The categories in a custom json file may not be sorted.
81
+ thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
82
+ if "imagenet" not in dataset_name and "cls_agnostic" not in dataset_name:
83
+ meta.thing_classes = thing_classes
84
+
85
+ # In COCO, certain category ids are artificially removed,
86
+ # and by convention they are always ignored.
87
+ # We deal with COCO's id issue and translate
88
+ # the category ids to contiguous ids in [0, 80).
89
+
90
+ # It works by looking at the "categories" field in the json, therefore
91
+ # if users' own json also have incontiguous ids, we'll
92
+ # apply this mapping as well but print a warning.
93
+ if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
94
+ if "coco" not in dataset_name:
95
+ logger.warning(
96
+ """
97
+ Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
98
+ """
99
+ )
100
+ id_map = {v: i for i, v in enumerate(cat_ids)}
101
+ meta.thing_dataset_id_to_contiguous_id = id_map
102
+ else:
103
+ id_map = meta.thing_dataset_id_to_contiguous_id
104
+
105
+ # sort indices for reproducible results
106
+ img_ids = sorted(coco_api.imgs.keys())
107
+ # imgs is a list of dicts, each looks something like:
108
+ # {'license': 4,
109
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
110
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
111
+ # 'height': 427,
112
+ # 'width': 640,
113
+ # 'date_captured': '2013-11-17 05:57:24',
114
+ # 'id': 1268}
115
+ imgs = coco_api.loadImgs(img_ids)
116
+ # anns is a list[list[dict]], where each dict is an annotation
117
+ # record for an object. The inner list enumerates the objects in an image
118
+ # and the outer list enumerates over images. Example of anns[0]:
119
+ # [{'segmentation': [[192.81,
120
+ # 247.09,
121
+ # ...
122
+ # 219.03,
123
+ # 249.06]],
124
+ # 'area': 1035.749,
125
+ # 'iscrowd': 0,
126
+ # 'image_id': 1268,
127
+ # 'bbox': [192.81, 224.8, 74.73, 33.43],
128
+ # 'category_id': 16,
129
+ # 'id': 42986},
130
+ # ...]
131
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
132
+ total_num_valid_anns = sum([len(x) for x in anns])
133
+ total_num_anns = len(coco_api.anns)
134
+ if total_num_valid_anns < total_num_anns:
135
+ logger.warning(
136
+ f"{json_file} contains {total_num_anns} annotations, but only "
137
+ f"{total_num_valid_anns} of them match to images in the file."
138
+ )
139
+
140
+ if "minival" not in json_file:
141
+ # The popular valminusminival & minival annotations for COCO2014 contain this bug.
142
+ # However the ratio of buggy annotations there is tiny and does not affect accuracy.
143
+ # Therefore we explicitly white-list them.
144
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
145
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
146
+ json_file
147
+ )
148
+
149
+ imgs_anns = list(zip(imgs, anns))
150
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
151
+
152
+ dataset_dicts = []
153
+
154
+ ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])
155
+
156
+ num_instances_without_valid_segmentation = 0
157
+
158
+ for (img_dict, anno_dict_list) in imgs_anns:
159
+ record = {}
160
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
161
+ record["height"] = img_dict["height"]
162
+ record["width"] = img_dict["width"]
163
+ image_id = record["image_id"] = img_dict["id"]
164
+
165
+ objs = []
166
+ for anno in anno_dict_list:
167
+ # Check that the image_id in this annotation is the same as
168
+ # the image_id we're looking at.
169
+ # This fails only when the data parsing logic or the annotation file is buggy.
170
+
171
+ # The original COCO valminusminival2014 & minival2014 annotation files
172
+ # actually contains bugs that, together with certain ways of using COCO API,
173
+ # can trigger this assertion.
174
+ assert anno["image_id"] == image_id
175
+
176
+ assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
177
+
178
+ obj = {key: anno[key] for key in ann_keys if key in anno}
179
+ if "bbox" in obj and len(obj["bbox"]) == 0:
180
+ raise ValueError(
181
+ f"One annotation of image {image_id} contains empty 'bbox' value! "
182
+ "This json does not have valid COCO format."
183
+ )
184
+
185
+ segm = anno.get("segmentation", None)
186
+ if segm: # either list[list[float]] or dict(RLE)
187
+ if isinstance(segm, dict):
188
+ if isinstance(segm["counts"], list):
189
+ # convert to compressed RLE
190
+ segm = mask_util.frPyObjects(segm, *segm["size"])
191
+ else:
192
+ # filter out invalid polygons (< 3 points)
193
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
194
+ if len(segm) == 0:
195
+ num_instances_without_valid_segmentation += 1
196
+ continue # ignore this instance
197
+ obj["segmentation"] = segm
198
+
199
+ keypts = anno.get("keypoints", None)
200
+ if keypts: # list[int]
201
+ for idx, v in enumerate(keypts):
202
+ if idx % 3 != 2:
203
+ # COCO's segmentation coordinates are floating points in [0, H or W],
204
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
205
+ # Therefore we assume the coordinates are "pixel indices" and
206
+ # add 0.5 to convert to floating point coordinates.
207
+ keypts[idx] = v + 0.5
208
+ obj["keypoints"] = keypts
209
+
210
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
211
+ if id_map:
212
+ annotation_category_id = obj["category_id"]
213
+ try:
214
+ obj["category_id"] = id_map[annotation_category_id]
215
+ except KeyError as e:
216
+ raise KeyError(
217
+ f"Encountered category_id={annotation_category_id} "
218
+ "but this id does not exist in 'categories' of the json file."
219
+ ) from e
220
+ objs.append(obj)
221
+ record["annotations"] = objs
222
+ dataset_dicts.append(record)
223
+
224
+ if num_instances_without_valid_segmentation > 0:
225
+ logger.warning(
226
+ "Filtered out {} instances without valid segmentation. ".format(
227
+ num_instances_without_valid_segmentation
228
+ )
229
+ + "There might be issues in your dataset generation process. Please "
230
+ "check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
231
+ )
232
+ return dataset_dicts
233
+
234
+
235
+ def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"):
236
+ """
237
+ Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are
238
+ treated as ground truth annotations and all files under "image_root" with "image_ext" extension
239
+ as input images. Ground truth and input images are matched using file paths relative to
240
+ "gt_root" and "image_root" respectively without taking into account file extensions.
241
+ This works for COCO as well as some other datasets.
242
+
243
+ Args:
244
+ gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation
245
+ annotations are stored as images with integer values in pixels that represent
246
+ corresponding semantic labels.
247
+ image_root (str): the directory where the input images are.
248
+ gt_ext (str): file extension for ground truth annotations.
249
+ image_ext (str): file extension for input images.
250
+
251
+ Returns:
252
+ list[dict]:
253
+ a list of dicts in detectron2 standard format without instance-level
254
+ annotation.
255
+
256
+ Notes:
257
+ 1. This function does not read the image and ground truth files.
258
+ The results do not have the "image" and "sem_seg" fields.
259
+ """
260
+
261
+ # We match input images with ground truth based on their relative filepaths (without file
262
+ # extensions) starting from 'image_root' and 'gt_root' respectively.
263
+ def file2id(folder_path, file_path):
264
+ # extract relative path starting from `folder_path`
265
+ image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path))
266
+ # remove file extension
267
+ image_id = os.path.splitext(image_id)[0]
268
+ return image_id
269
+
270
+ input_files = sorted(
271
+ (os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)),
272
+ key=lambda file_path: file2id(image_root, file_path),
273
+ )
274
+ gt_files = sorted(
275
+ (os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)),
276
+ key=lambda file_path: file2id(gt_root, file_path),
277
+ )
278
+
279
+ assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root)
280
+
281
+ # Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images
282
+ if len(input_files) != len(gt_files):
283
+ logger.warn(
284
+ "Directory {} and {} has {} and {} files, respectively.".format(
285
+ image_root, gt_root, len(input_files), len(gt_files)
286
+ )
287
+ )
288
+ input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files]
289
+ gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files]
290
+ intersect = list(set(input_basenames) & set(gt_basenames))
291
+ # sort, otherwise each worker may obtain a list[dict] in different order
292
+ intersect = sorted(intersect)
293
+ logger.warn("Will use their intersection of {} files.".format(len(intersect)))
294
+ input_files = [os.path.join(image_root, f + image_ext) for f in intersect]
295
+ gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect]
296
+
297
+ logger.info(
298
+ "Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root)
299
+ )
300
+
301
+ dataset_dicts = []
302
+ for (img_path, gt_path) in zip(input_files, gt_files):
303
+ record = {}
304
+ record["file_name"] = img_path
305
+ record["sem_seg_file_name"] = gt_path
306
+ dataset_dicts.append(record)
307
+
308
+ return dataset_dicts
309
+
310
+
311
+ def convert_to_coco_dict(dataset_name):
312
+ """
313
+ Convert an instance detection/segmentation or keypoint detection dataset
314
+ in detectron2's standard format into COCO json format.
315
+
316
+ Generic dataset description can be found here:
317
+ https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset
318
+
319
+ COCO data format description can be found here:
320
+ http://cocodataset.org/#format-data
321
+
322
+ Args:
323
+ dataset_name (str):
324
+ name of the source dataset
325
+ Must be registered in DatastCatalog and in detectron2's standard format.
326
+ Must have corresponding metadata "thing_classes"
327
+ Returns:
328
+ coco_dict: serializable dict in COCO json format
329
+ """
330
+
331
+ dataset_dicts = DatasetCatalog.get(dataset_name)
332
+ metadata = MetadataCatalog.get(dataset_name)
333
+
334
+ # unmap the category mapping ids for COCO
335
+ if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
336
+ reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()}
337
+ reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa
338
+ else:
339
+ reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa
340
+
341
+ categories = [
342
+ {"id": reverse_id_mapper(id), "name": name}
343
+ for id, name in enumerate(metadata.thing_classes)
344
+ ]
345
+
346
+ logger.info("Converting dataset dicts into COCO format")
347
+ coco_images = []
348
+ coco_annotations = []
349
+
350
+ for image_id, image_dict in enumerate(dataset_dicts):
351
+ coco_image = {
352
+ "id": image_dict.get("image_id", image_id),
353
+ "width": int(image_dict["width"]),
354
+ "height": int(image_dict["height"]),
355
+ "file_name": str(image_dict["file_name"]),
356
+ }
357
+ coco_images.append(coco_image)
358
+
359
+ anns_per_image = image_dict.get("annotations", [])
360
+ for annotation in anns_per_image:
361
+ # create a new dict with only COCO fields
362
+ coco_annotation = {}
363
+
364
+ # COCO requirement: XYWH box format for axis-align and XYWHA for rotated
365
+ bbox = annotation["bbox"]
366
+ if isinstance(bbox, np.ndarray):
367
+ if bbox.ndim != 1:
368
+ raise ValueError(f"bbox has to be 1-dimensional. Got shape={bbox.shape}.")
369
+ bbox = bbox.tolist()
370
+ if len(bbox) not in [4, 5]:
371
+ raise ValueError(f"bbox has to has length 4 or 5. Got {bbox}.")
372
+ from_bbox_mode = annotation["bbox_mode"]
373
+ to_bbox_mode = BoxMode.XYWH_ABS if len(bbox) == 4 else BoxMode.XYWHA_ABS
374
+ bbox = BoxMode.convert(bbox, from_bbox_mode, to_bbox_mode)
375
+
376
+ # COCO requirement: instance area
377
+ if "segmentation" in annotation:
378
+ # Computing areas for instances by counting the pixels
379
+ segmentation = annotation["segmentation"]
380
+ # TODO: check segmentation type: RLE, BinaryMask or Polygon
381
+ if isinstance(segmentation, list):
382
+ polygons = PolygonMasks([segmentation])
383
+ area = polygons.area()[0].item()
384
+ elif isinstance(segmentation, dict): # RLE
385
+ area = mask_util.area(segmentation).item()
386
+ else:
387
+ raise TypeError(f"Unknown segmentation type {type(segmentation)}!")
388
+ else:
389
+ # Computing areas using bounding boxes
390
+ if to_bbox_mode == BoxMode.XYWH_ABS:
391
+ bbox_xy = BoxMode.convert(bbox, to_bbox_mode, BoxMode.XYXY_ABS)
392
+ area = Boxes([bbox_xy]).area()[0].item()
393
+ else:
394
+ area = RotatedBoxes([bbox]).area()[0].item()
395
+
396
+ if "keypoints" in annotation:
397
+ keypoints = annotation["keypoints"] # list[int]
398
+ for idx, v in enumerate(keypoints):
399
+ if idx % 3 != 2:
400
+ # COCO's segmentation coordinates are floating points in [0, H or W],
401
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
402
+ # For COCO format consistency we substract 0.5
403
+ # https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163
404
+ keypoints[idx] = v - 0.5
405
+ if "num_keypoints" in annotation:
406
+ num_keypoints = annotation["num_keypoints"]
407
+ else:
408
+ num_keypoints = sum(kp > 0 for kp in keypoints[2::3])
409
+
410
+ # COCO requirement:
411
+ # linking annotations to images
412
+ # "id" field must start with 1
413
+ coco_annotation["id"] = len(coco_annotations) + 1
414
+ coco_annotation["image_id"] = coco_image["id"]
415
+ coco_annotation["bbox"] = [round(float(x), 3) for x in bbox]
416
+ coco_annotation["area"] = float(area)
417
+ coco_annotation["iscrowd"] = int(annotation.get("iscrowd", 0))
418
+ coco_annotation["category_id"] = int(reverse_id_mapper(annotation["category_id"]))
419
+
420
+ # Add optional fields
421
+ if "keypoints" in annotation:
422
+ coco_annotation["keypoints"] = keypoints
423
+ coco_annotation["num_keypoints"] = num_keypoints
424
+
425
+ if "segmentation" in annotation:
426
+ seg = coco_annotation["segmentation"] = annotation["segmentation"]
427
+ if isinstance(seg, dict): # RLE
428
+ counts = seg["counts"]
429
+ if not isinstance(counts, str):
430
+ # make it json-serializable
431
+ seg["counts"] = counts.decode("ascii")
432
+
433
+ coco_annotations.append(coco_annotation)
434
+
435
+ logger.info(
436
+ "Conversion finished, "
437
+ f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}"
438
+ )
439
+
440
+ info = {
441
+ "date_created": str(datetime.datetime.now()),
442
+ "description": "Automatically generated COCO json file for Detectron2.",
443
+ }
444
+ coco_dict = {"info": info, "images": coco_images, "categories": categories, "licenses": None}
445
+ if len(coco_annotations) > 0:
446
+ coco_dict["annotations"] = coco_annotations
447
+ return coco_dict
448
+
449
+
450
+ def convert_to_coco_json(dataset_name, output_file, allow_cached=True):
451
+ """
452
+ Converts dataset into COCO format and saves it to a json file.
453
+ dataset_name must be registered in DatasetCatalog and in detectron2's standard format.
454
+
455
+ Args:
456
+ dataset_name:
457
+ reference from the config file to the catalogs
458
+ must be registered in DatasetCatalog and in detectron2's standard format
459
+ output_file: path of json file that will be saved to
460
+ allow_cached: if json file is already present then skip conversion
461
+ """
462
+
463
+ # TODO: The dataset or the conversion script *may* change,
464
+ # a checksum would be useful for validating the cached data
465
+
466
+ PathManager.mkdirs(os.path.dirname(output_file))
467
+ with file_lock(output_file):
468
+ if PathManager.exists(output_file) and allow_cached:
469
+ logger.warning(
470
+ f"Using previously cached COCO format annotations at '{output_file}'. "
471
+ "You need to clear the cache file if your dataset has been modified."
472
+ )
473
+ else:
474
+ logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)")
475
+ coco_dict = convert_to_coco_dict(dataset_name)
476
+
477
+ logger.info(f"Caching COCO format annotations at '{output_file}' ...")
478
+ tmp_file = output_file + ".tmp"
479
+ with PathManager.open(tmp_file, "w") as f:
480
+ json.dump(coco_dict, f)
481
+ shutil.move(tmp_file, output_file)
482
+
483
+
484
+ def register_coco_instances(name, metadata, json_file, image_root):
485
+ """
486
+ Register a dataset in COCO's json annotation format for
487
+ instance detection, instance segmentation and keypoint detection.
488
+ (i.e., Type 1 and 2 in http://cocodataset.org/#format-data.
489
+ `instances*.json` and `person_keypoints*.json` in the dataset).
490
+
491
+ This is an example of how to register a new dataset.
492
+ You can do something similar to this function, to register new datasets.
493
+
494
+ Args:
495
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
496
+ metadata (dict): extra metadata associated with this dataset. You can
497
+ leave it as an empty dict.
498
+ json_file (str): path to the json instance annotation file.
499
+ image_root (str or path-like): directory which contains all the images.
500
+ """
501
+ assert isinstance(name, str), name
502
+ assert isinstance(json_file, (str, os.PathLike)), json_file
503
+ assert isinstance(image_root, (str, os.PathLike)), image_root
504
+ # 1. register a function which returns dicts
505
+ DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
506
+
507
+ # 2. Optionally, add metadata about this dataset,
508
+ # since they might be useful in evaluation, visualization or logging
509
+ MetadataCatalog.get(name).set(
510
+ json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata
511
+ )
512
+
513
+
514
+ if __name__ == "__main__":
515
+ """
516
+ Test the COCO json dataset loader.
517
+
518
+ Usage:
519
+ python -m detectron2.data.datasets.coco \
520
+ path/to/json path/to/image_root dataset_name
521
+
522
+ "dataset_name" can be "coco_2014_minival_100", or other
523
+ pre-registered ones
524
+ """
525
+ from detectron2.utils.logger import setup_logger
526
+ from detectron2.utils.visualizer import Visualizer
527
+ import detectron2.data.datasets # noqa # add pre-defined metadata
528
+ import sys
529
+
530
+ logger = setup_logger(name=__name__)
531
+ assert sys.argv[3] in DatasetCatalog.list()
532
+ meta = MetadataCatalog.get(sys.argv[3])
533
+
534
+ dicts = load_coco_json(sys.argv[1], sys.argv[2], sys.argv[3])
535
+ logger.info("Done loading {} samples.".format(len(dicts)))
536
+
537
+ dirname = "coco-data-vis"
538
+ os.makedirs(dirname, exist_ok=True)
539
+ for d in dicts:
540
+ img = np.array(Image.open(d["file_name"]))
541
+ visualizer = Visualizer(img, metadata=meta)
542
+ vis = visualizer.draw_dataset_dict(d)
543
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
544
+ vis.save(fpath)
cutler/data/detection_utils.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/detection_utils.py
3
+
4
+ """
5
+ Common data processing utilities that are used in a
6
+ typical object detection data pipeline.
7
+ """
8
+ import logging
9
+ import numpy as np
10
+ from typing import List, Union
11
+ import pycocotools.mask as mask_util
12
+ import torch
13
+ from PIL import Image
14
+
15
+ from detectron2.structures import (
16
+ Boxes,
17
+ BoxMode,
18
+ BitMasks,
19
+ Instances,
20
+ Keypoints,
21
+ PolygonMasks,
22
+ RotatedBoxes,
23
+ polygons_to_bitmask,
24
+ )
25
+
26
+ from detectron2.utils.file_io import PathManager
27
+
28
+ from data import transforms as T
29
+ from detectron2.data.catalog import MetadataCatalog
30
+
31
+ __all__ = [
32
+ "SizeMismatchError",
33
+ "convert_image_to_rgb",
34
+ "check_image_size",
35
+ "transform_proposals",
36
+ "transform_instance_annotations",
37
+ "annotations_to_instances",
38
+ "annotations_to_instances_rotated",
39
+ "build_augmentation",
40
+ "build_transform_gen",
41
+ "create_keypoint_hflip_indices",
42
+ "filter_empty_instances",
43
+ "read_image",
44
+ ]
45
+
46
+
47
+ class SizeMismatchError(ValueError):
48
+ """
49
+ When loaded image has difference width/height compared with annotation.
50
+ """
51
+
52
+
53
+ # https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
54
+ _M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
55
+ _M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]
56
+
57
+ # https://www.exiv2.org/tags.html
58
+ _EXIF_ORIENT = 274 # exif 'Orientation' tag
59
+
60
+
61
+ def convert_PIL_to_numpy(image, format):
62
+ """
63
+ Convert PIL image to numpy array of target format.
64
+
65
+ Args:
66
+ image (PIL.Image): a PIL image
67
+ format (str): the format of output image
68
+
69
+ Returns:
70
+ (np.ndarray): also see `read_image`
71
+ """
72
+ if format is not None:
73
+ # PIL only supports RGB, so convert to RGB and flip channels over below
74
+ conversion_format = format
75
+ if format in ["BGR", "YUV-BT.601"]:
76
+ conversion_format = "RGB"
77
+ image = image.convert(conversion_format)
78
+ image = np.asarray(image)
79
+ # PIL squeezes out the channel dimension for "L", so make it HWC
80
+ if format == "L":
81
+ image = np.expand_dims(image, -1)
82
+
83
+ # handle formats not supported by PIL
84
+ elif format == "BGR":
85
+ # flip channels if needed
86
+ image = image[:, :, ::-1]
87
+ elif format == "YUV-BT.601":
88
+ image = image / 255.0
89
+ image = np.dot(image, np.array(_M_RGB2YUV).T)
90
+
91
+ return image
92
+
93
+
94
+ def convert_image_to_rgb(image, format):
95
+ """
96
+ Convert an image from given format to RGB.
97
+
98
+ Args:
99
+ image (np.ndarray or Tensor): an HWC image
100
+ format (str): the format of input image, also see `read_image`
101
+
102
+ Returns:
103
+ (np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8
104
+ """
105
+ if isinstance(image, torch.Tensor):
106
+ image = image.cpu().numpy()
107
+ if format == "BGR":
108
+ image = image[:, :, [2, 1, 0]]
109
+ elif format == "YUV-BT.601":
110
+ image = np.dot(image, np.array(_M_YUV2RGB).T)
111
+ image = image * 255.0
112
+ else:
113
+ if format == "L":
114
+ image = image[:, :, 0]
115
+ image = image.astype(np.uint8)
116
+ image = np.asarray(Image.fromarray(image, mode=format).convert("RGB"))
117
+ return image
118
+
119
+
120
+ def _apply_exif_orientation(image):
121
+ """
122
+ Applies the exif orientation correctly.
123
+
124
+ This code exists per the bug:
125
+ https://github.com/python-pillow/Pillow/issues/3973
126
+ with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
127
+ various methods, especially `tobytes`
128
+
129
+ Function based on:
130
+ https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
131
+ https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
132
+
133
+ Args:
134
+ image (PIL.Image): a PIL image
135
+
136
+ Returns:
137
+ (PIL.Image): the PIL image with exif orientation applied, if applicable
138
+ """
139
+ if not hasattr(image, "getexif"):
140
+ return image
141
+
142
+ try:
143
+ exif = image.getexif()
144
+ except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
145
+ exif = None
146
+
147
+ if exif is None:
148
+ return image
149
+
150
+ orientation = exif.get(_EXIF_ORIENT)
151
+
152
+ method = {
153
+ 2: Image.FLIP_LEFT_RIGHT,
154
+ 3: Image.ROTATE_180,
155
+ 4: Image.FLIP_TOP_BOTTOM,
156
+ 5: Image.TRANSPOSE,
157
+ 6: Image.ROTATE_270,
158
+ 7: Image.TRANSVERSE,
159
+ 8: Image.ROTATE_90,
160
+ }.get(orientation)
161
+
162
+ if method is not None:
163
+ return image.transpose(method)
164
+ return image
165
+
166
+
167
+ def read_image(file_name, format=None):
168
+ """
169
+ Read an image into the given format.
170
+ Will apply rotation and flipping if the image has such exif information.
171
+
172
+ Args:
173
+ file_name (str): image file path
174
+ format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601".
175
+
176
+ Returns:
177
+ image (np.ndarray):
178
+ an HWC image in the given format, which is 0-255, uint8 for
179
+ supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601.
180
+ """
181
+ with PathManager.open(file_name, "rb") as f:
182
+ image = Image.open(f)
183
+
184
+ # work around this bug: https://github.com/python-pillow/Pillow/issues/3973
185
+ image = _apply_exif_orientation(image)
186
+ return convert_PIL_to_numpy(image, format)
187
+
188
+
189
+ def check_image_size(dataset_dict, image):
190
+ """
191
+ Raise an error if the image does not match the size specified in the dict.
192
+ """
193
+ if "width" in dataset_dict or "height" in dataset_dict:
194
+ image_wh = (image.shape[1], image.shape[0])
195
+ expected_wh = (dataset_dict["width"], dataset_dict["height"])
196
+ if not image_wh == expected_wh:
197
+ expected_wh = (dataset_dict["height"], dataset_dict["width"])
198
+ dataset_dict["height"], dataset_dict["width"] = dataset_dict["width"], dataset_dict["height"]
199
+ if image_wh != expected_wh:
200
+ raise SizeMismatchError(
201
+ "Mismatched image shape{}, got {}, expect {}.".format(
202
+ " for image " + dataset_dict["file_name"]
203
+ if "file_name" in dataset_dict
204
+ else "",
205
+ image_wh,
206
+ expected_wh,
207
+ )
208
+ + " Please check the width/height in your annotation."
209
+ )
210
+
211
+ # To ensure bbox always remap to original image size
212
+ if "width" not in dataset_dict:
213
+ dataset_dict["width"] = image.shape[1]
214
+ if "height" not in dataset_dict:
215
+ dataset_dict["height"] = image.shape[0]
216
+
217
+
218
+ def transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0):
219
+ """
220
+ Apply transformations to the proposals in dataset_dict, if any.
221
+
222
+ Args:
223
+ dataset_dict (dict): a dict read from the dataset, possibly
224
+ contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
225
+ image_shape (tuple): height, width
226
+ transforms (TransformList):
227
+ proposal_topk (int): only keep top-K scoring proposals
228
+ min_box_size (int): proposals with either side smaller than this
229
+ threshold are removed
230
+
231
+ The input dict is modified in-place, with abovementioned keys removed. A new
232
+ key "proposals" will be added. Its value is an `Instances`
233
+ object which contains the transformed proposals in its field
234
+ "proposal_boxes" and "objectness_logits".
235
+ """
236
+ if "proposal_boxes" in dataset_dict:
237
+ # Transform proposal boxes
238
+ boxes = transforms.apply_box(
239
+ BoxMode.convert(
240
+ dataset_dict.pop("proposal_boxes"),
241
+ dataset_dict.pop("proposal_bbox_mode"),
242
+ BoxMode.XYXY_ABS,
243
+ )
244
+ )
245
+ boxes = Boxes(boxes)
246
+ objectness_logits = torch.as_tensor(
247
+ dataset_dict.pop("proposal_objectness_logits").astype("float32")
248
+ )
249
+
250
+ boxes.clip(image_shape)
251
+ keep = boxes.nonempty(threshold=min_box_size)
252
+ boxes = boxes[keep]
253
+ objectness_logits = objectness_logits[keep]
254
+
255
+ proposals = Instances(image_shape)
256
+ proposals.proposal_boxes = boxes[:proposal_topk]
257
+ proposals.objectness_logits = objectness_logits[:proposal_topk]
258
+ dataset_dict["proposals"] = proposals
259
+
260
+
261
+ def transform_instance_annotations(
262
+ annotation, transforms, image_size, *, keypoint_hflip_indices=None
263
+ ):
264
+ """
265
+ Apply transforms to box, segmentation and keypoints annotations of a single instance.
266
+
267
+ It will use `transforms.apply_box` for the box, and
268
+ `transforms.apply_coords` for segmentation polygons & keypoints.
269
+ If you need anything more specially designed for each data structure,
270
+ you'll need to implement your own version of this function or the transforms.
271
+
272
+ Args:
273
+ annotation (dict): dict of instance annotations for a single instance.
274
+ It will be modified in-place.
275
+ transforms (TransformList or list[Transform]):
276
+ image_size (tuple): the height, width of the transformed image
277
+ keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
278
+
279
+ Returns:
280
+ dict:
281
+ the same input dict with fields "bbox", "segmentation", "keypoints"
282
+ transformed according to `transforms`.
283
+ The "bbox_mode" field will be set to XYXY_ABS.
284
+ """
285
+ if isinstance(transforms, (tuple, list)):
286
+ transforms = T.TransformList(transforms)
287
+ # bbox is 1d (per-instance bounding box)
288
+ bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
289
+ # clip transformed bbox to image size
290
+ bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
291
+ annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
292
+ annotation["bbox_mode"] = BoxMode.XYXY_ABS
293
+
294
+ if "segmentation" in annotation:
295
+ # each instance contains 1 or more polygons
296
+ segm = annotation["segmentation"]
297
+ if isinstance(segm, list):
298
+ # polygons
299
+ polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
300
+ annotation["segmentation"] = [
301
+ p.reshape(-1) for p in transforms.apply_polygons(polygons)
302
+ ]
303
+ elif isinstance(segm, dict):
304
+ # RLE
305
+ mask = mask_util.decode(segm)
306
+ mask = transforms.apply_segmentation(mask)
307
+ assert tuple(mask.shape[:2]) == image_size
308
+ annotation["segmentation"] = mask
309
+ else:
310
+ raise ValueError(
311
+ "Cannot transform segmentation of type '{}'!"
312
+ "Supported types are: polygons as list[list[float] or ndarray],"
313
+ " COCO-style RLE as a dict.".format(type(segm))
314
+ )
315
+
316
+ if "keypoints" in annotation:
317
+ keypoints = transform_keypoint_annotations(
318
+ annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
319
+ )
320
+ annotation["keypoints"] = keypoints
321
+
322
+ return annotation
323
+
324
+
325
+ def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None):
326
+ """
327
+ Transform keypoint annotations of an image.
328
+ If a keypoint is transformed out of image boundary, it will be marked "unlabeled" (visibility=0)
329
+
330
+ Args:
331
+ keypoints (list[float]): Nx3 float in Detectron2's Dataset format.
332
+ Each point is represented by (x, y, visibility).
333
+ transforms (TransformList):
334
+ image_size (tuple): the height, width of the transformed image
335
+ keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
336
+ When `transforms` includes horizontal flip, will use the index
337
+ mapping to flip keypoints.
338
+ """
339
+ # (N*3,) -> (N, 3)
340
+ keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3)
341
+ keypoints_xy = transforms.apply_coords(keypoints[:, :2])
342
+
343
+ # Set all out-of-boundary points to "unlabeled"
344
+ inside = (keypoints_xy >= np.array([0, 0])) & (keypoints_xy <= np.array(image_size[::-1]))
345
+ inside = inside.all(axis=1)
346
+ keypoints[:, :2] = keypoints_xy
347
+ keypoints[:, 2][~inside] = 0
348
+
349
+ # This assumes that HorizFlipTransform is the only one that does flip
350
+ do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
351
+
352
+ # Alternative way: check if probe points was horizontally flipped.
353
+ # probe = np.asarray([[0.0, 0.0], [image_width, 0.0]])
354
+ # probe_aug = transforms.apply_coords(probe.copy())
355
+ # do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa
356
+
357
+ # If flipped, swap each keypoint with its opposite-handed equivalent
358
+ if do_hflip:
359
+ if keypoint_hflip_indices is None:
360
+ raise ValueError("Cannot flip keypoints without providing flip indices!")
361
+ if len(keypoints) != len(keypoint_hflip_indices):
362
+ raise ValueError(
363
+ "Keypoint data has {} points, but metadata "
364
+ "contains {} points!".format(len(keypoints), len(keypoint_hflip_indices))
365
+ )
366
+ keypoints = keypoints[np.asarray(keypoint_hflip_indices, dtype=np.int32), :]
367
+
368
+ # Maintain COCO convention that if visibility == 0 (unlabeled), then x, y = 0
369
+ keypoints[keypoints[:, 2] == 0] = 0
370
+ return keypoints
371
+
372
+
373
+ def annotations_to_instances(annos, image_size, mask_format="polygon"):
374
+ """
375
+ Create an :class:`Instances` object used by the models,
376
+ from instance annotations in the dataset dict.
377
+
378
+ Args:
379
+ annos (list[dict]): a list of instance annotations in one image, each
380
+ element for one instance.
381
+ image_size (tuple): height, width
382
+
383
+ Returns:
384
+ Instances:
385
+ It will contain fields "gt_boxes", "gt_classes",
386
+ "gt_masks", "gt_keypoints", if they can be obtained from `annos`.
387
+ This is the format that builtin models expect.
388
+ """
389
+ boxes = (
390
+ np.stack(
391
+ [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
392
+ )
393
+ if len(annos)
394
+ else np.zeros((0, 4))
395
+ )
396
+ target = Instances(image_size)
397
+ target.gt_boxes = Boxes(boxes)
398
+
399
+ classes = [int(obj["category_id"]) for obj in annos]
400
+ classes = torch.tensor(classes, dtype=torch.int64)
401
+ target.gt_classes = classes
402
+
403
+ if len(annos) and "segmentation" in annos[0]:
404
+ segms = [obj["segmentation"] for obj in annos]
405
+ if mask_format == "polygon":
406
+ try:
407
+ masks = PolygonMasks(segms)
408
+ except ValueError as e:
409
+ raise ValueError(
410
+ "Failed to use mask_format=='polygon' from the given annotations!"
411
+ ) from e
412
+ else:
413
+ assert mask_format == "bitmask", mask_format
414
+ masks = []
415
+ for segm in segms:
416
+ if isinstance(segm, list):
417
+ # polygon
418
+ masks.append(polygons_to_bitmask(segm, *image_size))
419
+ elif isinstance(segm, dict):
420
+ # COCO RLE
421
+ masks.append(mask_util.decode(segm))
422
+ elif isinstance(segm, np.ndarray):
423
+ assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
424
+ segm.ndim
425
+ )
426
+ # mask array
427
+ masks.append(segm)
428
+ else:
429
+ raise ValueError(
430
+ "Cannot convert segmentation of type '{}' to BitMasks!"
431
+ "Supported types are: polygons as list[list[float] or ndarray],"
432
+ " COCO-style RLE as a dict, or a binary segmentation mask "
433
+ " in a 2D numpy array of shape HxW.".format(type(segm))
434
+ )
435
+ # torch.from_numpy does not support array with negative stride.
436
+ masks = BitMasks(
437
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
438
+ )
439
+ target.gt_masks = masks
440
+
441
+ if len(annos) and "keypoints" in annos[0]:
442
+ kpts = [obj.get("keypoints", []) for obj in annos]
443
+ target.gt_keypoints = Keypoints(kpts)
444
+
445
+ return target
446
+
447
+
448
+ def annotations_to_instances_rotated(annos, image_size):
449
+ """
450
+ Create an :class:`Instances` object used by the models,
451
+ from instance annotations in the dataset dict.
452
+ Compared to `annotations_to_instances`, this function is for rotated boxes only
453
+
454
+ Args:
455
+ annos (list[dict]): a list of instance annotations in one image, each
456
+ element for one instance.
457
+ image_size (tuple): height, width
458
+
459
+ Returns:
460
+ Instances:
461
+ Containing fields "gt_boxes", "gt_classes",
462
+ if they can be obtained from `annos`.
463
+ This is the format that builtin models expect.
464
+ """
465
+ boxes = [obj["bbox"] for obj in annos]
466
+ target = Instances(image_size)
467
+ boxes = target.gt_boxes = RotatedBoxes(boxes)
468
+ boxes.clip(image_size)
469
+
470
+ classes = [obj["category_id"] for obj in annos]
471
+ classes = torch.tensor(classes, dtype=torch.int64)
472
+ target.gt_classes = classes
473
+
474
+ return target
475
+
476
+
477
+ def filter_empty_instances(
478
+ instances, by_box=True, by_mask=True, box_threshold=1e-5, return_mask=False
479
+ ):
480
+ """
481
+ Filter out empty instances in an `Instances` object.
482
+
483
+ Args:
484
+ instances (Instances):
485
+ by_box (bool): whether to filter out instances with empty boxes
486
+ by_mask (bool): whether to filter out instances with empty masks
487
+ box_threshold (float): minimum width and height to be considered non-empty
488
+ return_mask (bool): whether to return boolean mask of filtered instances
489
+
490
+ Returns:
491
+ Instances: the filtered instances.
492
+ tensor[bool], optional: boolean mask of filtered instances
493
+ """
494
+ assert by_box or by_mask
495
+ r = []
496
+ if by_box:
497
+ r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
498
+ if instances.has("gt_masks") and by_mask:
499
+ r.append(instances.gt_masks.nonempty())
500
+
501
+ # TODO: can also filter visible keypoints
502
+
503
+ if not r:
504
+ return instances
505
+ m = r[0]
506
+ for x in r[1:]:
507
+ m = m & x
508
+ if return_mask:
509
+ return instances[m], m
510
+ return instances[m]
511
+
512
+
513
+ def create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) -> List[int]:
514
+ """
515
+ Args:
516
+ dataset_names: list of dataset names
517
+
518
+ Returns:
519
+ list[int]: a list of size=#keypoints, storing the
520
+ horizontally-flipped keypoint indices.
521
+ """
522
+ if isinstance(dataset_names, str):
523
+ dataset_names = [dataset_names]
524
+
525
+ check_metadata_consistency("keypoint_names", dataset_names)
526
+ check_metadata_consistency("keypoint_flip_map", dataset_names)
527
+
528
+ meta = MetadataCatalog.get(dataset_names[0])
529
+ names = meta.keypoint_names
530
+ # TODO flip -> hflip
531
+ flip_map = dict(meta.keypoint_flip_map)
532
+ flip_map.update({v: k for k, v in flip_map.items()})
533
+ flipped_names = [i if i not in flip_map else flip_map[i] for i in names]
534
+ flip_indices = [names.index(i) for i in flipped_names]
535
+ return flip_indices
536
+
537
+
538
+ def get_fed_loss_cls_weights(dataset_names: Union[str, List[str]], freq_weight_power=1.0):
539
+ """
540
+ Get frequency weight for each class sorted by class id.
541
+ We now calcualte freqency weight using image_count to the power freq_weight_power.
542
+
543
+ Args:
544
+ dataset_names: list of dataset names
545
+ freq_weight_power: power value
546
+ """
547
+ if isinstance(dataset_names, str):
548
+ dataset_names = [dataset_names]
549
+
550
+ check_metadata_consistency("class_image_count", dataset_names)
551
+
552
+ meta = MetadataCatalog.get(dataset_names[0])
553
+ class_freq_meta = meta.class_image_count
554
+ class_freq = torch.tensor(
555
+ [c["image_count"] for c in sorted(class_freq_meta, key=lambda x: x["id"])]
556
+ )
557
+ class_freq_weight = class_freq.float() ** freq_weight_power
558
+ return class_freq_weight
559
+
560
+
561
+ def gen_crop_transform_with_instance(crop_size, image_size, instance):
562
+ """
563
+ Generate a CropTransform so that the cropping region contains
564
+ the center of the given instance.
565
+
566
+ Args:
567
+ crop_size (tuple): h, w in pixels
568
+ image_size (tuple): h, w
569
+ instance (dict): an annotation dict of one instance, in Detectron2's
570
+ dataset format.
571
+ """
572
+ crop_size = np.asarray(crop_size, dtype=np.int32)
573
+ bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
574
+ center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
575
+ assert (
576
+ image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
577
+ ), "The annotation bounding box is outside of the image!"
578
+ assert (
579
+ image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
580
+ ), "Crop size is larger than image size!"
581
+
582
+ min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
583
+ max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
584
+ max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
585
+
586
+ y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
587
+ x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
588
+ return T.CropTransform(x0, y0, crop_size[1], crop_size[0])
589
+
590
+
591
+ def check_metadata_consistency(key, dataset_names):
592
+ """
593
+ Check that the datasets have consistent metadata.
594
+
595
+ Args:
596
+ key (str): a metadata key
597
+ dataset_names (list[str]): a list of dataset names
598
+
599
+ Raises:
600
+ AttributeError: if the key does not exist in the metadata
601
+ ValueError: if the given datasets do not have the same metadata values defined by key
602
+ """
603
+ if len(dataset_names) == 0:
604
+ return
605
+ logger = logging.getLogger(__name__)
606
+ entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names]
607
+ for idx, entry in enumerate(entries_per_dataset):
608
+ if entry != entries_per_dataset[0]:
609
+ logger.error(
610
+ "Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry))
611
+ )
612
+ logger.error(
613
+ "Metadata '{}' for dataset '{}' is '{}'".format(
614
+ key, dataset_names[0], str(entries_per_dataset[0])
615
+ )
616
+ )
617
+ raise ValueError("Datasets have different metadata '{}'!".format(key))
618
+
619
+
620
+ def build_augmentation(cfg, is_train):
621
+ """
622
+ Create a list of default :class:`Augmentation` from config.
623
+ Now it includes resizing and flipping.
624
+
625
+ Returns:
626
+ list[Augmentation]
627
+ """
628
+ if is_train:
629
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
630
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
631
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
632
+ else:
633
+ min_size = cfg.INPUT.MIN_SIZE_TEST
634
+ max_size = cfg.INPUT.MAX_SIZE_TEST
635
+ sample_style = "choice"
636
+ augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
637
+ if is_train and cfg.INPUT.RANDOM_FLIP != "none":
638
+ augmentation.append(
639
+ T.RandomFlip(
640
+ horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
641
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
642
+ )
643
+ )
644
+ return augmentation
645
+
646
+
647
+ build_transform_gen = build_augmentation
648
+ """
649
+ Alias for backward-compatibility.
650
+ """
cutler/data/transforms/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/__init__.py
3
+
4
+ from fvcore.transforms.transform import *
5
+ from .transform import *
6
+ from detectron2.data.transforms.augmentation import *
7
+ from .augmentation_impl import *
8
+
9
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
10
+
11
+
12
+ from detectron2.utils.env import fixup_module_metadata
13
+
14
+ fixup_module_metadata(__name__, globals(), __all__)
15
+ del fixup_module_metadata
cutler/data/transforms/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (728 Bytes). View file
 
cutler/data/transforms/__pycache__/augmentation_impl.cpython-312.pyc ADDED
Binary file (32.9 kB). View file
 
cutler/data/transforms/__pycache__/transform.cpython-312.pyc ADDED
Binary file (19.7 kB). View file
 
cutler/data/transforms/augmentation_impl.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py
4
+
5
+ """
6
+ Implement many useful :class:`Augmentation`.
7
+ """
8
+ import numpy as np
9
+ import sys
10
+ from typing import Tuple
11
+ import torch
12
+ from fvcore.transforms.transform import (
13
+ BlendTransform,
14
+ CropTransform,
15
+ HFlipTransform,
16
+ NoOpTransform,
17
+ PadTransform,
18
+ Transform,
19
+ TransformList,
20
+ VFlipTransform,
21
+ )
22
+ from PIL import Image
23
+
24
+ from detectron2.data.transforms.augmentation import Augmentation, _transform_to_aug
25
+ from .transform import ExtentTransform, ResizeTransform, RotationTransform
26
+
27
+ __all__ = [
28
+ "FixedSizeCrop",
29
+ "RandomApply",
30
+ "RandomBrightness",
31
+ "RandomContrast",
32
+ "RandomCrop",
33
+ "RandomExtent",
34
+ "RandomFlip",
35
+ "RandomSaturation",
36
+ "RandomLighting",
37
+ "RandomRotation",
38
+ "Resize",
39
+ "ResizeScale",
40
+ "ResizeShortestEdge",
41
+ "RandomCrop_CategoryAreaConstraint",
42
+ ]
43
+
44
+
45
+ class RandomApply(Augmentation):
46
+ """
47
+ Randomly apply an augmentation with a given probability.
48
+ """
49
+
50
+ def __init__(self, tfm_or_aug, prob=0.5):
51
+ """
52
+ Args:
53
+ tfm_or_aug (Transform, Augmentation): the transform or augmentation
54
+ to be applied. It can either be a `Transform` or `Augmentation`
55
+ instance.
56
+ prob (float): probability between 0.0 and 1.0 that
57
+ the wrapper transformation is applied
58
+ """
59
+ super().__init__()
60
+ self.aug = _transform_to_aug(tfm_or_aug)
61
+ assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
62
+ self.prob = prob
63
+
64
+ def get_transform(self, *args):
65
+ do = self._rand_range() < self.prob
66
+ if do:
67
+ return self.aug.get_transform(*args)
68
+ else:
69
+ return NoOpTransform()
70
+
71
+ def __call__(self, aug_input):
72
+ do = self._rand_range() < self.prob
73
+ if do:
74
+ return self.aug(aug_input)
75
+ else:
76
+ return NoOpTransform()
77
+
78
+
79
+ class RandomFlip(Augmentation):
80
+ """
81
+ Flip the image horizontally or vertically with the given probability.
82
+ """
83
+
84
+ def __init__(self, prob=0.5, *, horizontal=True, vertical=False):
85
+ """
86
+ Args:
87
+ prob (float): probability of flip.
88
+ horizontal (boolean): whether to apply horizontal flipping
89
+ vertical (boolean): whether to apply vertical flipping
90
+ """
91
+ super().__init__()
92
+
93
+ if horizontal and vertical:
94
+ raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
95
+ if not horizontal and not vertical:
96
+ raise ValueError("At least one of horiz or vert has to be True!")
97
+ self._init(locals())
98
+
99
+ def get_transform(self, image):
100
+ h, w = image.shape[:2]
101
+ do = self._rand_range() < self.prob
102
+ if do:
103
+ if self.horizontal:
104
+ return HFlipTransform(w)
105
+ elif self.vertical:
106
+ return VFlipTransform(h)
107
+ else:
108
+ return NoOpTransform()
109
+
110
+
111
+ class Resize(Augmentation):
112
+ """Resize image to a fixed target size"""
113
+
114
+ def __init__(self, shape, interp=Image.BILINEAR):
115
+ """
116
+ Args:
117
+ shape: (h, w) tuple or a int
118
+ interp: PIL interpolation method
119
+ """
120
+ if isinstance(shape, int):
121
+ shape = (shape, shape)
122
+ shape = tuple(shape)
123
+ self._init(locals())
124
+
125
+ def get_transform(self, image):
126
+ return ResizeTransform(
127
+ image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp
128
+ )
129
+
130
+
131
+ class ResizeShortestEdge(Augmentation):
132
+ """
133
+ Resize the image while keeping the aspect ratio unchanged.
134
+ It attempts to scale the shorter edge to the given `short_edge_length`,
135
+ as long as the longer edge does not exceed `max_size`.
136
+ If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
137
+ """
138
+
139
+ @torch.jit.unused
140
+ def __init__(
141
+ self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
142
+ ):
143
+ """
144
+ Args:
145
+ short_edge_length (list[int]): If ``sample_style=="range"``,
146
+ a [min, max] interval from which to sample the shortest edge length.
147
+ If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
148
+ max_size (int): maximum allowed longest edge length.
149
+ sample_style (str): either "range" or "choice".
150
+ """
151
+ super().__init__()
152
+ assert sample_style in ["range", "choice"], sample_style
153
+
154
+ self.is_range = sample_style == "range"
155
+ if isinstance(short_edge_length, int):
156
+ short_edge_length = (short_edge_length, short_edge_length)
157
+ if self.is_range:
158
+ assert len(short_edge_length) == 2, (
159
+ "short_edge_length must be two values using 'range' sample style."
160
+ f" Got {short_edge_length}!"
161
+ )
162
+ self._init(locals())
163
+
164
+ @torch.jit.unused
165
+ def get_transform(self, image):
166
+ h, w = image.shape[:2]
167
+ if self.is_range:
168
+ size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
169
+ else:
170
+ size = np.random.choice(self.short_edge_length)
171
+ if size == 0:
172
+ return NoOpTransform()
173
+
174
+ newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size)
175
+ return ResizeTransform(h, w, newh, neww, self.interp)
176
+
177
+ @staticmethod
178
+ def get_output_shape(
179
+ oldh: int, oldw: int, short_edge_length: int, max_size: int
180
+ ) -> Tuple[int, int]:
181
+ """
182
+ Compute the output size given input size and target short edge length.
183
+ """
184
+ h, w = oldh, oldw
185
+ size = short_edge_length * 1.0
186
+ scale = size / min(h, w)
187
+ if h < w:
188
+ newh, neww = size, scale * w
189
+ else:
190
+ newh, neww = scale * h, size
191
+ if max(newh, neww) > max_size:
192
+ scale = max_size * 1.0 / max(newh, neww)
193
+ newh = newh * scale
194
+ neww = neww * scale
195
+ neww = int(neww + 0.5)
196
+ newh = int(newh + 0.5)
197
+ return (newh, neww)
198
+
199
+
200
+ class ResizeScale(Augmentation):
201
+ """
202
+ Takes target size as input and randomly scales the given target size between `min_scale`
203
+ and `max_scale`. It then scales the input image such that it fits inside the scaled target
204
+ box, keeping the aspect ratio constant.
205
+ This implements the resize part of the Google's 'resize_and_crop' data augmentation:
206
+ https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ min_scale: float,
212
+ max_scale: float,
213
+ target_height: int,
214
+ target_width: int,
215
+ interp: int = Image.BILINEAR,
216
+ ):
217
+ """
218
+ Args:
219
+ min_scale: minimum image scale range.
220
+ max_scale: maximum image scale range.
221
+ target_height: target image height.
222
+ target_width: target image width.
223
+ interp: image interpolation method.
224
+ """
225
+ super().__init__()
226
+ self._init(locals())
227
+
228
+ def _get_resize(self, image: np.ndarray, scale: float) -> Transform:
229
+ input_size = image.shape[:2]
230
+
231
+ # Compute new target size given a scale.
232
+ target_size = (self.target_height, self.target_width)
233
+ target_scale_size = np.multiply(target_size, scale)
234
+
235
+ # Compute actual rescaling applied to input image and output size.
236
+ output_scale = np.minimum(
237
+ target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
238
+ )
239
+ output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
240
+
241
+ return ResizeTransform(
242
+ input_size[0], input_size[1], output_size[0], output_size[1], self.interp
243
+ )
244
+
245
+ def get_transform(self, image: np.ndarray) -> Transform:
246
+ random_scale = np.random.uniform(self.min_scale, self.max_scale)
247
+ return self._get_resize(image, random_scale)
248
+
249
+
250
+ class RandomRotation(Augmentation):
251
+ """
252
+ This method returns a copy of this image, rotated the given
253
+ number of degrees counter clockwise around the given center.
254
+ """
255
+
256
+ def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
257
+ """
258
+ Args:
259
+ angle (list[float]): If ``sample_style=="range"``,
260
+ a [min, max] interval from which to sample the angle (in degrees).
261
+ If ``sample_style=="choice"``, a list of angles to sample from
262
+ expand (bool): choose if the image should be resized to fit the whole
263
+ rotated image (default), or simply cropped
264
+ center (list[[float, float]]): If ``sample_style=="range"``,
265
+ a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
266
+ [0, 0] being the top left of the image and [1, 1] the bottom right.
267
+ If ``sample_style=="choice"``, a list of centers to sample from
268
+ Default: None, which means that the center of rotation is the center of the image
269
+ center has no effect if expand=True because it only affects shifting
270
+ """
271
+ super().__init__()
272
+ assert sample_style in ["range", "choice"], sample_style
273
+ self.is_range = sample_style == "range"
274
+ if isinstance(angle, (float, int)):
275
+ angle = (angle, angle)
276
+ if center is not None and isinstance(center[0], (float, int)):
277
+ center = (center, center)
278
+ self._init(locals())
279
+
280
+ def get_transform(self, image):
281
+ h, w = image.shape[:2]
282
+ center = None
283
+ if self.is_range:
284
+ angle = np.random.uniform(self.angle[0], self.angle[1])
285
+ if self.center is not None:
286
+ center = (
287
+ np.random.uniform(self.center[0][0], self.center[1][0]),
288
+ np.random.uniform(self.center[0][1], self.center[1][1]),
289
+ )
290
+ else:
291
+ angle = np.random.choice(self.angle)
292
+ if self.center is not None:
293
+ center = np.random.choice(self.center)
294
+
295
+ if center is not None:
296
+ center = (w * center[0], h * center[1]) # Convert to absolute coordinates
297
+
298
+ if angle % 360 == 0:
299
+ return NoOpTransform()
300
+
301
+ return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
302
+
303
+
304
+ class FixedSizeCrop(Augmentation):
305
+ """
306
+ If `crop_size` is smaller than the input image size, then it uses a random crop of
307
+ the crop size. If `crop_size` is larger than the input image size, then it pads
308
+ the right and the bottom of the image to the crop size if `pad` is True, otherwise
309
+ it returns the smaller image.
310
+ """
311
+
312
+ def __init__(self, crop_size: Tuple[int], pad: bool = True, pad_value: float = 128.0):
313
+ """
314
+ Args:
315
+ crop_size: target image (height, width).
316
+ pad: if True, will pad images smaller than `crop_size` up to `crop_size`
317
+ pad_value: the padding value.
318
+ """
319
+ super().__init__()
320
+ self._init(locals())
321
+
322
+ def _get_crop(self, image: np.ndarray) -> Transform:
323
+ # Compute the image scale and scaled size.
324
+ input_size = image.shape[:2]
325
+ output_size = self.crop_size
326
+
327
+ # Add random crop if the image is scaled up.
328
+ max_offset = np.subtract(input_size, output_size)
329
+ max_offset = np.maximum(max_offset, 0)
330
+ offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
331
+ offset = np.round(offset).astype(int)
332
+ return CropTransform(
333
+ offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
334
+ )
335
+
336
+ def _get_pad(self, image: np.ndarray) -> Transform:
337
+ # Compute the image scale and scaled size.
338
+ input_size = image.shape[:2]
339
+ output_size = self.crop_size
340
+
341
+ # Add padding if the image is scaled down.
342
+ pad_size = np.subtract(output_size, input_size)
343
+ pad_size = np.maximum(pad_size, 0)
344
+ original_size = np.minimum(input_size, output_size)
345
+ return PadTransform(
346
+ 0, 0, pad_size[1], pad_size[0], original_size[1], original_size[0], self.pad_value
347
+ )
348
+
349
+ def get_transform(self, image: np.ndarray) -> TransformList:
350
+ transforms = [self._get_crop(image)]
351
+ if self.pad:
352
+ transforms.append(self._get_pad(image))
353
+ return TransformList(transforms)
354
+
355
+
356
+ class RandomCrop(Augmentation):
357
+ """
358
+ Randomly crop a rectangle region out of an image.
359
+ """
360
+
361
+ def __init__(self, crop_type: str, crop_size):
362
+ """
363
+ Args:
364
+ crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
365
+ crop_size (tuple[float, float]): two floats, explained below.
366
+
367
+ - "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
368
+ size (H, W). crop size should be in (0, 1]
369
+ - "relative_range": uniformly sample two values from [crop_size[0], 1]
370
+ and [crop_size[1]], 1], and use them as in "relative" crop type.
371
+ - "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
372
+ crop_size must be smaller than the input image size.
373
+ - "absolute_range", for an input of size (H, W), uniformly sample H_crop in
374
+ [crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
375
+ Then crop a region (H_crop, W_crop).
376
+ """
377
+ # TODO style of relative_range and absolute_range are not consistent:
378
+ # one takes (h, w) but another takes (min, max)
379
+ super().__init__()
380
+ assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
381
+ self._init(locals())
382
+
383
+ def get_transform(self, image):
384
+ h, w = image.shape[:2]
385
+ croph, cropw = self.get_crop_size((h, w))
386
+ assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
387
+ h0 = np.random.randint(h - croph + 1)
388
+ w0 = np.random.randint(w - cropw + 1)
389
+ return CropTransform(w0, h0, cropw, croph)
390
+
391
+ def get_crop_size(self, image_size):
392
+ """
393
+ Args:
394
+ image_size (tuple): height, width
395
+
396
+ Returns:
397
+ crop_size (tuple): height, width in absolute pixels
398
+ """
399
+ h, w = image_size
400
+ if self.crop_type == "relative":
401
+ ch, cw = self.crop_size
402
+ return int(h * ch + 0.5), int(w * cw + 0.5)
403
+ elif self.crop_type == "relative_range":
404
+ crop_size = np.asarray(self.crop_size, dtype=np.float32)
405
+ ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
406
+ return int(h * ch + 0.5), int(w * cw + 0.5)
407
+ elif self.crop_type == "absolute":
408
+ return (min(self.crop_size[0], h), min(self.crop_size[1], w))
409
+ elif self.crop_type == "absolute_range":
410
+ assert self.crop_size[0] <= self.crop_size[1]
411
+ ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
412
+ cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
413
+ return ch, cw
414
+ else:
415
+ raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
416
+
417
+
418
+ class RandomCrop_CategoryAreaConstraint(Augmentation):
419
+ """
420
+ Similar to :class:`RandomCrop`, but find a cropping window such that no single category
421
+ occupies a ratio of more than `single_category_max_area` in semantic segmentation ground
422
+ truth, which can cause unstability in training. The function attempts to find such a valid
423
+ cropping window for at most 10 times.
424
+ """
425
+
426
+ def __init__(
427
+ self,
428
+ crop_type: str,
429
+ crop_size,
430
+ single_category_max_area: float = 1.0,
431
+ ignored_category: int = None,
432
+ ):
433
+ """
434
+ Args:
435
+ crop_type, crop_size: same as in :class:`RandomCrop`
436
+ single_category_max_area: the maximum allowed area ratio of a
437
+ category. Set to 1.0 to disable
438
+ ignored_category: allow this category in the semantic segmentation
439
+ ground truth to exceed the area ratio. Usually set to the category
440
+ that's ignored in training.
441
+ """
442
+ self.crop_aug = RandomCrop(crop_type, crop_size)
443
+ self._init(locals())
444
+
445
+ def get_transform(self, image, sem_seg):
446
+ if self.single_category_max_area >= 1.0:
447
+ return self.crop_aug.get_transform(image)
448
+ else:
449
+ h, w = sem_seg.shape
450
+ for _ in range(10):
451
+ crop_size = self.crop_aug.get_crop_size((h, w))
452
+ y0 = np.random.randint(h - crop_size[0] + 1)
453
+ x0 = np.random.randint(w - crop_size[1] + 1)
454
+ sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
455
+ labels, cnt = np.unique(sem_seg_temp, return_counts=True)
456
+ if self.ignored_category is not None:
457
+ cnt = cnt[labels != self.ignored_category]
458
+ if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area:
459
+ break
460
+ crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
461
+ return crop_tfm
462
+
463
+
464
+ class RandomExtent(Augmentation):
465
+ """
466
+ Outputs an image by cropping a random "subrect" of the source image.
467
+
468
+ The subrect can be parameterized to include pixels outside the source image,
469
+ in which case they will be set to zeros (i.e. black). The size of the output
470
+ image will vary with the size of the random subrect.
471
+ """
472
+
473
+ def __init__(self, scale_range, shift_range):
474
+ """
475
+ Args:
476
+ output_size (h, w): Dimensions of output image
477
+ scale_range (l, h): Range of input-to-output size scaling factor
478
+ shift_range (x, y): Range of shifts of the cropped subrect. The rect
479
+ is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
480
+ where (w, h) is the (width, height) of the input image. Set each
481
+ component to zero to crop at the image's center.
482
+ """
483
+ super().__init__()
484
+ self._init(locals())
485
+
486
+ def get_transform(self, image):
487
+ img_h, img_w = image.shape[:2]
488
+
489
+ # Initialize src_rect to fit the input image.
490
+ src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
491
+
492
+ # Apply a random scaling to the src_rect.
493
+ src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
494
+
495
+ # Apply a random shift to the coordinates origin.
496
+ src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
497
+ src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
498
+
499
+ # Map src_rect coordinates into image coordinates (center at corner).
500
+ src_rect[0::2] += 0.5 * img_w
501
+ src_rect[1::2] += 0.5 * img_h
502
+
503
+ return ExtentTransform(
504
+ src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
505
+ output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
506
+ )
507
+
508
+
509
+ class RandomContrast(Augmentation):
510
+ """
511
+ Randomly transforms image contrast.
512
+
513
+ Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
514
+ - intensity < 1 will reduce contrast
515
+ - intensity = 1 will preserve the input image
516
+ - intensity > 1 will increase contrast
517
+
518
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
519
+ """
520
+
521
+ def __init__(self, intensity_min, intensity_max):
522
+ """
523
+ Args:
524
+ intensity_min (float): Minimum augmentation
525
+ intensity_max (float): Maximum augmentation
526
+ """
527
+ super().__init__()
528
+ self._init(locals())
529
+
530
+ def get_transform(self, image):
531
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
532
+ return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w)
533
+
534
+
535
+ class RandomBrightness(Augmentation):
536
+ """
537
+ Randomly transforms image brightness.
538
+
539
+ Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
540
+ - intensity < 1 will reduce brightness
541
+ - intensity = 1 will preserve the input image
542
+ - intensity > 1 will increase brightness
543
+
544
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
545
+ """
546
+
547
+ def __init__(self, intensity_min, intensity_max):
548
+ """
549
+ Args:
550
+ intensity_min (float): Minimum augmentation
551
+ intensity_max (float): Maximum augmentation
552
+ """
553
+ super().__init__()
554
+ self._init(locals())
555
+
556
+ def get_transform(self, image):
557
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
558
+ return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
559
+
560
+
561
+ class RandomSaturation(Augmentation):
562
+ """
563
+ Randomly transforms saturation of an RGB image.
564
+ Input images are assumed to have 'RGB' channel order.
565
+
566
+ Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
567
+ - intensity < 1 will reduce saturation (make the image more grayscale)
568
+ - intensity = 1 will preserve the input image
569
+ - intensity > 1 will increase saturation
570
+
571
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
572
+ """
573
+
574
+ def __init__(self, intensity_min, intensity_max):
575
+ """
576
+ Args:
577
+ intensity_min (float): Minimum augmentation (1 preserves input).
578
+ intensity_max (float): Maximum augmentation (1 preserves input).
579
+ """
580
+ super().__init__()
581
+ self._init(locals())
582
+
583
+ def get_transform(self, image):
584
+ assert image.shape[-1] == 3, "RandomSaturation only works on RGB images"
585
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
586
+ grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
587
+ return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
588
+
589
+
590
+ class RandomLighting(Augmentation):
591
+ """
592
+ The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet.
593
+ Input images are assumed to have 'RGB' channel order.
594
+
595
+ The degree of color jittering is randomly sampled via a normal distribution,
596
+ with standard deviation given by the scale parameter.
597
+ """
598
+
599
+ def __init__(self, scale):
600
+ """
601
+ Args:
602
+ scale (float): Standard deviation of principal component weighting.
603
+ """
604
+ super().__init__()
605
+ self._init(locals())
606
+ self.eigen_vecs = np.array(
607
+ [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
608
+ )
609
+ self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
610
+
611
+ def get_transform(self, image):
612
+ assert image.shape[-1] == 3, "RandomLighting only works on RGB images"
613
+ weights = np.random.normal(scale=self.scale, size=3)
614
+ return BlendTransform(
615
+ src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
616
+ )
cutler/data/transforms/transform.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/transform.py
3
+
4
+ """
5
+ See "Data Augmentation" tutorial for an overview of the system:
6
+ https://detectron2.readthedocs.io/tutorials/augmentation.html
7
+ """
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from fvcore.transforms.transform import (
13
+ CropTransform,
14
+ HFlipTransform,
15
+ NoOpTransform,
16
+ Transform,
17
+ TransformList,
18
+ )
19
+ from PIL import Image
20
+
21
+ try:
22
+ import cv2 # noqa
23
+ except ImportError:
24
+ # OpenCV is an optional dependency at the moment
25
+ pass
26
+
27
+ __all__ = [
28
+ "ExtentTransform",
29
+ "ResizeTransform",
30
+ "RotationTransform",
31
+ "ColorTransform",
32
+ "PILColorTransform",
33
+ ]
34
+
35
+
36
+ class ExtentTransform(Transform):
37
+ """
38
+ Extracts a subregion from the source image and scales it to the output size.
39
+
40
+ The fill color is used to map pixels from the source rect that fall outside
41
+ the source image.
42
+
43
+ See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
44
+ """
45
+
46
+ def __init__(self, src_rect, output_size, interp=Image.BILINEAR, fill=0):
47
+ """
48
+ Args:
49
+ src_rect (x0, y0, x1, y1): src coordinates
50
+ output_size (h, w): dst image size
51
+ interp: PIL interpolation methods
52
+ fill: Fill color used when src_rect extends outside image
53
+ """
54
+ super().__init__()
55
+ self._set_attributes(locals())
56
+
57
+ def apply_image(self, img, interp=None):
58
+ h, w = self.output_size
59
+ if len(img.shape) > 2 and img.shape[2] == 1:
60
+ pil_image = Image.fromarray(img[:, :, 0], mode="L")
61
+ else:
62
+ pil_image = Image.fromarray(img)
63
+ pil_image = pil_image.transform(
64
+ size=(w, h),
65
+ method=Image.EXTENT,
66
+ data=self.src_rect,
67
+ resample=interp if interp else self.interp,
68
+ fill=self.fill,
69
+ )
70
+ ret = np.asarray(pil_image)
71
+ if len(img.shape) > 2 and img.shape[2] == 1:
72
+ ret = np.expand_dims(ret, -1)
73
+ return ret
74
+
75
+ def apply_coords(self, coords):
76
+ # Transform image center from source coordinates into output coordinates
77
+ # and then map the new origin to the corner of the output image.
78
+ h, w = self.output_size
79
+ x0, y0, x1, y1 = self.src_rect
80
+ new_coords = coords.astype(np.float32)
81
+ new_coords[:, 0] -= 0.5 * (x0 + x1)
82
+ new_coords[:, 1] -= 0.5 * (y0 + y1)
83
+ new_coords[:, 0] *= w / (x1 - x0)
84
+ new_coords[:, 1] *= h / (y1 - y0)
85
+ new_coords[:, 0] += 0.5 * w
86
+ new_coords[:, 1] += 0.5 * h
87
+ return new_coords
88
+
89
+ def apply_segmentation(self, segmentation):
90
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
91
+ return segmentation
92
+
93
+
94
+ class ResizeTransform(Transform):
95
+ """
96
+ Resize the image to a target size.
97
+ """
98
+
99
+ def __init__(self, h, w, new_h, new_w, interp=None):
100
+ """
101
+ Args:
102
+ h, w (int): original image size
103
+ new_h, new_w (int): new image size
104
+ interp: PIL interpolation methods, defaults to bilinear.
105
+ """
106
+ # TODO decide on PIL vs opencv
107
+ super().__init__()
108
+ if interp is None:
109
+ interp = Image.BILINEAR
110
+ self._set_attributes(locals())
111
+
112
+ def apply_image(self, img, interp=None):
113
+ try:
114
+ img.shape[:2] == (self.h, self.w)
115
+ except:
116
+ (self.h, self.w) = (self.w, self.h)
117
+ assert img.shape[:2] == (self.h, self.w)
118
+ assert len(img.shape) <= 4
119
+ interp_method = interp if interp is not None else self.interp
120
+
121
+ if img.dtype == np.uint8:
122
+ if len(img.shape) > 2 and img.shape[2] == 1:
123
+ pil_image = Image.fromarray(img[:, :, 0], mode="L")
124
+ else:
125
+ pil_image = Image.fromarray(img)
126
+ pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
127
+ ret = np.asarray(pil_image)
128
+ if len(img.shape) > 2 and img.shape[2] == 1:
129
+ ret = np.expand_dims(ret, -1)
130
+ else:
131
+ # PIL only supports uint8
132
+ if any(x < 0 for x in img.strides):
133
+ img = np.ascontiguousarray(img)
134
+ img = torch.from_numpy(img)
135
+ shape = list(img.shape)
136
+ shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
137
+ img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
138
+ _PIL_RESIZE_TO_INTERPOLATE_MODE = {
139
+ Image.NEAREST: "nearest",
140
+ Image.BILINEAR: "bilinear",
141
+ Image.BICUBIC: "bicubic",
142
+ }
143
+ mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method]
144
+ align_corners = None if mode == "nearest" else False
145
+ img = F.interpolate(
146
+ img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners
147
+ )
148
+ shape[:2] = (self.new_h, self.new_w)
149
+ ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
150
+
151
+ return ret
152
+
153
+ def apply_coords(self, coords):
154
+ coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
155
+ coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
156
+ return coords
157
+
158
+ def apply_segmentation(self, segmentation):
159
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
160
+ return segmentation
161
+
162
+ def inverse(self):
163
+ return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp)
164
+
165
+
166
+ class RotationTransform(Transform):
167
+ """
168
+ This method returns a copy of this image, rotated the given
169
+ number of degrees counter clockwise around its center.
170
+ """
171
+
172
+ def __init__(self, h, w, angle, expand=True, center=None, interp=None):
173
+ """
174
+ Args:
175
+ h, w (int): original image size
176
+ angle (float): degrees for rotation
177
+ expand (bool): choose if the image should be resized to fit the whole
178
+ rotated image (default), or simply cropped
179
+ center (tuple (width, height)): coordinates of the rotation center
180
+ if left to None, the center will be fit to the center of each image
181
+ center has no effect if expand=True because it only affects shifting
182
+ interp: cv2 interpolation method, default cv2.INTER_LINEAR
183
+ """
184
+ super().__init__()
185
+ image_center = np.array((w / 2, h / 2))
186
+ if center is None:
187
+ center = image_center
188
+ if interp is None:
189
+ interp = cv2.INTER_LINEAR
190
+ abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle))))
191
+ if expand:
192
+ # find the new width and height bounds
193
+ bound_w, bound_h = np.rint(
194
+ [h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
195
+ ).astype(int)
196
+ else:
197
+ bound_w, bound_h = w, h
198
+
199
+ self._set_attributes(locals())
200
+ self.rm_coords = self.create_rotation_matrix()
201
+ # Needed because of this problem https://github.com/opencv/opencv/issues/11784
202
+ self.rm_image = self.create_rotation_matrix(offset=-0.5)
203
+
204
+ def apply_image(self, img, interp=None):
205
+ """
206
+ img should be a numpy array, formatted as Height * Width * Nchannels
207
+ """
208
+ if len(img) == 0 or self.angle % 360 == 0:
209
+ return img
210
+ assert img.shape[:2] == (self.h, self.w)
211
+ interp = interp if interp is not None else self.interp
212
+ return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
213
+
214
+ def apply_coords(self, coords):
215
+ """
216
+ coords should be a N * 2 array-like, containing N couples of (x, y) points
217
+ """
218
+ coords = np.asarray(coords, dtype=float)
219
+ if len(coords) == 0 or self.angle % 360 == 0:
220
+ return coords
221
+ return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
222
+
223
+ def apply_segmentation(self, segmentation):
224
+ segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
225
+ return segmentation
226
+
227
+ def create_rotation_matrix(self, offset=0):
228
+ center = (self.center[0] + offset, self.center[1] + offset)
229
+ rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
230
+ if self.expand:
231
+ # Find the coordinates of the center of rotation in the new image
232
+ # The only point for which we know the future coordinates is the center of the image
233
+ rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
234
+ new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
235
+ # shift the rotation center to the new coordinates
236
+ rm[:, 2] += new_center
237
+ return rm
238
+
239
+ def inverse(self):
240
+ """
241
+ The inverse is to rotate it back with expand, and crop to get the original shape.
242
+ """
243
+ if not self.expand: # Not possible to inverse if a part of the image is lost
244
+ raise NotImplementedError()
245
+ rotation = RotationTransform(
246
+ self.bound_h, self.bound_w, -self.angle, True, None, self.interp
247
+ )
248
+ crop = CropTransform(
249
+ (rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h
250
+ )
251
+ return TransformList([rotation, crop])
252
+
253
+
254
+ class ColorTransform(Transform):
255
+ """
256
+ Generic wrapper for any photometric transforms.
257
+ These transformations should only affect the color space and
258
+ not the coordinate space of the image (e.g. annotation
259
+ coordinates such as bounding boxes should not be changed)
260
+ """
261
+
262
+ def __init__(self, op):
263
+ """
264
+ Args:
265
+ op (Callable): operation to be applied to the image,
266
+ which takes in an ndarray and returns an ndarray.
267
+ """
268
+ if not callable(op):
269
+ raise ValueError("op parameter should be callable")
270
+ super().__init__()
271
+ self._set_attributes(locals())
272
+
273
+ def apply_image(self, img):
274
+ return self.op(img)
275
+
276
+ def apply_coords(self, coords):
277
+ return coords
278
+
279
+ def inverse(self):
280
+ return NoOpTransform()
281
+
282
+ def apply_segmentation(self, segmentation):
283
+ return segmentation
284
+
285
+
286
+ class PILColorTransform(ColorTransform):
287
+ """
288
+ Generic wrapper for PIL Photometric image transforms,
289
+ which affect the color space and not the coordinate
290
+ space of the image
291
+ """
292
+
293
+ def __init__(self, op):
294
+ """
295
+ Args:
296
+ op (Callable): operation to be applied to the image,
297
+ which takes in a PIL Image and returns a transformed
298
+ PIL Image.
299
+ For reference on possible operations see:
300
+ - https://pillow.readthedocs.io/en/stable/
301
+ """
302
+ if not callable(op):
303
+ raise ValueError("op parameter should be callable")
304
+ super().__init__(op)
305
+
306
+ def apply_image(self, img):
307
+ img = Image.fromarray(img)
308
+ return np.asarray(super().apply_image(img))
309
+
310
+
311
+ def HFlip_rotated_box(transform, rotated_boxes):
312
+ """
313
+ Apply the horizontal flip transform on rotated boxes.
314
+
315
+ Args:
316
+ rotated_boxes (ndarray): Nx5 floating point array of
317
+ (x_center, y_center, width, height, angle_degrees) format
318
+ in absolute coordinates.
319
+ """
320
+ # Transform x_center
321
+ rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0]
322
+ # Transform angle
323
+ rotated_boxes[:, 4] = -rotated_boxes[:, 4]
324
+ return rotated_boxes
325
+
326
+
327
+ def Resize_rotated_box(transform, rotated_boxes):
328
+ """
329
+ Apply the resizing transform on rotated boxes. For details of how these (approximation)
330
+ formulas are derived, please refer to :meth:`RotatedBoxes.scale`.
331
+
332
+ Args:
333
+ rotated_boxes (ndarray): Nx5 floating point array of
334
+ (x_center, y_center, width, height, angle_degrees) format
335
+ in absolute coordinates.
336
+ """
337
+ scale_factor_x = transform.new_w * 1.0 / transform.w
338
+ scale_factor_y = transform.new_h * 1.0 / transform.h
339
+ rotated_boxes[:, 0] *= scale_factor_x
340
+ rotated_boxes[:, 1] *= scale_factor_y
341
+ theta = rotated_boxes[:, 4] * np.pi / 180.0
342
+ c = np.cos(theta)
343
+ s = np.sin(theta)
344
+ rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s))
345
+ rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c))
346
+ rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi
347
+
348
+ return rotated_boxes
349
+
350
+
351
+ HFlipTransform.register_type("rotated_box", HFlip_rotated_box)
352
+ ResizeTransform.register_type("rotated_box", Resize_rotated_box)
353
+
354
+ # not necessary any more with latest fvcore
355
+ NoOpTransform.register_type("rotated_box", lambda t, x: x)
cutler/demo/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from demo import *
3
+ from predictor import *
4
+
5
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
cutler/demo/__pycache__/predictor.cpython-312.pyc ADDED
Binary file (11.8 kB). View file
 
cutler/demo/cutler_cascade_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24fdc4544a348e5c27f6e334e4ff6557d8f2d50de94380a741bc91c2226b86bb
3
+ size 574672112
cutler/demo/demo.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/demo/demo.py
3
+
4
+ import argparse
5
+ import glob
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+ import os
9
+ import tempfile
10
+ import time
11
+ import warnings
12
+ import cv2
13
+ import tqdm
14
+
15
+ from detectron2.config import get_cfg
16
+ from detectron2.data.detection_utils import read_image
17
+ from detectron2.utils.logger import setup_logger
18
+ import sys
19
+ sys.path.append('./')
20
+ sys.path.append('../')
21
+ from config import add_cutler_config
22
+
23
+ from predictor import VisualizationDemo
24
+
25
+ # constants
26
+ WINDOW_NAME = "CutLER detections"
27
+
28
+
29
+ def setup_cfg(args):
30
+ # load config from file and command-line arguments
31
+ cfg = get_cfg()
32
+ add_cutler_config(cfg)
33
+ cfg.merge_from_file(args.config_file)
34
+ cfg.merge_from_list(args.opts)
35
+ # Disable the use of SyncBN normalization when running on a CPU
36
+ # SyncBN is not supported on CPU and can cause errors, so we switch to BN instead
37
+ if cfg.MODEL.DEVICE == 'cpu' and cfg.MODEL.RESNETS.NORM == 'SyncBN':
38
+ cfg.MODEL.RESNETS.NORM = "BN"
39
+ cfg.MODEL.FPN.NORM = "BN"
40
+ # Set score_threshold for builtin models
41
+ cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
42
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
43
+ cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
44
+ cfg.freeze()
45
+ return cfg
46
+
47
+
48
+ def get_parser():
49
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
50
+ parser.add_argument(
51
+ "--config-file",
52
+ default="model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml",
53
+ metavar="FILE",
54
+ help="path to config file",
55
+ )
56
+ parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
57
+ parser.add_argument("--video-input", help="Path to video file.")
58
+ parser.add_argument(
59
+ "--input",
60
+ nargs="+",
61
+ help="A list of space separated input images; "
62
+ "or a single glob pattern such as 'directory/*.jpg'",
63
+ )
64
+ parser.add_argument(
65
+ "--output",
66
+ help="A file or directory to save output visualizations. "
67
+ "If not given, will show output in an OpenCV window.",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--confidence-threshold",
72
+ type=float,
73
+ default=0.35,
74
+ help="Minimum score for instance predictions to be shown",
75
+ )
76
+ parser.add_argument(
77
+ "--opts",
78
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
79
+ default=[],
80
+ nargs=argparse.REMAINDER,
81
+ )
82
+ return parser
83
+
84
+
85
+ def test_opencv_video_format(codec, file_ext):
86
+ with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
87
+ filename = os.path.join(dir, "test_file" + file_ext)
88
+ writer = cv2.VideoWriter(
89
+ filename=filename,
90
+ fourcc=cv2.VideoWriter_fourcc(*codec),
91
+ fps=float(30),
92
+ frameSize=(10, 10),
93
+ isColor=True,
94
+ )
95
+ [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
96
+ writer.release()
97
+ if os.path.isfile(filename):
98
+ return True
99
+ return False
100
+
101
+
102
+ if __name__ == "__main__":
103
+ mp.set_start_method("spawn", force=True)
104
+ args = get_parser().parse_args()
105
+ setup_logger(name="fvcore")
106
+ logger = setup_logger()
107
+ logger.info("Arguments: " + str(args))
108
+
109
+ cfg = setup_cfg(args)
110
+
111
+ demo = VisualizationDemo(cfg)
112
+
113
+ if args.input:
114
+ if len(args.input) == 1:
115
+ args.input = glob.glob(os.path.expanduser(args.input[0]))
116
+ assert args.input, "The input path(s) was not found"
117
+ for path in tqdm.tqdm(args.input, disable=not args.output):
118
+ # use PIL, to be consistent with evaluation
119
+ img = read_image(path, format="BGR")
120
+ start_time = time.time()
121
+ predictions, visualized_output = demo.run_on_image(img)
122
+ logger.info(
123
+ "{}: {} in {:.2f}s".format(
124
+ path,
125
+ "detected {} instances".format(len(predictions["instances"]))
126
+ if "instances" in predictions
127
+ else "finished",
128
+ time.time() - start_time,
129
+ )
130
+ )
131
+
132
+ if args.output:
133
+ if os.path.isdir(args.output):
134
+ assert os.path.isdir(args.output), args.output
135
+ out_filename = os.path.join(args.output, os.path.basename(path))
136
+ else:
137
+ assert len(args.input) == 1, "Please specify a directory with args.output"
138
+ out_filename = args.output
139
+ visualized_output.save(out_filename)
140
+ else:
141
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
142
+ cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
143
+ if cv2.waitKey(0) == 27:
144
+ break # esc to quit
145
+ elif args.webcam:
146
+ assert args.input is None, "Cannot have both --input and --webcam!"
147
+ assert args.output is None, "output not yet supported with --webcam!"
148
+ cam = cv2.VideoCapture(0)
149
+ for vis in tqdm.tqdm(demo.run_on_video(cam)):
150
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
151
+ cv2.imshow(WINDOW_NAME, vis)
152
+ if cv2.waitKey(1) == 27:
153
+ break # esc to quit
154
+ cam.release()
155
+ cv2.destroyAllWindows()
156
+ elif args.video_input:
157
+ video = cv2.VideoCapture(args.video_input)
158
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
159
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
160
+ frames_per_second = video.get(cv2.CAP_PROP_FPS)
161
+ num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
162
+ basename = os.path.basename(args.video_input)
163
+ codec, file_ext = (
164
+ ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
165
+ )
166
+ if codec == ".mp4v":
167
+ warnings.warn("x264 codec not available, switching to mp4v")
168
+ if args.output:
169
+ if os.path.isdir(args.output):
170
+ output_fname = os.path.join(args.output, basename)
171
+ output_fname = os.path.splitext(output_fname)[0] + file_ext
172
+ else:
173
+ output_fname = args.output
174
+ assert not os.path.isfile(output_fname), output_fname
175
+ output_file = cv2.VideoWriter(
176
+ filename=output_fname,
177
+ # some installation of opencv may not support x264 (due to its license),
178
+ # you can try other format (e.g. MPEG)
179
+ fourcc=cv2.VideoWriter_fourcc(*codec),
180
+ fps=float(frames_per_second),
181
+ frameSize=(width, height),
182
+ isColor=True,
183
+ )
184
+ assert os.path.isfile(args.video_input)
185
+ for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
186
+ if args.output:
187
+ output_file.write(vis_frame)
188
+ else:
189
+ cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
190
+ cv2.imshow(basename, vis_frame)
191
+ if cv2.waitKey(1) == 27:
192
+ break # esc to quit
193
+ video.release()
194
+ if args.output:
195
+ output_file.release()
196
+ else:
197
+ cv2.destroyAllWindows()
cutler/demo/imgs/demo1.jpg ADDED

Git LFS Details

  • SHA256: 6036356bf683ac920fd9eb048864c26b5fdbd546299228c2a4d638c7e649c59a
  • Pointer size: 131 Bytes
  • Size of remote file: 539 kB
cutler/demo/imgs/demo2.jpg ADDED

Git LFS Details

  • SHA256: cc29e55f4d2de0a1e11440b4e406bb8b405c71a0276ea6844f109729c470e287
  • Pointer size: 131 Bytes
  • Size of remote file: 382 kB
cutler/demo/imgs/demo3.jpg ADDED

Git LFS Details

  • SHA256: 93d53ca4db286b2fe51847ef91793910f869c210e20169447d0bce14320fb4a0
  • Pointer size: 130 Bytes
  • Size of remote file: 72.8 kB
cutler/demo/imgs/demo4.jpg ADDED

Git LFS Details

  • SHA256: 61e98ea9aa327b108f6151958f773c6c20804a139173a7346e0a262837e34f8f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
cutler/demo/imgs/demo5.jpg ADDED

Git LFS Details

  • SHA256: 1889fb1ea44b9f1b6041d861ff4d738c21306715e47b51126c8768d0a24c66bf
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
cutler/demo/imgs/demo6.jpg ADDED

Git LFS Details

  • SHA256: 0105d25030c21e7b1f6337ee21b197c97db6acbd9aceebc944bb5913330571ac
  • Pointer size: 131 Bytes
  • Size of remote file: 262 kB
cutler/demo/imgs/demo7.jpg ADDED

Git LFS Details

  • SHA256: 9b456257e519cfa2653a2f443a7467a3bc18d5a93a96e307f89a71af150c2cf2
  • Pointer size: 131 Bytes
  • Size of remote file: 184 kB
cutler/demo/imgs/demo8.jpg ADDED

Git LFS Details

  • SHA256: 54aad1e602f8c0a7e412e10a197c6d42540d6cef763eed13388b9f6b7958fda4
  • Pointer size: 130 Bytes
  • Size of remote file: 47.3 kB
cutler/demo/predictor.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import atexit
3
+ import bisect
4
+ import multiprocessing as mp
5
+ from collections import deque
6
+ import cv2
7
+ import torch
8
+
9
+ from detectron2.data import MetadataCatalog
10
+ import sys
11
+ sys.path.append('./')
12
+ from engine.defaults import DefaultPredictor
13
+ from detectron2.utils.video_visualizer import VideoVisualizer
14
+ from detectron2.utils.visualizer import ColorMode, Visualizer
15
+
16
+
17
+ class VisualizationDemo(object):
18
+ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
19
+ """
20
+ Args:
21
+ cfg (CfgNode):
22
+ instance_mode (ColorMode):
23
+ parallel (bool): whether to run the model in different processes from visualization.
24
+ Useful since the visualization logic can be slow.
25
+ """
26
+ self.metadata = MetadataCatalog.get(
27
+ cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
28
+ )
29
+ self.cpu_device = torch.device("cpu")
30
+ self.instance_mode = instance_mode
31
+
32
+ self.parallel = parallel
33
+ if parallel:
34
+ num_gpu = torch.cuda.device_count()
35
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
36
+ else:
37
+ self.predictor = DefaultPredictor(cfg)
38
+
39
+ def run_on_image(self, image):
40
+ """
41
+ Args:
42
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
43
+ This is the format used by OpenCV.
44
+ Returns:
45
+ predictions (dict): the output of the model.
46
+ vis_output (VisImage): the visualized image output.
47
+ """
48
+ vis_output = None
49
+ predictions = self.predictor(image)
50
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
51
+ image = image[:, :, ::-1]
52
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
53
+ if "panoptic_seg" in predictions:
54
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
55
+ vis_output = visualizer.draw_panoptic_seg_predictions(
56
+ panoptic_seg.to(self.cpu_device), segments_info
57
+ )
58
+ else:
59
+ if "sem_seg" in predictions:
60
+ vis_output = visualizer.draw_sem_seg(
61
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
62
+ )
63
+ if "instances" in predictions:
64
+ instances = predictions["instances"].to(self.cpu_device)
65
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
66
+
67
+ return predictions, vis_output
68
+
69
+ def _frame_from_video(self, video):
70
+ while video.isOpened():
71
+ success, frame = video.read()
72
+ if success:
73
+ yield frame
74
+ else:
75
+ break
76
+
77
+ def run_on_video(self, video):
78
+ """
79
+ Visualizes predictions on frames of the input video.
80
+ Args:
81
+ video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
82
+ either a webcam or a video file.
83
+ Yields:
84
+ ndarray: BGR visualizations of each video frame.
85
+ """
86
+ video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
87
+
88
+ def process_predictions(frame, predictions):
89
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
90
+ if "panoptic_seg" in predictions:
91
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
92
+ vis_frame = video_visualizer.draw_panoptic_seg_predictions(
93
+ frame, panoptic_seg.to(self.cpu_device), segments_info
94
+ )
95
+ elif "instances" in predictions:
96
+ predictions = predictions["instances"].to(self.cpu_device)
97
+ vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
98
+ elif "sem_seg" in predictions:
99
+ vis_frame = video_visualizer.draw_sem_seg(
100
+ frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
101
+ )
102
+
103
+ # Converts Matplotlib RGB format to OpenCV BGR format
104
+ vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
105
+ return vis_frame
106
+
107
+ frame_gen = self._frame_from_video(video)
108
+ if self.parallel:
109
+ buffer_size = self.predictor.default_buffer_size
110
+
111
+ frame_data = deque()
112
+
113
+ for cnt, frame in enumerate(frame_gen):
114
+ frame_data.append(frame)
115
+ self.predictor.put(frame)
116
+
117
+ if cnt >= buffer_size:
118
+ frame = frame_data.popleft()
119
+ predictions = self.predictor.get()
120
+ yield process_predictions(frame, predictions)
121
+
122
+ while len(frame_data):
123
+ frame = frame_data.popleft()
124
+ predictions = self.predictor.get()
125
+ yield process_predictions(frame, predictions)
126
+ else:
127
+ for frame in frame_gen:
128
+ yield process_predictions(frame, self.predictor(frame))
129
+
130
+
131
+ class AsyncPredictor:
132
+ """
133
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
134
+ Because rendering the visualization takes considerably amount of time,
135
+ this helps improve throughput a little bit when rendering videos.
136
+ """
137
+
138
+ class _StopToken:
139
+ pass
140
+
141
+ class _PredictWorker(mp.Process):
142
+ def __init__(self, cfg, task_queue, result_queue):
143
+ self.cfg = cfg
144
+ self.task_queue = task_queue
145
+ self.result_queue = result_queue
146
+ super().__init__()
147
+
148
+ def run(self):
149
+ predictor = DefaultPredictor(self.cfg)
150
+
151
+ while True:
152
+ task = self.task_queue.get()
153
+ if isinstance(task, AsyncPredictor._StopToken):
154
+ break
155
+ idx, data = task
156
+ result = predictor(data)
157
+ self.result_queue.put((idx, result))
158
+
159
+ def __init__(self, cfg, num_gpus: int = 1):
160
+ """
161
+ Args:
162
+ cfg (CfgNode):
163
+ num_gpus (int): if 0, will run on CPU
164
+ """
165
+ num_workers = max(num_gpus, 1)
166
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
167
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
168
+ self.procs = []
169
+ for gpuid in range(max(num_gpus, 1)):
170
+ cfg = cfg.clone()
171
+ cfg.defrost()
172
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
173
+ self.procs.append(
174
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
175
+ )
176
+
177
+ self.put_idx = 0
178
+ self.get_idx = 0
179
+ self.result_rank = []
180
+ self.result_data = []
181
+
182
+ for p in self.procs:
183
+ p.start()
184
+ atexit.register(self.shutdown)
185
+
186
+ def put(self, image):
187
+ self.put_idx += 1
188
+ self.task_queue.put((self.put_idx, image))
189
+
190
+ def get(self):
191
+ self.get_idx += 1 # the index needed for this request
192
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
193
+ res = self.result_data[0]
194
+ del self.result_data[0], self.result_rank[0]
195
+ return res
196
+
197
+ while True:
198
+ # make sure the results are returned in the correct order
199
+ idx, res = self.result_queue.get()
200
+ if idx == self.get_idx:
201
+ return res
202
+ insert = bisect.bisect(self.result_rank, idx)
203
+ self.result_rank.insert(insert, idx)
204
+ self.result_data.insert(insert, res)
205
+
206
+ def __len__(self):
207
+ return self.put_idx - self.get_idx
208
+
209
+ def __call__(self, image):
210
+ self.put(image)
211
+ return self.get()
212
+
213
+ def shutdown(self):
214
+ for _ in self.procs:
215
+ self.task_queue.put(AsyncPredictor._StopToken())
216
+
217
+ @property
218
+ def default_buffer_size(self):
219
+ return len(self.procs) * 5
cutler/demo/wget-log ADDED
The diff for this file is too large to render. See raw diff
 
cutler/engine/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from .train_loop import *
4
+
5
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
6
+
7
+ from .defaults import *
cutler/engine/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (481 Bytes). View file
 
cutler/engine/__pycache__/defaults.cpython-312.pyc ADDED
Binary file (34.6 kB). View file
 
cutler/engine/__pycache__/train_loop.cpython-312.pyc ADDED
Binary file (20.3 kB). View file
 
cutler/engine/defaults.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/defaults.py
4
+
5
+ """
6
+ This file contains components with some default boilerplate logic user may need
7
+ in training / testing. They will not work for everyone, but many users may find them useful.
8
+
9
+ The behavior of functions/classes in this file is subject to change,
10
+ since they are meant to represent the "common default behavior" people need in their projects.
11
+ """
12
+
13
+ import argparse
14
+ import logging
15
+ import os
16
+ import sys
17
+ import weakref
18
+ from collections import OrderedDict
19
+ from typing import Optional
20
+ import torch
21
+ from fvcore.nn.precise_bn import get_bn_modules
22
+ from omegaconf import OmegaConf
23
+ from torch.nn.parallel import DistributedDataParallel
24
+
25
+ import data.transforms as T
26
+ from detectron2.checkpoint import DetectionCheckpointer
27
+ from detectron2.config import CfgNode, LazyConfig
28
+ from detectron2.data import (
29
+ MetadataCatalog,
30
+ )
31
+ from data import (
32
+ build_detection_test_loader,
33
+ build_detection_train_loader,
34
+ )
35
+ from detectron2.evaluation import (
36
+ DatasetEvaluator,
37
+ inference_on_dataset,
38
+ print_csv_format,
39
+ verify_results,
40
+ )
41
+ from modeling import build_model
42
+ from solver import build_lr_scheduler, build_optimizer
43
+ from detectron2.utils import comm
44
+ from detectron2.utils.collect_env import collect_env_info
45
+ from detectron2.utils.env import seed_all_rng
46
+ from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
47
+ from detectron2.utils.file_io import PathManager
48
+ from detectron2.utils.logger import setup_logger
49
+
50
+ from detectron2.engine import hooks
51
+ from detectron2.engine import TrainerBase
52
+ from .train_loop import CustomAMPTrainer, CustomSimpleTrainer
53
+
54
+ __all__ = [
55
+ "create_ddp_model",
56
+ "default_argument_parser",
57
+ "default_setup",
58
+ "default_writers",
59
+ "DefaultPredictor",
60
+ "DefaultTrainer",
61
+ ]
62
+
63
+
64
+ def create_ddp_model(model, *, fp16_compression=False, **kwargs):
65
+ """
66
+ Create a DistributedDataParallel model if there are >1 processes.
67
+
68
+ Args:
69
+ model: a torch.nn.Module
70
+ fp16_compression: add fp16 compression hooks to the ddp object.
71
+ See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
72
+ kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
73
+ """ # noqa
74
+ if comm.get_world_size() == 1:
75
+ return model
76
+ if "device_ids" not in kwargs:
77
+ kwargs["device_ids"] = [comm.get_local_rank()]
78
+ ddp = DistributedDataParallel(model, **kwargs)
79
+ if fp16_compression:
80
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
81
+
82
+ ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
83
+ return ddp
84
+
85
+
86
+ def default_argument_parser(epilog=None):
87
+ """
88
+ Create a parser with some common arguments used by detectron2 users.
89
+
90
+ Args:
91
+ epilog (str): epilog passed to ArgumentParser describing the usage.
92
+
93
+ Returns:
94
+ argparse.ArgumentParser:
95
+ """
96
+ parser = argparse.ArgumentParser(
97
+ epilog=epilog
98
+ or f"""
99
+ Examples:
100
+
101
+ Run on single machine:
102
+ $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
103
+
104
+ Change some config options:
105
+ $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
106
+
107
+ Run on multiple machines:
108
+ (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
109
+ (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
110
+ """,
111
+ formatter_class=argparse.RawDescriptionHelpFormatter,
112
+ )
113
+ parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
114
+ parser.add_argument(
115
+ "--resume",
116
+ action="store_true",
117
+ help="Whether to attempt to resume from the checkpoint directory. "
118
+ "See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
119
+ )
120
+ parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
121
+ parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
122
+ parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
123
+ parser.add_argument(
124
+ "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
125
+ )
126
+ parser.add_argument(
127
+ "--test-dataset", type=str, default="", help="the dataset used for evaluation"
128
+ )
129
+ parser.add_argument(
130
+ "--train-dataset", type=str, default="", help="the dataset used for training"
131
+ )
132
+ parser.add_argument("--no-segm", action="store_true", help="perform evaluation on detection only")
133
+ # PyTorch still may leave orphan processes in multi-gpu training.
134
+ # Therefore we use a deterministic way to obtain port,
135
+ # so that users are aware of orphan processes by seeing the port occupied.
136
+ port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
137
+ parser.add_argument(
138
+ "--dist-url",
139
+ default="tcp://127.0.0.1:{}".format(port),
140
+ help="initialization URL for pytorch distributed backend. See "
141
+ "https://pytorch.org/docs/stable/distributed.html for details.",
142
+ )
143
+ parser.add_argument(
144
+ "opts",
145
+ help="""
146
+ Modify config options at the end of the command. For Yacs configs, use
147
+ space-separated "PATH.KEY VALUE" pairs.
148
+ For python-based LazyConfig, use "path.key=value".
149
+ """.strip(),
150
+ default=None,
151
+ nargs=argparse.REMAINDER,
152
+ )
153
+ return parser
154
+
155
+
156
+ def _try_get_key(cfg, *keys, default=None):
157
+ """
158
+ Try select keys from cfg until the first key that exists. Otherwise return default.
159
+ """
160
+ if isinstance(cfg, CfgNode):
161
+ cfg = OmegaConf.create(cfg.dump())
162
+ for k in keys:
163
+ none = object()
164
+ p = OmegaConf.select(cfg, k, default=none)
165
+ if p is not none:
166
+ return p
167
+ return default
168
+
169
+
170
+ def _highlight(code, filename):
171
+ try:
172
+ import pygments
173
+ except ImportError:
174
+ return code
175
+
176
+ from pygments.lexers import Python3Lexer, YamlLexer
177
+ from pygments.formatters import Terminal256Formatter
178
+
179
+ lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
180
+ code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
181
+ return code
182
+
183
+
184
+ def default_setup(cfg, args):
185
+ """
186
+ Perform some basic common setups at the beginning of a job, including:
187
+
188
+ 1. Set up the detectron2 logger
189
+ 2. Log basic information about environment, cmdline arguments, and config
190
+ 3. Backup the config to the output directory
191
+
192
+ Args:
193
+ cfg (CfgNode or omegaconf.DictConfig): the full config to be used
194
+ args (argparse.NameSpace): the command line arguments to be logged
195
+ """
196
+ output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
197
+ if comm.is_main_process() and output_dir:
198
+ PathManager.mkdirs(output_dir)
199
+
200
+ rank = comm.get_rank()
201
+ setup_logger(output_dir, distributed_rank=rank, name="fvcore")
202
+ logger = setup_logger(output_dir, distributed_rank=rank)
203
+
204
+ logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
205
+ logger.info("Environment info:\n" + collect_env_info())
206
+
207
+ logger.info("Command line arguments: " + str(args))
208
+ if hasattr(args, "config_file") and args.config_file != "":
209
+ logger.info(
210
+ "Contents of args.config_file={}:\n{}".format(
211
+ args.config_file,
212
+ _highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
213
+ )
214
+ )
215
+
216
+ if comm.is_main_process() and output_dir:
217
+ # Note: some of our scripts may expect the existence of
218
+ # config.yaml in output directory
219
+ path = os.path.join(output_dir, "config.yaml")
220
+ if isinstance(cfg, CfgNode):
221
+ logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
222
+ with PathManager.open(path, "w") as f:
223
+ f.write(cfg.dump())
224
+ else:
225
+ LazyConfig.save(cfg, path)
226
+ logger.info("Full config saved to {}".format(path))
227
+
228
+ # make sure each worker has a different, yet deterministic seed if specified
229
+ seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
230
+ seed_all_rng(None if seed < 0 else seed + rank)
231
+
232
+ # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
233
+ # typical validation set.
234
+ if not (hasattr(args, "eval_only") and args.eval_only):
235
+ torch.backends.cudnn.benchmark = _try_get_key(
236
+ cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
237
+ )
238
+
239
+
240
+ def default_writers(output_dir: str, max_iter: Optional[int] = None):
241
+ """
242
+ Build a list of :class:`EventWriter` to be used.
243
+ It now consists of a :class:`CommonMetricPrinter`,
244
+ :class:`TensorboardXWriter` and :class:`JSONWriter`.
245
+
246
+ Args:
247
+ output_dir: directory to store JSON metrics and tensorboard events
248
+ max_iter: the total number of iterations
249
+
250
+ Returns:
251
+ list[EventWriter]: a list of :class:`EventWriter` objects.
252
+ """
253
+ PathManager.mkdirs(output_dir)
254
+ return [
255
+ # It may not always print what you want to see, since it prints "common" metrics only.
256
+ CommonMetricPrinter(max_iter),
257
+ JSONWriter(os.path.join(output_dir, "metrics.json")),
258
+ TensorboardXWriter(output_dir),
259
+ ]
260
+
261
+
262
+ class DefaultPredictor:
263
+ """
264
+ Create a simple end-to-end predictor with the given config that runs on
265
+ single device for a single input image.
266
+
267
+ Compared to using the model directly, this class does the following additions:
268
+
269
+ 1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
270
+ 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
271
+ 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
272
+ 4. Take one input image and produce a single output, instead of a batch.
273
+
274
+ This is meant for simple demo purposes, so it does the above steps automatically.
275
+ This is not meant for benchmarks or running complicated inference logic.
276
+ If you'd like to do anything more complicated, please refer to its source code as
277
+ examples to build and use the model manually.
278
+
279
+ Attributes:
280
+ metadata (Metadata): the metadata of the underlying dataset, obtained from
281
+ cfg.DATASETS.TEST.
282
+
283
+ Examples:
284
+ ::
285
+ pred = DefaultPredictor(cfg)
286
+ inputs = cv2.imread("input.jpg")
287
+ outputs = pred(inputs)
288
+ """
289
+
290
+ def __init__(self, cfg):
291
+ self.cfg = cfg.clone() # cfg can be modified by model
292
+ self.model = build_model(self.cfg)
293
+ self.model.eval()
294
+ if len(cfg.DATASETS.TEST):
295
+ self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
296
+
297
+ checkpointer = DetectionCheckpointer(self.model)
298
+ checkpointer.load(cfg.MODEL.WEIGHTS)
299
+
300
+ self.aug = T.ResizeShortestEdge(
301
+ [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
302
+ )
303
+
304
+ self.input_format = cfg.INPUT.FORMAT
305
+ assert self.input_format in ["RGB", "BGR"], self.input_format
306
+
307
+ def __call__(self, original_image):
308
+ """
309
+ Args:
310
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
311
+
312
+ Returns:
313
+ predictions (dict):
314
+ the output of the model for one image only.
315
+ See :doc:`/tutorials/models` for details about the format.
316
+ """
317
+ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
318
+ # Apply pre-processing to image.
319
+ if self.input_format == "RGB":
320
+ # whether the model expects BGR inputs or RGB
321
+ original_image = original_image[:, :, ::-1]
322
+ height, width = original_image.shape[:2]
323
+ image = self.aug.get_transform(original_image).apply_image(original_image)
324
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
325
+
326
+ inputs = {"image": image, "height": height, "width": width}
327
+ predictions = self.model([inputs])[0]
328
+ return predictions
329
+
330
+
331
+ class DefaultTrainer(TrainerBase):
332
+ """
333
+ A trainer with default training logic. It does the following:
334
+
335
+ 1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
336
+ defined by the given config. Create a LR scheduler defined by the config.
337
+ 2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
338
+ `resume_or_load` is called.
339
+ 3. Register a few common hooks defined by the config.
340
+
341
+ It is created to simplify the **standard model training workflow** and reduce code boilerplate
342
+ for users who only need the standard training workflow, with standard features.
343
+ It means this class makes *many assumptions* about your training logic that
344
+ may easily become invalid in a new research. In fact, any assumptions beyond those made in the
345
+ :class:`SimpleTrainer` are too much for research.
346
+
347
+ The code of this class has been annotated about restrictive assumptions it makes.
348
+ When they do not work for you, you're encouraged to:
349
+
350
+ 1. Overwrite methods of this class, OR:
351
+ 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
352
+ nothing else. You can then add your own hooks if needed. OR:
353
+ 3. Write your own training loop similar to `tools/plain_train_net.py`.
354
+
355
+ See the :doc:`/tutorials/training` tutorials for more details.
356
+
357
+ Note that the behavior of this class, like other functions/classes in
358
+ this file, is not stable, since it is meant to represent the "common default behavior".
359
+ It is only guaranteed to work well with the standard models and training workflow in detectron2.
360
+ To obtain more stable behavior, write your own training logic with other public APIs.
361
+
362
+ Examples:
363
+ ::
364
+ trainer = DefaultTrainer(cfg)
365
+ trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
366
+ trainer.train()
367
+
368
+ Attributes:
369
+ scheduler:
370
+ checkpointer (DetectionCheckpointer):
371
+ cfg (CfgNode):
372
+ """
373
+
374
+ def __init__(self, cfg):
375
+ """
376
+ Args:
377
+ cfg (CfgNode):
378
+ """
379
+ super().__init__()
380
+ logger = logging.getLogger("detectron2")
381
+ if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
382
+ setup_logger()
383
+ cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
384
+
385
+ # Assume these objects must be constructed in this order.
386
+ model = self.build_model(cfg)
387
+ optimizer = self.build_optimizer(cfg, model)
388
+ data_loader = self.build_train_loader(cfg)
389
+
390
+ model = create_ddp_model(model, broadcast_buffers=False)
391
+ if cfg.SOLVER.AMP.ENABLED:
392
+ self._trainer = CustomAMPTrainer(model, data_loader, optimizer, cfg=cfg)
393
+ else:
394
+ self._trainer = CustomSimpleTrainer(model, data_loader, optimizer, cfg=cfg)
395
+
396
+ self.scheduler = self.build_lr_scheduler(cfg, optimizer)
397
+ self.checkpointer = DetectionCheckpointer(
398
+ # Assume you want to save checkpoints together with logs/statistics
399
+ model,
400
+ cfg.OUTPUT_DIR,
401
+ trainer=weakref.proxy(self),
402
+ )
403
+ self.start_iter = 0
404
+ self.max_iter = cfg.SOLVER.MAX_ITER
405
+ self.cfg = cfg
406
+
407
+ self.register_hooks(self.build_hooks())
408
+
409
+ def resume_or_load(self, resume=True):
410
+ """
411
+ If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
412
+ a `last_checkpoint` file), resume from the file. Resuming means loading all
413
+ available states (eg. optimizer and scheduler) and update iteration counter
414
+ from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
415
+
416
+ Otherwise, this is considered as an independent training. The method will load model
417
+ weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
418
+ from iteration 0.
419
+
420
+ Args:
421
+ resume (bool): whether to do resume or not
422
+ """
423
+ self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
424
+ if resume and self.checkpointer.has_checkpoint():
425
+ # The checkpoint stores the training iteration that just finished, thus we start
426
+ # at the next iteration
427
+ self.start_iter = self.iter + 1
428
+
429
+ def build_hooks(self):
430
+ """
431
+ Build a list of default hooks, including timing, evaluation,
432
+ checkpointing, lr scheduling, precise BN, writing events.
433
+
434
+ Returns:
435
+ list[HookBase]:
436
+ """
437
+ cfg = self.cfg.clone()
438
+ cfg.defrost()
439
+ cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
440
+
441
+ ret = [
442
+ hooks.IterationTimer(),
443
+ hooks.LRScheduler(),
444
+ hooks.PreciseBN(
445
+ # Run at the same freq as (but before) evaluation.
446
+ cfg.TEST.EVAL_PERIOD,
447
+ self.model,
448
+ # Build a new data loader to not affect training
449
+ self.build_train_loader(cfg),
450
+ cfg.TEST.PRECISE_BN.NUM_ITER,
451
+ )
452
+ if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
453
+ else None,
454
+ ]
455
+
456
+ # Do PreciseBN before checkpointer, because it updates the model and need to
457
+ # be saved by checkpointer.
458
+ # This is not always the best: if checkpointing has a different frequency,
459
+ # some checkpoints may have more precise statistics than others.
460
+ if comm.is_main_process():
461
+ ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
462
+
463
+ def test_and_save_results():
464
+ self._last_eval_results = self.test(self.cfg, self.model)
465
+ return self._last_eval_results
466
+
467
+ # Do evaluation after checkpointer, because then if it fails,
468
+ # we can use the saved checkpoint to debug.
469
+ ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
470
+
471
+ if comm.is_main_process():
472
+ # Here the default print/log frequency of each writer is used.
473
+ # run writers in the end, so that evaluation metrics are written
474
+ ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
475
+ return ret
476
+
477
+ def build_writers(self):
478
+ """
479
+ Build a list of writers to be used using :func:`default_writers()`.
480
+ If you'd like a different list of writers, you can overwrite it in
481
+ your trainer.
482
+
483
+ Returns:
484
+ list[EventWriter]: a list of :class:`EventWriter` objects.
485
+ """
486
+ return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
487
+
488
+ def train(self):
489
+ """
490
+ Run training.
491
+
492
+ Returns:
493
+ OrderedDict of results, if evaluation is enabled. Otherwise None.
494
+ """
495
+ super().train(self.start_iter, self.max_iter)
496
+ if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
497
+ assert hasattr(
498
+ self, "_last_eval_results"
499
+ ), "No evaluation results obtained during training!"
500
+ verify_results(self.cfg, self._last_eval_results)
501
+ return self._last_eval_results
502
+
503
+ def run_step(self):
504
+ self._trainer.iter = self.iter
505
+ self._trainer.run_step()
506
+
507
+ def state_dict(self):
508
+ ret = super().state_dict()
509
+ ret["_trainer"] = self._trainer.state_dict()
510
+ return ret
511
+
512
+ def load_state_dict(self, state_dict):
513
+ super().load_state_dict(state_dict)
514
+ self._trainer.load_state_dict(state_dict["_trainer"])
515
+
516
+ @classmethod
517
+ def build_model(cls, cfg):
518
+ """
519
+ Returns:
520
+ torch.nn.Module:
521
+
522
+ It now calls :func:`detectron2.modeling.build_model`.
523
+ Overwrite it if you'd like a different model.
524
+ """
525
+ model = build_model(cfg)
526
+ logger = logging.getLogger(__name__)
527
+ logger.info("Model:\n{}".format(model))
528
+ return model
529
+
530
+ @classmethod
531
+ def build_optimizer(cls, cfg, model):
532
+ """
533
+ Returns:
534
+ torch.optim.Optimizer:
535
+
536
+ It now calls :func:`detectron2.solver.build_optimizer`.
537
+ Overwrite it if you'd like a different optimizer.
538
+ """
539
+ return build_optimizer(cfg, model)
540
+
541
+ @classmethod
542
+ def build_lr_scheduler(cls, cfg, optimizer):
543
+ """
544
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
545
+ Overwrite it if you'd like a different scheduler.
546
+ """
547
+ return build_lr_scheduler(cfg, optimizer)
548
+
549
+ @classmethod
550
+ def build_train_loader(cls, cfg):
551
+ """
552
+ Returns:
553
+ iterable
554
+
555
+ It now calls :func:`detectron2.data.build_detection_train_loader`.
556
+ Overwrite it if you'd like a different data loader.
557
+ """
558
+ return build_detection_train_loader(cfg)
559
+
560
+ @classmethod
561
+ def build_test_loader(cls, cfg, dataset_name):
562
+ """
563
+ Returns:
564
+ iterable
565
+
566
+ It now calls :func:`detectron2.data.build_detection_test_loader`.
567
+ Overwrite it if you'd like a different data loader.
568
+ """
569
+ return build_detection_test_loader(cfg, dataset_name)
570
+
571
+ @classmethod
572
+ def build_evaluator(cls, cfg, dataset_name):
573
+ """
574
+ Returns:
575
+ DatasetEvaluator or None
576
+
577
+ It is not implemented by default.
578
+ """
579
+ raise NotImplementedError(
580
+ """
581
+ If you want DefaultTrainer to automatically run evaluation,
582
+ please implement `build_evaluator()` in subclasses (see train_net.py for example).
583
+ Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
584
+ """
585
+ )
586
+
587
+ @classmethod
588
+ def test(cls, cfg, model, evaluators=None):
589
+ """
590
+ Evaluate the given model. The given model is expected to already contain
591
+ weights to evaluate.
592
+
593
+ Args:
594
+ cfg (CfgNode):
595
+ model (nn.Module):
596
+ evaluators (list[DatasetEvaluator] or None): if None, will call
597
+ :meth:`build_evaluator`. Otherwise, must have the same length as
598
+ ``cfg.DATASETS.TEST``.
599
+
600
+ Returns:
601
+ dict: a dict of result metrics
602
+ """
603
+ logger = logging.getLogger(__name__)
604
+ if isinstance(evaluators, DatasetEvaluator):
605
+ evaluators = [evaluators]
606
+ if evaluators is not None:
607
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
608
+ len(cfg.DATASETS.TEST), len(evaluators)
609
+ )
610
+
611
+ results = OrderedDict()
612
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
613
+ data_loader = cls.build_test_loader(cfg, dataset_name)
614
+ # When evaluators are passed in as arguments,
615
+ # implicitly assume that evaluators can be created before data_loader.
616
+ if evaluators is not None:
617
+ evaluator = evaluators[idx]
618
+ else:
619
+ try:
620
+ evaluator = cls.build_evaluator(cfg, dataset_name)
621
+ except NotImplementedError:
622
+ logger.warn(
623
+ "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
624
+ "or implement its `build_evaluator` method."
625
+ )
626
+ results[dataset_name] = {}
627
+ continue
628
+ results_i = inference_on_dataset(model, data_loader, evaluator)
629
+ results[dataset_name] = results_i
630
+ if comm.is_main_process():
631
+ assert isinstance(
632
+ results_i, dict
633
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(
634
+ results_i
635
+ )
636
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
637
+ print_csv_format(results_i)
638
+
639
+ if len(results) == 1:
640
+ results = list(results.values())[0]
641
+ return results
642
+
643
+ @staticmethod
644
+ def auto_scale_workers(cfg, num_workers: int):
645
+ """
646
+ When the config is defined for certain number of workers (according to
647
+ ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
648
+ workers currently in use, returns a new cfg where the total batch size
649
+ is scaled so that the per-GPU batch size stays the same as the
650
+ original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
651
+
652
+ Other config options are also scaled accordingly:
653
+ * training steps and warmup steps are scaled inverse proportionally.
654
+ * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
655
+
656
+ For example, with the original config like the following:
657
+
658
+ .. code-block:: yaml
659
+
660
+ IMS_PER_BATCH: 16
661
+ BASE_LR: 0.1
662
+ REFERENCE_WORLD_SIZE: 8
663
+ MAX_ITER: 5000
664
+ STEPS: (4000,)
665
+ CHECKPOINT_PERIOD: 1000
666
+
667
+ When this config is used on 16 GPUs instead of the reference number 8,
668
+ calling this method will return a new config with:
669
+
670
+ .. code-block:: yaml
671
+
672
+ IMS_PER_BATCH: 32
673
+ BASE_LR: 0.2
674
+ REFERENCE_WORLD_SIZE: 16
675
+ MAX_ITER: 2500
676
+ STEPS: (2000,)
677
+ CHECKPOINT_PERIOD: 500
678
+
679
+ Note that both the original config and this new config can be trained on 16 GPUs.
680
+ It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
681
+
682
+ Returns:
683
+ CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
684
+ """
685
+ old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
686
+ if old_world_size == 0 or old_world_size == num_workers:
687
+ return cfg
688
+ cfg = cfg.clone()
689
+ frozen = cfg.is_frozen()
690
+ cfg.defrost()
691
+
692
+ assert (
693
+ cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
694
+ ), "Invalid REFERENCE_WORLD_SIZE in config!"
695
+ scale = num_workers / old_world_size
696
+ bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
697
+ lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
698
+ max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
699
+ warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
700
+ cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
701
+ cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
702
+ cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
703
+ cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
704
+ logger = logging.getLogger(__name__)
705
+ logger.info(
706
+ f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
707
+ f"max_iter={max_iter}, warmup={warmup_iter}."
708
+ )
709
+
710
+ if frozen:
711
+ cfg.freeze()
712
+ return cfg
713
+
714
+
715
+ # Access basic attributes from the underlying trainer
716
+ for _attr in ["model", "data_loader", "optimizer"]:
717
+ setattr(
718
+ DefaultTrainer,
719
+ _attr,
720
+ property(
721
+ # getter
722
+ lambda self, x=_attr: getattr(self._trainer, x),
723
+ # setter
724
+ lambda self, value, x=_attr: setattr(self._trainer, x, value),
725
+ ),
726
+ )