Commit ·
08ec965
1
Parent(s): 985cbbd
ready init project
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +4 -0
- INSTALL.md +34 -0
- README.md +405 -0
- cog.yaml +24 -0
- cutler/__init__.py +15 -0
- cutler/config/__init__.py +3 -0
- cutler/config/__pycache__/__init__.cpython-312.pyc +0 -0
- cutler/config/__pycache__/cutler_config.cpython-312.pyc +0 -0
- cutler/config/cutler_config.py +19 -0
- cutler/data/__init__.py +15 -0
- cutler/data/__pycache__/__init__.cpython-312.pyc +0 -0
- cutler/data/__pycache__/build.cpython-312.pyc +0 -0
- cutler/data/__pycache__/dataset_mapper.cpython-312.pyc +0 -0
- cutler/data/__pycache__/detection_utils.cpython-312.pyc +0 -0
- cutler/data/build.py +561 -0
- cutler/data/dataset_mapper.py +193 -0
- cutler/data/datasets/__init__.py +16 -0
- cutler/data/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
- cutler/data/datasets/__pycache__/builtin.cpython-312.pyc +0 -0
- cutler/data/datasets/__pycache__/builtin_meta.cpython-312.pyc +0 -0
- cutler/data/datasets/__pycache__/coco.cpython-312.pyc +0 -0
- cutler/data/datasets/builtin.py +216 -0
- cutler/data/datasets/builtin_meta.py +389 -0
- cutler/data/datasets/coco.py +544 -0
- cutler/data/detection_utils.py +650 -0
- cutler/data/transforms/__init__.py +15 -0
- cutler/data/transforms/__pycache__/__init__.cpython-312.pyc +0 -0
- cutler/data/transforms/__pycache__/augmentation_impl.cpython-312.pyc +0 -0
- cutler/data/transforms/__pycache__/transform.cpython-312.pyc +0 -0
- cutler/data/transforms/augmentation_impl.py +616 -0
- cutler/data/transforms/transform.py +355 -0
- cutler/demo/__init__.py +5 -0
- cutler/demo/__pycache__/predictor.cpython-312.pyc +0 -0
- cutler/demo/cutler_cascade_final.pth +3 -0
- cutler/demo/demo.py +197 -0
- cutler/demo/imgs/demo1.jpg +3 -0
- cutler/demo/imgs/demo2.jpg +3 -0
- cutler/demo/imgs/demo3.jpg +3 -0
- cutler/demo/imgs/demo4.jpg +3 -0
- cutler/demo/imgs/demo5.jpg +3 -0
- cutler/demo/imgs/demo6.jpg +3 -0
- cutler/demo/imgs/demo7.jpg +3 -0
- cutler/demo/imgs/demo8.jpg +3 -0
- cutler/demo/predictor.py +219 -0
- cutler/demo/wget-log +0 -0
- cutler/engine/__init__.py +7 -0
- cutler/engine/__pycache__/__init__.cpython-312.pyc +0 -0
- cutler/engine/__pycache__/defaults.cpython-312.pyc +0 -0
- cutler/engine/__pycache__/train_loop.cpython-312.pyc +0 -0
- cutler/engine/defaults.py +726 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.txt filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
INSTALL.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Installation
|
| 3 |
+
|
| 4 |
+
## Requirements
|
| 5 |
+
- Linux or macOS with Python ≥ 3.8
|
| 6 |
+
- PyTorch ≥ 1.8 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
|
| 7 |
+
Install them together at [pytorch.org](https://pytorch.org) to make sure of this.
|
| 8 |
+
Note, please check PyTorch version matches that is required by Detectron2.
|
| 9 |
+
- Detectron2: follow Detectron2 installation instructions.
|
| 10 |
+
- OpenCV ≥ 4.6 is needed by demo and visualization.
|
| 11 |
+
|
| 12 |
+
## Example conda environment setup
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
conda create --name cutler python=3.8 -y
|
| 16 |
+
conda activate cutler
|
| 17 |
+
conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -c pytorch
|
| 18 |
+
pip install git+https://github.com/lucasb-eyer/pydensecrf.git
|
| 19 |
+
|
| 20 |
+
# under your working directory
|
| 21 |
+
git clone git@github.com:facebookresearch/detectron2.git
|
| 22 |
+
cd detectron2
|
| 23 |
+
pip install -e .
|
| 24 |
+
pip install git+https://github.com/cocodataset/panopticapi.git
|
| 25 |
+
pip install git+https://github.com/mcordts/cityscapesScripts.git
|
| 26 |
+
|
| 27 |
+
cd ..
|
| 28 |
+
git clone --recursive git@github.com:facebookresearch/CutLER.git
|
| 29 |
+
cd CutLER
|
| 30 |
+
pip install -r requirements.txt
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## datasets
|
| 34 |
+
If you want to train/evaluate on the datasets, please see [datasets/README.md](datasets/README.md) to see how we prepare datasets for this project.
|
README.md
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Cut and Learn for Unsupervised Image & Video Object Detection and Instance Segmentation
|
| 2 |
+
|
| 3 |
+
**Cut**-and-**LE**a**R**n (**CutLER**) is a simple approach for training object detection and instance segmentation models without human annotations.
|
| 4 |
+
It outperforms previous SOTA by **2.7 times** for AP50 and **2.6 times** for AR on **11 benchmarks**.
|
| 5 |
+
|
| 6 |
+
<p align="center"> <img src='docs/teaser_img.jpg' align="center" > </p>
|
| 7 |
+
|
| 8 |
+
> [**Cut and Learn for Unsupervised Object Detection and Instance Segmentation**](http://people.eecs.berkeley.edu/~xdwang/projects/CutLER/)
|
| 9 |
+
> [Xudong Wang](https://people.eecs.berkeley.edu/~xdwang/), [Rohit Girdhar](https://rohitgirdhar.github.io/), [Stella X. Yu](https://www1.icsi.berkeley.edu/~stellayu/), [Ishan Misra](https://imisra.github.io/)
|
| 10 |
+
> FAIR, Meta AI; UC Berkeley
|
| 11 |
+
> CVPR 2023
|
| 12 |
+
|
| 13 |
+
[[`project page`](http://people.eecs.berkeley.edu/~xdwang/projects/CutLER/)] [[`arxiv`](https://arxiv.org/abs/2301.11320)] [[`colab`](https://colab.research.google.com/drive/1NgEyFHvOfuA2MZZnfNPWg1w5gSr3HOBb?usp=sharing)] [[`bibtex`](#citation)]
|
| 14 |
+
|
| 15 |
+
Unsupervised video instance segmentation (**VideoCutLER**) is also supported. ***We demonstrate that video instance segmentation models can be learned without using any human annotations, without relying on natural videos (ImageNet data alone is sufficient), and even without motion estimations!*** The code is available [here](videocutler).
|
| 16 |
+
|
| 17 |
+
<p align="center">
|
| 18 |
+
<img src="docs/demos_videocutler.gif" width=100%>
|
| 19 |
+
</p>
|
| 20 |
+
|
| 21 |
+
> [**VideoCutLER: Surprisingly Simple Unsupervised Video Instance Segmentation**](https://people.eecs.berkeley.edu/~xdwang/projects/VideoCutLER/videocutler.pdf)
|
| 22 |
+
> [Xudong Wang](https://people.eecs.berkeley.edu/~xdwang/), [Ishan Misra](https://imisra.github.io/), Ziyun Zeng, [Rohit Girdhar](https://rohitgirdhar.github.io/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/)
|
| 23 |
+
> UC Berkeley; FAIR, Meta AI
|
| 24 |
+
> CVPR 2024
|
| 25 |
+
|
| 26 |
+
[[`code`](videocutler/README.md)] [[`PDF`](https://people.eecs.berkeley.edu/~xdwang/projects/VideoCutLER/videocutler.pdf)] [[`arxiv`](https://arxiv.org/abs/2308.14710)] [[`bibtex`](#citation)]
|
| 27 |
+
|
| 28 |
+
## Features
|
| 29 |
+
- We propose MaskCut approach to generate pseudo-masks for multiple objects in an image.
|
| 30 |
+
- CutLER can learn unsupervised object detectors and instance segmentors solely on ImageNet-1K.
|
| 31 |
+
- CutLER exhibits strong robustness to domain shifts when evaluated on 11 different benchmarks across domains like natural images, video frames, paintings, sketches, etc.
|
| 32 |
+
- CutLER can serve as a pretrained model for fully/semi-supervised detection and segmentation tasks.
|
| 33 |
+
- We also propose VideoCutLER, a surprisingly simple unsupervised video instance segmentation (UVIS) method without relying on optical flows. ImaegNet-1K is all we need for training a SOTA UVIS model!
|
| 34 |
+
|
| 35 |
+
## Installation
|
| 36 |
+
See [installation instructions](INSTALL.md).
|
| 37 |
+
|
| 38 |
+
## Dataset Preparation
|
| 39 |
+
See [Preparing Datasets for CutLER](datasets/README.md).
|
| 40 |
+
|
| 41 |
+
## Method Overview
|
| 42 |
+
<p align="center">
|
| 43 |
+
<img src="docs/pipeline.jpg" width=55%>
|
| 44 |
+
</p>
|
| 45 |
+
Cut-and-Learn has two stages: 1) generating pseudo-masks with MaskCut and 2) learning unsupervised detectors from pseudo-masks of unlabeled data.
|
| 46 |
+
|
| 47 |
+
### 1. MaskCut
|
| 48 |
+
|
| 49 |
+
MaskCut can be used to provide segmentation masks for multiple instances of each image.
|
| 50 |
+
<p align="center">
|
| 51 |
+
<img src="docs/maskcut.gif" width=100%>
|
| 52 |
+
</p>
|
| 53 |
+
|
| 54 |
+
### MaskCut Demo
|
| 55 |
+
|
| 56 |
+
Try out the MaskCut demo using Colab (no GPU needed): [](https://colab.research.google.com/drive/1X05lKL_IBRvZB7q6n6pb4w00_tIYjGlf?usp=sharing)
|
| 57 |
+
|
| 58 |
+
Try out the web demo: [](https://huggingface.co/spaces/facebook/MaskCut) (thanks to [@hysts](https://github.com/hysts)!)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
If you want to run MaskCut locally, we provide `demo.py` that is able to visualize the pseudo-masks produced by MaskCut.
|
| 64 |
+
Run it with:
|
| 65 |
+
```
|
| 66 |
+
cd maskcut
|
| 67 |
+
python demo.py --img-path imgs/demo2.jpg \
|
| 68 |
+
--N 3 --tau 0.15 --vit-arch base --patch-size 8 \
|
| 69 |
+
[--other-options]
|
| 70 |
+
```
|
| 71 |
+
We give a few demo images in maskcut/imgs/. If you want to run demo.py with cpu, simply add "--cpu" when running the demo script.
|
| 72 |
+
For imgs/demo4.jpg, you need to use "--N 6" to segment all six instances in the image.
|
| 73 |
+
Following, we give some visualizations of the pseudo-masks on the demo images.
|
| 74 |
+
<p align="center">
|
| 75 |
+
<img src="docs/maskcut-demo.jpg" width=100%>
|
| 76 |
+
</p>
|
| 77 |
+
|
| 78 |
+
### Generating Annotations for ImageNet-1K with MaskCut
|
| 79 |
+
To generate pseudo-masks for ImageNet-1K using MaskCut, first set up the ImageNet-1K dataset according to the instructions in [datasets/README.md](datasets/README.md), then execute the following command:
|
| 80 |
+
```
|
| 81 |
+
cd maskcut
|
| 82 |
+
python maskcut.py \
|
| 83 |
+
--vit-arch base --patch-size 8 \
|
| 84 |
+
--tau 0.15 --fixed_size 480 --N 3 \
|
| 85 |
+
--num-folder-per-job 1000 --job-index 0 \
|
| 86 |
+
--dataset-path /path/to/dataset/traindir \
|
| 87 |
+
--out-dir /path/to/save/annotations \
|
| 88 |
+
```
|
| 89 |
+
As the process of generating pseudo-masks for all 1.3 million images in 1,000 folders takes a significant amount of time, it is recommended to use multiple runs. Each run should process the pseudo-mask generation for a smaller number of image folders by setting "--num-folder-per-job" and "--job-index". Once all runs are completed, you can merge all the resulting json files by using the following command:
|
| 90 |
+
```
|
| 91 |
+
python merge_jsons.py \
|
| 92 |
+
--base-dir /path/to/save/annotations \
|
| 93 |
+
--num-folder-per-job 2 --fixed-size 480 \
|
| 94 |
+
--tau 0.15 --N 3 \
|
| 95 |
+
--save-path imagenet_train_fixsize480_tau0.15_N3.json
|
| 96 |
+
```
|
| 97 |
+
The "--num-folder-per-job", "--fixed-size", "--tau" and "--N" of merge_jsons.py should match the ones used to run maskcut.py.
|
| 98 |
+
|
| 99 |
+
We also provide a submitit script to launch the pseudo-mask generation process with multiple nodes.
|
| 100 |
+
```
|
| 101 |
+
cd maskcut
|
| 102 |
+
bash run_maskcut_with_submitit.sh
|
| 103 |
+
```
|
| 104 |
+
After that, you can use "merge_jsons.py" to merge all these json files as described above.
|
| 105 |
+
|
| 106 |
+
### 2. CutLER
|
| 107 |
+
|
| 108 |
+
### Inference Demo for CutLER with Pre-trained Models
|
| 109 |
+
Try out the CutLER demo using Colab (no GPU needed): [](https://colab.research.google.com/drive/1NgEyFHvOfuA2MZZnfNPWg1w5gSr3HOBb?usp=sharing)
|
| 110 |
+
|
| 111 |
+
Try out the web demo: [](https://huggingface.co/spaces/facebook/CutLER) (thanks to [@hysts](https://github.com/hysts)!)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
Try out Replicate demo and the API: [](https://replicate.com/cjwbw/cutler)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
If you want to run CutLER demos locally,
|
| 118 |
+
1. Pick a model and its config file from [model zoo](#model-zoo),
|
| 119 |
+
for example, `model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml`.
|
| 120 |
+
2. We provide `demo.py` that is able to demo builtin configs. Run it with:
|
| 121 |
+
```
|
| 122 |
+
cd cutler
|
| 123 |
+
python demo/demo.py --config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN_demo.yaml \
|
| 124 |
+
--input demo/imgs/*.jpg \
|
| 125 |
+
[--other-options]
|
| 126 |
+
--opts MODEL.WEIGHTS /path/to/cutler_w_cascade_checkpoint
|
| 127 |
+
```
|
| 128 |
+
The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation.
|
| 129 |
+
This command will run the inference and show visualizations in an OpenCV window.
|
| 130 |
+
<!-- For details of the command line arguments, see `demo.py -h` or look at its source code
|
| 131 |
+
to understand its behavior. Some common arguments are: -->
|
| 132 |
+
* To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`.
|
| 133 |
+
* To save outputs to a directory (for images) or a file (for webcam or video), use `--output`.
|
| 134 |
+
|
| 135 |
+
Following, we give some visualizations of the model predictions on the demo images.
|
| 136 |
+
<p align="center">
|
| 137 |
+
<img src="docs/cutler-demo.jpg" width=100%>
|
| 138 |
+
</p>
|
| 139 |
+
|
| 140 |
+
### Unsupervised Model Learning
|
| 141 |
+
Before training the detector, it is necessary to use MaskCut to generate pseudo-masks for all ImageNet data.
|
| 142 |
+
You can either use the pre-generated json file directly by downloading it from [here](http://dl.fbaipublicfiles.com/cutler/maskcut/imagenet_train_fixsize480_tau0.15_N3.json) and placing it under "DETECTRON2_DATASETS/imagenet/annotations/", or generate your own pseudo-masks by following the instructions in [MaskCut](#1-maskcut).
|
| 143 |
+
|
| 144 |
+
We provide a script `train_net.py`, that is made to train all the configs provided in CutLER.
|
| 145 |
+
To train a model with "train_net.py", first setup the ImageNet-1K dataset following [datasets/README.md](datasets/README.md), then run:
|
| 146 |
+
```
|
| 147 |
+
cd cutler
|
| 148 |
+
export DETECTRON2_DATASETS=/path/to/DETECTRON2_DATASETS/
|
| 149 |
+
python train_net.py --num-gpus 8 \
|
| 150 |
+
--config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
If you want to train a model using multiple nodes, you may need to adjust [some model parameters](https://arxiv.org/abs/1706.02677) and some SBATCH command options in "tools/train-1node.sh" and "tools/single-node_run.sh", then run:
|
| 154 |
+
```
|
| 155 |
+
cd cutler
|
| 156 |
+
sbatch tools/train-1node.sh \
|
| 157 |
+
--config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml \
|
| 158 |
+
MODEL.WEIGHTS /path/to/dino/d2format/model \
|
| 159 |
+
OUTPUT_DIR output/
|
| 160 |
+
```
|
| 161 |
+
You can also convert a pre-trained DINO model to detectron2's format by yourself following [this link](https://github.com/facebookresearch/moco/tree/main/detection).
|
| 162 |
+
|
| 163 |
+
### Self-training
|
| 164 |
+
We further improve performance by self-training the model on its predictions.
|
| 165 |
+
|
| 166 |
+
Firstly, we can get model predictions on ImageNet via running:
|
| 167 |
+
```
|
| 168 |
+
python train_net.py --num-gpus 8 \
|
| 169 |
+
--config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml \
|
| 170 |
+
--test-dataset imagenet_train \
|
| 171 |
+
--eval-only TEST.DETECTIONS_PER_IMAGE 30 \
|
| 172 |
+
MODEL.WEIGHTS output/model_final.pth \ # load previous stage/round checkpoints
|
| 173 |
+
OUTPUT_DIR output/ # path to save model predictions
|
| 174 |
+
```
|
| 175 |
+
Secondly, we can run the following command to generate the json file for the first round of self-training:
|
| 176 |
+
```
|
| 177 |
+
python tools/get_self_training_ann.py \
|
| 178 |
+
--new-pred output/inference/coco_instances_results.json \ # load model predictions
|
| 179 |
+
--prev-ann DETECTRON2_DATASETS/imagenet/annotations/imagenet_train_fixsize480_tau0.15_N3.json \ # path to the old annotation file.
|
| 180 |
+
--save-path DETECTRON2_DATASETS/imagenet/annotations/cutler_imagenet1k_train_r1.json \ # path to save a new annotation file.
|
| 181 |
+
--threshold 0.7
|
| 182 |
+
```
|
| 183 |
+
Finally, place "cutler_imagenet1k_train_r1.json" under "DETECTRON2_DATASETS/imagenet/annotations/", then launch the self-training process:
|
| 184 |
+
```
|
| 185 |
+
python train_net.py --num-gpus 8 \
|
| 186 |
+
--config-file model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN_self_train.yaml \
|
| 187 |
+
--train-dataset imagenet_train_r1 \
|
| 188 |
+
MODEL.WEIGHTS output/model_final.pth \ # load previous stage/round checkpoints
|
| 189 |
+
OUTPUT_DIR output/self-train-r1/ # path to save checkpoints
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
You can repeat the steps above to perform multiple rounds of self-training and adjust some arguments as needed (e.g., "--threshold" for round 1 and 2 can be set to 0.7 and 0.65, respectively; "--train-dataset" for round 1 and 2 can be set to "imagenet_train_r1" and "imagenet_train_r2", respectively; MODEL.WEIGHTS for round 1 and 2 should point to the previous stage/round checkpoints). Ensure that all annotation files are placed under DETECTRON2_DATASETS/imagenet/annotations/.
|
| 193 |
+
Please ensure that "--train-dataset", json file names and locations match the ones specified in "cutler/data/datasets/builtin.py".
|
| 194 |
+
Please refer to this [instruction](https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html) for guidance on using custom datasets.
|
| 195 |
+
|
| 196 |
+
You can also directly download the MODEL.WEIGHTS and annotations used for each round of self-training:
|
| 197 |
+
<table><tbody>
|
| 198 |
+
<!-- START TABLE -->
|
| 199 |
+
<!-- TABLE BODY -->
|
| 200 |
+
<!-- ROW: round 1 -->
|
| 201 |
+
<tr><td align="center">round 1</td>
|
| 202 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_r1.pth">cutler_cascade_r1.pth</a></td>
|
| 203 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/maskcut/cutler_imagenet1k_train_r1.json">cutler_imagenet1k_train_r1.json</a></td>
|
| 204 |
+
</tr>
|
| 205 |
+
<!-- ROW: round 2 -->
|
| 206 |
+
<tr><td align="center">round 2</td>
|
| 207 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_r2.pth">cutler_cascade_r2.pth</a></td>
|
| 208 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/maskcut/cutler_imagenet1k_train_r2.json">cutler_imagenet1k_train_r2.json</a></td>
|
| 209 |
+
</tr>
|
| 210 |
+
</tbody></table>
|
| 211 |
+
|
| 212 |
+
### Unsupervised Zero-shot Evaluation
|
| 213 |
+
To evaluate a model's performance on 11 different datasets, please refer to [datasets/README.md](datasets/README.md) for instructions on preparing the datasets. Next, select a model from the model zoo, specify the "model_weights", "config_file" and the path to "DETECTRON2_DATASETS" in `tools/eval.sh`, then run the script.
|
| 214 |
+
```
|
| 215 |
+
bash tools/eval.sh
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
### Model Zoo
|
| 219 |
+
We show zero-shot unsupervised object detection performance (AP50 | AR) on 11 different datasets spanning a variety of domains. ^: CutLER using Mask R-CNN as a detector; *: CutLER using Cascade Mask R-CNN as a detector.
|
| 220 |
+
<table><tbody>
|
| 221 |
+
<!-- START TABLE -->
|
| 222 |
+
<!-- TABLE HEADER -->
|
| 223 |
+
<th valign="bottom">Methods</th>
|
| 224 |
+
<th valign="bottom">Models</th>
|
| 225 |
+
<th valign="bottom">COCO</th>
|
| 226 |
+
<th valign="bottom">COCO20K</th>
|
| 227 |
+
<th valign="bottom">VOC</th>
|
| 228 |
+
<th valign="bottom">LVIS</th>
|
| 229 |
+
<th valign="bottom">UVO</th>
|
| 230 |
+
<th valign="bottom">Clipart</th>
|
| 231 |
+
<th valign="bottom">Comic</th>
|
| 232 |
+
<th valign="bottom">Watercolor</th>
|
| 233 |
+
<th valign="bottom">KITTI</th>
|
| 234 |
+
<th valign="bottom">Objects365</th>
|
| 235 |
+
<th valign="bottom">OpenImages</th>
|
| 236 |
+
<!-- TABLE BODY -->
|
| 237 |
+
</tr>
|
| 238 |
+
<tr><td align="center">Prev. SOTA</td>
|
| 239 |
+
<td valign="bottom">-</td>
|
| 240 |
+
<td align="center">9.6 | 12.6</td>
|
| 241 |
+
<td align="center">9.7 | 12.6</td>
|
| 242 |
+
<td align="center">15.9 | 21.3</td>
|
| 243 |
+
<td align="center">3.8 | 6.4</td>
|
| 244 |
+
<td align="center">10.0 | 14.2</td>
|
| 245 |
+
<td align="center">7.9 | 15.1</td>
|
| 246 |
+
<td align="center">9.9 | 16.3</td>
|
| 247 |
+
<td align="center">6.7 | 16.2</td>
|
| 248 |
+
<td align="center">7.7 | 7.1</td>
|
| 249 |
+
<td align="center">8.1 | 10.2</td>
|
| 250 |
+
<td align="center">9.9 | 14.9</td>
|
| 251 |
+
</tr>
|
| 252 |
+
<!-- ROW: Box/Mask AP for CutLER -->
|
| 253 |
+
</tr>
|
| 254 |
+
<tr><td align="center">CutLER^</td>
|
| 255 |
+
<td valign="bottom"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_mrcnn_final.pth">download</a></td>
|
| 256 |
+
<td align="center">21.1 | 29.6</td>
|
| 257 |
+
<td align="center">21.6 | 30.0</td>
|
| 258 |
+
<td align="center">36.6 | 41.0</td>
|
| 259 |
+
<td align="center">7.7 | 18.7</td>
|
| 260 |
+
<td align="center">29.8 | 38.4</td>
|
| 261 |
+
<td align="center">20.9 | 38.5</td>
|
| 262 |
+
<td align="center">31.2 | 37.1</td>
|
| 263 |
+
<td align="center">37.3 | 39.9</td>
|
| 264 |
+
<td align="center">15.3 | 25.4</td>
|
| 265 |
+
<td align="center">19.5 | 30.0</td>
|
| 266 |
+
<td align="center">17.1 | 26.4</td>
|
| 267 |
+
</tr>
|
| 268 |
+
<!-- ROW: Box/Mask AP for CutLER -->
|
| 269 |
+
</tr>
|
| 270 |
+
<tr><td align="center">CutLER*</td>
|
| 271 |
+
<td valign="bottom"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth">download</a></td>
|
| 272 |
+
<td align="center">21.9 | 32.7</td>
|
| 273 |
+
<td align="center">22.4 | 33.1</td>
|
| 274 |
+
<td align="center">36.9 | 44.3</td>
|
| 275 |
+
<td align="center">8.4 | 21.8</td>
|
| 276 |
+
<td align="center">31.7 | 42.8</td>
|
| 277 |
+
<td align="center">21.1 | 41.3</td>
|
| 278 |
+
<td align="center">30.4 | 38.6</td>
|
| 279 |
+
<td align="center">37.5 | 44.6</td>
|
| 280 |
+
<td align="center">18.4 | 27.5</td>
|
| 281 |
+
<td align="center">21.6 | 34.2</td>
|
| 282 |
+
<td align="center">17.3 | 29.6</td>
|
| 283 |
+
</tr>
|
| 284 |
+
</tbody></table>
|
| 285 |
+
|
| 286 |
+
## Semi-supervised and Fully-supervised Learning
|
| 287 |
+
CutLER can also serve as a pretrained model for training fully supervised object detection and instance segmentation models and improves performance on COCO, including on few-shot benchmarks.
|
| 288 |
+
|
| 289 |
+
### Training & Evaluation in Command Line
|
| 290 |
+
You can find all the semi-supervised and fully-supervised learning configs provided in CutLER under `model_zoo/configs/COCO-Semisupervised`.
|
| 291 |
+
|
| 292 |
+
To train a model using K% labels with `train_net.py`, first set up the COCO dataset according to [datasets/README.md](datasets/README.md) and specify K value in the config file, then run:
|
| 293 |
+
```
|
| 294 |
+
python train_net.py --num-gpus 8 \
|
| 295 |
+
--config-file model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_{K}perc.yaml \
|
| 296 |
+
MODEL.WEIGHTS /path/to/cutler_pretrained_model
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
You can find all config files used to train supervised models under `model_zoo/configs/COCO-Semisupervised`.
|
| 300 |
+
The configs are made for 8-GPU training. To train on 1 GPU, you may need to [change some parameters](https://arxiv.org/abs/1706.02677), e.g. number of GPUs (num-gpus your_num_gpus), learning rates (SOLVER.BASE_LR your_base_lr) and batch size (SOLVER.IMS_PER_BATCH your_batch_size).
|
| 301 |
+
|
| 302 |
+
### Evaluation
|
| 303 |
+
To evaluate a model's performance, use
|
| 304 |
+
```
|
| 305 |
+
python train_net.py \
|
| 306 |
+
--config-file model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_{K}perc.yaml \
|
| 307 |
+
--eval-only MODEL.WEIGHTS /path/to/checkpoint_file
|
| 308 |
+
```
|
| 309 |
+
For more options, see `python train_net.py -h`.
|
| 310 |
+
|
| 311 |
+
### Model Zoo
|
| 312 |
+
We fine-tune a Cascade R-CNN model initialized with CutLER or MoCo-v2 on varying amounts of labeled COCO data, and show results (Box | Mask AP) on the val2017 split below:
|
| 313 |
+
|
| 314 |
+
<table><tbody>
|
| 315 |
+
<!-- START TABLE -->
|
| 316 |
+
<!-- TABLE HEADER -->
|
| 317 |
+
<th valign="bottom">% of labels</th>
|
| 318 |
+
<th valign="bottom">1%</th>
|
| 319 |
+
<th valign="bottom">2%</th>
|
| 320 |
+
<th valign="bottom">5%</th>
|
| 321 |
+
<th valign="bottom">10%</th>
|
| 322 |
+
<th valign="bottom">20%</th>
|
| 323 |
+
<th valign="bottom">30%</th>
|
| 324 |
+
<th valign="bottom">40%</th>
|
| 325 |
+
<th valign="bottom">50%</th>
|
| 326 |
+
<th valign="bottom">60%</th>
|
| 327 |
+
<th valign="bottom">80%</th>
|
| 328 |
+
<th valign="bottom">100%</th>
|
| 329 |
+
<!-- TABLE BODY -->
|
| 330 |
+
<!-- ROW: Box/Mask AP for CutLER -->
|
| 331 |
+
<tr><td align="center">MoCo-v2</td>
|
| 332 |
+
<td align="center">11.8 | 10.0</td>
|
| 333 |
+
<td align="center">16.2 | 13.8</td>
|
| 334 |
+
<td align="center">20.5 | 17.8</td>
|
| 335 |
+
<td align="center">26.5 | 23.0</td>
|
| 336 |
+
<td align="center">32.5 | 28.2</td>
|
| 337 |
+
<td align="center">35.5 | 30.8</td>
|
| 338 |
+
<td align="center">37.3 | 32.3</td>
|
| 339 |
+
<td align="center">38.7 | 33.6</td>
|
| 340 |
+
<td align="center">39.9 | 34.6</td>
|
| 341 |
+
<td align="center">41.6 | 36.0</td>
|
| 342 |
+
<td align="center">42.8 | 37.0</td>
|
| 343 |
+
</tr>
|
| 344 |
+
<!-- ROW: Mask AP -->
|
| 345 |
+
<tr><td align="center">CutLER</td>
|
| 346 |
+
<td align="center">16.8 | 14.6</td>
|
| 347 |
+
<td align="center">21.6 | 18.9</td>
|
| 348 |
+
<td align="center">27.8 | 24.3</td>
|
| 349 |
+
<td align="center">32.2 | 28.1</td>
|
| 350 |
+
<td align="center">36.6 | 31.7</td>
|
| 351 |
+
<td align="center">38.2 | 33.3</td>
|
| 352 |
+
<td align="center">39.9 | 34.7</td>
|
| 353 |
+
<td align="center">41.5 | 35.9</td>
|
| 354 |
+
<td align="center">42.3 | 36.7</td>
|
| 355 |
+
<td align="center">43.8 | 37.9</td>
|
| 356 |
+
<td align="center">44.7 | 38.5</td>
|
| 357 |
+
</tr>
|
| 358 |
+
<!-- ROW: Model Downloads -->
|
| 359 |
+
<tr><td align="center">Download</td>
|
| 360 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_1perc.pth">model</a></td>
|
| 361 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_2perc.pth">model</a></td>
|
| 362 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_5perc.pth">model</a></td>
|
| 363 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_10perc.pth">model</a></td>
|
| 364 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_20perc.pth">model</a></td>
|
| 365 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_30perc.pth">model</a></td>
|
| 366 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_40perc.pth">model</a></td>
|
| 367 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_50perc.pth">model</a></td>
|
| 368 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_60perc.pth">model</a></td>
|
| 369 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_80perc.pth">model</a></td>
|
| 370 |
+
<td align="center"><a href="http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_fully_100perc.pth">model</a></td>
|
| 371 |
+
</tr>
|
| 372 |
+
</tbody></table>
|
| 373 |
+
|
| 374 |
+
Both MoCo-v2 and our CutLER are trained for the 1x schedule using Detectron2, except for extremely low-shot settings with 1% or 2% labels. When training with 1% or 2% labels, we train both MoCo-v2 and our model for 3,600 iterations with a batch size of 16.
|
| 375 |
+
|
| 376 |
+
## License
|
| 377 |
+
The majority of CutLER, Detectron2 and DINO are licensed under the [CC-BY-NC license](LICENSE), however portions of the project are available under separate license terms: TokenCut, Bilateral Solver and CRF are licensed under the MIT license; If you later add other third party code, please keep this license info updated, and please let us know if that component is licensed under something other than CC-BY-NC, MIT, or CC0.
|
| 378 |
+
|
| 379 |
+
## Ethical Considerations
|
| 380 |
+
CutLER's wide range of detection capabilities may introduce similar challenges to many other visual recognition methods.
|
| 381 |
+
As the image can contain arbitrary instances, it may impact the model output.
|
| 382 |
+
|
| 383 |
+
## How to get support from us?
|
| 384 |
+
If you have any general questions, feel free to email us at [Xudong Wang](mailto:xdwang@eecs.berkeley.edu), [Ishan Misra](mailto:imisra@meta.com) and [Rohit Girdhar](mailto:rgirdhar@meta.com). If you have code or implementation-related questions, please feel free to send emails to us or open an issue in this codebase (We recommend that you open an issue in this codebase, because your questions may help others).
|
| 385 |
+
|
| 386 |
+
## Citation
|
| 387 |
+
If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.
|
| 388 |
+
```
|
| 389 |
+
@inproceedings{wang2023cut,
|
| 390 |
+
title={Cut and learn for unsupervised object detection and instance segmentation},
|
| 391 |
+
author={Wang, Xudong and Girdhar, Rohit and Yu, Stella X and Misra, Ishan},
|
| 392 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
| 393 |
+
pages={3124--3134},
|
| 394 |
+
year={2023}
|
| 395 |
+
}
|
| 396 |
+
```
|
| 397 |
+
|
| 398 |
+
```
|
| 399 |
+
@article{wang2023videocutler,
|
| 400 |
+
title={VideoCutLER: Surprisingly Simple Unsupervised Video Instance Segmentation},
|
| 401 |
+
author={Wang, Xudong and Misra, Ishan and Zeng, Ziyun and Girdhar, Rohit and Darrell, Trevor},
|
| 402 |
+
journal={arXiv preprint arXiv:2308.14710},
|
| 403 |
+
year={2023}
|
| 404 |
+
}
|
| 405 |
+
```
|
cog.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build:
|
| 2 |
+
gpu: true
|
| 3 |
+
cuda: "11.6"
|
| 4 |
+
python_version: "3.8"
|
| 5 |
+
python_packages:
|
| 6 |
+
- "torch==1.11.0"
|
| 7 |
+
- "torchvision==0.12.0"
|
| 8 |
+
- "faiss-gpu==1.7.2"
|
| 9 |
+
- "opencv-python==4.6.0.66"
|
| 10 |
+
- "scikit-image==0.19.2"
|
| 11 |
+
- "scikit-learn==1.1.1"
|
| 12 |
+
- "shapely==1.8.2"
|
| 13 |
+
- "timm==0.5.4"
|
| 14 |
+
- "pyyaml==6.0"
|
| 15 |
+
- "colored==1.4.4"
|
| 16 |
+
- "fvcore==0.1.5.post20220512"
|
| 17 |
+
- "gdown==4.5.4"
|
| 18 |
+
- "pycocotools==2.0.6"
|
| 19 |
+
- "numpy==1.20.0"
|
| 20 |
+
|
| 21 |
+
run:
|
| 22 |
+
- pip install git+https://github.com/lucasb-eyer/pydensecrf.git
|
| 23 |
+
|
| 24 |
+
predict: "maskcut/predict.py:Predictor"
|
cutler/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import config
|
| 4 |
+
import engine
|
| 5 |
+
import modeling
|
| 6 |
+
import structures
|
| 7 |
+
import tools
|
| 8 |
+
import demo
|
| 9 |
+
|
| 10 |
+
# dataset loading
|
| 11 |
+
from . import data # register all new datasets
|
| 12 |
+
from data import datasets # register all new datasets
|
| 13 |
+
from solver import *
|
| 14 |
+
|
| 15 |
+
# from .data import register_all_imagenet
|
cutler/config/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from .cutler_config import add_cutler_config
|
cutler/config/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (267 Bytes). View file
|
|
|
cutler/config/__pycache__/cutler_config.cpython-312.pyc
ADDED
|
Binary file (1.22 kB). View file
|
|
|
cutler/config/cutler_config.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from detectron2.config import CfgNode as CN
|
| 4 |
+
|
| 5 |
+
def add_cutler_config(cfg):
|
| 6 |
+
cfg.DATALOADER.COPY_PASTE = False
|
| 7 |
+
cfg.DATALOADER.COPY_PASTE_RATE = 0.0
|
| 8 |
+
cfg.DATALOADER.COPY_PASTE_MIN_RATIO = 0.5
|
| 9 |
+
cfg.DATALOADER.COPY_PASTE_MAX_RATIO = 1.0
|
| 10 |
+
cfg.DATALOADER.COPY_PASTE_RANDOM_NUM = True
|
| 11 |
+
cfg.DATALOADER.VISUALIZE_COPY_PASTE = False
|
| 12 |
+
|
| 13 |
+
cfg.MODEL.ROI_HEADS.USE_DROPLOSS = False
|
| 14 |
+
cfg.MODEL.ROI_HEADS.DROPLOSS_IOU_THRESH = 0.0
|
| 15 |
+
|
| 16 |
+
cfg.SOLVER.BASE_LR_MULTIPLIER = 1
|
| 17 |
+
cfg.SOLVER.BASE_LR_MULTIPLIER_NAMES = []
|
| 18 |
+
|
| 19 |
+
cfg.TEST.NO_SEGM = False
|
cutler/data/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from . import datasets # ensure the builtin datasets are registered
|
| 4 |
+
from .detection_utils import * # isort:skip
|
| 5 |
+
from .build import (
|
| 6 |
+
build_batch_data_loader,
|
| 7 |
+
build_detection_train_loader,
|
| 8 |
+
build_detection_test_loader,
|
| 9 |
+
get_detection_dataset_dicts,
|
| 10 |
+
load_proposals_into_dataset,
|
| 11 |
+
print_instances_class_histogram,
|
| 12 |
+
)
|
| 13 |
+
from detectron2.data.common import *
|
| 14 |
+
|
| 15 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
cutler/data/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (793 Bytes). View file
|
|
|
cutler/data/__pycache__/build.cpython-312.pyc
ADDED
|
Binary file (24.3 kB). View file
|
|
|
cutler/data/__pycache__/dataset_mapper.cpython-312.pyc
ADDED
|
Binary file (8.69 kB). View file
|
|
|
cutler/data/__pycache__/detection_utils.cpython-312.pyc
ADDED
|
Binary file (27.6 kB). View file
|
|
|
cutler/data/build.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/build.py
|
| 3 |
+
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
import operator
|
| 8 |
+
import pickle
|
| 9 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 10 |
+
import torch
|
| 11 |
+
import torch.utils.data as torchdata
|
| 12 |
+
from tabulate import tabulate
|
| 13 |
+
from termcolor import colored
|
| 14 |
+
|
| 15 |
+
from detectron2.config import configurable
|
| 16 |
+
from detectron2.structures import BoxMode
|
| 17 |
+
from detectron2.utils.comm import get_world_size
|
| 18 |
+
from detectron2.utils.env import seed_all_rng
|
| 19 |
+
from detectron2.utils.file_io import PathManager
|
| 20 |
+
from detectron2.utils.logger import _log_api_usage, log_first_n
|
| 21 |
+
|
| 22 |
+
from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
|
| 23 |
+
from detectron2.data.common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
|
| 24 |
+
from data.dataset_mapper import DatasetMapper
|
| 25 |
+
from data.detection_utils import check_metadata_consistency
|
| 26 |
+
from detectron2.data.samplers import (
|
| 27 |
+
InferenceSampler,
|
| 28 |
+
RandomSubsetTrainingSampler,
|
| 29 |
+
RepeatFactorTrainingSampler,
|
| 30 |
+
TrainingSampler,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
This file contains the default logic to build a dataloader for training or testing.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
__all__ = [
|
| 38 |
+
"build_batch_data_loader",
|
| 39 |
+
"build_detection_train_loader",
|
| 40 |
+
"build_detection_test_loader",
|
| 41 |
+
"get_detection_dataset_dicts",
|
| 42 |
+
"load_proposals_into_dataset",
|
| 43 |
+
"print_instances_class_histogram",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def filter_images_with_only_crowd_annotations(dataset_dicts):
|
| 48 |
+
"""
|
| 49 |
+
Filter out images with none annotations or only crowd annotations
|
| 50 |
+
(i.e., images without non-crowd annotations).
|
| 51 |
+
A common training-time preprocessing on COCO dataset.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
list[dict]: the same format, but filtered.
|
| 58 |
+
"""
|
| 59 |
+
num_before = len(dataset_dicts)
|
| 60 |
+
|
| 61 |
+
def valid(anns):
|
| 62 |
+
for ann in anns:
|
| 63 |
+
if ann.get("iscrowd", 0) == 0:
|
| 64 |
+
return True
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
|
| 68 |
+
num_after = len(dataset_dicts)
|
| 69 |
+
logger = logging.getLogger(__name__)
|
| 70 |
+
logger.info(
|
| 71 |
+
"Removed {} images with no usable annotations. {} images left.".format(
|
| 72 |
+
num_before - num_after, num_after
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
print("Removed {} images with no usable annotations. {} images left.".format(
|
| 76 |
+
num_before - num_after, num_after
|
| 77 |
+
))
|
| 78 |
+
return dataset_dicts
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
|
| 82 |
+
"""
|
| 83 |
+
Filter out images with too few number of keypoints.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
list[dict]: the same format as dataset_dicts, but filtered.
|
| 90 |
+
"""
|
| 91 |
+
num_before = len(dataset_dicts)
|
| 92 |
+
|
| 93 |
+
def visible_keypoints_in_image(dic):
|
| 94 |
+
# Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
|
| 95 |
+
annotations = dic["annotations"]
|
| 96 |
+
return sum(
|
| 97 |
+
(np.array(ann["keypoints"][2::3]) > 0).sum()
|
| 98 |
+
for ann in annotations
|
| 99 |
+
if "keypoints" in ann
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
dataset_dicts = [
|
| 103 |
+
x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
|
| 104 |
+
]
|
| 105 |
+
num_after = len(dataset_dicts)
|
| 106 |
+
logger = logging.getLogger(__name__)
|
| 107 |
+
logger.info(
|
| 108 |
+
"Removed {} images with fewer than {} keypoints.".format(
|
| 109 |
+
num_before - num_after, min_keypoints_per_image
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
return dataset_dicts
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def load_proposals_into_dataset(dataset_dicts, proposal_file):
|
| 116 |
+
"""
|
| 117 |
+
Load precomputed object proposals into the dataset.
|
| 118 |
+
|
| 119 |
+
The proposal file should be a pickled dict with the following keys:
|
| 120 |
+
|
| 121 |
+
- "ids": list[int] or list[str], the image ids
|
| 122 |
+
- "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
|
| 123 |
+
- "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
|
| 124 |
+
corresponding to the boxes.
|
| 125 |
+
- "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 129 |
+
proposal_file (str): file path of pre-computed proposals, in pkl format.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
list[dict]: the same format as dataset_dicts, but added proposal field.
|
| 133 |
+
"""
|
| 134 |
+
logger = logging.getLogger(__name__)
|
| 135 |
+
logger.info("Loading proposals from: {}".format(proposal_file))
|
| 136 |
+
|
| 137 |
+
with PathManager.open(proposal_file, "rb") as f:
|
| 138 |
+
proposals = pickle.load(f, encoding="latin1")
|
| 139 |
+
|
| 140 |
+
# Rename the key names in D1 proposal files
|
| 141 |
+
rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
|
| 142 |
+
for key in rename_keys:
|
| 143 |
+
if key in proposals:
|
| 144 |
+
proposals[rename_keys[key]] = proposals.pop(key)
|
| 145 |
+
|
| 146 |
+
# Fetch the indexes of all proposals that are in the dataset
|
| 147 |
+
# Convert image_id to str since they could be int.
|
| 148 |
+
img_ids = set({str(record["image_id"]) for record in dataset_dicts})
|
| 149 |
+
id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
|
| 150 |
+
|
| 151 |
+
# Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
|
| 152 |
+
bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
|
| 153 |
+
|
| 154 |
+
for record in dataset_dicts:
|
| 155 |
+
# Get the index of the proposal
|
| 156 |
+
i = id_to_index[str(record["image_id"])]
|
| 157 |
+
|
| 158 |
+
boxes = proposals["boxes"][i]
|
| 159 |
+
objectness_logits = proposals["objectness_logits"][i]
|
| 160 |
+
# Sort the proposals in descending order of the scores
|
| 161 |
+
inds = objectness_logits.argsort()[::-1]
|
| 162 |
+
record["proposal_boxes"] = boxes[inds]
|
| 163 |
+
record["proposal_objectness_logits"] = objectness_logits[inds]
|
| 164 |
+
record["proposal_bbox_mode"] = bbox_mode
|
| 165 |
+
|
| 166 |
+
return dataset_dicts
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def print_instances_class_histogram(dataset_dicts, class_names):
|
| 170 |
+
"""
|
| 171 |
+
Args:
|
| 172 |
+
dataset_dicts (list[dict]): list of dataset dicts.
|
| 173 |
+
class_names (list[str]): list of class names (zero-indexed).
|
| 174 |
+
"""
|
| 175 |
+
num_classes = len(class_names)
|
| 176 |
+
hist_bins = np.arange(num_classes + 1)
|
| 177 |
+
histogram = np.zeros((num_classes,), dtype=np.int)
|
| 178 |
+
for entry in dataset_dicts:
|
| 179 |
+
annos = entry["annotations"]
|
| 180 |
+
classes = np.asarray(
|
| 181 |
+
[x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=np.int
|
| 182 |
+
)
|
| 183 |
+
if len(classes):
|
| 184 |
+
assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
|
| 185 |
+
assert (
|
| 186 |
+
classes.max() < num_classes
|
| 187 |
+
), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
|
| 188 |
+
histogram += np.histogram(classes, bins=hist_bins)[0]
|
| 189 |
+
|
| 190 |
+
N_COLS = min(6, len(class_names) * 2)
|
| 191 |
+
|
| 192 |
+
def short_name(x):
|
| 193 |
+
# make long class names shorter. useful for lvis
|
| 194 |
+
if len(x) > 13:
|
| 195 |
+
return x[:11] + ".."
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
data = list(
|
| 199 |
+
itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
|
| 200 |
+
)
|
| 201 |
+
total_num_instances = sum(data[1::2])
|
| 202 |
+
data.extend([None] * (N_COLS - (len(data) % N_COLS)))
|
| 203 |
+
if num_classes > 1:
|
| 204 |
+
data.extend(["total", total_num_instances])
|
| 205 |
+
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
|
| 206 |
+
table = tabulate(
|
| 207 |
+
data,
|
| 208 |
+
headers=["category", "#instances"] * (N_COLS // 2),
|
| 209 |
+
tablefmt="pipe",
|
| 210 |
+
numalign="left",
|
| 211 |
+
stralign="center",
|
| 212 |
+
)
|
| 213 |
+
log_first_n(
|
| 214 |
+
logging.INFO,
|
| 215 |
+
"Distribution of instances among all {} categories:\n".format(num_classes)
|
| 216 |
+
+ colored(table, "cyan"),
|
| 217 |
+
key="message",
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_detection_dataset_dicts(
|
| 222 |
+
names,
|
| 223 |
+
filter_empty=True,
|
| 224 |
+
min_keypoints=0,
|
| 225 |
+
proposal_files=None,
|
| 226 |
+
check_consistency=True,
|
| 227 |
+
):
|
| 228 |
+
"""
|
| 229 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
names (str or list[str]): a dataset name or a list of dataset names
|
| 233 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
| 234 |
+
min_keypoints (int): filter out images with fewer keypoints than
|
| 235 |
+
`min_keypoints`. Set to 0 to do nothing.
|
| 236 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
| 237 |
+
that match each dataset in `names`.
|
| 238 |
+
check_consistency (bool): whether to check if datasets have consistent metadata.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
| 242 |
+
"""
|
| 243 |
+
if isinstance(names, str):
|
| 244 |
+
names = [names]
|
| 245 |
+
assert len(names), names
|
| 246 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
|
| 247 |
+
|
| 248 |
+
if isinstance(dataset_dicts[0], torchdata.Dataset):
|
| 249 |
+
if len(dataset_dicts) > 1:
|
| 250 |
+
# ConcatDataset does not work for iterable style dataset.
|
| 251 |
+
# We could support concat for iterable as well, but it's often
|
| 252 |
+
# not a good idea to concat iterables anyway.
|
| 253 |
+
return torchdata.ConcatDataset(dataset_dicts)
|
| 254 |
+
return dataset_dicts[0]
|
| 255 |
+
|
| 256 |
+
for dataset_name, dicts in zip(names, dataset_dicts):
|
| 257 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
| 258 |
+
|
| 259 |
+
if proposal_files is not None:
|
| 260 |
+
assert len(names) == len(proposal_files)
|
| 261 |
+
# load precomputed proposals from proposal files
|
| 262 |
+
dataset_dicts = [
|
| 263 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
| 264 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
| 268 |
+
|
| 269 |
+
has_instances = "annotations" in dataset_dicts[0]
|
| 270 |
+
if filter_empty and has_instances:
|
| 271 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
| 272 |
+
if min_keypoints > 0 and has_instances:
|
| 273 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
| 274 |
+
|
| 275 |
+
if check_consistency and has_instances:
|
| 276 |
+
try:
|
| 277 |
+
class_names = MetadataCatalog.get(names[0]).thing_classes
|
| 278 |
+
check_metadata_consistency("thing_classes", names)
|
| 279 |
+
print_instances_class_histogram(dataset_dicts, class_names)
|
| 280 |
+
except AttributeError: # class names are not available for this dataset
|
| 281 |
+
pass
|
| 282 |
+
|
| 283 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
|
| 284 |
+
return dataset_dicts
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def build_batch_data_loader(
|
| 288 |
+
dataset,
|
| 289 |
+
sampler,
|
| 290 |
+
total_batch_size,
|
| 291 |
+
*,
|
| 292 |
+
aspect_ratio_grouping=False,
|
| 293 |
+
num_workers=0,
|
| 294 |
+
collate_fn=None,
|
| 295 |
+
):
|
| 296 |
+
"""
|
| 297 |
+
Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
|
| 298 |
+
1. support aspect ratio grouping options
|
| 299 |
+
2. use no "batch collation", because this is common for detection training
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
|
| 303 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
|
| 304 |
+
Must be provided iff. ``dataset`` is a map-style dataset.
|
| 305 |
+
total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
|
| 306 |
+
:func:`build_detection_train_loader`.
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
iterable[list]. Length of each list is the batch size of the current
|
| 310 |
+
GPU. Each element in the list comes from the dataset.
|
| 311 |
+
"""
|
| 312 |
+
world_size = get_world_size()
|
| 313 |
+
assert (
|
| 314 |
+
total_batch_size > 0 and total_batch_size % world_size == 0
|
| 315 |
+
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
|
| 316 |
+
total_batch_size, world_size
|
| 317 |
+
)
|
| 318 |
+
batch_size = total_batch_size // world_size
|
| 319 |
+
|
| 320 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 321 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
| 322 |
+
else:
|
| 323 |
+
dataset = ToIterableDataset(dataset, sampler)
|
| 324 |
+
|
| 325 |
+
if aspect_ratio_grouping:
|
| 326 |
+
data_loader = torchdata.DataLoader(
|
| 327 |
+
dataset,
|
| 328 |
+
num_workers=num_workers,
|
| 329 |
+
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
| 330 |
+
worker_init_fn=worker_init_reset_seed,
|
| 331 |
+
) # yield individual mapped dict
|
| 332 |
+
data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
|
| 333 |
+
if collate_fn is None:
|
| 334 |
+
return data_loader
|
| 335 |
+
return MapDataset(data_loader, collate_fn)
|
| 336 |
+
else:
|
| 337 |
+
return torchdata.DataLoader(
|
| 338 |
+
dataset,
|
| 339 |
+
batch_size=batch_size,
|
| 340 |
+
drop_last=True,
|
| 341 |
+
num_workers=num_workers,
|
| 342 |
+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
|
| 343 |
+
worker_init_fn=worker_init_reset_seed,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
| 348 |
+
if dataset is None:
|
| 349 |
+
dataset = get_detection_dataset_dicts(
|
| 350 |
+
cfg.DATASETS.TRAIN,
|
| 351 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
| 352 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
| 353 |
+
if cfg.MODEL.KEYPOINT_ON
|
| 354 |
+
else 0,
|
| 355 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
| 356 |
+
)
|
| 357 |
+
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
|
| 358 |
+
|
| 359 |
+
if mapper is None:
|
| 360 |
+
mapper = DatasetMapper(cfg, True)
|
| 361 |
+
|
| 362 |
+
if sampler is None:
|
| 363 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
| 364 |
+
logger = logging.getLogger(__name__)
|
| 365 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 366 |
+
logger.info("Not using any sampler since the dataset is IterableDataset.")
|
| 367 |
+
sampler = None
|
| 368 |
+
else:
|
| 369 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
| 370 |
+
if sampler_name == "TrainingSampler":
|
| 371 |
+
sampler = TrainingSampler(len(dataset))
|
| 372 |
+
elif sampler_name == "RepeatFactorTrainingSampler":
|
| 373 |
+
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
| 374 |
+
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
|
| 375 |
+
)
|
| 376 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
| 377 |
+
elif sampler_name == "RandomSubsetTrainingSampler":
|
| 378 |
+
sampler = RandomSubsetTrainingSampler(
|
| 379 |
+
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
|
| 380 |
+
)
|
| 381 |
+
else:
|
| 382 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
| 383 |
+
|
| 384 |
+
return {
|
| 385 |
+
"dataset": dataset,
|
| 386 |
+
"sampler": sampler,
|
| 387 |
+
"mapper": mapper,
|
| 388 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
| 389 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
| 390 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
@configurable(from_config=_train_loader_from_config)
|
| 395 |
+
def build_detection_train_loader(
|
| 396 |
+
dataset,
|
| 397 |
+
*,
|
| 398 |
+
mapper,
|
| 399 |
+
sampler=None,
|
| 400 |
+
total_batch_size,
|
| 401 |
+
aspect_ratio_grouping=True,
|
| 402 |
+
num_workers=0,
|
| 403 |
+
collate_fn=None,
|
| 404 |
+
):
|
| 405 |
+
"""
|
| 406 |
+
Build a dataloader for object detection with some default features.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
| 410 |
+
or a pytorch dataset (either map-style or iterable). It can be obtained
|
| 411 |
+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
| 412 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
| 413 |
+
returns the format to be consumed by the model.
|
| 414 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
| 415 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
|
| 416 |
+
indices to be applied on ``dataset``.
|
| 417 |
+
If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
|
| 418 |
+
which coordinates an infinite random shuffle sequence across all workers.
|
| 419 |
+
Sampler must be None if ``dataset`` is iterable.
|
| 420 |
+
total_batch_size (int): total batch size across all workers.
|
| 421 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
| 422 |
+
aspect ratio for efficiency. When enabled, it requires each
|
| 423 |
+
element in dataset be a dict with keys "width" and "height".
|
| 424 |
+
num_workers (int): number of parallel data loading workers
|
| 425 |
+
collate_fn: a function that determines how to do batching, same as the argument of
|
| 426 |
+
`torch.utils.data.DataLoader`. Defaults to do no collation and return a list of
|
| 427 |
+
data. No collation is OK for small batch size and simple data structures.
|
| 428 |
+
If your batch size is large and each sample contains too many small tensors,
|
| 429 |
+
it's more efficient to collate them in data loader.
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
torch.utils.data.DataLoader:
|
| 433 |
+
a dataloader. Each output from it is a ``list[mapped_element]`` of length
|
| 434 |
+
``total_batch_size / num_workers``, where ``mapped_element`` is produced
|
| 435 |
+
by the ``mapper``.
|
| 436 |
+
"""
|
| 437 |
+
if isinstance(dataset, list):
|
| 438 |
+
dataset = DatasetFromList(dataset, copy=False)
|
| 439 |
+
if mapper is not None:
|
| 440 |
+
dataset = MapDataset(dataset, mapper)
|
| 441 |
+
|
| 442 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 443 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
| 444 |
+
else:
|
| 445 |
+
if sampler is None:
|
| 446 |
+
sampler = TrainingSampler(len(dataset))
|
| 447 |
+
assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}"
|
| 448 |
+
return build_batch_data_loader(
|
| 449 |
+
dataset,
|
| 450 |
+
sampler,
|
| 451 |
+
total_batch_size,
|
| 452 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
| 453 |
+
num_workers=num_workers,
|
| 454 |
+
collate_fn=collate_fn,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
| 459 |
+
"""
|
| 460 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
| 461 |
+
standard practice is to evaluate each test set individually (not combining them).
|
| 462 |
+
"""
|
| 463 |
+
if isinstance(dataset_name, str):
|
| 464 |
+
dataset_name = [dataset_name]
|
| 465 |
+
|
| 466 |
+
dataset = get_detection_dataset_dicts(
|
| 467 |
+
dataset_name,
|
| 468 |
+
filter_empty=False,
|
| 469 |
+
proposal_files=[
|
| 470 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
|
| 471 |
+
]
|
| 472 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
| 473 |
+
else None,
|
| 474 |
+
)
|
| 475 |
+
if mapper is None:
|
| 476 |
+
mapper = DatasetMapper(cfg, False)
|
| 477 |
+
return {
|
| 478 |
+
"dataset": dataset,
|
| 479 |
+
"mapper": mapper,
|
| 480 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
| 481 |
+
"sampler": InferenceSampler(len(dataset))
|
| 482 |
+
if not isinstance(dataset, torchdata.IterableDataset)
|
| 483 |
+
else None,
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@configurable(from_config=_test_loader_from_config)
|
| 488 |
+
def build_detection_test_loader(
|
| 489 |
+
dataset: Union[List[Any], torchdata.Dataset],
|
| 490 |
+
*,
|
| 491 |
+
mapper: Callable[[Dict[str, Any]], Any],
|
| 492 |
+
sampler: Optional[torchdata.Sampler] = None,
|
| 493 |
+
batch_size: int = 1,
|
| 494 |
+
num_workers: int = 0,
|
| 495 |
+
collate_fn: Optional[Callable[[List[Any]], Any]] = None,
|
| 496 |
+
) -> torchdata.DataLoader:
|
| 497 |
+
"""
|
| 498 |
+
Similar to `build_detection_train_loader`, with default batch size = 1,
|
| 499 |
+
and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
|
| 500 |
+
to produce the exact set of all samples.
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
dataset: a list of dataset dicts,
|
| 504 |
+
or a pytorch dataset (either map-style or iterable). They can be obtained
|
| 505 |
+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
| 506 |
+
mapper: a callable which takes a sample (dict) from dataset
|
| 507 |
+
and returns the format to be consumed by the model.
|
| 508 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
| 509 |
+
sampler: a sampler that produces
|
| 510 |
+
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
|
| 511 |
+
which splits the dataset across all workers. Sampler must be None
|
| 512 |
+
if `dataset` is iterable.
|
| 513 |
+
batch_size: the batch size of the data loader to be created.
|
| 514 |
+
Default to 1 image per worker since this is the standard when reporting
|
| 515 |
+
inference time in papers.
|
| 516 |
+
num_workers: number of parallel data loading workers
|
| 517 |
+
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
|
| 518 |
+
Defaults to do no collation and return a list of data.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
| 522 |
+
dataset, with test-time transformation and batching.
|
| 523 |
+
|
| 524 |
+
Examples:
|
| 525 |
+
::
|
| 526 |
+
data_loader = build_detection_test_loader(
|
| 527 |
+
DatasetRegistry.get("my_test"),
|
| 528 |
+
mapper=DatasetMapper(...))
|
| 529 |
+
|
| 530 |
+
# or, instantiate with a CfgNode:
|
| 531 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
| 532 |
+
"""
|
| 533 |
+
if isinstance(dataset, list):
|
| 534 |
+
dataset = DatasetFromList(dataset, copy=False)
|
| 535 |
+
if mapper is not None:
|
| 536 |
+
dataset = MapDataset(dataset, mapper)
|
| 537 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 538 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
| 539 |
+
else:
|
| 540 |
+
if sampler is None:
|
| 541 |
+
sampler = InferenceSampler(len(dataset))
|
| 542 |
+
return torchdata.DataLoader(
|
| 543 |
+
dataset,
|
| 544 |
+
batch_size=batch_size,
|
| 545 |
+
sampler=sampler,
|
| 546 |
+
drop_last=False,
|
| 547 |
+
num_workers=num_workers,
|
| 548 |
+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def trivial_batch_collator(batch):
|
| 553 |
+
"""
|
| 554 |
+
A batch collator that does nothing.
|
| 555 |
+
"""
|
| 556 |
+
return batch
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def worker_init_reset_seed(worker_id):
|
| 560 |
+
initial_seed = torch.initial_seed() % 2**31
|
| 561 |
+
seed_all_rng(initial_seed + worker_id)
|
cutler/data/dataset_mapper.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py
|
| 3 |
+
|
| 4 |
+
import copy
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List, Optional, Union
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from detectron2.config import configurable
|
| 11 |
+
|
| 12 |
+
import data.detection_utils as utils
|
| 13 |
+
import data.transforms as T
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
This file contains the default mapping that's applied to "dataset dicts".
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
__all__ = ["DatasetMapper"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DatasetMapper:
|
| 23 |
+
"""
|
| 24 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
| 25 |
+
and map it into a format used by the model.
|
| 26 |
+
|
| 27 |
+
This is the default callable to be used to map your dataset dict into training data.
|
| 28 |
+
You may need to follow it to implement your own one for customized logic,
|
| 29 |
+
such as a different way to read or transform images.
|
| 30 |
+
See :doc:`/tutorials/data_loading` for details.
|
| 31 |
+
|
| 32 |
+
The callable currently does the following:
|
| 33 |
+
|
| 34 |
+
1. Read the image from "file_name"
|
| 35 |
+
2. Applies cropping/geometric transforms to the image and annotations
|
| 36 |
+
3. Prepare data and annotations to Tensor and :class:`Instances`
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@configurable
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
is_train: bool,
|
| 43 |
+
*,
|
| 44 |
+
augmentations: List[Union[T.Augmentation, T.Transform]],
|
| 45 |
+
image_format: str,
|
| 46 |
+
use_instance_mask: bool = False,
|
| 47 |
+
use_keypoint: bool = False,
|
| 48 |
+
instance_mask_format: str = "polygon",
|
| 49 |
+
keypoint_hflip_indices: Optional[np.ndarray] = None,
|
| 50 |
+
precomputed_proposal_topk: Optional[int] = None,
|
| 51 |
+
recompute_boxes: bool = False,
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
NOTE: this interface is experimental.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
is_train: whether it's used in training or inference
|
| 58 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
| 59 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
| 60 |
+
use_instance_mask: whether to process instance segmentation annotations, if available
|
| 61 |
+
use_keypoint: whether to process keypoint annotations if available
|
| 62 |
+
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
|
| 63 |
+
masks into this format.
|
| 64 |
+
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
|
| 65 |
+
precomputed_proposal_topk: if given, will load pre-computed
|
| 66 |
+
proposals from dataset_dict and keep the top k proposals for each image.
|
| 67 |
+
recompute_boxes: whether to overwrite bounding box annotations
|
| 68 |
+
by computing tight bounding boxes from instance mask annotations.
|
| 69 |
+
"""
|
| 70 |
+
if recompute_boxes:
|
| 71 |
+
assert use_instance_mask, "recompute_boxes requires instance masks"
|
| 72 |
+
# fmt: off
|
| 73 |
+
self.is_train = is_train
|
| 74 |
+
self.augmentations = T.AugmentationList(augmentations)
|
| 75 |
+
self.image_format = image_format
|
| 76 |
+
self.use_instance_mask = use_instance_mask
|
| 77 |
+
self.instance_mask_format = instance_mask_format
|
| 78 |
+
self.use_keypoint = use_keypoint
|
| 79 |
+
self.keypoint_hflip_indices = keypoint_hflip_indices
|
| 80 |
+
self.proposal_topk = precomputed_proposal_topk
|
| 81 |
+
self.recompute_boxes = recompute_boxes
|
| 82 |
+
# fmt: on
|
| 83 |
+
logger = logging.getLogger(__name__)
|
| 84 |
+
mode = "training" if is_train else "inference"
|
| 85 |
+
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def from_config(cls, cfg, is_train: bool = True):
|
| 89 |
+
augs = utils.build_augmentation(cfg, is_train)
|
| 90 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
| 91 |
+
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
|
| 92 |
+
recompute_boxes = cfg.MODEL.MASK_ON
|
| 93 |
+
else:
|
| 94 |
+
recompute_boxes = False
|
| 95 |
+
|
| 96 |
+
ret = {
|
| 97 |
+
"is_train": is_train,
|
| 98 |
+
"augmentations": augs,
|
| 99 |
+
"image_format": cfg.INPUT.FORMAT,
|
| 100 |
+
"use_instance_mask": cfg.MODEL.MASK_ON,
|
| 101 |
+
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
|
| 102 |
+
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
|
| 103 |
+
"recompute_boxes": recompute_boxes,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if cfg.MODEL.KEYPOINT_ON:
|
| 107 |
+
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
| 108 |
+
|
| 109 |
+
if cfg.MODEL.LOAD_PROPOSALS:
|
| 110 |
+
ret["precomputed_proposal_topk"] = (
|
| 111 |
+
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
|
| 112 |
+
if is_train
|
| 113 |
+
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
|
| 114 |
+
)
|
| 115 |
+
return ret
|
| 116 |
+
|
| 117 |
+
def _transform_annotations(self, dataset_dict, transforms, image_shape):
|
| 118 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 119 |
+
for anno in dataset_dict["annotations"]:
|
| 120 |
+
if not self.use_instance_mask:
|
| 121 |
+
anno.pop("segmentation", None)
|
| 122 |
+
if not self.use_keypoint:
|
| 123 |
+
anno.pop("keypoints", None)
|
| 124 |
+
|
| 125 |
+
# USER: Implement additional transformations if you have other types of data
|
| 126 |
+
annos = [
|
| 127 |
+
utils.transform_instance_annotations(
|
| 128 |
+
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
| 129 |
+
)
|
| 130 |
+
for obj in dataset_dict.pop("annotations")
|
| 131 |
+
if obj.get("iscrowd", 0) == 0
|
| 132 |
+
]
|
| 133 |
+
instances = utils.annotations_to_instances(
|
| 134 |
+
annos, image_shape, mask_format=self.instance_mask_format
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# After transforms such as cropping are applied, the bounding box may no longer
|
| 138 |
+
# tightly bound the object. As an example, imagine a triangle object
|
| 139 |
+
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
|
| 140 |
+
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
|
| 141 |
+
# the intersection of original bounding box and the cropping box.
|
| 142 |
+
if self.recompute_boxes:
|
| 143 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
| 144 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
| 145 |
+
|
| 146 |
+
def __call__(self, dataset_dict):
|
| 147 |
+
"""
|
| 148 |
+
Args:
|
| 149 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
dict: a format that builtin models in detectron2 accept
|
| 153 |
+
"""
|
| 154 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
| 155 |
+
# USER: Write your own image loading if it's not from a file
|
| 156 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
|
| 157 |
+
utils.check_image_size(dataset_dict, image)
|
| 158 |
+
|
| 159 |
+
# USER: Remove if you don't do semantic/panoptic segmentation.
|
| 160 |
+
if "sem_seg_file_name" in dataset_dict:
|
| 161 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
|
| 162 |
+
else:
|
| 163 |
+
sem_seg_gt = None
|
| 164 |
+
|
| 165 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
| 166 |
+
transforms = self.augmentations(aug_input)
|
| 167 |
+
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
|
| 168 |
+
|
| 169 |
+
image_shape = image.shape[:2] # h, w
|
| 170 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
| 171 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
| 172 |
+
# Therefore it's important to use torch.Tensor.
|
| 173 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
| 174 |
+
if sem_seg_gt is not None:
|
| 175 |
+
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
|
| 176 |
+
|
| 177 |
+
# USER: Remove if you don't use pre-computed proposals.
|
| 178 |
+
# Most users would not need this feature.
|
| 179 |
+
if self.proposal_topk is not None:
|
| 180 |
+
utils.transform_proposals(
|
| 181 |
+
dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if not self.is_train:
|
| 185 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 186 |
+
dataset_dict.pop("annotations", None)
|
| 187 |
+
dataset_dict.pop("sem_seg_file_name", None)
|
| 188 |
+
return dataset_dict
|
| 189 |
+
|
| 190 |
+
if "annotations" in dataset_dict:
|
| 191 |
+
self._transform_annotations(dataset_dict, transforms, image_shape)
|
| 192 |
+
|
| 193 |
+
return dataset_dict
|
cutler/data/datasets/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json
|
| 3 |
+
from .builtin import (
|
| 4 |
+
register_all_imagenet,
|
| 5 |
+
register_all_uvo,
|
| 6 |
+
register_all_coco_ca,
|
| 7 |
+
register_all_coco_semi,
|
| 8 |
+
register_all_lvis,
|
| 9 |
+
register_all_voc,
|
| 10 |
+
register_all_cross_domain,
|
| 11 |
+
register_all_kitti,
|
| 12 |
+
register_all_objects365,
|
| 13 |
+
register_all_openimages,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
cutler/data/datasets/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (913 Bytes). View file
|
|
|
cutler/data/datasets/__pycache__/builtin.cpython-312.pyc
ADDED
|
Binary file (9.74 kB). View file
|
|
|
cutler/data/datasets/__pycache__/builtin_meta.cpython-312.pyc
ADDED
|
Binary file (20 kB). View file
|
|
|
cutler/data/datasets/__pycache__/coco.cpython-312.pyc
ADDED
|
Binary file (24.8 kB). View file
|
|
|
cutler/data/datasets/builtin.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/builtin.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
This file registers pre-defined datasets at hard-coded paths, and their metadata.
|
| 6 |
+
|
| 7 |
+
We hard-code metadata for common datasets. This will enable:
|
| 8 |
+
1. Consistency check when loading the datasets
|
| 9 |
+
2. Use models on these standard datasets directly and run demos,
|
| 10 |
+
without having to download the dataset annotations
|
| 11 |
+
|
| 12 |
+
We hard-code some paths to the dataset that's assumed to
|
| 13 |
+
exist in "./datasets/".
|
| 14 |
+
|
| 15 |
+
Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
|
| 16 |
+
To add new dataset, refer to the tutorial "docs/DATASETS.md".
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
from .builtin_meta import _get_builtin_metadata
|
| 22 |
+
from .coco import register_coco_instances
|
| 23 |
+
|
| 24 |
+
# ==== Predefined datasets and splits for COCO ==========
|
| 25 |
+
|
| 26 |
+
_PREDEFINED_SPLITS_COCO_SEMI = {}
|
| 27 |
+
_PREDEFINED_SPLITS_COCO_SEMI["coco_semi"] = {
|
| 28 |
+
# we use seed 42 to be consistent with previous works on SSL detection and segmentation
|
| 29 |
+
"coco_semi_1perc": ("coco/train2017", "coco/annotations/1perc_instances_train2017.json"),
|
| 30 |
+
"coco_semi_2perc": ("coco/train2017", "coco/annotations/2perc_instances_train2017.json"),
|
| 31 |
+
"coco_semi_5perc": ("coco/train2017", "coco/annotations/5perc_instances_train2017.json"),
|
| 32 |
+
"coco_semi_10perc": ("coco/train2017", "coco/annotations/10perc_instances_train2017.json"),
|
| 33 |
+
"coco_semi_20perc": ("coco/train2017", "coco/annotations/20perc_instances_train2017.json"),
|
| 34 |
+
"coco_semi_30perc": ("coco/train2017", "coco/annotations/30perc_instances_train2017.json"),
|
| 35 |
+
"coco_semi_40perc": ("coco/train2017", "coco/annotations/40perc_instances_train2017.json"),
|
| 36 |
+
"coco_semi_50perc": ("coco/train2017", "coco/annotations/50perc_instances_train2017.json"),
|
| 37 |
+
"coco_semi_60perc": ("coco/train2017", "coco/annotations/60perc_instances_train2017.json"),
|
| 38 |
+
"coco_semi_80perc": ("coco/train2017", "coco/annotations/80perc_instances_train2017.json"),
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
_PREDEFINED_SPLITS_COCO_CA = {}
|
| 42 |
+
_PREDEFINED_SPLITS_COCO_CA["coco_cls_agnostic"] = {
|
| 43 |
+
"cls_agnostic_coco": ("coco/val2017", "coco/annotations/coco_cls_agnostic_instances_val2017.json"),
|
| 44 |
+
"cls_agnostic_coco20k": ("coco/train2014", "coco/annotations/coco20k_trainval_gt.json"),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
_PREDEFINED_SPLITS_IMAGENET = {}
|
| 48 |
+
_PREDEFINED_SPLITS_IMAGENET["imagenet"] = {
|
| 49 |
+
# maskcut annotations
|
| 50 |
+
"imagenet_train": ("imagenet/train", "imagenet/annotations/imagenet_train_fixsize480_tau0.15_N3.json"),
|
| 51 |
+
# self-training round 1
|
| 52 |
+
"imagenet_train_r1": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r1.json"),
|
| 53 |
+
# self-training round 2
|
| 54 |
+
"imagenet_train_r2": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r2.json"),
|
| 55 |
+
# self-training round 3
|
| 56 |
+
"imagenet_train_r3": ("imagenet/train", "imagenet/annotations/cutler_imagenet1k_train_r3.json"),
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
_PREDEFINED_SPLITS_VOC = {}
|
| 60 |
+
_PREDEFINED_SPLITS_VOC["voc"] = {
|
| 61 |
+
'cls_agnostic_voc': ("voc/", "voc/annotations/trainvaltest_2007_cls_agnostic.json"),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
_PREDEFINED_SPLITS_CROSSDOMAIN = {}
|
| 65 |
+
_PREDEFINED_SPLITS_CROSSDOMAIN["cross_domain"] = {
|
| 66 |
+
'cls_agnostic_clipart': ("clipart/", "clipart/annotations/traintest_cls_agnostic.json"),
|
| 67 |
+
'cls_agnostic_watercolor': ("watercolor/", "watercolor/annotations/traintest_cls_agnostic.json"),
|
| 68 |
+
'cls_agnostic_comic': ("comic/", "comic/annotations/traintest_cls_agnostic.json"),
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
_PREDEFINED_SPLITS_KITTI = {}
|
| 72 |
+
_PREDEFINED_SPLITS_KITTI["kitti"] = {
|
| 73 |
+
'cls_agnostic_kitti': ("kitti/", "kitti/annotations/trainval_cls_agnostic.json"),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
_PREDEFINED_SPLITS_LVIS = {}
|
| 77 |
+
_PREDEFINED_SPLITS_LVIS["lvis"] = {
|
| 78 |
+
"cls_agnostic_lvis": ("coco/", "coco/annotations/lvis1.0_cocofied_val_cls_agnostic.json"),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
_PREDEFINED_SPLITS_OBJECTS365 = {}
|
| 82 |
+
_PREDEFINED_SPLITS_OBJECTS365["objects365"] = {
|
| 83 |
+
'cls_agnostic_objects365': ("objects365/val", "objects365/annotations/zhiyuan_objv2_val_cls_agnostic.json"),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
_PREDEFINED_SPLITS_OpenImages = {}
|
| 87 |
+
_PREDEFINED_SPLITS_OpenImages["openimages"] = {
|
| 88 |
+
'cls_agnostic_openimages': ("openImages/validation", "openImages/annotations/openimages_val_cls_agnostic.json"),
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
_PREDEFINED_SPLITS_UVO = {}
|
| 92 |
+
_PREDEFINED_SPLITS_UVO["uvo"] = {
|
| 93 |
+
"cls_agnostic_uvo": ("uvo/all_UVO_frames", "uvo/annotations/val_sparse_cleaned_cls_agnostic.json"),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def register_all_imagenet(root):
|
| 97 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_IMAGENET.items():
|
| 98 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 99 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 100 |
+
register_coco_instances(
|
| 101 |
+
key,
|
| 102 |
+
_get_builtin_metadata(dataset_name),
|
| 103 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 104 |
+
os.path.join(root, image_root),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def register_all_voc(root):
|
| 108 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_VOC.items():
|
| 109 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 110 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 111 |
+
register_coco_instances(
|
| 112 |
+
key,
|
| 113 |
+
_get_builtin_metadata(dataset_name),
|
| 114 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 115 |
+
os.path.join(root, image_root),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def register_all_cross_domain(root):
|
| 119 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_CROSSDOMAIN.items():
|
| 120 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 121 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 122 |
+
register_coco_instances(
|
| 123 |
+
key,
|
| 124 |
+
_get_builtin_metadata(dataset_name),
|
| 125 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 126 |
+
os.path.join(root, image_root),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def register_all_kitti(root):
|
| 130 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_KITTI.items():
|
| 131 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 132 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 133 |
+
register_coco_instances(
|
| 134 |
+
key,
|
| 135 |
+
_get_builtin_metadata(dataset_name),
|
| 136 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 137 |
+
os.path.join(root, image_root),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def register_all_objects365(root):
|
| 141 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_OBJECTS365.items():
|
| 142 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 143 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 144 |
+
register_coco_instances(
|
| 145 |
+
key,
|
| 146 |
+
_get_builtin_metadata(dataset_name),
|
| 147 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 148 |
+
os.path.join(root, image_root),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def register_all_openimages(root):
|
| 152 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_OpenImages.items():
|
| 153 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 154 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 155 |
+
register_coco_instances(
|
| 156 |
+
key,
|
| 157 |
+
_get_builtin_metadata(dataset_name),
|
| 158 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 159 |
+
os.path.join(root, image_root),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def register_all_lvis(root):
|
| 163 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items():
|
| 164 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 165 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 166 |
+
register_coco_instances(
|
| 167 |
+
key,
|
| 168 |
+
_get_builtin_metadata(dataset_name),
|
| 169 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 170 |
+
os.path.join(root, image_root),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def register_all_uvo(root):
|
| 174 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_UVO.items():
|
| 175 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 176 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 177 |
+
register_coco_instances(
|
| 178 |
+
key,
|
| 179 |
+
_get_builtin_metadata(dataset_name),
|
| 180 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 181 |
+
os.path.join(root, image_root),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def register_all_coco_semi(root):
|
| 185 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO_SEMI.items():
|
| 186 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 187 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 188 |
+
register_coco_instances(
|
| 189 |
+
key,
|
| 190 |
+
_get_builtin_metadata(dataset_name),
|
| 191 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 192 |
+
os.path.join(root, image_root),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def register_all_coco_ca(root):
|
| 196 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO_CA.items():
|
| 197 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 198 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 199 |
+
register_coco_instances(
|
| 200 |
+
key,
|
| 201 |
+
_get_builtin_metadata(dataset_name),
|
| 202 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 203 |
+
os.path.join(root, image_root),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
_root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
|
| 207 |
+
register_all_coco_semi(_root)
|
| 208 |
+
register_all_coco_ca(_root)
|
| 209 |
+
register_all_imagenet(_root)
|
| 210 |
+
register_all_uvo(_root)
|
| 211 |
+
register_all_voc(_root)
|
| 212 |
+
register_all_cross_domain(_root)
|
| 213 |
+
register_all_kitti(_root)
|
| 214 |
+
register_all_openimages(_root)
|
| 215 |
+
register_all_objects365(_root)
|
| 216 |
+
register_all_lvis(_root)
|
cutler/data/datasets/builtin_meta.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/builtin_meta.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Note:
|
| 6 |
+
For your custom dataset, there is no need to hard-code metadata anywhere in the code.
|
| 7 |
+
For example, for COCO-format dataset, metadata will be obtained automatically
|
| 8 |
+
when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways
|
| 9 |
+
during loading.
|
| 10 |
+
|
| 11 |
+
However, we hard-coded metadata for a few common dataset here.
|
| 12 |
+
The only goal is to allow users who don't have these dataset to use pre-trained models.
|
| 13 |
+
Users don't have to download a COCO json (which contains metadata), in order to visualize a
|
| 14 |
+
COCO model (with correct class names and colors).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# All coco categories, together with their nice-looking visualization colors
|
| 19 |
+
# It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
|
| 20 |
+
COCO_CATEGORIES = [
|
| 21 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
| 22 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
| 23 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
| 24 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
| 25 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
| 26 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
| 27 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
| 28 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
| 29 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
| 30 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
| 31 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
| 32 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
| 33 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
| 34 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
| 35 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
| 36 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
| 37 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
| 38 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
| 39 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
| 40 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
| 41 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
| 42 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
| 43 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
| 44 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
| 45 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
| 46 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
| 47 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
| 48 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
| 49 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
| 50 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
| 51 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
| 52 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
| 53 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
| 54 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
| 55 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
| 56 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
| 57 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
| 58 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
| 59 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
| 60 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
| 61 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
| 62 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
| 63 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
| 64 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
| 65 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
| 66 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
| 67 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
| 68 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
| 69 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
| 70 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
| 71 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
| 72 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
| 73 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
| 74 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
| 75 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
| 76 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
| 77 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
| 78 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
| 79 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
| 80 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
| 81 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
| 82 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
| 83 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
| 84 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
| 85 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
| 86 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
| 87 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
| 88 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
| 89 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
| 90 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
| 91 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
| 92 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
| 93 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
| 94 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
| 95 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
| 96 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
| 97 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
| 98 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
| 99 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
| 100 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
| 101 |
+
{"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"},
|
| 102 |
+
{"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"},
|
| 103 |
+
{"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"},
|
| 104 |
+
{"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"},
|
| 105 |
+
{"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"},
|
| 106 |
+
{"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"},
|
| 107 |
+
{"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"},
|
| 108 |
+
{"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"},
|
| 109 |
+
{"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"},
|
| 110 |
+
{"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"},
|
| 111 |
+
{"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"},
|
| 112 |
+
{"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"},
|
| 113 |
+
{"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"},
|
| 114 |
+
{"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"},
|
| 115 |
+
{"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"},
|
| 116 |
+
{"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"},
|
| 117 |
+
{"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"},
|
| 118 |
+
{"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"},
|
| 119 |
+
{"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"},
|
| 120 |
+
{"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"},
|
| 121 |
+
{"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"},
|
| 122 |
+
{"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"},
|
| 123 |
+
{"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"},
|
| 124 |
+
{"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"},
|
| 125 |
+
{"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"},
|
| 126 |
+
{"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"},
|
| 127 |
+
{"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"},
|
| 128 |
+
{"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"},
|
| 129 |
+
{"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"},
|
| 130 |
+
{"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"},
|
| 131 |
+
{"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"},
|
| 132 |
+
{"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"},
|
| 133 |
+
{"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"},
|
| 134 |
+
{"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"},
|
| 135 |
+
{"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"},
|
| 136 |
+
{"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"},
|
| 137 |
+
{"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"},
|
| 138 |
+
{"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"},
|
| 139 |
+
{"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"},
|
| 140 |
+
{"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"},
|
| 141 |
+
{"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"},
|
| 142 |
+
{"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"},
|
| 143 |
+
{"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"},
|
| 144 |
+
{"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"},
|
| 145 |
+
{"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"},
|
| 146 |
+
{"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"},
|
| 147 |
+
{"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"},
|
| 148 |
+
{"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"},
|
| 149 |
+
{"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"},
|
| 150 |
+
{"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"},
|
| 151 |
+
{"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"},
|
| 152 |
+
{"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"},
|
| 153 |
+
{"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"},
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
IMAGENET_CATEGORIES = [
|
| 157 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "fg"},
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
UVO_CATEGORIES = [
|
| 161 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "object"},
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
# fmt: off
|
| 165 |
+
COCO_PERSON_KEYPOINT_NAMES = (
|
| 166 |
+
"nose",
|
| 167 |
+
"left_eye", "right_eye",
|
| 168 |
+
"left_ear", "right_ear",
|
| 169 |
+
"left_shoulder", "right_shoulder",
|
| 170 |
+
"left_elbow", "right_elbow",
|
| 171 |
+
"left_wrist", "right_wrist",
|
| 172 |
+
"left_hip", "right_hip",
|
| 173 |
+
"left_knee", "right_knee",
|
| 174 |
+
"left_ankle", "right_ankle",
|
| 175 |
+
)
|
| 176 |
+
# fmt: on
|
| 177 |
+
|
| 178 |
+
# Pairs of keypoints that should be exchanged under horizontal flipping
|
| 179 |
+
COCO_PERSON_KEYPOINT_FLIP_MAP = (
|
| 180 |
+
("left_eye", "right_eye"),
|
| 181 |
+
("left_ear", "right_ear"),
|
| 182 |
+
("left_shoulder", "right_shoulder"),
|
| 183 |
+
("left_elbow", "right_elbow"),
|
| 184 |
+
("left_wrist", "right_wrist"),
|
| 185 |
+
("left_hip", "right_hip"),
|
| 186 |
+
("left_knee", "right_knee"),
|
| 187 |
+
("left_ankle", "right_ankle"),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# rules for pairs of keypoints to draw a line between, and the line color to use.
|
| 191 |
+
KEYPOINT_CONNECTION_RULES = [
|
| 192 |
+
# face
|
| 193 |
+
("left_ear", "left_eye", (102, 204, 255)),
|
| 194 |
+
("right_ear", "right_eye", (51, 153, 255)),
|
| 195 |
+
("left_eye", "nose", (102, 0, 204)),
|
| 196 |
+
("nose", "right_eye", (51, 102, 255)),
|
| 197 |
+
# upper-body
|
| 198 |
+
("left_shoulder", "right_shoulder", (255, 128, 0)),
|
| 199 |
+
("left_shoulder", "left_elbow", (153, 255, 204)),
|
| 200 |
+
("right_shoulder", "right_elbow", (128, 229, 255)),
|
| 201 |
+
("left_elbow", "left_wrist", (153, 255, 153)),
|
| 202 |
+
("right_elbow", "right_wrist", (102, 255, 224)),
|
| 203 |
+
# lower-body
|
| 204 |
+
("left_hip", "right_hip", (255, 102, 0)),
|
| 205 |
+
("left_hip", "left_knee", (255, 255, 77)),
|
| 206 |
+
("right_hip", "right_knee", (153, 255, 204)),
|
| 207 |
+
("left_knee", "left_ankle", (191, 255, 128)),
|
| 208 |
+
("right_knee", "right_ankle", (255, 195, 77)),
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
# All Cityscapes categories, together with their nice-looking visualization colors
|
| 212 |
+
# It's from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py # noqa
|
| 213 |
+
CITYSCAPES_CATEGORIES = [
|
| 214 |
+
{"color": (128, 64, 128), "isthing": 0, "id": 7, "trainId": 0, "name": "road"},
|
| 215 |
+
{"color": (244, 35, 232), "isthing": 0, "id": 8, "trainId": 1, "name": "sidewalk"},
|
| 216 |
+
{"color": (70, 70, 70), "isthing": 0, "id": 11, "trainId": 2, "name": "building"},
|
| 217 |
+
{"color": (102, 102, 156), "isthing": 0, "id": 12, "trainId": 3, "name": "wall"},
|
| 218 |
+
{"color": (190, 153, 153), "isthing": 0, "id": 13, "trainId": 4, "name": "fence"},
|
| 219 |
+
{"color": (153, 153, 153), "isthing": 0, "id": 17, "trainId": 5, "name": "pole"},
|
| 220 |
+
{"color": (250, 170, 30), "isthing": 0, "id": 19, "trainId": 6, "name": "traffic light"},
|
| 221 |
+
{"color": (220, 220, 0), "isthing": 0, "id": 20, "trainId": 7, "name": "traffic sign"},
|
| 222 |
+
{"color": (107, 142, 35), "isthing": 0, "id": 21, "trainId": 8, "name": "vegetation"},
|
| 223 |
+
{"color": (152, 251, 152), "isthing": 0, "id": 22, "trainId": 9, "name": "terrain"},
|
| 224 |
+
{"color": (70, 130, 180), "isthing": 0, "id": 23, "trainId": 10, "name": "sky"},
|
| 225 |
+
{"color": (220, 20, 60), "isthing": 1, "id": 24, "trainId": 11, "name": "person"},
|
| 226 |
+
{"color": (255, 0, 0), "isthing": 1, "id": 25, "trainId": 12, "name": "rider"},
|
| 227 |
+
{"color": (0, 0, 142), "isthing": 1, "id": 26, "trainId": 13, "name": "car"},
|
| 228 |
+
{"color": (0, 0, 70), "isthing": 1, "id": 27, "trainId": 14, "name": "truck"},
|
| 229 |
+
{"color": (0, 60, 100), "isthing": 1, "id": 28, "trainId": 15, "name": "bus"},
|
| 230 |
+
{"color": (0, 80, 100), "isthing": 1, "id": 31, "trainId": 16, "name": "train"},
|
| 231 |
+
{"color": (0, 0, 230), "isthing": 1, "id": 32, "trainId": 17, "name": "motorcycle"},
|
| 232 |
+
{"color": (119, 11, 32), "isthing": 1, "id": 33, "trainId": 18, "name": "bicycle"},
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
# fmt: off
|
| 236 |
+
ADE20K_SEM_SEG_CATEGORIES = [
|
| 237 |
+
"wall", "building", "sky", "floor", "tree", "ceiling", "road, route", "bed", "window ", "grass", "cabinet", "sidewalk, pavement", "person", "earth, ground", "door", "table", "mountain, mount", "plant", "curtain", "chair", "car", "water", "painting, picture", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "tub", "rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove", "palm, palm tree", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus", "towel", "light", "truck", "tower", "chandelier", "awning, sunshade, sunblind", "street lamp", "booth", "tv", "plane", "dirt track", "clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "pool", "stool", "barrel, cask", "basket, handbasket", "falls", "tent", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light", "tray", "trash can", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass, drinking glass", "clock", "flag", # noqa
|
| 238 |
+
]
|
| 239 |
+
# After processed by `prepare_ade20k_sem_seg.py`, id 255 means ignore
|
| 240 |
+
# fmt: on
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _get_coco_instances_meta():
|
| 244 |
+
thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 245 |
+
thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 246 |
+
assert len(thing_ids) == 80, len(thing_ids)
|
| 247 |
+
# Mapping from the incontiguous COCO category id to an id in [0, 79]
|
| 248 |
+
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
|
| 249 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 250 |
+
ret = {
|
| 251 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 252 |
+
"thing_classes": thing_classes,
|
| 253 |
+
"thing_colors": thing_colors,
|
| 254 |
+
}
|
| 255 |
+
return ret
|
| 256 |
+
|
| 257 |
+
def _get_imagenet_instances_meta():
|
| 258 |
+
thing_ids = [k["id"] for k in IMAGENET_CATEGORIES if k["isthing"] == 1]
|
| 259 |
+
thing_colors = [k["color"] for k in IMAGENET_CATEGORIES if k["isthing"] == 1]
|
| 260 |
+
assert len(thing_ids) == 1, len(thing_ids)
|
| 261 |
+
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
|
| 262 |
+
thing_classes = [k["name"] for k in IMAGENET_CATEGORIES if k["isthing"] == 1]
|
| 263 |
+
ret = {
|
| 264 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 265 |
+
"thing_classes": thing_classes,
|
| 266 |
+
"thing_colors": thing_colors,
|
| 267 |
+
"class_image_count": [{'id': 1, 'image_count': 116986}]
|
| 268 |
+
}
|
| 269 |
+
return ret
|
| 270 |
+
|
| 271 |
+
def _get_UVO_instances_meta():
|
| 272 |
+
thing_ids = [k["id"] for k in UVO_CATEGORIES if k["isthing"] == 1]
|
| 273 |
+
thing_colors = [k["color"] for k in UVO_CATEGORIES if k["isthing"] == 1]
|
| 274 |
+
assert len(thing_ids) == 1, len(thing_ids)
|
| 275 |
+
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
|
| 276 |
+
thing_classes = [k["name"] for k in UVO_CATEGORIES if k["isthing"] == 1]
|
| 277 |
+
ret = {
|
| 278 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 279 |
+
"thing_classes": thing_classes,
|
| 280 |
+
"thing_colors": thing_colors,
|
| 281 |
+
"class_image_count": [{'id': 1, 'image_count': 116986}]
|
| 282 |
+
}
|
| 283 |
+
return ret
|
| 284 |
+
|
| 285 |
+
def _get_coco_panoptic_separated_meta():
|
| 286 |
+
"""
|
| 287 |
+
Returns metadata for "separated" version of the panoptic segmentation dataset.
|
| 288 |
+
"""
|
| 289 |
+
stuff_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 0]
|
| 290 |
+
assert len(stuff_ids) == 53, len(stuff_ids)
|
| 291 |
+
|
| 292 |
+
# For semantic segmentation, this mapping maps from contiguous stuff id
|
| 293 |
+
# (in [0, 53], used in models) to ids in the dataset (used for processing results)
|
| 294 |
+
# The id 0 is mapped to an extra category "thing".
|
| 295 |
+
stuff_dataset_id_to_contiguous_id = {k: i + 1 for i, k in enumerate(stuff_ids)}
|
| 296 |
+
# When converting COCO panoptic annotations to semantic annotations
|
| 297 |
+
# We label the "thing" category to 0
|
| 298 |
+
stuff_dataset_id_to_contiguous_id[0] = 0
|
| 299 |
+
|
| 300 |
+
# 54 names for COCO stuff categories (including "things")
|
| 301 |
+
stuff_classes = ["things"] + [
|
| 302 |
+
k["name"].replace("-other", "").replace("-merged", "")
|
| 303 |
+
for k in COCO_CATEGORIES
|
| 304 |
+
if k["isthing"] == 0
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
# NOTE: I randomly picked a color for things
|
| 308 |
+
stuff_colors = [[82, 18, 128]] + [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 0]
|
| 309 |
+
ret = {
|
| 310 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
| 311 |
+
"stuff_classes": stuff_classes,
|
| 312 |
+
"stuff_colors": stuff_colors,
|
| 313 |
+
}
|
| 314 |
+
ret.update(_get_coco_instances_meta())
|
| 315 |
+
return ret
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _get_builtin_metadata(dataset_name):
|
| 319 |
+
if dataset_name in ["coco", "coco_semi"]:
|
| 320 |
+
return _get_coco_instances_meta()
|
| 321 |
+
if dataset_name == "coco_panoptic_separated":
|
| 322 |
+
return _get_coco_panoptic_separated_meta()
|
| 323 |
+
elif dataset_name in ["imagenet", "kitti", "cross_domain", "lvis", "voc", "coco_cls_agnostic", "objects365", 'openimages']:
|
| 324 |
+
return _get_imagenet_instances_meta()
|
| 325 |
+
elif dataset_name == "uvo":
|
| 326 |
+
return _get_UVO_instances_meta()
|
| 327 |
+
elif dataset_name == "coco_panoptic_standard":
|
| 328 |
+
meta = {}
|
| 329 |
+
# The following metadata maps contiguous id from [0, #thing categories +
|
| 330 |
+
# #stuff categories) to their names and colors. We have to replica of the
|
| 331 |
+
# same name and color under "thing_*" and "stuff_*" because the current
|
| 332 |
+
# visualization function in D2 handles thing and class classes differently
|
| 333 |
+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
|
| 334 |
+
# enable reusing existing visualization functions.
|
| 335 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES]
|
| 336 |
+
thing_colors = [k["color"] for k in COCO_CATEGORIES]
|
| 337 |
+
stuff_classes = [k["name"] for k in COCO_CATEGORIES]
|
| 338 |
+
stuff_colors = [k["color"] for k in COCO_CATEGORIES]
|
| 339 |
+
|
| 340 |
+
meta["thing_classes"] = thing_classes
|
| 341 |
+
meta["thing_colors"] = thing_colors
|
| 342 |
+
meta["stuff_classes"] = stuff_classes
|
| 343 |
+
meta["stuff_colors"] = stuff_colors
|
| 344 |
+
|
| 345 |
+
# Convert category id for training:
|
| 346 |
+
# category id: like semantic segmentation, it is the class id for each
|
| 347 |
+
# pixel. Since there are some classes not used in evaluation, the category
|
| 348 |
+
# id is not always contiguous and thus we have two set of category ids:
|
| 349 |
+
# - original category id: category id in the original dataset, mainly
|
| 350 |
+
# used for evaluation.
|
| 351 |
+
# - contiguous category id: [0, #classes), in order to train the linear
|
| 352 |
+
# softmax classifier.
|
| 353 |
+
thing_dataset_id_to_contiguous_id = {}
|
| 354 |
+
stuff_dataset_id_to_contiguous_id = {}
|
| 355 |
+
|
| 356 |
+
for i, cat in enumerate(COCO_CATEGORIES):
|
| 357 |
+
if cat["isthing"]:
|
| 358 |
+
thing_dataset_id_to_contiguous_id[cat["id"]] = i
|
| 359 |
+
else:
|
| 360 |
+
stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
| 361 |
+
|
| 362 |
+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
| 363 |
+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
| 364 |
+
|
| 365 |
+
return meta
|
| 366 |
+
elif dataset_name == "coco_person":
|
| 367 |
+
return {
|
| 368 |
+
"thing_classes": ["person"],
|
| 369 |
+
"keypoint_names": COCO_PERSON_KEYPOINT_NAMES,
|
| 370 |
+
"keypoint_flip_map": COCO_PERSON_KEYPOINT_FLIP_MAP,
|
| 371 |
+
"keypoint_connection_rules": KEYPOINT_CONNECTION_RULES,
|
| 372 |
+
}
|
| 373 |
+
elif dataset_name == "cityscapes":
|
| 374 |
+
# fmt: off
|
| 375 |
+
CITYSCAPES_THING_CLASSES = [
|
| 376 |
+
"person", "rider", "car", "truck",
|
| 377 |
+
"bus", "train", "motorcycle", "bicycle",
|
| 378 |
+
]
|
| 379 |
+
CITYSCAPES_STUFF_CLASSES = [
|
| 380 |
+
"road", "sidewalk", "building", "wall", "fence", "pole", "traffic light",
|
| 381 |
+
"traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car",
|
| 382 |
+
"truck", "bus", "train", "motorcycle", "bicycle",
|
| 383 |
+
]
|
| 384 |
+
# fmt: on
|
| 385 |
+
return {
|
| 386 |
+
"thing_classes": CITYSCAPES_THING_CLASSES,
|
| 387 |
+
"stuff_classes": CITYSCAPES_STUFF_CLASSES,
|
| 388 |
+
}
|
| 389 |
+
raise KeyError("No built-in metadata for dataset {}".format(dataset_name))
|
cutler/data/datasets/coco.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/coco.py
|
| 3 |
+
|
| 4 |
+
import contextlib
|
| 5 |
+
import datetime
|
| 6 |
+
import io
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import pycocotools.mask as mask_util
|
| 13 |
+
from fvcore.common.timer import Timer
|
| 14 |
+
from iopath.common.file_io import file_lock
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
|
| 18 |
+
from detectron2.utils.file_io import PathManager
|
| 19 |
+
|
| 20 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format".
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
__all__ = ["load_coco_json", "load_sem_seg", "convert_to_coco_json", "register_coco_instances"]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
|
| 33 |
+
"""
|
| 34 |
+
Load a json file with COCO's instances annotation format.
|
| 35 |
+
Currently supports instance detection, instance segmentation,
|
| 36 |
+
and person keypoints annotations.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
json_file (str): full path to the json file in COCO instances annotation format.
|
| 40 |
+
image_root (str or path-like): the directory where the images in this json file exists.
|
| 41 |
+
dataset_name (str or None): the name of the dataset (e.g., coco_2017_train).
|
| 42 |
+
When provided, this function will also do the following:
|
| 43 |
+
|
| 44 |
+
* Put "thing_classes" into the metadata associated with this dataset.
|
| 45 |
+
* Map the category ids into a contiguous range (needed by standard dataset format),
|
| 46 |
+
and add "thing_dataset_id_to_contiguous_id" to the metadata associated
|
| 47 |
+
with this dataset.
|
| 48 |
+
|
| 49 |
+
This option should usually be provided, unless users need to load
|
| 50 |
+
the original json content and apply more processing manually.
|
| 51 |
+
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
|
| 52 |
+
loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
|
| 53 |
+
"category_id", "segmentation"). The values for these keys will be returned as-is.
|
| 54 |
+
For example, the densepose annotations are loaded in this way.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See
|
| 58 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ ) when `dataset_name` is not None.
|
| 59 |
+
If `dataset_name` is None, the returned `category_ids` may be
|
| 60 |
+
incontiguous and may not conform to the Detectron2 standard format.
|
| 61 |
+
|
| 62 |
+
Notes:
|
| 63 |
+
1. This function does not read the image files.
|
| 64 |
+
The results do not have the "image" field.
|
| 65 |
+
"""
|
| 66 |
+
from pycocotools.coco import COCO
|
| 67 |
+
|
| 68 |
+
timer = Timer()
|
| 69 |
+
json_file = PathManager.get_local_path(json_file)
|
| 70 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 71 |
+
coco_api = COCO(json_file)
|
| 72 |
+
if timer.seconds() > 1:
|
| 73 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 74 |
+
|
| 75 |
+
id_map = None
|
| 76 |
+
if dataset_name is not None:
|
| 77 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 78 |
+
cat_ids = sorted(coco_api.getCatIds())
|
| 79 |
+
cats = coco_api.loadCats(cat_ids)
|
| 80 |
+
# The categories in a custom json file may not be sorted.
|
| 81 |
+
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
|
| 82 |
+
if "imagenet" not in dataset_name and "cls_agnostic" not in dataset_name:
|
| 83 |
+
meta.thing_classes = thing_classes
|
| 84 |
+
|
| 85 |
+
# In COCO, certain category ids are artificially removed,
|
| 86 |
+
# and by convention they are always ignored.
|
| 87 |
+
# We deal with COCO's id issue and translate
|
| 88 |
+
# the category ids to contiguous ids in [0, 80).
|
| 89 |
+
|
| 90 |
+
# It works by looking at the "categories" field in the json, therefore
|
| 91 |
+
# if users' own json also have incontiguous ids, we'll
|
| 92 |
+
# apply this mapping as well but print a warning.
|
| 93 |
+
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
|
| 94 |
+
if "coco" not in dataset_name:
|
| 95 |
+
logger.warning(
|
| 96 |
+
"""
|
| 97 |
+
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
|
| 98 |
+
"""
|
| 99 |
+
)
|
| 100 |
+
id_map = {v: i for i, v in enumerate(cat_ids)}
|
| 101 |
+
meta.thing_dataset_id_to_contiguous_id = id_map
|
| 102 |
+
else:
|
| 103 |
+
id_map = meta.thing_dataset_id_to_contiguous_id
|
| 104 |
+
|
| 105 |
+
# sort indices for reproducible results
|
| 106 |
+
img_ids = sorted(coco_api.imgs.keys())
|
| 107 |
+
# imgs is a list of dicts, each looks something like:
|
| 108 |
+
# {'license': 4,
|
| 109 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
| 110 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
| 111 |
+
# 'height': 427,
|
| 112 |
+
# 'width': 640,
|
| 113 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
| 114 |
+
# 'id': 1268}
|
| 115 |
+
imgs = coco_api.loadImgs(img_ids)
|
| 116 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
| 117 |
+
# record for an object. The inner list enumerates the objects in an image
|
| 118 |
+
# and the outer list enumerates over images. Example of anns[0]:
|
| 119 |
+
# [{'segmentation': [[192.81,
|
| 120 |
+
# 247.09,
|
| 121 |
+
# ...
|
| 122 |
+
# 219.03,
|
| 123 |
+
# 249.06]],
|
| 124 |
+
# 'area': 1035.749,
|
| 125 |
+
# 'iscrowd': 0,
|
| 126 |
+
# 'image_id': 1268,
|
| 127 |
+
# 'bbox': [192.81, 224.8, 74.73, 33.43],
|
| 128 |
+
# 'category_id': 16,
|
| 129 |
+
# 'id': 42986},
|
| 130 |
+
# ...]
|
| 131 |
+
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
|
| 132 |
+
total_num_valid_anns = sum([len(x) for x in anns])
|
| 133 |
+
total_num_anns = len(coco_api.anns)
|
| 134 |
+
if total_num_valid_anns < total_num_anns:
|
| 135 |
+
logger.warning(
|
| 136 |
+
f"{json_file} contains {total_num_anns} annotations, but only "
|
| 137 |
+
f"{total_num_valid_anns} of them match to images in the file."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if "minival" not in json_file:
|
| 141 |
+
# The popular valminusminival & minival annotations for COCO2014 contain this bug.
|
| 142 |
+
# However the ratio of buggy annotations there is tiny and does not affect accuracy.
|
| 143 |
+
# Therefore we explicitly white-list them.
|
| 144 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
| 145 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
| 146 |
+
json_file
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
imgs_anns = list(zip(imgs, anns))
|
| 150 |
+
logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
|
| 151 |
+
|
| 152 |
+
dataset_dicts = []
|
| 153 |
+
|
| 154 |
+
ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])
|
| 155 |
+
|
| 156 |
+
num_instances_without_valid_segmentation = 0
|
| 157 |
+
|
| 158 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
| 159 |
+
record = {}
|
| 160 |
+
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
| 161 |
+
record["height"] = img_dict["height"]
|
| 162 |
+
record["width"] = img_dict["width"]
|
| 163 |
+
image_id = record["image_id"] = img_dict["id"]
|
| 164 |
+
|
| 165 |
+
objs = []
|
| 166 |
+
for anno in anno_dict_list:
|
| 167 |
+
# Check that the image_id in this annotation is the same as
|
| 168 |
+
# the image_id we're looking at.
|
| 169 |
+
# This fails only when the data parsing logic or the annotation file is buggy.
|
| 170 |
+
|
| 171 |
+
# The original COCO valminusminival2014 & minival2014 annotation files
|
| 172 |
+
# actually contains bugs that, together with certain ways of using COCO API,
|
| 173 |
+
# can trigger this assertion.
|
| 174 |
+
assert anno["image_id"] == image_id
|
| 175 |
+
|
| 176 |
+
assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
|
| 177 |
+
|
| 178 |
+
obj = {key: anno[key] for key in ann_keys if key in anno}
|
| 179 |
+
if "bbox" in obj and len(obj["bbox"]) == 0:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
f"One annotation of image {image_id} contains empty 'bbox' value! "
|
| 182 |
+
"This json does not have valid COCO format."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
segm = anno.get("segmentation", None)
|
| 186 |
+
if segm: # either list[list[float]] or dict(RLE)
|
| 187 |
+
if isinstance(segm, dict):
|
| 188 |
+
if isinstance(segm["counts"], list):
|
| 189 |
+
# convert to compressed RLE
|
| 190 |
+
segm = mask_util.frPyObjects(segm, *segm["size"])
|
| 191 |
+
else:
|
| 192 |
+
# filter out invalid polygons (< 3 points)
|
| 193 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 194 |
+
if len(segm) == 0:
|
| 195 |
+
num_instances_without_valid_segmentation += 1
|
| 196 |
+
continue # ignore this instance
|
| 197 |
+
obj["segmentation"] = segm
|
| 198 |
+
|
| 199 |
+
keypts = anno.get("keypoints", None)
|
| 200 |
+
if keypts: # list[int]
|
| 201 |
+
for idx, v in enumerate(keypts):
|
| 202 |
+
if idx % 3 != 2:
|
| 203 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
| 204 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
| 205 |
+
# Therefore we assume the coordinates are "pixel indices" and
|
| 206 |
+
# add 0.5 to convert to floating point coordinates.
|
| 207 |
+
keypts[idx] = v + 0.5
|
| 208 |
+
obj["keypoints"] = keypts
|
| 209 |
+
|
| 210 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
| 211 |
+
if id_map:
|
| 212 |
+
annotation_category_id = obj["category_id"]
|
| 213 |
+
try:
|
| 214 |
+
obj["category_id"] = id_map[annotation_category_id]
|
| 215 |
+
except KeyError as e:
|
| 216 |
+
raise KeyError(
|
| 217 |
+
f"Encountered category_id={annotation_category_id} "
|
| 218 |
+
"but this id does not exist in 'categories' of the json file."
|
| 219 |
+
) from e
|
| 220 |
+
objs.append(obj)
|
| 221 |
+
record["annotations"] = objs
|
| 222 |
+
dataset_dicts.append(record)
|
| 223 |
+
|
| 224 |
+
if num_instances_without_valid_segmentation > 0:
|
| 225 |
+
logger.warning(
|
| 226 |
+
"Filtered out {} instances without valid segmentation. ".format(
|
| 227 |
+
num_instances_without_valid_segmentation
|
| 228 |
+
)
|
| 229 |
+
+ "There might be issues in your dataset generation process. Please "
|
| 230 |
+
"check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
|
| 231 |
+
)
|
| 232 |
+
return dataset_dicts
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"):
|
| 236 |
+
"""
|
| 237 |
+
Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are
|
| 238 |
+
treated as ground truth annotations and all files under "image_root" with "image_ext" extension
|
| 239 |
+
as input images. Ground truth and input images are matched using file paths relative to
|
| 240 |
+
"gt_root" and "image_root" respectively without taking into account file extensions.
|
| 241 |
+
This works for COCO as well as some other datasets.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation
|
| 245 |
+
annotations are stored as images with integer values in pixels that represent
|
| 246 |
+
corresponding semantic labels.
|
| 247 |
+
image_root (str): the directory where the input images are.
|
| 248 |
+
gt_ext (str): file extension for ground truth annotations.
|
| 249 |
+
image_ext (str): file extension for input images.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
list[dict]:
|
| 253 |
+
a list of dicts in detectron2 standard format without instance-level
|
| 254 |
+
annotation.
|
| 255 |
+
|
| 256 |
+
Notes:
|
| 257 |
+
1. This function does not read the image and ground truth files.
|
| 258 |
+
The results do not have the "image" and "sem_seg" fields.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
# We match input images with ground truth based on their relative filepaths (without file
|
| 262 |
+
# extensions) starting from 'image_root' and 'gt_root' respectively.
|
| 263 |
+
def file2id(folder_path, file_path):
|
| 264 |
+
# extract relative path starting from `folder_path`
|
| 265 |
+
image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path))
|
| 266 |
+
# remove file extension
|
| 267 |
+
image_id = os.path.splitext(image_id)[0]
|
| 268 |
+
return image_id
|
| 269 |
+
|
| 270 |
+
input_files = sorted(
|
| 271 |
+
(os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)),
|
| 272 |
+
key=lambda file_path: file2id(image_root, file_path),
|
| 273 |
+
)
|
| 274 |
+
gt_files = sorted(
|
| 275 |
+
(os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)),
|
| 276 |
+
key=lambda file_path: file2id(gt_root, file_path),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root)
|
| 280 |
+
|
| 281 |
+
# Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images
|
| 282 |
+
if len(input_files) != len(gt_files):
|
| 283 |
+
logger.warn(
|
| 284 |
+
"Directory {} and {} has {} and {} files, respectively.".format(
|
| 285 |
+
image_root, gt_root, len(input_files), len(gt_files)
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files]
|
| 289 |
+
gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files]
|
| 290 |
+
intersect = list(set(input_basenames) & set(gt_basenames))
|
| 291 |
+
# sort, otherwise each worker may obtain a list[dict] in different order
|
| 292 |
+
intersect = sorted(intersect)
|
| 293 |
+
logger.warn("Will use their intersection of {} files.".format(len(intersect)))
|
| 294 |
+
input_files = [os.path.join(image_root, f + image_ext) for f in intersect]
|
| 295 |
+
gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect]
|
| 296 |
+
|
| 297 |
+
logger.info(
|
| 298 |
+
"Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
dataset_dicts = []
|
| 302 |
+
for (img_path, gt_path) in zip(input_files, gt_files):
|
| 303 |
+
record = {}
|
| 304 |
+
record["file_name"] = img_path
|
| 305 |
+
record["sem_seg_file_name"] = gt_path
|
| 306 |
+
dataset_dicts.append(record)
|
| 307 |
+
|
| 308 |
+
return dataset_dicts
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def convert_to_coco_dict(dataset_name):
|
| 312 |
+
"""
|
| 313 |
+
Convert an instance detection/segmentation or keypoint detection dataset
|
| 314 |
+
in detectron2's standard format into COCO json format.
|
| 315 |
+
|
| 316 |
+
Generic dataset description can be found here:
|
| 317 |
+
https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset
|
| 318 |
+
|
| 319 |
+
COCO data format description can be found here:
|
| 320 |
+
http://cocodataset.org/#format-data
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
dataset_name (str):
|
| 324 |
+
name of the source dataset
|
| 325 |
+
Must be registered in DatastCatalog and in detectron2's standard format.
|
| 326 |
+
Must have corresponding metadata "thing_classes"
|
| 327 |
+
Returns:
|
| 328 |
+
coco_dict: serializable dict in COCO json format
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
dataset_dicts = DatasetCatalog.get(dataset_name)
|
| 332 |
+
metadata = MetadataCatalog.get(dataset_name)
|
| 333 |
+
|
| 334 |
+
# unmap the category mapping ids for COCO
|
| 335 |
+
if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
|
| 336 |
+
reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()}
|
| 337 |
+
reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa
|
| 338 |
+
else:
|
| 339 |
+
reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa
|
| 340 |
+
|
| 341 |
+
categories = [
|
| 342 |
+
{"id": reverse_id_mapper(id), "name": name}
|
| 343 |
+
for id, name in enumerate(metadata.thing_classes)
|
| 344 |
+
]
|
| 345 |
+
|
| 346 |
+
logger.info("Converting dataset dicts into COCO format")
|
| 347 |
+
coco_images = []
|
| 348 |
+
coco_annotations = []
|
| 349 |
+
|
| 350 |
+
for image_id, image_dict in enumerate(dataset_dicts):
|
| 351 |
+
coco_image = {
|
| 352 |
+
"id": image_dict.get("image_id", image_id),
|
| 353 |
+
"width": int(image_dict["width"]),
|
| 354 |
+
"height": int(image_dict["height"]),
|
| 355 |
+
"file_name": str(image_dict["file_name"]),
|
| 356 |
+
}
|
| 357 |
+
coco_images.append(coco_image)
|
| 358 |
+
|
| 359 |
+
anns_per_image = image_dict.get("annotations", [])
|
| 360 |
+
for annotation in anns_per_image:
|
| 361 |
+
# create a new dict with only COCO fields
|
| 362 |
+
coco_annotation = {}
|
| 363 |
+
|
| 364 |
+
# COCO requirement: XYWH box format for axis-align and XYWHA for rotated
|
| 365 |
+
bbox = annotation["bbox"]
|
| 366 |
+
if isinstance(bbox, np.ndarray):
|
| 367 |
+
if bbox.ndim != 1:
|
| 368 |
+
raise ValueError(f"bbox has to be 1-dimensional. Got shape={bbox.shape}.")
|
| 369 |
+
bbox = bbox.tolist()
|
| 370 |
+
if len(bbox) not in [4, 5]:
|
| 371 |
+
raise ValueError(f"bbox has to has length 4 or 5. Got {bbox}.")
|
| 372 |
+
from_bbox_mode = annotation["bbox_mode"]
|
| 373 |
+
to_bbox_mode = BoxMode.XYWH_ABS if len(bbox) == 4 else BoxMode.XYWHA_ABS
|
| 374 |
+
bbox = BoxMode.convert(bbox, from_bbox_mode, to_bbox_mode)
|
| 375 |
+
|
| 376 |
+
# COCO requirement: instance area
|
| 377 |
+
if "segmentation" in annotation:
|
| 378 |
+
# Computing areas for instances by counting the pixels
|
| 379 |
+
segmentation = annotation["segmentation"]
|
| 380 |
+
# TODO: check segmentation type: RLE, BinaryMask or Polygon
|
| 381 |
+
if isinstance(segmentation, list):
|
| 382 |
+
polygons = PolygonMasks([segmentation])
|
| 383 |
+
area = polygons.area()[0].item()
|
| 384 |
+
elif isinstance(segmentation, dict): # RLE
|
| 385 |
+
area = mask_util.area(segmentation).item()
|
| 386 |
+
else:
|
| 387 |
+
raise TypeError(f"Unknown segmentation type {type(segmentation)}!")
|
| 388 |
+
else:
|
| 389 |
+
# Computing areas using bounding boxes
|
| 390 |
+
if to_bbox_mode == BoxMode.XYWH_ABS:
|
| 391 |
+
bbox_xy = BoxMode.convert(bbox, to_bbox_mode, BoxMode.XYXY_ABS)
|
| 392 |
+
area = Boxes([bbox_xy]).area()[0].item()
|
| 393 |
+
else:
|
| 394 |
+
area = RotatedBoxes([bbox]).area()[0].item()
|
| 395 |
+
|
| 396 |
+
if "keypoints" in annotation:
|
| 397 |
+
keypoints = annotation["keypoints"] # list[int]
|
| 398 |
+
for idx, v in enumerate(keypoints):
|
| 399 |
+
if idx % 3 != 2:
|
| 400 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
| 401 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
| 402 |
+
# For COCO format consistency we substract 0.5
|
| 403 |
+
# https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163
|
| 404 |
+
keypoints[idx] = v - 0.5
|
| 405 |
+
if "num_keypoints" in annotation:
|
| 406 |
+
num_keypoints = annotation["num_keypoints"]
|
| 407 |
+
else:
|
| 408 |
+
num_keypoints = sum(kp > 0 for kp in keypoints[2::3])
|
| 409 |
+
|
| 410 |
+
# COCO requirement:
|
| 411 |
+
# linking annotations to images
|
| 412 |
+
# "id" field must start with 1
|
| 413 |
+
coco_annotation["id"] = len(coco_annotations) + 1
|
| 414 |
+
coco_annotation["image_id"] = coco_image["id"]
|
| 415 |
+
coco_annotation["bbox"] = [round(float(x), 3) for x in bbox]
|
| 416 |
+
coco_annotation["area"] = float(area)
|
| 417 |
+
coco_annotation["iscrowd"] = int(annotation.get("iscrowd", 0))
|
| 418 |
+
coco_annotation["category_id"] = int(reverse_id_mapper(annotation["category_id"]))
|
| 419 |
+
|
| 420 |
+
# Add optional fields
|
| 421 |
+
if "keypoints" in annotation:
|
| 422 |
+
coco_annotation["keypoints"] = keypoints
|
| 423 |
+
coco_annotation["num_keypoints"] = num_keypoints
|
| 424 |
+
|
| 425 |
+
if "segmentation" in annotation:
|
| 426 |
+
seg = coco_annotation["segmentation"] = annotation["segmentation"]
|
| 427 |
+
if isinstance(seg, dict): # RLE
|
| 428 |
+
counts = seg["counts"]
|
| 429 |
+
if not isinstance(counts, str):
|
| 430 |
+
# make it json-serializable
|
| 431 |
+
seg["counts"] = counts.decode("ascii")
|
| 432 |
+
|
| 433 |
+
coco_annotations.append(coco_annotation)
|
| 434 |
+
|
| 435 |
+
logger.info(
|
| 436 |
+
"Conversion finished, "
|
| 437 |
+
f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
info = {
|
| 441 |
+
"date_created": str(datetime.datetime.now()),
|
| 442 |
+
"description": "Automatically generated COCO json file for Detectron2.",
|
| 443 |
+
}
|
| 444 |
+
coco_dict = {"info": info, "images": coco_images, "categories": categories, "licenses": None}
|
| 445 |
+
if len(coco_annotations) > 0:
|
| 446 |
+
coco_dict["annotations"] = coco_annotations
|
| 447 |
+
return coco_dict
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def convert_to_coco_json(dataset_name, output_file, allow_cached=True):
|
| 451 |
+
"""
|
| 452 |
+
Converts dataset into COCO format and saves it to a json file.
|
| 453 |
+
dataset_name must be registered in DatasetCatalog and in detectron2's standard format.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
dataset_name:
|
| 457 |
+
reference from the config file to the catalogs
|
| 458 |
+
must be registered in DatasetCatalog and in detectron2's standard format
|
| 459 |
+
output_file: path of json file that will be saved to
|
| 460 |
+
allow_cached: if json file is already present then skip conversion
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
# TODO: The dataset or the conversion script *may* change,
|
| 464 |
+
# a checksum would be useful for validating the cached data
|
| 465 |
+
|
| 466 |
+
PathManager.mkdirs(os.path.dirname(output_file))
|
| 467 |
+
with file_lock(output_file):
|
| 468 |
+
if PathManager.exists(output_file) and allow_cached:
|
| 469 |
+
logger.warning(
|
| 470 |
+
f"Using previously cached COCO format annotations at '{output_file}'. "
|
| 471 |
+
"You need to clear the cache file if your dataset has been modified."
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
+
logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)")
|
| 475 |
+
coco_dict = convert_to_coco_dict(dataset_name)
|
| 476 |
+
|
| 477 |
+
logger.info(f"Caching COCO format annotations at '{output_file}' ...")
|
| 478 |
+
tmp_file = output_file + ".tmp"
|
| 479 |
+
with PathManager.open(tmp_file, "w") as f:
|
| 480 |
+
json.dump(coco_dict, f)
|
| 481 |
+
shutil.move(tmp_file, output_file)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def register_coco_instances(name, metadata, json_file, image_root):
|
| 485 |
+
"""
|
| 486 |
+
Register a dataset in COCO's json annotation format for
|
| 487 |
+
instance detection, instance segmentation and keypoint detection.
|
| 488 |
+
(i.e., Type 1 and 2 in http://cocodataset.org/#format-data.
|
| 489 |
+
`instances*.json` and `person_keypoints*.json` in the dataset).
|
| 490 |
+
|
| 491 |
+
This is an example of how to register a new dataset.
|
| 492 |
+
You can do something similar to this function, to register new datasets.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
| 496 |
+
metadata (dict): extra metadata associated with this dataset. You can
|
| 497 |
+
leave it as an empty dict.
|
| 498 |
+
json_file (str): path to the json instance annotation file.
|
| 499 |
+
image_root (str or path-like): directory which contains all the images.
|
| 500 |
+
"""
|
| 501 |
+
assert isinstance(name, str), name
|
| 502 |
+
assert isinstance(json_file, (str, os.PathLike)), json_file
|
| 503 |
+
assert isinstance(image_root, (str, os.PathLike)), image_root
|
| 504 |
+
# 1. register a function which returns dicts
|
| 505 |
+
DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
|
| 506 |
+
|
| 507 |
+
# 2. Optionally, add metadata about this dataset,
|
| 508 |
+
# since they might be useful in evaluation, visualization or logging
|
| 509 |
+
MetadataCatalog.get(name).set(
|
| 510 |
+
json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
if __name__ == "__main__":
|
| 515 |
+
"""
|
| 516 |
+
Test the COCO json dataset loader.
|
| 517 |
+
|
| 518 |
+
Usage:
|
| 519 |
+
python -m detectron2.data.datasets.coco \
|
| 520 |
+
path/to/json path/to/image_root dataset_name
|
| 521 |
+
|
| 522 |
+
"dataset_name" can be "coco_2014_minival_100", or other
|
| 523 |
+
pre-registered ones
|
| 524 |
+
"""
|
| 525 |
+
from detectron2.utils.logger import setup_logger
|
| 526 |
+
from detectron2.utils.visualizer import Visualizer
|
| 527 |
+
import detectron2.data.datasets # noqa # add pre-defined metadata
|
| 528 |
+
import sys
|
| 529 |
+
|
| 530 |
+
logger = setup_logger(name=__name__)
|
| 531 |
+
assert sys.argv[3] in DatasetCatalog.list()
|
| 532 |
+
meta = MetadataCatalog.get(sys.argv[3])
|
| 533 |
+
|
| 534 |
+
dicts = load_coco_json(sys.argv[1], sys.argv[2], sys.argv[3])
|
| 535 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 536 |
+
|
| 537 |
+
dirname = "coco-data-vis"
|
| 538 |
+
os.makedirs(dirname, exist_ok=True)
|
| 539 |
+
for d in dicts:
|
| 540 |
+
img = np.array(Image.open(d["file_name"]))
|
| 541 |
+
visualizer = Visualizer(img, metadata=meta)
|
| 542 |
+
vis = visualizer.draw_dataset_dict(d)
|
| 543 |
+
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
|
| 544 |
+
vis.save(fpath)
|
cutler/data/detection_utils.py
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/detection_utils.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Common data processing utilities that are used in a
|
| 6 |
+
typical object detection data pipeline.
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import List, Union
|
| 11 |
+
import pycocotools.mask as mask_util
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from detectron2.structures import (
|
| 16 |
+
Boxes,
|
| 17 |
+
BoxMode,
|
| 18 |
+
BitMasks,
|
| 19 |
+
Instances,
|
| 20 |
+
Keypoints,
|
| 21 |
+
PolygonMasks,
|
| 22 |
+
RotatedBoxes,
|
| 23 |
+
polygons_to_bitmask,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from detectron2.utils.file_io import PathManager
|
| 27 |
+
|
| 28 |
+
from data import transforms as T
|
| 29 |
+
from detectron2.data.catalog import MetadataCatalog
|
| 30 |
+
|
| 31 |
+
__all__ = [
|
| 32 |
+
"SizeMismatchError",
|
| 33 |
+
"convert_image_to_rgb",
|
| 34 |
+
"check_image_size",
|
| 35 |
+
"transform_proposals",
|
| 36 |
+
"transform_instance_annotations",
|
| 37 |
+
"annotations_to_instances",
|
| 38 |
+
"annotations_to_instances_rotated",
|
| 39 |
+
"build_augmentation",
|
| 40 |
+
"build_transform_gen",
|
| 41 |
+
"create_keypoint_hflip_indices",
|
| 42 |
+
"filter_empty_instances",
|
| 43 |
+
"read_image",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class SizeMismatchError(ValueError):
|
| 48 |
+
"""
|
| 49 |
+
When loaded image has difference width/height compared with annotation.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
|
| 54 |
+
_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
|
| 55 |
+
_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]
|
| 56 |
+
|
| 57 |
+
# https://www.exiv2.org/tags.html
|
| 58 |
+
_EXIF_ORIENT = 274 # exif 'Orientation' tag
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def convert_PIL_to_numpy(image, format):
|
| 62 |
+
"""
|
| 63 |
+
Convert PIL image to numpy array of target format.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
image (PIL.Image): a PIL image
|
| 67 |
+
format (str): the format of output image
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
(np.ndarray): also see `read_image`
|
| 71 |
+
"""
|
| 72 |
+
if format is not None:
|
| 73 |
+
# PIL only supports RGB, so convert to RGB and flip channels over below
|
| 74 |
+
conversion_format = format
|
| 75 |
+
if format in ["BGR", "YUV-BT.601"]:
|
| 76 |
+
conversion_format = "RGB"
|
| 77 |
+
image = image.convert(conversion_format)
|
| 78 |
+
image = np.asarray(image)
|
| 79 |
+
# PIL squeezes out the channel dimension for "L", so make it HWC
|
| 80 |
+
if format == "L":
|
| 81 |
+
image = np.expand_dims(image, -1)
|
| 82 |
+
|
| 83 |
+
# handle formats not supported by PIL
|
| 84 |
+
elif format == "BGR":
|
| 85 |
+
# flip channels if needed
|
| 86 |
+
image = image[:, :, ::-1]
|
| 87 |
+
elif format == "YUV-BT.601":
|
| 88 |
+
image = image / 255.0
|
| 89 |
+
image = np.dot(image, np.array(_M_RGB2YUV).T)
|
| 90 |
+
|
| 91 |
+
return image
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def convert_image_to_rgb(image, format):
|
| 95 |
+
"""
|
| 96 |
+
Convert an image from given format to RGB.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
image (np.ndarray or Tensor): an HWC image
|
| 100 |
+
format (str): the format of input image, also see `read_image`
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
(np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8
|
| 104 |
+
"""
|
| 105 |
+
if isinstance(image, torch.Tensor):
|
| 106 |
+
image = image.cpu().numpy()
|
| 107 |
+
if format == "BGR":
|
| 108 |
+
image = image[:, :, [2, 1, 0]]
|
| 109 |
+
elif format == "YUV-BT.601":
|
| 110 |
+
image = np.dot(image, np.array(_M_YUV2RGB).T)
|
| 111 |
+
image = image * 255.0
|
| 112 |
+
else:
|
| 113 |
+
if format == "L":
|
| 114 |
+
image = image[:, :, 0]
|
| 115 |
+
image = image.astype(np.uint8)
|
| 116 |
+
image = np.asarray(Image.fromarray(image, mode=format).convert("RGB"))
|
| 117 |
+
return image
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _apply_exif_orientation(image):
|
| 121 |
+
"""
|
| 122 |
+
Applies the exif orientation correctly.
|
| 123 |
+
|
| 124 |
+
This code exists per the bug:
|
| 125 |
+
https://github.com/python-pillow/Pillow/issues/3973
|
| 126 |
+
with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
|
| 127 |
+
various methods, especially `tobytes`
|
| 128 |
+
|
| 129 |
+
Function based on:
|
| 130 |
+
https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
|
| 131 |
+
https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
image (PIL.Image): a PIL image
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
(PIL.Image): the PIL image with exif orientation applied, if applicable
|
| 138 |
+
"""
|
| 139 |
+
if not hasattr(image, "getexif"):
|
| 140 |
+
return image
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
exif = image.getexif()
|
| 144 |
+
except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
|
| 145 |
+
exif = None
|
| 146 |
+
|
| 147 |
+
if exif is None:
|
| 148 |
+
return image
|
| 149 |
+
|
| 150 |
+
orientation = exif.get(_EXIF_ORIENT)
|
| 151 |
+
|
| 152 |
+
method = {
|
| 153 |
+
2: Image.FLIP_LEFT_RIGHT,
|
| 154 |
+
3: Image.ROTATE_180,
|
| 155 |
+
4: Image.FLIP_TOP_BOTTOM,
|
| 156 |
+
5: Image.TRANSPOSE,
|
| 157 |
+
6: Image.ROTATE_270,
|
| 158 |
+
7: Image.TRANSVERSE,
|
| 159 |
+
8: Image.ROTATE_90,
|
| 160 |
+
}.get(orientation)
|
| 161 |
+
|
| 162 |
+
if method is not None:
|
| 163 |
+
return image.transpose(method)
|
| 164 |
+
return image
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def read_image(file_name, format=None):
|
| 168 |
+
"""
|
| 169 |
+
Read an image into the given format.
|
| 170 |
+
Will apply rotation and flipping if the image has such exif information.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
file_name (str): image file path
|
| 174 |
+
format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601".
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
image (np.ndarray):
|
| 178 |
+
an HWC image in the given format, which is 0-255, uint8 for
|
| 179 |
+
supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601.
|
| 180 |
+
"""
|
| 181 |
+
with PathManager.open(file_name, "rb") as f:
|
| 182 |
+
image = Image.open(f)
|
| 183 |
+
|
| 184 |
+
# work around this bug: https://github.com/python-pillow/Pillow/issues/3973
|
| 185 |
+
image = _apply_exif_orientation(image)
|
| 186 |
+
return convert_PIL_to_numpy(image, format)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def check_image_size(dataset_dict, image):
|
| 190 |
+
"""
|
| 191 |
+
Raise an error if the image does not match the size specified in the dict.
|
| 192 |
+
"""
|
| 193 |
+
if "width" in dataset_dict or "height" in dataset_dict:
|
| 194 |
+
image_wh = (image.shape[1], image.shape[0])
|
| 195 |
+
expected_wh = (dataset_dict["width"], dataset_dict["height"])
|
| 196 |
+
if not image_wh == expected_wh:
|
| 197 |
+
expected_wh = (dataset_dict["height"], dataset_dict["width"])
|
| 198 |
+
dataset_dict["height"], dataset_dict["width"] = dataset_dict["width"], dataset_dict["height"]
|
| 199 |
+
if image_wh != expected_wh:
|
| 200 |
+
raise SizeMismatchError(
|
| 201 |
+
"Mismatched image shape{}, got {}, expect {}.".format(
|
| 202 |
+
" for image " + dataset_dict["file_name"]
|
| 203 |
+
if "file_name" in dataset_dict
|
| 204 |
+
else "",
|
| 205 |
+
image_wh,
|
| 206 |
+
expected_wh,
|
| 207 |
+
)
|
| 208 |
+
+ " Please check the width/height in your annotation."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# To ensure bbox always remap to original image size
|
| 212 |
+
if "width" not in dataset_dict:
|
| 213 |
+
dataset_dict["width"] = image.shape[1]
|
| 214 |
+
if "height" not in dataset_dict:
|
| 215 |
+
dataset_dict["height"] = image.shape[0]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0):
|
| 219 |
+
"""
|
| 220 |
+
Apply transformations to the proposals in dataset_dict, if any.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
dataset_dict (dict): a dict read from the dataset, possibly
|
| 224 |
+
contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
|
| 225 |
+
image_shape (tuple): height, width
|
| 226 |
+
transforms (TransformList):
|
| 227 |
+
proposal_topk (int): only keep top-K scoring proposals
|
| 228 |
+
min_box_size (int): proposals with either side smaller than this
|
| 229 |
+
threshold are removed
|
| 230 |
+
|
| 231 |
+
The input dict is modified in-place, with abovementioned keys removed. A new
|
| 232 |
+
key "proposals" will be added. Its value is an `Instances`
|
| 233 |
+
object which contains the transformed proposals in its field
|
| 234 |
+
"proposal_boxes" and "objectness_logits".
|
| 235 |
+
"""
|
| 236 |
+
if "proposal_boxes" in dataset_dict:
|
| 237 |
+
# Transform proposal boxes
|
| 238 |
+
boxes = transforms.apply_box(
|
| 239 |
+
BoxMode.convert(
|
| 240 |
+
dataset_dict.pop("proposal_boxes"),
|
| 241 |
+
dataset_dict.pop("proposal_bbox_mode"),
|
| 242 |
+
BoxMode.XYXY_ABS,
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
boxes = Boxes(boxes)
|
| 246 |
+
objectness_logits = torch.as_tensor(
|
| 247 |
+
dataset_dict.pop("proposal_objectness_logits").astype("float32")
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
boxes.clip(image_shape)
|
| 251 |
+
keep = boxes.nonempty(threshold=min_box_size)
|
| 252 |
+
boxes = boxes[keep]
|
| 253 |
+
objectness_logits = objectness_logits[keep]
|
| 254 |
+
|
| 255 |
+
proposals = Instances(image_shape)
|
| 256 |
+
proposals.proposal_boxes = boxes[:proposal_topk]
|
| 257 |
+
proposals.objectness_logits = objectness_logits[:proposal_topk]
|
| 258 |
+
dataset_dict["proposals"] = proposals
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def transform_instance_annotations(
|
| 262 |
+
annotation, transforms, image_size, *, keypoint_hflip_indices=None
|
| 263 |
+
):
|
| 264 |
+
"""
|
| 265 |
+
Apply transforms to box, segmentation and keypoints annotations of a single instance.
|
| 266 |
+
|
| 267 |
+
It will use `transforms.apply_box` for the box, and
|
| 268 |
+
`transforms.apply_coords` for segmentation polygons & keypoints.
|
| 269 |
+
If you need anything more specially designed for each data structure,
|
| 270 |
+
you'll need to implement your own version of this function or the transforms.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
annotation (dict): dict of instance annotations for a single instance.
|
| 274 |
+
It will be modified in-place.
|
| 275 |
+
transforms (TransformList or list[Transform]):
|
| 276 |
+
image_size (tuple): the height, width of the transformed image
|
| 277 |
+
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
dict:
|
| 281 |
+
the same input dict with fields "bbox", "segmentation", "keypoints"
|
| 282 |
+
transformed according to `transforms`.
|
| 283 |
+
The "bbox_mode" field will be set to XYXY_ABS.
|
| 284 |
+
"""
|
| 285 |
+
if isinstance(transforms, (tuple, list)):
|
| 286 |
+
transforms = T.TransformList(transforms)
|
| 287 |
+
# bbox is 1d (per-instance bounding box)
|
| 288 |
+
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
|
| 289 |
+
# clip transformed bbox to image size
|
| 290 |
+
bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
|
| 291 |
+
annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
|
| 292 |
+
annotation["bbox_mode"] = BoxMode.XYXY_ABS
|
| 293 |
+
|
| 294 |
+
if "segmentation" in annotation:
|
| 295 |
+
# each instance contains 1 or more polygons
|
| 296 |
+
segm = annotation["segmentation"]
|
| 297 |
+
if isinstance(segm, list):
|
| 298 |
+
# polygons
|
| 299 |
+
polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
|
| 300 |
+
annotation["segmentation"] = [
|
| 301 |
+
p.reshape(-1) for p in transforms.apply_polygons(polygons)
|
| 302 |
+
]
|
| 303 |
+
elif isinstance(segm, dict):
|
| 304 |
+
# RLE
|
| 305 |
+
mask = mask_util.decode(segm)
|
| 306 |
+
mask = transforms.apply_segmentation(mask)
|
| 307 |
+
assert tuple(mask.shape[:2]) == image_size
|
| 308 |
+
annotation["segmentation"] = mask
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
"Cannot transform segmentation of type '{}'!"
|
| 312 |
+
"Supported types are: polygons as list[list[float] or ndarray],"
|
| 313 |
+
" COCO-style RLE as a dict.".format(type(segm))
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
if "keypoints" in annotation:
|
| 317 |
+
keypoints = transform_keypoint_annotations(
|
| 318 |
+
annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
|
| 319 |
+
)
|
| 320 |
+
annotation["keypoints"] = keypoints
|
| 321 |
+
|
| 322 |
+
return annotation
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None):
|
| 326 |
+
"""
|
| 327 |
+
Transform keypoint annotations of an image.
|
| 328 |
+
If a keypoint is transformed out of image boundary, it will be marked "unlabeled" (visibility=0)
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
keypoints (list[float]): Nx3 float in Detectron2's Dataset format.
|
| 332 |
+
Each point is represented by (x, y, visibility).
|
| 333 |
+
transforms (TransformList):
|
| 334 |
+
image_size (tuple): the height, width of the transformed image
|
| 335 |
+
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
|
| 336 |
+
When `transforms` includes horizontal flip, will use the index
|
| 337 |
+
mapping to flip keypoints.
|
| 338 |
+
"""
|
| 339 |
+
# (N*3,) -> (N, 3)
|
| 340 |
+
keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3)
|
| 341 |
+
keypoints_xy = transforms.apply_coords(keypoints[:, :2])
|
| 342 |
+
|
| 343 |
+
# Set all out-of-boundary points to "unlabeled"
|
| 344 |
+
inside = (keypoints_xy >= np.array([0, 0])) & (keypoints_xy <= np.array(image_size[::-1]))
|
| 345 |
+
inside = inside.all(axis=1)
|
| 346 |
+
keypoints[:, :2] = keypoints_xy
|
| 347 |
+
keypoints[:, 2][~inside] = 0
|
| 348 |
+
|
| 349 |
+
# This assumes that HorizFlipTransform is the only one that does flip
|
| 350 |
+
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
|
| 351 |
+
|
| 352 |
+
# Alternative way: check if probe points was horizontally flipped.
|
| 353 |
+
# probe = np.asarray([[0.0, 0.0], [image_width, 0.0]])
|
| 354 |
+
# probe_aug = transforms.apply_coords(probe.copy())
|
| 355 |
+
# do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa
|
| 356 |
+
|
| 357 |
+
# If flipped, swap each keypoint with its opposite-handed equivalent
|
| 358 |
+
if do_hflip:
|
| 359 |
+
if keypoint_hflip_indices is None:
|
| 360 |
+
raise ValueError("Cannot flip keypoints without providing flip indices!")
|
| 361 |
+
if len(keypoints) != len(keypoint_hflip_indices):
|
| 362 |
+
raise ValueError(
|
| 363 |
+
"Keypoint data has {} points, but metadata "
|
| 364 |
+
"contains {} points!".format(len(keypoints), len(keypoint_hflip_indices))
|
| 365 |
+
)
|
| 366 |
+
keypoints = keypoints[np.asarray(keypoint_hflip_indices, dtype=np.int32), :]
|
| 367 |
+
|
| 368 |
+
# Maintain COCO convention that if visibility == 0 (unlabeled), then x, y = 0
|
| 369 |
+
keypoints[keypoints[:, 2] == 0] = 0
|
| 370 |
+
return keypoints
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def annotations_to_instances(annos, image_size, mask_format="polygon"):
|
| 374 |
+
"""
|
| 375 |
+
Create an :class:`Instances` object used by the models,
|
| 376 |
+
from instance annotations in the dataset dict.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
annos (list[dict]): a list of instance annotations in one image, each
|
| 380 |
+
element for one instance.
|
| 381 |
+
image_size (tuple): height, width
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
Instances:
|
| 385 |
+
It will contain fields "gt_boxes", "gt_classes",
|
| 386 |
+
"gt_masks", "gt_keypoints", if they can be obtained from `annos`.
|
| 387 |
+
This is the format that builtin models expect.
|
| 388 |
+
"""
|
| 389 |
+
boxes = (
|
| 390 |
+
np.stack(
|
| 391 |
+
[BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
|
| 392 |
+
)
|
| 393 |
+
if len(annos)
|
| 394 |
+
else np.zeros((0, 4))
|
| 395 |
+
)
|
| 396 |
+
target = Instances(image_size)
|
| 397 |
+
target.gt_boxes = Boxes(boxes)
|
| 398 |
+
|
| 399 |
+
classes = [int(obj["category_id"]) for obj in annos]
|
| 400 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 401 |
+
target.gt_classes = classes
|
| 402 |
+
|
| 403 |
+
if len(annos) and "segmentation" in annos[0]:
|
| 404 |
+
segms = [obj["segmentation"] for obj in annos]
|
| 405 |
+
if mask_format == "polygon":
|
| 406 |
+
try:
|
| 407 |
+
masks = PolygonMasks(segms)
|
| 408 |
+
except ValueError as e:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
"Failed to use mask_format=='polygon' from the given annotations!"
|
| 411 |
+
) from e
|
| 412 |
+
else:
|
| 413 |
+
assert mask_format == "bitmask", mask_format
|
| 414 |
+
masks = []
|
| 415 |
+
for segm in segms:
|
| 416 |
+
if isinstance(segm, list):
|
| 417 |
+
# polygon
|
| 418 |
+
masks.append(polygons_to_bitmask(segm, *image_size))
|
| 419 |
+
elif isinstance(segm, dict):
|
| 420 |
+
# COCO RLE
|
| 421 |
+
masks.append(mask_util.decode(segm))
|
| 422 |
+
elif isinstance(segm, np.ndarray):
|
| 423 |
+
assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
|
| 424 |
+
segm.ndim
|
| 425 |
+
)
|
| 426 |
+
# mask array
|
| 427 |
+
masks.append(segm)
|
| 428 |
+
else:
|
| 429 |
+
raise ValueError(
|
| 430 |
+
"Cannot convert segmentation of type '{}' to BitMasks!"
|
| 431 |
+
"Supported types are: polygons as list[list[float] or ndarray],"
|
| 432 |
+
" COCO-style RLE as a dict, or a binary segmentation mask "
|
| 433 |
+
" in a 2D numpy array of shape HxW.".format(type(segm))
|
| 434 |
+
)
|
| 435 |
+
# torch.from_numpy does not support array with negative stride.
|
| 436 |
+
masks = BitMasks(
|
| 437 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
|
| 438 |
+
)
|
| 439 |
+
target.gt_masks = masks
|
| 440 |
+
|
| 441 |
+
if len(annos) and "keypoints" in annos[0]:
|
| 442 |
+
kpts = [obj.get("keypoints", []) for obj in annos]
|
| 443 |
+
target.gt_keypoints = Keypoints(kpts)
|
| 444 |
+
|
| 445 |
+
return target
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def annotations_to_instances_rotated(annos, image_size):
|
| 449 |
+
"""
|
| 450 |
+
Create an :class:`Instances` object used by the models,
|
| 451 |
+
from instance annotations in the dataset dict.
|
| 452 |
+
Compared to `annotations_to_instances`, this function is for rotated boxes only
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
annos (list[dict]): a list of instance annotations in one image, each
|
| 456 |
+
element for one instance.
|
| 457 |
+
image_size (tuple): height, width
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
Instances:
|
| 461 |
+
Containing fields "gt_boxes", "gt_classes",
|
| 462 |
+
if they can be obtained from `annos`.
|
| 463 |
+
This is the format that builtin models expect.
|
| 464 |
+
"""
|
| 465 |
+
boxes = [obj["bbox"] for obj in annos]
|
| 466 |
+
target = Instances(image_size)
|
| 467 |
+
boxes = target.gt_boxes = RotatedBoxes(boxes)
|
| 468 |
+
boxes.clip(image_size)
|
| 469 |
+
|
| 470 |
+
classes = [obj["category_id"] for obj in annos]
|
| 471 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 472 |
+
target.gt_classes = classes
|
| 473 |
+
|
| 474 |
+
return target
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def filter_empty_instances(
|
| 478 |
+
instances, by_box=True, by_mask=True, box_threshold=1e-5, return_mask=False
|
| 479 |
+
):
|
| 480 |
+
"""
|
| 481 |
+
Filter out empty instances in an `Instances` object.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
instances (Instances):
|
| 485 |
+
by_box (bool): whether to filter out instances with empty boxes
|
| 486 |
+
by_mask (bool): whether to filter out instances with empty masks
|
| 487 |
+
box_threshold (float): minimum width and height to be considered non-empty
|
| 488 |
+
return_mask (bool): whether to return boolean mask of filtered instances
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
Instances: the filtered instances.
|
| 492 |
+
tensor[bool], optional: boolean mask of filtered instances
|
| 493 |
+
"""
|
| 494 |
+
assert by_box or by_mask
|
| 495 |
+
r = []
|
| 496 |
+
if by_box:
|
| 497 |
+
r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
|
| 498 |
+
if instances.has("gt_masks") and by_mask:
|
| 499 |
+
r.append(instances.gt_masks.nonempty())
|
| 500 |
+
|
| 501 |
+
# TODO: can also filter visible keypoints
|
| 502 |
+
|
| 503 |
+
if not r:
|
| 504 |
+
return instances
|
| 505 |
+
m = r[0]
|
| 506 |
+
for x in r[1:]:
|
| 507 |
+
m = m & x
|
| 508 |
+
if return_mask:
|
| 509 |
+
return instances[m], m
|
| 510 |
+
return instances[m]
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) -> List[int]:
|
| 514 |
+
"""
|
| 515 |
+
Args:
|
| 516 |
+
dataset_names: list of dataset names
|
| 517 |
+
|
| 518 |
+
Returns:
|
| 519 |
+
list[int]: a list of size=#keypoints, storing the
|
| 520 |
+
horizontally-flipped keypoint indices.
|
| 521 |
+
"""
|
| 522 |
+
if isinstance(dataset_names, str):
|
| 523 |
+
dataset_names = [dataset_names]
|
| 524 |
+
|
| 525 |
+
check_metadata_consistency("keypoint_names", dataset_names)
|
| 526 |
+
check_metadata_consistency("keypoint_flip_map", dataset_names)
|
| 527 |
+
|
| 528 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
| 529 |
+
names = meta.keypoint_names
|
| 530 |
+
# TODO flip -> hflip
|
| 531 |
+
flip_map = dict(meta.keypoint_flip_map)
|
| 532 |
+
flip_map.update({v: k for k, v in flip_map.items()})
|
| 533 |
+
flipped_names = [i if i not in flip_map else flip_map[i] for i in names]
|
| 534 |
+
flip_indices = [names.index(i) for i in flipped_names]
|
| 535 |
+
return flip_indices
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def get_fed_loss_cls_weights(dataset_names: Union[str, List[str]], freq_weight_power=1.0):
|
| 539 |
+
"""
|
| 540 |
+
Get frequency weight for each class sorted by class id.
|
| 541 |
+
We now calcualte freqency weight using image_count to the power freq_weight_power.
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
dataset_names: list of dataset names
|
| 545 |
+
freq_weight_power: power value
|
| 546 |
+
"""
|
| 547 |
+
if isinstance(dataset_names, str):
|
| 548 |
+
dataset_names = [dataset_names]
|
| 549 |
+
|
| 550 |
+
check_metadata_consistency("class_image_count", dataset_names)
|
| 551 |
+
|
| 552 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
| 553 |
+
class_freq_meta = meta.class_image_count
|
| 554 |
+
class_freq = torch.tensor(
|
| 555 |
+
[c["image_count"] for c in sorted(class_freq_meta, key=lambda x: x["id"])]
|
| 556 |
+
)
|
| 557 |
+
class_freq_weight = class_freq.float() ** freq_weight_power
|
| 558 |
+
return class_freq_weight
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def gen_crop_transform_with_instance(crop_size, image_size, instance):
|
| 562 |
+
"""
|
| 563 |
+
Generate a CropTransform so that the cropping region contains
|
| 564 |
+
the center of the given instance.
|
| 565 |
+
|
| 566 |
+
Args:
|
| 567 |
+
crop_size (tuple): h, w in pixels
|
| 568 |
+
image_size (tuple): h, w
|
| 569 |
+
instance (dict): an annotation dict of one instance, in Detectron2's
|
| 570 |
+
dataset format.
|
| 571 |
+
"""
|
| 572 |
+
crop_size = np.asarray(crop_size, dtype=np.int32)
|
| 573 |
+
bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
|
| 574 |
+
center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
|
| 575 |
+
assert (
|
| 576 |
+
image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
|
| 577 |
+
), "The annotation bounding box is outside of the image!"
|
| 578 |
+
assert (
|
| 579 |
+
image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
|
| 580 |
+
), "Crop size is larger than image size!"
|
| 581 |
+
|
| 582 |
+
min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
|
| 583 |
+
max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
|
| 584 |
+
max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
|
| 585 |
+
|
| 586 |
+
y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
|
| 587 |
+
x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
|
| 588 |
+
return T.CropTransform(x0, y0, crop_size[1], crop_size[0])
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def check_metadata_consistency(key, dataset_names):
|
| 592 |
+
"""
|
| 593 |
+
Check that the datasets have consistent metadata.
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
key (str): a metadata key
|
| 597 |
+
dataset_names (list[str]): a list of dataset names
|
| 598 |
+
|
| 599 |
+
Raises:
|
| 600 |
+
AttributeError: if the key does not exist in the metadata
|
| 601 |
+
ValueError: if the given datasets do not have the same metadata values defined by key
|
| 602 |
+
"""
|
| 603 |
+
if len(dataset_names) == 0:
|
| 604 |
+
return
|
| 605 |
+
logger = logging.getLogger(__name__)
|
| 606 |
+
entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names]
|
| 607 |
+
for idx, entry in enumerate(entries_per_dataset):
|
| 608 |
+
if entry != entries_per_dataset[0]:
|
| 609 |
+
logger.error(
|
| 610 |
+
"Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry))
|
| 611 |
+
)
|
| 612 |
+
logger.error(
|
| 613 |
+
"Metadata '{}' for dataset '{}' is '{}'".format(
|
| 614 |
+
key, dataset_names[0], str(entries_per_dataset[0])
|
| 615 |
+
)
|
| 616 |
+
)
|
| 617 |
+
raise ValueError("Datasets have different metadata '{}'!".format(key))
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def build_augmentation(cfg, is_train):
|
| 621 |
+
"""
|
| 622 |
+
Create a list of default :class:`Augmentation` from config.
|
| 623 |
+
Now it includes resizing and flipping.
|
| 624 |
+
|
| 625 |
+
Returns:
|
| 626 |
+
list[Augmentation]
|
| 627 |
+
"""
|
| 628 |
+
if is_train:
|
| 629 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
| 630 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
| 631 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
| 632 |
+
else:
|
| 633 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
| 634 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
| 635 |
+
sample_style = "choice"
|
| 636 |
+
augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
|
| 637 |
+
if is_train and cfg.INPUT.RANDOM_FLIP != "none":
|
| 638 |
+
augmentation.append(
|
| 639 |
+
T.RandomFlip(
|
| 640 |
+
horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
|
| 641 |
+
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
|
| 642 |
+
)
|
| 643 |
+
)
|
| 644 |
+
return augmentation
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
build_transform_gen = build_augmentation
|
| 648 |
+
"""
|
| 649 |
+
Alias for backward-compatibility.
|
| 650 |
+
"""
|
cutler/data/transforms/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/__init__.py
|
| 3 |
+
|
| 4 |
+
from fvcore.transforms.transform import *
|
| 5 |
+
from .transform import *
|
| 6 |
+
from detectron2.data.transforms.augmentation import *
|
| 7 |
+
from .augmentation_impl import *
|
| 8 |
+
|
| 9 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from detectron2.utils.env import fixup_module_metadata
|
| 13 |
+
|
| 14 |
+
fixup_module_metadata(__name__, globals(), __all__)
|
| 15 |
+
del fixup_module_metadata
|
cutler/data/transforms/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (728 Bytes). View file
|
|
|
cutler/data/transforms/__pycache__/augmentation_impl.cpython-312.pyc
ADDED
|
Binary file (32.9 kB). View file
|
|
|
cutler/data/transforms/__pycache__/transform.cpython-312.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
cutler/data/transforms/augmentation_impl.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
Implement many useful :class:`Augmentation`.
|
| 7 |
+
"""
|
| 8 |
+
import numpy as np
|
| 9 |
+
import sys
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
import torch
|
| 12 |
+
from fvcore.transforms.transform import (
|
| 13 |
+
BlendTransform,
|
| 14 |
+
CropTransform,
|
| 15 |
+
HFlipTransform,
|
| 16 |
+
NoOpTransform,
|
| 17 |
+
PadTransform,
|
| 18 |
+
Transform,
|
| 19 |
+
TransformList,
|
| 20 |
+
VFlipTransform,
|
| 21 |
+
)
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
from detectron2.data.transforms.augmentation import Augmentation, _transform_to_aug
|
| 25 |
+
from .transform import ExtentTransform, ResizeTransform, RotationTransform
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"FixedSizeCrop",
|
| 29 |
+
"RandomApply",
|
| 30 |
+
"RandomBrightness",
|
| 31 |
+
"RandomContrast",
|
| 32 |
+
"RandomCrop",
|
| 33 |
+
"RandomExtent",
|
| 34 |
+
"RandomFlip",
|
| 35 |
+
"RandomSaturation",
|
| 36 |
+
"RandomLighting",
|
| 37 |
+
"RandomRotation",
|
| 38 |
+
"Resize",
|
| 39 |
+
"ResizeScale",
|
| 40 |
+
"ResizeShortestEdge",
|
| 41 |
+
"RandomCrop_CategoryAreaConstraint",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class RandomApply(Augmentation):
|
| 46 |
+
"""
|
| 47 |
+
Randomly apply an augmentation with a given probability.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, tfm_or_aug, prob=0.5):
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
tfm_or_aug (Transform, Augmentation): the transform or augmentation
|
| 54 |
+
to be applied. It can either be a `Transform` or `Augmentation`
|
| 55 |
+
instance.
|
| 56 |
+
prob (float): probability between 0.0 and 1.0 that
|
| 57 |
+
the wrapper transformation is applied
|
| 58 |
+
"""
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.aug = _transform_to_aug(tfm_or_aug)
|
| 61 |
+
assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
|
| 62 |
+
self.prob = prob
|
| 63 |
+
|
| 64 |
+
def get_transform(self, *args):
|
| 65 |
+
do = self._rand_range() < self.prob
|
| 66 |
+
if do:
|
| 67 |
+
return self.aug.get_transform(*args)
|
| 68 |
+
else:
|
| 69 |
+
return NoOpTransform()
|
| 70 |
+
|
| 71 |
+
def __call__(self, aug_input):
|
| 72 |
+
do = self._rand_range() < self.prob
|
| 73 |
+
if do:
|
| 74 |
+
return self.aug(aug_input)
|
| 75 |
+
else:
|
| 76 |
+
return NoOpTransform()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class RandomFlip(Augmentation):
|
| 80 |
+
"""
|
| 81 |
+
Flip the image horizontally or vertically with the given probability.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, prob=0.5, *, horizontal=True, vertical=False):
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
prob (float): probability of flip.
|
| 88 |
+
horizontal (boolean): whether to apply horizontal flipping
|
| 89 |
+
vertical (boolean): whether to apply vertical flipping
|
| 90 |
+
"""
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
if horizontal and vertical:
|
| 94 |
+
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
|
| 95 |
+
if not horizontal and not vertical:
|
| 96 |
+
raise ValueError("At least one of horiz or vert has to be True!")
|
| 97 |
+
self._init(locals())
|
| 98 |
+
|
| 99 |
+
def get_transform(self, image):
|
| 100 |
+
h, w = image.shape[:2]
|
| 101 |
+
do = self._rand_range() < self.prob
|
| 102 |
+
if do:
|
| 103 |
+
if self.horizontal:
|
| 104 |
+
return HFlipTransform(w)
|
| 105 |
+
elif self.vertical:
|
| 106 |
+
return VFlipTransform(h)
|
| 107 |
+
else:
|
| 108 |
+
return NoOpTransform()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Resize(Augmentation):
|
| 112 |
+
"""Resize image to a fixed target size"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, shape, interp=Image.BILINEAR):
|
| 115 |
+
"""
|
| 116 |
+
Args:
|
| 117 |
+
shape: (h, w) tuple or a int
|
| 118 |
+
interp: PIL interpolation method
|
| 119 |
+
"""
|
| 120 |
+
if isinstance(shape, int):
|
| 121 |
+
shape = (shape, shape)
|
| 122 |
+
shape = tuple(shape)
|
| 123 |
+
self._init(locals())
|
| 124 |
+
|
| 125 |
+
def get_transform(self, image):
|
| 126 |
+
return ResizeTransform(
|
| 127 |
+
image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class ResizeShortestEdge(Augmentation):
|
| 132 |
+
"""
|
| 133 |
+
Resize the image while keeping the aspect ratio unchanged.
|
| 134 |
+
It attempts to scale the shorter edge to the given `short_edge_length`,
|
| 135 |
+
as long as the longer edge does not exceed `max_size`.
|
| 136 |
+
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
@torch.jit.unused
|
| 140 |
+
def __init__(
|
| 141 |
+
self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
|
| 142 |
+
):
|
| 143 |
+
"""
|
| 144 |
+
Args:
|
| 145 |
+
short_edge_length (list[int]): If ``sample_style=="range"``,
|
| 146 |
+
a [min, max] interval from which to sample the shortest edge length.
|
| 147 |
+
If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
|
| 148 |
+
max_size (int): maximum allowed longest edge length.
|
| 149 |
+
sample_style (str): either "range" or "choice".
|
| 150 |
+
"""
|
| 151 |
+
super().__init__()
|
| 152 |
+
assert sample_style in ["range", "choice"], sample_style
|
| 153 |
+
|
| 154 |
+
self.is_range = sample_style == "range"
|
| 155 |
+
if isinstance(short_edge_length, int):
|
| 156 |
+
short_edge_length = (short_edge_length, short_edge_length)
|
| 157 |
+
if self.is_range:
|
| 158 |
+
assert len(short_edge_length) == 2, (
|
| 159 |
+
"short_edge_length must be two values using 'range' sample style."
|
| 160 |
+
f" Got {short_edge_length}!"
|
| 161 |
+
)
|
| 162 |
+
self._init(locals())
|
| 163 |
+
|
| 164 |
+
@torch.jit.unused
|
| 165 |
+
def get_transform(self, image):
|
| 166 |
+
h, w = image.shape[:2]
|
| 167 |
+
if self.is_range:
|
| 168 |
+
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
| 169 |
+
else:
|
| 170 |
+
size = np.random.choice(self.short_edge_length)
|
| 171 |
+
if size == 0:
|
| 172 |
+
return NoOpTransform()
|
| 173 |
+
|
| 174 |
+
newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size)
|
| 175 |
+
return ResizeTransform(h, w, newh, neww, self.interp)
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def get_output_shape(
|
| 179 |
+
oldh: int, oldw: int, short_edge_length: int, max_size: int
|
| 180 |
+
) -> Tuple[int, int]:
|
| 181 |
+
"""
|
| 182 |
+
Compute the output size given input size and target short edge length.
|
| 183 |
+
"""
|
| 184 |
+
h, w = oldh, oldw
|
| 185 |
+
size = short_edge_length * 1.0
|
| 186 |
+
scale = size / min(h, w)
|
| 187 |
+
if h < w:
|
| 188 |
+
newh, neww = size, scale * w
|
| 189 |
+
else:
|
| 190 |
+
newh, neww = scale * h, size
|
| 191 |
+
if max(newh, neww) > max_size:
|
| 192 |
+
scale = max_size * 1.0 / max(newh, neww)
|
| 193 |
+
newh = newh * scale
|
| 194 |
+
neww = neww * scale
|
| 195 |
+
neww = int(neww + 0.5)
|
| 196 |
+
newh = int(newh + 0.5)
|
| 197 |
+
return (newh, neww)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class ResizeScale(Augmentation):
|
| 201 |
+
"""
|
| 202 |
+
Takes target size as input and randomly scales the given target size between `min_scale`
|
| 203 |
+
and `max_scale`. It then scales the input image such that it fits inside the scaled target
|
| 204 |
+
box, keeping the aspect ratio constant.
|
| 205 |
+
This implements the resize part of the Google's 'resize_and_crop' data augmentation:
|
| 206 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
min_scale: float,
|
| 212 |
+
max_scale: float,
|
| 213 |
+
target_height: int,
|
| 214 |
+
target_width: int,
|
| 215 |
+
interp: int = Image.BILINEAR,
|
| 216 |
+
):
|
| 217 |
+
"""
|
| 218 |
+
Args:
|
| 219 |
+
min_scale: minimum image scale range.
|
| 220 |
+
max_scale: maximum image scale range.
|
| 221 |
+
target_height: target image height.
|
| 222 |
+
target_width: target image width.
|
| 223 |
+
interp: image interpolation method.
|
| 224 |
+
"""
|
| 225 |
+
super().__init__()
|
| 226 |
+
self._init(locals())
|
| 227 |
+
|
| 228 |
+
def _get_resize(self, image: np.ndarray, scale: float) -> Transform:
|
| 229 |
+
input_size = image.shape[:2]
|
| 230 |
+
|
| 231 |
+
# Compute new target size given a scale.
|
| 232 |
+
target_size = (self.target_height, self.target_width)
|
| 233 |
+
target_scale_size = np.multiply(target_size, scale)
|
| 234 |
+
|
| 235 |
+
# Compute actual rescaling applied to input image and output size.
|
| 236 |
+
output_scale = np.minimum(
|
| 237 |
+
target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
|
| 238 |
+
)
|
| 239 |
+
output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
|
| 240 |
+
|
| 241 |
+
return ResizeTransform(
|
| 242 |
+
input_size[0], input_size[1], output_size[0], output_size[1], self.interp
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def get_transform(self, image: np.ndarray) -> Transform:
|
| 246 |
+
random_scale = np.random.uniform(self.min_scale, self.max_scale)
|
| 247 |
+
return self._get_resize(image, random_scale)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class RandomRotation(Augmentation):
|
| 251 |
+
"""
|
| 252 |
+
This method returns a copy of this image, rotated the given
|
| 253 |
+
number of degrees counter clockwise around the given center.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
|
| 257 |
+
"""
|
| 258 |
+
Args:
|
| 259 |
+
angle (list[float]): If ``sample_style=="range"``,
|
| 260 |
+
a [min, max] interval from which to sample the angle (in degrees).
|
| 261 |
+
If ``sample_style=="choice"``, a list of angles to sample from
|
| 262 |
+
expand (bool): choose if the image should be resized to fit the whole
|
| 263 |
+
rotated image (default), or simply cropped
|
| 264 |
+
center (list[[float, float]]): If ``sample_style=="range"``,
|
| 265 |
+
a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
|
| 266 |
+
[0, 0] being the top left of the image and [1, 1] the bottom right.
|
| 267 |
+
If ``sample_style=="choice"``, a list of centers to sample from
|
| 268 |
+
Default: None, which means that the center of rotation is the center of the image
|
| 269 |
+
center has no effect if expand=True because it only affects shifting
|
| 270 |
+
"""
|
| 271 |
+
super().__init__()
|
| 272 |
+
assert sample_style in ["range", "choice"], sample_style
|
| 273 |
+
self.is_range = sample_style == "range"
|
| 274 |
+
if isinstance(angle, (float, int)):
|
| 275 |
+
angle = (angle, angle)
|
| 276 |
+
if center is not None and isinstance(center[0], (float, int)):
|
| 277 |
+
center = (center, center)
|
| 278 |
+
self._init(locals())
|
| 279 |
+
|
| 280 |
+
def get_transform(self, image):
|
| 281 |
+
h, w = image.shape[:2]
|
| 282 |
+
center = None
|
| 283 |
+
if self.is_range:
|
| 284 |
+
angle = np.random.uniform(self.angle[0], self.angle[1])
|
| 285 |
+
if self.center is not None:
|
| 286 |
+
center = (
|
| 287 |
+
np.random.uniform(self.center[0][0], self.center[1][0]),
|
| 288 |
+
np.random.uniform(self.center[0][1], self.center[1][1]),
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
angle = np.random.choice(self.angle)
|
| 292 |
+
if self.center is not None:
|
| 293 |
+
center = np.random.choice(self.center)
|
| 294 |
+
|
| 295 |
+
if center is not None:
|
| 296 |
+
center = (w * center[0], h * center[1]) # Convert to absolute coordinates
|
| 297 |
+
|
| 298 |
+
if angle % 360 == 0:
|
| 299 |
+
return NoOpTransform()
|
| 300 |
+
|
| 301 |
+
return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class FixedSizeCrop(Augmentation):
|
| 305 |
+
"""
|
| 306 |
+
If `crop_size` is smaller than the input image size, then it uses a random crop of
|
| 307 |
+
the crop size. If `crop_size` is larger than the input image size, then it pads
|
| 308 |
+
the right and the bottom of the image to the crop size if `pad` is True, otherwise
|
| 309 |
+
it returns the smaller image.
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
def __init__(self, crop_size: Tuple[int], pad: bool = True, pad_value: float = 128.0):
|
| 313 |
+
"""
|
| 314 |
+
Args:
|
| 315 |
+
crop_size: target image (height, width).
|
| 316 |
+
pad: if True, will pad images smaller than `crop_size` up to `crop_size`
|
| 317 |
+
pad_value: the padding value.
|
| 318 |
+
"""
|
| 319 |
+
super().__init__()
|
| 320 |
+
self._init(locals())
|
| 321 |
+
|
| 322 |
+
def _get_crop(self, image: np.ndarray) -> Transform:
|
| 323 |
+
# Compute the image scale and scaled size.
|
| 324 |
+
input_size = image.shape[:2]
|
| 325 |
+
output_size = self.crop_size
|
| 326 |
+
|
| 327 |
+
# Add random crop if the image is scaled up.
|
| 328 |
+
max_offset = np.subtract(input_size, output_size)
|
| 329 |
+
max_offset = np.maximum(max_offset, 0)
|
| 330 |
+
offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
|
| 331 |
+
offset = np.round(offset).astype(int)
|
| 332 |
+
return CropTransform(
|
| 333 |
+
offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
def _get_pad(self, image: np.ndarray) -> Transform:
|
| 337 |
+
# Compute the image scale and scaled size.
|
| 338 |
+
input_size = image.shape[:2]
|
| 339 |
+
output_size = self.crop_size
|
| 340 |
+
|
| 341 |
+
# Add padding if the image is scaled down.
|
| 342 |
+
pad_size = np.subtract(output_size, input_size)
|
| 343 |
+
pad_size = np.maximum(pad_size, 0)
|
| 344 |
+
original_size = np.minimum(input_size, output_size)
|
| 345 |
+
return PadTransform(
|
| 346 |
+
0, 0, pad_size[1], pad_size[0], original_size[1], original_size[0], self.pad_value
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def get_transform(self, image: np.ndarray) -> TransformList:
|
| 350 |
+
transforms = [self._get_crop(image)]
|
| 351 |
+
if self.pad:
|
| 352 |
+
transforms.append(self._get_pad(image))
|
| 353 |
+
return TransformList(transforms)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class RandomCrop(Augmentation):
|
| 357 |
+
"""
|
| 358 |
+
Randomly crop a rectangle region out of an image.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
def __init__(self, crop_type: str, crop_size):
|
| 362 |
+
"""
|
| 363 |
+
Args:
|
| 364 |
+
crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
|
| 365 |
+
crop_size (tuple[float, float]): two floats, explained below.
|
| 366 |
+
|
| 367 |
+
- "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
|
| 368 |
+
size (H, W). crop size should be in (0, 1]
|
| 369 |
+
- "relative_range": uniformly sample two values from [crop_size[0], 1]
|
| 370 |
+
and [crop_size[1]], 1], and use them as in "relative" crop type.
|
| 371 |
+
- "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
|
| 372 |
+
crop_size must be smaller than the input image size.
|
| 373 |
+
- "absolute_range", for an input of size (H, W), uniformly sample H_crop in
|
| 374 |
+
[crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
|
| 375 |
+
Then crop a region (H_crop, W_crop).
|
| 376 |
+
"""
|
| 377 |
+
# TODO style of relative_range and absolute_range are not consistent:
|
| 378 |
+
# one takes (h, w) but another takes (min, max)
|
| 379 |
+
super().__init__()
|
| 380 |
+
assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
|
| 381 |
+
self._init(locals())
|
| 382 |
+
|
| 383 |
+
def get_transform(self, image):
|
| 384 |
+
h, w = image.shape[:2]
|
| 385 |
+
croph, cropw = self.get_crop_size((h, w))
|
| 386 |
+
assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
|
| 387 |
+
h0 = np.random.randint(h - croph + 1)
|
| 388 |
+
w0 = np.random.randint(w - cropw + 1)
|
| 389 |
+
return CropTransform(w0, h0, cropw, croph)
|
| 390 |
+
|
| 391 |
+
def get_crop_size(self, image_size):
|
| 392 |
+
"""
|
| 393 |
+
Args:
|
| 394 |
+
image_size (tuple): height, width
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
crop_size (tuple): height, width in absolute pixels
|
| 398 |
+
"""
|
| 399 |
+
h, w = image_size
|
| 400 |
+
if self.crop_type == "relative":
|
| 401 |
+
ch, cw = self.crop_size
|
| 402 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
| 403 |
+
elif self.crop_type == "relative_range":
|
| 404 |
+
crop_size = np.asarray(self.crop_size, dtype=np.float32)
|
| 405 |
+
ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
|
| 406 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
| 407 |
+
elif self.crop_type == "absolute":
|
| 408 |
+
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
|
| 409 |
+
elif self.crop_type == "absolute_range":
|
| 410 |
+
assert self.crop_size[0] <= self.crop_size[1]
|
| 411 |
+
ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
|
| 412 |
+
cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
|
| 413 |
+
return ch, cw
|
| 414 |
+
else:
|
| 415 |
+
raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class RandomCrop_CategoryAreaConstraint(Augmentation):
|
| 419 |
+
"""
|
| 420 |
+
Similar to :class:`RandomCrop`, but find a cropping window such that no single category
|
| 421 |
+
occupies a ratio of more than `single_category_max_area` in semantic segmentation ground
|
| 422 |
+
truth, which can cause unstability in training. The function attempts to find such a valid
|
| 423 |
+
cropping window for at most 10 times.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
def __init__(
|
| 427 |
+
self,
|
| 428 |
+
crop_type: str,
|
| 429 |
+
crop_size,
|
| 430 |
+
single_category_max_area: float = 1.0,
|
| 431 |
+
ignored_category: int = None,
|
| 432 |
+
):
|
| 433 |
+
"""
|
| 434 |
+
Args:
|
| 435 |
+
crop_type, crop_size: same as in :class:`RandomCrop`
|
| 436 |
+
single_category_max_area: the maximum allowed area ratio of a
|
| 437 |
+
category. Set to 1.0 to disable
|
| 438 |
+
ignored_category: allow this category in the semantic segmentation
|
| 439 |
+
ground truth to exceed the area ratio. Usually set to the category
|
| 440 |
+
that's ignored in training.
|
| 441 |
+
"""
|
| 442 |
+
self.crop_aug = RandomCrop(crop_type, crop_size)
|
| 443 |
+
self._init(locals())
|
| 444 |
+
|
| 445 |
+
def get_transform(self, image, sem_seg):
|
| 446 |
+
if self.single_category_max_area >= 1.0:
|
| 447 |
+
return self.crop_aug.get_transform(image)
|
| 448 |
+
else:
|
| 449 |
+
h, w = sem_seg.shape
|
| 450 |
+
for _ in range(10):
|
| 451 |
+
crop_size = self.crop_aug.get_crop_size((h, w))
|
| 452 |
+
y0 = np.random.randint(h - crop_size[0] + 1)
|
| 453 |
+
x0 = np.random.randint(w - crop_size[1] + 1)
|
| 454 |
+
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
|
| 455 |
+
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
|
| 456 |
+
if self.ignored_category is not None:
|
| 457 |
+
cnt = cnt[labels != self.ignored_category]
|
| 458 |
+
if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area:
|
| 459 |
+
break
|
| 460 |
+
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
|
| 461 |
+
return crop_tfm
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class RandomExtent(Augmentation):
|
| 465 |
+
"""
|
| 466 |
+
Outputs an image by cropping a random "subrect" of the source image.
|
| 467 |
+
|
| 468 |
+
The subrect can be parameterized to include pixels outside the source image,
|
| 469 |
+
in which case they will be set to zeros (i.e. black). The size of the output
|
| 470 |
+
image will vary with the size of the random subrect.
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
def __init__(self, scale_range, shift_range):
|
| 474 |
+
"""
|
| 475 |
+
Args:
|
| 476 |
+
output_size (h, w): Dimensions of output image
|
| 477 |
+
scale_range (l, h): Range of input-to-output size scaling factor
|
| 478 |
+
shift_range (x, y): Range of shifts of the cropped subrect. The rect
|
| 479 |
+
is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
|
| 480 |
+
where (w, h) is the (width, height) of the input image. Set each
|
| 481 |
+
component to zero to crop at the image's center.
|
| 482 |
+
"""
|
| 483 |
+
super().__init__()
|
| 484 |
+
self._init(locals())
|
| 485 |
+
|
| 486 |
+
def get_transform(self, image):
|
| 487 |
+
img_h, img_w = image.shape[:2]
|
| 488 |
+
|
| 489 |
+
# Initialize src_rect to fit the input image.
|
| 490 |
+
src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
|
| 491 |
+
|
| 492 |
+
# Apply a random scaling to the src_rect.
|
| 493 |
+
src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
|
| 494 |
+
|
| 495 |
+
# Apply a random shift to the coordinates origin.
|
| 496 |
+
src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
|
| 497 |
+
src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
|
| 498 |
+
|
| 499 |
+
# Map src_rect coordinates into image coordinates (center at corner).
|
| 500 |
+
src_rect[0::2] += 0.5 * img_w
|
| 501 |
+
src_rect[1::2] += 0.5 * img_h
|
| 502 |
+
|
| 503 |
+
return ExtentTransform(
|
| 504 |
+
src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
|
| 505 |
+
output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
class RandomContrast(Augmentation):
|
| 510 |
+
"""
|
| 511 |
+
Randomly transforms image contrast.
|
| 512 |
+
|
| 513 |
+
Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 514 |
+
- intensity < 1 will reduce contrast
|
| 515 |
+
- intensity = 1 will preserve the input image
|
| 516 |
+
- intensity > 1 will increase contrast
|
| 517 |
+
|
| 518 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 519 |
+
"""
|
| 520 |
+
|
| 521 |
+
def __init__(self, intensity_min, intensity_max):
|
| 522 |
+
"""
|
| 523 |
+
Args:
|
| 524 |
+
intensity_min (float): Minimum augmentation
|
| 525 |
+
intensity_max (float): Maximum augmentation
|
| 526 |
+
"""
|
| 527 |
+
super().__init__()
|
| 528 |
+
self._init(locals())
|
| 529 |
+
|
| 530 |
+
def get_transform(self, image):
|
| 531 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 532 |
+
return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class RandomBrightness(Augmentation):
|
| 536 |
+
"""
|
| 537 |
+
Randomly transforms image brightness.
|
| 538 |
+
|
| 539 |
+
Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 540 |
+
- intensity < 1 will reduce brightness
|
| 541 |
+
- intensity = 1 will preserve the input image
|
| 542 |
+
- intensity > 1 will increase brightness
|
| 543 |
+
|
| 544 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
def __init__(self, intensity_min, intensity_max):
|
| 548 |
+
"""
|
| 549 |
+
Args:
|
| 550 |
+
intensity_min (float): Minimum augmentation
|
| 551 |
+
intensity_max (float): Maximum augmentation
|
| 552 |
+
"""
|
| 553 |
+
super().__init__()
|
| 554 |
+
self._init(locals())
|
| 555 |
+
|
| 556 |
+
def get_transform(self, image):
|
| 557 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 558 |
+
return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class RandomSaturation(Augmentation):
|
| 562 |
+
"""
|
| 563 |
+
Randomly transforms saturation of an RGB image.
|
| 564 |
+
Input images are assumed to have 'RGB' channel order.
|
| 565 |
+
|
| 566 |
+
Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 567 |
+
- intensity < 1 will reduce saturation (make the image more grayscale)
|
| 568 |
+
- intensity = 1 will preserve the input image
|
| 569 |
+
- intensity > 1 will increase saturation
|
| 570 |
+
|
| 571 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 572 |
+
"""
|
| 573 |
+
|
| 574 |
+
def __init__(self, intensity_min, intensity_max):
|
| 575 |
+
"""
|
| 576 |
+
Args:
|
| 577 |
+
intensity_min (float): Minimum augmentation (1 preserves input).
|
| 578 |
+
intensity_max (float): Maximum augmentation (1 preserves input).
|
| 579 |
+
"""
|
| 580 |
+
super().__init__()
|
| 581 |
+
self._init(locals())
|
| 582 |
+
|
| 583 |
+
def get_transform(self, image):
|
| 584 |
+
assert image.shape[-1] == 3, "RandomSaturation only works on RGB images"
|
| 585 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 586 |
+
grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
|
| 587 |
+
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class RandomLighting(Augmentation):
|
| 591 |
+
"""
|
| 592 |
+
The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet.
|
| 593 |
+
Input images are assumed to have 'RGB' channel order.
|
| 594 |
+
|
| 595 |
+
The degree of color jittering is randomly sampled via a normal distribution,
|
| 596 |
+
with standard deviation given by the scale parameter.
|
| 597 |
+
"""
|
| 598 |
+
|
| 599 |
+
def __init__(self, scale):
|
| 600 |
+
"""
|
| 601 |
+
Args:
|
| 602 |
+
scale (float): Standard deviation of principal component weighting.
|
| 603 |
+
"""
|
| 604 |
+
super().__init__()
|
| 605 |
+
self._init(locals())
|
| 606 |
+
self.eigen_vecs = np.array(
|
| 607 |
+
[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
|
| 608 |
+
)
|
| 609 |
+
self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
|
| 610 |
+
|
| 611 |
+
def get_transform(self, image):
|
| 612 |
+
assert image.shape[-1] == 3, "RandomLighting only works on RGB images"
|
| 613 |
+
weights = np.random.normal(scale=self.scale, size=3)
|
| 614 |
+
return BlendTransform(
|
| 615 |
+
src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
|
| 616 |
+
)
|
cutler/data/transforms/transform.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/transform.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
See "Data Augmentation" tutorial for an overview of the system:
|
| 6 |
+
https://detectron2.readthedocs.io/tutorials/augmentation.html
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from fvcore.transforms.transform import (
|
| 13 |
+
CropTransform,
|
| 14 |
+
HFlipTransform,
|
| 15 |
+
NoOpTransform,
|
| 16 |
+
Transform,
|
| 17 |
+
TransformList,
|
| 18 |
+
)
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import cv2 # noqa
|
| 23 |
+
except ImportError:
|
| 24 |
+
# OpenCV is an optional dependency at the moment
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"ExtentTransform",
|
| 29 |
+
"ResizeTransform",
|
| 30 |
+
"RotationTransform",
|
| 31 |
+
"ColorTransform",
|
| 32 |
+
"PILColorTransform",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ExtentTransform(Transform):
|
| 37 |
+
"""
|
| 38 |
+
Extracts a subregion from the source image and scales it to the output size.
|
| 39 |
+
|
| 40 |
+
The fill color is used to map pixels from the source rect that fall outside
|
| 41 |
+
the source image.
|
| 42 |
+
|
| 43 |
+
See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, src_rect, output_size, interp=Image.BILINEAR, fill=0):
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
src_rect (x0, y0, x1, y1): src coordinates
|
| 50 |
+
output_size (h, w): dst image size
|
| 51 |
+
interp: PIL interpolation methods
|
| 52 |
+
fill: Fill color used when src_rect extends outside image
|
| 53 |
+
"""
|
| 54 |
+
super().__init__()
|
| 55 |
+
self._set_attributes(locals())
|
| 56 |
+
|
| 57 |
+
def apply_image(self, img, interp=None):
|
| 58 |
+
h, w = self.output_size
|
| 59 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
| 60 |
+
pil_image = Image.fromarray(img[:, :, 0], mode="L")
|
| 61 |
+
else:
|
| 62 |
+
pil_image = Image.fromarray(img)
|
| 63 |
+
pil_image = pil_image.transform(
|
| 64 |
+
size=(w, h),
|
| 65 |
+
method=Image.EXTENT,
|
| 66 |
+
data=self.src_rect,
|
| 67 |
+
resample=interp if interp else self.interp,
|
| 68 |
+
fill=self.fill,
|
| 69 |
+
)
|
| 70 |
+
ret = np.asarray(pil_image)
|
| 71 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
| 72 |
+
ret = np.expand_dims(ret, -1)
|
| 73 |
+
return ret
|
| 74 |
+
|
| 75 |
+
def apply_coords(self, coords):
|
| 76 |
+
# Transform image center from source coordinates into output coordinates
|
| 77 |
+
# and then map the new origin to the corner of the output image.
|
| 78 |
+
h, w = self.output_size
|
| 79 |
+
x0, y0, x1, y1 = self.src_rect
|
| 80 |
+
new_coords = coords.astype(np.float32)
|
| 81 |
+
new_coords[:, 0] -= 0.5 * (x0 + x1)
|
| 82 |
+
new_coords[:, 1] -= 0.5 * (y0 + y1)
|
| 83 |
+
new_coords[:, 0] *= w / (x1 - x0)
|
| 84 |
+
new_coords[:, 1] *= h / (y1 - y0)
|
| 85 |
+
new_coords[:, 0] += 0.5 * w
|
| 86 |
+
new_coords[:, 1] += 0.5 * h
|
| 87 |
+
return new_coords
|
| 88 |
+
|
| 89 |
+
def apply_segmentation(self, segmentation):
|
| 90 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
| 91 |
+
return segmentation
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ResizeTransform(Transform):
|
| 95 |
+
"""
|
| 96 |
+
Resize the image to a target size.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, h, w, new_h, new_w, interp=None):
|
| 100 |
+
"""
|
| 101 |
+
Args:
|
| 102 |
+
h, w (int): original image size
|
| 103 |
+
new_h, new_w (int): new image size
|
| 104 |
+
interp: PIL interpolation methods, defaults to bilinear.
|
| 105 |
+
"""
|
| 106 |
+
# TODO decide on PIL vs opencv
|
| 107 |
+
super().__init__()
|
| 108 |
+
if interp is None:
|
| 109 |
+
interp = Image.BILINEAR
|
| 110 |
+
self._set_attributes(locals())
|
| 111 |
+
|
| 112 |
+
def apply_image(self, img, interp=None):
|
| 113 |
+
try:
|
| 114 |
+
img.shape[:2] == (self.h, self.w)
|
| 115 |
+
except:
|
| 116 |
+
(self.h, self.w) = (self.w, self.h)
|
| 117 |
+
assert img.shape[:2] == (self.h, self.w)
|
| 118 |
+
assert len(img.shape) <= 4
|
| 119 |
+
interp_method = interp if interp is not None else self.interp
|
| 120 |
+
|
| 121 |
+
if img.dtype == np.uint8:
|
| 122 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
| 123 |
+
pil_image = Image.fromarray(img[:, :, 0], mode="L")
|
| 124 |
+
else:
|
| 125 |
+
pil_image = Image.fromarray(img)
|
| 126 |
+
pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
|
| 127 |
+
ret = np.asarray(pil_image)
|
| 128 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
| 129 |
+
ret = np.expand_dims(ret, -1)
|
| 130 |
+
else:
|
| 131 |
+
# PIL only supports uint8
|
| 132 |
+
if any(x < 0 for x in img.strides):
|
| 133 |
+
img = np.ascontiguousarray(img)
|
| 134 |
+
img = torch.from_numpy(img)
|
| 135 |
+
shape = list(img.shape)
|
| 136 |
+
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
|
| 137 |
+
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
|
| 138 |
+
_PIL_RESIZE_TO_INTERPOLATE_MODE = {
|
| 139 |
+
Image.NEAREST: "nearest",
|
| 140 |
+
Image.BILINEAR: "bilinear",
|
| 141 |
+
Image.BICUBIC: "bicubic",
|
| 142 |
+
}
|
| 143 |
+
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method]
|
| 144 |
+
align_corners = None if mode == "nearest" else False
|
| 145 |
+
img = F.interpolate(
|
| 146 |
+
img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners
|
| 147 |
+
)
|
| 148 |
+
shape[:2] = (self.new_h, self.new_w)
|
| 149 |
+
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
|
| 150 |
+
|
| 151 |
+
return ret
|
| 152 |
+
|
| 153 |
+
def apply_coords(self, coords):
|
| 154 |
+
coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
|
| 155 |
+
coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
|
| 156 |
+
return coords
|
| 157 |
+
|
| 158 |
+
def apply_segmentation(self, segmentation):
|
| 159 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
| 160 |
+
return segmentation
|
| 161 |
+
|
| 162 |
+
def inverse(self):
|
| 163 |
+
return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class RotationTransform(Transform):
|
| 167 |
+
"""
|
| 168 |
+
This method returns a copy of this image, rotated the given
|
| 169 |
+
number of degrees counter clockwise around its center.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, h, w, angle, expand=True, center=None, interp=None):
|
| 173 |
+
"""
|
| 174 |
+
Args:
|
| 175 |
+
h, w (int): original image size
|
| 176 |
+
angle (float): degrees for rotation
|
| 177 |
+
expand (bool): choose if the image should be resized to fit the whole
|
| 178 |
+
rotated image (default), or simply cropped
|
| 179 |
+
center (tuple (width, height)): coordinates of the rotation center
|
| 180 |
+
if left to None, the center will be fit to the center of each image
|
| 181 |
+
center has no effect if expand=True because it only affects shifting
|
| 182 |
+
interp: cv2 interpolation method, default cv2.INTER_LINEAR
|
| 183 |
+
"""
|
| 184 |
+
super().__init__()
|
| 185 |
+
image_center = np.array((w / 2, h / 2))
|
| 186 |
+
if center is None:
|
| 187 |
+
center = image_center
|
| 188 |
+
if interp is None:
|
| 189 |
+
interp = cv2.INTER_LINEAR
|
| 190 |
+
abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle))))
|
| 191 |
+
if expand:
|
| 192 |
+
# find the new width and height bounds
|
| 193 |
+
bound_w, bound_h = np.rint(
|
| 194 |
+
[h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
|
| 195 |
+
).astype(int)
|
| 196 |
+
else:
|
| 197 |
+
bound_w, bound_h = w, h
|
| 198 |
+
|
| 199 |
+
self._set_attributes(locals())
|
| 200 |
+
self.rm_coords = self.create_rotation_matrix()
|
| 201 |
+
# Needed because of this problem https://github.com/opencv/opencv/issues/11784
|
| 202 |
+
self.rm_image = self.create_rotation_matrix(offset=-0.5)
|
| 203 |
+
|
| 204 |
+
def apply_image(self, img, interp=None):
|
| 205 |
+
"""
|
| 206 |
+
img should be a numpy array, formatted as Height * Width * Nchannels
|
| 207 |
+
"""
|
| 208 |
+
if len(img) == 0 or self.angle % 360 == 0:
|
| 209 |
+
return img
|
| 210 |
+
assert img.shape[:2] == (self.h, self.w)
|
| 211 |
+
interp = interp if interp is not None else self.interp
|
| 212 |
+
return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
|
| 213 |
+
|
| 214 |
+
def apply_coords(self, coords):
|
| 215 |
+
"""
|
| 216 |
+
coords should be a N * 2 array-like, containing N couples of (x, y) points
|
| 217 |
+
"""
|
| 218 |
+
coords = np.asarray(coords, dtype=float)
|
| 219 |
+
if len(coords) == 0 or self.angle % 360 == 0:
|
| 220 |
+
return coords
|
| 221 |
+
return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
|
| 222 |
+
|
| 223 |
+
def apply_segmentation(self, segmentation):
|
| 224 |
+
segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
|
| 225 |
+
return segmentation
|
| 226 |
+
|
| 227 |
+
def create_rotation_matrix(self, offset=0):
|
| 228 |
+
center = (self.center[0] + offset, self.center[1] + offset)
|
| 229 |
+
rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
|
| 230 |
+
if self.expand:
|
| 231 |
+
# Find the coordinates of the center of rotation in the new image
|
| 232 |
+
# The only point for which we know the future coordinates is the center of the image
|
| 233 |
+
rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
|
| 234 |
+
new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
|
| 235 |
+
# shift the rotation center to the new coordinates
|
| 236 |
+
rm[:, 2] += new_center
|
| 237 |
+
return rm
|
| 238 |
+
|
| 239 |
+
def inverse(self):
|
| 240 |
+
"""
|
| 241 |
+
The inverse is to rotate it back with expand, and crop to get the original shape.
|
| 242 |
+
"""
|
| 243 |
+
if not self.expand: # Not possible to inverse if a part of the image is lost
|
| 244 |
+
raise NotImplementedError()
|
| 245 |
+
rotation = RotationTransform(
|
| 246 |
+
self.bound_h, self.bound_w, -self.angle, True, None, self.interp
|
| 247 |
+
)
|
| 248 |
+
crop = CropTransform(
|
| 249 |
+
(rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h
|
| 250 |
+
)
|
| 251 |
+
return TransformList([rotation, crop])
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class ColorTransform(Transform):
|
| 255 |
+
"""
|
| 256 |
+
Generic wrapper for any photometric transforms.
|
| 257 |
+
These transformations should only affect the color space and
|
| 258 |
+
not the coordinate space of the image (e.g. annotation
|
| 259 |
+
coordinates such as bounding boxes should not be changed)
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, op):
|
| 263 |
+
"""
|
| 264 |
+
Args:
|
| 265 |
+
op (Callable): operation to be applied to the image,
|
| 266 |
+
which takes in an ndarray and returns an ndarray.
|
| 267 |
+
"""
|
| 268 |
+
if not callable(op):
|
| 269 |
+
raise ValueError("op parameter should be callable")
|
| 270 |
+
super().__init__()
|
| 271 |
+
self._set_attributes(locals())
|
| 272 |
+
|
| 273 |
+
def apply_image(self, img):
|
| 274 |
+
return self.op(img)
|
| 275 |
+
|
| 276 |
+
def apply_coords(self, coords):
|
| 277 |
+
return coords
|
| 278 |
+
|
| 279 |
+
def inverse(self):
|
| 280 |
+
return NoOpTransform()
|
| 281 |
+
|
| 282 |
+
def apply_segmentation(self, segmentation):
|
| 283 |
+
return segmentation
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class PILColorTransform(ColorTransform):
|
| 287 |
+
"""
|
| 288 |
+
Generic wrapper for PIL Photometric image transforms,
|
| 289 |
+
which affect the color space and not the coordinate
|
| 290 |
+
space of the image
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def __init__(self, op):
|
| 294 |
+
"""
|
| 295 |
+
Args:
|
| 296 |
+
op (Callable): operation to be applied to the image,
|
| 297 |
+
which takes in a PIL Image and returns a transformed
|
| 298 |
+
PIL Image.
|
| 299 |
+
For reference on possible operations see:
|
| 300 |
+
- https://pillow.readthedocs.io/en/stable/
|
| 301 |
+
"""
|
| 302 |
+
if not callable(op):
|
| 303 |
+
raise ValueError("op parameter should be callable")
|
| 304 |
+
super().__init__(op)
|
| 305 |
+
|
| 306 |
+
def apply_image(self, img):
|
| 307 |
+
img = Image.fromarray(img)
|
| 308 |
+
return np.asarray(super().apply_image(img))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def HFlip_rotated_box(transform, rotated_boxes):
|
| 312 |
+
"""
|
| 313 |
+
Apply the horizontal flip transform on rotated boxes.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
rotated_boxes (ndarray): Nx5 floating point array of
|
| 317 |
+
(x_center, y_center, width, height, angle_degrees) format
|
| 318 |
+
in absolute coordinates.
|
| 319 |
+
"""
|
| 320 |
+
# Transform x_center
|
| 321 |
+
rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0]
|
| 322 |
+
# Transform angle
|
| 323 |
+
rotated_boxes[:, 4] = -rotated_boxes[:, 4]
|
| 324 |
+
return rotated_boxes
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def Resize_rotated_box(transform, rotated_boxes):
|
| 328 |
+
"""
|
| 329 |
+
Apply the resizing transform on rotated boxes. For details of how these (approximation)
|
| 330 |
+
formulas are derived, please refer to :meth:`RotatedBoxes.scale`.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
rotated_boxes (ndarray): Nx5 floating point array of
|
| 334 |
+
(x_center, y_center, width, height, angle_degrees) format
|
| 335 |
+
in absolute coordinates.
|
| 336 |
+
"""
|
| 337 |
+
scale_factor_x = transform.new_w * 1.0 / transform.w
|
| 338 |
+
scale_factor_y = transform.new_h * 1.0 / transform.h
|
| 339 |
+
rotated_boxes[:, 0] *= scale_factor_x
|
| 340 |
+
rotated_boxes[:, 1] *= scale_factor_y
|
| 341 |
+
theta = rotated_boxes[:, 4] * np.pi / 180.0
|
| 342 |
+
c = np.cos(theta)
|
| 343 |
+
s = np.sin(theta)
|
| 344 |
+
rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s))
|
| 345 |
+
rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c))
|
| 346 |
+
rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi
|
| 347 |
+
|
| 348 |
+
return rotated_boxes
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
HFlipTransform.register_type("rotated_box", HFlip_rotated_box)
|
| 352 |
+
ResizeTransform.register_type("rotated_box", Resize_rotated_box)
|
| 353 |
+
|
| 354 |
+
# not necessary any more with latest fvcore
|
| 355 |
+
NoOpTransform.register_type("rotated_box", lambda t, x: x)
|
cutler/demo/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
from demo import *
|
| 3 |
+
from predictor import *
|
| 4 |
+
|
| 5 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
cutler/demo/__pycache__/predictor.cpython-312.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
cutler/demo/cutler_cascade_final.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24fdc4544a348e5c27f6e334e4ff6557d8f2d50de94380a741bc91c2226b86bb
|
| 3 |
+
size 574672112
|
cutler/demo/demo.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/demo/demo.py
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import glob
|
| 6 |
+
import multiprocessing as mp
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
import time
|
| 11 |
+
import warnings
|
| 12 |
+
import cv2
|
| 13 |
+
import tqdm
|
| 14 |
+
|
| 15 |
+
from detectron2.config import get_cfg
|
| 16 |
+
from detectron2.data.detection_utils import read_image
|
| 17 |
+
from detectron2.utils.logger import setup_logger
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append('./')
|
| 20 |
+
sys.path.append('../')
|
| 21 |
+
from config import add_cutler_config
|
| 22 |
+
|
| 23 |
+
from predictor import VisualizationDemo
|
| 24 |
+
|
| 25 |
+
# constants
|
| 26 |
+
WINDOW_NAME = "CutLER detections"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def setup_cfg(args):
|
| 30 |
+
# load config from file and command-line arguments
|
| 31 |
+
cfg = get_cfg()
|
| 32 |
+
add_cutler_config(cfg)
|
| 33 |
+
cfg.merge_from_file(args.config_file)
|
| 34 |
+
cfg.merge_from_list(args.opts)
|
| 35 |
+
# Disable the use of SyncBN normalization when running on a CPU
|
| 36 |
+
# SyncBN is not supported on CPU and can cause errors, so we switch to BN instead
|
| 37 |
+
if cfg.MODEL.DEVICE == 'cpu' and cfg.MODEL.RESNETS.NORM == 'SyncBN':
|
| 38 |
+
cfg.MODEL.RESNETS.NORM = "BN"
|
| 39 |
+
cfg.MODEL.FPN.NORM = "BN"
|
| 40 |
+
# Set score_threshold for builtin models
|
| 41 |
+
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
|
| 42 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
|
| 43 |
+
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
|
| 44 |
+
cfg.freeze()
|
| 45 |
+
return cfg
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_parser():
|
| 49 |
+
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--config-file",
|
| 52 |
+
default="model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml",
|
| 53 |
+
metavar="FILE",
|
| 54 |
+
help="path to config file",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
|
| 57 |
+
parser.add_argument("--video-input", help="Path to video file.")
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--input",
|
| 60 |
+
nargs="+",
|
| 61 |
+
help="A list of space separated input images; "
|
| 62 |
+
"or a single glob pattern such as 'directory/*.jpg'",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--output",
|
| 66 |
+
help="A file or directory to save output visualizations. "
|
| 67 |
+
"If not given, will show output in an OpenCV window.",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--confidence-threshold",
|
| 72 |
+
type=float,
|
| 73 |
+
default=0.35,
|
| 74 |
+
help="Minimum score for instance predictions to be shown",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--opts",
|
| 78 |
+
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
| 79 |
+
default=[],
|
| 80 |
+
nargs=argparse.REMAINDER,
|
| 81 |
+
)
|
| 82 |
+
return parser
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_opencv_video_format(codec, file_ext):
|
| 86 |
+
with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
|
| 87 |
+
filename = os.path.join(dir, "test_file" + file_ext)
|
| 88 |
+
writer = cv2.VideoWriter(
|
| 89 |
+
filename=filename,
|
| 90 |
+
fourcc=cv2.VideoWriter_fourcc(*codec),
|
| 91 |
+
fps=float(30),
|
| 92 |
+
frameSize=(10, 10),
|
| 93 |
+
isColor=True,
|
| 94 |
+
)
|
| 95 |
+
[writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
|
| 96 |
+
writer.release()
|
| 97 |
+
if os.path.isfile(filename):
|
| 98 |
+
return True
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
mp.set_start_method("spawn", force=True)
|
| 104 |
+
args = get_parser().parse_args()
|
| 105 |
+
setup_logger(name="fvcore")
|
| 106 |
+
logger = setup_logger()
|
| 107 |
+
logger.info("Arguments: " + str(args))
|
| 108 |
+
|
| 109 |
+
cfg = setup_cfg(args)
|
| 110 |
+
|
| 111 |
+
demo = VisualizationDemo(cfg)
|
| 112 |
+
|
| 113 |
+
if args.input:
|
| 114 |
+
if len(args.input) == 1:
|
| 115 |
+
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
| 116 |
+
assert args.input, "The input path(s) was not found"
|
| 117 |
+
for path in tqdm.tqdm(args.input, disable=not args.output):
|
| 118 |
+
# use PIL, to be consistent with evaluation
|
| 119 |
+
img = read_image(path, format="BGR")
|
| 120 |
+
start_time = time.time()
|
| 121 |
+
predictions, visualized_output = demo.run_on_image(img)
|
| 122 |
+
logger.info(
|
| 123 |
+
"{}: {} in {:.2f}s".format(
|
| 124 |
+
path,
|
| 125 |
+
"detected {} instances".format(len(predictions["instances"]))
|
| 126 |
+
if "instances" in predictions
|
| 127 |
+
else "finished",
|
| 128 |
+
time.time() - start_time,
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if args.output:
|
| 133 |
+
if os.path.isdir(args.output):
|
| 134 |
+
assert os.path.isdir(args.output), args.output
|
| 135 |
+
out_filename = os.path.join(args.output, os.path.basename(path))
|
| 136 |
+
else:
|
| 137 |
+
assert len(args.input) == 1, "Please specify a directory with args.output"
|
| 138 |
+
out_filename = args.output
|
| 139 |
+
visualized_output.save(out_filename)
|
| 140 |
+
else:
|
| 141 |
+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
| 142 |
+
cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
|
| 143 |
+
if cv2.waitKey(0) == 27:
|
| 144 |
+
break # esc to quit
|
| 145 |
+
elif args.webcam:
|
| 146 |
+
assert args.input is None, "Cannot have both --input and --webcam!"
|
| 147 |
+
assert args.output is None, "output not yet supported with --webcam!"
|
| 148 |
+
cam = cv2.VideoCapture(0)
|
| 149 |
+
for vis in tqdm.tqdm(demo.run_on_video(cam)):
|
| 150 |
+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
| 151 |
+
cv2.imshow(WINDOW_NAME, vis)
|
| 152 |
+
if cv2.waitKey(1) == 27:
|
| 153 |
+
break # esc to quit
|
| 154 |
+
cam.release()
|
| 155 |
+
cv2.destroyAllWindows()
|
| 156 |
+
elif args.video_input:
|
| 157 |
+
video = cv2.VideoCapture(args.video_input)
|
| 158 |
+
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 159 |
+
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 160 |
+
frames_per_second = video.get(cv2.CAP_PROP_FPS)
|
| 161 |
+
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 162 |
+
basename = os.path.basename(args.video_input)
|
| 163 |
+
codec, file_ext = (
|
| 164 |
+
("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
|
| 165 |
+
)
|
| 166 |
+
if codec == ".mp4v":
|
| 167 |
+
warnings.warn("x264 codec not available, switching to mp4v")
|
| 168 |
+
if args.output:
|
| 169 |
+
if os.path.isdir(args.output):
|
| 170 |
+
output_fname = os.path.join(args.output, basename)
|
| 171 |
+
output_fname = os.path.splitext(output_fname)[0] + file_ext
|
| 172 |
+
else:
|
| 173 |
+
output_fname = args.output
|
| 174 |
+
assert not os.path.isfile(output_fname), output_fname
|
| 175 |
+
output_file = cv2.VideoWriter(
|
| 176 |
+
filename=output_fname,
|
| 177 |
+
# some installation of opencv may not support x264 (due to its license),
|
| 178 |
+
# you can try other format (e.g. MPEG)
|
| 179 |
+
fourcc=cv2.VideoWriter_fourcc(*codec),
|
| 180 |
+
fps=float(frames_per_second),
|
| 181 |
+
frameSize=(width, height),
|
| 182 |
+
isColor=True,
|
| 183 |
+
)
|
| 184 |
+
assert os.path.isfile(args.video_input)
|
| 185 |
+
for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
|
| 186 |
+
if args.output:
|
| 187 |
+
output_file.write(vis_frame)
|
| 188 |
+
else:
|
| 189 |
+
cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
|
| 190 |
+
cv2.imshow(basename, vis_frame)
|
| 191 |
+
if cv2.waitKey(1) == 27:
|
| 192 |
+
break # esc to quit
|
| 193 |
+
video.release()
|
| 194 |
+
if args.output:
|
| 195 |
+
output_file.release()
|
| 196 |
+
else:
|
| 197 |
+
cv2.destroyAllWindows()
|
cutler/demo/imgs/demo1.jpg
ADDED
|
Git LFS Details
|
cutler/demo/imgs/demo2.jpg
ADDED
|
Git LFS Details
|
cutler/demo/imgs/demo3.jpg
ADDED
|
Git LFS Details
|
cutler/demo/imgs/demo4.jpg
ADDED
|
Git LFS Details
|
cutler/demo/imgs/demo5.jpg
ADDED
|
Git LFS Details
|
cutler/demo/imgs/demo6.jpg
ADDED
|
Git LFS Details
|
cutler/demo/imgs/demo7.jpg
ADDED
|
Git LFS Details
|
cutler/demo/imgs/demo8.jpg
ADDED
|
Git LFS Details
|
cutler/demo/predictor.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import atexit
|
| 3 |
+
import bisect
|
| 4 |
+
import multiprocessing as mp
|
| 5 |
+
from collections import deque
|
| 6 |
+
import cv2
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from detectron2.data import MetadataCatalog
|
| 10 |
+
import sys
|
| 11 |
+
sys.path.append('./')
|
| 12 |
+
from engine.defaults import DefaultPredictor
|
| 13 |
+
from detectron2.utils.video_visualizer import VideoVisualizer
|
| 14 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VisualizationDemo(object):
|
| 18 |
+
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
|
| 19 |
+
"""
|
| 20 |
+
Args:
|
| 21 |
+
cfg (CfgNode):
|
| 22 |
+
instance_mode (ColorMode):
|
| 23 |
+
parallel (bool): whether to run the model in different processes from visualization.
|
| 24 |
+
Useful since the visualization logic can be slow.
|
| 25 |
+
"""
|
| 26 |
+
self.metadata = MetadataCatalog.get(
|
| 27 |
+
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
|
| 28 |
+
)
|
| 29 |
+
self.cpu_device = torch.device("cpu")
|
| 30 |
+
self.instance_mode = instance_mode
|
| 31 |
+
|
| 32 |
+
self.parallel = parallel
|
| 33 |
+
if parallel:
|
| 34 |
+
num_gpu = torch.cuda.device_count()
|
| 35 |
+
self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
|
| 36 |
+
else:
|
| 37 |
+
self.predictor = DefaultPredictor(cfg)
|
| 38 |
+
|
| 39 |
+
def run_on_image(self, image):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
| 43 |
+
This is the format used by OpenCV.
|
| 44 |
+
Returns:
|
| 45 |
+
predictions (dict): the output of the model.
|
| 46 |
+
vis_output (VisImage): the visualized image output.
|
| 47 |
+
"""
|
| 48 |
+
vis_output = None
|
| 49 |
+
predictions = self.predictor(image)
|
| 50 |
+
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
| 51 |
+
image = image[:, :, ::-1]
|
| 52 |
+
visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
|
| 53 |
+
if "panoptic_seg" in predictions:
|
| 54 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
| 55 |
+
vis_output = visualizer.draw_panoptic_seg_predictions(
|
| 56 |
+
panoptic_seg.to(self.cpu_device), segments_info
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
if "sem_seg" in predictions:
|
| 60 |
+
vis_output = visualizer.draw_sem_seg(
|
| 61 |
+
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
| 62 |
+
)
|
| 63 |
+
if "instances" in predictions:
|
| 64 |
+
instances = predictions["instances"].to(self.cpu_device)
|
| 65 |
+
vis_output = visualizer.draw_instance_predictions(predictions=instances)
|
| 66 |
+
|
| 67 |
+
return predictions, vis_output
|
| 68 |
+
|
| 69 |
+
def _frame_from_video(self, video):
|
| 70 |
+
while video.isOpened():
|
| 71 |
+
success, frame = video.read()
|
| 72 |
+
if success:
|
| 73 |
+
yield frame
|
| 74 |
+
else:
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
def run_on_video(self, video):
|
| 78 |
+
"""
|
| 79 |
+
Visualizes predictions on frames of the input video.
|
| 80 |
+
Args:
|
| 81 |
+
video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
|
| 82 |
+
either a webcam or a video file.
|
| 83 |
+
Yields:
|
| 84 |
+
ndarray: BGR visualizations of each video frame.
|
| 85 |
+
"""
|
| 86 |
+
video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
|
| 87 |
+
|
| 88 |
+
def process_predictions(frame, predictions):
|
| 89 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 90 |
+
if "panoptic_seg" in predictions:
|
| 91 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
| 92 |
+
vis_frame = video_visualizer.draw_panoptic_seg_predictions(
|
| 93 |
+
frame, panoptic_seg.to(self.cpu_device), segments_info
|
| 94 |
+
)
|
| 95 |
+
elif "instances" in predictions:
|
| 96 |
+
predictions = predictions["instances"].to(self.cpu_device)
|
| 97 |
+
vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
|
| 98 |
+
elif "sem_seg" in predictions:
|
| 99 |
+
vis_frame = video_visualizer.draw_sem_seg(
|
| 100 |
+
frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Converts Matplotlib RGB format to OpenCV BGR format
|
| 104 |
+
vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
|
| 105 |
+
return vis_frame
|
| 106 |
+
|
| 107 |
+
frame_gen = self._frame_from_video(video)
|
| 108 |
+
if self.parallel:
|
| 109 |
+
buffer_size = self.predictor.default_buffer_size
|
| 110 |
+
|
| 111 |
+
frame_data = deque()
|
| 112 |
+
|
| 113 |
+
for cnt, frame in enumerate(frame_gen):
|
| 114 |
+
frame_data.append(frame)
|
| 115 |
+
self.predictor.put(frame)
|
| 116 |
+
|
| 117 |
+
if cnt >= buffer_size:
|
| 118 |
+
frame = frame_data.popleft()
|
| 119 |
+
predictions = self.predictor.get()
|
| 120 |
+
yield process_predictions(frame, predictions)
|
| 121 |
+
|
| 122 |
+
while len(frame_data):
|
| 123 |
+
frame = frame_data.popleft()
|
| 124 |
+
predictions = self.predictor.get()
|
| 125 |
+
yield process_predictions(frame, predictions)
|
| 126 |
+
else:
|
| 127 |
+
for frame in frame_gen:
|
| 128 |
+
yield process_predictions(frame, self.predictor(frame))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class AsyncPredictor:
|
| 132 |
+
"""
|
| 133 |
+
A predictor that runs the model asynchronously, possibly on >1 GPUs.
|
| 134 |
+
Because rendering the visualization takes considerably amount of time,
|
| 135 |
+
this helps improve throughput a little bit when rendering videos.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
class _StopToken:
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
class _PredictWorker(mp.Process):
|
| 142 |
+
def __init__(self, cfg, task_queue, result_queue):
|
| 143 |
+
self.cfg = cfg
|
| 144 |
+
self.task_queue = task_queue
|
| 145 |
+
self.result_queue = result_queue
|
| 146 |
+
super().__init__()
|
| 147 |
+
|
| 148 |
+
def run(self):
|
| 149 |
+
predictor = DefaultPredictor(self.cfg)
|
| 150 |
+
|
| 151 |
+
while True:
|
| 152 |
+
task = self.task_queue.get()
|
| 153 |
+
if isinstance(task, AsyncPredictor._StopToken):
|
| 154 |
+
break
|
| 155 |
+
idx, data = task
|
| 156 |
+
result = predictor(data)
|
| 157 |
+
self.result_queue.put((idx, result))
|
| 158 |
+
|
| 159 |
+
def __init__(self, cfg, num_gpus: int = 1):
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
cfg (CfgNode):
|
| 163 |
+
num_gpus (int): if 0, will run on CPU
|
| 164 |
+
"""
|
| 165 |
+
num_workers = max(num_gpus, 1)
|
| 166 |
+
self.task_queue = mp.Queue(maxsize=num_workers * 3)
|
| 167 |
+
self.result_queue = mp.Queue(maxsize=num_workers * 3)
|
| 168 |
+
self.procs = []
|
| 169 |
+
for gpuid in range(max(num_gpus, 1)):
|
| 170 |
+
cfg = cfg.clone()
|
| 171 |
+
cfg.defrost()
|
| 172 |
+
cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
|
| 173 |
+
self.procs.append(
|
| 174 |
+
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
self.put_idx = 0
|
| 178 |
+
self.get_idx = 0
|
| 179 |
+
self.result_rank = []
|
| 180 |
+
self.result_data = []
|
| 181 |
+
|
| 182 |
+
for p in self.procs:
|
| 183 |
+
p.start()
|
| 184 |
+
atexit.register(self.shutdown)
|
| 185 |
+
|
| 186 |
+
def put(self, image):
|
| 187 |
+
self.put_idx += 1
|
| 188 |
+
self.task_queue.put((self.put_idx, image))
|
| 189 |
+
|
| 190 |
+
def get(self):
|
| 191 |
+
self.get_idx += 1 # the index needed for this request
|
| 192 |
+
if len(self.result_rank) and self.result_rank[0] == self.get_idx:
|
| 193 |
+
res = self.result_data[0]
|
| 194 |
+
del self.result_data[0], self.result_rank[0]
|
| 195 |
+
return res
|
| 196 |
+
|
| 197 |
+
while True:
|
| 198 |
+
# make sure the results are returned in the correct order
|
| 199 |
+
idx, res = self.result_queue.get()
|
| 200 |
+
if idx == self.get_idx:
|
| 201 |
+
return res
|
| 202 |
+
insert = bisect.bisect(self.result_rank, idx)
|
| 203 |
+
self.result_rank.insert(insert, idx)
|
| 204 |
+
self.result_data.insert(insert, res)
|
| 205 |
+
|
| 206 |
+
def __len__(self):
|
| 207 |
+
return self.put_idx - self.get_idx
|
| 208 |
+
|
| 209 |
+
def __call__(self, image):
|
| 210 |
+
self.put(image)
|
| 211 |
+
return self.get()
|
| 212 |
+
|
| 213 |
+
def shutdown(self):
|
| 214 |
+
for _ in self.procs:
|
| 215 |
+
self.task_queue.put(AsyncPredictor._StopToken())
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def default_buffer_size(self):
|
| 219 |
+
return len(self.procs) * 5
|
cutler/demo/wget-log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cutler/engine/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from .train_loop import *
|
| 4 |
+
|
| 5 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
| 6 |
+
|
| 7 |
+
from .defaults import *
|
cutler/engine/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (481 Bytes). View file
|
|
|
cutler/engine/__pycache__/defaults.cpython-312.pyc
ADDED
|
Binary file (34.6 kB). View file
|
|
|
cutler/engine/__pycache__/train_loop.cpython-312.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
cutler/engine/defaults.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# Modified by XuDong Wang from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/defaults.py
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
This file contains components with some default boilerplate logic user may need
|
| 7 |
+
in training / testing. They will not work for everyone, but many users may find them useful.
|
| 8 |
+
|
| 9 |
+
The behavior of functions/classes in this file is subject to change,
|
| 10 |
+
since they are meant to represent the "common default behavior" people need in their projects.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import weakref
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from typing import Optional
|
| 20 |
+
import torch
|
| 21 |
+
from fvcore.nn.precise_bn import get_bn_modules
|
| 22 |
+
from omegaconf import OmegaConf
|
| 23 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 24 |
+
|
| 25 |
+
import data.transforms as T
|
| 26 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
| 27 |
+
from detectron2.config import CfgNode, LazyConfig
|
| 28 |
+
from detectron2.data import (
|
| 29 |
+
MetadataCatalog,
|
| 30 |
+
)
|
| 31 |
+
from data import (
|
| 32 |
+
build_detection_test_loader,
|
| 33 |
+
build_detection_train_loader,
|
| 34 |
+
)
|
| 35 |
+
from detectron2.evaluation import (
|
| 36 |
+
DatasetEvaluator,
|
| 37 |
+
inference_on_dataset,
|
| 38 |
+
print_csv_format,
|
| 39 |
+
verify_results,
|
| 40 |
+
)
|
| 41 |
+
from modeling import build_model
|
| 42 |
+
from solver import build_lr_scheduler, build_optimizer
|
| 43 |
+
from detectron2.utils import comm
|
| 44 |
+
from detectron2.utils.collect_env import collect_env_info
|
| 45 |
+
from detectron2.utils.env import seed_all_rng
|
| 46 |
+
from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
| 47 |
+
from detectron2.utils.file_io import PathManager
|
| 48 |
+
from detectron2.utils.logger import setup_logger
|
| 49 |
+
|
| 50 |
+
from detectron2.engine import hooks
|
| 51 |
+
from detectron2.engine import TrainerBase
|
| 52 |
+
from .train_loop import CustomAMPTrainer, CustomSimpleTrainer
|
| 53 |
+
|
| 54 |
+
__all__ = [
|
| 55 |
+
"create_ddp_model",
|
| 56 |
+
"default_argument_parser",
|
| 57 |
+
"default_setup",
|
| 58 |
+
"default_writers",
|
| 59 |
+
"DefaultPredictor",
|
| 60 |
+
"DefaultTrainer",
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def create_ddp_model(model, *, fp16_compression=False, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
Create a DistributedDataParallel model if there are >1 processes.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model: a torch.nn.Module
|
| 70 |
+
fp16_compression: add fp16 compression hooks to the ddp object.
|
| 71 |
+
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
|
| 72 |
+
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
|
| 73 |
+
""" # noqa
|
| 74 |
+
if comm.get_world_size() == 1:
|
| 75 |
+
return model
|
| 76 |
+
if "device_ids" not in kwargs:
|
| 77 |
+
kwargs["device_ids"] = [comm.get_local_rank()]
|
| 78 |
+
ddp = DistributedDataParallel(model, **kwargs)
|
| 79 |
+
if fp16_compression:
|
| 80 |
+
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
| 81 |
+
|
| 82 |
+
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
|
| 83 |
+
return ddp
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def default_argument_parser(epilog=None):
|
| 87 |
+
"""
|
| 88 |
+
Create a parser with some common arguments used by detectron2 users.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
epilog (str): epilog passed to ArgumentParser describing the usage.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
argparse.ArgumentParser:
|
| 95 |
+
"""
|
| 96 |
+
parser = argparse.ArgumentParser(
|
| 97 |
+
epilog=epilog
|
| 98 |
+
or f"""
|
| 99 |
+
Examples:
|
| 100 |
+
|
| 101 |
+
Run on single machine:
|
| 102 |
+
$ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
|
| 103 |
+
|
| 104 |
+
Change some config options:
|
| 105 |
+
$ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
|
| 106 |
+
|
| 107 |
+
Run on multiple machines:
|
| 108 |
+
(machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
|
| 109 |
+
(machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
|
| 110 |
+
""",
|
| 111 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--resume",
|
| 116 |
+
action="store_true",
|
| 117 |
+
help="Whether to attempt to resume from the checkpoint directory. "
|
| 118 |
+
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
| 121 |
+
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
|
| 122 |
+
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--test-dataset", type=str, default="", help="the dataset used for evaluation"
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--train-dataset", type=str, default="", help="the dataset used for training"
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument("--no-segm", action="store_true", help="perform evaluation on detection only")
|
| 133 |
+
# PyTorch still may leave orphan processes in multi-gpu training.
|
| 134 |
+
# Therefore we use a deterministic way to obtain port,
|
| 135 |
+
# so that users are aware of orphan processes by seeing the port occupied.
|
| 136 |
+
port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--dist-url",
|
| 139 |
+
default="tcp://127.0.0.1:{}".format(port),
|
| 140 |
+
help="initialization URL for pytorch distributed backend. See "
|
| 141 |
+
"https://pytorch.org/docs/stable/distributed.html for details.",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"opts",
|
| 145 |
+
help="""
|
| 146 |
+
Modify config options at the end of the command. For Yacs configs, use
|
| 147 |
+
space-separated "PATH.KEY VALUE" pairs.
|
| 148 |
+
For python-based LazyConfig, use "path.key=value".
|
| 149 |
+
""".strip(),
|
| 150 |
+
default=None,
|
| 151 |
+
nargs=argparse.REMAINDER,
|
| 152 |
+
)
|
| 153 |
+
return parser
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _try_get_key(cfg, *keys, default=None):
|
| 157 |
+
"""
|
| 158 |
+
Try select keys from cfg until the first key that exists. Otherwise return default.
|
| 159 |
+
"""
|
| 160 |
+
if isinstance(cfg, CfgNode):
|
| 161 |
+
cfg = OmegaConf.create(cfg.dump())
|
| 162 |
+
for k in keys:
|
| 163 |
+
none = object()
|
| 164 |
+
p = OmegaConf.select(cfg, k, default=none)
|
| 165 |
+
if p is not none:
|
| 166 |
+
return p
|
| 167 |
+
return default
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _highlight(code, filename):
|
| 171 |
+
try:
|
| 172 |
+
import pygments
|
| 173 |
+
except ImportError:
|
| 174 |
+
return code
|
| 175 |
+
|
| 176 |
+
from pygments.lexers import Python3Lexer, YamlLexer
|
| 177 |
+
from pygments.formatters import Terminal256Formatter
|
| 178 |
+
|
| 179 |
+
lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
|
| 180 |
+
code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
|
| 181 |
+
return code
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def default_setup(cfg, args):
|
| 185 |
+
"""
|
| 186 |
+
Perform some basic common setups at the beginning of a job, including:
|
| 187 |
+
|
| 188 |
+
1. Set up the detectron2 logger
|
| 189 |
+
2. Log basic information about environment, cmdline arguments, and config
|
| 190 |
+
3. Backup the config to the output directory
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
cfg (CfgNode or omegaconf.DictConfig): the full config to be used
|
| 194 |
+
args (argparse.NameSpace): the command line arguments to be logged
|
| 195 |
+
"""
|
| 196 |
+
output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
|
| 197 |
+
if comm.is_main_process() and output_dir:
|
| 198 |
+
PathManager.mkdirs(output_dir)
|
| 199 |
+
|
| 200 |
+
rank = comm.get_rank()
|
| 201 |
+
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
| 202 |
+
logger = setup_logger(output_dir, distributed_rank=rank)
|
| 203 |
+
|
| 204 |
+
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
|
| 205 |
+
logger.info("Environment info:\n" + collect_env_info())
|
| 206 |
+
|
| 207 |
+
logger.info("Command line arguments: " + str(args))
|
| 208 |
+
if hasattr(args, "config_file") and args.config_file != "":
|
| 209 |
+
logger.info(
|
| 210 |
+
"Contents of args.config_file={}:\n{}".format(
|
| 211 |
+
args.config_file,
|
| 212 |
+
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if comm.is_main_process() and output_dir:
|
| 217 |
+
# Note: some of our scripts may expect the existence of
|
| 218 |
+
# config.yaml in output directory
|
| 219 |
+
path = os.path.join(output_dir, "config.yaml")
|
| 220 |
+
if isinstance(cfg, CfgNode):
|
| 221 |
+
logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
|
| 222 |
+
with PathManager.open(path, "w") as f:
|
| 223 |
+
f.write(cfg.dump())
|
| 224 |
+
else:
|
| 225 |
+
LazyConfig.save(cfg, path)
|
| 226 |
+
logger.info("Full config saved to {}".format(path))
|
| 227 |
+
|
| 228 |
+
# make sure each worker has a different, yet deterministic seed if specified
|
| 229 |
+
seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
|
| 230 |
+
seed_all_rng(None if seed < 0 else seed + rank)
|
| 231 |
+
|
| 232 |
+
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
|
| 233 |
+
# typical validation set.
|
| 234 |
+
if not (hasattr(args, "eval_only") and args.eval_only):
|
| 235 |
+
torch.backends.cudnn.benchmark = _try_get_key(
|
| 236 |
+
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def default_writers(output_dir: str, max_iter: Optional[int] = None):
|
| 241 |
+
"""
|
| 242 |
+
Build a list of :class:`EventWriter` to be used.
|
| 243 |
+
It now consists of a :class:`CommonMetricPrinter`,
|
| 244 |
+
:class:`TensorboardXWriter` and :class:`JSONWriter`.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
output_dir: directory to store JSON metrics and tensorboard events
|
| 248 |
+
max_iter: the total number of iterations
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
list[EventWriter]: a list of :class:`EventWriter` objects.
|
| 252 |
+
"""
|
| 253 |
+
PathManager.mkdirs(output_dir)
|
| 254 |
+
return [
|
| 255 |
+
# It may not always print what you want to see, since it prints "common" metrics only.
|
| 256 |
+
CommonMetricPrinter(max_iter),
|
| 257 |
+
JSONWriter(os.path.join(output_dir, "metrics.json")),
|
| 258 |
+
TensorboardXWriter(output_dir),
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class DefaultPredictor:
|
| 263 |
+
"""
|
| 264 |
+
Create a simple end-to-end predictor with the given config that runs on
|
| 265 |
+
single device for a single input image.
|
| 266 |
+
|
| 267 |
+
Compared to using the model directly, this class does the following additions:
|
| 268 |
+
|
| 269 |
+
1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
|
| 270 |
+
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
|
| 271 |
+
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
|
| 272 |
+
4. Take one input image and produce a single output, instead of a batch.
|
| 273 |
+
|
| 274 |
+
This is meant for simple demo purposes, so it does the above steps automatically.
|
| 275 |
+
This is not meant for benchmarks or running complicated inference logic.
|
| 276 |
+
If you'd like to do anything more complicated, please refer to its source code as
|
| 277 |
+
examples to build and use the model manually.
|
| 278 |
+
|
| 279 |
+
Attributes:
|
| 280 |
+
metadata (Metadata): the metadata of the underlying dataset, obtained from
|
| 281 |
+
cfg.DATASETS.TEST.
|
| 282 |
+
|
| 283 |
+
Examples:
|
| 284 |
+
::
|
| 285 |
+
pred = DefaultPredictor(cfg)
|
| 286 |
+
inputs = cv2.imread("input.jpg")
|
| 287 |
+
outputs = pred(inputs)
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
def __init__(self, cfg):
|
| 291 |
+
self.cfg = cfg.clone() # cfg can be modified by model
|
| 292 |
+
self.model = build_model(self.cfg)
|
| 293 |
+
self.model.eval()
|
| 294 |
+
if len(cfg.DATASETS.TEST):
|
| 295 |
+
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
| 296 |
+
|
| 297 |
+
checkpointer = DetectionCheckpointer(self.model)
|
| 298 |
+
checkpointer.load(cfg.MODEL.WEIGHTS)
|
| 299 |
+
|
| 300 |
+
self.aug = T.ResizeShortestEdge(
|
| 301 |
+
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
self.input_format = cfg.INPUT.FORMAT
|
| 305 |
+
assert self.input_format in ["RGB", "BGR"], self.input_format
|
| 306 |
+
|
| 307 |
+
def __call__(self, original_image):
|
| 308 |
+
"""
|
| 309 |
+
Args:
|
| 310 |
+
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
predictions (dict):
|
| 314 |
+
the output of the model for one image only.
|
| 315 |
+
See :doc:`/tutorials/models` for details about the format.
|
| 316 |
+
"""
|
| 317 |
+
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
| 318 |
+
# Apply pre-processing to image.
|
| 319 |
+
if self.input_format == "RGB":
|
| 320 |
+
# whether the model expects BGR inputs or RGB
|
| 321 |
+
original_image = original_image[:, :, ::-1]
|
| 322 |
+
height, width = original_image.shape[:2]
|
| 323 |
+
image = self.aug.get_transform(original_image).apply_image(original_image)
|
| 324 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
| 325 |
+
|
| 326 |
+
inputs = {"image": image, "height": height, "width": width}
|
| 327 |
+
predictions = self.model([inputs])[0]
|
| 328 |
+
return predictions
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class DefaultTrainer(TrainerBase):
|
| 332 |
+
"""
|
| 333 |
+
A trainer with default training logic. It does the following:
|
| 334 |
+
|
| 335 |
+
1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
|
| 336 |
+
defined by the given config. Create a LR scheduler defined by the config.
|
| 337 |
+
2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
|
| 338 |
+
`resume_or_load` is called.
|
| 339 |
+
3. Register a few common hooks defined by the config.
|
| 340 |
+
|
| 341 |
+
It is created to simplify the **standard model training workflow** and reduce code boilerplate
|
| 342 |
+
for users who only need the standard training workflow, with standard features.
|
| 343 |
+
It means this class makes *many assumptions* about your training logic that
|
| 344 |
+
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
|
| 345 |
+
:class:`SimpleTrainer` are too much for research.
|
| 346 |
+
|
| 347 |
+
The code of this class has been annotated about restrictive assumptions it makes.
|
| 348 |
+
When they do not work for you, you're encouraged to:
|
| 349 |
+
|
| 350 |
+
1. Overwrite methods of this class, OR:
|
| 351 |
+
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
|
| 352 |
+
nothing else. You can then add your own hooks if needed. OR:
|
| 353 |
+
3. Write your own training loop similar to `tools/plain_train_net.py`.
|
| 354 |
+
|
| 355 |
+
See the :doc:`/tutorials/training` tutorials for more details.
|
| 356 |
+
|
| 357 |
+
Note that the behavior of this class, like other functions/classes in
|
| 358 |
+
this file, is not stable, since it is meant to represent the "common default behavior".
|
| 359 |
+
It is only guaranteed to work well with the standard models and training workflow in detectron2.
|
| 360 |
+
To obtain more stable behavior, write your own training logic with other public APIs.
|
| 361 |
+
|
| 362 |
+
Examples:
|
| 363 |
+
::
|
| 364 |
+
trainer = DefaultTrainer(cfg)
|
| 365 |
+
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
|
| 366 |
+
trainer.train()
|
| 367 |
+
|
| 368 |
+
Attributes:
|
| 369 |
+
scheduler:
|
| 370 |
+
checkpointer (DetectionCheckpointer):
|
| 371 |
+
cfg (CfgNode):
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
def __init__(self, cfg):
|
| 375 |
+
"""
|
| 376 |
+
Args:
|
| 377 |
+
cfg (CfgNode):
|
| 378 |
+
"""
|
| 379 |
+
super().__init__()
|
| 380 |
+
logger = logging.getLogger("detectron2")
|
| 381 |
+
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
| 382 |
+
setup_logger()
|
| 383 |
+
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
|
| 384 |
+
|
| 385 |
+
# Assume these objects must be constructed in this order.
|
| 386 |
+
model = self.build_model(cfg)
|
| 387 |
+
optimizer = self.build_optimizer(cfg, model)
|
| 388 |
+
data_loader = self.build_train_loader(cfg)
|
| 389 |
+
|
| 390 |
+
model = create_ddp_model(model, broadcast_buffers=False)
|
| 391 |
+
if cfg.SOLVER.AMP.ENABLED:
|
| 392 |
+
self._trainer = CustomAMPTrainer(model, data_loader, optimizer, cfg=cfg)
|
| 393 |
+
else:
|
| 394 |
+
self._trainer = CustomSimpleTrainer(model, data_loader, optimizer, cfg=cfg)
|
| 395 |
+
|
| 396 |
+
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
| 397 |
+
self.checkpointer = DetectionCheckpointer(
|
| 398 |
+
# Assume you want to save checkpoints together with logs/statistics
|
| 399 |
+
model,
|
| 400 |
+
cfg.OUTPUT_DIR,
|
| 401 |
+
trainer=weakref.proxy(self),
|
| 402 |
+
)
|
| 403 |
+
self.start_iter = 0
|
| 404 |
+
self.max_iter = cfg.SOLVER.MAX_ITER
|
| 405 |
+
self.cfg = cfg
|
| 406 |
+
|
| 407 |
+
self.register_hooks(self.build_hooks())
|
| 408 |
+
|
| 409 |
+
def resume_or_load(self, resume=True):
|
| 410 |
+
"""
|
| 411 |
+
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
|
| 412 |
+
a `last_checkpoint` file), resume from the file. Resuming means loading all
|
| 413 |
+
available states (eg. optimizer and scheduler) and update iteration counter
|
| 414 |
+
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
|
| 415 |
+
|
| 416 |
+
Otherwise, this is considered as an independent training. The method will load model
|
| 417 |
+
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
|
| 418 |
+
from iteration 0.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
resume (bool): whether to do resume or not
|
| 422 |
+
"""
|
| 423 |
+
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
|
| 424 |
+
if resume and self.checkpointer.has_checkpoint():
|
| 425 |
+
# The checkpoint stores the training iteration that just finished, thus we start
|
| 426 |
+
# at the next iteration
|
| 427 |
+
self.start_iter = self.iter + 1
|
| 428 |
+
|
| 429 |
+
def build_hooks(self):
|
| 430 |
+
"""
|
| 431 |
+
Build a list of default hooks, including timing, evaluation,
|
| 432 |
+
checkpointing, lr scheduling, precise BN, writing events.
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
list[HookBase]:
|
| 436 |
+
"""
|
| 437 |
+
cfg = self.cfg.clone()
|
| 438 |
+
cfg.defrost()
|
| 439 |
+
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
| 440 |
+
|
| 441 |
+
ret = [
|
| 442 |
+
hooks.IterationTimer(),
|
| 443 |
+
hooks.LRScheduler(),
|
| 444 |
+
hooks.PreciseBN(
|
| 445 |
+
# Run at the same freq as (but before) evaluation.
|
| 446 |
+
cfg.TEST.EVAL_PERIOD,
|
| 447 |
+
self.model,
|
| 448 |
+
# Build a new data loader to not affect training
|
| 449 |
+
self.build_train_loader(cfg),
|
| 450 |
+
cfg.TEST.PRECISE_BN.NUM_ITER,
|
| 451 |
+
)
|
| 452 |
+
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
| 453 |
+
else None,
|
| 454 |
+
]
|
| 455 |
+
|
| 456 |
+
# Do PreciseBN before checkpointer, because it updates the model and need to
|
| 457 |
+
# be saved by checkpointer.
|
| 458 |
+
# This is not always the best: if checkpointing has a different frequency,
|
| 459 |
+
# some checkpoints may have more precise statistics than others.
|
| 460 |
+
if comm.is_main_process():
|
| 461 |
+
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
| 462 |
+
|
| 463 |
+
def test_and_save_results():
|
| 464 |
+
self._last_eval_results = self.test(self.cfg, self.model)
|
| 465 |
+
return self._last_eval_results
|
| 466 |
+
|
| 467 |
+
# Do evaluation after checkpointer, because then if it fails,
|
| 468 |
+
# we can use the saved checkpoint to debug.
|
| 469 |
+
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
| 470 |
+
|
| 471 |
+
if comm.is_main_process():
|
| 472 |
+
# Here the default print/log frequency of each writer is used.
|
| 473 |
+
# run writers in the end, so that evaluation metrics are written
|
| 474 |
+
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
|
| 475 |
+
return ret
|
| 476 |
+
|
| 477 |
+
def build_writers(self):
|
| 478 |
+
"""
|
| 479 |
+
Build a list of writers to be used using :func:`default_writers()`.
|
| 480 |
+
If you'd like a different list of writers, you can overwrite it in
|
| 481 |
+
your trainer.
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
list[EventWriter]: a list of :class:`EventWriter` objects.
|
| 485 |
+
"""
|
| 486 |
+
return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
|
| 487 |
+
|
| 488 |
+
def train(self):
|
| 489 |
+
"""
|
| 490 |
+
Run training.
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
| 494 |
+
"""
|
| 495 |
+
super().train(self.start_iter, self.max_iter)
|
| 496 |
+
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
|
| 497 |
+
assert hasattr(
|
| 498 |
+
self, "_last_eval_results"
|
| 499 |
+
), "No evaluation results obtained during training!"
|
| 500 |
+
verify_results(self.cfg, self._last_eval_results)
|
| 501 |
+
return self._last_eval_results
|
| 502 |
+
|
| 503 |
+
def run_step(self):
|
| 504 |
+
self._trainer.iter = self.iter
|
| 505 |
+
self._trainer.run_step()
|
| 506 |
+
|
| 507 |
+
def state_dict(self):
|
| 508 |
+
ret = super().state_dict()
|
| 509 |
+
ret["_trainer"] = self._trainer.state_dict()
|
| 510 |
+
return ret
|
| 511 |
+
|
| 512 |
+
def load_state_dict(self, state_dict):
|
| 513 |
+
super().load_state_dict(state_dict)
|
| 514 |
+
self._trainer.load_state_dict(state_dict["_trainer"])
|
| 515 |
+
|
| 516 |
+
@classmethod
|
| 517 |
+
def build_model(cls, cfg):
|
| 518 |
+
"""
|
| 519 |
+
Returns:
|
| 520 |
+
torch.nn.Module:
|
| 521 |
+
|
| 522 |
+
It now calls :func:`detectron2.modeling.build_model`.
|
| 523 |
+
Overwrite it if you'd like a different model.
|
| 524 |
+
"""
|
| 525 |
+
model = build_model(cfg)
|
| 526 |
+
logger = logging.getLogger(__name__)
|
| 527 |
+
logger.info("Model:\n{}".format(model))
|
| 528 |
+
return model
|
| 529 |
+
|
| 530 |
+
@classmethod
|
| 531 |
+
def build_optimizer(cls, cfg, model):
|
| 532 |
+
"""
|
| 533 |
+
Returns:
|
| 534 |
+
torch.optim.Optimizer:
|
| 535 |
+
|
| 536 |
+
It now calls :func:`detectron2.solver.build_optimizer`.
|
| 537 |
+
Overwrite it if you'd like a different optimizer.
|
| 538 |
+
"""
|
| 539 |
+
return build_optimizer(cfg, model)
|
| 540 |
+
|
| 541 |
+
@classmethod
|
| 542 |
+
def build_lr_scheduler(cls, cfg, optimizer):
|
| 543 |
+
"""
|
| 544 |
+
It now calls :func:`detectron2.solver.build_lr_scheduler`.
|
| 545 |
+
Overwrite it if you'd like a different scheduler.
|
| 546 |
+
"""
|
| 547 |
+
return build_lr_scheduler(cfg, optimizer)
|
| 548 |
+
|
| 549 |
+
@classmethod
|
| 550 |
+
def build_train_loader(cls, cfg):
|
| 551 |
+
"""
|
| 552 |
+
Returns:
|
| 553 |
+
iterable
|
| 554 |
+
|
| 555 |
+
It now calls :func:`detectron2.data.build_detection_train_loader`.
|
| 556 |
+
Overwrite it if you'd like a different data loader.
|
| 557 |
+
"""
|
| 558 |
+
return build_detection_train_loader(cfg)
|
| 559 |
+
|
| 560 |
+
@classmethod
|
| 561 |
+
def build_test_loader(cls, cfg, dataset_name):
|
| 562 |
+
"""
|
| 563 |
+
Returns:
|
| 564 |
+
iterable
|
| 565 |
+
|
| 566 |
+
It now calls :func:`detectron2.data.build_detection_test_loader`.
|
| 567 |
+
Overwrite it if you'd like a different data loader.
|
| 568 |
+
"""
|
| 569 |
+
return build_detection_test_loader(cfg, dataset_name)
|
| 570 |
+
|
| 571 |
+
@classmethod
|
| 572 |
+
def build_evaluator(cls, cfg, dataset_name):
|
| 573 |
+
"""
|
| 574 |
+
Returns:
|
| 575 |
+
DatasetEvaluator or None
|
| 576 |
+
|
| 577 |
+
It is not implemented by default.
|
| 578 |
+
"""
|
| 579 |
+
raise NotImplementedError(
|
| 580 |
+
"""
|
| 581 |
+
If you want DefaultTrainer to automatically run evaluation,
|
| 582 |
+
please implement `build_evaluator()` in subclasses (see train_net.py for example).
|
| 583 |
+
Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
|
| 584 |
+
"""
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
@classmethod
|
| 588 |
+
def test(cls, cfg, model, evaluators=None):
|
| 589 |
+
"""
|
| 590 |
+
Evaluate the given model. The given model is expected to already contain
|
| 591 |
+
weights to evaluate.
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
cfg (CfgNode):
|
| 595 |
+
model (nn.Module):
|
| 596 |
+
evaluators (list[DatasetEvaluator] or None): if None, will call
|
| 597 |
+
:meth:`build_evaluator`. Otherwise, must have the same length as
|
| 598 |
+
``cfg.DATASETS.TEST``.
|
| 599 |
+
|
| 600 |
+
Returns:
|
| 601 |
+
dict: a dict of result metrics
|
| 602 |
+
"""
|
| 603 |
+
logger = logging.getLogger(__name__)
|
| 604 |
+
if isinstance(evaluators, DatasetEvaluator):
|
| 605 |
+
evaluators = [evaluators]
|
| 606 |
+
if evaluators is not None:
|
| 607 |
+
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
| 608 |
+
len(cfg.DATASETS.TEST), len(evaluators)
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
results = OrderedDict()
|
| 612 |
+
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
| 613 |
+
data_loader = cls.build_test_loader(cfg, dataset_name)
|
| 614 |
+
# When evaluators are passed in as arguments,
|
| 615 |
+
# implicitly assume that evaluators can be created before data_loader.
|
| 616 |
+
if evaluators is not None:
|
| 617 |
+
evaluator = evaluators[idx]
|
| 618 |
+
else:
|
| 619 |
+
try:
|
| 620 |
+
evaluator = cls.build_evaluator(cfg, dataset_name)
|
| 621 |
+
except NotImplementedError:
|
| 622 |
+
logger.warn(
|
| 623 |
+
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
| 624 |
+
"or implement its `build_evaluator` method."
|
| 625 |
+
)
|
| 626 |
+
results[dataset_name] = {}
|
| 627 |
+
continue
|
| 628 |
+
results_i = inference_on_dataset(model, data_loader, evaluator)
|
| 629 |
+
results[dataset_name] = results_i
|
| 630 |
+
if comm.is_main_process():
|
| 631 |
+
assert isinstance(
|
| 632 |
+
results_i, dict
|
| 633 |
+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
| 634 |
+
results_i
|
| 635 |
+
)
|
| 636 |
+
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
| 637 |
+
print_csv_format(results_i)
|
| 638 |
+
|
| 639 |
+
if len(results) == 1:
|
| 640 |
+
results = list(results.values())[0]
|
| 641 |
+
return results
|
| 642 |
+
|
| 643 |
+
@staticmethod
|
| 644 |
+
def auto_scale_workers(cfg, num_workers: int):
|
| 645 |
+
"""
|
| 646 |
+
When the config is defined for certain number of workers (according to
|
| 647 |
+
``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
|
| 648 |
+
workers currently in use, returns a new cfg where the total batch size
|
| 649 |
+
is scaled so that the per-GPU batch size stays the same as the
|
| 650 |
+
original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
|
| 651 |
+
|
| 652 |
+
Other config options are also scaled accordingly:
|
| 653 |
+
* training steps and warmup steps are scaled inverse proportionally.
|
| 654 |
+
* learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
|
| 655 |
+
|
| 656 |
+
For example, with the original config like the following:
|
| 657 |
+
|
| 658 |
+
.. code-block:: yaml
|
| 659 |
+
|
| 660 |
+
IMS_PER_BATCH: 16
|
| 661 |
+
BASE_LR: 0.1
|
| 662 |
+
REFERENCE_WORLD_SIZE: 8
|
| 663 |
+
MAX_ITER: 5000
|
| 664 |
+
STEPS: (4000,)
|
| 665 |
+
CHECKPOINT_PERIOD: 1000
|
| 666 |
+
|
| 667 |
+
When this config is used on 16 GPUs instead of the reference number 8,
|
| 668 |
+
calling this method will return a new config with:
|
| 669 |
+
|
| 670 |
+
.. code-block:: yaml
|
| 671 |
+
|
| 672 |
+
IMS_PER_BATCH: 32
|
| 673 |
+
BASE_LR: 0.2
|
| 674 |
+
REFERENCE_WORLD_SIZE: 16
|
| 675 |
+
MAX_ITER: 2500
|
| 676 |
+
STEPS: (2000,)
|
| 677 |
+
CHECKPOINT_PERIOD: 500
|
| 678 |
+
|
| 679 |
+
Note that both the original config and this new config can be trained on 16 GPUs.
|
| 680 |
+
It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
|
| 681 |
+
|
| 682 |
+
Returns:
|
| 683 |
+
CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
|
| 684 |
+
"""
|
| 685 |
+
old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
|
| 686 |
+
if old_world_size == 0 or old_world_size == num_workers:
|
| 687 |
+
return cfg
|
| 688 |
+
cfg = cfg.clone()
|
| 689 |
+
frozen = cfg.is_frozen()
|
| 690 |
+
cfg.defrost()
|
| 691 |
+
|
| 692 |
+
assert (
|
| 693 |
+
cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
|
| 694 |
+
), "Invalid REFERENCE_WORLD_SIZE in config!"
|
| 695 |
+
scale = num_workers / old_world_size
|
| 696 |
+
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
|
| 697 |
+
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
|
| 698 |
+
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
|
| 699 |
+
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
|
| 700 |
+
cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
|
| 701 |
+
cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
|
| 702 |
+
cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
|
| 703 |
+
cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
|
| 704 |
+
logger = logging.getLogger(__name__)
|
| 705 |
+
logger.info(
|
| 706 |
+
f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
|
| 707 |
+
f"max_iter={max_iter}, warmup={warmup_iter}."
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
if frozen:
|
| 711 |
+
cfg.freeze()
|
| 712 |
+
return cfg
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# Access basic attributes from the underlying trainer
|
| 716 |
+
for _attr in ["model", "data_loader", "optimizer"]:
|
| 717 |
+
setattr(
|
| 718 |
+
DefaultTrainer,
|
| 719 |
+
_attr,
|
| 720 |
+
property(
|
| 721 |
+
# getter
|
| 722 |
+
lambda self, x=_attr: getattr(self._trainer, x),
|
| 723 |
+
# setter
|
| 724 |
+
lambda self, value, x=_attr: setattr(self._trainer, x, value),
|
| 725 |
+
),
|
| 726 |
+
)
|