Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +14 -0
- .gitignore +30 -0
- ComfyUI_AEMatter/AEMatter.py +1248 -0
- ComfyUI_AEMatter/AEMatter.run.sh +3 -0
- ComfyUI_AEMatter/README.org +1357 -0
- ComfyUI_AEMatter/__init__.py +1248 -0
- ComfyUI_MVANet/MVANet_inference.py +1548 -0
- ComfyUI_MVANet/MVANet_inference.run.sh +3 -0
- ComfyUI_MVANet/README.org +1694 -0
- ComfyUI_MVANet/__init__.py +1548 -0
- ComfyUI_MVANet/download.sh +13 -0
- ComfyUI_MVANet/requirements.txt +3 -0
- LICENSE +21 -0
- MVANet_Inference/README.org +2179 -0
- README.md +131 -0
- checkpoints/AEMatter/AEM_RWA.ckpt +3 -0
- checkpoints/MVANet/garment.pth +3 -0
- checkpoints/MVANet/skin.pth +3 -0
- checkpoints/Model_80.pth +3 -0
- checkpoints/StableDiffusion/90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b +3 -0
- checkpoints/StableDiffusion/b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30 +3 -0
- checkpoints/StableDiffusion/hash +2 -0
- checkpoints/atr.pth +3 -0
- checkpoints/lip.pth +3 -0
- checkpoints/pascal.pth +3 -0
- datasets/__init__.py +0 -0
- datasets/datasets.py +201 -0
- datasets/simple_extractor_dataset.py +78 -0
- datasets/target_generation.py +40 -0
- demo/demo.jpg +3 -0
- demo/demo_atr.png +0 -0
- demo/demo_lip.png +0 -0
- demo/demo_pascal.png +0 -0
- demo/lip-visualization.jpg +3 -0
- environment.yaml +49 -0
- evaluate.py +209 -0
- main.org +663 -0
- mhp_extension/README.md +38 -0
- mhp_extension/coco_style_annotation_creator/__pycache__/pycococreatortools.cpython-37.pyc +0 -0
- mhp_extension/coco_style_annotation_creator/human_to_coco.py +166 -0
- mhp_extension/coco_style_annotation_creator/pycococreatortools.py +114 -0
- mhp_extension/coco_style_annotation_creator/test_human2coco_format.py +74 -0
- mhp_extension/demo.ipynb +0 -0
- mhp_extension/detectron2/.circleci/config.yml +179 -0
- mhp_extension/detectron2/.clang-format +85 -0
- mhp_extension/detectron2/.flake8 +9 -0
- mhp_extension/detectron2/.gitignore +46 -0
- mhp_extension/detectron2/GETTING_STARTED.md +79 -0
- mhp_extension/detectron2/INSTALL.md +184 -0
- mhp_extension/detectron2/LICENSE +201 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
checkpoints/atr.pth filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
checkpoints/lip.pth filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
checkpoints/pascal.pth filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
checkpoints/Model_80.pth filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
checkpoints/AEMatter/AEM_RWA.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
checkpoints/StableDiffusion/90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
checkpoints/StableDiffusion/b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
checkpoints/MVANet/skin.pth filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
checkpoints/MVANet/garment.pth filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
demo/demo_lip.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
demo/lip-visualization.jpg filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
demo/demo_pascal.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
demo/demo_atr.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
demo/demo.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/ComfyUI_MVANet/__pycache__/__init__.cpython-310.pyc
|
| 2 |
+
/ComfyUI_MVANet/#README.org#
|
| 3 |
+
/ComfyUI_MVANet/.#README.org
|
| 4 |
+
/ComfyUI_MVANet/README.org~
|
| 5 |
+
/ComfyUI_MVANet/.README.org.~undo-tree~
|
| 6 |
+
/#main.org#
|
| 7 |
+
/.#main.org
|
| 8 |
+
/main.org~
|
| 9 |
+
/.main.org.~undo-tree~
|
| 10 |
+
/.README.md.~undo-tree~
|
| 11 |
+
/ComfyUI_MVANet/.#README.org
|
| 12 |
+
/ComfyUI_AEMatter/__pycache__/__init__.cpython-310.pyc
|
| 13 |
+
/ComfyUI_AEMatter/AEMatter.class.py
|
| 14 |
+
/ComfyUI_AEMatter/AEMatter.execute.py
|
| 15 |
+
/ComfyUI_AEMatter/AEMatter.function.py
|
| 16 |
+
/ComfyUI_AEMatter/AEMatter.import.py
|
| 17 |
+
/ComfyUI_MVANet/MVANet_inference.class.py
|
| 18 |
+
/ComfyUI_MVANet/MVANet_inference.execute.py
|
| 19 |
+
/ComfyUI_MVANet/MVANet_inference.function.py
|
| 20 |
+
/ComfyUI_MVANet/MVANet_inference.import.py
|
| 21 |
+
/ComfyUI_MVANet/MVANet_inference.unify.sh
|
| 22 |
+
/ComfyUI_AEMatter/AEMatter.unify.sh
|
| 23 |
+
/git_add.txt
|
| 24 |
+
/git_lfs_track.txt
|
| 25 |
+
/gitignore.txt
|
| 26 |
+
/rm.txt
|
| 27 |
+
/work.sh
|
| 28 |
+
log/
|
| 29 |
+
pretrain_model/
|
| 30 |
+
commit_and_push.sh
|
ComfyUI_AEMatter/AEMatter.py
ADDED
|
@@ -0,0 +1,1248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import wget
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import init
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.checkpoint as checkpoint
|
| 14 |
+
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 18 |
+
|
| 19 |
+
import folder_paths
|
| 20 |
+
from folder_paths import models_dir
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
#!/usr/bin/python3
|
| 24 |
+
def mkdir_safe(out_path):
|
| 25 |
+
if type(out_path) == str:
|
| 26 |
+
if len(out_path) > 0:
|
| 27 |
+
if not os.path.exists(out_path):
|
| 28 |
+
os.mkdir(out_path)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_model_path():
|
| 32 |
+
import folder_paths
|
| 33 |
+
from folder_paths import models_dir
|
| 34 |
+
|
| 35 |
+
path_file_model = models_dir
|
| 36 |
+
mkdir_safe(out_path=path_file_model)
|
| 37 |
+
|
| 38 |
+
path_file_model = os.path.join(path_file_model, 'AEMatter')
|
| 39 |
+
mkdir_safe(out_path=path_file_model)
|
| 40 |
+
|
| 41 |
+
path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
|
| 42 |
+
|
| 43 |
+
return path_file_model
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def download_model(path):
|
| 47 |
+
if not os.path.exists(path):
|
| 48 |
+
wget.download(
|
| 49 |
+
'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
|
| 50 |
+
out=path)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def from_torch_image(image):
|
| 54 |
+
image = image.cpu().numpy() * 255.0
|
| 55 |
+
image = np.clip(image, 0, 255).astype(np.uint8)
|
| 56 |
+
return image
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def to_torch_image(image):
|
| 60 |
+
image = image.astype(dtype=np.float32)
|
| 61 |
+
image /= 255.0
|
| 62 |
+
image = torch.from_numpy(image)
|
| 63 |
+
return image
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def window_partition(x, window_size):
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
x: (B, H, W, C)
|
| 70 |
+
window_size (int): window size
|
| 71 |
+
Returns:
|
| 72 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 73 |
+
"""
|
| 74 |
+
B, H, W, C = x.shape
|
| 75 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
| 76 |
+
C)
|
| 77 |
+
windows = x.permute(0, 1, 3, 2, 4,
|
| 78 |
+
5).contiguous().view(-1, window_size, window_size, C)
|
| 79 |
+
return windows
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def window_reverse(windows, window_size, H, W):
|
| 83 |
+
"""
|
| 84 |
+
Args:
|
| 85 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 86 |
+
window_size (int): Window size
|
| 87 |
+
H (int): Height of image
|
| 88 |
+
W (int): Width of image
|
| 89 |
+
Returns:
|
| 90 |
+
x: (B, H, W, C)
|
| 91 |
+
"""
|
| 92 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 93 |
+
x = windows.view(B, H // window_size, W // window_size, window_size,
|
| 94 |
+
window_size, -1)
|
| 95 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_AEMatter_model(path_model_checkpoint):
|
| 100 |
+
|
| 101 |
+
download_model(path=path_model_checkpoint)
|
| 102 |
+
|
| 103 |
+
matmodel = AEMatter()
|
| 104 |
+
matmodel.load_state_dict(
|
| 105 |
+
torch.load(path_model_checkpoint, map_location='cpu')['model'])
|
| 106 |
+
|
| 107 |
+
matmodel = matmodel.cuda()
|
| 108 |
+
matmodel.eval()
|
| 109 |
+
|
| 110 |
+
return matmodel
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def do_infer(rawimg, trimap, matmodel):
|
| 114 |
+
trimap_nonp = trimap.copy()
|
| 115 |
+
h, w, c = rawimg.shape
|
| 116 |
+
nonph, nonpw, _ = rawimg.shape
|
| 117 |
+
newh = (((h - 1) // 32) + 1) * 32
|
| 118 |
+
neww = (((w - 1) // 32) + 1) * 32
|
| 119 |
+
padh = newh - h
|
| 120 |
+
padh1 = int(padh / 2)
|
| 121 |
+
padh2 = padh - padh1
|
| 122 |
+
padw = neww - w
|
| 123 |
+
padw1 = int(padw / 2)
|
| 124 |
+
padw2 = padw - padw1
|
| 125 |
+
|
| 126 |
+
rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
|
| 127 |
+
cv2.BORDER_REFLECT)
|
| 128 |
+
|
| 129 |
+
trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
|
| 130 |
+
cv2.BORDER_REFLECT)
|
| 131 |
+
|
| 132 |
+
h_pad, w_pad, _ = rawimg_pad.shape
|
| 133 |
+
tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
|
| 134 |
+
tritemp[:, :, 0] = (trimap_pad == 0)
|
| 135 |
+
tritemp[:, :, 1] = (trimap_pad == 128)
|
| 136 |
+
tritemp[:, :, 2] = (trimap_pad == 255)
|
| 137 |
+
tritempimgs = np.transpose(tritemp, (2, 0, 1))
|
| 138 |
+
tritempimgs = tritempimgs[np.newaxis, :, :, :]
|
| 139 |
+
img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
|
| 140 |
+
img = np.array(img, np.float32)
|
| 141 |
+
img = img / 255.
|
| 142 |
+
img = torch.from_numpy(img).cuda()
|
| 143 |
+
tritempimgs = torch.from_numpy(tritempimgs).cuda()
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
pred = matmodel(img, tritempimgs)
|
| 146 |
+
pred = pred.detach().cpu().numpy()[0]
|
| 147 |
+
pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
|
| 148 |
+
preda = pred[
|
| 149 |
+
0:1,
|
| 150 |
+
] * 255
|
| 151 |
+
preda = np.transpose(preda, (1, 2, 0))
|
| 152 |
+
preda = preda * (trimap_nonp[:, :, None]
|
| 153 |
+
== 128) + (trimap_nonp[:, :, None] == 255) * 255
|
| 154 |
+
preda = np.array(preda, np.uint8)
|
| 155 |
+
return preda
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def main():
|
| 159 |
+
ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
|
| 160 |
+
pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
|
| 161 |
+
p_outs = 'alpha.png'
|
| 162 |
+
|
| 163 |
+
matmodel = get_AEMatter_model(
|
| 164 |
+
path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
|
| 165 |
+
|
| 166 |
+
# matmodel = AEMatter()
|
| 167 |
+
# matmodel.load_state_dict(
|
| 168 |
+
# torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
|
| 169 |
+
# map_location='cpu')['model'])
|
| 170 |
+
|
| 171 |
+
# matmodel = matmodel.cuda()
|
| 172 |
+
# matmodel.eval()
|
| 173 |
+
|
| 174 |
+
rawimg = pimgs
|
| 175 |
+
trimap = ptrimap
|
| 176 |
+
rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
|
| 177 |
+
trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
|
| 178 |
+
trimap_nonp = trimap.copy()
|
| 179 |
+
h, w, c = rawimg.shape
|
| 180 |
+
nonph, nonpw, _ = rawimg.shape
|
| 181 |
+
newh = (((h - 1) // 32) + 1) * 32
|
| 182 |
+
neww = (((w - 1) // 32) + 1) * 32
|
| 183 |
+
padh = newh - h
|
| 184 |
+
padh1 = int(padh / 2)
|
| 185 |
+
padh2 = padh - padh1
|
| 186 |
+
padw = neww - w
|
| 187 |
+
padw1 = int(padw / 2)
|
| 188 |
+
padw2 = padw - padw1
|
| 189 |
+
rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
|
| 190 |
+
cv2.BORDER_REFLECT)
|
| 191 |
+
trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
|
| 192 |
+
cv2.BORDER_REFLECT)
|
| 193 |
+
h_pad, w_pad, _ = rawimg_pad.shape
|
| 194 |
+
tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
|
| 195 |
+
tritemp[:, :, 0] = (trimap_pad == 0)
|
| 196 |
+
tritemp[:, :, 1] = (trimap_pad == 128)
|
| 197 |
+
tritemp[:, :, 2] = (trimap_pad == 255)
|
| 198 |
+
tritempimgs = np.transpose(tritemp, (2, 0, 1))
|
| 199 |
+
tritempimgs = tritempimgs[np.newaxis, :, :, :]
|
| 200 |
+
img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
|
| 201 |
+
img = np.array(img, np.float32)
|
| 202 |
+
img = img / 255.
|
| 203 |
+
img = torch.from_numpy(img).cuda()
|
| 204 |
+
tritempimgs = torch.from_numpy(tritempimgs).cuda()
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
pred = matmodel(img, tritempimgs)
|
| 207 |
+
pred = pred.detach().cpu().numpy()[0]
|
| 208 |
+
pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
|
| 209 |
+
preda = pred[
|
| 210 |
+
0:1,
|
| 211 |
+
] * 255
|
| 212 |
+
preda = np.transpose(preda, (1, 2, 0))
|
| 213 |
+
preda = preda * (trimap_nonp[:, :, None]
|
| 214 |
+
== 128) + (trimap_nonp[:, :, None] == 255) * 255
|
| 215 |
+
preda = np.array(preda, np.uint8)
|
| 216 |
+
cv2.imwrite(p_outs, preda)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
#!/usr/bin/python3
|
| 220 |
+
class WindowAttention(nn.Module):
|
| 221 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 222 |
+
It supports both of shifted and non-shifted window.
|
| 223 |
+
Args:
|
| 224 |
+
dim (int): Number of input channels.
|
| 225 |
+
window_size (tuple[int]): The height and width of the window.
|
| 226 |
+
num_heads (int): Number of attention heads.
|
| 227 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 228 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 229 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 230 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(self,
|
| 234 |
+
dim,
|
| 235 |
+
window_size,
|
| 236 |
+
num_heads,
|
| 237 |
+
qkv_bias=True,
|
| 238 |
+
qk_scale=None,
|
| 239 |
+
attn_drop=0.,
|
| 240 |
+
proj_drop=0.):
|
| 241 |
+
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.dim = dim
|
| 244 |
+
self.window_size = window_size # Wh, Ww
|
| 245 |
+
self.num_heads = num_heads
|
| 246 |
+
head_dim = dim // num_heads
|
| 247 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 248 |
+
|
| 249 |
+
# define a parameter table of relative position bias
|
| 250 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 251 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
| 252 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 253 |
+
|
| 254 |
+
# get pair-wise relative position index for each token inside the window
|
| 255 |
+
coords_h = torch.arange(self.window_size[0])
|
| 256 |
+
coords_w = torch.arange(self.window_size[1])
|
| 257 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 258 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 259 |
+
relative_coords = coords_flatten[:, :,
|
| 260 |
+
None] - coords_flatten[:,
|
| 261 |
+
None, :] # 2, Wh*Ww, Wh*Ww
|
| 262 |
+
relative_coords = relative_coords.permute(
|
| 263 |
+
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 264 |
+
relative_coords[:, :,
|
| 265 |
+
0] += self.window_size[0] - 1 # shift to start from 0
|
| 266 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 267 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 268 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 269 |
+
self.register_buffer("relative_position_index",
|
| 270 |
+
relative_position_index)
|
| 271 |
+
|
| 272 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 273 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 274 |
+
self.proj = nn.Linear(dim, dim)
|
| 275 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 276 |
+
|
| 277 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 278 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 279 |
+
|
| 280 |
+
def forward(self, x, mask=None):
|
| 281 |
+
""" Forward function.
|
| 282 |
+
Args:
|
| 283 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 284 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 285 |
+
"""
|
| 286 |
+
B_, N, C = x.shape
|
| 287 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
| 288 |
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 289 |
+
q, k, v = qkv[0], qkv[1], qkv[
|
| 290 |
+
2] # make torchscript happy (cannot use tensor as tuple)
|
| 291 |
+
|
| 292 |
+
q = q * self.scale
|
| 293 |
+
attn = (q @ k.transpose(-2, -1))
|
| 294 |
+
|
| 295 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 296 |
+
self.relative_position_index.view(-1)].view(
|
| 297 |
+
self.window_size[0] * self.window_size[1],
|
| 298 |
+
self.window_size[0] * self.window_size[1],
|
| 299 |
+
-1) # Wh*Ww,Wh*Ww,nH
|
| 300 |
+
relative_position_bias = relative_position_bias.permute(
|
| 301 |
+
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 302 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 303 |
+
|
| 304 |
+
if mask is not None:
|
| 305 |
+
nW = mask.shape[0]
|
| 306 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
| 307 |
+
N) + mask.unsqueeze(1).unsqueeze(0)
|
| 308 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 309 |
+
attn = self.softmax(attn)
|
| 310 |
+
else:
|
| 311 |
+
attn = self.softmax(attn)
|
| 312 |
+
|
| 313 |
+
attn = self.attn_drop(attn)
|
| 314 |
+
|
| 315 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 316 |
+
x = self.proj(x)
|
| 317 |
+
x = self.proj_drop(x)
|
| 318 |
+
return x
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class SwinTransformerBlock(nn.Module):
|
| 322 |
+
""" Swin Transformer Block.
|
| 323 |
+
Args:
|
| 324 |
+
dim (int): Number of input channels.
|
| 325 |
+
num_heads (int): Number of attention heads.
|
| 326 |
+
window_size (int): Window size.
|
| 327 |
+
shift_size (int): Shift size for SW-MSA.
|
| 328 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 329 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 330 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 331 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 332 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 333 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 334 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 335 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(self,
|
| 339 |
+
dim,
|
| 340 |
+
num_heads,
|
| 341 |
+
window_size=7,
|
| 342 |
+
shift_size=0,
|
| 343 |
+
mlp_ratio=4.,
|
| 344 |
+
qkv_bias=True,
|
| 345 |
+
qk_scale=None,
|
| 346 |
+
drop=0.,
|
| 347 |
+
attn_drop=0.,
|
| 348 |
+
drop_path=0.,
|
| 349 |
+
act_layer=nn.GELU,
|
| 350 |
+
norm_layer=nn.LayerNorm):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.dim = dim
|
| 353 |
+
self.num_heads = num_heads
|
| 354 |
+
self.window_size = window_size
|
| 355 |
+
self.shift_size = shift_size
|
| 356 |
+
self.mlp_ratio = mlp_ratio
|
| 357 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 358 |
+
|
| 359 |
+
self.norm1 = norm_layer(dim)
|
| 360 |
+
self.attn = WindowAttention(dim,
|
| 361 |
+
window_size=to_2tuple(self.window_size),
|
| 362 |
+
num_heads=num_heads,
|
| 363 |
+
qkv_bias=qkv_bias,
|
| 364 |
+
qk_scale=qk_scale,
|
| 365 |
+
attn_drop=attn_drop,
|
| 366 |
+
proj_drop=drop)
|
| 367 |
+
|
| 368 |
+
self.drop_path = DropPath(
|
| 369 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 370 |
+
self.norm2 = norm_layer(dim)
|
| 371 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 372 |
+
self.mlp = Mlp(in_features=dim,
|
| 373 |
+
hidden_features=mlp_hidden_dim,
|
| 374 |
+
act_layer=act_layer,
|
| 375 |
+
drop=drop)
|
| 376 |
+
|
| 377 |
+
self.H = None
|
| 378 |
+
self.W = None
|
| 379 |
+
|
| 380 |
+
def forward(self, x, mask_matrix):
|
| 381 |
+
""" Forward function.
|
| 382 |
+
Args:
|
| 383 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 384 |
+
H, W: Spatial resolution of the input feature.
|
| 385 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 386 |
+
"""
|
| 387 |
+
B, L, C = x.shape
|
| 388 |
+
H, W = self.H, self.W
|
| 389 |
+
assert L == H * W, "input feature has wrong size"
|
| 390 |
+
|
| 391 |
+
shortcut = x
|
| 392 |
+
x = self.norm1(x)
|
| 393 |
+
x = x.view(B, H, W, C)
|
| 394 |
+
|
| 395 |
+
# pad feature maps to multiples of window size
|
| 396 |
+
pad_l = pad_t = 0
|
| 397 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 398 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 399 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 400 |
+
_, Hp, Wp, _ = x.shape
|
| 401 |
+
|
| 402 |
+
# cyclic shift
|
| 403 |
+
if self.shift_size > 0:
|
| 404 |
+
shifted_x = torch.roll(x,
|
| 405 |
+
shifts=(-self.shift_size, -self.shift_size),
|
| 406 |
+
dims=(1, 2))
|
| 407 |
+
attn_mask = mask_matrix
|
| 408 |
+
else:
|
| 409 |
+
shifted_x = x
|
| 410 |
+
attn_mask = None
|
| 411 |
+
|
| 412 |
+
# partition windows
|
| 413 |
+
x_windows = window_partition(
|
| 414 |
+
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 415 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
| 416 |
+
C) # nW*B, window_size*window_size, C
|
| 417 |
+
|
| 418 |
+
# W-MSA/SW-MSA
|
| 419 |
+
attn_windows = self.attn(
|
| 420 |
+
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 421 |
+
|
| 422 |
+
# merge windows
|
| 423 |
+
attn_windows = attn_windows.view(-1, self.window_size,
|
| 424 |
+
self.window_size, C)
|
| 425 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
| 426 |
+
Wp) # B H' W' C
|
| 427 |
+
|
| 428 |
+
# reverse cyclic shift
|
| 429 |
+
if self.shift_size > 0:
|
| 430 |
+
x = torch.roll(shifted_x,
|
| 431 |
+
shifts=(self.shift_size, self.shift_size),
|
| 432 |
+
dims=(1, 2))
|
| 433 |
+
else:
|
| 434 |
+
x = shifted_x
|
| 435 |
+
|
| 436 |
+
if pad_r > 0 or pad_b > 0:
|
| 437 |
+
x = x[:, :H, :W, :].contiguous()
|
| 438 |
+
|
| 439 |
+
x = x.view(B, H * W, C)
|
| 440 |
+
|
| 441 |
+
# FFN
|
| 442 |
+
x = shortcut + self.drop_path(x)
|
| 443 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 444 |
+
|
| 445 |
+
return x
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class PatchMerging(nn.Module):
|
| 449 |
+
""" Patch Merging Layer
|
| 450 |
+
Args:
|
| 451 |
+
dim (int): Number of input channels.
|
| 452 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 456 |
+
super().__init__()
|
| 457 |
+
self.dim = dim
|
| 458 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 459 |
+
self.norm = norm_layer(4 * dim)
|
| 460 |
+
|
| 461 |
+
def forward(self, x, H, W):
|
| 462 |
+
""" Forward function.
|
| 463 |
+
Args:
|
| 464 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 465 |
+
H, W: Spatial resolution of the input feature.
|
| 466 |
+
"""
|
| 467 |
+
B, L, C = x.shape
|
| 468 |
+
assert L == H * W, "input feature has wrong size"
|
| 469 |
+
|
| 470 |
+
x = x.view(B, H, W, C)
|
| 471 |
+
|
| 472 |
+
# padding
|
| 473 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 474 |
+
if pad_input:
|
| 475 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 476 |
+
|
| 477 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 478 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 479 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 480 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 481 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 482 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 483 |
+
|
| 484 |
+
x = self.norm(x)
|
| 485 |
+
x = self.reduction(x)
|
| 486 |
+
|
| 487 |
+
return x
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class BasicLayer(nn.Module):
|
| 491 |
+
""" A basic Swin Transformer layer for one stage.
|
| 492 |
+
Args:
|
| 493 |
+
dim (int): Number of feature channels
|
| 494 |
+
depth (int): Depths of this stage.
|
| 495 |
+
num_heads (int): Number of attention head.
|
| 496 |
+
window_size (int): Local window size. Default: 7.
|
| 497 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 498 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 499 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 500 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 501 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 502 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 503 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 504 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 505 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 506 |
+
"""
|
| 507 |
+
|
| 508 |
+
def __init__(self,
|
| 509 |
+
dim,
|
| 510 |
+
depth,
|
| 511 |
+
num_heads,
|
| 512 |
+
window_size=7,
|
| 513 |
+
mlp_ratio=4.,
|
| 514 |
+
qkv_bias=True,
|
| 515 |
+
qk_scale=None,
|
| 516 |
+
drop=0.,
|
| 517 |
+
attn_drop=0.,
|
| 518 |
+
drop_path=0.,
|
| 519 |
+
norm_layer=nn.LayerNorm,
|
| 520 |
+
downsample=None,
|
| 521 |
+
use_checkpoint=False):
|
| 522 |
+
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.window_size = window_size
|
| 525 |
+
self.shift_size = window_size // 2
|
| 526 |
+
self.depth = depth
|
| 527 |
+
self.use_checkpoint = use_checkpoint
|
| 528 |
+
|
| 529 |
+
# build blocks
|
| 530 |
+
self.blocks = nn.ModuleList([
|
| 531 |
+
SwinTransformerBlock(dim=dim,
|
| 532 |
+
num_heads=num_heads,
|
| 533 |
+
window_size=window_size,
|
| 534 |
+
shift_size=0 if
|
| 535 |
+
(i % 2 == 0) else window_size // 2,
|
| 536 |
+
mlp_ratio=mlp_ratio,
|
| 537 |
+
qkv_bias=qkv_bias,
|
| 538 |
+
qk_scale=qk_scale,
|
| 539 |
+
drop=drop,
|
| 540 |
+
attn_drop=attn_drop,
|
| 541 |
+
drop_path=drop_path[i] if isinstance(
|
| 542 |
+
drop_path, list) else drop_path,
|
| 543 |
+
norm_layer=norm_layer) for i in range(depth)
|
| 544 |
+
])
|
| 545 |
+
|
| 546 |
+
# patch merging layer
|
| 547 |
+
if downsample is not None:
|
| 548 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 549 |
+
else:
|
| 550 |
+
self.downsample = None
|
| 551 |
+
|
| 552 |
+
def forward(self, x, H, W):
|
| 553 |
+
""" Forward function.
|
| 554 |
+
Args:
|
| 555 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 556 |
+
H, W: Spatial resolution of the input feature.
|
| 557 |
+
"""
|
| 558 |
+
# print(x.shape,H,W)
|
| 559 |
+
# calculate attention mask for SW-MSA
|
| 560 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 561 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 562 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 563 |
+
h_slices = (slice(0, -self.window_size),
|
| 564 |
+
slice(-self.window_size,
|
| 565 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 566 |
+
w_slices = (slice(0, -self.window_size),
|
| 567 |
+
slice(-self.window_size,
|
| 568 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 569 |
+
cnt = 0
|
| 570 |
+
for h in h_slices:
|
| 571 |
+
for w in w_slices:
|
| 572 |
+
img_mask[:, h, w, :] = cnt
|
| 573 |
+
cnt += 1
|
| 574 |
+
|
| 575 |
+
mask_windows = window_partition(
|
| 576 |
+
img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 577 |
+
|
| 578 |
+
mask_windows = mask_windows.view(-1,
|
| 579 |
+
self.window_size * self.window_size)
|
| 580 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
|
| 581 |
+
2) # nW, ww window_size*window_size
|
| 582 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
| 583 |
+
float(-100.0)).masked_fill(
|
| 584 |
+
attn_mask == 0, float(0.0))
|
| 585 |
+
|
| 586 |
+
for blk in self.blocks:
|
| 587 |
+
blk.H, blk.W = H, W
|
| 588 |
+
if self.use_checkpoint:
|
| 589 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 590 |
+
else:
|
| 591 |
+
x = blk(x, attn_mask)
|
| 592 |
+
|
| 593 |
+
if self.downsample is not None:
|
| 594 |
+
x_down = self.downsample(x, H, W)
|
| 595 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 596 |
+
return x, H, W, x_down, Wh, Ww
|
| 597 |
+
else:
|
| 598 |
+
return x, H, W, x, H, W
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class PatchEmbed(nn.Module):
|
| 602 |
+
""" Image to Patch Embedding
|
| 603 |
+
Args:
|
| 604 |
+
patch_size (int): Patch token size. Default: 4.
|
| 605 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 606 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 607 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 608 |
+
"""
|
| 609 |
+
|
| 610 |
+
def __init__(self,
|
| 611 |
+
patch_size=4,
|
| 612 |
+
in_chans=3,
|
| 613 |
+
embed_dim=96,
|
| 614 |
+
norm_layer=None):
|
| 615 |
+
|
| 616 |
+
super().__init__()
|
| 617 |
+
patch_size = to_2tuple(patch_size)
|
| 618 |
+
self.patch_size = patch_size
|
| 619 |
+
|
| 620 |
+
self.in_chans = in_chans
|
| 621 |
+
self.embed_dim = embed_dim
|
| 622 |
+
|
| 623 |
+
self.proj = nn.Conv2d(in_chans,
|
| 624 |
+
embed_dim,
|
| 625 |
+
kernel_size=patch_size,
|
| 626 |
+
stride=patch_size)
|
| 627 |
+
if norm_layer is not None:
|
| 628 |
+
self.norm = norm_layer(embed_dim)
|
| 629 |
+
else:
|
| 630 |
+
self.norm = None
|
| 631 |
+
|
| 632 |
+
def forward(self, x):
|
| 633 |
+
"""Forward function."""
|
| 634 |
+
# padding
|
| 635 |
+
_, _, H, W = x.size()
|
| 636 |
+
if W % self.patch_size[1] != 0:
|
| 637 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 638 |
+
if H % self.patch_size[0] != 0:
|
| 639 |
+
x = F.pad(x,
|
| 640 |
+
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 641 |
+
|
| 642 |
+
x = self.proj(x) # B C Wh Ww
|
| 643 |
+
if self.norm is not None:
|
| 644 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 645 |
+
x = x.flatten(2).transpose(1, 2)
|
| 646 |
+
x = self.norm(x)
|
| 647 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 648 |
+
|
| 649 |
+
return x
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class SwinTransformer(nn.Module):
|
| 653 |
+
""" Swin Transformer backbone.
|
| 654 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 655 |
+
https://arxiv.org/pdf/2103.14030
|
| 656 |
+
Args:
|
| 657 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 658 |
+
used in absolute postion embedding. Default 224.
|
| 659 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 660 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 661 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 662 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 663 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 664 |
+
window_size (int): Window size. Default: 7.
|
| 665 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 666 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 667 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 668 |
+
drop_rate (float): Dropout rate.
|
| 669 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 670 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 671 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 672 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 673 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 674 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 675 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 676 |
+
-1 means not freezing any parameters.
|
| 677 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
def __init__(self,
|
| 681 |
+
pretrain_img_size=224,
|
| 682 |
+
patch_size=4,
|
| 683 |
+
in_chans=3,
|
| 684 |
+
embed_dim=96,
|
| 685 |
+
depths=[2, 2, 6, 2],
|
| 686 |
+
num_heads=[3, 6, 12, 24],
|
| 687 |
+
window_size=7,
|
| 688 |
+
mlp_ratio=4.,
|
| 689 |
+
qkv_bias=True,
|
| 690 |
+
qk_scale=None,
|
| 691 |
+
drop_rate=0.,
|
| 692 |
+
attn_drop_rate=0.,
|
| 693 |
+
drop_path_rate=0.2,
|
| 694 |
+
norm_layer=nn.LayerNorm,
|
| 695 |
+
ape=False,
|
| 696 |
+
patch_norm=True,
|
| 697 |
+
out_indices=(0, 1, 2, 3),
|
| 698 |
+
frozen_stages=-1,
|
| 699 |
+
use_checkpoint=False):
|
| 700 |
+
|
| 701 |
+
super().__init__()
|
| 702 |
+
|
| 703 |
+
self.pretrain_img_size = pretrain_img_size
|
| 704 |
+
self.num_layers = len(depths)
|
| 705 |
+
self.embed_dim = embed_dim
|
| 706 |
+
self.ape = ape
|
| 707 |
+
self.patch_norm = patch_norm
|
| 708 |
+
self.out_indices = out_indices
|
| 709 |
+
self.frozen_stages = frozen_stages
|
| 710 |
+
|
| 711 |
+
# split image into non-overlapping patches
|
| 712 |
+
self.patch_embed = PatchEmbed(
|
| 713 |
+
patch_size=patch_size,
|
| 714 |
+
in_chans=in_chans,
|
| 715 |
+
embed_dim=embed_dim,
|
| 716 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 717 |
+
|
| 718 |
+
# absolute position embedding
|
| 719 |
+
if self.ape:
|
| 720 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 721 |
+
patch_size = to_2tuple(patch_size)
|
| 722 |
+
patches_resolution = [
|
| 723 |
+
pretrain_img_size[0] // patch_size[0],
|
| 724 |
+
pretrain_img_size[1] // patch_size[1]
|
| 725 |
+
]
|
| 726 |
+
|
| 727 |
+
self.absolute_pos_embed = nn.Parameter(
|
| 728 |
+
torch.zeros(1, embed_dim, patches_resolution[0],
|
| 729 |
+
patches_resolution[1]))
|
| 730 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 731 |
+
|
| 732 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 733 |
+
|
| 734 |
+
# stochastic depth
|
| 735 |
+
dpr = [
|
| 736 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 737 |
+
] # stochastic depth decay rule
|
| 738 |
+
|
| 739 |
+
# build layers
|
| 740 |
+
self.layers = nn.ModuleList()
|
| 741 |
+
for i_layer in range(self.num_layers):
|
| 742 |
+
layer = BasicLayer(
|
| 743 |
+
dim=int(embed_dim * 2**i_layer),
|
| 744 |
+
depth=depths[i_layer],
|
| 745 |
+
num_heads=num_heads[i_layer],
|
| 746 |
+
window_size=window_size,
|
| 747 |
+
mlp_ratio=mlp_ratio,
|
| 748 |
+
qkv_bias=qkv_bias,
|
| 749 |
+
qk_scale=qk_scale,
|
| 750 |
+
drop=drop_rate,
|
| 751 |
+
attn_drop=attn_drop_rate,
|
| 752 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 753 |
+
norm_layer=norm_layer,
|
| 754 |
+
downsample=PatchMerging if
|
| 755 |
+
(i_layer < self.num_layers - 1) else None,
|
| 756 |
+
use_checkpoint=use_checkpoint)
|
| 757 |
+
self.layers.append(layer)
|
| 758 |
+
|
| 759 |
+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
| 760 |
+
self.num_features = num_features
|
| 761 |
+
|
| 762 |
+
# add a norm layer for each output
|
| 763 |
+
for i_layer in out_indices:
|
| 764 |
+
layer = norm_layer(num_features[i_layer])
|
| 765 |
+
layer_name = f'norm{i_layer}'
|
| 766 |
+
self.add_module(layer_name, layer)
|
| 767 |
+
|
| 768 |
+
self._freeze_stages()
|
| 769 |
+
|
| 770 |
+
def _freeze_stages(self):
|
| 771 |
+
if self.frozen_stages >= 0:
|
| 772 |
+
self.patch_embed.eval()
|
| 773 |
+
for param in self.patch_embed.parameters():
|
| 774 |
+
param.requires_grad = False
|
| 775 |
+
|
| 776 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 777 |
+
self.absolute_pos_embed.requires_grad = False
|
| 778 |
+
|
| 779 |
+
if self.frozen_stages >= 2:
|
| 780 |
+
self.pos_drop.eval()
|
| 781 |
+
for i in range(0, self.frozen_stages - 1):
|
| 782 |
+
m = self.layers[i]
|
| 783 |
+
m.eval()
|
| 784 |
+
for param in m.parameters():
|
| 785 |
+
param.requires_grad = False
|
| 786 |
+
|
| 787 |
+
def init_weights(self, pretrained=None):
|
| 788 |
+
"""Initialize the weights in backbone.
|
| 789 |
+
Args:
|
| 790 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 791 |
+
Defaults to None.
|
| 792 |
+
"""
|
| 793 |
+
|
| 794 |
+
def forward(self, x):
|
| 795 |
+
"""Forward function."""
|
| 796 |
+
x = self.patch_embed(x)
|
| 797 |
+
|
| 798 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 799 |
+
if self.ape:
|
| 800 |
+
# interpolate the position embedding to the corresponding size
|
| 801 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
|
| 802 |
+
size=(Wh, Ww),
|
| 803 |
+
mode='bicubic')
|
| 804 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1,
|
| 805 |
+
2) # B Wh*Ww C
|
| 806 |
+
else:
|
| 807 |
+
x = x.flatten(2).transpose(1, 2)
|
| 808 |
+
x = self.pos_drop(x)
|
| 809 |
+
|
| 810 |
+
outs = []
|
| 811 |
+
for i in range(self.num_layers):
|
| 812 |
+
layer = self.layers[i]
|
| 813 |
+
|
| 814 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 815 |
+
|
| 816 |
+
if i in self.out_indices:
|
| 817 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 818 |
+
x_out = norm_layer(x_out)
|
| 819 |
+
|
| 820 |
+
out = x_out.view(-1, H, W,
|
| 821 |
+
self.num_features[i]).permute(0, 3, 1,
|
| 822 |
+
2).contiguous()
|
| 823 |
+
outs.append(out)
|
| 824 |
+
|
| 825 |
+
return tuple(outs)
|
| 826 |
+
|
| 827 |
+
def train(self, mode=True):
|
| 828 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 829 |
+
super(SwinTransformer, self).train(mode)
|
| 830 |
+
self._freeze_stages()
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
class Mlp(nn.Module):
|
| 834 |
+
""" Multilayer perceptron."""
|
| 835 |
+
|
| 836 |
+
def __init__(self,
|
| 837 |
+
in_features,
|
| 838 |
+
hidden_features=None,
|
| 839 |
+
out_features=None,
|
| 840 |
+
act_layer=nn.GELU,
|
| 841 |
+
drop=0.):
|
| 842 |
+
super().__init__()
|
| 843 |
+
out_features = out_features or in_features
|
| 844 |
+
hidden_features = hidden_features or in_features
|
| 845 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 846 |
+
self.act = act_layer()
|
| 847 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 848 |
+
self.drop = nn.Dropout(drop)
|
| 849 |
+
|
| 850 |
+
def forward(self, x):
|
| 851 |
+
x = self.fc1(x)
|
| 852 |
+
x = self.act(x)
|
| 853 |
+
x = self.drop(x)
|
| 854 |
+
x = self.fc2(x)
|
| 855 |
+
x = self.drop(x)
|
| 856 |
+
return x
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class ResBlock(nn.Module):
|
| 860 |
+
|
| 861 |
+
def __init__(self, inc, midc):
|
| 862 |
+
super(ResBlock, self).__init__()
|
| 863 |
+
self.conv1 = nn.Conv2d(inc,
|
| 864 |
+
midc,
|
| 865 |
+
kernel_size=1,
|
| 866 |
+
stride=1,
|
| 867 |
+
padding=0,
|
| 868 |
+
bias=True)
|
| 869 |
+
self.gn1 = nn.GroupNorm(16, midc)
|
| 870 |
+
self.conv2 = nn.Conv2d(midc,
|
| 871 |
+
midc,
|
| 872 |
+
kernel_size=3,
|
| 873 |
+
stride=1,
|
| 874 |
+
padding=1,
|
| 875 |
+
bias=True)
|
| 876 |
+
self.gn2 = nn.GroupNorm(16, midc)
|
| 877 |
+
self.conv3 = nn.Conv2d(midc,
|
| 878 |
+
inc,
|
| 879 |
+
kernel_size=1,
|
| 880 |
+
stride=1,
|
| 881 |
+
padding=0,
|
| 882 |
+
bias=True)
|
| 883 |
+
self.relu = nn.LeakyReLU(0.1)
|
| 884 |
+
|
| 885 |
+
def forward(self, x):
|
| 886 |
+
x_ = x
|
| 887 |
+
x = self.conv1(x)
|
| 888 |
+
x = self.gn1(x)
|
| 889 |
+
x = self.relu(x)
|
| 890 |
+
x = self.conv2(x)
|
| 891 |
+
x = self.gn2(x)
|
| 892 |
+
x = self.relu(x)
|
| 893 |
+
x = self.conv3(x)
|
| 894 |
+
x = x + x_
|
| 895 |
+
x = self.relu(x)
|
| 896 |
+
return x
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
class AEALblock(nn.Module):
|
| 900 |
+
|
| 901 |
+
def __init__(self,
|
| 902 |
+
d_model,
|
| 903 |
+
nhead,
|
| 904 |
+
dim_feedforward=512,
|
| 905 |
+
dropout=0.0,
|
| 906 |
+
layer_norm_eps=1e-5,
|
| 907 |
+
batch_first=True,
|
| 908 |
+
norm_first=False,
|
| 909 |
+
width=5):
|
| 910 |
+
super(AEALblock, self).__init__()
|
| 911 |
+
self.self_attn2 = nn.MultiheadAttention(d_model // 2,
|
| 912 |
+
nhead // 2,
|
| 913 |
+
dropout=dropout,
|
| 914 |
+
batch_first=batch_first)
|
| 915 |
+
self.self_attn1 = nn.MultiheadAttention(d_model // 2,
|
| 916 |
+
nhead // 2,
|
| 917 |
+
dropout=dropout,
|
| 918 |
+
batch_first=batch_first)
|
| 919 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 920 |
+
self.dropout = nn.Dropout(dropout)
|
| 921 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 922 |
+
self.norm_first = norm_first
|
| 923 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 924 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 925 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 926 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 927 |
+
self.activation = nn.ReLU()
|
| 928 |
+
self.width = width
|
| 929 |
+
self.trans = nn.Sequential(
|
| 930 |
+
nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
|
| 931 |
+
ResBlock(d_model // 2, d_model // 4),
|
| 932 |
+
nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
|
| 933 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
| 934 |
+
|
| 935 |
+
def forward(
|
| 936 |
+
self,
|
| 937 |
+
src,
|
| 938 |
+
feats,
|
| 939 |
+
):
|
| 940 |
+
src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
|
| 941 |
+
b, c, h, w = src.shape
|
| 942 |
+
x1 = src[:, 0:c // 2]
|
| 943 |
+
x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
|
| 944 |
+
x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
|
| 945 |
+
x2 = src[:, c // 2:]
|
| 946 |
+
x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
|
| 947 |
+
x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
|
| 948 |
+
x = rearrange(src, 'b c h w-> b (h w) c')
|
| 949 |
+
x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
|
| 950 |
+
x = self.norm2(x + self._ff_block(x))
|
| 951 |
+
x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
|
| 952 |
+
return x
|
| 953 |
+
|
| 954 |
+
def _sa_block(self, x1, x2, h, w):
|
| 955 |
+
x1 = self.self_attn1(x1,
|
| 956 |
+
x1,
|
| 957 |
+
x1,
|
| 958 |
+
attn_mask=None,
|
| 959 |
+
key_padding_mask=None,
|
| 960 |
+
need_weights=False)[0]
|
| 961 |
+
|
| 962 |
+
x2 = self.self_attn2(x2,
|
| 963 |
+
x2,
|
| 964 |
+
x2,
|
| 965 |
+
attn_mask=None,
|
| 966 |
+
key_padding_mask=None,
|
| 967 |
+
need_weights=False)[0]
|
| 968 |
+
|
| 969 |
+
x1 = rearrange(x1,
|
| 970 |
+
'(b h1) (h2 w) c-> b (h1 h2 w) c',
|
| 971 |
+
h2=self.width,
|
| 972 |
+
h1=h // self.width)
|
| 973 |
+
x2 = rearrange(x2,
|
| 974 |
+
' (b w1) (h w2) c-> b (h w1 w2) c',
|
| 975 |
+
w2=self.width,
|
| 976 |
+
w1=w // self.width)
|
| 977 |
+
x = torch.cat([x1, x2], dim=2)
|
| 978 |
+
return self.dropout1(x)
|
| 979 |
+
|
| 980 |
+
def _ff_block(self, x):
|
| 981 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 982 |
+
return self.dropout2(x)
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
class AEMatter(nn.Module):
|
| 986 |
+
|
| 987 |
+
def __init__(self):
|
| 988 |
+
super(AEMatter, self).__init__()
|
| 989 |
+
trans = SwinTransformer(pretrain_img_size=224,
|
| 990 |
+
embed_dim=96,
|
| 991 |
+
depths=[2, 2, 6, 2],
|
| 992 |
+
num_heads=[3, 6, 12, 24],
|
| 993 |
+
window_size=7,
|
| 994 |
+
ape=False,
|
| 995 |
+
drop_path_rate=0.2,
|
| 996 |
+
patch_norm=True,
|
| 997 |
+
use_checkpoint=False)
|
| 998 |
+
|
| 999 |
+
# trans.load_state_dict(torch.load(
|
| 1000 |
+
# '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
|
| 1001 |
+
# map_location="cpu")["model"],
|
| 1002 |
+
# strict=False)
|
| 1003 |
+
|
| 1004 |
+
trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
|
| 1005 |
+
|
| 1006 |
+
self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
|
| 1007 |
+
nn.PReLU(48))
|
| 1008 |
+
|
| 1009 |
+
self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
|
| 1010 |
+
1), nn.PReLU(64),
|
| 1011 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 1012 |
+
nn.PReLU(64))
|
| 1013 |
+
|
| 1014 |
+
self.trans = trans
|
| 1015 |
+
self.conv1 = nn.Sequential(
|
| 1016 |
+
nn.Conv2d(in_channels=640 + 768,
|
| 1017 |
+
out_channels=256,
|
| 1018 |
+
kernel_size=1,
|
| 1019 |
+
stride=1,
|
| 1020 |
+
padding=0,
|
| 1021 |
+
bias=True))
|
| 1022 |
+
self.conv2 = nn.Sequential(
|
| 1023 |
+
nn.Conv2d(in_channels=256 + 384,
|
| 1024 |
+
out_channels=256,
|
| 1025 |
+
kernel_size=1,
|
| 1026 |
+
stride=1,
|
| 1027 |
+
padding=0,
|
| 1028 |
+
bias=True), )
|
| 1029 |
+
self.conv3 = nn.Sequential(
|
| 1030 |
+
nn.Conv2d(in_channels=256 + 192,
|
| 1031 |
+
out_channels=192,
|
| 1032 |
+
kernel_size=1,
|
| 1033 |
+
stride=1,
|
| 1034 |
+
padding=0,
|
| 1035 |
+
bias=True), )
|
| 1036 |
+
self.conv4 = nn.Sequential(
|
| 1037 |
+
nn.Conv2d(in_channels=192 + 96,
|
| 1038 |
+
out_channels=128,
|
| 1039 |
+
kernel_size=1,
|
| 1040 |
+
stride=1,
|
| 1041 |
+
padding=0,
|
| 1042 |
+
bias=True), )
|
| 1043 |
+
self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
|
| 1044 |
+
self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
|
| 1045 |
+
self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
|
| 1046 |
+
self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
|
| 1047 |
+
self.conv5 = nn.Sequential(
|
| 1048 |
+
nn.Conv2d(in_channels=192,
|
| 1049 |
+
out_channels=64,
|
| 1050 |
+
kernel_size=3,
|
| 1051 |
+
stride=1,
|
| 1052 |
+
padding=1,
|
| 1053 |
+
bias=True), nn.PReLU(64),
|
| 1054 |
+
nn.Conv2d(in_channels=64,
|
| 1055 |
+
out_channels=64,
|
| 1056 |
+
kernel_size=3,
|
| 1057 |
+
stride=1,
|
| 1058 |
+
padding=1,
|
| 1059 |
+
bias=True), nn.PReLU(64),
|
| 1060 |
+
nn.Conv2d(in_channels=64,
|
| 1061 |
+
out_channels=48,
|
| 1062 |
+
kernel_size=3,
|
| 1063 |
+
stride=1,
|
| 1064 |
+
padding=1,
|
| 1065 |
+
bias=True), nn.PReLU(48))
|
| 1066 |
+
self.convo = nn.Sequential(
|
| 1067 |
+
nn.Conv2d(in_channels=48 + 48 + 6,
|
| 1068 |
+
out_channels=32,
|
| 1069 |
+
kernel_size=3,
|
| 1070 |
+
stride=1,
|
| 1071 |
+
padding=1,
|
| 1072 |
+
bias=True), nn.PReLU(32),
|
| 1073 |
+
nn.Conv2d(in_channels=32,
|
| 1074 |
+
out_channels=32,
|
| 1075 |
+
kernel_size=3,
|
| 1076 |
+
stride=1,
|
| 1077 |
+
padding=1,
|
| 1078 |
+
bias=True), nn.PReLU(32),
|
| 1079 |
+
nn.Conv2d(in_channels=32,
|
| 1080 |
+
out_channels=1,
|
| 1081 |
+
kernel_size=3,
|
| 1082 |
+
stride=1,
|
| 1083 |
+
padding=1,
|
| 1084 |
+
bias=True))
|
| 1085 |
+
self.up = nn.Upsample(scale_factor=2,
|
| 1086 |
+
mode='bilinear',
|
| 1087 |
+
align_corners=False)
|
| 1088 |
+
self.upn = nn.Upsample(scale_factor=2, mode='nearest')
|
| 1089 |
+
self.apptrans = nn.Sequential(
|
| 1090 |
+
nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
|
| 1091 |
+
ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
|
| 1092 |
+
ResBlock(512, 128))
|
| 1093 |
+
self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
|
| 1094 |
+
ResBlock(640, 160))
|
| 1095 |
+
self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
|
| 1096 |
+
self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
|
| 1097 |
+
self.width = 5
|
| 1098 |
+
self.trans1 = AEALblock(d_model=640,
|
| 1099 |
+
nhead=20,
|
| 1100 |
+
dim_feedforward=2048,
|
| 1101 |
+
dropout=0.2,
|
| 1102 |
+
width=self.width)
|
| 1103 |
+
self.trans2 = AEALblock(d_model=640,
|
| 1104 |
+
nhead=20,
|
| 1105 |
+
dim_feedforward=2048,
|
| 1106 |
+
dropout=0.2,
|
| 1107 |
+
width=self.width)
|
| 1108 |
+
self.trans3 = AEALblock(d_model=640,
|
| 1109 |
+
nhead=20,
|
| 1110 |
+
dim_feedforward=2048,
|
| 1111 |
+
dropout=0.2,
|
| 1112 |
+
width=self.width)
|
| 1113 |
+
|
| 1114 |
+
def aeal(self, x, sem):
|
| 1115 |
+
xe = self.emb(x)
|
| 1116 |
+
x_ = xe
|
| 1117 |
+
x_ = self.embdp(x_)
|
| 1118 |
+
b, c, h1, w1 = x_.shape
|
| 1119 |
+
bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
|
| 1120 |
+
bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
|
| 1121 |
+
newph1 = bnew_ph // 2
|
| 1122 |
+
newph2 = bnew_ph - newph1
|
| 1123 |
+
newpw1 = bnew_pw // 2
|
| 1124 |
+
newpw2 = bnew_pw - newpw1
|
| 1125 |
+
x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
|
| 1126 |
+
sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
|
| 1127 |
+
x_ = self.trans1(x_, sem)
|
| 1128 |
+
x_ = self.trans2(x_, sem)
|
| 1129 |
+
x_ = self.trans3(x_, sem)
|
| 1130 |
+
x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
|
| 1131 |
+
return x_
|
| 1132 |
+
|
| 1133 |
+
def forward(self, x, y):
|
| 1134 |
+
inputs = torch.cat((x, y), 1)
|
| 1135 |
+
x = self.start_conv0(inputs)
|
| 1136 |
+
x_ = self.start_conv(x)
|
| 1137 |
+
x1, x2, x3, x4 = self.trans(x_)
|
| 1138 |
+
x4h = self.h2l(x4)
|
| 1139 |
+
x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
|
| 1140 |
+
x4_ = self.aeal(x4, x3s)
|
| 1141 |
+
x4 = torch.cat((x4, x4_), 1)
|
| 1142 |
+
X4 = self.conv1(x4)
|
| 1143 |
+
wh, ww = X4.shape[2], X4.shape[3]
|
| 1144 |
+
X4 = rearrange(X4, 'b c h w -> b (h w) c')
|
| 1145 |
+
X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
|
| 1146 |
+
X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1147 |
+
X3 = self.up(X4)
|
| 1148 |
+
X3 = torch.cat((x3, X3), 1)
|
| 1149 |
+
X3 = self.conv2(X3)
|
| 1150 |
+
wh, ww = X3.shape[2], X3.shape[3]
|
| 1151 |
+
X3 = rearrange(X3, 'b c h w -> b (h w) c')
|
| 1152 |
+
X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
|
| 1153 |
+
X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1154 |
+
X2 = self.up(X3)
|
| 1155 |
+
X2 = torch.cat((x2, X2), 1)
|
| 1156 |
+
X2 = self.conv3(X2)
|
| 1157 |
+
wh, ww = X2.shape[2], X2.shape[3]
|
| 1158 |
+
X2 = rearrange(X2, 'b c h w -> b (h w) c')
|
| 1159 |
+
X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
|
| 1160 |
+
X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1161 |
+
X1 = self.up(X2)
|
| 1162 |
+
X1 = torch.cat((x1, X1), 1)
|
| 1163 |
+
X1 = self.conv4(X1)
|
| 1164 |
+
wh, ww = X1.shape[2], X1.shape[3]
|
| 1165 |
+
X1 = rearrange(X1, 'b c h w -> b (h w) c')
|
| 1166 |
+
X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
|
| 1167 |
+
X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1168 |
+
X0 = self.up(X1)
|
| 1169 |
+
X0 = torch.cat((x_, X0), 1)
|
| 1170 |
+
X0 = self.conv5(X0)
|
| 1171 |
+
X = self.up(X0)
|
| 1172 |
+
X = torch.cat((inputs, x, X), 1)
|
| 1173 |
+
alpha = self.convo(X)
|
| 1174 |
+
alpha = torch.clamp(alpha, min=0, max=1)
|
| 1175 |
+
return alpha
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
class load_AEMatter_Model:
|
| 1179 |
+
|
| 1180 |
+
def __init__(self):
|
| 1181 |
+
pass
|
| 1182 |
+
|
| 1183 |
+
@classmethod
|
| 1184 |
+
def INPUT_TYPES(s):
|
| 1185 |
+
return {
|
| 1186 |
+
"required": {},
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
RETURN_TYPES = ("AEMatter_Model", )
|
| 1190 |
+
FUNCTION = "test"
|
| 1191 |
+
CATEGORY = "AEMatter"
|
| 1192 |
+
|
| 1193 |
+
def test(self):
|
| 1194 |
+
return (get_AEMatter_model(get_model_path()), )
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
class run_AEMatter_inference:
|
| 1198 |
+
|
| 1199 |
+
def __init__(self):
|
| 1200 |
+
pass
|
| 1201 |
+
|
| 1202 |
+
@classmethod
|
| 1203 |
+
def INPUT_TYPES(s):
|
| 1204 |
+
return {
|
| 1205 |
+
"required": {
|
| 1206 |
+
"image": ("IMAGE", ),
|
| 1207 |
+
"trimap": ("MASK", ),
|
| 1208 |
+
"AEMatter_Model": ("AEMatter_Model", ),
|
| 1209 |
+
},
|
| 1210 |
+
}
|
| 1211 |
+
|
| 1212 |
+
RETURN_TYPES = ("MASK", )
|
| 1213 |
+
FUNCTION = "test"
|
| 1214 |
+
CATEGORY = "AEMatter"
|
| 1215 |
+
|
| 1216 |
+
def test(
|
| 1217 |
+
self,
|
| 1218 |
+
image,
|
| 1219 |
+
trimap,
|
| 1220 |
+
AEMatter_Model,
|
| 1221 |
+
):
|
| 1222 |
+
|
| 1223 |
+
ret = []
|
| 1224 |
+
batch_size = image.shape[0]
|
| 1225 |
+
|
| 1226 |
+
for i in range(batch_size):
|
| 1227 |
+
tmp_i = from_torch_image(image[i])
|
| 1228 |
+
tmp_m = from_torch_image(trimap[i])
|
| 1229 |
+
tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
|
| 1230 |
+
ret.append(tmp)
|
| 1231 |
+
|
| 1232 |
+
ret = to_torch_image(np.array(ret))
|
| 1233 |
+
ret = ret.squeeze(-1)
|
| 1234 |
+
print(ret.shape)
|
| 1235 |
+
|
| 1236 |
+
return ret
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
#!/usr/bin/python3
|
| 1240 |
+
NODE_CLASS_MAPPINGS = {
|
| 1241 |
+
'load_AEMatter_Model': load_AEMatter_Model,
|
| 1242 |
+
'run_AEMatter_inference': run_AEMatter_inference,
|
| 1243 |
+
}
|
| 1244 |
+
|
| 1245 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 1246 |
+
'load_AEMatter_Model': 'load_AEMatter_Model',
|
| 1247 |
+
'run_AEMatter_inference': 'run_AEMatter_inference',
|
| 1248 |
+
}
|
ComfyUI_AEMatter/AEMatter.run.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
. "${HOME}/dbnew.sh"
|
| 3 |
+
python3 './AEMatter.py'
|
ComfyUI_AEMatter/README.org
ADDED
|
@@ -0,0 +1,1357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
* COMMENT SAMPLE
|
| 2 |
+
|
| 3 |
+
** AEMatter.import.py
|
| 4 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.import.py
|
| 5 |
+
#+end_src
|
| 6 |
+
|
| 7 |
+
** AEMatter.function.py
|
| 8 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
|
| 9 |
+
#+end_src
|
| 10 |
+
|
| 11 |
+
** AEMatter.class.py
|
| 12 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 13 |
+
#+end_src
|
| 14 |
+
|
| 15 |
+
** AEMatter.execute.py
|
| 16 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
|
| 17 |
+
#+end_src
|
| 18 |
+
|
| 19 |
+
** AEMatter.unify.sh
|
| 20 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.unify.sh
|
| 21 |
+
#+end_src
|
| 22 |
+
|
| 23 |
+
** AEMatter.run.sh
|
| 24 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.run.sh
|
| 25 |
+
#+end_src
|
| 26 |
+
|
| 27 |
+
* Code for AEMatter inference
|
| 28 |
+
|
| 29 |
+
** AEMatter.import.py
|
| 30 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.import.py
|
| 31 |
+
import cv2
|
| 32 |
+
import math
|
| 33 |
+
import numpy as np
|
| 34 |
+
import os
|
| 35 |
+
import random
|
| 36 |
+
import wget
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
import torch.nn as nn
|
| 40 |
+
from torch.nn import init
|
| 41 |
+
import torch.nn.functional as F
|
| 42 |
+
import torch.utils.checkpoint as checkpoint
|
| 43 |
+
|
| 44 |
+
from collections import OrderedDict
|
| 45 |
+
from einops import rearrange, repeat
|
| 46 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 47 |
+
|
| 48 |
+
import folder_paths
|
| 49 |
+
from folder_paths import models_dir
|
| 50 |
+
#+end_src
|
| 51 |
+
|
| 52 |
+
** Functions to prepare directory structure and download models
|
| 53 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
|
| 54 |
+
def mkdir_safe(out_path):
|
| 55 |
+
if type(out_path) == str:
|
| 56 |
+
if len(out_path) > 0:
|
| 57 |
+
if not os.path.exists(out_path):
|
| 58 |
+
os.mkdir(out_path)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_model_path():
|
| 62 |
+
import folder_paths
|
| 63 |
+
from folder_paths import models_dir
|
| 64 |
+
|
| 65 |
+
path_file_model = models_dir
|
| 66 |
+
mkdir_safe(out_path=path_file_model)
|
| 67 |
+
|
| 68 |
+
path_file_model = os.path.join(path_file_model, 'AEMatter')
|
| 69 |
+
mkdir_safe(out_path=path_file_model)
|
| 70 |
+
|
| 71 |
+
path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
|
| 72 |
+
|
| 73 |
+
return path_file_model
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def download_model(path):
|
| 77 |
+
if not os.path.exists(path):
|
| 78 |
+
wget.download(
|
| 79 |
+
'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
|
| 80 |
+
out=path)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def from_torch_image(image):
|
| 84 |
+
image = image.cpu().numpy() * 255.0
|
| 85 |
+
image = np.clip(image, 0, 255).astype(np.uint8)
|
| 86 |
+
return image
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def to_torch_image(image):
|
| 90 |
+
image = image.astype(dtype=np.float32)
|
| 91 |
+
image /= 255.0
|
| 92 |
+
image = torch.from_numpy(image)
|
| 93 |
+
return image
|
| 94 |
+
#+end_src
|
| 95 |
+
|
| 96 |
+
** AEMatter.function.py
|
| 97 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
|
| 98 |
+
def window_partition(x, window_size):
|
| 99 |
+
"""
|
| 100 |
+
Args:
|
| 101 |
+
x: (B, H, W, C)
|
| 102 |
+
window_size (int): window size
|
| 103 |
+
Returns:
|
| 104 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 105 |
+
"""
|
| 106 |
+
B, H, W, C = x.shape
|
| 107 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 108 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 109 |
+
return windows
|
| 110 |
+
#+end_src
|
| 111 |
+
|
| 112 |
+
** AEMatter.function.py
|
| 113 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
|
| 114 |
+
def window_reverse(windows, window_size, H, W):
|
| 115 |
+
"""
|
| 116 |
+
Args:
|
| 117 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 118 |
+
window_size (int): Window size
|
| 119 |
+
H (int): Height of image
|
| 120 |
+
W (int): Width of image
|
| 121 |
+
Returns:
|
| 122 |
+
x: (B, H, W, C)
|
| 123 |
+
"""
|
| 124 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 125 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 126 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 127 |
+
return x
|
| 128 |
+
#+end_src
|
| 129 |
+
|
| 130 |
+
** AEMatter.class.py
|
| 131 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 132 |
+
class WindowAttention(nn.Module):
|
| 133 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 134 |
+
It supports both of shifted and non-shifted window.
|
| 135 |
+
Args:
|
| 136 |
+
dim (int): Number of input channels.
|
| 137 |
+
window_size (tuple[int]): The height and width of the window.
|
| 138 |
+
num_heads (int): Number of attention heads.
|
| 139 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 140 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 141 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 142 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self,
|
| 146 |
+
dim,
|
| 147 |
+
window_size,
|
| 148 |
+
num_heads,
|
| 149 |
+
qkv_bias=True,
|
| 150 |
+
qk_scale=None,
|
| 151 |
+
attn_drop=0.,
|
| 152 |
+
proj_drop=0.):
|
| 153 |
+
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.dim = dim
|
| 156 |
+
self.window_size = window_size # Wh, Ww
|
| 157 |
+
self.num_heads = num_heads
|
| 158 |
+
head_dim = dim // num_heads
|
| 159 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 160 |
+
|
| 161 |
+
# define a parameter table of relative position bias
|
| 162 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 163 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
| 164 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 165 |
+
|
| 166 |
+
# get pair-wise relative position index for each token inside the window
|
| 167 |
+
coords_h = torch.arange(self.window_size[0])
|
| 168 |
+
coords_w = torch.arange(self.window_size[1])
|
| 169 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 170 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 171 |
+
relative_coords = coords_flatten[:, :,
|
| 172 |
+
None] - coords_flatten[:,
|
| 173 |
+
None, :] # 2, Wh*Ww, Wh*Ww
|
| 174 |
+
relative_coords = relative_coords.permute(
|
| 175 |
+
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 176 |
+
relative_coords[:, :,
|
| 177 |
+
0] += self.window_size[0] - 1 # shift to start from 0
|
| 178 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 179 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 180 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 181 |
+
self.register_buffer("relative_position_index",
|
| 182 |
+
relative_position_index)
|
| 183 |
+
|
| 184 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 185 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 186 |
+
self.proj = nn.Linear(dim, dim)
|
| 187 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 188 |
+
|
| 189 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 190 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 191 |
+
|
| 192 |
+
def forward(self, x, mask=None):
|
| 193 |
+
""" Forward function.
|
| 194 |
+
Args:
|
| 195 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 196 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 197 |
+
"""
|
| 198 |
+
B_, N, C = x.shape
|
| 199 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
| 200 |
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 201 |
+
q, k, v = qkv[0], qkv[1], qkv[
|
| 202 |
+
2] # make torchscript happy (cannot use tensor as tuple)
|
| 203 |
+
|
| 204 |
+
q = q * self.scale
|
| 205 |
+
attn = (q @ k.transpose(-2, -1))
|
| 206 |
+
|
| 207 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 208 |
+
self.relative_position_index.view(-1)].view(
|
| 209 |
+
self.window_size[0] * self.window_size[1],
|
| 210 |
+
self.window_size[0] * self.window_size[1],
|
| 211 |
+
-1) # Wh*Ww,Wh*Ww,nH
|
| 212 |
+
relative_position_bias = relative_position_bias.permute(
|
| 213 |
+
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 214 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 215 |
+
|
| 216 |
+
if mask is not None:
|
| 217 |
+
nW = mask.shape[0]
|
| 218 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
| 219 |
+
N) + mask.unsqueeze(1).unsqueeze(0)
|
| 220 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 221 |
+
attn = self.softmax(attn)
|
| 222 |
+
else:
|
| 223 |
+
attn = self.softmax(attn)
|
| 224 |
+
|
| 225 |
+
attn = self.attn_drop(attn)
|
| 226 |
+
|
| 227 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 228 |
+
x = self.proj(x)
|
| 229 |
+
x = self.proj_drop(x)
|
| 230 |
+
return x
|
| 231 |
+
#+end_src
|
| 232 |
+
|
| 233 |
+
** AEMatter.class.py
|
| 234 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 235 |
+
class SwinTransformerBlock(nn.Module):
|
| 236 |
+
""" Swin Transformer Block.
|
| 237 |
+
Args:
|
| 238 |
+
dim (int): Number of input channels.
|
| 239 |
+
num_heads (int): Number of attention heads.
|
| 240 |
+
window_size (int): Window size.
|
| 241 |
+
shift_size (int): Shift size for SW-MSA.
|
| 242 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 243 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 244 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 245 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 246 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 247 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 248 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 249 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(self,
|
| 253 |
+
dim,
|
| 254 |
+
num_heads,
|
| 255 |
+
window_size=7,
|
| 256 |
+
shift_size=0,
|
| 257 |
+
mlp_ratio=4.,
|
| 258 |
+
qkv_bias=True,
|
| 259 |
+
qk_scale=None,
|
| 260 |
+
drop=0.,
|
| 261 |
+
attn_drop=0.,
|
| 262 |
+
drop_path=0.,
|
| 263 |
+
act_layer=nn.GELU,
|
| 264 |
+
norm_layer=nn.LayerNorm):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.dim = dim
|
| 267 |
+
self.num_heads = num_heads
|
| 268 |
+
self.window_size = window_size
|
| 269 |
+
self.shift_size = shift_size
|
| 270 |
+
self.mlp_ratio = mlp_ratio
|
| 271 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 272 |
+
|
| 273 |
+
self.norm1 = norm_layer(dim)
|
| 274 |
+
self.attn = WindowAttention(dim,
|
| 275 |
+
window_size=to_2tuple(self.window_size),
|
| 276 |
+
num_heads=num_heads,
|
| 277 |
+
qkv_bias=qkv_bias,
|
| 278 |
+
qk_scale=qk_scale,
|
| 279 |
+
attn_drop=attn_drop,
|
| 280 |
+
proj_drop=drop)
|
| 281 |
+
|
| 282 |
+
self.drop_path = DropPath(
|
| 283 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 284 |
+
self.norm2 = norm_layer(dim)
|
| 285 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 286 |
+
self.mlp = Mlp(in_features=dim,
|
| 287 |
+
hidden_features=mlp_hidden_dim,
|
| 288 |
+
act_layer=act_layer,
|
| 289 |
+
drop=drop)
|
| 290 |
+
|
| 291 |
+
self.H = None
|
| 292 |
+
self.W = None
|
| 293 |
+
|
| 294 |
+
def forward(self, x, mask_matrix):
|
| 295 |
+
""" Forward function.
|
| 296 |
+
Args:
|
| 297 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 298 |
+
H, W: Spatial resolution of the input feature.
|
| 299 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 300 |
+
"""
|
| 301 |
+
B, L, C = x.shape
|
| 302 |
+
H, W = self.H, self.W
|
| 303 |
+
assert L == H * W, "input feature has wrong size"
|
| 304 |
+
|
| 305 |
+
shortcut = x
|
| 306 |
+
x = self.norm1(x)
|
| 307 |
+
x = x.view(B, H, W, C)
|
| 308 |
+
|
| 309 |
+
# pad feature maps to multiples of window size
|
| 310 |
+
pad_l = pad_t = 0
|
| 311 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 312 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 313 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 314 |
+
_, Hp, Wp, _ = x.shape
|
| 315 |
+
|
| 316 |
+
# cyclic shift
|
| 317 |
+
if self.shift_size > 0:
|
| 318 |
+
shifted_x = torch.roll(x,
|
| 319 |
+
shifts=(-self.shift_size, -self.shift_size),
|
| 320 |
+
dims=(1, 2))
|
| 321 |
+
attn_mask = mask_matrix
|
| 322 |
+
else:
|
| 323 |
+
shifted_x = x
|
| 324 |
+
attn_mask = None
|
| 325 |
+
|
| 326 |
+
# partition windows
|
| 327 |
+
x_windows = window_partition(
|
| 328 |
+
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 329 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
| 330 |
+
C) # nW*B, window_size*window_size, C
|
| 331 |
+
|
| 332 |
+
# W-MSA/SW-MSA
|
| 333 |
+
attn_windows = self.attn(
|
| 334 |
+
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 335 |
+
|
| 336 |
+
# merge windows
|
| 337 |
+
attn_windows = attn_windows.view(-1, self.window_size,
|
| 338 |
+
self.window_size, C)
|
| 339 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
| 340 |
+
Wp) # B H' W' C
|
| 341 |
+
|
| 342 |
+
# reverse cyclic shift
|
| 343 |
+
if self.shift_size > 0:
|
| 344 |
+
x = torch.roll(shifted_x,
|
| 345 |
+
shifts=(self.shift_size, self.shift_size),
|
| 346 |
+
dims=(1, 2))
|
| 347 |
+
else:
|
| 348 |
+
x = shifted_x
|
| 349 |
+
|
| 350 |
+
if pad_r > 0 or pad_b > 0:
|
| 351 |
+
x = x[:, :H, :W, :].contiguous()
|
| 352 |
+
|
| 353 |
+
x = x.view(B, H * W, C)
|
| 354 |
+
|
| 355 |
+
# FFN
|
| 356 |
+
x = shortcut + self.drop_path(x)
|
| 357 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 358 |
+
|
| 359 |
+
return x
|
| 360 |
+
#+end_src
|
| 361 |
+
|
| 362 |
+
** AEMatter.class.py
|
| 363 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 364 |
+
class PatchMerging(nn.Module):
|
| 365 |
+
""" Patch Merging Layer
|
| 366 |
+
Args:
|
| 367 |
+
dim (int): Number of input channels.
|
| 368 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.dim = dim
|
| 374 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 375 |
+
self.norm = norm_layer(4 * dim)
|
| 376 |
+
|
| 377 |
+
def forward(self, x, H, W):
|
| 378 |
+
""" Forward function.
|
| 379 |
+
Args:
|
| 380 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 381 |
+
H, W: Spatial resolution of the input feature.
|
| 382 |
+
"""
|
| 383 |
+
B, L, C = x.shape
|
| 384 |
+
assert L == H * W, "input feature has wrong size"
|
| 385 |
+
|
| 386 |
+
x = x.view(B, H, W, C)
|
| 387 |
+
|
| 388 |
+
# padding
|
| 389 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 390 |
+
if pad_input:
|
| 391 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 392 |
+
|
| 393 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 394 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 395 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 396 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 397 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 398 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 399 |
+
|
| 400 |
+
x = self.norm(x)
|
| 401 |
+
x = self.reduction(x)
|
| 402 |
+
|
| 403 |
+
return x
|
| 404 |
+
#+end_src
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
** AEMatter.class.py
|
| 408 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 409 |
+
class BasicLayer(nn.Module):
|
| 410 |
+
""" A basic Swin Transformer layer for one stage.
|
| 411 |
+
Args:
|
| 412 |
+
dim (int): Number of feature channels
|
| 413 |
+
depth (int): Depths of this stage.
|
| 414 |
+
num_heads (int): Number of attention head.
|
| 415 |
+
window_size (int): Local window size. Default: 7.
|
| 416 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 417 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 418 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 419 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 420 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 421 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 422 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 423 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 424 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def __init__(self,
|
| 428 |
+
dim,
|
| 429 |
+
depth,
|
| 430 |
+
num_heads,
|
| 431 |
+
window_size=7,
|
| 432 |
+
mlp_ratio=4.,
|
| 433 |
+
qkv_bias=True,
|
| 434 |
+
qk_scale=None,
|
| 435 |
+
drop=0.,
|
| 436 |
+
attn_drop=0.,
|
| 437 |
+
drop_path=0.,
|
| 438 |
+
norm_layer=nn.LayerNorm,
|
| 439 |
+
downsample=None,
|
| 440 |
+
use_checkpoint=False):
|
| 441 |
+
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.window_size = window_size
|
| 444 |
+
self.shift_size = window_size // 2
|
| 445 |
+
self.depth = depth
|
| 446 |
+
self.use_checkpoint = use_checkpoint
|
| 447 |
+
|
| 448 |
+
# build blocks
|
| 449 |
+
self.blocks = nn.ModuleList([
|
| 450 |
+
SwinTransformerBlock(dim=dim,
|
| 451 |
+
num_heads=num_heads,
|
| 452 |
+
window_size=window_size,
|
| 453 |
+
shift_size=0 if
|
| 454 |
+
(i % 2 == 0) else window_size // 2,
|
| 455 |
+
mlp_ratio=mlp_ratio,
|
| 456 |
+
qkv_bias=qkv_bias,
|
| 457 |
+
qk_scale=qk_scale,
|
| 458 |
+
drop=drop,
|
| 459 |
+
attn_drop=attn_drop,
|
| 460 |
+
drop_path=drop_path[i] if isinstance(
|
| 461 |
+
drop_path, list) else drop_path,
|
| 462 |
+
norm_layer=norm_layer) for i in range(depth)
|
| 463 |
+
])
|
| 464 |
+
|
| 465 |
+
# patch merging layer
|
| 466 |
+
if downsample is not None:
|
| 467 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 468 |
+
else:
|
| 469 |
+
self.downsample = None
|
| 470 |
+
|
| 471 |
+
def forward(self, x, H, W):
|
| 472 |
+
""" Forward function.
|
| 473 |
+
Args:
|
| 474 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 475 |
+
H, W: Spatial resolution of the input feature.
|
| 476 |
+
"""
|
| 477 |
+
# print(x.shape,H,W)
|
| 478 |
+
# calculate attention mask for SW-MSA
|
| 479 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 480 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 481 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 482 |
+
h_slices = (slice(0, -self.window_size),
|
| 483 |
+
slice(-self.window_size,
|
| 484 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 485 |
+
w_slices = (slice(0, -self.window_size),
|
| 486 |
+
slice(-self.window_size,
|
| 487 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 488 |
+
cnt = 0
|
| 489 |
+
for h in h_slices:
|
| 490 |
+
for w in w_slices:
|
| 491 |
+
img_mask[:, h, w, :] = cnt
|
| 492 |
+
cnt += 1
|
| 493 |
+
|
| 494 |
+
mask_windows = window_partition(
|
| 495 |
+
img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 496 |
+
|
| 497 |
+
mask_windows = mask_windows.view(-1,
|
| 498 |
+
self.window_size * self.window_size)
|
| 499 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
|
| 500 |
+
2) # nW, ww window_size*window_size
|
| 501 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
| 502 |
+
float(-100.0)).masked_fill(
|
| 503 |
+
attn_mask == 0, float(0.0))
|
| 504 |
+
|
| 505 |
+
for blk in self.blocks:
|
| 506 |
+
blk.H, blk.W = H, W
|
| 507 |
+
if self.use_checkpoint:
|
| 508 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 509 |
+
else:
|
| 510 |
+
x = blk(x, attn_mask)
|
| 511 |
+
|
| 512 |
+
if self.downsample is not None:
|
| 513 |
+
x_down = self.downsample(x, H, W)
|
| 514 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 515 |
+
return x, H, W, x_down, Wh, Ww
|
| 516 |
+
else:
|
| 517 |
+
return x, H, W, x, H, W
|
| 518 |
+
#+end_src
|
| 519 |
+
|
| 520 |
+
** AEMatter.class.py
|
| 521 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 522 |
+
class PatchEmbed(nn.Module):
|
| 523 |
+
""" Image to Patch Embedding
|
| 524 |
+
Args:
|
| 525 |
+
patch_size (int): Patch token size. Default: 4.
|
| 526 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 527 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 528 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
def __init__(self,
|
| 532 |
+
patch_size=4,
|
| 533 |
+
in_chans=3,
|
| 534 |
+
embed_dim=96,
|
| 535 |
+
norm_layer=None):
|
| 536 |
+
|
| 537 |
+
super().__init__()
|
| 538 |
+
patch_size = to_2tuple(patch_size)
|
| 539 |
+
self.patch_size = patch_size
|
| 540 |
+
|
| 541 |
+
self.in_chans = in_chans
|
| 542 |
+
self.embed_dim = embed_dim
|
| 543 |
+
|
| 544 |
+
self.proj = nn.Conv2d(in_chans,
|
| 545 |
+
embed_dim,
|
| 546 |
+
kernel_size=patch_size,
|
| 547 |
+
stride=patch_size)
|
| 548 |
+
if norm_layer is not None:
|
| 549 |
+
self.norm = norm_layer(embed_dim)
|
| 550 |
+
else:
|
| 551 |
+
self.norm = None
|
| 552 |
+
|
| 553 |
+
def forward(self, x):
|
| 554 |
+
"""Forward function."""
|
| 555 |
+
# padding
|
| 556 |
+
_, _, H, W = x.size()
|
| 557 |
+
if W % self.patch_size[1] != 0:
|
| 558 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 559 |
+
if H % self.patch_size[0] != 0:
|
| 560 |
+
x = F.pad(x,
|
| 561 |
+
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 562 |
+
|
| 563 |
+
x = self.proj(x) # B C Wh Ww
|
| 564 |
+
if self.norm is not None:
|
| 565 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 566 |
+
x = x.flatten(2).transpose(1, 2)
|
| 567 |
+
x = self.norm(x)
|
| 568 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 569 |
+
|
| 570 |
+
return x
|
| 571 |
+
#+end_src
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
** AEMatter.class.py
|
| 575 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 576 |
+
class SwinTransformer(nn.Module):
|
| 577 |
+
""" Swin Transformer backbone.
|
| 578 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 579 |
+
https://arxiv.org/pdf/2103.14030
|
| 580 |
+
Args:
|
| 581 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 582 |
+
used in absolute postion embedding. Default 224.
|
| 583 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 584 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 585 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 586 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 587 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 588 |
+
window_size (int): Window size. Default: 7.
|
| 589 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 590 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 591 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 592 |
+
drop_rate (float): Dropout rate.
|
| 593 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 594 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 595 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 596 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 597 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 598 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 599 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 600 |
+
-1 means not freezing any parameters.
|
| 601 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 602 |
+
"""
|
| 603 |
+
|
| 604 |
+
def __init__(self,
|
| 605 |
+
pretrain_img_size=224,
|
| 606 |
+
patch_size=4,
|
| 607 |
+
in_chans=3,
|
| 608 |
+
embed_dim=96,
|
| 609 |
+
depths=[2, 2, 6, 2],
|
| 610 |
+
num_heads=[3, 6, 12, 24],
|
| 611 |
+
window_size=7,
|
| 612 |
+
mlp_ratio=4.,
|
| 613 |
+
qkv_bias=True,
|
| 614 |
+
qk_scale=None,
|
| 615 |
+
drop_rate=0.,
|
| 616 |
+
attn_drop_rate=0.,
|
| 617 |
+
drop_path_rate=0.2,
|
| 618 |
+
norm_layer=nn.LayerNorm,
|
| 619 |
+
ape=False,
|
| 620 |
+
patch_norm=True,
|
| 621 |
+
out_indices=(0, 1, 2, 3),
|
| 622 |
+
frozen_stages=-1,
|
| 623 |
+
use_checkpoint=False):
|
| 624 |
+
|
| 625 |
+
super().__init__()
|
| 626 |
+
|
| 627 |
+
self.pretrain_img_size = pretrain_img_size
|
| 628 |
+
self.num_layers = len(depths)
|
| 629 |
+
self.embed_dim = embed_dim
|
| 630 |
+
self.ape = ape
|
| 631 |
+
self.patch_norm = patch_norm
|
| 632 |
+
self.out_indices = out_indices
|
| 633 |
+
self.frozen_stages = frozen_stages
|
| 634 |
+
|
| 635 |
+
# split image into non-overlapping patches
|
| 636 |
+
self.patch_embed = PatchEmbed(
|
| 637 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 638 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 639 |
+
|
| 640 |
+
# absolute position embedding
|
| 641 |
+
if self.ape:
|
| 642 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 643 |
+
patch_size = to_2tuple(patch_size)
|
| 644 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
| 645 |
+
|
| 646 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
| 647 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 648 |
+
|
| 649 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 650 |
+
|
| 651 |
+
# stochastic depth
|
| 652 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 653 |
+
|
| 654 |
+
# build layers
|
| 655 |
+
self.layers = nn.ModuleList()
|
| 656 |
+
for i_layer in range(self.num_layers):
|
| 657 |
+
layer = BasicLayer(
|
| 658 |
+
dim=int(embed_dim * 2 ** i_layer),
|
| 659 |
+
depth=depths[i_layer],
|
| 660 |
+
num_heads=num_heads[i_layer],
|
| 661 |
+
window_size=window_size,
|
| 662 |
+
mlp_ratio=mlp_ratio,
|
| 663 |
+
qkv_bias=qkv_bias,
|
| 664 |
+
qk_scale=qk_scale,
|
| 665 |
+
drop=drop_rate,
|
| 666 |
+
attn_drop=attn_drop_rate,
|
| 667 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 668 |
+
norm_layer=norm_layer,
|
| 669 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 670 |
+
use_checkpoint=use_checkpoint)
|
| 671 |
+
self.layers.append(layer)
|
| 672 |
+
|
| 673 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
| 674 |
+
self.num_features = num_features
|
| 675 |
+
|
| 676 |
+
# add a norm layer for each output
|
| 677 |
+
for i_layer in out_indices:
|
| 678 |
+
layer = norm_layer(num_features[i_layer])
|
| 679 |
+
layer_name = f'norm{i_layer}'
|
| 680 |
+
self.add_module(layer_name, layer)
|
| 681 |
+
|
| 682 |
+
self._freeze_stages()
|
| 683 |
+
|
| 684 |
+
def _freeze_stages(self):
|
| 685 |
+
if self.frozen_stages >= 0:
|
| 686 |
+
self.patch_embed.eval()
|
| 687 |
+
for param in self.patch_embed.parameters():
|
| 688 |
+
param.requires_grad = False
|
| 689 |
+
|
| 690 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 691 |
+
self.absolute_pos_embed.requires_grad = False
|
| 692 |
+
|
| 693 |
+
if self.frozen_stages >= 2:
|
| 694 |
+
self.pos_drop.eval()
|
| 695 |
+
for i in range(0, self.frozen_stages - 1):
|
| 696 |
+
m = self.layers[i]
|
| 697 |
+
m.eval()
|
| 698 |
+
for param in m.parameters():
|
| 699 |
+
param.requires_grad = False
|
| 700 |
+
|
| 701 |
+
def init_weights(self, pretrained=None):
|
| 702 |
+
"""Initialize the weights in backbone.
|
| 703 |
+
Args:
|
| 704 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 705 |
+
Defaults to None.
|
| 706 |
+
"""
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def forward(self, x):
|
| 710 |
+
"""Forward function."""
|
| 711 |
+
x = self.patch_embed(x)
|
| 712 |
+
|
| 713 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 714 |
+
if self.ape:
|
| 715 |
+
# interpolate the position embedding to the corresponding size
|
| 716 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 717 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
| 718 |
+
else:
|
| 719 |
+
x = x.flatten(2).transpose(1, 2)
|
| 720 |
+
x = self.pos_drop(x)
|
| 721 |
+
|
| 722 |
+
outs = []
|
| 723 |
+
for i in range(self.num_layers):
|
| 724 |
+
layer = self.layers[i]
|
| 725 |
+
|
| 726 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 727 |
+
|
| 728 |
+
if i in self.out_indices:
|
| 729 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 730 |
+
x_out = norm_layer(x_out)
|
| 731 |
+
|
| 732 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
| 733 |
+
outs.append(out)
|
| 734 |
+
|
| 735 |
+
return tuple(outs)
|
| 736 |
+
|
| 737 |
+
def train(self, mode=True):
|
| 738 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 739 |
+
super(SwinTransformer, self).train(mode)
|
| 740 |
+
self._freeze_stages()
|
| 741 |
+
#+end_src
|
| 742 |
+
|
| 743 |
+
** AEMatter.class.py
|
| 744 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 745 |
+
class Mlp(nn.Module):
|
| 746 |
+
""" Multilayer perceptron."""
|
| 747 |
+
|
| 748 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 749 |
+
super().__init__()
|
| 750 |
+
out_features = out_features or in_features
|
| 751 |
+
hidden_features = hidden_features or in_features
|
| 752 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 753 |
+
self.act = act_layer()
|
| 754 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 755 |
+
self.drop = nn.Dropout(drop)
|
| 756 |
+
|
| 757 |
+
def forward(self, x):
|
| 758 |
+
x = self.fc1(x)
|
| 759 |
+
x = self.act(x)
|
| 760 |
+
x = self.drop(x)
|
| 761 |
+
x = self.fc2(x)
|
| 762 |
+
x = self.drop(x)
|
| 763 |
+
return x
|
| 764 |
+
#+end_src
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
** AEMatter.class.py
|
| 768 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 769 |
+
class ResBlock(nn.Module):
|
| 770 |
+
|
| 771 |
+
def __init__(self, inc, midc):
|
| 772 |
+
super(ResBlock, self).__init__()
|
| 773 |
+
self.conv1 = nn.Conv2d(inc,
|
| 774 |
+
midc,
|
| 775 |
+
kernel_size=1,
|
| 776 |
+
stride=1,
|
| 777 |
+
padding=0,
|
| 778 |
+
bias=True)
|
| 779 |
+
self.gn1 = nn.GroupNorm(16, midc)
|
| 780 |
+
self.conv2 = nn.Conv2d(midc,
|
| 781 |
+
midc,
|
| 782 |
+
kernel_size=3,
|
| 783 |
+
stride=1,
|
| 784 |
+
padding=1,
|
| 785 |
+
bias=True)
|
| 786 |
+
self.gn2 = nn.GroupNorm(16, midc)
|
| 787 |
+
self.conv3 = nn.Conv2d(midc,
|
| 788 |
+
inc,
|
| 789 |
+
kernel_size=1,
|
| 790 |
+
stride=1,
|
| 791 |
+
padding=0,
|
| 792 |
+
bias=True)
|
| 793 |
+
self.relu = nn.LeakyReLU(0.1)
|
| 794 |
+
|
| 795 |
+
def forward(self, x):
|
| 796 |
+
x_ = x
|
| 797 |
+
x = self.conv1(x)
|
| 798 |
+
x = self.gn1(x)
|
| 799 |
+
x = self.relu(x)
|
| 800 |
+
x = self.conv2(x)
|
| 801 |
+
x = self.gn2(x)
|
| 802 |
+
x = self.relu(x)
|
| 803 |
+
x = self.conv3(x)
|
| 804 |
+
x = x + x_
|
| 805 |
+
x = self.relu(x)
|
| 806 |
+
return x
|
| 807 |
+
#+end_src
|
| 808 |
+
|
| 809 |
+
** AEMatter.class.py
|
| 810 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 811 |
+
class AEALblock(nn.Module):
|
| 812 |
+
|
| 813 |
+
def __init__(self,
|
| 814 |
+
d_model,
|
| 815 |
+
nhead,
|
| 816 |
+
dim_feedforward=512,
|
| 817 |
+
dropout=0.0,
|
| 818 |
+
layer_norm_eps=1e-5,
|
| 819 |
+
batch_first=True,
|
| 820 |
+
norm_first=False,
|
| 821 |
+
width=5):
|
| 822 |
+
super(AEALblock, self).__init__()
|
| 823 |
+
self.self_attn2 = nn.MultiheadAttention(d_model // 2,
|
| 824 |
+
nhead // 2,
|
| 825 |
+
dropout=dropout,
|
| 826 |
+
batch_first=batch_first)
|
| 827 |
+
self.self_attn1 = nn.MultiheadAttention(d_model // 2,
|
| 828 |
+
nhead // 2,
|
| 829 |
+
dropout=dropout,
|
| 830 |
+
batch_first=batch_first)
|
| 831 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 832 |
+
self.dropout = nn.Dropout(dropout)
|
| 833 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 834 |
+
self.norm_first = norm_first
|
| 835 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 836 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 837 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 838 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 839 |
+
self.activation = nn.ReLU()
|
| 840 |
+
self.width = width
|
| 841 |
+
self.trans = nn.Sequential(
|
| 842 |
+
nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
|
| 843 |
+
ResBlock(d_model // 2, d_model // 4),
|
| 844 |
+
nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
|
| 845 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
| 846 |
+
|
| 847 |
+
def forward(
|
| 848 |
+
self,
|
| 849 |
+
src,
|
| 850 |
+
feats,
|
| 851 |
+
):
|
| 852 |
+
src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
|
| 853 |
+
b, c, h, w = src.shape
|
| 854 |
+
x1 = src[:, 0:c // 2]
|
| 855 |
+
x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
|
| 856 |
+
x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
|
| 857 |
+
x2 = src[:, c // 2:]
|
| 858 |
+
x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
|
| 859 |
+
x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
|
| 860 |
+
x = rearrange(src, 'b c h w-> b (h w) c')
|
| 861 |
+
x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
|
| 862 |
+
x = self.norm2(x + self._ff_block(x))
|
| 863 |
+
x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
|
| 864 |
+
return x
|
| 865 |
+
|
| 866 |
+
def _sa_block(self, x1, x2, h, w):
|
| 867 |
+
x1 = self.self_attn1(x1,
|
| 868 |
+
x1,
|
| 869 |
+
x1,
|
| 870 |
+
attn_mask=None,
|
| 871 |
+
key_padding_mask=None,
|
| 872 |
+
need_weights=False)[0]
|
| 873 |
+
|
| 874 |
+
x2 = self.self_attn2(x2,
|
| 875 |
+
x2,
|
| 876 |
+
x2,
|
| 877 |
+
attn_mask=None,
|
| 878 |
+
key_padding_mask=None,
|
| 879 |
+
need_weights=False)[0]
|
| 880 |
+
|
| 881 |
+
x1 = rearrange(x1,
|
| 882 |
+
'(b h1) (h2 w) c-> b (h1 h2 w) c',
|
| 883 |
+
h2=self.width,
|
| 884 |
+
h1=h // self.width)
|
| 885 |
+
x2 = rearrange(x2,
|
| 886 |
+
' (b w1) (h w2) c-> b (h w1 w2) c',
|
| 887 |
+
w2=self.width,
|
| 888 |
+
w1=w // self.width)
|
| 889 |
+
x = torch.cat([x1, x2], dim=2)
|
| 890 |
+
return self.dropout1(x)
|
| 891 |
+
|
| 892 |
+
def _ff_block(self, x):
|
| 893 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 894 |
+
return self.dropout2(x)
|
| 895 |
+
#+end_src
|
| 896 |
+
|
| 897 |
+
** AEMatter.class.py
|
| 898 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 899 |
+
class AEMatter(nn.Module):
|
| 900 |
+
|
| 901 |
+
def __init__(self):
|
| 902 |
+
super(AEMatter, self).__init__()
|
| 903 |
+
trans = SwinTransformer(pretrain_img_size=224,
|
| 904 |
+
embed_dim=96,
|
| 905 |
+
depths=[2, 2, 6, 2],
|
| 906 |
+
num_heads=[3, 6, 12, 24],
|
| 907 |
+
window_size=7,
|
| 908 |
+
ape=False,
|
| 909 |
+
drop_path_rate=0.2,
|
| 910 |
+
patch_norm=True,
|
| 911 |
+
use_checkpoint=False)
|
| 912 |
+
|
| 913 |
+
# trans.load_state_dict(torch.load(
|
| 914 |
+
# '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
|
| 915 |
+
# map_location="cpu")["model"],
|
| 916 |
+
# strict=False)
|
| 917 |
+
|
| 918 |
+
trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
|
| 919 |
+
|
| 920 |
+
self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
|
| 921 |
+
nn.PReLU(48))
|
| 922 |
+
|
| 923 |
+
self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
|
| 924 |
+
1), nn.PReLU(64),
|
| 925 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 926 |
+
nn.PReLU(64))
|
| 927 |
+
|
| 928 |
+
self.trans = trans
|
| 929 |
+
self.conv1 = nn.Sequential(
|
| 930 |
+
nn.Conv2d(in_channels=640 + 768,
|
| 931 |
+
out_channels=256,
|
| 932 |
+
kernel_size=1,
|
| 933 |
+
stride=1,
|
| 934 |
+
padding=0,
|
| 935 |
+
bias=True))
|
| 936 |
+
self.conv2 = nn.Sequential(
|
| 937 |
+
nn.Conv2d(in_channels=256 + 384,
|
| 938 |
+
out_channels=256,
|
| 939 |
+
kernel_size=1,
|
| 940 |
+
stride=1,
|
| 941 |
+
padding=0,
|
| 942 |
+
bias=True), )
|
| 943 |
+
self.conv3 = nn.Sequential(
|
| 944 |
+
nn.Conv2d(in_channels=256 + 192,
|
| 945 |
+
out_channels=192,
|
| 946 |
+
kernel_size=1,
|
| 947 |
+
stride=1,
|
| 948 |
+
padding=0,
|
| 949 |
+
bias=True), )
|
| 950 |
+
self.conv4 = nn.Sequential(
|
| 951 |
+
nn.Conv2d(in_channels=192 + 96,
|
| 952 |
+
out_channels=128,
|
| 953 |
+
kernel_size=1,
|
| 954 |
+
stride=1,
|
| 955 |
+
padding=0,
|
| 956 |
+
bias=True), )
|
| 957 |
+
self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
|
| 958 |
+
self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
|
| 959 |
+
self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
|
| 960 |
+
self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
|
| 961 |
+
self.conv5 = nn.Sequential(
|
| 962 |
+
nn.Conv2d(in_channels=192,
|
| 963 |
+
out_channels=64,
|
| 964 |
+
kernel_size=3,
|
| 965 |
+
stride=1,
|
| 966 |
+
padding=1,
|
| 967 |
+
bias=True), nn.PReLU(64),
|
| 968 |
+
nn.Conv2d(in_channels=64,
|
| 969 |
+
out_channels=64,
|
| 970 |
+
kernel_size=3,
|
| 971 |
+
stride=1,
|
| 972 |
+
padding=1,
|
| 973 |
+
bias=True), nn.PReLU(64),
|
| 974 |
+
nn.Conv2d(in_channels=64,
|
| 975 |
+
out_channels=48,
|
| 976 |
+
kernel_size=3,
|
| 977 |
+
stride=1,
|
| 978 |
+
padding=1,
|
| 979 |
+
bias=True), nn.PReLU(48))
|
| 980 |
+
self.convo = nn.Sequential(
|
| 981 |
+
nn.Conv2d(in_channels=48 + 48 + 6,
|
| 982 |
+
out_channels=32,
|
| 983 |
+
kernel_size=3,
|
| 984 |
+
stride=1,
|
| 985 |
+
padding=1,
|
| 986 |
+
bias=True), nn.PReLU(32),
|
| 987 |
+
nn.Conv2d(in_channels=32,
|
| 988 |
+
out_channels=32,
|
| 989 |
+
kernel_size=3,
|
| 990 |
+
stride=1,
|
| 991 |
+
padding=1,
|
| 992 |
+
bias=True), nn.PReLU(32),
|
| 993 |
+
nn.Conv2d(in_channels=32,
|
| 994 |
+
out_channels=1,
|
| 995 |
+
kernel_size=3,
|
| 996 |
+
stride=1,
|
| 997 |
+
padding=1,
|
| 998 |
+
bias=True))
|
| 999 |
+
self.up = nn.Upsample(scale_factor=2,
|
| 1000 |
+
mode='bilinear',
|
| 1001 |
+
align_corners=False)
|
| 1002 |
+
self.upn = nn.Upsample(scale_factor=2, mode='nearest')
|
| 1003 |
+
self.apptrans = nn.Sequential(
|
| 1004 |
+
nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
|
| 1005 |
+
ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
|
| 1006 |
+
ResBlock(512, 128))
|
| 1007 |
+
self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
|
| 1008 |
+
ResBlock(640, 160))
|
| 1009 |
+
self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
|
| 1010 |
+
self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
|
| 1011 |
+
self.width = 5
|
| 1012 |
+
self.trans1 = AEALblock(d_model=640,
|
| 1013 |
+
nhead=20,
|
| 1014 |
+
dim_feedforward=2048,
|
| 1015 |
+
dropout=0.2,
|
| 1016 |
+
width=self.width)
|
| 1017 |
+
self.trans2 = AEALblock(d_model=640,
|
| 1018 |
+
nhead=20,
|
| 1019 |
+
dim_feedforward=2048,
|
| 1020 |
+
dropout=0.2,
|
| 1021 |
+
width=self.width)
|
| 1022 |
+
self.trans3 = AEALblock(d_model=640,
|
| 1023 |
+
nhead=20,
|
| 1024 |
+
dim_feedforward=2048,
|
| 1025 |
+
dropout=0.2,
|
| 1026 |
+
width=self.width)
|
| 1027 |
+
|
| 1028 |
+
def aeal(self, x, sem):
|
| 1029 |
+
xe = self.emb(x)
|
| 1030 |
+
x_ = xe
|
| 1031 |
+
x_ = self.embdp(x_)
|
| 1032 |
+
b, c, h1, w1 = x_.shape
|
| 1033 |
+
bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
|
| 1034 |
+
bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
|
| 1035 |
+
newph1 = bnew_ph // 2
|
| 1036 |
+
newph2 = bnew_ph - newph1
|
| 1037 |
+
newpw1 = bnew_pw // 2
|
| 1038 |
+
newpw2 = bnew_pw - newpw1
|
| 1039 |
+
x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
|
| 1040 |
+
sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
|
| 1041 |
+
x_ = self.trans1(x_, sem)
|
| 1042 |
+
x_ = self.trans2(x_, sem)
|
| 1043 |
+
x_ = self.trans3(x_, sem)
|
| 1044 |
+
x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
|
| 1045 |
+
return x_
|
| 1046 |
+
|
| 1047 |
+
def forward(self, x, y):
|
| 1048 |
+
inputs = torch.cat((x, y), 1)
|
| 1049 |
+
x = self.start_conv0(inputs)
|
| 1050 |
+
x_ = self.start_conv(x)
|
| 1051 |
+
x1, x2, x3, x4 = self.trans(x_)
|
| 1052 |
+
x4h = self.h2l(x4)
|
| 1053 |
+
x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
|
| 1054 |
+
x4_ = self.aeal(x4, x3s)
|
| 1055 |
+
x4 = torch.cat((x4, x4_), 1)
|
| 1056 |
+
X4 = self.conv1(x4)
|
| 1057 |
+
wh, ww = X4.shape[2], X4.shape[3]
|
| 1058 |
+
X4 = rearrange(X4, 'b c h w -> b (h w) c')
|
| 1059 |
+
X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
|
| 1060 |
+
X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1061 |
+
X3 = self.up(X4)
|
| 1062 |
+
X3 = torch.cat((x3, X3), 1)
|
| 1063 |
+
X3 = self.conv2(X3)
|
| 1064 |
+
wh, ww = X3.shape[2], X3.shape[3]
|
| 1065 |
+
X3 = rearrange(X3, 'b c h w -> b (h w) c')
|
| 1066 |
+
X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
|
| 1067 |
+
X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1068 |
+
X2 = self.up(X3)
|
| 1069 |
+
X2 = torch.cat((x2, X2), 1)
|
| 1070 |
+
X2 = self.conv3(X2)
|
| 1071 |
+
wh, ww = X2.shape[2], X2.shape[3]
|
| 1072 |
+
X2 = rearrange(X2, 'b c h w -> b (h w) c')
|
| 1073 |
+
X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
|
| 1074 |
+
X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1075 |
+
X1 = self.up(X2)
|
| 1076 |
+
X1 = torch.cat((x1, X1), 1)
|
| 1077 |
+
X1 = self.conv4(X1)
|
| 1078 |
+
wh, ww = X1.shape[2], X1.shape[3]
|
| 1079 |
+
X1 = rearrange(X1, 'b c h w -> b (h w) c')
|
| 1080 |
+
X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
|
| 1081 |
+
X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1082 |
+
X0 = self.up(X1)
|
| 1083 |
+
X0 = torch.cat((x_, X0), 1)
|
| 1084 |
+
X0 = self.conv5(X0)
|
| 1085 |
+
X = self.up(X0)
|
| 1086 |
+
X = torch.cat((inputs, x, X), 1)
|
| 1087 |
+
alpha = self.convo(X)
|
| 1088 |
+
alpha = torch.clamp(alpha, min=0, max=1)
|
| 1089 |
+
return alpha
|
| 1090 |
+
#+end_src
|
| 1091 |
+
|
| 1092 |
+
** Function to load model
|
| 1093 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
|
| 1094 |
+
def get_AEMatter_model(path_model_checkpoint):
|
| 1095 |
+
|
| 1096 |
+
download_model(path=path_model_checkpoint)
|
| 1097 |
+
|
| 1098 |
+
matmodel = AEMatter()
|
| 1099 |
+
matmodel.load_state_dict(
|
| 1100 |
+
torch.load(path_model_checkpoint, map_location='cpu')['model'])
|
| 1101 |
+
|
| 1102 |
+
matmodel = matmodel.cuda()
|
| 1103 |
+
matmodel.eval()
|
| 1104 |
+
|
| 1105 |
+
return matmodel
|
| 1106 |
+
#+end_src
|
| 1107 |
+
|
| 1108 |
+
** Function to do inference
|
| 1109 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
|
| 1110 |
+
def do_infer(rawimg, trimap, matmodel):
|
| 1111 |
+
trimap_nonp = trimap.copy()
|
| 1112 |
+
h, w, c = rawimg.shape
|
| 1113 |
+
nonph, nonpw, _ = rawimg.shape
|
| 1114 |
+
newh = (((h - 1) // 32) + 1) * 32
|
| 1115 |
+
neww = (((w - 1) // 32) + 1) * 32
|
| 1116 |
+
padh = newh - h
|
| 1117 |
+
padh1 = int(padh / 2)
|
| 1118 |
+
padh2 = padh - padh1
|
| 1119 |
+
padw = neww - w
|
| 1120 |
+
padw1 = int(padw / 2)
|
| 1121 |
+
padw2 = padw - padw1
|
| 1122 |
+
|
| 1123 |
+
rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
|
| 1124 |
+
cv2.BORDER_REFLECT)
|
| 1125 |
+
|
| 1126 |
+
trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
|
| 1127 |
+
cv2.BORDER_REFLECT)
|
| 1128 |
+
|
| 1129 |
+
h_pad, w_pad, _ = rawimg_pad.shape
|
| 1130 |
+
tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
|
| 1131 |
+
tritemp[:, :, 0] = (trimap_pad == 0)
|
| 1132 |
+
tritemp[:, :, 1] = (trimap_pad == 128)
|
| 1133 |
+
tritemp[:, :, 2] = (trimap_pad == 255)
|
| 1134 |
+
tritempimgs = np.transpose(tritemp, (2, 0, 1))
|
| 1135 |
+
tritempimgs = tritempimgs[np.newaxis, :, :, :]
|
| 1136 |
+
img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
|
| 1137 |
+
img = np.array(img, np.float32)
|
| 1138 |
+
img = img / 255.
|
| 1139 |
+
img = torch.from_numpy(img).cuda()
|
| 1140 |
+
tritempimgs = torch.from_numpy(tritempimgs).cuda()
|
| 1141 |
+
with torch.no_grad():
|
| 1142 |
+
pred = matmodel(img, tritempimgs)
|
| 1143 |
+
pred = pred.detach().cpu().numpy()[0]
|
| 1144 |
+
pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
|
| 1145 |
+
preda = pred[
|
| 1146 |
+
0:1,
|
| 1147 |
+
] * 255
|
| 1148 |
+
preda = np.transpose(preda, (1, 2, 0))
|
| 1149 |
+
preda = preda * (trimap_nonp[:, :, None]
|
| 1150 |
+
== 128) + (trimap_nonp[:, :, None] == 255) * 255
|
| 1151 |
+
preda = np.array(preda, np.uint8)
|
| 1152 |
+
return preda
|
| 1153 |
+
#+end_src
|
| 1154 |
+
|
| 1155 |
+
** Load ComfyUI AEMatter model
|
| 1156 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
|
| 1157 |
+
class load_AEMatter_Model:
|
| 1158 |
+
|
| 1159 |
+
def __init__(self):
|
| 1160 |
+
pass
|
| 1161 |
+
|
| 1162 |
+
@classmethod
|
| 1163 |
+
def INPUT_TYPES(s):
|
| 1164 |
+
return {
|
| 1165 |
+
"required": {},
|
| 1166 |
+
}
|
| 1167 |
+
|
| 1168 |
+
RETURN_TYPES = ("AEMatter_Model", )
|
| 1169 |
+
FUNCTION = "test"
|
| 1170 |
+
CATEGORY = "AEMatter"
|
| 1171 |
+
|
| 1172 |
+
def test(self):
|
| 1173 |
+
return (get_AEMatter_model(get_model_path()), )
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
class run_AEMatter_inference:
|
| 1177 |
+
|
| 1178 |
+
def __init__(self):
|
| 1179 |
+
pass
|
| 1180 |
+
|
| 1181 |
+
@classmethod
|
| 1182 |
+
def INPUT_TYPES(s):
|
| 1183 |
+
return {
|
| 1184 |
+
"required": {
|
| 1185 |
+
"image": ("IMAGE", ),
|
| 1186 |
+
"trimap": ("MASK", ),
|
| 1187 |
+
"AEMatter_Model": ("AEMatter_Model", ),
|
| 1188 |
+
},
|
| 1189 |
+
}
|
| 1190 |
+
|
| 1191 |
+
RETURN_TYPES = ("MASK", )
|
| 1192 |
+
FUNCTION = "test"
|
| 1193 |
+
CATEGORY = "AEMatter"
|
| 1194 |
+
|
| 1195 |
+
def test(
|
| 1196 |
+
self,
|
| 1197 |
+
image,
|
| 1198 |
+
trimap,
|
| 1199 |
+
AEMatter_Model,
|
| 1200 |
+
):
|
| 1201 |
+
|
| 1202 |
+
ret = []
|
| 1203 |
+
batch_size = image.shape[0]
|
| 1204 |
+
|
| 1205 |
+
for i in range(batch_size):
|
| 1206 |
+
tmp_i = from_torch_image(image[i])
|
| 1207 |
+
tmp_m = from_torch_image(trimap[i])
|
| 1208 |
+
tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
|
| 1209 |
+
ret.append(tmp)
|
| 1210 |
+
|
| 1211 |
+
ret = to_torch_image(np.array(ret))
|
| 1212 |
+
ret = ret.squeeze(-1)
|
| 1213 |
+
print(ret.shape)
|
| 1214 |
+
|
| 1215 |
+
return ret
|
| 1216 |
+
#+end_src
|
| 1217 |
+
|
| 1218 |
+
** Main function
|
| 1219 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
|
| 1220 |
+
def main():
|
| 1221 |
+
ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
|
| 1222 |
+
pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
|
| 1223 |
+
p_outs = 'alpha.png'
|
| 1224 |
+
|
| 1225 |
+
matmodel = get_AEMatter_model(
|
| 1226 |
+
path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
|
| 1227 |
+
|
| 1228 |
+
# matmodel = AEMatter()
|
| 1229 |
+
# matmodel.load_state_dict(
|
| 1230 |
+
# torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
|
| 1231 |
+
# map_location='cpu')['model'])
|
| 1232 |
+
|
| 1233 |
+
# matmodel = matmodel.cuda()
|
| 1234 |
+
# matmodel.eval()
|
| 1235 |
+
|
| 1236 |
+
rawimg = pimgs
|
| 1237 |
+
trimap = ptrimap
|
| 1238 |
+
rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
|
| 1239 |
+
trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
|
| 1240 |
+
trimap_nonp = trimap.copy()
|
| 1241 |
+
h, w, c = rawimg.shape
|
| 1242 |
+
nonph, nonpw, _ = rawimg.shape
|
| 1243 |
+
newh = (((h - 1) // 32) + 1) * 32
|
| 1244 |
+
neww = (((w - 1) // 32) + 1) * 32
|
| 1245 |
+
padh = newh - h
|
| 1246 |
+
padh1 = int(padh / 2)
|
| 1247 |
+
padh2 = padh - padh1
|
| 1248 |
+
padw = neww - w
|
| 1249 |
+
padw1 = int(padw / 2)
|
| 1250 |
+
padw2 = padw - padw1
|
| 1251 |
+
rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
|
| 1252 |
+
cv2.BORDER_REFLECT)
|
| 1253 |
+
trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
|
| 1254 |
+
cv2.BORDER_REFLECT)
|
| 1255 |
+
h_pad, w_pad, _ = rawimg_pad.shape
|
| 1256 |
+
tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
|
| 1257 |
+
tritemp[:, :, 0] = (trimap_pad == 0)
|
| 1258 |
+
tritemp[:, :, 1] = (trimap_pad == 128)
|
| 1259 |
+
tritemp[:, :, 2] = (trimap_pad == 255)
|
| 1260 |
+
tritempimgs = np.transpose(tritemp, (2, 0, 1))
|
| 1261 |
+
tritempimgs = tritempimgs[np.newaxis, :, :, :]
|
| 1262 |
+
img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
|
| 1263 |
+
img = np.array(img, np.float32)
|
| 1264 |
+
img = img / 255.
|
| 1265 |
+
img = torch.from_numpy(img).cuda()
|
| 1266 |
+
tritempimgs = torch.from_numpy(tritempimgs).cuda()
|
| 1267 |
+
with torch.no_grad():
|
| 1268 |
+
pred = matmodel(img, tritempimgs)
|
| 1269 |
+
pred = pred.detach().cpu().numpy()[0]
|
| 1270 |
+
pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
|
| 1271 |
+
preda = pred[
|
| 1272 |
+
0:1,
|
| 1273 |
+
] * 255
|
| 1274 |
+
preda = np.transpose(preda, (1, 2, 0))
|
| 1275 |
+
preda = preda * (trimap_nonp[:, :, None]
|
| 1276 |
+
== 128) + (trimap_nonp[:, :, None] == 255) * 255
|
| 1277 |
+
preda = np.array(preda, np.uint8)
|
| 1278 |
+
cv2.imwrite(p_outs, preda)
|
| 1279 |
+
|
| 1280 |
+
#+end_src
|
| 1281 |
+
|
| 1282 |
+
** Comfyui Dictionary
|
| 1283 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
|
| 1284 |
+
NODE_CLASS_MAPPINGS = {
|
| 1285 |
+
'load_AEMatter_Model': load_AEMatter_Model,
|
| 1286 |
+
'run_AEMatter_inference': run_AEMatter_inference,
|
| 1287 |
+
}
|
| 1288 |
+
|
| 1289 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 1290 |
+
'load_AEMatter_Model': 'load_AEMatter_Model',
|
| 1291 |
+
'run_AEMatter_inference': 'run_AEMatter_inference',
|
| 1292 |
+
}
|
| 1293 |
+
#+end_src
|
| 1294 |
+
|
| 1295 |
+
** COMMENT AEMatter.execute.py
|
| 1296 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
|
| 1297 |
+
if __name__ == '__main__':
|
| 1298 |
+
# main()
|
| 1299 |
+
|
| 1300 |
+
rawimg = cv2.imread('/home/asd/Desktop/demo/retriever_rgb.png',
|
| 1301 |
+
cv2.IMREAD_COLOR)
|
| 1302 |
+
|
| 1303 |
+
trimap = cv2.imread('/home/asd/Desktop/demo/retriever_trimap.png',
|
| 1304 |
+
cv2.IMREAD_GRAYSCALE)
|
| 1305 |
+
|
| 1306 |
+
do_infer(rawimg, trimap,
|
| 1307 |
+
get_AEMatter_model('/home/asd/Desktop/AEM_RWA.ckpt'))
|
| 1308 |
+
#+end_src
|
| 1309 |
+
|
| 1310 |
+
** AEMatter.unify.sh
|
| 1311 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.unify.sh
|
| 1312 |
+
. "${HOME}/dbnew.sh"
|
| 1313 |
+
|
| 1314 |
+
cat \
|
| 1315 |
+
'AEMatter.import.py' \
|
| 1316 |
+
'AEMatter.function.py' \
|
| 1317 |
+
'AEMatter.class.py' \
|
| 1318 |
+
'AEMatter.execute.py' \
|
| 1319 |
+
| expand | yapf3 \
|
| 1320 |
+
> 'AEMatter.py' \
|
| 1321 |
+
;
|
| 1322 |
+
|
| 1323 |
+
cp 'AEMatter.py' '__init__.py'
|
| 1324 |
+
#+end_src
|
| 1325 |
+
|
| 1326 |
+
** AEMatter.run.sh
|
| 1327 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.run.sh
|
| 1328 |
+
. "${HOME}/dbnew.sh"
|
| 1329 |
+
python3 './AEMatter.py'
|
| 1330 |
+
#+end_src
|
| 1331 |
+
|
| 1332 |
+
#+RESULTS:
|
| 1333 |
+
|
| 1334 |
+
* COMMENT WORK SPACE
|
| 1335 |
+
|
| 1336 |
+
** ESHELL
|
| 1337 |
+
#+begin_src elisp
|
| 1338 |
+
(save-buffer)
|
| 1339 |
+
(org-babel-tangle)
|
| 1340 |
+
(shell-command "./AEMatter.unify.sh")
|
| 1341 |
+
#+end_src
|
| 1342 |
+
|
| 1343 |
+
#+RESULTS:
|
| 1344 |
+
: 0
|
| 1345 |
+
|
| 1346 |
+
** SHELL
|
| 1347 |
+
#+begin_src sh :shebang #!/bin/sh :results output
|
| 1348 |
+
realpath .
|
| 1349 |
+
cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/AEMatter
|
| 1350 |
+
#+end_src
|
| 1351 |
+
|
| 1352 |
+
#+RESULTS:
|
| 1353 |
+
|
| 1354 |
+
** SHELL
|
| 1355 |
+
#+begin_src sh :shebang #!/bin/sh :results output
|
| 1356 |
+
ls
|
| 1357 |
+
#+end_src
|
ComfyUI_AEMatter/__init__.py
ADDED
|
@@ -0,0 +1,1248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import wget
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import init
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.checkpoint as checkpoint
|
| 14 |
+
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 18 |
+
|
| 19 |
+
import folder_paths
|
| 20 |
+
from folder_paths import models_dir
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
#!/usr/bin/python3
|
| 24 |
+
def mkdir_safe(out_path):
|
| 25 |
+
if type(out_path) == str:
|
| 26 |
+
if len(out_path) > 0:
|
| 27 |
+
if not os.path.exists(out_path):
|
| 28 |
+
os.mkdir(out_path)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_model_path():
|
| 32 |
+
import folder_paths
|
| 33 |
+
from folder_paths import models_dir
|
| 34 |
+
|
| 35 |
+
path_file_model = models_dir
|
| 36 |
+
mkdir_safe(out_path=path_file_model)
|
| 37 |
+
|
| 38 |
+
path_file_model = os.path.join(path_file_model, 'AEMatter')
|
| 39 |
+
mkdir_safe(out_path=path_file_model)
|
| 40 |
+
|
| 41 |
+
path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
|
| 42 |
+
|
| 43 |
+
return path_file_model
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def download_model(path):
|
| 47 |
+
if not os.path.exists(path):
|
| 48 |
+
wget.download(
|
| 49 |
+
'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
|
| 50 |
+
out=path)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def from_torch_image(image):
|
| 54 |
+
image = image.cpu().numpy() * 255.0
|
| 55 |
+
image = np.clip(image, 0, 255).astype(np.uint8)
|
| 56 |
+
return image
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def to_torch_image(image):
|
| 60 |
+
image = image.astype(dtype=np.float32)
|
| 61 |
+
image /= 255.0
|
| 62 |
+
image = torch.from_numpy(image)
|
| 63 |
+
return image
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def window_partition(x, window_size):
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
x: (B, H, W, C)
|
| 70 |
+
window_size (int): window size
|
| 71 |
+
Returns:
|
| 72 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 73 |
+
"""
|
| 74 |
+
B, H, W, C = x.shape
|
| 75 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
| 76 |
+
C)
|
| 77 |
+
windows = x.permute(0, 1, 3, 2, 4,
|
| 78 |
+
5).contiguous().view(-1, window_size, window_size, C)
|
| 79 |
+
return windows
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def window_reverse(windows, window_size, H, W):
|
| 83 |
+
"""
|
| 84 |
+
Args:
|
| 85 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 86 |
+
window_size (int): Window size
|
| 87 |
+
H (int): Height of image
|
| 88 |
+
W (int): Width of image
|
| 89 |
+
Returns:
|
| 90 |
+
x: (B, H, W, C)
|
| 91 |
+
"""
|
| 92 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 93 |
+
x = windows.view(B, H // window_size, W // window_size, window_size,
|
| 94 |
+
window_size, -1)
|
| 95 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_AEMatter_model(path_model_checkpoint):
|
| 100 |
+
|
| 101 |
+
download_model(path=path_model_checkpoint)
|
| 102 |
+
|
| 103 |
+
matmodel = AEMatter()
|
| 104 |
+
matmodel.load_state_dict(
|
| 105 |
+
torch.load(path_model_checkpoint, map_location='cpu')['model'])
|
| 106 |
+
|
| 107 |
+
matmodel = matmodel.cuda()
|
| 108 |
+
matmodel.eval()
|
| 109 |
+
|
| 110 |
+
return matmodel
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def do_infer(rawimg, trimap, matmodel):
|
| 114 |
+
trimap_nonp = trimap.copy()
|
| 115 |
+
h, w, c = rawimg.shape
|
| 116 |
+
nonph, nonpw, _ = rawimg.shape
|
| 117 |
+
newh = (((h - 1) // 32) + 1) * 32
|
| 118 |
+
neww = (((w - 1) // 32) + 1) * 32
|
| 119 |
+
padh = newh - h
|
| 120 |
+
padh1 = int(padh / 2)
|
| 121 |
+
padh2 = padh - padh1
|
| 122 |
+
padw = neww - w
|
| 123 |
+
padw1 = int(padw / 2)
|
| 124 |
+
padw2 = padw - padw1
|
| 125 |
+
|
| 126 |
+
rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
|
| 127 |
+
cv2.BORDER_REFLECT)
|
| 128 |
+
|
| 129 |
+
trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
|
| 130 |
+
cv2.BORDER_REFLECT)
|
| 131 |
+
|
| 132 |
+
h_pad, w_pad, _ = rawimg_pad.shape
|
| 133 |
+
tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
|
| 134 |
+
tritemp[:, :, 0] = (trimap_pad == 0)
|
| 135 |
+
tritemp[:, :, 1] = (trimap_pad == 128)
|
| 136 |
+
tritemp[:, :, 2] = (trimap_pad == 255)
|
| 137 |
+
tritempimgs = np.transpose(tritemp, (2, 0, 1))
|
| 138 |
+
tritempimgs = tritempimgs[np.newaxis, :, :, :]
|
| 139 |
+
img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
|
| 140 |
+
img = np.array(img, np.float32)
|
| 141 |
+
img = img / 255.
|
| 142 |
+
img = torch.from_numpy(img).cuda()
|
| 143 |
+
tritempimgs = torch.from_numpy(tritempimgs).cuda()
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
pred = matmodel(img, tritempimgs)
|
| 146 |
+
pred = pred.detach().cpu().numpy()[0]
|
| 147 |
+
pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
|
| 148 |
+
preda = pred[
|
| 149 |
+
0:1,
|
| 150 |
+
] * 255
|
| 151 |
+
preda = np.transpose(preda, (1, 2, 0))
|
| 152 |
+
preda = preda * (trimap_nonp[:, :, None]
|
| 153 |
+
== 128) + (trimap_nonp[:, :, None] == 255) * 255
|
| 154 |
+
preda = np.array(preda, np.uint8)
|
| 155 |
+
return preda
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def main():
|
| 159 |
+
ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
|
| 160 |
+
pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
|
| 161 |
+
p_outs = 'alpha.png'
|
| 162 |
+
|
| 163 |
+
matmodel = get_AEMatter_model(
|
| 164 |
+
path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
|
| 165 |
+
|
| 166 |
+
# matmodel = AEMatter()
|
| 167 |
+
# matmodel.load_state_dict(
|
| 168 |
+
# torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
|
| 169 |
+
# map_location='cpu')['model'])
|
| 170 |
+
|
| 171 |
+
# matmodel = matmodel.cuda()
|
| 172 |
+
# matmodel.eval()
|
| 173 |
+
|
| 174 |
+
rawimg = pimgs
|
| 175 |
+
trimap = ptrimap
|
| 176 |
+
rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
|
| 177 |
+
trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
|
| 178 |
+
trimap_nonp = trimap.copy()
|
| 179 |
+
h, w, c = rawimg.shape
|
| 180 |
+
nonph, nonpw, _ = rawimg.shape
|
| 181 |
+
newh = (((h - 1) // 32) + 1) * 32
|
| 182 |
+
neww = (((w - 1) // 32) + 1) * 32
|
| 183 |
+
padh = newh - h
|
| 184 |
+
padh1 = int(padh / 2)
|
| 185 |
+
padh2 = padh - padh1
|
| 186 |
+
padw = neww - w
|
| 187 |
+
padw1 = int(padw / 2)
|
| 188 |
+
padw2 = padw - padw1
|
| 189 |
+
rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
|
| 190 |
+
cv2.BORDER_REFLECT)
|
| 191 |
+
trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
|
| 192 |
+
cv2.BORDER_REFLECT)
|
| 193 |
+
h_pad, w_pad, _ = rawimg_pad.shape
|
| 194 |
+
tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
|
| 195 |
+
tritemp[:, :, 0] = (trimap_pad == 0)
|
| 196 |
+
tritemp[:, :, 1] = (trimap_pad == 128)
|
| 197 |
+
tritemp[:, :, 2] = (trimap_pad == 255)
|
| 198 |
+
tritempimgs = np.transpose(tritemp, (2, 0, 1))
|
| 199 |
+
tritempimgs = tritempimgs[np.newaxis, :, :, :]
|
| 200 |
+
img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
|
| 201 |
+
img = np.array(img, np.float32)
|
| 202 |
+
img = img / 255.
|
| 203 |
+
img = torch.from_numpy(img).cuda()
|
| 204 |
+
tritempimgs = torch.from_numpy(tritempimgs).cuda()
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
pred = matmodel(img, tritempimgs)
|
| 207 |
+
pred = pred.detach().cpu().numpy()[0]
|
| 208 |
+
pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
|
| 209 |
+
preda = pred[
|
| 210 |
+
0:1,
|
| 211 |
+
] * 255
|
| 212 |
+
preda = np.transpose(preda, (1, 2, 0))
|
| 213 |
+
preda = preda * (trimap_nonp[:, :, None]
|
| 214 |
+
== 128) + (trimap_nonp[:, :, None] == 255) * 255
|
| 215 |
+
preda = np.array(preda, np.uint8)
|
| 216 |
+
cv2.imwrite(p_outs, preda)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
#!/usr/bin/python3
|
| 220 |
+
class WindowAttention(nn.Module):
|
| 221 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 222 |
+
It supports both of shifted and non-shifted window.
|
| 223 |
+
Args:
|
| 224 |
+
dim (int): Number of input channels.
|
| 225 |
+
window_size (tuple[int]): The height and width of the window.
|
| 226 |
+
num_heads (int): Number of attention heads.
|
| 227 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 228 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 229 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 230 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(self,
|
| 234 |
+
dim,
|
| 235 |
+
window_size,
|
| 236 |
+
num_heads,
|
| 237 |
+
qkv_bias=True,
|
| 238 |
+
qk_scale=None,
|
| 239 |
+
attn_drop=0.,
|
| 240 |
+
proj_drop=0.):
|
| 241 |
+
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.dim = dim
|
| 244 |
+
self.window_size = window_size # Wh, Ww
|
| 245 |
+
self.num_heads = num_heads
|
| 246 |
+
head_dim = dim // num_heads
|
| 247 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 248 |
+
|
| 249 |
+
# define a parameter table of relative position bias
|
| 250 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 251 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
| 252 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 253 |
+
|
| 254 |
+
# get pair-wise relative position index for each token inside the window
|
| 255 |
+
coords_h = torch.arange(self.window_size[0])
|
| 256 |
+
coords_w = torch.arange(self.window_size[1])
|
| 257 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 258 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 259 |
+
relative_coords = coords_flatten[:, :,
|
| 260 |
+
None] - coords_flatten[:,
|
| 261 |
+
None, :] # 2, Wh*Ww, Wh*Ww
|
| 262 |
+
relative_coords = relative_coords.permute(
|
| 263 |
+
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 264 |
+
relative_coords[:, :,
|
| 265 |
+
0] += self.window_size[0] - 1 # shift to start from 0
|
| 266 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 267 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 268 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 269 |
+
self.register_buffer("relative_position_index",
|
| 270 |
+
relative_position_index)
|
| 271 |
+
|
| 272 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 273 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 274 |
+
self.proj = nn.Linear(dim, dim)
|
| 275 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 276 |
+
|
| 277 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 278 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 279 |
+
|
| 280 |
+
def forward(self, x, mask=None):
|
| 281 |
+
""" Forward function.
|
| 282 |
+
Args:
|
| 283 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 284 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 285 |
+
"""
|
| 286 |
+
B_, N, C = x.shape
|
| 287 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
| 288 |
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 289 |
+
q, k, v = qkv[0], qkv[1], qkv[
|
| 290 |
+
2] # make torchscript happy (cannot use tensor as tuple)
|
| 291 |
+
|
| 292 |
+
q = q * self.scale
|
| 293 |
+
attn = (q @ k.transpose(-2, -1))
|
| 294 |
+
|
| 295 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 296 |
+
self.relative_position_index.view(-1)].view(
|
| 297 |
+
self.window_size[0] * self.window_size[1],
|
| 298 |
+
self.window_size[0] * self.window_size[1],
|
| 299 |
+
-1) # Wh*Ww,Wh*Ww,nH
|
| 300 |
+
relative_position_bias = relative_position_bias.permute(
|
| 301 |
+
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 302 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 303 |
+
|
| 304 |
+
if mask is not None:
|
| 305 |
+
nW = mask.shape[0]
|
| 306 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
| 307 |
+
N) + mask.unsqueeze(1).unsqueeze(0)
|
| 308 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 309 |
+
attn = self.softmax(attn)
|
| 310 |
+
else:
|
| 311 |
+
attn = self.softmax(attn)
|
| 312 |
+
|
| 313 |
+
attn = self.attn_drop(attn)
|
| 314 |
+
|
| 315 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 316 |
+
x = self.proj(x)
|
| 317 |
+
x = self.proj_drop(x)
|
| 318 |
+
return x
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class SwinTransformerBlock(nn.Module):
|
| 322 |
+
""" Swin Transformer Block.
|
| 323 |
+
Args:
|
| 324 |
+
dim (int): Number of input channels.
|
| 325 |
+
num_heads (int): Number of attention heads.
|
| 326 |
+
window_size (int): Window size.
|
| 327 |
+
shift_size (int): Shift size for SW-MSA.
|
| 328 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 329 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 330 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 331 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 332 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 333 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 334 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 335 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(self,
|
| 339 |
+
dim,
|
| 340 |
+
num_heads,
|
| 341 |
+
window_size=7,
|
| 342 |
+
shift_size=0,
|
| 343 |
+
mlp_ratio=4.,
|
| 344 |
+
qkv_bias=True,
|
| 345 |
+
qk_scale=None,
|
| 346 |
+
drop=0.,
|
| 347 |
+
attn_drop=0.,
|
| 348 |
+
drop_path=0.,
|
| 349 |
+
act_layer=nn.GELU,
|
| 350 |
+
norm_layer=nn.LayerNorm):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.dim = dim
|
| 353 |
+
self.num_heads = num_heads
|
| 354 |
+
self.window_size = window_size
|
| 355 |
+
self.shift_size = shift_size
|
| 356 |
+
self.mlp_ratio = mlp_ratio
|
| 357 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 358 |
+
|
| 359 |
+
self.norm1 = norm_layer(dim)
|
| 360 |
+
self.attn = WindowAttention(dim,
|
| 361 |
+
window_size=to_2tuple(self.window_size),
|
| 362 |
+
num_heads=num_heads,
|
| 363 |
+
qkv_bias=qkv_bias,
|
| 364 |
+
qk_scale=qk_scale,
|
| 365 |
+
attn_drop=attn_drop,
|
| 366 |
+
proj_drop=drop)
|
| 367 |
+
|
| 368 |
+
self.drop_path = DropPath(
|
| 369 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 370 |
+
self.norm2 = norm_layer(dim)
|
| 371 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 372 |
+
self.mlp = Mlp(in_features=dim,
|
| 373 |
+
hidden_features=mlp_hidden_dim,
|
| 374 |
+
act_layer=act_layer,
|
| 375 |
+
drop=drop)
|
| 376 |
+
|
| 377 |
+
self.H = None
|
| 378 |
+
self.W = None
|
| 379 |
+
|
| 380 |
+
def forward(self, x, mask_matrix):
|
| 381 |
+
""" Forward function.
|
| 382 |
+
Args:
|
| 383 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 384 |
+
H, W: Spatial resolution of the input feature.
|
| 385 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 386 |
+
"""
|
| 387 |
+
B, L, C = x.shape
|
| 388 |
+
H, W = self.H, self.W
|
| 389 |
+
assert L == H * W, "input feature has wrong size"
|
| 390 |
+
|
| 391 |
+
shortcut = x
|
| 392 |
+
x = self.norm1(x)
|
| 393 |
+
x = x.view(B, H, W, C)
|
| 394 |
+
|
| 395 |
+
# pad feature maps to multiples of window size
|
| 396 |
+
pad_l = pad_t = 0
|
| 397 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 398 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 399 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 400 |
+
_, Hp, Wp, _ = x.shape
|
| 401 |
+
|
| 402 |
+
# cyclic shift
|
| 403 |
+
if self.shift_size > 0:
|
| 404 |
+
shifted_x = torch.roll(x,
|
| 405 |
+
shifts=(-self.shift_size, -self.shift_size),
|
| 406 |
+
dims=(1, 2))
|
| 407 |
+
attn_mask = mask_matrix
|
| 408 |
+
else:
|
| 409 |
+
shifted_x = x
|
| 410 |
+
attn_mask = None
|
| 411 |
+
|
| 412 |
+
# partition windows
|
| 413 |
+
x_windows = window_partition(
|
| 414 |
+
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 415 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
| 416 |
+
C) # nW*B, window_size*window_size, C
|
| 417 |
+
|
| 418 |
+
# W-MSA/SW-MSA
|
| 419 |
+
attn_windows = self.attn(
|
| 420 |
+
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 421 |
+
|
| 422 |
+
# merge windows
|
| 423 |
+
attn_windows = attn_windows.view(-1, self.window_size,
|
| 424 |
+
self.window_size, C)
|
| 425 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
| 426 |
+
Wp) # B H' W' C
|
| 427 |
+
|
| 428 |
+
# reverse cyclic shift
|
| 429 |
+
if self.shift_size > 0:
|
| 430 |
+
x = torch.roll(shifted_x,
|
| 431 |
+
shifts=(self.shift_size, self.shift_size),
|
| 432 |
+
dims=(1, 2))
|
| 433 |
+
else:
|
| 434 |
+
x = shifted_x
|
| 435 |
+
|
| 436 |
+
if pad_r > 0 or pad_b > 0:
|
| 437 |
+
x = x[:, :H, :W, :].contiguous()
|
| 438 |
+
|
| 439 |
+
x = x.view(B, H * W, C)
|
| 440 |
+
|
| 441 |
+
# FFN
|
| 442 |
+
x = shortcut + self.drop_path(x)
|
| 443 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 444 |
+
|
| 445 |
+
return x
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class PatchMerging(nn.Module):
|
| 449 |
+
""" Patch Merging Layer
|
| 450 |
+
Args:
|
| 451 |
+
dim (int): Number of input channels.
|
| 452 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 456 |
+
super().__init__()
|
| 457 |
+
self.dim = dim
|
| 458 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 459 |
+
self.norm = norm_layer(4 * dim)
|
| 460 |
+
|
| 461 |
+
def forward(self, x, H, W):
|
| 462 |
+
""" Forward function.
|
| 463 |
+
Args:
|
| 464 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 465 |
+
H, W: Spatial resolution of the input feature.
|
| 466 |
+
"""
|
| 467 |
+
B, L, C = x.shape
|
| 468 |
+
assert L == H * W, "input feature has wrong size"
|
| 469 |
+
|
| 470 |
+
x = x.view(B, H, W, C)
|
| 471 |
+
|
| 472 |
+
# padding
|
| 473 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 474 |
+
if pad_input:
|
| 475 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 476 |
+
|
| 477 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 478 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 479 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 480 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 481 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 482 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 483 |
+
|
| 484 |
+
x = self.norm(x)
|
| 485 |
+
x = self.reduction(x)
|
| 486 |
+
|
| 487 |
+
return x
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class BasicLayer(nn.Module):
|
| 491 |
+
""" A basic Swin Transformer layer for one stage.
|
| 492 |
+
Args:
|
| 493 |
+
dim (int): Number of feature channels
|
| 494 |
+
depth (int): Depths of this stage.
|
| 495 |
+
num_heads (int): Number of attention head.
|
| 496 |
+
window_size (int): Local window size. Default: 7.
|
| 497 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 498 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 499 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 500 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 501 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 502 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 503 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 504 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 505 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 506 |
+
"""
|
| 507 |
+
|
| 508 |
+
def __init__(self,
|
| 509 |
+
dim,
|
| 510 |
+
depth,
|
| 511 |
+
num_heads,
|
| 512 |
+
window_size=7,
|
| 513 |
+
mlp_ratio=4.,
|
| 514 |
+
qkv_bias=True,
|
| 515 |
+
qk_scale=None,
|
| 516 |
+
drop=0.,
|
| 517 |
+
attn_drop=0.,
|
| 518 |
+
drop_path=0.,
|
| 519 |
+
norm_layer=nn.LayerNorm,
|
| 520 |
+
downsample=None,
|
| 521 |
+
use_checkpoint=False):
|
| 522 |
+
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.window_size = window_size
|
| 525 |
+
self.shift_size = window_size // 2
|
| 526 |
+
self.depth = depth
|
| 527 |
+
self.use_checkpoint = use_checkpoint
|
| 528 |
+
|
| 529 |
+
# build blocks
|
| 530 |
+
self.blocks = nn.ModuleList([
|
| 531 |
+
SwinTransformerBlock(dim=dim,
|
| 532 |
+
num_heads=num_heads,
|
| 533 |
+
window_size=window_size,
|
| 534 |
+
shift_size=0 if
|
| 535 |
+
(i % 2 == 0) else window_size // 2,
|
| 536 |
+
mlp_ratio=mlp_ratio,
|
| 537 |
+
qkv_bias=qkv_bias,
|
| 538 |
+
qk_scale=qk_scale,
|
| 539 |
+
drop=drop,
|
| 540 |
+
attn_drop=attn_drop,
|
| 541 |
+
drop_path=drop_path[i] if isinstance(
|
| 542 |
+
drop_path, list) else drop_path,
|
| 543 |
+
norm_layer=norm_layer) for i in range(depth)
|
| 544 |
+
])
|
| 545 |
+
|
| 546 |
+
# patch merging layer
|
| 547 |
+
if downsample is not None:
|
| 548 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 549 |
+
else:
|
| 550 |
+
self.downsample = None
|
| 551 |
+
|
| 552 |
+
def forward(self, x, H, W):
|
| 553 |
+
""" Forward function.
|
| 554 |
+
Args:
|
| 555 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 556 |
+
H, W: Spatial resolution of the input feature.
|
| 557 |
+
"""
|
| 558 |
+
# print(x.shape,H,W)
|
| 559 |
+
# calculate attention mask for SW-MSA
|
| 560 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 561 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 562 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 563 |
+
h_slices = (slice(0, -self.window_size),
|
| 564 |
+
slice(-self.window_size,
|
| 565 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 566 |
+
w_slices = (slice(0, -self.window_size),
|
| 567 |
+
slice(-self.window_size,
|
| 568 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 569 |
+
cnt = 0
|
| 570 |
+
for h in h_slices:
|
| 571 |
+
for w in w_slices:
|
| 572 |
+
img_mask[:, h, w, :] = cnt
|
| 573 |
+
cnt += 1
|
| 574 |
+
|
| 575 |
+
mask_windows = window_partition(
|
| 576 |
+
img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 577 |
+
|
| 578 |
+
mask_windows = mask_windows.view(-1,
|
| 579 |
+
self.window_size * self.window_size)
|
| 580 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
|
| 581 |
+
2) # nW, ww window_size*window_size
|
| 582 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
| 583 |
+
float(-100.0)).masked_fill(
|
| 584 |
+
attn_mask == 0, float(0.0))
|
| 585 |
+
|
| 586 |
+
for blk in self.blocks:
|
| 587 |
+
blk.H, blk.W = H, W
|
| 588 |
+
if self.use_checkpoint:
|
| 589 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 590 |
+
else:
|
| 591 |
+
x = blk(x, attn_mask)
|
| 592 |
+
|
| 593 |
+
if self.downsample is not None:
|
| 594 |
+
x_down = self.downsample(x, H, W)
|
| 595 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 596 |
+
return x, H, W, x_down, Wh, Ww
|
| 597 |
+
else:
|
| 598 |
+
return x, H, W, x, H, W
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class PatchEmbed(nn.Module):
|
| 602 |
+
""" Image to Patch Embedding
|
| 603 |
+
Args:
|
| 604 |
+
patch_size (int): Patch token size. Default: 4.
|
| 605 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 606 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 607 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 608 |
+
"""
|
| 609 |
+
|
| 610 |
+
def __init__(self,
|
| 611 |
+
patch_size=4,
|
| 612 |
+
in_chans=3,
|
| 613 |
+
embed_dim=96,
|
| 614 |
+
norm_layer=None):
|
| 615 |
+
|
| 616 |
+
super().__init__()
|
| 617 |
+
patch_size = to_2tuple(patch_size)
|
| 618 |
+
self.patch_size = patch_size
|
| 619 |
+
|
| 620 |
+
self.in_chans = in_chans
|
| 621 |
+
self.embed_dim = embed_dim
|
| 622 |
+
|
| 623 |
+
self.proj = nn.Conv2d(in_chans,
|
| 624 |
+
embed_dim,
|
| 625 |
+
kernel_size=patch_size,
|
| 626 |
+
stride=patch_size)
|
| 627 |
+
if norm_layer is not None:
|
| 628 |
+
self.norm = norm_layer(embed_dim)
|
| 629 |
+
else:
|
| 630 |
+
self.norm = None
|
| 631 |
+
|
| 632 |
+
def forward(self, x):
|
| 633 |
+
"""Forward function."""
|
| 634 |
+
# padding
|
| 635 |
+
_, _, H, W = x.size()
|
| 636 |
+
if W % self.patch_size[1] != 0:
|
| 637 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 638 |
+
if H % self.patch_size[0] != 0:
|
| 639 |
+
x = F.pad(x,
|
| 640 |
+
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 641 |
+
|
| 642 |
+
x = self.proj(x) # B C Wh Ww
|
| 643 |
+
if self.norm is not None:
|
| 644 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 645 |
+
x = x.flatten(2).transpose(1, 2)
|
| 646 |
+
x = self.norm(x)
|
| 647 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 648 |
+
|
| 649 |
+
return x
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class SwinTransformer(nn.Module):
|
| 653 |
+
""" Swin Transformer backbone.
|
| 654 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 655 |
+
https://arxiv.org/pdf/2103.14030
|
| 656 |
+
Args:
|
| 657 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 658 |
+
used in absolute postion embedding. Default 224.
|
| 659 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 660 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 661 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 662 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 663 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 664 |
+
window_size (int): Window size. Default: 7.
|
| 665 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 666 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 667 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 668 |
+
drop_rate (float): Dropout rate.
|
| 669 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 670 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 671 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 672 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 673 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 674 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 675 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 676 |
+
-1 means not freezing any parameters.
|
| 677 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
def __init__(self,
|
| 681 |
+
pretrain_img_size=224,
|
| 682 |
+
patch_size=4,
|
| 683 |
+
in_chans=3,
|
| 684 |
+
embed_dim=96,
|
| 685 |
+
depths=[2, 2, 6, 2],
|
| 686 |
+
num_heads=[3, 6, 12, 24],
|
| 687 |
+
window_size=7,
|
| 688 |
+
mlp_ratio=4.,
|
| 689 |
+
qkv_bias=True,
|
| 690 |
+
qk_scale=None,
|
| 691 |
+
drop_rate=0.,
|
| 692 |
+
attn_drop_rate=0.,
|
| 693 |
+
drop_path_rate=0.2,
|
| 694 |
+
norm_layer=nn.LayerNorm,
|
| 695 |
+
ape=False,
|
| 696 |
+
patch_norm=True,
|
| 697 |
+
out_indices=(0, 1, 2, 3),
|
| 698 |
+
frozen_stages=-1,
|
| 699 |
+
use_checkpoint=False):
|
| 700 |
+
|
| 701 |
+
super().__init__()
|
| 702 |
+
|
| 703 |
+
self.pretrain_img_size = pretrain_img_size
|
| 704 |
+
self.num_layers = len(depths)
|
| 705 |
+
self.embed_dim = embed_dim
|
| 706 |
+
self.ape = ape
|
| 707 |
+
self.patch_norm = patch_norm
|
| 708 |
+
self.out_indices = out_indices
|
| 709 |
+
self.frozen_stages = frozen_stages
|
| 710 |
+
|
| 711 |
+
# split image into non-overlapping patches
|
| 712 |
+
self.patch_embed = PatchEmbed(
|
| 713 |
+
patch_size=patch_size,
|
| 714 |
+
in_chans=in_chans,
|
| 715 |
+
embed_dim=embed_dim,
|
| 716 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 717 |
+
|
| 718 |
+
# absolute position embedding
|
| 719 |
+
if self.ape:
|
| 720 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 721 |
+
patch_size = to_2tuple(patch_size)
|
| 722 |
+
patches_resolution = [
|
| 723 |
+
pretrain_img_size[0] // patch_size[0],
|
| 724 |
+
pretrain_img_size[1] // patch_size[1]
|
| 725 |
+
]
|
| 726 |
+
|
| 727 |
+
self.absolute_pos_embed = nn.Parameter(
|
| 728 |
+
torch.zeros(1, embed_dim, patches_resolution[0],
|
| 729 |
+
patches_resolution[1]))
|
| 730 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 731 |
+
|
| 732 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 733 |
+
|
| 734 |
+
# stochastic depth
|
| 735 |
+
dpr = [
|
| 736 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 737 |
+
] # stochastic depth decay rule
|
| 738 |
+
|
| 739 |
+
# build layers
|
| 740 |
+
self.layers = nn.ModuleList()
|
| 741 |
+
for i_layer in range(self.num_layers):
|
| 742 |
+
layer = BasicLayer(
|
| 743 |
+
dim=int(embed_dim * 2**i_layer),
|
| 744 |
+
depth=depths[i_layer],
|
| 745 |
+
num_heads=num_heads[i_layer],
|
| 746 |
+
window_size=window_size,
|
| 747 |
+
mlp_ratio=mlp_ratio,
|
| 748 |
+
qkv_bias=qkv_bias,
|
| 749 |
+
qk_scale=qk_scale,
|
| 750 |
+
drop=drop_rate,
|
| 751 |
+
attn_drop=attn_drop_rate,
|
| 752 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 753 |
+
norm_layer=norm_layer,
|
| 754 |
+
downsample=PatchMerging if
|
| 755 |
+
(i_layer < self.num_layers - 1) else None,
|
| 756 |
+
use_checkpoint=use_checkpoint)
|
| 757 |
+
self.layers.append(layer)
|
| 758 |
+
|
| 759 |
+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
| 760 |
+
self.num_features = num_features
|
| 761 |
+
|
| 762 |
+
# add a norm layer for each output
|
| 763 |
+
for i_layer in out_indices:
|
| 764 |
+
layer = norm_layer(num_features[i_layer])
|
| 765 |
+
layer_name = f'norm{i_layer}'
|
| 766 |
+
self.add_module(layer_name, layer)
|
| 767 |
+
|
| 768 |
+
self._freeze_stages()
|
| 769 |
+
|
| 770 |
+
def _freeze_stages(self):
|
| 771 |
+
if self.frozen_stages >= 0:
|
| 772 |
+
self.patch_embed.eval()
|
| 773 |
+
for param in self.patch_embed.parameters():
|
| 774 |
+
param.requires_grad = False
|
| 775 |
+
|
| 776 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 777 |
+
self.absolute_pos_embed.requires_grad = False
|
| 778 |
+
|
| 779 |
+
if self.frozen_stages >= 2:
|
| 780 |
+
self.pos_drop.eval()
|
| 781 |
+
for i in range(0, self.frozen_stages - 1):
|
| 782 |
+
m = self.layers[i]
|
| 783 |
+
m.eval()
|
| 784 |
+
for param in m.parameters():
|
| 785 |
+
param.requires_grad = False
|
| 786 |
+
|
| 787 |
+
def init_weights(self, pretrained=None):
|
| 788 |
+
"""Initialize the weights in backbone.
|
| 789 |
+
Args:
|
| 790 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 791 |
+
Defaults to None.
|
| 792 |
+
"""
|
| 793 |
+
|
| 794 |
+
def forward(self, x):
|
| 795 |
+
"""Forward function."""
|
| 796 |
+
x = self.patch_embed(x)
|
| 797 |
+
|
| 798 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 799 |
+
if self.ape:
|
| 800 |
+
# interpolate the position embedding to the corresponding size
|
| 801 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
|
| 802 |
+
size=(Wh, Ww),
|
| 803 |
+
mode='bicubic')
|
| 804 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1,
|
| 805 |
+
2) # B Wh*Ww C
|
| 806 |
+
else:
|
| 807 |
+
x = x.flatten(2).transpose(1, 2)
|
| 808 |
+
x = self.pos_drop(x)
|
| 809 |
+
|
| 810 |
+
outs = []
|
| 811 |
+
for i in range(self.num_layers):
|
| 812 |
+
layer = self.layers[i]
|
| 813 |
+
|
| 814 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 815 |
+
|
| 816 |
+
if i in self.out_indices:
|
| 817 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 818 |
+
x_out = norm_layer(x_out)
|
| 819 |
+
|
| 820 |
+
out = x_out.view(-1, H, W,
|
| 821 |
+
self.num_features[i]).permute(0, 3, 1,
|
| 822 |
+
2).contiguous()
|
| 823 |
+
outs.append(out)
|
| 824 |
+
|
| 825 |
+
return tuple(outs)
|
| 826 |
+
|
| 827 |
+
def train(self, mode=True):
|
| 828 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 829 |
+
super(SwinTransformer, self).train(mode)
|
| 830 |
+
self._freeze_stages()
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
class Mlp(nn.Module):
|
| 834 |
+
""" Multilayer perceptron."""
|
| 835 |
+
|
| 836 |
+
def __init__(self,
|
| 837 |
+
in_features,
|
| 838 |
+
hidden_features=None,
|
| 839 |
+
out_features=None,
|
| 840 |
+
act_layer=nn.GELU,
|
| 841 |
+
drop=0.):
|
| 842 |
+
super().__init__()
|
| 843 |
+
out_features = out_features or in_features
|
| 844 |
+
hidden_features = hidden_features or in_features
|
| 845 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 846 |
+
self.act = act_layer()
|
| 847 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 848 |
+
self.drop = nn.Dropout(drop)
|
| 849 |
+
|
| 850 |
+
def forward(self, x):
|
| 851 |
+
x = self.fc1(x)
|
| 852 |
+
x = self.act(x)
|
| 853 |
+
x = self.drop(x)
|
| 854 |
+
x = self.fc2(x)
|
| 855 |
+
x = self.drop(x)
|
| 856 |
+
return x
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class ResBlock(nn.Module):
|
| 860 |
+
|
| 861 |
+
def __init__(self, inc, midc):
|
| 862 |
+
super(ResBlock, self).__init__()
|
| 863 |
+
self.conv1 = nn.Conv2d(inc,
|
| 864 |
+
midc,
|
| 865 |
+
kernel_size=1,
|
| 866 |
+
stride=1,
|
| 867 |
+
padding=0,
|
| 868 |
+
bias=True)
|
| 869 |
+
self.gn1 = nn.GroupNorm(16, midc)
|
| 870 |
+
self.conv2 = nn.Conv2d(midc,
|
| 871 |
+
midc,
|
| 872 |
+
kernel_size=3,
|
| 873 |
+
stride=1,
|
| 874 |
+
padding=1,
|
| 875 |
+
bias=True)
|
| 876 |
+
self.gn2 = nn.GroupNorm(16, midc)
|
| 877 |
+
self.conv3 = nn.Conv2d(midc,
|
| 878 |
+
inc,
|
| 879 |
+
kernel_size=1,
|
| 880 |
+
stride=1,
|
| 881 |
+
padding=0,
|
| 882 |
+
bias=True)
|
| 883 |
+
self.relu = nn.LeakyReLU(0.1)
|
| 884 |
+
|
| 885 |
+
def forward(self, x):
|
| 886 |
+
x_ = x
|
| 887 |
+
x = self.conv1(x)
|
| 888 |
+
x = self.gn1(x)
|
| 889 |
+
x = self.relu(x)
|
| 890 |
+
x = self.conv2(x)
|
| 891 |
+
x = self.gn2(x)
|
| 892 |
+
x = self.relu(x)
|
| 893 |
+
x = self.conv3(x)
|
| 894 |
+
x = x + x_
|
| 895 |
+
x = self.relu(x)
|
| 896 |
+
return x
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
class AEALblock(nn.Module):
|
| 900 |
+
|
| 901 |
+
def __init__(self,
|
| 902 |
+
d_model,
|
| 903 |
+
nhead,
|
| 904 |
+
dim_feedforward=512,
|
| 905 |
+
dropout=0.0,
|
| 906 |
+
layer_norm_eps=1e-5,
|
| 907 |
+
batch_first=True,
|
| 908 |
+
norm_first=False,
|
| 909 |
+
width=5):
|
| 910 |
+
super(AEALblock, self).__init__()
|
| 911 |
+
self.self_attn2 = nn.MultiheadAttention(d_model // 2,
|
| 912 |
+
nhead // 2,
|
| 913 |
+
dropout=dropout,
|
| 914 |
+
batch_first=batch_first)
|
| 915 |
+
self.self_attn1 = nn.MultiheadAttention(d_model // 2,
|
| 916 |
+
nhead // 2,
|
| 917 |
+
dropout=dropout,
|
| 918 |
+
batch_first=batch_first)
|
| 919 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 920 |
+
self.dropout = nn.Dropout(dropout)
|
| 921 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 922 |
+
self.norm_first = norm_first
|
| 923 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 924 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 925 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 926 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 927 |
+
self.activation = nn.ReLU()
|
| 928 |
+
self.width = width
|
| 929 |
+
self.trans = nn.Sequential(
|
| 930 |
+
nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
|
| 931 |
+
ResBlock(d_model // 2, d_model // 4),
|
| 932 |
+
nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
|
| 933 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
| 934 |
+
|
| 935 |
+
def forward(
|
| 936 |
+
self,
|
| 937 |
+
src,
|
| 938 |
+
feats,
|
| 939 |
+
):
|
| 940 |
+
src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
|
| 941 |
+
b, c, h, w = src.shape
|
| 942 |
+
x1 = src[:, 0:c // 2]
|
| 943 |
+
x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
|
| 944 |
+
x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
|
| 945 |
+
x2 = src[:, c // 2:]
|
| 946 |
+
x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
|
| 947 |
+
x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
|
| 948 |
+
x = rearrange(src, 'b c h w-> b (h w) c')
|
| 949 |
+
x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
|
| 950 |
+
x = self.norm2(x + self._ff_block(x))
|
| 951 |
+
x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
|
| 952 |
+
return x
|
| 953 |
+
|
| 954 |
+
def _sa_block(self, x1, x2, h, w):
|
| 955 |
+
x1 = self.self_attn1(x1,
|
| 956 |
+
x1,
|
| 957 |
+
x1,
|
| 958 |
+
attn_mask=None,
|
| 959 |
+
key_padding_mask=None,
|
| 960 |
+
need_weights=False)[0]
|
| 961 |
+
|
| 962 |
+
x2 = self.self_attn2(x2,
|
| 963 |
+
x2,
|
| 964 |
+
x2,
|
| 965 |
+
attn_mask=None,
|
| 966 |
+
key_padding_mask=None,
|
| 967 |
+
need_weights=False)[0]
|
| 968 |
+
|
| 969 |
+
x1 = rearrange(x1,
|
| 970 |
+
'(b h1) (h2 w) c-> b (h1 h2 w) c',
|
| 971 |
+
h2=self.width,
|
| 972 |
+
h1=h // self.width)
|
| 973 |
+
x2 = rearrange(x2,
|
| 974 |
+
' (b w1) (h w2) c-> b (h w1 w2) c',
|
| 975 |
+
w2=self.width,
|
| 976 |
+
w1=w // self.width)
|
| 977 |
+
x = torch.cat([x1, x2], dim=2)
|
| 978 |
+
return self.dropout1(x)
|
| 979 |
+
|
| 980 |
+
def _ff_block(self, x):
|
| 981 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 982 |
+
return self.dropout2(x)
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
class AEMatter(nn.Module):
|
| 986 |
+
|
| 987 |
+
def __init__(self):
|
| 988 |
+
super(AEMatter, self).__init__()
|
| 989 |
+
trans = SwinTransformer(pretrain_img_size=224,
|
| 990 |
+
embed_dim=96,
|
| 991 |
+
depths=[2, 2, 6, 2],
|
| 992 |
+
num_heads=[3, 6, 12, 24],
|
| 993 |
+
window_size=7,
|
| 994 |
+
ape=False,
|
| 995 |
+
drop_path_rate=0.2,
|
| 996 |
+
patch_norm=True,
|
| 997 |
+
use_checkpoint=False)
|
| 998 |
+
|
| 999 |
+
# trans.load_state_dict(torch.load(
|
| 1000 |
+
# '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
|
| 1001 |
+
# map_location="cpu")["model"],
|
| 1002 |
+
# strict=False)
|
| 1003 |
+
|
| 1004 |
+
trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
|
| 1005 |
+
|
| 1006 |
+
self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
|
| 1007 |
+
nn.PReLU(48))
|
| 1008 |
+
|
| 1009 |
+
self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
|
| 1010 |
+
1), nn.PReLU(64),
|
| 1011 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 1012 |
+
nn.PReLU(64))
|
| 1013 |
+
|
| 1014 |
+
self.trans = trans
|
| 1015 |
+
self.conv1 = nn.Sequential(
|
| 1016 |
+
nn.Conv2d(in_channels=640 + 768,
|
| 1017 |
+
out_channels=256,
|
| 1018 |
+
kernel_size=1,
|
| 1019 |
+
stride=1,
|
| 1020 |
+
padding=0,
|
| 1021 |
+
bias=True))
|
| 1022 |
+
self.conv2 = nn.Sequential(
|
| 1023 |
+
nn.Conv2d(in_channels=256 + 384,
|
| 1024 |
+
out_channels=256,
|
| 1025 |
+
kernel_size=1,
|
| 1026 |
+
stride=1,
|
| 1027 |
+
padding=0,
|
| 1028 |
+
bias=True), )
|
| 1029 |
+
self.conv3 = nn.Sequential(
|
| 1030 |
+
nn.Conv2d(in_channels=256 + 192,
|
| 1031 |
+
out_channels=192,
|
| 1032 |
+
kernel_size=1,
|
| 1033 |
+
stride=1,
|
| 1034 |
+
padding=0,
|
| 1035 |
+
bias=True), )
|
| 1036 |
+
self.conv4 = nn.Sequential(
|
| 1037 |
+
nn.Conv2d(in_channels=192 + 96,
|
| 1038 |
+
out_channels=128,
|
| 1039 |
+
kernel_size=1,
|
| 1040 |
+
stride=1,
|
| 1041 |
+
padding=0,
|
| 1042 |
+
bias=True), )
|
| 1043 |
+
self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
|
| 1044 |
+
self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
|
| 1045 |
+
self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
|
| 1046 |
+
self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
|
| 1047 |
+
self.conv5 = nn.Sequential(
|
| 1048 |
+
nn.Conv2d(in_channels=192,
|
| 1049 |
+
out_channels=64,
|
| 1050 |
+
kernel_size=3,
|
| 1051 |
+
stride=1,
|
| 1052 |
+
padding=1,
|
| 1053 |
+
bias=True), nn.PReLU(64),
|
| 1054 |
+
nn.Conv2d(in_channels=64,
|
| 1055 |
+
out_channels=64,
|
| 1056 |
+
kernel_size=3,
|
| 1057 |
+
stride=1,
|
| 1058 |
+
padding=1,
|
| 1059 |
+
bias=True), nn.PReLU(64),
|
| 1060 |
+
nn.Conv2d(in_channels=64,
|
| 1061 |
+
out_channels=48,
|
| 1062 |
+
kernel_size=3,
|
| 1063 |
+
stride=1,
|
| 1064 |
+
padding=1,
|
| 1065 |
+
bias=True), nn.PReLU(48))
|
| 1066 |
+
self.convo = nn.Sequential(
|
| 1067 |
+
nn.Conv2d(in_channels=48 + 48 + 6,
|
| 1068 |
+
out_channels=32,
|
| 1069 |
+
kernel_size=3,
|
| 1070 |
+
stride=1,
|
| 1071 |
+
padding=1,
|
| 1072 |
+
bias=True), nn.PReLU(32),
|
| 1073 |
+
nn.Conv2d(in_channels=32,
|
| 1074 |
+
out_channels=32,
|
| 1075 |
+
kernel_size=3,
|
| 1076 |
+
stride=1,
|
| 1077 |
+
padding=1,
|
| 1078 |
+
bias=True), nn.PReLU(32),
|
| 1079 |
+
nn.Conv2d(in_channels=32,
|
| 1080 |
+
out_channels=1,
|
| 1081 |
+
kernel_size=3,
|
| 1082 |
+
stride=1,
|
| 1083 |
+
padding=1,
|
| 1084 |
+
bias=True))
|
| 1085 |
+
self.up = nn.Upsample(scale_factor=2,
|
| 1086 |
+
mode='bilinear',
|
| 1087 |
+
align_corners=False)
|
| 1088 |
+
self.upn = nn.Upsample(scale_factor=2, mode='nearest')
|
| 1089 |
+
self.apptrans = nn.Sequential(
|
| 1090 |
+
nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
|
| 1091 |
+
ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
|
| 1092 |
+
ResBlock(512, 128))
|
| 1093 |
+
self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
|
| 1094 |
+
ResBlock(640, 160))
|
| 1095 |
+
self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
|
| 1096 |
+
self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
|
| 1097 |
+
self.width = 5
|
| 1098 |
+
self.trans1 = AEALblock(d_model=640,
|
| 1099 |
+
nhead=20,
|
| 1100 |
+
dim_feedforward=2048,
|
| 1101 |
+
dropout=0.2,
|
| 1102 |
+
width=self.width)
|
| 1103 |
+
self.trans2 = AEALblock(d_model=640,
|
| 1104 |
+
nhead=20,
|
| 1105 |
+
dim_feedforward=2048,
|
| 1106 |
+
dropout=0.2,
|
| 1107 |
+
width=self.width)
|
| 1108 |
+
self.trans3 = AEALblock(d_model=640,
|
| 1109 |
+
nhead=20,
|
| 1110 |
+
dim_feedforward=2048,
|
| 1111 |
+
dropout=0.2,
|
| 1112 |
+
width=self.width)
|
| 1113 |
+
|
| 1114 |
+
def aeal(self, x, sem):
|
| 1115 |
+
xe = self.emb(x)
|
| 1116 |
+
x_ = xe
|
| 1117 |
+
x_ = self.embdp(x_)
|
| 1118 |
+
b, c, h1, w1 = x_.shape
|
| 1119 |
+
bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
|
| 1120 |
+
bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
|
| 1121 |
+
newph1 = bnew_ph // 2
|
| 1122 |
+
newph2 = bnew_ph - newph1
|
| 1123 |
+
newpw1 = bnew_pw // 2
|
| 1124 |
+
newpw2 = bnew_pw - newpw1
|
| 1125 |
+
x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
|
| 1126 |
+
sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
|
| 1127 |
+
x_ = self.trans1(x_, sem)
|
| 1128 |
+
x_ = self.trans2(x_, sem)
|
| 1129 |
+
x_ = self.trans3(x_, sem)
|
| 1130 |
+
x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
|
| 1131 |
+
return x_
|
| 1132 |
+
|
| 1133 |
+
def forward(self, x, y):
|
| 1134 |
+
inputs = torch.cat((x, y), 1)
|
| 1135 |
+
x = self.start_conv0(inputs)
|
| 1136 |
+
x_ = self.start_conv(x)
|
| 1137 |
+
x1, x2, x3, x4 = self.trans(x_)
|
| 1138 |
+
x4h = self.h2l(x4)
|
| 1139 |
+
x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
|
| 1140 |
+
x4_ = self.aeal(x4, x3s)
|
| 1141 |
+
x4 = torch.cat((x4, x4_), 1)
|
| 1142 |
+
X4 = self.conv1(x4)
|
| 1143 |
+
wh, ww = X4.shape[2], X4.shape[3]
|
| 1144 |
+
X4 = rearrange(X4, 'b c h w -> b (h w) c')
|
| 1145 |
+
X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
|
| 1146 |
+
X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1147 |
+
X3 = self.up(X4)
|
| 1148 |
+
X3 = torch.cat((x3, X3), 1)
|
| 1149 |
+
X3 = self.conv2(X3)
|
| 1150 |
+
wh, ww = X3.shape[2], X3.shape[3]
|
| 1151 |
+
X3 = rearrange(X3, 'b c h w -> b (h w) c')
|
| 1152 |
+
X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
|
| 1153 |
+
X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1154 |
+
X2 = self.up(X3)
|
| 1155 |
+
X2 = torch.cat((x2, X2), 1)
|
| 1156 |
+
X2 = self.conv3(X2)
|
| 1157 |
+
wh, ww = X2.shape[2], X2.shape[3]
|
| 1158 |
+
X2 = rearrange(X2, 'b c h w -> b (h w) c')
|
| 1159 |
+
X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
|
| 1160 |
+
X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1161 |
+
X1 = self.up(X2)
|
| 1162 |
+
X1 = torch.cat((x1, X1), 1)
|
| 1163 |
+
X1 = self.conv4(X1)
|
| 1164 |
+
wh, ww = X1.shape[2], X1.shape[3]
|
| 1165 |
+
X1 = rearrange(X1, 'b c h w -> b (h w) c')
|
| 1166 |
+
X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
|
| 1167 |
+
X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
|
| 1168 |
+
X0 = self.up(X1)
|
| 1169 |
+
X0 = torch.cat((x_, X0), 1)
|
| 1170 |
+
X0 = self.conv5(X0)
|
| 1171 |
+
X = self.up(X0)
|
| 1172 |
+
X = torch.cat((inputs, x, X), 1)
|
| 1173 |
+
alpha = self.convo(X)
|
| 1174 |
+
alpha = torch.clamp(alpha, min=0, max=1)
|
| 1175 |
+
return alpha
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
class load_AEMatter_Model:
|
| 1179 |
+
|
| 1180 |
+
def __init__(self):
|
| 1181 |
+
pass
|
| 1182 |
+
|
| 1183 |
+
@classmethod
|
| 1184 |
+
def INPUT_TYPES(s):
|
| 1185 |
+
return {
|
| 1186 |
+
"required": {},
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
RETURN_TYPES = ("AEMatter_Model", )
|
| 1190 |
+
FUNCTION = "test"
|
| 1191 |
+
CATEGORY = "AEMatter"
|
| 1192 |
+
|
| 1193 |
+
def test(self):
|
| 1194 |
+
return (get_AEMatter_model(get_model_path()), )
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
class run_AEMatter_inference:
|
| 1198 |
+
|
| 1199 |
+
def __init__(self):
|
| 1200 |
+
pass
|
| 1201 |
+
|
| 1202 |
+
@classmethod
|
| 1203 |
+
def INPUT_TYPES(s):
|
| 1204 |
+
return {
|
| 1205 |
+
"required": {
|
| 1206 |
+
"image": ("IMAGE", ),
|
| 1207 |
+
"trimap": ("MASK", ),
|
| 1208 |
+
"AEMatter_Model": ("AEMatter_Model", ),
|
| 1209 |
+
},
|
| 1210 |
+
}
|
| 1211 |
+
|
| 1212 |
+
RETURN_TYPES = ("MASK", )
|
| 1213 |
+
FUNCTION = "test"
|
| 1214 |
+
CATEGORY = "AEMatter"
|
| 1215 |
+
|
| 1216 |
+
def test(
|
| 1217 |
+
self,
|
| 1218 |
+
image,
|
| 1219 |
+
trimap,
|
| 1220 |
+
AEMatter_Model,
|
| 1221 |
+
):
|
| 1222 |
+
|
| 1223 |
+
ret = []
|
| 1224 |
+
batch_size = image.shape[0]
|
| 1225 |
+
|
| 1226 |
+
for i in range(batch_size):
|
| 1227 |
+
tmp_i = from_torch_image(image[i])
|
| 1228 |
+
tmp_m = from_torch_image(trimap[i])
|
| 1229 |
+
tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
|
| 1230 |
+
ret.append(tmp)
|
| 1231 |
+
|
| 1232 |
+
ret = to_torch_image(np.array(ret))
|
| 1233 |
+
ret = ret.squeeze(-1)
|
| 1234 |
+
print(ret.shape)
|
| 1235 |
+
|
| 1236 |
+
return ret
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
#!/usr/bin/python3
|
| 1240 |
+
NODE_CLASS_MAPPINGS = {
|
| 1241 |
+
'load_AEMatter_Model': load_AEMatter_Model,
|
| 1242 |
+
'run_AEMatter_inference': run_AEMatter_inference,
|
| 1243 |
+
}
|
| 1244 |
+
|
| 1245 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 1246 |
+
'load_AEMatter_Model': 'load_AEMatter_Model',
|
| 1247 |
+
'run_AEMatter_inference': 'run_AEMatter_inference',
|
| 1248 |
+
}
|
ComfyUI_MVANet/MVANet_inference.py
ADDED
|
@@ -0,0 +1,1548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
HOME_DIR = os.environ.get('HOME', '/root')
|
| 6 |
+
MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
|
| 7 |
+
finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
|
| 8 |
+
pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
import cv2
|
| 13 |
+
import wget
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.utils.checkpoint as checkpoint
|
| 19 |
+
from torch.autograd import Variable
|
| 20 |
+
from torch import nn
|
| 21 |
+
from torchvision import transforms
|
| 22 |
+
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
|
| 25 |
+
from timm.models import load_checkpoint
|
| 26 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 27 |
+
|
| 28 |
+
torch_device = 'cuda'
|
| 29 |
+
torch_dtype = torch.float16
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def check_mkdir(dir_name):
|
| 33 |
+
if not os.path.isdir(dir_name):
|
| 34 |
+
os.makedirs(dir_name)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def SwinT(pretrained=True):
|
| 38 |
+
model = SwinTransformer(embed_dim=96,
|
| 39 |
+
depths=[2, 2, 6, 2],
|
| 40 |
+
num_heads=[3, 6, 12, 24],
|
| 41 |
+
window_size=7)
|
| 42 |
+
if pretrained is True:
|
| 43 |
+
model.load_state_dict(torch.load(
|
| 44 |
+
'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
|
| 45 |
+
map_location='cpu')['model'],
|
| 46 |
+
strict=False)
|
| 47 |
+
|
| 48 |
+
return model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def SwinS(pretrained=True):
|
| 52 |
+
model = SwinTransformer(embed_dim=96,
|
| 53 |
+
depths=[2, 2, 18, 2],
|
| 54 |
+
num_heads=[3, 6, 12, 24],
|
| 55 |
+
window_size=7)
|
| 56 |
+
if pretrained is True:
|
| 57 |
+
model.load_state_dict(torch.load(
|
| 58 |
+
'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
|
| 59 |
+
map_location='cpu')['model'],
|
| 60 |
+
strict=False)
|
| 61 |
+
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def SwinB(pretrained=True):
|
| 66 |
+
model = SwinTransformer(embed_dim=128,
|
| 67 |
+
depths=[2, 2, 18, 2],
|
| 68 |
+
num_heads=[4, 8, 16, 32],
|
| 69 |
+
window_size=12)
|
| 70 |
+
if pretrained is True:
|
| 71 |
+
import os
|
| 72 |
+
model.load_state_dict(torch.load(pretrained_SwinB_model_path,
|
| 73 |
+
map_location='cpu')['model'],
|
| 74 |
+
strict=False)
|
| 75 |
+
return model
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def SwinL(pretrained=True):
|
| 79 |
+
model = SwinTransformer(embed_dim=192,
|
| 80 |
+
depths=[2, 2, 18, 2],
|
| 81 |
+
num_heads=[6, 12, 24, 48],
|
| 82 |
+
window_size=12)
|
| 83 |
+
if pretrained is True:
|
| 84 |
+
model.load_state_dict(torch.load(
|
| 85 |
+
'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
|
| 86 |
+
map_location='cpu')['model'],
|
| 87 |
+
strict=False)
|
| 88 |
+
|
| 89 |
+
return model
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_activation_fn(activation):
|
| 93 |
+
"""Return an activation function given a string"""
|
| 94 |
+
if activation == "relu":
|
| 95 |
+
return F.relu
|
| 96 |
+
if activation == "gelu":
|
| 97 |
+
return F.gelu
|
| 98 |
+
if activation == "glu":
|
| 99 |
+
return F.glu
|
| 100 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def make_cbr(in_dim, out_dim):
|
| 104 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
|
| 105 |
+
nn.BatchNorm2d(out_dim), nn.PReLU())
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def make_cbg(in_dim, out_dim):
|
| 109 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
|
| 110 |
+
nn.BatchNorm2d(out_dim), nn.GELU())
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
|
| 114 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def resize_as(x, y, interpolation='bilinear'):
|
| 118 |
+
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def image2patches(x):
|
| 122 |
+
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
|
| 123 |
+
x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def patches2image(x):
|
| 128 |
+
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
|
| 129 |
+
x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def window_partition(x, window_size):
|
| 134 |
+
"""
|
| 135 |
+
Args:
|
| 136 |
+
x: (B, H, W, C)
|
| 137 |
+
window_size (int): window size
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 141 |
+
"""
|
| 142 |
+
B, H, W, C = x.shape
|
| 143 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
| 144 |
+
C)
|
| 145 |
+
windows = x.permute(0, 1, 3, 2, 4,
|
| 146 |
+
5).contiguous().view(-1, window_size, window_size, C)
|
| 147 |
+
return windows
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def window_reverse(windows, window_size, H, W):
|
| 151 |
+
"""
|
| 152 |
+
Args:
|
| 153 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 154 |
+
window_size (int): Window size
|
| 155 |
+
H (int): Height of image
|
| 156 |
+
W (int): Width of image
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
x: (B, H, W, C)
|
| 160 |
+
"""
|
| 161 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 162 |
+
x = windows.view(B, H // window_size, W // window_size, window_size,
|
| 163 |
+
window_size, -1)
|
| 164 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def mkdir_safe(out_path):
|
| 169 |
+
if type(out_path) == str:
|
| 170 |
+
if len(out_path) > 0:
|
| 171 |
+
if not os.path.exists(out_path):
|
| 172 |
+
os.mkdir(out_path)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_model_path():
|
| 176 |
+
import folder_paths
|
| 177 |
+
from folder_paths import models_dir
|
| 178 |
+
|
| 179 |
+
path_file_model = models_dir
|
| 180 |
+
mkdir_safe(out_path=path_file_model)
|
| 181 |
+
|
| 182 |
+
path_file_model = os.path.join(path_file_model, 'MVANet')
|
| 183 |
+
mkdir_safe(out_path=path_file_model)
|
| 184 |
+
|
| 185 |
+
path_file_model = os.path.join(path_file_model, 'Model_80.pth')
|
| 186 |
+
|
| 187 |
+
return path_file_model
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def download_model(path):
|
| 191 |
+
if not os.path.exists(path):
|
| 192 |
+
wget.download(
|
| 193 |
+
'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
|
| 194 |
+
out=path)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def load_model(model_checkpoint_path):
|
| 198 |
+
download_model(path=model_checkpoint_path)
|
| 199 |
+
torch.cuda.set_device(0)
|
| 200 |
+
|
| 201 |
+
net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
|
| 202 |
+
|
| 203 |
+
pretrained_dict = torch.load(finetuned_MVANet_model_path,
|
| 204 |
+
map_location=torch_device)
|
| 205 |
+
|
| 206 |
+
model_dict = net.state_dict()
|
| 207 |
+
pretrained_dict = {
|
| 208 |
+
k: v
|
| 209 |
+
for k, v in pretrained_dict.items() if k in model_dict
|
| 210 |
+
}
|
| 211 |
+
model_dict.update(pretrained_dict)
|
| 212 |
+
net.load_state_dict(model_dict)
|
| 213 |
+
net = net.to(dtype=torch_dtype, device=torch_device)
|
| 214 |
+
net.eval()
|
| 215 |
+
return net
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def do_infer_tensor2tensor(img, net):
|
| 219 |
+
|
| 220 |
+
img_transform = transforms.Compose(
|
| 221 |
+
[transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
| 222 |
+
|
| 223 |
+
h_, w_ = img.shape[1], img.shape[2]
|
| 224 |
+
|
| 225 |
+
with torch.no_grad():
|
| 226 |
+
|
| 227 |
+
img = rearrange(img, 'B H W C -> B C H W')
|
| 228 |
+
|
| 229 |
+
img_resize = torch.nn.functional.interpolate(input=img,
|
| 230 |
+
size=(1024, 1024),
|
| 231 |
+
mode='bicubic',
|
| 232 |
+
antialias=True)
|
| 233 |
+
|
| 234 |
+
img_var = img_transform(img_resize)
|
| 235 |
+
img_var = Variable(img_var)
|
| 236 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 237 |
+
|
| 238 |
+
mask = []
|
| 239 |
+
|
| 240 |
+
mask.append(net(img_var))
|
| 241 |
+
|
| 242 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 243 |
+
prediction = prediction.sigmoid()
|
| 244 |
+
|
| 245 |
+
prediction = torch.nn.functional.interpolate(input=prediction,
|
| 246 |
+
size=(h_, w_),
|
| 247 |
+
mode='bicubic',
|
| 248 |
+
antialias=True)
|
| 249 |
+
|
| 250 |
+
prediction = prediction.squeeze(0)
|
| 251 |
+
prediction = prediction.clamp(0, 1)
|
| 252 |
+
prediction = prediction.detach()
|
| 253 |
+
prediction = prediction.to(dtype=torch.float32, device='cpu')
|
| 254 |
+
|
| 255 |
+
return prediction
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class Mlp(nn.Module):
|
| 259 |
+
""" Multilayer perceptron."""
|
| 260 |
+
|
| 261 |
+
def __init__(self,
|
| 262 |
+
in_features,
|
| 263 |
+
hidden_features=None,
|
| 264 |
+
out_features=None,
|
| 265 |
+
act_layer=nn.GELU,
|
| 266 |
+
drop=0.):
|
| 267 |
+
super().__init__()
|
| 268 |
+
out_features = out_features or in_features
|
| 269 |
+
hidden_features = hidden_features or in_features
|
| 270 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 271 |
+
self.act = act_layer()
|
| 272 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 273 |
+
self.drop = nn.Dropout(drop)
|
| 274 |
+
|
| 275 |
+
def forward(self, x):
|
| 276 |
+
x = self.fc1(x)
|
| 277 |
+
x = self.act(x)
|
| 278 |
+
x = self.drop(x)
|
| 279 |
+
x = self.fc2(x)
|
| 280 |
+
x = self.drop(x)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class WindowAttention(nn.Module):
|
| 285 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 286 |
+
It supports both of shifted and non-shifted window.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
dim (int): Number of input channels.
|
| 290 |
+
window_size (tuple[int]): The height and width of the window.
|
| 291 |
+
num_heads (int): Number of attention heads.
|
| 292 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 293 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 294 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 295 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self,
|
| 299 |
+
dim,
|
| 300 |
+
window_size,
|
| 301 |
+
num_heads,
|
| 302 |
+
qkv_bias=True,
|
| 303 |
+
qk_scale=None,
|
| 304 |
+
attn_drop=0.,
|
| 305 |
+
proj_drop=0.):
|
| 306 |
+
|
| 307 |
+
super().__init__()
|
| 308 |
+
self.dim = dim
|
| 309 |
+
self.window_size = window_size # Wh, Ww
|
| 310 |
+
self.num_heads = num_heads
|
| 311 |
+
head_dim = dim // num_heads
|
| 312 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 313 |
+
|
| 314 |
+
# define a parameter table of relative position bias
|
| 315 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 316 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
| 317 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 318 |
+
|
| 319 |
+
# get pair-wise relative position index for each token inside the window
|
| 320 |
+
coords_h = torch.arange(self.window_size[0])
|
| 321 |
+
coords_w = torch.arange(self.window_size[1])
|
| 322 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 323 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 324 |
+
relative_coords = coords_flatten[:, :,
|
| 325 |
+
None] - coords_flatten[:,
|
| 326 |
+
None, :] # 2, Wh*Ww, Wh*Ww
|
| 327 |
+
relative_coords = relative_coords.permute(
|
| 328 |
+
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 329 |
+
relative_coords[:, :,
|
| 330 |
+
0] += self.window_size[0] - 1 # shift to start from 0
|
| 331 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 332 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 333 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 334 |
+
self.register_buffer("relative_position_index",
|
| 335 |
+
relative_position_index)
|
| 336 |
+
|
| 337 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 338 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 339 |
+
self.proj = nn.Linear(dim, dim)
|
| 340 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 341 |
+
|
| 342 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 343 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 344 |
+
|
| 345 |
+
def forward(self, x, mask=None):
|
| 346 |
+
""" Forward function.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 350 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 351 |
+
"""
|
| 352 |
+
x = x.to(dtype=torch_dtype, device=torch_device)
|
| 353 |
+
B_, N, C = x.shape
|
| 354 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
| 355 |
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 356 |
+
q, k, v = qkv[0], qkv[1], qkv[
|
| 357 |
+
2] # make torchscript happy (cannot use tensor as tuple)
|
| 358 |
+
|
| 359 |
+
q = q * self.scale
|
| 360 |
+
attn = (q @ k.transpose(-2, -1))
|
| 361 |
+
|
| 362 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 363 |
+
self.relative_position_index.view(-1)].view(
|
| 364 |
+
self.window_size[0] * self.window_size[1],
|
| 365 |
+
self.window_size[0] * self.window_size[1],
|
| 366 |
+
-1) # Wh*Ww,Wh*Ww,nH
|
| 367 |
+
relative_position_bias = relative_position_bias.permute(
|
| 368 |
+
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 369 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 370 |
+
|
| 371 |
+
if mask is not None:
|
| 372 |
+
nW = mask.shape[0]
|
| 373 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
| 374 |
+
N) + mask.unsqueeze(1).unsqueeze(0)
|
| 375 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 376 |
+
attn = self.softmax(attn)
|
| 377 |
+
else:
|
| 378 |
+
attn = self.softmax(attn)
|
| 379 |
+
|
| 380 |
+
attn = self.attn_drop(attn)
|
| 381 |
+
attn = attn.to(dtype=torch_dtype, device=torch_device)
|
| 382 |
+
v = v.to(dtype=torch_dtype, device=torch_device)
|
| 383 |
+
|
| 384 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 385 |
+
x = self.proj(x)
|
| 386 |
+
x = self.proj_drop(x)
|
| 387 |
+
return x
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class SwinTransformerBlock(nn.Module):
|
| 391 |
+
""" Swin Transformer Block.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
dim (int): Number of input channels.
|
| 395 |
+
num_heads (int): Number of attention heads.
|
| 396 |
+
window_size (int): Window size.
|
| 397 |
+
shift_size (int): Shift size for SW-MSA.
|
| 398 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 399 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 400 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 401 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 402 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 403 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 404 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 405 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
def __init__(self,
|
| 409 |
+
dim,
|
| 410 |
+
num_heads,
|
| 411 |
+
window_size=7,
|
| 412 |
+
shift_size=0,
|
| 413 |
+
mlp_ratio=4.,
|
| 414 |
+
qkv_bias=True,
|
| 415 |
+
qk_scale=None,
|
| 416 |
+
drop=0.,
|
| 417 |
+
attn_drop=0.,
|
| 418 |
+
drop_path=0.,
|
| 419 |
+
act_layer=nn.GELU,
|
| 420 |
+
norm_layer=nn.LayerNorm):
|
| 421 |
+
super().__init__()
|
| 422 |
+
self.dim = dim
|
| 423 |
+
self.num_heads = num_heads
|
| 424 |
+
self.window_size = window_size
|
| 425 |
+
self.shift_size = shift_size
|
| 426 |
+
self.mlp_ratio = mlp_ratio
|
| 427 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 428 |
+
|
| 429 |
+
self.norm1 = norm_layer(dim)
|
| 430 |
+
self.attn = WindowAttention(dim,
|
| 431 |
+
window_size=to_2tuple(self.window_size),
|
| 432 |
+
num_heads=num_heads,
|
| 433 |
+
qkv_bias=qkv_bias,
|
| 434 |
+
qk_scale=qk_scale,
|
| 435 |
+
attn_drop=attn_drop,
|
| 436 |
+
proj_drop=drop)
|
| 437 |
+
|
| 438 |
+
self.drop_path = DropPath(
|
| 439 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 440 |
+
self.norm2 = norm_layer(dim)
|
| 441 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 442 |
+
self.mlp = Mlp(in_features=dim,
|
| 443 |
+
hidden_features=mlp_hidden_dim,
|
| 444 |
+
act_layer=act_layer,
|
| 445 |
+
drop=drop)
|
| 446 |
+
|
| 447 |
+
self.H = None
|
| 448 |
+
self.W = None
|
| 449 |
+
|
| 450 |
+
def forward(self, x, mask_matrix):
|
| 451 |
+
""" Forward function.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 455 |
+
H, W: Spatial resolution of the input feature.
|
| 456 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 457 |
+
"""
|
| 458 |
+
B, L, C = x.shape
|
| 459 |
+
H, W = self.H, self.W
|
| 460 |
+
assert L == H * W, "input feature has wrong size"
|
| 461 |
+
|
| 462 |
+
shortcut = x
|
| 463 |
+
x = self.norm1(x)
|
| 464 |
+
x = x.view(B, H, W, C)
|
| 465 |
+
|
| 466 |
+
# pad feature maps to multiples of window size
|
| 467 |
+
pad_l = pad_t = 0
|
| 468 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 469 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 470 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 471 |
+
_, Hp, Wp, _ = x.shape
|
| 472 |
+
|
| 473 |
+
# cyclic shift
|
| 474 |
+
if self.shift_size > 0:
|
| 475 |
+
shifted_x = torch.roll(x,
|
| 476 |
+
shifts=(-self.shift_size, -self.shift_size),
|
| 477 |
+
dims=(1, 2))
|
| 478 |
+
attn_mask = mask_matrix
|
| 479 |
+
else:
|
| 480 |
+
shifted_x = x
|
| 481 |
+
attn_mask = None
|
| 482 |
+
|
| 483 |
+
# partition windows
|
| 484 |
+
x_windows = window_partition(
|
| 485 |
+
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 486 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
| 487 |
+
C) # nW*B, window_size*window_size, C
|
| 488 |
+
|
| 489 |
+
# W-MSA/SW-MSA
|
| 490 |
+
attn_windows = self.attn(
|
| 491 |
+
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 492 |
+
|
| 493 |
+
# merge windows
|
| 494 |
+
attn_windows = attn_windows.view(-1, self.window_size,
|
| 495 |
+
self.window_size, C)
|
| 496 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
| 497 |
+
Wp) # B H' W' C
|
| 498 |
+
|
| 499 |
+
# reverse cyclic shift
|
| 500 |
+
if self.shift_size > 0:
|
| 501 |
+
x = torch.roll(shifted_x,
|
| 502 |
+
shifts=(self.shift_size, self.shift_size),
|
| 503 |
+
dims=(1, 2))
|
| 504 |
+
else:
|
| 505 |
+
x = shifted_x
|
| 506 |
+
|
| 507 |
+
if pad_r > 0 or pad_b > 0:
|
| 508 |
+
x = x[:, :H, :W, :].contiguous()
|
| 509 |
+
|
| 510 |
+
x = x.view(B, H * W, C)
|
| 511 |
+
|
| 512 |
+
# FFN
|
| 513 |
+
x = shortcut + self.drop_path(x)
|
| 514 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 515 |
+
|
| 516 |
+
return x
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class PatchMerging(nn.Module):
|
| 520 |
+
""" Patch Merging Layer
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
dim (int): Number of input channels.
|
| 524 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 525 |
+
"""
|
| 526 |
+
|
| 527 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 528 |
+
super().__init__()
|
| 529 |
+
self.dim = dim
|
| 530 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 531 |
+
self.norm = norm_layer(4 * dim)
|
| 532 |
+
|
| 533 |
+
def forward(self, x, H, W):
|
| 534 |
+
""" Forward function.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 538 |
+
H, W: Spatial resolution of the input feature.
|
| 539 |
+
"""
|
| 540 |
+
B, L, C = x.shape
|
| 541 |
+
assert L == H * W, "input feature has wrong size"
|
| 542 |
+
|
| 543 |
+
x = x.view(B, H, W, C)
|
| 544 |
+
|
| 545 |
+
# padding
|
| 546 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 547 |
+
if pad_input:
|
| 548 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 549 |
+
|
| 550 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 551 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 552 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 553 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 554 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 555 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 556 |
+
|
| 557 |
+
x = self.norm(x)
|
| 558 |
+
x = self.reduction(x)
|
| 559 |
+
|
| 560 |
+
return x
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class BasicLayer(nn.Module):
|
| 564 |
+
""" A basic Swin Transformer layer for one stage.
|
| 565 |
+
|
| 566 |
+
Args:
|
| 567 |
+
dim (int): Number of feature channels
|
| 568 |
+
depth (int): Depths of this stage.
|
| 569 |
+
num_heads (int): Number of attention head.
|
| 570 |
+
window_size (int): Local window size. Default: 7.
|
| 571 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 572 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 573 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 574 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 575 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 576 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 577 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 578 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 579 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 580 |
+
"""
|
| 581 |
+
|
| 582 |
+
def __init__(self,
|
| 583 |
+
dim,
|
| 584 |
+
depth,
|
| 585 |
+
num_heads,
|
| 586 |
+
window_size=7,
|
| 587 |
+
mlp_ratio=4.,
|
| 588 |
+
qkv_bias=True,
|
| 589 |
+
qk_scale=None,
|
| 590 |
+
drop=0.,
|
| 591 |
+
attn_drop=0.,
|
| 592 |
+
drop_path=0.,
|
| 593 |
+
norm_layer=nn.LayerNorm,
|
| 594 |
+
downsample=None,
|
| 595 |
+
use_checkpoint=False):
|
| 596 |
+
super().__init__()
|
| 597 |
+
self.window_size = window_size
|
| 598 |
+
self.shift_size = window_size // 2
|
| 599 |
+
self.depth = depth
|
| 600 |
+
self.use_checkpoint = use_checkpoint
|
| 601 |
+
|
| 602 |
+
# build blocks
|
| 603 |
+
self.blocks = nn.ModuleList([
|
| 604 |
+
SwinTransformerBlock(dim=dim,
|
| 605 |
+
num_heads=num_heads,
|
| 606 |
+
window_size=window_size,
|
| 607 |
+
shift_size=0 if
|
| 608 |
+
(i % 2 == 0) else window_size // 2,
|
| 609 |
+
mlp_ratio=mlp_ratio,
|
| 610 |
+
qkv_bias=qkv_bias,
|
| 611 |
+
qk_scale=qk_scale,
|
| 612 |
+
drop=drop,
|
| 613 |
+
attn_drop=attn_drop,
|
| 614 |
+
drop_path=drop_path[i] if isinstance(
|
| 615 |
+
drop_path, list) else drop_path,
|
| 616 |
+
norm_layer=norm_layer) for i in range(depth)
|
| 617 |
+
])
|
| 618 |
+
|
| 619 |
+
# patch merging layer
|
| 620 |
+
if downsample is not None:
|
| 621 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 622 |
+
else:
|
| 623 |
+
self.downsample = None
|
| 624 |
+
|
| 625 |
+
def forward(self, x, H, W):
|
| 626 |
+
""" Forward function.
|
| 627 |
+
|
| 628 |
+
Args:
|
| 629 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 630 |
+
H, W: Spatial resolution of the input feature.
|
| 631 |
+
"""
|
| 632 |
+
|
| 633 |
+
# calculate attention mask for SW-MSA
|
| 634 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 635 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 636 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 637 |
+
h_slices = (slice(0, -self.window_size),
|
| 638 |
+
slice(-self.window_size,
|
| 639 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 640 |
+
w_slices = (slice(0, -self.window_size),
|
| 641 |
+
slice(-self.window_size,
|
| 642 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 643 |
+
cnt = 0
|
| 644 |
+
for h in h_slices:
|
| 645 |
+
for w in w_slices:
|
| 646 |
+
img_mask[:, h, w, :] = cnt
|
| 647 |
+
cnt += 1
|
| 648 |
+
|
| 649 |
+
mask_windows = window_partition(
|
| 650 |
+
img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 651 |
+
mask_windows = mask_windows.view(-1,
|
| 652 |
+
self.window_size * self.window_size)
|
| 653 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 654 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
| 655 |
+
float(-100.0)).masked_fill(
|
| 656 |
+
attn_mask == 0, float(0.0))
|
| 657 |
+
|
| 658 |
+
for blk in self.blocks:
|
| 659 |
+
blk.H, blk.W = H, W
|
| 660 |
+
if self.use_checkpoint:
|
| 661 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 662 |
+
else:
|
| 663 |
+
x = blk(x, attn_mask)
|
| 664 |
+
if self.downsample is not None:
|
| 665 |
+
x_down = self.downsample(x, H, W)
|
| 666 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 667 |
+
return x, H, W, x_down, Wh, Ww
|
| 668 |
+
else:
|
| 669 |
+
return x, H, W, x, H, W
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
class PatchEmbed(nn.Module):
|
| 673 |
+
""" Image to Patch Embedding
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
patch_size (int): Patch token size. Default: 4.
|
| 677 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 678 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 679 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 680 |
+
"""
|
| 681 |
+
|
| 682 |
+
def __init__(self,
|
| 683 |
+
patch_size=4,
|
| 684 |
+
in_chans=3,
|
| 685 |
+
embed_dim=96,
|
| 686 |
+
norm_layer=None):
|
| 687 |
+
super().__init__()
|
| 688 |
+
patch_size = to_2tuple(patch_size)
|
| 689 |
+
self.patch_size = patch_size
|
| 690 |
+
|
| 691 |
+
self.in_chans = in_chans
|
| 692 |
+
self.embed_dim = embed_dim
|
| 693 |
+
|
| 694 |
+
self.proj = nn.Conv2d(in_chans,
|
| 695 |
+
embed_dim,
|
| 696 |
+
kernel_size=patch_size,
|
| 697 |
+
stride=patch_size)
|
| 698 |
+
if norm_layer is not None:
|
| 699 |
+
self.norm = norm_layer(embed_dim)
|
| 700 |
+
else:
|
| 701 |
+
self.norm = None
|
| 702 |
+
|
| 703 |
+
def forward(self, x):
|
| 704 |
+
"""Forward function."""
|
| 705 |
+
# padding
|
| 706 |
+
_, _, H, W = x.size()
|
| 707 |
+
if W % self.patch_size[1] != 0:
|
| 708 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 709 |
+
if H % self.patch_size[0] != 0:
|
| 710 |
+
x = F.pad(x,
|
| 711 |
+
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 712 |
+
|
| 713 |
+
x = self.proj(x) # B C Wh Ww
|
| 714 |
+
if self.norm is not None:
|
| 715 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 716 |
+
x = x.flatten(2).transpose(1, 2)
|
| 717 |
+
x = self.norm(x)
|
| 718 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 719 |
+
|
| 720 |
+
return x
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class SwinTransformer(nn.Module):
|
| 724 |
+
""" Swin Transformer backbone.
|
| 725 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 726 |
+
https://arxiv.org/pdf/2103.14030
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 730 |
+
used in absolute postion embedding. Default 224.
|
| 731 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 732 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 733 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 734 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 735 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 736 |
+
window_size (int): Window size. Default: 7.
|
| 737 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 738 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 739 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 740 |
+
drop_rate (float): Dropout rate.
|
| 741 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 742 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 743 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 744 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 745 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 746 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 747 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 748 |
+
-1 means not freezing any parameters.
|
| 749 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 750 |
+
"""
|
| 751 |
+
|
| 752 |
+
def __init__(self,
|
| 753 |
+
pretrain_img_size=224,
|
| 754 |
+
patch_size=4,
|
| 755 |
+
in_chans=3,
|
| 756 |
+
embed_dim=96,
|
| 757 |
+
depths=[2, 2, 6, 2],
|
| 758 |
+
num_heads=[3, 6, 12, 24],
|
| 759 |
+
window_size=7,
|
| 760 |
+
mlp_ratio=4.,
|
| 761 |
+
qkv_bias=True,
|
| 762 |
+
qk_scale=None,
|
| 763 |
+
drop_rate=0.,
|
| 764 |
+
attn_drop_rate=0.,
|
| 765 |
+
drop_path_rate=0.2,
|
| 766 |
+
norm_layer=nn.LayerNorm,
|
| 767 |
+
ape=False,
|
| 768 |
+
patch_norm=True,
|
| 769 |
+
out_indices=(0, 1, 2, 3),
|
| 770 |
+
frozen_stages=-1,
|
| 771 |
+
use_checkpoint=False):
|
| 772 |
+
super().__init__()
|
| 773 |
+
|
| 774 |
+
self.pretrain_img_size = pretrain_img_size
|
| 775 |
+
self.num_layers = len(depths)
|
| 776 |
+
self.embed_dim = embed_dim
|
| 777 |
+
self.ape = ape
|
| 778 |
+
self.patch_norm = patch_norm
|
| 779 |
+
self.out_indices = out_indices
|
| 780 |
+
self.frozen_stages = frozen_stages
|
| 781 |
+
|
| 782 |
+
# split image into non-overlapping patches
|
| 783 |
+
self.patch_embed = PatchEmbed(
|
| 784 |
+
patch_size=patch_size,
|
| 785 |
+
in_chans=in_chans,
|
| 786 |
+
embed_dim=embed_dim,
|
| 787 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 788 |
+
|
| 789 |
+
# absolute position embedding
|
| 790 |
+
if self.ape:
|
| 791 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 792 |
+
patch_size = to_2tuple(patch_size)
|
| 793 |
+
patches_resolution = [
|
| 794 |
+
pretrain_img_size[0] // patch_size[0],
|
| 795 |
+
pretrain_img_size[1] // patch_size[1]
|
| 796 |
+
]
|
| 797 |
+
|
| 798 |
+
self.absolute_pos_embed = nn.Parameter(
|
| 799 |
+
torch.zeros(1, embed_dim, patches_resolution[0],
|
| 800 |
+
patches_resolution[1]))
|
| 801 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 802 |
+
|
| 803 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 804 |
+
|
| 805 |
+
# stochastic depth
|
| 806 |
+
dpr = [
|
| 807 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 808 |
+
] # stochastic depth decay rule
|
| 809 |
+
|
| 810 |
+
# build layers
|
| 811 |
+
self.layers = nn.ModuleList()
|
| 812 |
+
for i_layer in range(self.num_layers):
|
| 813 |
+
layer = BasicLayer(
|
| 814 |
+
dim=int(embed_dim * 2**i_layer),
|
| 815 |
+
depth=depths[i_layer],
|
| 816 |
+
num_heads=num_heads[i_layer],
|
| 817 |
+
window_size=window_size,
|
| 818 |
+
mlp_ratio=mlp_ratio,
|
| 819 |
+
qkv_bias=qkv_bias,
|
| 820 |
+
qk_scale=qk_scale,
|
| 821 |
+
drop=drop_rate,
|
| 822 |
+
attn_drop=attn_drop_rate,
|
| 823 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 824 |
+
norm_layer=norm_layer,
|
| 825 |
+
downsample=PatchMerging if
|
| 826 |
+
(i_layer < self.num_layers - 1) else None,
|
| 827 |
+
use_checkpoint=use_checkpoint)
|
| 828 |
+
self.layers.append(layer)
|
| 829 |
+
|
| 830 |
+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
| 831 |
+
self.num_features = num_features
|
| 832 |
+
|
| 833 |
+
# add a norm layer for each output
|
| 834 |
+
for i_layer in out_indices:
|
| 835 |
+
layer = norm_layer(num_features[i_layer])
|
| 836 |
+
layer_name = f'norm{i_layer}'
|
| 837 |
+
self.add_module(layer_name, layer)
|
| 838 |
+
|
| 839 |
+
self._freeze_stages()
|
| 840 |
+
|
| 841 |
+
def _freeze_stages(self):
|
| 842 |
+
if self.frozen_stages >= 0:
|
| 843 |
+
self.patch_embed.eval()
|
| 844 |
+
for param in self.patch_embed.parameters():
|
| 845 |
+
param.requires_grad = False
|
| 846 |
+
|
| 847 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 848 |
+
self.absolute_pos_embed.requires_grad = False
|
| 849 |
+
|
| 850 |
+
if self.frozen_stages >= 2:
|
| 851 |
+
self.pos_drop.eval()
|
| 852 |
+
for i in range(0, self.frozen_stages - 1):
|
| 853 |
+
m = self.layers[i]
|
| 854 |
+
m.eval()
|
| 855 |
+
for param in m.parameters():
|
| 856 |
+
param.requires_grad = False
|
| 857 |
+
|
| 858 |
+
def init_weights(self, pretrained=None):
|
| 859 |
+
"""Initialize the weights in backbone.
|
| 860 |
+
|
| 861 |
+
Args:
|
| 862 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 863 |
+
Defaults to None.
|
| 864 |
+
"""
|
| 865 |
+
|
| 866 |
+
def _init_weights(m):
|
| 867 |
+
if isinstance(m, nn.Linear):
|
| 868 |
+
trunc_normal_(m.weight, std=.02)
|
| 869 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 870 |
+
nn.init.constant_(m.bias, 0)
|
| 871 |
+
elif isinstance(m, nn.LayerNorm):
|
| 872 |
+
nn.init.constant_(m.bias, 0)
|
| 873 |
+
nn.init.constant_(m.weight, 1.0)
|
| 874 |
+
|
| 875 |
+
if isinstance(pretrained, str):
|
| 876 |
+
self.apply(_init_weights)
|
| 877 |
+
load_checkpoint(self, pretrained, strict=False, logger=None)
|
| 878 |
+
elif pretrained is None:
|
| 879 |
+
self.apply(_init_weights)
|
| 880 |
+
else:
|
| 881 |
+
raise TypeError('pretrained must be a str or None')
|
| 882 |
+
|
| 883 |
+
def forward(self, x):
|
| 884 |
+
x = self.patch_embed(x)
|
| 885 |
+
|
| 886 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 887 |
+
if self.ape:
|
| 888 |
+
# interpolate the position embedding to the corresponding size
|
| 889 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
|
| 890 |
+
size=(Wh, Ww),
|
| 891 |
+
mode='bicubic')
|
| 892 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 893 |
+
|
| 894 |
+
outs = [x.contiguous()]
|
| 895 |
+
x = x.flatten(2).transpose(1, 2)
|
| 896 |
+
x = self.pos_drop(x)
|
| 897 |
+
for i in range(self.num_layers):
|
| 898 |
+
layer = self.layers[i]
|
| 899 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 900 |
+
|
| 901 |
+
if i in self.out_indices:
|
| 902 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 903 |
+
x_out = norm_layer(x_out)
|
| 904 |
+
|
| 905 |
+
out = x_out.view(-1, H, W,
|
| 906 |
+
self.num_features[i]).permute(0, 3, 1,
|
| 907 |
+
2).contiguous()
|
| 908 |
+
outs.append(out)
|
| 909 |
+
|
| 910 |
+
return tuple(outs)
|
| 911 |
+
|
| 912 |
+
def train(self, mode=True):
|
| 913 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 914 |
+
super(SwinTransformer, self).train(mode)
|
| 915 |
+
self._freeze_stages()
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
class PositionEmbeddingSine:
|
| 919 |
+
|
| 920 |
+
def __init__(self,
|
| 921 |
+
num_pos_feats=64,
|
| 922 |
+
temperature=10000,
|
| 923 |
+
normalize=False,
|
| 924 |
+
scale=None):
|
| 925 |
+
super().__init__()
|
| 926 |
+
self.num_pos_feats = num_pos_feats
|
| 927 |
+
self.temperature = temperature
|
| 928 |
+
self.normalize = normalize
|
| 929 |
+
if scale is not None and normalize is False:
|
| 930 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 931 |
+
if scale is None:
|
| 932 |
+
scale = 2 * math.pi
|
| 933 |
+
self.scale = scale
|
| 934 |
+
self.dim_t = torch.arange(0,
|
| 935 |
+
self.num_pos_feats,
|
| 936 |
+
dtype=torch_dtype,
|
| 937 |
+
device=torch_device)
|
| 938 |
+
|
| 939 |
+
def __call__(self, b, h, w):
|
| 940 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
|
| 941 |
+
assert mask is not None
|
| 942 |
+
not_mask = ~mask
|
| 943 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
|
| 944 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
|
| 945 |
+
if self.normalize:
|
| 946 |
+
eps = 1e-6
|
| 947 |
+
y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
|
| 948 |
+
self.scale).to(device=torch_device, dtype=torch_dtype)
|
| 949 |
+
x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
|
| 950 |
+
self.scale).to(device=torch_device, dtype=torch_dtype)
|
| 951 |
+
|
| 952 |
+
dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
|
| 953 |
+
|
| 954 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 955 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 956 |
+
pos_x = torch.stack(
|
| 957 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
|
| 958 |
+
dim=4).flatten(3)
|
| 959 |
+
pos_y = torch.stack(
|
| 960 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
|
| 961 |
+
dim=4).flatten(3)
|
| 962 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
class MCLM(nn.Module):
|
| 966 |
+
|
| 967 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 968 |
+
super(MCLM, self).__init__()
|
| 969 |
+
self.attention = nn.ModuleList([
|
| 970 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 971 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 972 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 973 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 974 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 975 |
+
])
|
| 976 |
+
|
| 977 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
| 978 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
| 979 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 980 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 981 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 982 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 983 |
+
self.dropout = nn.Dropout(0.1)
|
| 984 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 985 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 986 |
+
self.activation = get_activation_fn('relu')
|
| 987 |
+
self.pool_ratios = pool_ratios
|
| 988 |
+
self.p_poses = []
|
| 989 |
+
self.g_pos = None
|
| 990 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 991 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 992 |
+
|
| 993 |
+
def forward(self, l, g):
|
| 994 |
+
"""
|
| 995 |
+
l: 4,c,h,w
|
| 996 |
+
g: 1,c,h,w
|
| 997 |
+
"""
|
| 998 |
+
b, c, h, w = l.size()
|
| 999 |
+
# 4,c,h,w -> 1,c,2h,2w
|
| 1000 |
+
concated_locs = rearrange(l,
|
| 1001 |
+
'(hg wg b) c h w -> b c (hg h) (wg w)',
|
| 1002 |
+
hg=2,
|
| 1003 |
+
wg=2)
|
| 1004 |
+
|
| 1005 |
+
pools = []
|
| 1006 |
+
for pool_ratio in self.pool_ratios:
|
| 1007 |
+
# b,c,h,w
|
| 1008 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1009 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
| 1010 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
| 1011 |
+
if self.g_pos is None:
|
| 1012 |
+
pos_emb = self.positional_encoding(pool.shape[0],
|
| 1013 |
+
pool.shape[2],
|
| 1014 |
+
pool.shape[3])
|
| 1015 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1016 |
+
self.p_poses.append(pos_emb)
|
| 1017 |
+
pools = torch.cat(pools, 0)
|
| 1018 |
+
if self.g_pos is None:
|
| 1019 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
| 1020 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
|
| 1021 |
+
g.shape[3])
|
| 1022 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1023 |
+
|
| 1024 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
| 1025 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
| 1026 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
|
| 1027 |
+
g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
| 1028 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
| 1029 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(
|
| 1030 |
+
self.linear2(
|
| 1031 |
+
self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
| 1032 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
| 1033 |
+
|
| 1034 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
| 1035 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
| 1036 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
| 1037 |
+
_g_hw_b_c = rearrange(_g_hw_b_c,
|
| 1038 |
+
"(ng h) (nw w) b c -> (h w) (ng nw b) c",
|
| 1039 |
+
ng=2,
|
| 1040 |
+
nw=2)
|
| 1041 |
+
outputs_re = []
|
| 1042 |
+
for i, (_l, _g) in enumerate(
|
| 1043 |
+
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
| 1044 |
+
outputs_re.append(self.attention[i + 1](_l, _g,
|
| 1045 |
+
_g)[0]) # (h w) 1 c
|
| 1046 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
| 1047 |
+
|
| 1048 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 1049 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 1050 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(
|
| 1051 |
+
self.linear4(
|
| 1052 |
+
self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 1053 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 1054 |
+
|
| 1055 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 1056 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
class inf_MCLM(nn.Module):
|
| 1060 |
+
|
| 1061 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 1062 |
+
super(inf_MCLM, self).__init__()
|
| 1063 |
+
self.attention = nn.ModuleList([
|
| 1064 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1065 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1066 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1067 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1068 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1069 |
+
])
|
| 1070 |
+
|
| 1071 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
| 1072 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
| 1073 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1074 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1075 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1076 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1077 |
+
self.dropout = nn.Dropout(0.1)
|
| 1078 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1079 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1080 |
+
self.activation = get_activation_fn('relu')
|
| 1081 |
+
self.pool_ratios = pool_ratios
|
| 1082 |
+
self.p_poses = []
|
| 1083 |
+
self.g_pos = None
|
| 1084 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1085 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1086 |
+
|
| 1087 |
+
def forward(self, l, g):
|
| 1088 |
+
"""
|
| 1089 |
+
l: 4,c,h,w
|
| 1090 |
+
g: 1,c,h,w
|
| 1091 |
+
"""
|
| 1092 |
+
b, c, h, w = l.size()
|
| 1093 |
+
# 4,c,h,w -> 1,c,2h,2w
|
| 1094 |
+
concated_locs = rearrange(l,
|
| 1095 |
+
'(hg wg b) c h w -> b c (hg h) (wg w)',
|
| 1096 |
+
hg=2,
|
| 1097 |
+
wg=2)
|
| 1098 |
+
self.p_poses = []
|
| 1099 |
+
pools = []
|
| 1100 |
+
for pool_ratio in self.pool_ratios:
|
| 1101 |
+
# b,c,h,w
|
| 1102 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1103 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
| 1104 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
| 1105 |
+
# if self.g_pos is None:
|
| 1106 |
+
pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
|
| 1107 |
+
pool.shape[3])
|
| 1108 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1109 |
+
self.p_poses.append(pos_emb)
|
| 1110 |
+
pools = torch.cat(pools, 0)
|
| 1111 |
+
# if self.g_pos is None:
|
| 1112 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
| 1113 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
|
| 1114 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1115 |
+
|
| 1116 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
| 1117 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
| 1118 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
|
| 1119 |
+
g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
| 1120 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
| 1121 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(
|
| 1122 |
+
self.linear2(
|
| 1123 |
+
self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
| 1124 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
| 1125 |
+
|
| 1126 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
| 1127 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
| 1128 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
| 1129 |
+
_g_hw_b_c = rearrange(_g_hw_b_c,
|
| 1130 |
+
"(ng h) (nw w) b c -> (h w) (ng nw b) c",
|
| 1131 |
+
ng=2,
|
| 1132 |
+
nw=2)
|
| 1133 |
+
outputs_re = []
|
| 1134 |
+
for i, (_l, _g) in enumerate(
|
| 1135 |
+
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
| 1136 |
+
outputs_re.append(self.attention[i + 1](_l, _g,
|
| 1137 |
+
_g)[0]) # (h w) 1 c
|
| 1138 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
| 1139 |
+
|
| 1140 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 1141 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 1142 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(
|
| 1143 |
+
self.linear4(
|
| 1144 |
+
self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 1145 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 1146 |
+
|
| 1147 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 1148 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
class MCRM(nn.Module):
|
| 1152 |
+
|
| 1153 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
| 1154 |
+
super(MCRM, self).__init__()
|
| 1155 |
+
self.attention = nn.ModuleList([
|
| 1156 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1157 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1158 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1159 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1160 |
+
])
|
| 1161 |
+
|
| 1162 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1163 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1164 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1165 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1166 |
+
self.dropout = nn.Dropout(0.1)
|
| 1167 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1168 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1169 |
+
self.sigmoid = nn.Sigmoid()
|
| 1170 |
+
self.activation = get_activation_fn('relu')
|
| 1171 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
| 1172 |
+
self.pool_ratios = pool_ratios
|
| 1173 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1174 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1175 |
+
|
| 1176 |
+
def forward(self, x):
|
| 1177 |
+
b, c, h, w = x.size()
|
| 1178 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 1179 |
+
# b(4),c,h,w
|
| 1180 |
+
patched_glb = rearrange(glb,
|
| 1181 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1182 |
+
hg=2,
|
| 1183 |
+
wg=2)
|
| 1184 |
+
|
| 1185 |
+
# generate token attention map
|
| 1186 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 1187 |
+
token_attention_map = F.interpolate(token_attention_map,
|
| 1188 |
+
size=patches2image(loc).shape[-2:],
|
| 1189 |
+
mode='nearest')
|
| 1190 |
+
loc = loc * rearrange(token_attention_map,
|
| 1191 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1192 |
+
hg=2,
|
| 1193 |
+
wg=2)
|
| 1194 |
+
pools = []
|
| 1195 |
+
for pool_ratio in self.pool_ratios:
|
| 1196 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1197 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 1198 |
+
pools.append(rearrange(pool,
|
| 1199 |
+
'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 1200 |
+
# nl(4),c,nphw -> nl(4),nphw,1,c
|
| 1201 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 1202 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 1203 |
+
outputs = []
|
| 1204 |
+
for i, q in enumerate(
|
| 1205 |
+
loc_.unbind(dim=0)): # traverse all local patches
|
| 1206 |
+
# np*hw,1,c
|
| 1207 |
+
v = pools[i]
|
| 1208 |
+
k = v
|
| 1209 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
| 1210 |
+
outputs = torch.cat(outputs, 1)
|
| 1211 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 1212 |
+
src = self.norm1(src)
|
| 1213 |
+
src = src + self.dropout2(
|
| 1214 |
+
self.linear4(
|
| 1215 |
+
self.dropout(self.activation(self.linear3(src)).clone())))
|
| 1216 |
+
src = self.norm2(src)
|
| 1217 |
+
|
| 1218 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 1219 |
+
glb = glb + F.interpolate(patches2image(src),
|
| 1220 |
+
size=glb.shape[-2:],
|
| 1221 |
+
mode='nearest') # freshed glb
|
| 1222 |
+
return torch.cat((src, glb), 0), token_attention_map
|
| 1223 |
+
|
| 1224 |
+
|
| 1225 |
+
class inf_MCRM(nn.Module):
|
| 1226 |
+
|
| 1227 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
| 1228 |
+
super(inf_MCRM, self).__init__()
|
| 1229 |
+
self.attention = nn.ModuleList([
|
| 1230 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1231 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1232 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1233 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1234 |
+
])
|
| 1235 |
+
|
| 1236 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1237 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1238 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1239 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1240 |
+
self.dropout = nn.Dropout(0.1)
|
| 1241 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1242 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1243 |
+
self.sigmoid = nn.Sigmoid()
|
| 1244 |
+
self.activation = get_activation_fn('relu')
|
| 1245 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
| 1246 |
+
self.pool_ratios = pool_ratios
|
| 1247 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1248 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1249 |
+
|
| 1250 |
+
def forward(self, x):
|
| 1251 |
+
b, c, h, w = x.size()
|
| 1252 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 1253 |
+
# b(4),c,h,w
|
| 1254 |
+
patched_glb = rearrange(glb,
|
| 1255 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1256 |
+
hg=2,
|
| 1257 |
+
wg=2)
|
| 1258 |
+
|
| 1259 |
+
# generate token attention map
|
| 1260 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 1261 |
+
token_attention_map = F.interpolate(token_attention_map,
|
| 1262 |
+
size=patches2image(loc).shape[-2:],
|
| 1263 |
+
mode='nearest')
|
| 1264 |
+
loc = loc * rearrange(token_attention_map,
|
| 1265 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1266 |
+
hg=2,
|
| 1267 |
+
wg=2)
|
| 1268 |
+
pools = []
|
| 1269 |
+
for pool_ratio in self.pool_ratios:
|
| 1270 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1271 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 1272 |
+
pools.append(rearrange(pool,
|
| 1273 |
+
'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 1274 |
+
# nl(4),c,nphw -> nl(4),nphw,1,c
|
| 1275 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 1276 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 1277 |
+
outputs = []
|
| 1278 |
+
for i, q in enumerate(
|
| 1279 |
+
loc_.unbind(dim=0)): # traverse all local patches
|
| 1280 |
+
# np*hw,1,c
|
| 1281 |
+
v = pools[i]
|
| 1282 |
+
k = v
|
| 1283 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
| 1284 |
+
outputs = torch.cat(outputs, 1)
|
| 1285 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 1286 |
+
src = self.norm1(src)
|
| 1287 |
+
src = src + self.dropout2(
|
| 1288 |
+
self.linear4(
|
| 1289 |
+
self.dropout(self.activation(self.linear3(src)).clone())))
|
| 1290 |
+
src = self.norm2(src)
|
| 1291 |
+
|
| 1292 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 1293 |
+
glb = glb + F.interpolate(patches2image(src),
|
| 1294 |
+
size=glb.shape[-2:],
|
| 1295 |
+
mode='nearest') # freshed glb
|
| 1296 |
+
return torch.cat((src, glb), 0)
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
# model for single-scale training
|
| 1300 |
+
class MVANet(nn.Module):
|
| 1301 |
+
|
| 1302 |
+
def __init__(self):
|
| 1303 |
+
super().__init__()
|
| 1304 |
+
self.backbone = SwinB(pretrained=True)
|
| 1305 |
+
emb_dim = 128
|
| 1306 |
+
self.sideout5 = nn.Sequential(
|
| 1307 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1308 |
+
self.sideout4 = nn.Sequential(
|
| 1309 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1310 |
+
self.sideout3 = nn.Sequential(
|
| 1311 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1312 |
+
self.sideout2 = nn.Sequential(
|
| 1313 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1314 |
+
self.sideout1 = nn.Sequential(
|
| 1315 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1316 |
+
|
| 1317 |
+
self.output5 = make_cbr(1024, emb_dim)
|
| 1318 |
+
self.output4 = make_cbr(512, emb_dim)
|
| 1319 |
+
self.output3 = make_cbr(256, emb_dim)
|
| 1320 |
+
self.output2 = make_cbr(128, emb_dim)
|
| 1321 |
+
self.output1 = make_cbr(128, emb_dim)
|
| 1322 |
+
|
| 1323 |
+
self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
|
| 1324 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
| 1325 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
| 1326 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
| 1327 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
| 1328 |
+
self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1329 |
+
self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1330 |
+
self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1331 |
+
self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1332 |
+
|
| 1333 |
+
self.insmask_head = nn.Sequential(
|
| 1334 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
| 1335 |
+
nn.BatchNorm2d(384), nn.PReLU(),
|
| 1336 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
|
| 1337 |
+
nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
|
| 1338 |
+
|
| 1339 |
+
self.shallow = nn.Sequential(
|
| 1340 |
+
nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
| 1341 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
| 1342 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
| 1343 |
+
self.output = nn.Sequential(
|
| 1344 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1345 |
+
|
| 1346 |
+
for m in self.modules():
|
| 1347 |
+
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
|
| 1348 |
+
m.inplace = True
|
| 1349 |
+
|
| 1350 |
+
def forward(self, x):
|
| 1351 |
+
x = x.to(dtype=torch_dtype, device=torch_device)
|
| 1352 |
+
shallow = self.shallow(x)
|
| 1353 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
| 1354 |
+
loc = image2patches(x)
|
| 1355 |
+
input = torch.cat((loc, glb), dim=0)
|
| 1356 |
+
feature = self.backbone(input)
|
| 1357 |
+
e5 = self.output5(feature[4]) # (5,128,16,16)
|
| 1358 |
+
e4 = self.output4(feature[3]) # (5,128,32,32)
|
| 1359 |
+
e3 = self.output3(feature[2]) # (5,128,64,64)
|
| 1360 |
+
e2 = self.output2(feature[1]) # (5,128,128,128)
|
| 1361 |
+
e1 = self.output1(feature[0]) # (5,128,128,128)
|
| 1362 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
| 1363 |
+
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
| 1364 |
+
|
| 1365 |
+
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
| 1366 |
+
e4 = self.conv4(e4)
|
| 1367 |
+
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
| 1368 |
+
e3 = self.conv3(e3)
|
| 1369 |
+
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
| 1370 |
+
e2 = self.conv2(e2)
|
| 1371 |
+
e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
|
| 1372 |
+
e1 = self.conv1(e1)
|
| 1373 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
| 1374 |
+
output1_cat = patches2image(loc_e1) # (1,128,256,256)
|
| 1375 |
+
# add glb feat in
|
| 1376 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
| 1377 |
+
# merge
|
| 1378 |
+
final_output = self.insmask_head(output1_cat) # (1,128,256,256)
|
| 1379 |
+
# shallow feature merge
|
| 1380 |
+
final_output = final_output + resize_as(shallow, final_output)
|
| 1381 |
+
final_output = self.upsample1(rescale_to(final_output))
|
| 1382 |
+
final_output = rescale_to(final_output +
|
| 1383 |
+
resize_as(shallow, final_output))
|
| 1384 |
+
final_output = self.upsample2(final_output)
|
| 1385 |
+
final_output = self.output(final_output)
|
| 1386 |
+
####
|
| 1387 |
+
sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
|
| 1388 |
+
sideout4 = self.sideout4(e4)
|
| 1389 |
+
sideout3 = self.sideout3(e3)
|
| 1390 |
+
sideout2 = self.sideout2(e2)
|
| 1391 |
+
sideout1 = self.sideout1(e1)
|
| 1392 |
+
#######glb_sideouts ######
|
| 1393 |
+
glb5 = self.sideout5(glb_e5)
|
| 1394 |
+
glb4 = sideout4[-1, :, :, :].unsqueeze(0)
|
| 1395 |
+
glb3 = sideout3[-1, :, :, :].unsqueeze(0)
|
| 1396 |
+
glb2 = sideout2[-1, :, :, :].unsqueeze(0)
|
| 1397 |
+
glb1 = sideout1[-1, :, :, :].unsqueeze(0)
|
| 1398 |
+
####### concat 4 to 1 #######
|
| 1399 |
+
sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
|
| 1400 |
+
device=torch_device)
|
| 1401 |
+
sideout2 = patches2image(sideout2[:-1]).to(
|
| 1402 |
+
dtype=torch_dtype,
|
| 1403 |
+
device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
|
| 1404 |
+
sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
|
| 1405 |
+
device=torch_device)
|
| 1406 |
+
sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
|
| 1407 |
+
device=torch_device)
|
| 1408 |
+
sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
|
| 1409 |
+
device=torch_device)
|
| 1410 |
+
if self.training:
|
| 1411 |
+
return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
|
| 1412 |
+
else:
|
| 1413 |
+
return final_output
|
| 1414 |
+
|
| 1415 |
+
|
| 1416 |
+
# model for multi-scale testing
|
| 1417 |
+
class inf_MVANet(nn.Module):
|
| 1418 |
+
|
| 1419 |
+
def __init__(self):
|
| 1420 |
+
super().__init__()
|
| 1421 |
+
# self.backbone = SwinB(pretrained=True)
|
| 1422 |
+
self.backbone = SwinB(pretrained=False)
|
| 1423 |
+
|
| 1424 |
+
emb_dim = 128
|
| 1425 |
+
self.output5 = make_cbr(1024, emb_dim)
|
| 1426 |
+
self.output4 = make_cbr(512, emb_dim)
|
| 1427 |
+
self.output3 = make_cbr(256, emb_dim)
|
| 1428 |
+
self.output2 = make_cbr(128, emb_dim)
|
| 1429 |
+
self.output1 = make_cbr(128, emb_dim)
|
| 1430 |
+
|
| 1431 |
+
self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
|
| 1432 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
| 1433 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
| 1434 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
| 1435 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
| 1436 |
+
self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1437 |
+
self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1438 |
+
self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1439 |
+
self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1440 |
+
|
| 1441 |
+
self.insmask_head = nn.Sequential(
|
| 1442 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
| 1443 |
+
nn.BatchNorm2d(384), nn.PReLU(),
|
| 1444 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
|
| 1445 |
+
nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
|
| 1446 |
+
|
| 1447 |
+
self.shallow = nn.Sequential(
|
| 1448 |
+
nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
| 1449 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
| 1450 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
| 1451 |
+
self.output = nn.Sequential(
|
| 1452 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1453 |
+
|
| 1454 |
+
for m in self.modules():
|
| 1455 |
+
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
|
| 1456 |
+
m.inplace = True
|
| 1457 |
+
|
| 1458 |
+
def forward(self, x):
|
| 1459 |
+
shallow = self.shallow(x)
|
| 1460 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
| 1461 |
+
loc = image2patches(x)
|
| 1462 |
+
input = torch.cat((loc, glb), dim=0)
|
| 1463 |
+
feature = self.backbone(input)
|
| 1464 |
+
e5 = self.output5(feature[4])
|
| 1465 |
+
e4 = self.output4(feature[3])
|
| 1466 |
+
e3 = self.output3(feature[2])
|
| 1467 |
+
e2 = self.output2(feature[1])
|
| 1468 |
+
e1 = self.output1(feature[0])
|
| 1469 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
| 1470 |
+
e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
|
| 1471 |
+
|
| 1472 |
+
e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
|
| 1473 |
+
e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
|
| 1474 |
+
e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
|
| 1475 |
+
e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
|
| 1476 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
| 1477 |
+
# after decoder, concat loc features to a whole one, and merge
|
| 1478 |
+
output1_cat = patches2image(loc_e1)
|
| 1479 |
+
# add glb feat in
|
| 1480 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
| 1481 |
+
# merge
|
| 1482 |
+
final_output = self.insmask_head(output1_cat)
|
| 1483 |
+
# shallow feature merge
|
| 1484 |
+
final_output = final_output + resize_as(shallow, final_output)
|
| 1485 |
+
final_output = self.upsample1(rescale_to(final_output))
|
| 1486 |
+
final_output = rescale_to(final_output +
|
| 1487 |
+
resize_as(shallow, final_output))
|
| 1488 |
+
final_output = self.upsample2(final_output)
|
| 1489 |
+
final_output = self.output(final_output)
|
| 1490 |
+
return final_output
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
class load_MVANet_Model:
|
| 1494 |
+
|
| 1495 |
+
def __init__(self):
|
| 1496 |
+
pass
|
| 1497 |
+
|
| 1498 |
+
@classmethod
|
| 1499 |
+
def INPUT_TYPES(s):
|
| 1500 |
+
return {
|
| 1501 |
+
"required": {},
|
| 1502 |
+
}
|
| 1503 |
+
|
| 1504 |
+
RETURN_TYPES = ("MVANet_Model", )
|
| 1505 |
+
FUNCTION = "test"
|
| 1506 |
+
CATEGORY = "MVANet"
|
| 1507 |
+
|
| 1508 |
+
def test(self):
|
| 1509 |
+
return (load_model(get_model_path()), )
|
| 1510 |
+
|
| 1511 |
+
|
| 1512 |
+
class run_MVANet_inference:
|
| 1513 |
+
|
| 1514 |
+
def __init__(self):
|
| 1515 |
+
pass
|
| 1516 |
+
|
| 1517 |
+
@classmethod
|
| 1518 |
+
def INPUT_TYPES(s):
|
| 1519 |
+
return {
|
| 1520 |
+
"required": {
|
| 1521 |
+
"image": ("IMAGE", ),
|
| 1522 |
+
"MVANet_Model": ("MVANet_Model", ),
|
| 1523 |
+
},
|
| 1524 |
+
}
|
| 1525 |
+
|
| 1526 |
+
RETURN_TYPES = ("MASK", )
|
| 1527 |
+
FUNCTION = "test"
|
| 1528 |
+
CATEGORY = "MVANet"
|
| 1529 |
+
|
| 1530 |
+
def test(
|
| 1531 |
+
self,
|
| 1532 |
+
image,
|
| 1533 |
+
MVANet_Model,
|
| 1534 |
+
):
|
| 1535 |
+
ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
|
| 1536 |
+
|
| 1537 |
+
return (ret, )
|
| 1538 |
+
|
| 1539 |
+
|
| 1540 |
+
NODE_CLASS_MAPPINGS = {
|
| 1541 |
+
"load_MVANet_Model": load_MVANet_Model,
|
| 1542 |
+
"run_MVANet_inference": run_MVANet_inference
|
| 1543 |
+
}
|
| 1544 |
+
|
| 1545 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 1546 |
+
"load_MVANet_Model": "load MVANet Model",
|
| 1547 |
+
"run_MVANet_inference": "run_MVANet_inference"
|
| 1548 |
+
}
|
ComfyUI_MVANet/MVANet_inference.run.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
. "${HOME}/dbnew.sh"
|
| 3 |
+
python3 './MVANet_inference.py'
|
ComfyUI_MVANet/README.org
ADDED
|
@@ -0,0 +1,1694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
* COMMENT Sample
|
| 2 |
+
|
| 3 |
+
** Shell script to download
|
| 4 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
|
| 5 |
+
#+end_src
|
| 6 |
+
|
| 7 |
+
** MVANet_inference import
|
| 8 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
|
| 9 |
+
#+end_src
|
| 10 |
+
|
| 11 |
+
** MVANet_inference function
|
| 12 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 13 |
+
#+end_src
|
| 14 |
+
|
| 15 |
+
** MVANet_inference class
|
| 16 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
|
| 17 |
+
#+end_src
|
| 18 |
+
|
| 19 |
+
** MVANet_inference execute
|
| 20 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
|
| 21 |
+
#+end_src
|
| 22 |
+
|
| 23 |
+
** MVANet_inference unify
|
| 24 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
|
| 25 |
+
#+end_src
|
| 26 |
+
|
| 27 |
+
* Download the code:
|
| 28 |
+
|
| 29 |
+
** Function to download
|
| 30 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
|
| 31 |
+
get_repo(){
|
| 32 |
+
DIR_REPO="${HOME}/GITHUB/$('echo' "${1}" | 'sed' 's/^git@github.com://g ; s@^https://github.com/@@g ; s@.git$@@g' )"
|
| 33 |
+
DIR_BASE="$('dirname' '--' "${DIR_REPO}")"
|
| 34 |
+
mkdir -pv -- "${DIR_BASE}"
|
| 35 |
+
cd "${DIR_BASE}"
|
| 36 |
+
git clone "${1}"
|
| 37 |
+
cd "${DIR_REPO}"
|
| 38 |
+
git pull
|
| 39 |
+
git submodule update --recursive --init
|
| 40 |
+
}
|
| 41 |
+
#+end_src
|
| 42 |
+
|
| 43 |
+
** Download
|
| 44 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
|
| 45 |
+
get_repo 'https://github.com/qianyu-dlut/MVANet.git'
|
| 46 |
+
#+end_src
|
| 47 |
+
|
| 48 |
+
* Dependencies
|
| 49 |
+
#+begin_src conf :tangle ./requirements.txt
|
| 50 |
+
timm
|
| 51 |
+
einops
|
| 52 |
+
wget
|
| 53 |
+
#+end_src
|
| 54 |
+
|
| 55 |
+
* Python inference
|
| 56 |
+
|
| 57 |
+
** Important configs
|
| 58 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
|
| 59 |
+
import os
|
| 60 |
+
import sys
|
| 61 |
+
|
| 62 |
+
HOME_DIR = os.environ.get('HOME', '/root')
|
| 63 |
+
MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
|
| 64 |
+
finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
|
| 65 |
+
pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
|
| 66 |
+
#+end_src
|
| 67 |
+
|
| 68 |
+
** MVANet_inference import
|
| 69 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
|
| 70 |
+
import math
|
| 71 |
+
import numpy as np
|
| 72 |
+
import cv2
|
| 73 |
+
import wget
|
| 74 |
+
|
| 75 |
+
import torch
|
| 76 |
+
import torch.nn as nn
|
| 77 |
+
import torch.nn.functional as F
|
| 78 |
+
import torch.utils.checkpoint as checkpoint
|
| 79 |
+
from torch.autograd import Variable
|
| 80 |
+
from torch import nn
|
| 81 |
+
from torchvision import transforms
|
| 82 |
+
|
| 83 |
+
from einops import rearrange
|
| 84 |
+
|
| 85 |
+
from timm.models import load_checkpoint
|
| 86 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 87 |
+
|
| 88 |
+
torch_device = 'cuda'
|
| 89 |
+
torch_dtype = torch.float16
|
| 90 |
+
#+end_src
|
| 91 |
+
|
| 92 |
+
** COMMENT Load image using CV
|
| 93 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 94 |
+
def load_image(input_image_path):
|
| 95 |
+
img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
|
| 96 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 97 |
+
return img
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_image_torch(input_image_path):
|
| 101 |
+
img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
|
| 102 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 103 |
+
img = torch.from_numpy(img)
|
| 104 |
+
img = img.to(dtype=torch.float32)
|
| 105 |
+
img /= 255.0
|
| 106 |
+
img = img.unsqueeze(0)
|
| 107 |
+
return img
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def save_mask(output_image_path, mask):
|
| 111 |
+
cv2.imwrite(output_image_path, mask)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def save_mask_torch(output_image_path, mask):
|
| 115 |
+
mask = mask.detach().cpu()
|
| 116 |
+
mask *= 255.0
|
| 117 |
+
mask = mask.clamp(0, 255)
|
| 118 |
+
print(mask.shape)
|
| 119 |
+
mask = mask.squeeze(0)
|
| 120 |
+
mask = mask.to(dtype=torch.uint8)
|
| 121 |
+
print(mask.shape)
|
| 122 |
+
mask = mask.numpy()
|
| 123 |
+
print(mask.shape)
|
| 124 |
+
cv2.imwrite(output_image_path, mask)
|
| 125 |
+
#+end_src
|
| 126 |
+
|
| 127 |
+
** MVANet_inference function
|
| 128 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 129 |
+
def check_mkdir(dir_name):
|
| 130 |
+
if not os.path.isdir(dir_name):
|
| 131 |
+
os.makedirs(dir_name)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def SwinT(pretrained=True):
|
| 135 |
+
model = SwinTransformer(embed_dim=96,
|
| 136 |
+
depths=[2, 2, 6, 2],
|
| 137 |
+
num_heads=[3, 6, 12, 24],
|
| 138 |
+
window_size=7)
|
| 139 |
+
if pretrained is True:
|
| 140 |
+
model.load_state_dict(torch.load(
|
| 141 |
+
'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
|
| 142 |
+
map_location='cpu')['model'],
|
| 143 |
+
strict=False)
|
| 144 |
+
|
| 145 |
+
return model
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def SwinS(pretrained=True):
|
| 149 |
+
model = SwinTransformer(embed_dim=96,
|
| 150 |
+
depths=[2, 2, 18, 2],
|
| 151 |
+
num_heads=[3, 6, 12, 24],
|
| 152 |
+
window_size=7)
|
| 153 |
+
if pretrained is True:
|
| 154 |
+
model.load_state_dict(torch.load(
|
| 155 |
+
'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
|
| 156 |
+
map_location='cpu')['model'],
|
| 157 |
+
strict=False)
|
| 158 |
+
|
| 159 |
+
return model
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def SwinB(pretrained=True):
|
| 163 |
+
model = SwinTransformer(embed_dim=128,
|
| 164 |
+
depths=[2, 2, 18, 2],
|
| 165 |
+
num_heads=[4, 8, 16, 32],
|
| 166 |
+
window_size=12)
|
| 167 |
+
if pretrained is True:
|
| 168 |
+
import os
|
| 169 |
+
model.load_state_dict(torch.load(pretrained_SwinB_model_path,
|
| 170 |
+
map_location='cpu')['model'],
|
| 171 |
+
strict=False)
|
| 172 |
+
return model
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def SwinL(pretrained=True):
|
| 176 |
+
model = SwinTransformer(embed_dim=192,
|
| 177 |
+
depths=[2, 2, 18, 2],
|
| 178 |
+
num_heads=[6, 12, 24, 48],
|
| 179 |
+
window_size=12)
|
| 180 |
+
if pretrained is True:
|
| 181 |
+
model.load_state_dict(torch.load(
|
| 182 |
+
'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
|
| 183 |
+
map_location='cpu')['model'],
|
| 184 |
+
strict=False)
|
| 185 |
+
|
| 186 |
+
return model
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def get_activation_fn(activation):
|
| 190 |
+
"""Return an activation function given a string"""
|
| 191 |
+
if activation == "relu":
|
| 192 |
+
return F.relu
|
| 193 |
+
if activation == "gelu":
|
| 194 |
+
return F.gelu
|
| 195 |
+
if activation == "glu":
|
| 196 |
+
return F.glu
|
| 197 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def make_cbr(in_dim, out_dim):
|
| 201 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
|
| 202 |
+
nn.BatchNorm2d(out_dim), nn.PReLU())
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def make_cbg(in_dim, out_dim):
|
| 206 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
|
| 207 |
+
nn.BatchNorm2d(out_dim), nn.GELU())
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
|
| 211 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def resize_as(x, y, interpolation='bilinear'):
|
| 215 |
+
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def image2patches(x):
|
| 219 |
+
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
|
| 220 |
+
x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def patches2image(x):
|
| 225 |
+
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
|
| 226 |
+
x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def window_partition(x, window_size):
|
| 231 |
+
"""
|
| 232 |
+
Args:
|
| 233 |
+
x: (B, H, W, C)
|
| 234 |
+
window_size (int): window size
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 238 |
+
"""
|
| 239 |
+
B, H, W, C = x.shape
|
| 240 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
| 241 |
+
C)
|
| 242 |
+
windows = x.permute(0, 1, 3, 2, 4,
|
| 243 |
+
5).contiguous().view(-1, window_size, window_size, C)
|
| 244 |
+
return windows
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def window_reverse(windows, window_size, H, W):
|
| 248 |
+
"""
|
| 249 |
+
Args:
|
| 250 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 251 |
+
window_size (int): Window size
|
| 252 |
+
H (int): Height of image
|
| 253 |
+
W (int): Width of image
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
x: (B, H, W, C)
|
| 257 |
+
"""
|
| 258 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 259 |
+
x = windows.view(B, H // window_size, W // window_size, window_size,
|
| 260 |
+
window_size, -1)
|
| 261 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 262 |
+
return x
|
| 263 |
+
#+end_src
|
| 264 |
+
|
| 265 |
+
** MVANet_inference class
|
| 266 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
|
| 267 |
+
class Mlp(nn.Module):
|
| 268 |
+
""" Multilayer perceptron."""
|
| 269 |
+
|
| 270 |
+
def __init__(self,
|
| 271 |
+
in_features,
|
| 272 |
+
hidden_features=None,
|
| 273 |
+
out_features=None,
|
| 274 |
+
act_layer=nn.GELU,
|
| 275 |
+
drop=0.):
|
| 276 |
+
super().__init__()
|
| 277 |
+
out_features = out_features or in_features
|
| 278 |
+
hidden_features = hidden_features or in_features
|
| 279 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 280 |
+
self.act = act_layer()
|
| 281 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 282 |
+
self.drop = nn.Dropout(drop)
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
x = self.fc1(x)
|
| 286 |
+
x = self.act(x)
|
| 287 |
+
x = self.drop(x)
|
| 288 |
+
x = self.fc2(x)
|
| 289 |
+
x = self.drop(x)
|
| 290 |
+
return x
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class WindowAttention(nn.Module):
|
| 294 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 295 |
+
It supports both of shifted and non-shifted window.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
dim (int): Number of input channels.
|
| 299 |
+
window_size (tuple[int]): The height and width of the window.
|
| 300 |
+
num_heads (int): Number of attention heads.
|
| 301 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 302 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 303 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 304 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def __init__(self,
|
| 308 |
+
dim,
|
| 309 |
+
window_size,
|
| 310 |
+
num_heads,
|
| 311 |
+
qkv_bias=True,
|
| 312 |
+
qk_scale=None,
|
| 313 |
+
attn_drop=0.,
|
| 314 |
+
proj_drop=0.):
|
| 315 |
+
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.dim = dim
|
| 318 |
+
self.window_size = window_size # Wh, Ww
|
| 319 |
+
self.num_heads = num_heads
|
| 320 |
+
head_dim = dim // num_heads
|
| 321 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 322 |
+
|
| 323 |
+
# define a parameter table of relative position bias
|
| 324 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 325 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
| 326 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 327 |
+
|
| 328 |
+
# get pair-wise relative position index for each token inside the window
|
| 329 |
+
coords_h = torch.arange(self.window_size[0])
|
| 330 |
+
coords_w = torch.arange(self.window_size[1])
|
| 331 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 332 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 333 |
+
relative_coords = coords_flatten[:, :,
|
| 334 |
+
None] - coords_flatten[:,
|
| 335 |
+
None, :] # 2, Wh*Ww, Wh*Ww
|
| 336 |
+
relative_coords = relative_coords.permute(
|
| 337 |
+
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 338 |
+
relative_coords[:, :,
|
| 339 |
+
0] += self.window_size[0] - 1 # shift to start from 0
|
| 340 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 341 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 342 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 343 |
+
self.register_buffer("relative_position_index",
|
| 344 |
+
relative_position_index)
|
| 345 |
+
|
| 346 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 347 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 348 |
+
self.proj = nn.Linear(dim, dim)
|
| 349 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 350 |
+
|
| 351 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 352 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 353 |
+
|
| 354 |
+
def forward(self, x, mask=None):
|
| 355 |
+
""" Forward function.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 359 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 360 |
+
"""
|
| 361 |
+
x = x.to(dtype=torch_dtype, device=torch_device)
|
| 362 |
+
B_, N, C = x.shape
|
| 363 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
| 364 |
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 365 |
+
q, k, v = qkv[0], qkv[1], qkv[
|
| 366 |
+
2] # make torchscript happy (cannot use tensor as tuple)
|
| 367 |
+
|
| 368 |
+
q = q * self.scale
|
| 369 |
+
attn = (q @ k.transpose(-2, -1))
|
| 370 |
+
|
| 371 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 372 |
+
self.relative_position_index.view(-1)].view(
|
| 373 |
+
self.window_size[0] * self.window_size[1],
|
| 374 |
+
self.window_size[0] * self.window_size[1],
|
| 375 |
+
-1) # Wh*Ww,Wh*Ww,nH
|
| 376 |
+
relative_position_bias = relative_position_bias.permute(
|
| 377 |
+
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 378 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 379 |
+
|
| 380 |
+
if mask is not None:
|
| 381 |
+
nW = mask.shape[0]
|
| 382 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
| 383 |
+
N) + mask.unsqueeze(1).unsqueeze(0)
|
| 384 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 385 |
+
attn = self.softmax(attn)
|
| 386 |
+
else:
|
| 387 |
+
attn = self.softmax(attn)
|
| 388 |
+
|
| 389 |
+
attn = self.attn_drop(attn)
|
| 390 |
+
attn = attn.to(dtype=torch_dtype, device=torch_device)
|
| 391 |
+
v = v.to(dtype=torch_dtype, device=torch_device)
|
| 392 |
+
|
| 393 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 394 |
+
x = self.proj(x)
|
| 395 |
+
x = self.proj_drop(x)
|
| 396 |
+
return x
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class SwinTransformerBlock(nn.Module):
|
| 400 |
+
""" Swin Transformer Block.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
dim (int): Number of input channels.
|
| 404 |
+
num_heads (int): Number of attention heads.
|
| 405 |
+
window_size (int): Window size.
|
| 406 |
+
shift_size (int): Shift size for SW-MSA.
|
| 407 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 408 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 409 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 410 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 411 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 412 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 413 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 414 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
def __init__(self,
|
| 418 |
+
dim,
|
| 419 |
+
num_heads,
|
| 420 |
+
window_size=7,
|
| 421 |
+
shift_size=0,
|
| 422 |
+
mlp_ratio=4.,
|
| 423 |
+
qkv_bias=True,
|
| 424 |
+
qk_scale=None,
|
| 425 |
+
drop=0.,
|
| 426 |
+
attn_drop=0.,
|
| 427 |
+
drop_path=0.,
|
| 428 |
+
act_layer=nn.GELU,
|
| 429 |
+
norm_layer=nn.LayerNorm):
|
| 430 |
+
super().__init__()
|
| 431 |
+
self.dim = dim
|
| 432 |
+
self.num_heads = num_heads
|
| 433 |
+
self.window_size = window_size
|
| 434 |
+
self.shift_size = shift_size
|
| 435 |
+
self.mlp_ratio = mlp_ratio
|
| 436 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 437 |
+
|
| 438 |
+
self.norm1 = norm_layer(dim)
|
| 439 |
+
self.attn = WindowAttention(dim,
|
| 440 |
+
window_size=to_2tuple(self.window_size),
|
| 441 |
+
num_heads=num_heads,
|
| 442 |
+
qkv_bias=qkv_bias,
|
| 443 |
+
qk_scale=qk_scale,
|
| 444 |
+
attn_drop=attn_drop,
|
| 445 |
+
proj_drop=drop)
|
| 446 |
+
|
| 447 |
+
self.drop_path = DropPath(
|
| 448 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 449 |
+
self.norm2 = norm_layer(dim)
|
| 450 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 451 |
+
self.mlp = Mlp(in_features=dim,
|
| 452 |
+
hidden_features=mlp_hidden_dim,
|
| 453 |
+
act_layer=act_layer,
|
| 454 |
+
drop=drop)
|
| 455 |
+
|
| 456 |
+
self.H = None
|
| 457 |
+
self.W = None
|
| 458 |
+
|
| 459 |
+
def forward(self, x, mask_matrix):
|
| 460 |
+
""" Forward function.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 464 |
+
H, W: Spatial resolution of the input feature.
|
| 465 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 466 |
+
"""
|
| 467 |
+
B, L, C = x.shape
|
| 468 |
+
H, W = self.H, self.W
|
| 469 |
+
assert L == H * W, "input feature has wrong size"
|
| 470 |
+
|
| 471 |
+
shortcut = x
|
| 472 |
+
x = self.norm1(x)
|
| 473 |
+
x = x.view(B, H, W, C)
|
| 474 |
+
|
| 475 |
+
# pad feature maps to multiples of window size
|
| 476 |
+
pad_l = pad_t = 0
|
| 477 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 478 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 479 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 480 |
+
_, Hp, Wp, _ = x.shape
|
| 481 |
+
|
| 482 |
+
# cyclic shift
|
| 483 |
+
if self.shift_size > 0:
|
| 484 |
+
shifted_x = torch.roll(x,
|
| 485 |
+
shifts=(-self.shift_size, -self.shift_size),
|
| 486 |
+
dims=(1, 2))
|
| 487 |
+
attn_mask = mask_matrix
|
| 488 |
+
else:
|
| 489 |
+
shifted_x = x
|
| 490 |
+
attn_mask = None
|
| 491 |
+
|
| 492 |
+
# partition windows
|
| 493 |
+
x_windows = window_partition(
|
| 494 |
+
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 495 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
| 496 |
+
C) # nW*B, window_size*window_size, C
|
| 497 |
+
|
| 498 |
+
# W-MSA/SW-MSA
|
| 499 |
+
attn_windows = self.attn(
|
| 500 |
+
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 501 |
+
|
| 502 |
+
# merge windows
|
| 503 |
+
attn_windows = attn_windows.view(-1, self.window_size,
|
| 504 |
+
self.window_size, C)
|
| 505 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
| 506 |
+
Wp) # B H' W' C
|
| 507 |
+
|
| 508 |
+
# reverse cyclic shift
|
| 509 |
+
if self.shift_size > 0:
|
| 510 |
+
x = torch.roll(shifted_x,
|
| 511 |
+
shifts=(self.shift_size, self.shift_size),
|
| 512 |
+
dims=(1, 2))
|
| 513 |
+
else:
|
| 514 |
+
x = shifted_x
|
| 515 |
+
|
| 516 |
+
if pad_r > 0 or pad_b > 0:
|
| 517 |
+
x = x[:, :H, :W, :].contiguous()
|
| 518 |
+
|
| 519 |
+
x = x.view(B, H * W, C)
|
| 520 |
+
|
| 521 |
+
# FFN
|
| 522 |
+
x = shortcut + self.drop_path(x)
|
| 523 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 524 |
+
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class PatchMerging(nn.Module):
|
| 529 |
+
""" Patch Merging Layer
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
dim (int): Number of input channels.
|
| 533 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 537 |
+
super().__init__()
|
| 538 |
+
self.dim = dim
|
| 539 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 540 |
+
self.norm = norm_layer(4 * dim)
|
| 541 |
+
|
| 542 |
+
def forward(self, x, H, W):
|
| 543 |
+
""" Forward function.
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 547 |
+
H, W: Spatial resolution of the input feature.
|
| 548 |
+
"""
|
| 549 |
+
B, L, C = x.shape
|
| 550 |
+
assert L == H * W, "input feature has wrong size"
|
| 551 |
+
|
| 552 |
+
x = x.view(B, H, W, C)
|
| 553 |
+
|
| 554 |
+
# padding
|
| 555 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 556 |
+
if pad_input:
|
| 557 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 558 |
+
|
| 559 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 560 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 561 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 562 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 563 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 564 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 565 |
+
|
| 566 |
+
x = self.norm(x)
|
| 567 |
+
x = self.reduction(x)
|
| 568 |
+
|
| 569 |
+
return x
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class BasicLayer(nn.Module):
|
| 573 |
+
""" A basic Swin Transformer layer for one stage.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
dim (int): Number of feature channels
|
| 577 |
+
depth (int): Depths of this stage.
|
| 578 |
+
num_heads (int): Number of attention head.
|
| 579 |
+
window_size (int): Local window size. Default: 7.
|
| 580 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 581 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 582 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 583 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 584 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 585 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 586 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 587 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 588 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 589 |
+
"""
|
| 590 |
+
|
| 591 |
+
def __init__(self,
|
| 592 |
+
dim,
|
| 593 |
+
depth,
|
| 594 |
+
num_heads,
|
| 595 |
+
window_size=7,
|
| 596 |
+
mlp_ratio=4.,
|
| 597 |
+
qkv_bias=True,
|
| 598 |
+
qk_scale=None,
|
| 599 |
+
drop=0.,
|
| 600 |
+
attn_drop=0.,
|
| 601 |
+
drop_path=0.,
|
| 602 |
+
norm_layer=nn.LayerNorm,
|
| 603 |
+
downsample=None,
|
| 604 |
+
use_checkpoint=False):
|
| 605 |
+
super().__init__()
|
| 606 |
+
self.window_size = window_size
|
| 607 |
+
self.shift_size = window_size // 2
|
| 608 |
+
self.depth = depth
|
| 609 |
+
self.use_checkpoint = use_checkpoint
|
| 610 |
+
|
| 611 |
+
# build blocks
|
| 612 |
+
self.blocks = nn.ModuleList([
|
| 613 |
+
SwinTransformerBlock(dim=dim,
|
| 614 |
+
num_heads=num_heads,
|
| 615 |
+
window_size=window_size,
|
| 616 |
+
shift_size=0 if
|
| 617 |
+
(i % 2 == 0) else window_size // 2,
|
| 618 |
+
mlp_ratio=mlp_ratio,
|
| 619 |
+
qkv_bias=qkv_bias,
|
| 620 |
+
qk_scale=qk_scale,
|
| 621 |
+
drop=drop,
|
| 622 |
+
attn_drop=attn_drop,
|
| 623 |
+
drop_path=drop_path[i] if isinstance(
|
| 624 |
+
drop_path, list) else drop_path,
|
| 625 |
+
norm_layer=norm_layer) for i in range(depth)
|
| 626 |
+
])
|
| 627 |
+
|
| 628 |
+
# patch merging layer
|
| 629 |
+
if downsample is not None:
|
| 630 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 631 |
+
else:
|
| 632 |
+
self.downsample = None
|
| 633 |
+
|
| 634 |
+
def forward(self, x, H, W):
|
| 635 |
+
""" Forward function.
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 639 |
+
H, W: Spatial resolution of the input feature.
|
| 640 |
+
"""
|
| 641 |
+
|
| 642 |
+
# calculate attention mask for SW-MSA
|
| 643 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 644 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 645 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 646 |
+
h_slices = (slice(0, -self.window_size),
|
| 647 |
+
slice(-self.window_size,
|
| 648 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 649 |
+
w_slices = (slice(0, -self.window_size),
|
| 650 |
+
slice(-self.window_size,
|
| 651 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 652 |
+
cnt = 0
|
| 653 |
+
for h in h_slices:
|
| 654 |
+
for w in w_slices:
|
| 655 |
+
img_mask[:, h, w, :] = cnt
|
| 656 |
+
cnt += 1
|
| 657 |
+
|
| 658 |
+
mask_windows = window_partition(
|
| 659 |
+
img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 660 |
+
mask_windows = mask_windows.view(-1,
|
| 661 |
+
self.window_size * self.window_size)
|
| 662 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 663 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
| 664 |
+
float(-100.0)).masked_fill(
|
| 665 |
+
attn_mask == 0, float(0.0))
|
| 666 |
+
|
| 667 |
+
for blk in self.blocks:
|
| 668 |
+
blk.H, blk.W = H, W
|
| 669 |
+
if self.use_checkpoint:
|
| 670 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 671 |
+
else:
|
| 672 |
+
x = blk(x, attn_mask)
|
| 673 |
+
if self.downsample is not None:
|
| 674 |
+
x_down = self.downsample(x, H, W)
|
| 675 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 676 |
+
return x, H, W, x_down, Wh, Ww
|
| 677 |
+
else:
|
| 678 |
+
return x, H, W, x, H, W
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class PatchEmbed(nn.Module):
|
| 682 |
+
""" Image to Patch Embedding
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
patch_size (int): Patch token size. Default: 4.
|
| 686 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 687 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 688 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 689 |
+
"""
|
| 690 |
+
|
| 691 |
+
def __init__(self,
|
| 692 |
+
patch_size=4,
|
| 693 |
+
in_chans=3,
|
| 694 |
+
embed_dim=96,
|
| 695 |
+
norm_layer=None):
|
| 696 |
+
super().__init__()
|
| 697 |
+
patch_size = to_2tuple(patch_size)
|
| 698 |
+
self.patch_size = patch_size
|
| 699 |
+
|
| 700 |
+
self.in_chans = in_chans
|
| 701 |
+
self.embed_dim = embed_dim
|
| 702 |
+
|
| 703 |
+
self.proj = nn.Conv2d(in_chans,
|
| 704 |
+
embed_dim,
|
| 705 |
+
kernel_size=patch_size,
|
| 706 |
+
stride=patch_size)
|
| 707 |
+
if norm_layer is not None:
|
| 708 |
+
self.norm = norm_layer(embed_dim)
|
| 709 |
+
else:
|
| 710 |
+
self.norm = None
|
| 711 |
+
|
| 712 |
+
def forward(self, x):
|
| 713 |
+
"""Forward function."""
|
| 714 |
+
# padding
|
| 715 |
+
_, _, H, W = x.size()
|
| 716 |
+
if W % self.patch_size[1] != 0:
|
| 717 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 718 |
+
if H % self.patch_size[0] != 0:
|
| 719 |
+
x = F.pad(x,
|
| 720 |
+
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 721 |
+
|
| 722 |
+
x = self.proj(x) # B C Wh Ww
|
| 723 |
+
if self.norm is not None:
|
| 724 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 725 |
+
x = x.flatten(2).transpose(1, 2)
|
| 726 |
+
x = self.norm(x)
|
| 727 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 728 |
+
|
| 729 |
+
return x
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
class SwinTransformer(nn.Module):
|
| 733 |
+
""" Swin Transformer backbone.
|
| 734 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 735 |
+
https://arxiv.org/pdf/2103.14030
|
| 736 |
+
|
| 737 |
+
Args:
|
| 738 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 739 |
+
used in absolute postion embedding. Default 224.
|
| 740 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 741 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 742 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 743 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 744 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 745 |
+
window_size (int): Window size. Default: 7.
|
| 746 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 747 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 748 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 749 |
+
drop_rate (float): Dropout rate.
|
| 750 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 751 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 752 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 753 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 754 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 755 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 756 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 757 |
+
-1 means not freezing any parameters.
|
| 758 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 759 |
+
"""
|
| 760 |
+
|
| 761 |
+
def __init__(self,
|
| 762 |
+
pretrain_img_size=224,
|
| 763 |
+
patch_size=4,
|
| 764 |
+
in_chans=3,
|
| 765 |
+
embed_dim=96,
|
| 766 |
+
depths=[2, 2, 6, 2],
|
| 767 |
+
num_heads=[3, 6, 12, 24],
|
| 768 |
+
window_size=7,
|
| 769 |
+
mlp_ratio=4.,
|
| 770 |
+
qkv_bias=True,
|
| 771 |
+
qk_scale=None,
|
| 772 |
+
drop_rate=0.,
|
| 773 |
+
attn_drop_rate=0.,
|
| 774 |
+
drop_path_rate=0.2,
|
| 775 |
+
norm_layer=nn.LayerNorm,
|
| 776 |
+
ape=False,
|
| 777 |
+
patch_norm=True,
|
| 778 |
+
out_indices=(0, 1, 2, 3),
|
| 779 |
+
frozen_stages=-1,
|
| 780 |
+
use_checkpoint=False):
|
| 781 |
+
super().__init__()
|
| 782 |
+
|
| 783 |
+
self.pretrain_img_size = pretrain_img_size
|
| 784 |
+
self.num_layers = len(depths)
|
| 785 |
+
self.embed_dim = embed_dim
|
| 786 |
+
self.ape = ape
|
| 787 |
+
self.patch_norm = patch_norm
|
| 788 |
+
self.out_indices = out_indices
|
| 789 |
+
self.frozen_stages = frozen_stages
|
| 790 |
+
|
| 791 |
+
# split image into non-overlapping patches
|
| 792 |
+
self.patch_embed = PatchEmbed(
|
| 793 |
+
patch_size=patch_size,
|
| 794 |
+
in_chans=in_chans,
|
| 795 |
+
embed_dim=embed_dim,
|
| 796 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 797 |
+
|
| 798 |
+
# absolute position embedding
|
| 799 |
+
if self.ape:
|
| 800 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 801 |
+
patch_size = to_2tuple(patch_size)
|
| 802 |
+
patches_resolution = [
|
| 803 |
+
pretrain_img_size[0] // patch_size[0],
|
| 804 |
+
pretrain_img_size[1] // patch_size[1]
|
| 805 |
+
]
|
| 806 |
+
|
| 807 |
+
self.absolute_pos_embed = nn.Parameter(
|
| 808 |
+
torch.zeros(1, embed_dim, patches_resolution[0],
|
| 809 |
+
patches_resolution[1]))
|
| 810 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 811 |
+
|
| 812 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 813 |
+
|
| 814 |
+
# stochastic depth
|
| 815 |
+
dpr = [
|
| 816 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 817 |
+
] # stochastic depth decay rule
|
| 818 |
+
|
| 819 |
+
# build layers
|
| 820 |
+
self.layers = nn.ModuleList()
|
| 821 |
+
for i_layer in range(self.num_layers):
|
| 822 |
+
layer = BasicLayer(
|
| 823 |
+
dim=int(embed_dim * 2**i_layer),
|
| 824 |
+
depth=depths[i_layer],
|
| 825 |
+
num_heads=num_heads[i_layer],
|
| 826 |
+
window_size=window_size,
|
| 827 |
+
mlp_ratio=mlp_ratio,
|
| 828 |
+
qkv_bias=qkv_bias,
|
| 829 |
+
qk_scale=qk_scale,
|
| 830 |
+
drop=drop_rate,
|
| 831 |
+
attn_drop=attn_drop_rate,
|
| 832 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 833 |
+
norm_layer=norm_layer,
|
| 834 |
+
downsample=PatchMerging if
|
| 835 |
+
(i_layer < self.num_layers - 1) else None,
|
| 836 |
+
use_checkpoint=use_checkpoint)
|
| 837 |
+
self.layers.append(layer)
|
| 838 |
+
|
| 839 |
+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
| 840 |
+
self.num_features = num_features
|
| 841 |
+
|
| 842 |
+
# add a norm layer for each output
|
| 843 |
+
for i_layer in out_indices:
|
| 844 |
+
layer = norm_layer(num_features[i_layer])
|
| 845 |
+
layer_name = f'norm{i_layer}'
|
| 846 |
+
self.add_module(layer_name, layer)
|
| 847 |
+
|
| 848 |
+
self._freeze_stages()
|
| 849 |
+
|
| 850 |
+
def _freeze_stages(self):
|
| 851 |
+
if self.frozen_stages >= 0:
|
| 852 |
+
self.patch_embed.eval()
|
| 853 |
+
for param in self.patch_embed.parameters():
|
| 854 |
+
param.requires_grad = False
|
| 855 |
+
|
| 856 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 857 |
+
self.absolute_pos_embed.requires_grad = False
|
| 858 |
+
|
| 859 |
+
if self.frozen_stages >= 2:
|
| 860 |
+
self.pos_drop.eval()
|
| 861 |
+
for i in range(0, self.frozen_stages - 1):
|
| 862 |
+
m = self.layers[i]
|
| 863 |
+
m.eval()
|
| 864 |
+
for param in m.parameters():
|
| 865 |
+
param.requires_grad = False
|
| 866 |
+
|
| 867 |
+
def init_weights(self, pretrained=None):
|
| 868 |
+
"""Initialize the weights in backbone.
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 872 |
+
Defaults to None.
|
| 873 |
+
"""
|
| 874 |
+
|
| 875 |
+
def _init_weights(m):
|
| 876 |
+
if isinstance(m, nn.Linear):
|
| 877 |
+
trunc_normal_(m.weight, std=.02)
|
| 878 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 879 |
+
nn.init.constant_(m.bias, 0)
|
| 880 |
+
elif isinstance(m, nn.LayerNorm):
|
| 881 |
+
nn.init.constant_(m.bias, 0)
|
| 882 |
+
nn.init.constant_(m.weight, 1.0)
|
| 883 |
+
|
| 884 |
+
if isinstance(pretrained, str):
|
| 885 |
+
self.apply(_init_weights)
|
| 886 |
+
load_checkpoint(self, pretrained, strict=False, logger=None)
|
| 887 |
+
elif pretrained is None:
|
| 888 |
+
self.apply(_init_weights)
|
| 889 |
+
else:
|
| 890 |
+
raise TypeError('pretrained must be a str or None')
|
| 891 |
+
|
| 892 |
+
def forward(self, x):
|
| 893 |
+
x = self.patch_embed(x)
|
| 894 |
+
|
| 895 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 896 |
+
if self.ape:
|
| 897 |
+
# interpolate the position embedding to the corresponding size
|
| 898 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
|
| 899 |
+
size=(Wh, Ww),
|
| 900 |
+
mode='bicubic')
|
| 901 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 902 |
+
|
| 903 |
+
outs = [x.contiguous()]
|
| 904 |
+
x = x.flatten(2).transpose(1, 2)
|
| 905 |
+
x = self.pos_drop(x)
|
| 906 |
+
for i in range(self.num_layers):
|
| 907 |
+
layer = self.layers[i]
|
| 908 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 909 |
+
|
| 910 |
+
if i in self.out_indices:
|
| 911 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 912 |
+
x_out = norm_layer(x_out)
|
| 913 |
+
|
| 914 |
+
out = x_out.view(-1, H, W,
|
| 915 |
+
self.num_features[i]).permute(0, 3, 1,
|
| 916 |
+
2).contiguous()
|
| 917 |
+
outs.append(out)
|
| 918 |
+
|
| 919 |
+
return tuple(outs)
|
| 920 |
+
|
| 921 |
+
def train(self, mode=True):
|
| 922 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 923 |
+
super(SwinTransformer, self).train(mode)
|
| 924 |
+
self._freeze_stages()
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
class PositionEmbeddingSine:
|
| 928 |
+
|
| 929 |
+
def __init__(self,
|
| 930 |
+
num_pos_feats=64,
|
| 931 |
+
temperature=10000,
|
| 932 |
+
normalize=False,
|
| 933 |
+
scale=None):
|
| 934 |
+
super().__init__()
|
| 935 |
+
self.num_pos_feats = num_pos_feats
|
| 936 |
+
self.temperature = temperature
|
| 937 |
+
self.normalize = normalize
|
| 938 |
+
if scale is not None and normalize is False:
|
| 939 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 940 |
+
if scale is None:
|
| 941 |
+
scale = 2 * math.pi
|
| 942 |
+
self.scale = scale
|
| 943 |
+
self.dim_t = torch.arange(0,
|
| 944 |
+
self.num_pos_feats,
|
| 945 |
+
dtype=torch_dtype,
|
| 946 |
+
device=torch_device)
|
| 947 |
+
|
| 948 |
+
def __call__(self, b, h, w):
|
| 949 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
|
| 950 |
+
assert mask is not None
|
| 951 |
+
not_mask = ~mask
|
| 952 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
|
| 953 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
|
| 954 |
+
if self.normalize:
|
| 955 |
+
eps = 1e-6
|
| 956 |
+
y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
|
| 957 |
+
self.scale).to(device=torch_device, dtype=torch_dtype)
|
| 958 |
+
x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
|
| 959 |
+
self.scale).to(device=torch_device, dtype=torch_dtype)
|
| 960 |
+
|
| 961 |
+
dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
|
| 962 |
+
|
| 963 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 964 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 965 |
+
pos_x = torch.stack(
|
| 966 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
|
| 967 |
+
dim=4).flatten(3)
|
| 968 |
+
pos_y = torch.stack(
|
| 969 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
|
| 970 |
+
dim=4).flatten(3)
|
| 971 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
class MCLM(nn.Module):
|
| 975 |
+
|
| 976 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 977 |
+
super(MCLM, self).__init__()
|
| 978 |
+
self.attention = nn.ModuleList([
|
| 979 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 980 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 981 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 982 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 983 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 984 |
+
])
|
| 985 |
+
|
| 986 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
| 987 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
| 988 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 989 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 990 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 991 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 992 |
+
self.dropout = nn.Dropout(0.1)
|
| 993 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 994 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 995 |
+
self.activation = get_activation_fn('relu')
|
| 996 |
+
self.pool_ratios = pool_ratios
|
| 997 |
+
self.p_poses = []
|
| 998 |
+
self.g_pos = None
|
| 999 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1000 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1001 |
+
|
| 1002 |
+
def forward(self, l, g):
|
| 1003 |
+
"""
|
| 1004 |
+
l: 4,c,h,w
|
| 1005 |
+
g: 1,c,h,w
|
| 1006 |
+
"""
|
| 1007 |
+
b, c, h, w = l.size()
|
| 1008 |
+
# 4,c,h,w -> 1,c,2h,2w
|
| 1009 |
+
concated_locs = rearrange(l,
|
| 1010 |
+
'(hg wg b) c h w -> b c (hg h) (wg w)',
|
| 1011 |
+
hg=2,
|
| 1012 |
+
wg=2)
|
| 1013 |
+
|
| 1014 |
+
pools = []
|
| 1015 |
+
for pool_ratio in self.pool_ratios:
|
| 1016 |
+
# b,c,h,w
|
| 1017 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1018 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
| 1019 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
| 1020 |
+
if self.g_pos is None:
|
| 1021 |
+
pos_emb = self.positional_encoding(pool.shape[0],
|
| 1022 |
+
pool.shape[2],
|
| 1023 |
+
pool.shape[3])
|
| 1024 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1025 |
+
self.p_poses.append(pos_emb)
|
| 1026 |
+
pools = torch.cat(pools, 0)
|
| 1027 |
+
if self.g_pos is None:
|
| 1028 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
| 1029 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
|
| 1030 |
+
g.shape[3])
|
| 1031 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1032 |
+
|
| 1033 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
| 1034 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
| 1035 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
|
| 1036 |
+
g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
| 1037 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
| 1038 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(
|
| 1039 |
+
self.linear2(
|
| 1040 |
+
self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
| 1041 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
| 1042 |
+
|
| 1043 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
| 1044 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
| 1045 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
| 1046 |
+
_g_hw_b_c = rearrange(_g_hw_b_c,
|
| 1047 |
+
"(ng h) (nw w) b c -> (h w) (ng nw b) c",
|
| 1048 |
+
ng=2,
|
| 1049 |
+
nw=2)
|
| 1050 |
+
outputs_re = []
|
| 1051 |
+
for i, (_l, _g) in enumerate(
|
| 1052 |
+
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
| 1053 |
+
outputs_re.append(self.attention[i + 1](_l, _g,
|
| 1054 |
+
_g)[0]) # (h w) 1 c
|
| 1055 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
| 1056 |
+
|
| 1057 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 1058 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 1059 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(
|
| 1060 |
+
self.linear4(
|
| 1061 |
+
self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 1062 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 1063 |
+
|
| 1064 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 1065 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
| 1066 |
+
|
| 1067 |
+
|
| 1068 |
+
class inf_MCLM(nn.Module):
|
| 1069 |
+
|
| 1070 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 1071 |
+
super(inf_MCLM, self).__init__()
|
| 1072 |
+
self.attention = nn.ModuleList([
|
| 1073 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1074 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1075 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1076 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1077 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1078 |
+
])
|
| 1079 |
+
|
| 1080 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
| 1081 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
| 1082 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1083 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1084 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1085 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1086 |
+
self.dropout = nn.Dropout(0.1)
|
| 1087 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1088 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1089 |
+
self.activation = get_activation_fn('relu')
|
| 1090 |
+
self.pool_ratios = pool_ratios
|
| 1091 |
+
self.p_poses = []
|
| 1092 |
+
self.g_pos = None
|
| 1093 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1094 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1095 |
+
|
| 1096 |
+
def forward(self, l, g):
|
| 1097 |
+
"""
|
| 1098 |
+
l: 4,c,h,w
|
| 1099 |
+
g: 1,c,h,w
|
| 1100 |
+
"""
|
| 1101 |
+
b, c, h, w = l.size()
|
| 1102 |
+
# 4,c,h,w -> 1,c,2h,2w
|
| 1103 |
+
concated_locs = rearrange(l,
|
| 1104 |
+
'(hg wg b) c h w -> b c (hg h) (wg w)',
|
| 1105 |
+
hg=2,
|
| 1106 |
+
wg=2)
|
| 1107 |
+
self.p_poses = []
|
| 1108 |
+
pools = []
|
| 1109 |
+
for pool_ratio in self.pool_ratios:
|
| 1110 |
+
# b,c,h,w
|
| 1111 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1112 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
| 1113 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
| 1114 |
+
# if self.g_pos is None:
|
| 1115 |
+
pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
|
| 1116 |
+
pool.shape[3])
|
| 1117 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1118 |
+
self.p_poses.append(pos_emb)
|
| 1119 |
+
pools = torch.cat(pools, 0)
|
| 1120 |
+
# if self.g_pos is None:
|
| 1121 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
| 1122 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
|
| 1123 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1124 |
+
|
| 1125 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
| 1126 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
| 1127 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
|
| 1128 |
+
g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
| 1129 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
| 1130 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(
|
| 1131 |
+
self.linear2(
|
| 1132 |
+
self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
| 1133 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
| 1134 |
+
|
| 1135 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
| 1136 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
| 1137 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
| 1138 |
+
_g_hw_b_c = rearrange(_g_hw_b_c,
|
| 1139 |
+
"(ng h) (nw w) b c -> (h w) (ng nw b) c",
|
| 1140 |
+
ng=2,
|
| 1141 |
+
nw=2)
|
| 1142 |
+
outputs_re = []
|
| 1143 |
+
for i, (_l, _g) in enumerate(
|
| 1144 |
+
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
| 1145 |
+
outputs_re.append(self.attention[i + 1](_l, _g,
|
| 1146 |
+
_g)[0]) # (h w) 1 c
|
| 1147 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
| 1148 |
+
|
| 1149 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 1150 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 1151 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(
|
| 1152 |
+
self.linear4(
|
| 1153 |
+
self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 1154 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 1155 |
+
|
| 1156 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 1157 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
| 1158 |
+
|
| 1159 |
+
|
| 1160 |
+
class MCRM(nn.Module):
|
| 1161 |
+
|
| 1162 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
| 1163 |
+
super(MCRM, self).__init__()
|
| 1164 |
+
self.attention = nn.ModuleList([
|
| 1165 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1166 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1167 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1168 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1169 |
+
])
|
| 1170 |
+
|
| 1171 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1172 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1173 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1174 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1175 |
+
self.dropout = nn.Dropout(0.1)
|
| 1176 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1177 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1178 |
+
self.sigmoid = nn.Sigmoid()
|
| 1179 |
+
self.activation = get_activation_fn('relu')
|
| 1180 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
| 1181 |
+
self.pool_ratios = pool_ratios
|
| 1182 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1183 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1184 |
+
|
| 1185 |
+
def forward(self, x):
|
| 1186 |
+
b, c, h, w = x.size()
|
| 1187 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 1188 |
+
# b(4),c,h,w
|
| 1189 |
+
patched_glb = rearrange(glb,
|
| 1190 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1191 |
+
hg=2,
|
| 1192 |
+
wg=2)
|
| 1193 |
+
|
| 1194 |
+
# generate token attention map
|
| 1195 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 1196 |
+
token_attention_map = F.interpolate(token_attention_map,
|
| 1197 |
+
size=patches2image(loc).shape[-2:],
|
| 1198 |
+
mode='nearest')
|
| 1199 |
+
loc = loc * rearrange(token_attention_map,
|
| 1200 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1201 |
+
hg=2,
|
| 1202 |
+
wg=2)
|
| 1203 |
+
pools = []
|
| 1204 |
+
for pool_ratio in self.pool_ratios:
|
| 1205 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1206 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 1207 |
+
pools.append(rearrange(pool,
|
| 1208 |
+
'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 1209 |
+
# nl(4),c,nphw -> nl(4),nphw,1,c
|
| 1210 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 1211 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 1212 |
+
outputs = []
|
| 1213 |
+
for i, q in enumerate(
|
| 1214 |
+
loc_.unbind(dim=0)): # traverse all local patches
|
| 1215 |
+
# np*hw,1,c
|
| 1216 |
+
v = pools[i]
|
| 1217 |
+
k = v
|
| 1218 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
| 1219 |
+
outputs = torch.cat(outputs, 1)
|
| 1220 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 1221 |
+
src = self.norm1(src)
|
| 1222 |
+
src = src + self.dropout2(
|
| 1223 |
+
self.linear4(
|
| 1224 |
+
self.dropout(self.activation(self.linear3(src)).clone())))
|
| 1225 |
+
src = self.norm2(src)
|
| 1226 |
+
|
| 1227 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 1228 |
+
glb = glb + F.interpolate(patches2image(src),
|
| 1229 |
+
size=glb.shape[-2:],
|
| 1230 |
+
mode='nearest') # freshed glb
|
| 1231 |
+
return torch.cat((src, glb), 0), token_attention_map
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
class inf_MCRM(nn.Module):
|
| 1235 |
+
|
| 1236 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
| 1237 |
+
super(inf_MCRM, self).__init__()
|
| 1238 |
+
self.attention = nn.ModuleList([
|
| 1239 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1240 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1241 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1242 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1243 |
+
])
|
| 1244 |
+
|
| 1245 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1246 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1247 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1248 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1249 |
+
self.dropout = nn.Dropout(0.1)
|
| 1250 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1251 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1252 |
+
self.sigmoid = nn.Sigmoid()
|
| 1253 |
+
self.activation = get_activation_fn('relu')
|
| 1254 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
| 1255 |
+
self.pool_ratios = pool_ratios
|
| 1256 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1257 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1258 |
+
|
| 1259 |
+
def forward(self, x):
|
| 1260 |
+
b, c, h, w = x.size()
|
| 1261 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 1262 |
+
# b(4),c,h,w
|
| 1263 |
+
patched_glb = rearrange(glb,
|
| 1264 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1265 |
+
hg=2,
|
| 1266 |
+
wg=2)
|
| 1267 |
+
|
| 1268 |
+
# generate token attention map
|
| 1269 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 1270 |
+
token_attention_map = F.interpolate(token_attention_map,
|
| 1271 |
+
size=patches2image(loc).shape[-2:],
|
| 1272 |
+
mode='nearest')
|
| 1273 |
+
loc = loc * rearrange(token_attention_map,
|
| 1274 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1275 |
+
hg=2,
|
| 1276 |
+
wg=2)
|
| 1277 |
+
pools = []
|
| 1278 |
+
for pool_ratio in self.pool_ratios:
|
| 1279 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1280 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 1281 |
+
pools.append(rearrange(pool,
|
| 1282 |
+
'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 1283 |
+
# nl(4),c,nphw -> nl(4),nphw,1,c
|
| 1284 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 1285 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 1286 |
+
outputs = []
|
| 1287 |
+
for i, q in enumerate(
|
| 1288 |
+
loc_.unbind(dim=0)): # traverse all local patches
|
| 1289 |
+
# np*hw,1,c
|
| 1290 |
+
v = pools[i]
|
| 1291 |
+
k = v
|
| 1292 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
| 1293 |
+
outputs = torch.cat(outputs, 1)
|
| 1294 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 1295 |
+
src = self.norm1(src)
|
| 1296 |
+
src = src + self.dropout2(
|
| 1297 |
+
self.linear4(
|
| 1298 |
+
self.dropout(self.activation(self.linear3(src)).clone())))
|
| 1299 |
+
src = self.norm2(src)
|
| 1300 |
+
|
| 1301 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 1302 |
+
glb = glb + F.interpolate(patches2image(src),
|
| 1303 |
+
size=glb.shape[-2:],
|
| 1304 |
+
mode='nearest') # freshed glb
|
| 1305 |
+
return torch.cat((src, glb), 0)
|
| 1306 |
+
|
| 1307 |
+
|
| 1308 |
+
# model for single-scale training
|
| 1309 |
+
class MVANet(nn.Module):
|
| 1310 |
+
|
| 1311 |
+
def __init__(self):
|
| 1312 |
+
super().__init__()
|
| 1313 |
+
self.backbone = SwinB(pretrained=True)
|
| 1314 |
+
emb_dim = 128
|
| 1315 |
+
self.sideout5 = nn.Sequential(
|
| 1316 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1317 |
+
self.sideout4 = nn.Sequential(
|
| 1318 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1319 |
+
self.sideout3 = nn.Sequential(
|
| 1320 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1321 |
+
self.sideout2 = nn.Sequential(
|
| 1322 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1323 |
+
self.sideout1 = nn.Sequential(
|
| 1324 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1325 |
+
|
| 1326 |
+
self.output5 = make_cbr(1024, emb_dim)
|
| 1327 |
+
self.output4 = make_cbr(512, emb_dim)
|
| 1328 |
+
self.output3 = make_cbr(256, emb_dim)
|
| 1329 |
+
self.output2 = make_cbr(128, emb_dim)
|
| 1330 |
+
self.output1 = make_cbr(128, emb_dim)
|
| 1331 |
+
|
| 1332 |
+
self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
|
| 1333 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
| 1334 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
| 1335 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
| 1336 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
| 1337 |
+
self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1338 |
+
self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1339 |
+
self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1340 |
+
self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1341 |
+
|
| 1342 |
+
self.insmask_head = nn.Sequential(
|
| 1343 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
| 1344 |
+
nn.BatchNorm2d(384), nn.PReLU(),
|
| 1345 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
|
| 1346 |
+
nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
|
| 1347 |
+
|
| 1348 |
+
self.shallow = nn.Sequential(
|
| 1349 |
+
nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
| 1350 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
| 1351 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
| 1352 |
+
self.output = nn.Sequential(
|
| 1353 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1354 |
+
|
| 1355 |
+
for m in self.modules():
|
| 1356 |
+
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
|
| 1357 |
+
m.inplace = True
|
| 1358 |
+
|
| 1359 |
+
def forward(self, x):
|
| 1360 |
+
x = x.to(dtype=torch_dtype, device=torch_device)
|
| 1361 |
+
shallow = self.shallow(x)
|
| 1362 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
| 1363 |
+
loc = image2patches(x)
|
| 1364 |
+
input = torch.cat((loc, glb), dim=0)
|
| 1365 |
+
feature = self.backbone(input)
|
| 1366 |
+
e5 = self.output5(feature[4]) # (5,128,16,16)
|
| 1367 |
+
e4 = self.output4(feature[3]) # (5,128,32,32)
|
| 1368 |
+
e3 = self.output3(feature[2]) # (5,128,64,64)
|
| 1369 |
+
e2 = self.output2(feature[1]) # (5,128,128,128)
|
| 1370 |
+
e1 = self.output1(feature[0]) # (5,128,128,128)
|
| 1371 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
| 1372 |
+
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
| 1373 |
+
|
| 1374 |
+
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
| 1375 |
+
e4 = self.conv4(e4)
|
| 1376 |
+
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
| 1377 |
+
e3 = self.conv3(e3)
|
| 1378 |
+
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
| 1379 |
+
e2 = self.conv2(e2)
|
| 1380 |
+
e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
|
| 1381 |
+
e1 = self.conv1(e1)
|
| 1382 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
| 1383 |
+
output1_cat = patches2image(loc_e1) # (1,128,256,256)
|
| 1384 |
+
# add glb feat in
|
| 1385 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
| 1386 |
+
# merge
|
| 1387 |
+
final_output = self.insmask_head(output1_cat) # (1,128,256,256)
|
| 1388 |
+
# shallow feature merge
|
| 1389 |
+
final_output = final_output + resize_as(shallow, final_output)
|
| 1390 |
+
final_output = self.upsample1(rescale_to(final_output))
|
| 1391 |
+
final_output = rescale_to(final_output +
|
| 1392 |
+
resize_as(shallow, final_output))
|
| 1393 |
+
final_output = self.upsample2(final_output)
|
| 1394 |
+
final_output = self.output(final_output)
|
| 1395 |
+
####
|
| 1396 |
+
sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
|
| 1397 |
+
sideout4 = self.sideout4(e4)
|
| 1398 |
+
sideout3 = self.sideout3(e3)
|
| 1399 |
+
sideout2 = self.sideout2(e2)
|
| 1400 |
+
sideout1 = self.sideout1(e1)
|
| 1401 |
+
#######glb_sideouts ######
|
| 1402 |
+
glb5 = self.sideout5(glb_e5)
|
| 1403 |
+
glb4 = sideout4[-1, :, :, :].unsqueeze(0)
|
| 1404 |
+
glb3 = sideout3[-1, :, :, :].unsqueeze(0)
|
| 1405 |
+
glb2 = sideout2[-1, :, :, :].unsqueeze(0)
|
| 1406 |
+
glb1 = sideout1[-1, :, :, :].unsqueeze(0)
|
| 1407 |
+
####### concat 4 to 1 #######
|
| 1408 |
+
sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
|
| 1409 |
+
device=torch_device)
|
| 1410 |
+
sideout2 = patches2image(sideout2[:-1]).to(
|
| 1411 |
+
dtype=torch_dtype,
|
| 1412 |
+
device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
|
| 1413 |
+
sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
|
| 1414 |
+
device=torch_device)
|
| 1415 |
+
sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
|
| 1416 |
+
device=torch_device)
|
| 1417 |
+
sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
|
| 1418 |
+
device=torch_device)
|
| 1419 |
+
if self.training:
|
| 1420 |
+
return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
|
| 1421 |
+
else:
|
| 1422 |
+
return final_output
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
# model for multi-scale testing
|
| 1426 |
+
class inf_MVANet(nn.Module):
|
| 1427 |
+
|
| 1428 |
+
def __init__(self):
|
| 1429 |
+
super().__init__()
|
| 1430 |
+
# self.backbone = SwinB(pretrained=True)
|
| 1431 |
+
self.backbone = SwinB(pretrained=False)
|
| 1432 |
+
|
| 1433 |
+
emb_dim = 128
|
| 1434 |
+
self.output5 = make_cbr(1024, emb_dim)
|
| 1435 |
+
self.output4 = make_cbr(512, emb_dim)
|
| 1436 |
+
self.output3 = make_cbr(256, emb_dim)
|
| 1437 |
+
self.output2 = make_cbr(128, emb_dim)
|
| 1438 |
+
self.output1 = make_cbr(128, emb_dim)
|
| 1439 |
+
|
| 1440 |
+
self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
|
| 1441 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
| 1442 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
| 1443 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
| 1444 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
| 1445 |
+
self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1446 |
+
self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1447 |
+
self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1448 |
+
self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1449 |
+
|
| 1450 |
+
self.insmask_head = nn.Sequential(
|
| 1451 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
| 1452 |
+
nn.BatchNorm2d(384), nn.PReLU(),
|
| 1453 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
|
| 1454 |
+
nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
|
| 1455 |
+
|
| 1456 |
+
self.shallow = nn.Sequential(
|
| 1457 |
+
nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
| 1458 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
| 1459 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
| 1460 |
+
self.output = nn.Sequential(
|
| 1461 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1462 |
+
|
| 1463 |
+
for m in self.modules():
|
| 1464 |
+
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
|
| 1465 |
+
m.inplace = True
|
| 1466 |
+
|
| 1467 |
+
def forward(self, x):
|
| 1468 |
+
shallow = self.shallow(x)
|
| 1469 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
| 1470 |
+
loc = image2patches(x)
|
| 1471 |
+
input = torch.cat((loc, glb), dim=0)
|
| 1472 |
+
feature = self.backbone(input)
|
| 1473 |
+
e5 = self.output5(feature[4])
|
| 1474 |
+
e4 = self.output4(feature[3])
|
| 1475 |
+
e3 = self.output3(feature[2])
|
| 1476 |
+
e2 = self.output2(feature[1])
|
| 1477 |
+
e1 = self.output1(feature[0])
|
| 1478 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
| 1479 |
+
e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
|
| 1480 |
+
|
| 1481 |
+
e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
|
| 1482 |
+
e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
|
| 1483 |
+
e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
|
| 1484 |
+
e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
|
| 1485 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
| 1486 |
+
# after decoder, concat loc features to a whole one, and merge
|
| 1487 |
+
output1_cat = patches2image(loc_e1)
|
| 1488 |
+
# add glb feat in
|
| 1489 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
| 1490 |
+
# merge
|
| 1491 |
+
final_output = self.insmask_head(output1_cat)
|
| 1492 |
+
# shallow feature merge
|
| 1493 |
+
final_output = final_output + resize_as(shallow, final_output)
|
| 1494 |
+
final_output = self.upsample1(rescale_to(final_output))
|
| 1495 |
+
final_output = rescale_to(final_output +
|
| 1496 |
+
resize_as(shallow, final_output))
|
| 1497 |
+
final_output = self.upsample2(final_output)
|
| 1498 |
+
final_output = self.output(final_output)
|
| 1499 |
+
return final_output
|
| 1500 |
+
#+end_src
|
| 1501 |
+
|
| 1502 |
+
** Function to load model
|
| 1503 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 1504 |
+
def mkdir_safe(out_path):
|
| 1505 |
+
if type(out_path) == str:
|
| 1506 |
+
if len(out_path) > 0:
|
| 1507 |
+
if not os.path.exists(out_path):
|
| 1508 |
+
os.mkdir(out_path)
|
| 1509 |
+
|
| 1510 |
+
|
| 1511 |
+
def get_model_path():
|
| 1512 |
+
import folder_paths
|
| 1513 |
+
from folder_paths import models_dir
|
| 1514 |
+
|
| 1515 |
+
path_file_model = models_dir
|
| 1516 |
+
mkdir_safe(out_path=path_file_model)
|
| 1517 |
+
|
| 1518 |
+
path_file_model = os.path.join(path_file_model, 'MVANet')
|
| 1519 |
+
mkdir_safe(out_path=path_file_model)
|
| 1520 |
+
|
| 1521 |
+
path_file_model = os.path.join(path_file_model, 'Model_80.pth')
|
| 1522 |
+
|
| 1523 |
+
return path_file_model
|
| 1524 |
+
|
| 1525 |
+
|
| 1526 |
+
def download_model(path):
|
| 1527 |
+
if not os.path.exists(path):
|
| 1528 |
+
wget.download(
|
| 1529 |
+
'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
|
| 1530 |
+
out=path)
|
| 1531 |
+
|
| 1532 |
+
|
| 1533 |
+
def load_model(model_checkpoint_path):
|
| 1534 |
+
download_model(path=model_checkpoint_path)
|
| 1535 |
+
torch.cuda.set_device(0)
|
| 1536 |
+
|
| 1537 |
+
net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
|
| 1538 |
+
|
| 1539 |
+
pretrained_dict = torch.load(finetuned_MVANet_model_path,
|
| 1540 |
+
map_location=torch_device)
|
| 1541 |
+
|
| 1542 |
+
model_dict = net.state_dict()
|
| 1543 |
+
pretrained_dict = {
|
| 1544 |
+
k: v
|
| 1545 |
+
for k, v in pretrained_dict.items() if k in model_dict
|
| 1546 |
+
}
|
| 1547 |
+
model_dict.update(pretrained_dict)
|
| 1548 |
+
net.load_state_dict(model_dict)
|
| 1549 |
+
net = net.to(dtype=torch_dtype, device=torch_device)
|
| 1550 |
+
net.eval()
|
| 1551 |
+
return net
|
| 1552 |
+
#+end_src
|
| 1553 |
+
|
| 1554 |
+
** Function for modular inference CV
|
| 1555 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 1556 |
+
def do_infer_tensor2tensor(img, net):
|
| 1557 |
+
|
| 1558 |
+
img_transform = transforms.Compose(
|
| 1559 |
+
[transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
| 1560 |
+
|
| 1561 |
+
h_, w_ = img.shape[1], img.shape[2]
|
| 1562 |
+
|
| 1563 |
+
with torch.no_grad():
|
| 1564 |
+
|
| 1565 |
+
img = rearrange(img, 'B H W C -> B C H W')
|
| 1566 |
+
|
| 1567 |
+
img_resize = torch.nn.functional.interpolate(input=img,
|
| 1568 |
+
size=(1024, 1024),
|
| 1569 |
+
mode='bicubic',
|
| 1570 |
+
antialias=True)
|
| 1571 |
+
|
| 1572 |
+
img_var = img_transform(img_resize)
|
| 1573 |
+
img_var = Variable(img_var)
|
| 1574 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1575 |
+
|
| 1576 |
+
mask = []
|
| 1577 |
+
|
| 1578 |
+
mask.append(net(img_var))
|
| 1579 |
+
|
| 1580 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1581 |
+
prediction = prediction.sigmoid()
|
| 1582 |
+
|
| 1583 |
+
prediction = torch.nn.functional.interpolate(input=prediction,
|
| 1584 |
+
size=(h_, w_),
|
| 1585 |
+
mode='bicubic',
|
| 1586 |
+
antialias=True)
|
| 1587 |
+
|
| 1588 |
+
prediction = prediction.squeeze(0)
|
| 1589 |
+
prediction = prediction.clamp(0, 1)
|
| 1590 |
+
prediction = prediction.detach()
|
| 1591 |
+
prediction = prediction.to(dtype=torch.float32, device='cpu')
|
| 1592 |
+
|
| 1593 |
+
return prediction
|
| 1594 |
+
#+end_src
|
| 1595 |
+
|
| 1596 |
+
** Comfyui wrapper classes
|
| 1597 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
|
| 1598 |
+
class load_MVANet_Model:
|
| 1599 |
+
|
| 1600 |
+
def __init__(self):
|
| 1601 |
+
pass
|
| 1602 |
+
|
| 1603 |
+
@classmethod
|
| 1604 |
+
def INPUT_TYPES(s):
|
| 1605 |
+
return {
|
| 1606 |
+
"required": {},
|
| 1607 |
+
}
|
| 1608 |
+
|
| 1609 |
+
RETURN_TYPES = ("MVANet_Model", )
|
| 1610 |
+
FUNCTION = "test"
|
| 1611 |
+
CATEGORY = "MVANet"
|
| 1612 |
+
|
| 1613 |
+
def test(self):
|
| 1614 |
+
return (load_model(get_model_path()), )
|
| 1615 |
+
|
| 1616 |
+
|
| 1617 |
+
class run_MVANet_inference:
|
| 1618 |
+
|
| 1619 |
+
def __init__(self):
|
| 1620 |
+
pass
|
| 1621 |
+
|
| 1622 |
+
@classmethod
|
| 1623 |
+
def INPUT_TYPES(s):
|
| 1624 |
+
return {
|
| 1625 |
+
"required": {
|
| 1626 |
+
"image": ("IMAGE", ),
|
| 1627 |
+
"MVANet_Model": ("MVANet_Model", ),
|
| 1628 |
+
},
|
| 1629 |
+
}
|
| 1630 |
+
|
| 1631 |
+
RETURN_TYPES = ("MASK", )
|
| 1632 |
+
FUNCTION = "test"
|
| 1633 |
+
CATEGORY = "MVANet"
|
| 1634 |
+
|
| 1635 |
+
def test(
|
| 1636 |
+
self,
|
| 1637 |
+
image,
|
| 1638 |
+
MVANet_Model,
|
| 1639 |
+
):
|
| 1640 |
+
ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
|
| 1641 |
+
|
| 1642 |
+
return (ret, )
|
| 1643 |
+
#+end_src
|
| 1644 |
+
|
| 1645 |
+
** MVANet_inference execute
|
| 1646 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
|
| 1647 |
+
NODE_CLASS_MAPPINGS = {
|
| 1648 |
+
"load_MVANet_Model": load_MVANet_Model,
|
| 1649 |
+
"run_MVANet_inference": run_MVANet_inference
|
| 1650 |
+
}
|
| 1651 |
+
|
| 1652 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 1653 |
+
"load_MVANet_Model": "load MVANet Model",
|
| 1654 |
+
"run_MVANet_inference": "run_MVANet_inference"
|
| 1655 |
+
}
|
| 1656 |
+
#+end_src
|
| 1657 |
+
|
| 1658 |
+
** MVANet_inference unify
|
| 1659 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
|
| 1660 |
+
. "${HOME}/dbnew.sh"
|
| 1661 |
+
|
| 1662 |
+
(
|
| 1663 |
+
echo '#!/usr/bin/python3'
|
| 1664 |
+
cat \
|
| 1665 |
+
'./MVANet_inference.import.py' \
|
| 1666 |
+
'./MVANet_inference.function.py' \
|
| 1667 |
+
'./MVANet_inference.class.py' \
|
| 1668 |
+
'./MVANet_inference.execute.py' \
|
| 1669 |
+
| expand | yapf3 \
|
| 1670 |
+
| grep -v '#!/usr/bin/python3' \
|
| 1671 |
+
;
|
| 1672 |
+
) > './MVANet_inference.py' \
|
| 1673 |
+
;
|
| 1674 |
+
|
| 1675 |
+
cp './MVANet_inference.py' '__init__.py'
|
| 1676 |
+
#+end_src
|
| 1677 |
+
|
| 1678 |
+
* WORK SPACE
|
| 1679 |
+
|
| 1680 |
+
** elisp
|
| 1681 |
+
#+begin_src elisp
|
| 1682 |
+
(save-buffer)
|
| 1683 |
+
(org-babel-tangle)
|
| 1684 |
+
(shell-command "./MVANet_inference.unify.sh")
|
| 1685 |
+
#+end_src
|
| 1686 |
+
|
| 1687 |
+
#+RESULTS:
|
| 1688 |
+
: 0
|
| 1689 |
+
|
| 1690 |
+
** sh
|
| 1691 |
+
#+begin_src sh :shebang #!/bin/sh :results output
|
| 1692 |
+
realpath .
|
| 1693 |
+
cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/MVANet
|
| 1694 |
+
#+end_src
|
ComfyUI_MVANet/__init__.py
ADDED
|
@@ -0,0 +1,1548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
HOME_DIR = os.environ.get('HOME', '/root')
|
| 6 |
+
MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
|
| 7 |
+
finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
|
| 8 |
+
pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
import cv2
|
| 13 |
+
import wget
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.utils.checkpoint as checkpoint
|
| 19 |
+
from torch.autograd import Variable
|
| 20 |
+
from torch import nn
|
| 21 |
+
from torchvision import transforms
|
| 22 |
+
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
|
| 25 |
+
from timm.models import load_checkpoint
|
| 26 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 27 |
+
|
| 28 |
+
torch_device = 'cuda'
|
| 29 |
+
torch_dtype = torch.float16
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def check_mkdir(dir_name):
|
| 33 |
+
if not os.path.isdir(dir_name):
|
| 34 |
+
os.makedirs(dir_name)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def SwinT(pretrained=True):
|
| 38 |
+
model = SwinTransformer(embed_dim=96,
|
| 39 |
+
depths=[2, 2, 6, 2],
|
| 40 |
+
num_heads=[3, 6, 12, 24],
|
| 41 |
+
window_size=7)
|
| 42 |
+
if pretrained is True:
|
| 43 |
+
model.load_state_dict(torch.load(
|
| 44 |
+
'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
|
| 45 |
+
map_location='cpu')['model'],
|
| 46 |
+
strict=False)
|
| 47 |
+
|
| 48 |
+
return model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def SwinS(pretrained=True):
|
| 52 |
+
model = SwinTransformer(embed_dim=96,
|
| 53 |
+
depths=[2, 2, 18, 2],
|
| 54 |
+
num_heads=[3, 6, 12, 24],
|
| 55 |
+
window_size=7)
|
| 56 |
+
if pretrained is True:
|
| 57 |
+
model.load_state_dict(torch.load(
|
| 58 |
+
'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
|
| 59 |
+
map_location='cpu')['model'],
|
| 60 |
+
strict=False)
|
| 61 |
+
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def SwinB(pretrained=True):
|
| 66 |
+
model = SwinTransformer(embed_dim=128,
|
| 67 |
+
depths=[2, 2, 18, 2],
|
| 68 |
+
num_heads=[4, 8, 16, 32],
|
| 69 |
+
window_size=12)
|
| 70 |
+
if pretrained is True:
|
| 71 |
+
import os
|
| 72 |
+
model.load_state_dict(torch.load(pretrained_SwinB_model_path,
|
| 73 |
+
map_location='cpu')['model'],
|
| 74 |
+
strict=False)
|
| 75 |
+
return model
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def SwinL(pretrained=True):
|
| 79 |
+
model = SwinTransformer(embed_dim=192,
|
| 80 |
+
depths=[2, 2, 18, 2],
|
| 81 |
+
num_heads=[6, 12, 24, 48],
|
| 82 |
+
window_size=12)
|
| 83 |
+
if pretrained is True:
|
| 84 |
+
model.load_state_dict(torch.load(
|
| 85 |
+
'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
|
| 86 |
+
map_location='cpu')['model'],
|
| 87 |
+
strict=False)
|
| 88 |
+
|
| 89 |
+
return model
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_activation_fn(activation):
|
| 93 |
+
"""Return an activation function given a string"""
|
| 94 |
+
if activation == "relu":
|
| 95 |
+
return F.relu
|
| 96 |
+
if activation == "gelu":
|
| 97 |
+
return F.gelu
|
| 98 |
+
if activation == "glu":
|
| 99 |
+
return F.glu
|
| 100 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def make_cbr(in_dim, out_dim):
|
| 104 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
|
| 105 |
+
nn.BatchNorm2d(out_dim), nn.PReLU())
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def make_cbg(in_dim, out_dim):
|
| 109 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
|
| 110 |
+
nn.BatchNorm2d(out_dim), nn.GELU())
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
|
| 114 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def resize_as(x, y, interpolation='bilinear'):
|
| 118 |
+
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def image2patches(x):
|
| 122 |
+
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
|
| 123 |
+
x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def patches2image(x):
|
| 128 |
+
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
|
| 129 |
+
x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def window_partition(x, window_size):
|
| 134 |
+
"""
|
| 135 |
+
Args:
|
| 136 |
+
x: (B, H, W, C)
|
| 137 |
+
window_size (int): window size
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 141 |
+
"""
|
| 142 |
+
B, H, W, C = x.shape
|
| 143 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
| 144 |
+
C)
|
| 145 |
+
windows = x.permute(0, 1, 3, 2, 4,
|
| 146 |
+
5).contiguous().view(-1, window_size, window_size, C)
|
| 147 |
+
return windows
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def window_reverse(windows, window_size, H, W):
|
| 151 |
+
"""
|
| 152 |
+
Args:
|
| 153 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 154 |
+
window_size (int): Window size
|
| 155 |
+
H (int): Height of image
|
| 156 |
+
W (int): Width of image
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
x: (B, H, W, C)
|
| 160 |
+
"""
|
| 161 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 162 |
+
x = windows.view(B, H // window_size, W // window_size, window_size,
|
| 163 |
+
window_size, -1)
|
| 164 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def mkdir_safe(out_path):
|
| 169 |
+
if type(out_path) == str:
|
| 170 |
+
if len(out_path) > 0:
|
| 171 |
+
if not os.path.exists(out_path):
|
| 172 |
+
os.mkdir(out_path)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_model_path():
|
| 176 |
+
import folder_paths
|
| 177 |
+
from folder_paths import models_dir
|
| 178 |
+
|
| 179 |
+
path_file_model = models_dir
|
| 180 |
+
mkdir_safe(out_path=path_file_model)
|
| 181 |
+
|
| 182 |
+
path_file_model = os.path.join(path_file_model, 'MVANet')
|
| 183 |
+
mkdir_safe(out_path=path_file_model)
|
| 184 |
+
|
| 185 |
+
path_file_model = os.path.join(path_file_model, 'Model_80.pth')
|
| 186 |
+
|
| 187 |
+
return path_file_model
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def download_model(path):
|
| 191 |
+
if not os.path.exists(path):
|
| 192 |
+
wget.download(
|
| 193 |
+
'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
|
| 194 |
+
out=path)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def load_model(model_checkpoint_path):
|
| 198 |
+
download_model(path=model_checkpoint_path)
|
| 199 |
+
torch.cuda.set_device(0)
|
| 200 |
+
|
| 201 |
+
net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
|
| 202 |
+
|
| 203 |
+
pretrained_dict = torch.load(finetuned_MVANet_model_path,
|
| 204 |
+
map_location=torch_device)
|
| 205 |
+
|
| 206 |
+
model_dict = net.state_dict()
|
| 207 |
+
pretrained_dict = {
|
| 208 |
+
k: v
|
| 209 |
+
for k, v in pretrained_dict.items() if k in model_dict
|
| 210 |
+
}
|
| 211 |
+
model_dict.update(pretrained_dict)
|
| 212 |
+
net.load_state_dict(model_dict)
|
| 213 |
+
net = net.to(dtype=torch_dtype, device=torch_device)
|
| 214 |
+
net.eval()
|
| 215 |
+
return net
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def do_infer_tensor2tensor(img, net):
|
| 219 |
+
|
| 220 |
+
img_transform = transforms.Compose(
|
| 221 |
+
[transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
| 222 |
+
|
| 223 |
+
h_, w_ = img.shape[1], img.shape[2]
|
| 224 |
+
|
| 225 |
+
with torch.no_grad():
|
| 226 |
+
|
| 227 |
+
img = rearrange(img, 'B H W C -> B C H W')
|
| 228 |
+
|
| 229 |
+
img_resize = torch.nn.functional.interpolate(input=img,
|
| 230 |
+
size=(1024, 1024),
|
| 231 |
+
mode='bicubic',
|
| 232 |
+
antialias=True)
|
| 233 |
+
|
| 234 |
+
img_var = img_transform(img_resize)
|
| 235 |
+
img_var = Variable(img_var)
|
| 236 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 237 |
+
|
| 238 |
+
mask = []
|
| 239 |
+
|
| 240 |
+
mask.append(net(img_var))
|
| 241 |
+
|
| 242 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 243 |
+
prediction = prediction.sigmoid()
|
| 244 |
+
|
| 245 |
+
prediction = torch.nn.functional.interpolate(input=prediction,
|
| 246 |
+
size=(h_, w_),
|
| 247 |
+
mode='bicubic',
|
| 248 |
+
antialias=True)
|
| 249 |
+
|
| 250 |
+
prediction = prediction.squeeze(0)
|
| 251 |
+
prediction = prediction.clamp(0, 1)
|
| 252 |
+
prediction = prediction.detach()
|
| 253 |
+
prediction = prediction.to(dtype=torch.float32, device='cpu')
|
| 254 |
+
|
| 255 |
+
return prediction
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class Mlp(nn.Module):
|
| 259 |
+
""" Multilayer perceptron."""
|
| 260 |
+
|
| 261 |
+
def __init__(self,
|
| 262 |
+
in_features,
|
| 263 |
+
hidden_features=None,
|
| 264 |
+
out_features=None,
|
| 265 |
+
act_layer=nn.GELU,
|
| 266 |
+
drop=0.):
|
| 267 |
+
super().__init__()
|
| 268 |
+
out_features = out_features or in_features
|
| 269 |
+
hidden_features = hidden_features or in_features
|
| 270 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 271 |
+
self.act = act_layer()
|
| 272 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 273 |
+
self.drop = nn.Dropout(drop)
|
| 274 |
+
|
| 275 |
+
def forward(self, x):
|
| 276 |
+
x = self.fc1(x)
|
| 277 |
+
x = self.act(x)
|
| 278 |
+
x = self.drop(x)
|
| 279 |
+
x = self.fc2(x)
|
| 280 |
+
x = self.drop(x)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class WindowAttention(nn.Module):
|
| 285 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 286 |
+
It supports both of shifted and non-shifted window.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
dim (int): Number of input channels.
|
| 290 |
+
window_size (tuple[int]): The height and width of the window.
|
| 291 |
+
num_heads (int): Number of attention heads.
|
| 292 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 293 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 294 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 295 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self,
|
| 299 |
+
dim,
|
| 300 |
+
window_size,
|
| 301 |
+
num_heads,
|
| 302 |
+
qkv_bias=True,
|
| 303 |
+
qk_scale=None,
|
| 304 |
+
attn_drop=0.,
|
| 305 |
+
proj_drop=0.):
|
| 306 |
+
|
| 307 |
+
super().__init__()
|
| 308 |
+
self.dim = dim
|
| 309 |
+
self.window_size = window_size # Wh, Ww
|
| 310 |
+
self.num_heads = num_heads
|
| 311 |
+
head_dim = dim // num_heads
|
| 312 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 313 |
+
|
| 314 |
+
# define a parameter table of relative position bias
|
| 315 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 316 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
| 317 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 318 |
+
|
| 319 |
+
# get pair-wise relative position index for each token inside the window
|
| 320 |
+
coords_h = torch.arange(self.window_size[0])
|
| 321 |
+
coords_w = torch.arange(self.window_size[1])
|
| 322 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 323 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 324 |
+
relative_coords = coords_flatten[:, :,
|
| 325 |
+
None] - coords_flatten[:,
|
| 326 |
+
None, :] # 2, Wh*Ww, Wh*Ww
|
| 327 |
+
relative_coords = relative_coords.permute(
|
| 328 |
+
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 329 |
+
relative_coords[:, :,
|
| 330 |
+
0] += self.window_size[0] - 1 # shift to start from 0
|
| 331 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 332 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 333 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 334 |
+
self.register_buffer("relative_position_index",
|
| 335 |
+
relative_position_index)
|
| 336 |
+
|
| 337 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 338 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 339 |
+
self.proj = nn.Linear(dim, dim)
|
| 340 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 341 |
+
|
| 342 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 343 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 344 |
+
|
| 345 |
+
def forward(self, x, mask=None):
|
| 346 |
+
""" Forward function.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 350 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 351 |
+
"""
|
| 352 |
+
x = x.to(dtype=torch_dtype, device=torch_device)
|
| 353 |
+
B_, N, C = x.shape
|
| 354 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
| 355 |
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 356 |
+
q, k, v = qkv[0], qkv[1], qkv[
|
| 357 |
+
2] # make torchscript happy (cannot use tensor as tuple)
|
| 358 |
+
|
| 359 |
+
q = q * self.scale
|
| 360 |
+
attn = (q @ k.transpose(-2, -1))
|
| 361 |
+
|
| 362 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 363 |
+
self.relative_position_index.view(-1)].view(
|
| 364 |
+
self.window_size[0] * self.window_size[1],
|
| 365 |
+
self.window_size[0] * self.window_size[1],
|
| 366 |
+
-1) # Wh*Ww,Wh*Ww,nH
|
| 367 |
+
relative_position_bias = relative_position_bias.permute(
|
| 368 |
+
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 369 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 370 |
+
|
| 371 |
+
if mask is not None:
|
| 372 |
+
nW = mask.shape[0]
|
| 373 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
| 374 |
+
N) + mask.unsqueeze(1).unsqueeze(0)
|
| 375 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 376 |
+
attn = self.softmax(attn)
|
| 377 |
+
else:
|
| 378 |
+
attn = self.softmax(attn)
|
| 379 |
+
|
| 380 |
+
attn = self.attn_drop(attn)
|
| 381 |
+
attn = attn.to(dtype=torch_dtype, device=torch_device)
|
| 382 |
+
v = v.to(dtype=torch_dtype, device=torch_device)
|
| 383 |
+
|
| 384 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 385 |
+
x = self.proj(x)
|
| 386 |
+
x = self.proj_drop(x)
|
| 387 |
+
return x
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class SwinTransformerBlock(nn.Module):
|
| 391 |
+
""" Swin Transformer Block.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
dim (int): Number of input channels.
|
| 395 |
+
num_heads (int): Number of attention heads.
|
| 396 |
+
window_size (int): Window size.
|
| 397 |
+
shift_size (int): Shift size for SW-MSA.
|
| 398 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 399 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 400 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 401 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 402 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 403 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 404 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 405 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
def __init__(self,
|
| 409 |
+
dim,
|
| 410 |
+
num_heads,
|
| 411 |
+
window_size=7,
|
| 412 |
+
shift_size=0,
|
| 413 |
+
mlp_ratio=4.,
|
| 414 |
+
qkv_bias=True,
|
| 415 |
+
qk_scale=None,
|
| 416 |
+
drop=0.,
|
| 417 |
+
attn_drop=0.,
|
| 418 |
+
drop_path=0.,
|
| 419 |
+
act_layer=nn.GELU,
|
| 420 |
+
norm_layer=nn.LayerNorm):
|
| 421 |
+
super().__init__()
|
| 422 |
+
self.dim = dim
|
| 423 |
+
self.num_heads = num_heads
|
| 424 |
+
self.window_size = window_size
|
| 425 |
+
self.shift_size = shift_size
|
| 426 |
+
self.mlp_ratio = mlp_ratio
|
| 427 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 428 |
+
|
| 429 |
+
self.norm1 = norm_layer(dim)
|
| 430 |
+
self.attn = WindowAttention(dim,
|
| 431 |
+
window_size=to_2tuple(self.window_size),
|
| 432 |
+
num_heads=num_heads,
|
| 433 |
+
qkv_bias=qkv_bias,
|
| 434 |
+
qk_scale=qk_scale,
|
| 435 |
+
attn_drop=attn_drop,
|
| 436 |
+
proj_drop=drop)
|
| 437 |
+
|
| 438 |
+
self.drop_path = DropPath(
|
| 439 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 440 |
+
self.norm2 = norm_layer(dim)
|
| 441 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 442 |
+
self.mlp = Mlp(in_features=dim,
|
| 443 |
+
hidden_features=mlp_hidden_dim,
|
| 444 |
+
act_layer=act_layer,
|
| 445 |
+
drop=drop)
|
| 446 |
+
|
| 447 |
+
self.H = None
|
| 448 |
+
self.W = None
|
| 449 |
+
|
| 450 |
+
def forward(self, x, mask_matrix):
|
| 451 |
+
""" Forward function.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 455 |
+
H, W: Spatial resolution of the input feature.
|
| 456 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 457 |
+
"""
|
| 458 |
+
B, L, C = x.shape
|
| 459 |
+
H, W = self.H, self.W
|
| 460 |
+
assert L == H * W, "input feature has wrong size"
|
| 461 |
+
|
| 462 |
+
shortcut = x
|
| 463 |
+
x = self.norm1(x)
|
| 464 |
+
x = x.view(B, H, W, C)
|
| 465 |
+
|
| 466 |
+
# pad feature maps to multiples of window size
|
| 467 |
+
pad_l = pad_t = 0
|
| 468 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 469 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 470 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 471 |
+
_, Hp, Wp, _ = x.shape
|
| 472 |
+
|
| 473 |
+
# cyclic shift
|
| 474 |
+
if self.shift_size > 0:
|
| 475 |
+
shifted_x = torch.roll(x,
|
| 476 |
+
shifts=(-self.shift_size, -self.shift_size),
|
| 477 |
+
dims=(1, 2))
|
| 478 |
+
attn_mask = mask_matrix
|
| 479 |
+
else:
|
| 480 |
+
shifted_x = x
|
| 481 |
+
attn_mask = None
|
| 482 |
+
|
| 483 |
+
# partition windows
|
| 484 |
+
x_windows = window_partition(
|
| 485 |
+
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 486 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
| 487 |
+
C) # nW*B, window_size*window_size, C
|
| 488 |
+
|
| 489 |
+
# W-MSA/SW-MSA
|
| 490 |
+
attn_windows = self.attn(
|
| 491 |
+
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 492 |
+
|
| 493 |
+
# merge windows
|
| 494 |
+
attn_windows = attn_windows.view(-1, self.window_size,
|
| 495 |
+
self.window_size, C)
|
| 496 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
| 497 |
+
Wp) # B H' W' C
|
| 498 |
+
|
| 499 |
+
# reverse cyclic shift
|
| 500 |
+
if self.shift_size > 0:
|
| 501 |
+
x = torch.roll(shifted_x,
|
| 502 |
+
shifts=(self.shift_size, self.shift_size),
|
| 503 |
+
dims=(1, 2))
|
| 504 |
+
else:
|
| 505 |
+
x = shifted_x
|
| 506 |
+
|
| 507 |
+
if pad_r > 0 or pad_b > 0:
|
| 508 |
+
x = x[:, :H, :W, :].contiguous()
|
| 509 |
+
|
| 510 |
+
x = x.view(B, H * W, C)
|
| 511 |
+
|
| 512 |
+
# FFN
|
| 513 |
+
x = shortcut + self.drop_path(x)
|
| 514 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 515 |
+
|
| 516 |
+
return x
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class PatchMerging(nn.Module):
|
| 520 |
+
""" Patch Merging Layer
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
dim (int): Number of input channels.
|
| 524 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 525 |
+
"""
|
| 526 |
+
|
| 527 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 528 |
+
super().__init__()
|
| 529 |
+
self.dim = dim
|
| 530 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 531 |
+
self.norm = norm_layer(4 * dim)
|
| 532 |
+
|
| 533 |
+
def forward(self, x, H, W):
|
| 534 |
+
""" Forward function.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 538 |
+
H, W: Spatial resolution of the input feature.
|
| 539 |
+
"""
|
| 540 |
+
B, L, C = x.shape
|
| 541 |
+
assert L == H * W, "input feature has wrong size"
|
| 542 |
+
|
| 543 |
+
x = x.view(B, H, W, C)
|
| 544 |
+
|
| 545 |
+
# padding
|
| 546 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 547 |
+
if pad_input:
|
| 548 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 549 |
+
|
| 550 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 551 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 552 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 553 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 554 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 555 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 556 |
+
|
| 557 |
+
x = self.norm(x)
|
| 558 |
+
x = self.reduction(x)
|
| 559 |
+
|
| 560 |
+
return x
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class BasicLayer(nn.Module):
|
| 564 |
+
""" A basic Swin Transformer layer for one stage.
|
| 565 |
+
|
| 566 |
+
Args:
|
| 567 |
+
dim (int): Number of feature channels
|
| 568 |
+
depth (int): Depths of this stage.
|
| 569 |
+
num_heads (int): Number of attention head.
|
| 570 |
+
window_size (int): Local window size. Default: 7.
|
| 571 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 572 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 573 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 574 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 575 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 576 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 577 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 578 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 579 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 580 |
+
"""
|
| 581 |
+
|
| 582 |
+
def __init__(self,
|
| 583 |
+
dim,
|
| 584 |
+
depth,
|
| 585 |
+
num_heads,
|
| 586 |
+
window_size=7,
|
| 587 |
+
mlp_ratio=4.,
|
| 588 |
+
qkv_bias=True,
|
| 589 |
+
qk_scale=None,
|
| 590 |
+
drop=0.,
|
| 591 |
+
attn_drop=0.,
|
| 592 |
+
drop_path=0.,
|
| 593 |
+
norm_layer=nn.LayerNorm,
|
| 594 |
+
downsample=None,
|
| 595 |
+
use_checkpoint=False):
|
| 596 |
+
super().__init__()
|
| 597 |
+
self.window_size = window_size
|
| 598 |
+
self.shift_size = window_size // 2
|
| 599 |
+
self.depth = depth
|
| 600 |
+
self.use_checkpoint = use_checkpoint
|
| 601 |
+
|
| 602 |
+
# build blocks
|
| 603 |
+
self.blocks = nn.ModuleList([
|
| 604 |
+
SwinTransformerBlock(dim=dim,
|
| 605 |
+
num_heads=num_heads,
|
| 606 |
+
window_size=window_size,
|
| 607 |
+
shift_size=0 if
|
| 608 |
+
(i % 2 == 0) else window_size // 2,
|
| 609 |
+
mlp_ratio=mlp_ratio,
|
| 610 |
+
qkv_bias=qkv_bias,
|
| 611 |
+
qk_scale=qk_scale,
|
| 612 |
+
drop=drop,
|
| 613 |
+
attn_drop=attn_drop,
|
| 614 |
+
drop_path=drop_path[i] if isinstance(
|
| 615 |
+
drop_path, list) else drop_path,
|
| 616 |
+
norm_layer=norm_layer) for i in range(depth)
|
| 617 |
+
])
|
| 618 |
+
|
| 619 |
+
# patch merging layer
|
| 620 |
+
if downsample is not None:
|
| 621 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 622 |
+
else:
|
| 623 |
+
self.downsample = None
|
| 624 |
+
|
| 625 |
+
def forward(self, x, H, W):
|
| 626 |
+
""" Forward function.
|
| 627 |
+
|
| 628 |
+
Args:
|
| 629 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 630 |
+
H, W: Spatial resolution of the input feature.
|
| 631 |
+
"""
|
| 632 |
+
|
| 633 |
+
# calculate attention mask for SW-MSA
|
| 634 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 635 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 636 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 637 |
+
h_slices = (slice(0, -self.window_size),
|
| 638 |
+
slice(-self.window_size,
|
| 639 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 640 |
+
w_slices = (slice(0, -self.window_size),
|
| 641 |
+
slice(-self.window_size,
|
| 642 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 643 |
+
cnt = 0
|
| 644 |
+
for h in h_slices:
|
| 645 |
+
for w in w_slices:
|
| 646 |
+
img_mask[:, h, w, :] = cnt
|
| 647 |
+
cnt += 1
|
| 648 |
+
|
| 649 |
+
mask_windows = window_partition(
|
| 650 |
+
img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 651 |
+
mask_windows = mask_windows.view(-1,
|
| 652 |
+
self.window_size * self.window_size)
|
| 653 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 654 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
| 655 |
+
float(-100.0)).masked_fill(
|
| 656 |
+
attn_mask == 0, float(0.0))
|
| 657 |
+
|
| 658 |
+
for blk in self.blocks:
|
| 659 |
+
blk.H, blk.W = H, W
|
| 660 |
+
if self.use_checkpoint:
|
| 661 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 662 |
+
else:
|
| 663 |
+
x = blk(x, attn_mask)
|
| 664 |
+
if self.downsample is not None:
|
| 665 |
+
x_down = self.downsample(x, H, W)
|
| 666 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 667 |
+
return x, H, W, x_down, Wh, Ww
|
| 668 |
+
else:
|
| 669 |
+
return x, H, W, x, H, W
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
class PatchEmbed(nn.Module):
|
| 673 |
+
""" Image to Patch Embedding
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
patch_size (int): Patch token size. Default: 4.
|
| 677 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 678 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 679 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 680 |
+
"""
|
| 681 |
+
|
| 682 |
+
def __init__(self,
|
| 683 |
+
patch_size=4,
|
| 684 |
+
in_chans=3,
|
| 685 |
+
embed_dim=96,
|
| 686 |
+
norm_layer=None):
|
| 687 |
+
super().__init__()
|
| 688 |
+
patch_size = to_2tuple(patch_size)
|
| 689 |
+
self.patch_size = patch_size
|
| 690 |
+
|
| 691 |
+
self.in_chans = in_chans
|
| 692 |
+
self.embed_dim = embed_dim
|
| 693 |
+
|
| 694 |
+
self.proj = nn.Conv2d(in_chans,
|
| 695 |
+
embed_dim,
|
| 696 |
+
kernel_size=patch_size,
|
| 697 |
+
stride=patch_size)
|
| 698 |
+
if norm_layer is not None:
|
| 699 |
+
self.norm = norm_layer(embed_dim)
|
| 700 |
+
else:
|
| 701 |
+
self.norm = None
|
| 702 |
+
|
| 703 |
+
def forward(self, x):
|
| 704 |
+
"""Forward function."""
|
| 705 |
+
# padding
|
| 706 |
+
_, _, H, W = x.size()
|
| 707 |
+
if W % self.patch_size[1] != 0:
|
| 708 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 709 |
+
if H % self.patch_size[0] != 0:
|
| 710 |
+
x = F.pad(x,
|
| 711 |
+
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 712 |
+
|
| 713 |
+
x = self.proj(x) # B C Wh Ww
|
| 714 |
+
if self.norm is not None:
|
| 715 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 716 |
+
x = x.flatten(2).transpose(1, 2)
|
| 717 |
+
x = self.norm(x)
|
| 718 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 719 |
+
|
| 720 |
+
return x
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class SwinTransformer(nn.Module):
|
| 724 |
+
""" Swin Transformer backbone.
|
| 725 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 726 |
+
https://arxiv.org/pdf/2103.14030
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 730 |
+
used in absolute postion embedding. Default 224.
|
| 731 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 732 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 733 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 734 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 735 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 736 |
+
window_size (int): Window size. Default: 7.
|
| 737 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 738 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 739 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 740 |
+
drop_rate (float): Dropout rate.
|
| 741 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 742 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 743 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 744 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 745 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 746 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 747 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 748 |
+
-1 means not freezing any parameters.
|
| 749 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 750 |
+
"""
|
| 751 |
+
|
| 752 |
+
def __init__(self,
|
| 753 |
+
pretrain_img_size=224,
|
| 754 |
+
patch_size=4,
|
| 755 |
+
in_chans=3,
|
| 756 |
+
embed_dim=96,
|
| 757 |
+
depths=[2, 2, 6, 2],
|
| 758 |
+
num_heads=[3, 6, 12, 24],
|
| 759 |
+
window_size=7,
|
| 760 |
+
mlp_ratio=4.,
|
| 761 |
+
qkv_bias=True,
|
| 762 |
+
qk_scale=None,
|
| 763 |
+
drop_rate=0.,
|
| 764 |
+
attn_drop_rate=0.,
|
| 765 |
+
drop_path_rate=0.2,
|
| 766 |
+
norm_layer=nn.LayerNorm,
|
| 767 |
+
ape=False,
|
| 768 |
+
patch_norm=True,
|
| 769 |
+
out_indices=(0, 1, 2, 3),
|
| 770 |
+
frozen_stages=-1,
|
| 771 |
+
use_checkpoint=False):
|
| 772 |
+
super().__init__()
|
| 773 |
+
|
| 774 |
+
self.pretrain_img_size = pretrain_img_size
|
| 775 |
+
self.num_layers = len(depths)
|
| 776 |
+
self.embed_dim = embed_dim
|
| 777 |
+
self.ape = ape
|
| 778 |
+
self.patch_norm = patch_norm
|
| 779 |
+
self.out_indices = out_indices
|
| 780 |
+
self.frozen_stages = frozen_stages
|
| 781 |
+
|
| 782 |
+
# split image into non-overlapping patches
|
| 783 |
+
self.patch_embed = PatchEmbed(
|
| 784 |
+
patch_size=patch_size,
|
| 785 |
+
in_chans=in_chans,
|
| 786 |
+
embed_dim=embed_dim,
|
| 787 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 788 |
+
|
| 789 |
+
# absolute position embedding
|
| 790 |
+
if self.ape:
|
| 791 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 792 |
+
patch_size = to_2tuple(patch_size)
|
| 793 |
+
patches_resolution = [
|
| 794 |
+
pretrain_img_size[0] // patch_size[0],
|
| 795 |
+
pretrain_img_size[1] // patch_size[1]
|
| 796 |
+
]
|
| 797 |
+
|
| 798 |
+
self.absolute_pos_embed = nn.Parameter(
|
| 799 |
+
torch.zeros(1, embed_dim, patches_resolution[0],
|
| 800 |
+
patches_resolution[1]))
|
| 801 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 802 |
+
|
| 803 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 804 |
+
|
| 805 |
+
# stochastic depth
|
| 806 |
+
dpr = [
|
| 807 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 808 |
+
] # stochastic depth decay rule
|
| 809 |
+
|
| 810 |
+
# build layers
|
| 811 |
+
self.layers = nn.ModuleList()
|
| 812 |
+
for i_layer in range(self.num_layers):
|
| 813 |
+
layer = BasicLayer(
|
| 814 |
+
dim=int(embed_dim * 2**i_layer),
|
| 815 |
+
depth=depths[i_layer],
|
| 816 |
+
num_heads=num_heads[i_layer],
|
| 817 |
+
window_size=window_size,
|
| 818 |
+
mlp_ratio=mlp_ratio,
|
| 819 |
+
qkv_bias=qkv_bias,
|
| 820 |
+
qk_scale=qk_scale,
|
| 821 |
+
drop=drop_rate,
|
| 822 |
+
attn_drop=attn_drop_rate,
|
| 823 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 824 |
+
norm_layer=norm_layer,
|
| 825 |
+
downsample=PatchMerging if
|
| 826 |
+
(i_layer < self.num_layers - 1) else None,
|
| 827 |
+
use_checkpoint=use_checkpoint)
|
| 828 |
+
self.layers.append(layer)
|
| 829 |
+
|
| 830 |
+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
| 831 |
+
self.num_features = num_features
|
| 832 |
+
|
| 833 |
+
# add a norm layer for each output
|
| 834 |
+
for i_layer in out_indices:
|
| 835 |
+
layer = norm_layer(num_features[i_layer])
|
| 836 |
+
layer_name = f'norm{i_layer}'
|
| 837 |
+
self.add_module(layer_name, layer)
|
| 838 |
+
|
| 839 |
+
self._freeze_stages()
|
| 840 |
+
|
| 841 |
+
def _freeze_stages(self):
|
| 842 |
+
if self.frozen_stages >= 0:
|
| 843 |
+
self.patch_embed.eval()
|
| 844 |
+
for param in self.patch_embed.parameters():
|
| 845 |
+
param.requires_grad = False
|
| 846 |
+
|
| 847 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 848 |
+
self.absolute_pos_embed.requires_grad = False
|
| 849 |
+
|
| 850 |
+
if self.frozen_stages >= 2:
|
| 851 |
+
self.pos_drop.eval()
|
| 852 |
+
for i in range(0, self.frozen_stages - 1):
|
| 853 |
+
m = self.layers[i]
|
| 854 |
+
m.eval()
|
| 855 |
+
for param in m.parameters():
|
| 856 |
+
param.requires_grad = False
|
| 857 |
+
|
| 858 |
+
def init_weights(self, pretrained=None):
|
| 859 |
+
"""Initialize the weights in backbone.
|
| 860 |
+
|
| 861 |
+
Args:
|
| 862 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 863 |
+
Defaults to None.
|
| 864 |
+
"""
|
| 865 |
+
|
| 866 |
+
def _init_weights(m):
|
| 867 |
+
if isinstance(m, nn.Linear):
|
| 868 |
+
trunc_normal_(m.weight, std=.02)
|
| 869 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 870 |
+
nn.init.constant_(m.bias, 0)
|
| 871 |
+
elif isinstance(m, nn.LayerNorm):
|
| 872 |
+
nn.init.constant_(m.bias, 0)
|
| 873 |
+
nn.init.constant_(m.weight, 1.0)
|
| 874 |
+
|
| 875 |
+
if isinstance(pretrained, str):
|
| 876 |
+
self.apply(_init_weights)
|
| 877 |
+
load_checkpoint(self, pretrained, strict=False, logger=None)
|
| 878 |
+
elif pretrained is None:
|
| 879 |
+
self.apply(_init_weights)
|
| 880 |
+
else:
|
| 881 |
+
raise TypeError('pretrained must be a str or None')
|
| 882 |
+
|
| 883 |
+
def forward(self, x):
|
| 884 |
+
x = self.patch_embed(x)
|
| 885 |
+
|
| 886 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 887 |
+
if self.ape:
|
| 888 |
+
# interpolate the position embedding to the corresponding size
|
| 889 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
|
| 890 |
+
size=(Wh, Ww),
|
| 891 |
+
mode='bicubic')
|
| 892 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 893 |
+
|
| 894 |
+
outs = [x.contiguous()]
|
| 895 |
+
x = x.flatten(2).transpose(1, 2)
|
| 896 |
+
x = self.pos_drop(x)
|
| 897 |
+
for i in range(self.num_layers):
|
| 898 |
+
layer = self.layers[i]
|
| 899 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 900 |
+
|
| 901 |
+
if i in self.out_indices:
|
| 902 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 903 |
+
x_out = norm_layer(x_out)
|
| 904 |
+
|
| 905 |
+
out = x_out.view(-1, H, W,
|
| 906 |
+
self.num_features[i]).permute(0, 3, 1,
|
| 907 |
+
2).contiguous()
|
| 908 |
+
outs.append(out)
|
| 909 |
+
|
| 910 |
+
return tuple(outs)
|
| 911 |
+
|
| 912 |
+
def train(self, mode=True):
|
| 913 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 914 |
+
super(SwinTransformer, self).train(mode)
|
| 915 |
+
self._freeze_stages()
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
class PositionEmbeddingSine:
|
| 919 |
+
|
| 920 |
+
def __init__(self,
|
| 921 |
+
num_pos_feats=64,
|
| 922 |
+
temperature=10000,
|
| 923 |
+
normalize=False,
|
| 924 |
+
scale=None):
|
| 925 |
+
super().__init__()
|
| 926 |
+
self.num_pos_feats = num_pos_feats
|
| 927 |
+
self.temperature = temperature
|
| 928 |
+
self.normalize = normalize
|
| 929 |
+
if scale is not None and normalize is False:
|
| 930 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 931 |
+
if scale is None:
|
| 932 |
+
scale = 2 * math.pi
|
| 933 |
+
self.scale = scale
|
| 934 |
+
self.dim_t = torch.arange(0,
|
| 935 |
+
self.num_pos_feats,
|
| 936 |
+
dtype=torch_dtype,
|
| 937 |
+
device=torch_device)
|
| 938 |
+
|
| 939 |
+
def __call__(self, b, h, w):
|
| 940 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
|
| 941 |
+
assert mask is not None
|
| 942 |
+
not_mask = ~mask
|
| 943 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
|
| 944 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
|
| 945 |
+
if self.normalize:
|
| 946 |
+
eps = 1e-6
|
| 947 |
+
y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
|
| 948 |
+
self.scale).to(device=torch_device, dtype=torch_dtype)
|
| 949 |
+
x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
|
| 950 |
+
self.scale).to(device=torch_device, dtype=torch_dtype)
|
| 951 |
+
|
| 952 |
+
dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
|
| 953 |
+
|
| 954 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 955 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 956 |
+
pos_x = torch.stack(
|
| 957 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
|
| 958 |
+
dim=4).flatten(3)
|
| 959 |
+
pos_y = torch.stack(
|
| 960 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
|
| 961 |
+
dim=4).flatten(3)
|
| 962 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
class MCLM(nn.Module):
|
| 966 |
+
|
| 967 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 968 |
+
super(MCLM, self).__init__()
|
| 969 |
+
self.attention = nn.ModuleList([
|
| 970 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 971 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 972 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 973 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 974 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 975 |
+
])
|
| 976 |
+
|
| 977 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
| 978 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
| 979 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 980 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 981 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 982 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 983 |
+
self.dropout = nn.Dropout(0.1)
|
| 984 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 985 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 986 |
+
self.activation = get_activation_fn('relu')
|
| 987 |
+
self.pool_ratios = pool_ratios
|
| 988 |
+
self.p_poses = []
|
| 989 |
+
self.g_pos = None
|
| 990 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 991 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 992 |
+
|
| 993 |
+
def forward(self, l, g):
|
| 994 |
+
"""
|
| 995 |
+
l: 4,c,h,w
|
| 996 |
+
g: 1,c,h,w
|
| 997 |
+
"""
|
| 998 |
+
b, c, h, w = l.size()
|
| 999 |
+
# 4,c,h,w -> 1,c,2h,2w
|
| 1000 |
+
concated_locs = rearrange(l,
|
| 1001 |
+
'(hg wg b) c h w -> b c (hg h) (wg w)',
|
| 1002 |
+
hg=2,
|
| 1003 |
+
wg=2)
|
| 1004 |
+
|
| 1005 |
+
pools = []
|
| 1006 |
+
for pool_ratio in self.pool_ratios:
|
| 1007 |
+
# b,c,h,w
|
| 1008 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1009 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
| 1010 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
| 1011 |
+
if self.g_pos is None:
|
| 1012 |
+
pos_emb = self.positional_encoding(pool.shape[0],
|
| 1013 |
+
pool.shape[2],
|
| 1014 |
+
pool.shape[3])
|
| 1015 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1016 |
+
self.p_poses.append(pos_emb)
|
| 1017 |
+
pools = torch.cat(pools, 0)
|
| 1018 |
+
if self.g_pos is None:
|
| 1019 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
| 1020 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
|
| 1021 |
+
g.shape[3])
|
| 1022 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1023 |
+
|
| 1024 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
| 1025 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
| 1026 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
|
| 1027 |
+
g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
| 1028 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
| 1029 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(
|
| 1030 |
+
self.linear2(
|
| 1031 |
+
self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
| 1032 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
| 1033 |
+
|
| 1034 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
| 1035 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
| 1036 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
| 1037 |
+
_g_hw_b_c = rearrange(_g_hw_b_c,
|
| 1038 |
+
"(ng h) (nw w) b c -> (h w) (ng nw b) c",
|
| 1039 |
+
ng=2,
|
| 1040 |
+
nw=2)
|
| 1041 |
+
outputs_re = []
|
| 1042 |
+
for i, (_l, _g) in enumerate(
|
| 1043 |
+
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
| 1044 |
+
outputs_re.append(self.attention[i + 1](_l, _g,
|
| 1045 |
+
_g)[0]) # (h w) 1 c
|
| 1046 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
| 1047 |
+
|
| 1048 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 1049 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 1050 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(
|
| 1051 |
+
self.linear4(
|
| 1052 |
+
self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 1053 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 1054 |
+
|
| 1055 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 1056 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
class inf_MCLM(nn.Module):
|
| 1060 |
+
|
| 1061 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 1062 |
+
super(inf_MCLM, self).__init__()
|
| 1063 |
+
self.attention = nn.ModuleList([
|
| 1064 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1065 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1066 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1067 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1068 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1069 |
+
])
|
| 1070 |
+
|
| 1071 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
| 1072 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
| 1073 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1074 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1075 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1076 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1077 |
+
self.dropout = nn.Dropout(0.1)
|
| 1078 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1079 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1080 |
+
self.activation = get_activation_fn('relu')
|
| 1081 |
+
self.pool_ratios = pool_ratios
|
| 1082 |
+
self.p_poses = []
|
| 1083 |
+
self.g_pos = None
|
| 1084 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1085 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1086 |
+
|
| 1087 |
+
def forward(self, l, g):
|
| 1088 |
+
"""
|
| 1089 |
+
l: 4,c,h,w
|
| 1090 |
+
g: 1,c,h,w
|
| 1091 |
+
"""
|
| 1092 |
+
b, c, h, w = l.size()
|
| 1093 |
+
# 4,c,h,w -> 1,c,2h,2w
|
| 1094 |
+
concated_locs = rearrange(l,
|
| 1095 |
+
'(hg wg b) c h w -> b c (hg h) (wg w)',
|
| 1096 |
+
hg=2,
|
| 1097 |
+
wg=2)
|
| 1098 |
+
self.p_poses = []
|
| 1099 |
+
pools = []
|
| 1100 |
+
for pool_ratio in self.pool_ratios:
|
| 1101 |
+
# b,c,h,w
|
| 1102 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1103 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
| 1104 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
| 1105 |
+
# if self.g_pos is None:
|
| 1106 |
+
pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
|
| 1107 |
+
pool.shape[3])
|
| 1108 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1109 |
+
self.p_poses.append(pos_emb)
|
| 1110 |
+
pools = torch.cat(pools, 0)
|
| 1111 |
+
# if self.g_pos is None:
|
| 1112 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
| 1113 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
|
| 1114 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1115 |
+
|
| 1116 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
| 1117 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
| 1118 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
|
| 1119 |
+
g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
| 1120 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
| 1121 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(
|
| 1122 |
+
self.linear2(
|
| 1123 |
+
self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
| 1124 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
| 1125 |
+
|
| 1126 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
| 1127 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
| 1128 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
| 1129 |
+
_g_hw_b_c = rearrange(_g_hw_b_c,
|
| 1130 |
+
"(ng h) (nw w) b c -> (h w) (ng nw b) c",
|
| 1131 |
+
ng=2,
|
| 1132 |
+
nw=2)
|
| 1133 |
+
outputs_re = []
|
| 1134 |
+
for i, (_l, _g) in enumerate(
|
| 1135 |
+
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
| 1136 |
+
outputs_re.append(self.attention[i + 1](_l, _g,
|
| 1137 |
+
_g)[0]) # (h w) 1 c
|
| 1138 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
| 1139 |
+
|
| 1140 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 1141 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 1142 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(
|
| 1143 |
+
self.linear4(
|
| 1144 |
+
self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 1145 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 1146 |
+
|
| 1147 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 1148 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
class MCRM(nn.Module):
|
| 1152 |
+
|
| 1153 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
| 1154 |
+
super(MCRM, self).__init__()
|
| 1155 |
+
self.attention = nn.ModuleList([
|
| 1156 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1157 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1158 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1159 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1160 |
+
])
|
| 1161 |
+
|
| 1162 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1163 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1164 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1165 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1166 |
+
self.dropout = nn.Dropout(0.1)
|
| 1167 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1168 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1169 |
+
self.sigmoid = nn.Sigmoid()
|
| 1170 |
+
self.activation = get_activation_fn('relu')
|
| 1171 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
| 1172 |
+
self.pool_ratios = pool_ratios
|
| 1173 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1174 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1175 |
+
|
| 1176 |
+
def forward(self, x):
|
| 1177 |
+
b, c, h, w = x.size()
|
| 1178 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 1179 |
+
# b(4),c,h,w
|
| 1180 |
+
patched_glb = rearrange(glb,
|
| 1181 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1182 |
+
hg=2,
|
| 1183 |
+
wg=2)
|
| 1184 |
+
|
| 1185 |
+
# generate token attention map
|
| 1186 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 1187 |
+
token_attention_map = F.interpolate(token_attention_map,
|
| 1188 |
+
size=patches2image(loc).shape[-2:],
|
| 1189 |
+
mode='nearest')
|
| 1190 |
+
loc = loc * rearrange(token_attention_map,
|
| 1191 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1192 |
+
hg=2,
|
| 1193 |
+
wg=2)
|
| 1194 |
+
pools = []
|
| 1195 |
+
for pool_ratio in self.pool_ratios:
|
| 1196 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1197 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 1198 |
+
pools.append(rearrange(pool,
|
| 1199 |
+
'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 1200 |
+
# nl(4),c,nphw -> nl(4),nphw,1,c
|
| 1201 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 1202 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 1203 |
+
outputs = []
|
| 1204 |
+
for i, q in enumerate(
|
| 1205 |
+
loc_.unbind(dim=0)): # traverse all local patches
|
| 1206 |
+
# np*hw,1,c
|
| 1207 |
+
v = pools[i]
|
| 1208 |
+
k = v
|
| 1209 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
| 1210 |
+
outputs = torch.cat(outputs, 1)
|
| 1211 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 1212 |
+
src = self.norm1(src)
|
| 1213 |
+
src = src + self.dropout2(
|
| 1214 |
+
self.linear4(
|
| 1215 |
+
self.dropout(self.activation(self.linear3(src)).clone())))
|
| 1216 |
+
src = self.norm2(src)
|
| 1217 |
+
|
| 1218 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 1219 |
+
glb = glb + F.interpolate(patches2image(src),
|
| 1220 |
+
size=glb.shape[-2:],
|
| 1221 |
+
mode='nearest') # freshed glb
|
| 1222 |
+
return torch.cat((src, glb), 0), token_attention_map
|
| 1223 |
+
|
| 1224 |
+
|
| 1225 |
+
class inf_MCRM(nn.Module):
|
| 1226 |
+
|
| 1227 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
| 1228 |
+
super(inf_MCRM, self).__init__()
|
| 1229 |
+
self.attention = nn.ModuleList([
|
| 1230 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1231 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1232 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1233 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1234 |
+
])
|
| 1235 |
+
|
| 1236 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1237 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1238 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1239 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1240 |
+
self.dropout = nn.Dropout(0.1)
|
| 1241 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1242 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1243 |
+
self.sigmoid = nn.Sigmoid()
|
| 1244 |
+
self.activation = get_activation_fn('relu')
|
| 1245 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
| 1246 |
+
self.pool_ratios = pool_ratios
|
| 1247 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1248 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1249 |
+
|
| 1250 |
+
def forward(self, x):
|
| 1251 |
+
b, c, h, w = x.size()
|
| 1252 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 1253 |
+
# b(4),c,h,w
|
| 1254 |
+
patched_glb = rearrange(glb,
|
| 1255 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1256 |
+
hg=2,
|
| 1257 |
+
wg=2)
|
| 1258 |
+
|
| 1259 |
+
# generate token attention map
|
| 1260 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 1261 |
+
token_attention_map = F.interpolate(token_attention_map,
|
| 1262 |
+
size=patches2image(loc).shape[-2:],
|
| 1263 |
+
mode='nearest')
|
| 1264 |
+
loc = loc * rearrange(token_attention_map,
|
| 1265 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1266 |
+
hg=2,
|
| 1267 |
+
wg=2)
|
| 1268 |
+
pools = []
|
| 1269 |
+
for pool_ratio in self.pool_ratios:
|
| 1270 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1271 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 1272 |
+
pools.append(rearrange(pool,
|
| 1273 |
+
'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 1274 |
+
# nl(4),c,nphw -> nl(4),nphw,1,c
|
| 1275 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 1276 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 1277 |
+
outputs = []
|
| 1278 |
+
for i, q in enumerate(
|
| 1279 |
+
loc_.unbind(dim=0)): # traverse all local patches
|
| 1280 |
+
# np*hw,1,c
|
| 1281 |
+
v = pools[i]
|
| 1282 |
+
k = v
|
| 1283 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
| 1284 |
+
outputs = torch.cat(outputs, 1)
|
| 1285 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 1286 |
+
src = self.norm1(src)
|
| 1287 |
+
src = src + self.dropout2(
|
| 1288 |
+
self.linear4(
|
| 1289 |
+
self.dropout(self.activation(self.linear3(src)).clone())))
|
| 1290 |
+
src = self.norm2(src)
|
| 1291 |
+
|
| 1292 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 1293 |
+
glb = glb + F.interpolate(patches2image(src),
|
| 1294 |
+
size=glb.shape[-2:],
|
| 1295 |
+
mode='nearest') # freshed glb
|
| 1296 |
+
return torch.cat((src, glb), 0)
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
# model for single-scale training
|
| 1300 |
+
class MVANet(nn.Module):
|
| 1301 |
+
|
| 1302 |
+
def __init__(self):
|
| 1303 |
+
super().__init__()
|
| 1304 |
+
self.backbone = SwinB(pretrained=True)
|
| 1305 |
+
emb_dim = 128
|
| 1306 |
+
self.sideout5 = nn.Sequential(
|
| 1307 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1308 |
+
self.sideout4 = nn.Sequential(
|
| 1309 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1310 |
+
self.sideout3 = nn.Sequential(
|
| 1311 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1312 |
+
self.sideout2 = nn.Sequential(
|
| 1313 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1314 |
+
self.sideout1 = nn.Sequential(
|
| 1315 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1316 |
+
|
| 1317 |
+
self.output5 = make_cbr(1024, emb_dim)
|
| 1318 |
+
self.output4 = make_cbr(512, emb_dim)
|
| 1319 |
+
self.output3 = make_cbr(256, emb_dim)
|
| 1320 |
+
self.output2 = make_cbr(128, emb_dim)
|
| 1321 |
+
self.output1 = make_cbr(128, emb_dim)
|
| 1322 |
+
|
| 1323 |
+
self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
|
| 1324 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
| 1325 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
| 1326 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
| 1327 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
| 1328 |
+
self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1329 |
+
self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1330 |
+
self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1331 |
+
self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1332 |
+
|
| 1333 |
+
self.insmask_head = nn.Sequential(
|
| 1334 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
| 1335 |
+
nn.BatchNorm2d(384), nn.PReLU(),
|
| 1336 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
|
| 1337 |
+
nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
|
| 1338 |
+
|
| 1339 |
+
self.shallow = nn.Sequential(
|
| 1340 |
+
nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
| 1341 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
| 1342 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
| 1343 |
+
self.output = nn.Sequential(
|
| 1344 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1345 |
+
|
| 1346 |
+
for m in self.modules():
|
| 1347 |
+
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
|
| 1348 |
+
m.inplace = True
|
| 1349 |
+
|
| 1350 |
+
def forward(self, x):
|
| 1351 |
+
x = x.to(dtype=torch_dtype, device=torch_device)
|
| 1352 |
+
shallow = self.shallow(x)
|
| 1353 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
| 1354 |
+
loc = image2patches(x)
|
| 1355 |
+
input = torch.cat((loc, glb), dim=0)
|
| 1356 |
+
feature = self.backbone(input)
|
| 1357 |
+
e5 = self.output5(feature[4]) # (5,128,16,16)
|
| 1358 |
+
e4 = self.output4(feature[3]) # (5,128,32,32)
|
| 1359 |
+
e3 = self.output3(feature[2]) # (5,128,64,64)
|
| 1360 |
+
e2 = self.output2(feature[1]) # (5,128,128,128)
|
| 1361 |
+
e1 = self.output1(feature[0]) # (5,128,128,128)
|
| 1362 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
| 1363 |
+
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
| 1364 |
+
|
| 1365 |
+
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
| 1366 |
+
e4 = self.conv4(e4)
|
| 1367 |
+
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
| 1368 |
+
e3 = self.conv3(e3)
|
| 1369 |
+
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
| 1370 |
+
e2 = self.conv2(e2)
|
| 1371 |
+
e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
|
| 1372 |
+
e1 = self.conv1(e1)
|
| 1373 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
| 1374 |
+
output1_cat = patches2image(loc_e1) # (1,128,256,256)
|
| 1375 |
+
# add glb feat in
|
| 1376 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
| 1377 |
+
# merge
|
| 1378 |
+
final_output = self.insmask_head(output1_cat) # (1,128,256,256)
|
| 1379 |
+
# shallow feature merge
|
| 1380 |
+
final_output = final_output + resize_as(shallow, final_output)
|
| 1381 |
+
final_output = self.upsample1(rescale_to(final_output))
|
| 1382 |
+
final_output = rescale_to(final_output +
|
| 1383 |
+
resize_as(shallow, final_output))
|
| 1384 |
+
final_output = self.upsample2(final_output)
|
| 1385 |
+
final_output = self.output(final_output)
|
| 1386 |
+
####
|
| 1387 |
+
sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
|
| 1388 |
+
sideout4 = self.sideout4(e4)
|
| 1389 |
+
sideout3 = self.sideout3(e3)
|
| 1390 |
+
sideout2 = self.sideout2(e2)
|
| 1391 |
+
sideout1 = self.sideout1(e1)
|
| 1392 |
+
#######glb_sideouts ######
|
| 1393 |
+
glb5 = self.sideout5(glb_e5)
|
| 1394 |
+
glb4 = sideout4[-1, :, :, :].unsqueeze(0)
|
| 1395 |
+
glb3 = sideout3[-1, :, :, :].unsqueeze(0)
|
| 1396 |
+
glb2 = sideout2[-1, :, :, :].unsqueeze(0)
|
| 1397 |
+
glb1 = sideout1[-1, :, :, :].unsqueeze(0)
|
| 1398 |
+
####### concat 4 to 1 #######
|
| 1399 |
+
sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
|
| 1400 |
+
device=torch_device)
|
| 1401 |
+
sideout2 = patches2image(sideout2[:-1]).to(
|
| 1402 |
+
dtype=torch_dtype,
|
| 1403 |
+
device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
|
| 1404 |
+
sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
|
| 1405 |
+
device=torch_device)
|
| 1406 |
+
sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
|
| 1407 |
+
device=torch_device)
|
| 1408 |
+
sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
|
| 1409 |
+
device=torch_device)
|
| 1410 |
+
if self.training:
|
| 1411 |
+
return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
|
| 1412 |
+
else:
|
| 1413 |
+
return final_output
|
| 1414 |
+
|
| 1415 |
+
|
| 1416 |
+
# model for multi-scale testing
|
| 1417 |
+
class inf_MVANet(nn.Module):
|
| 1418 |
+
|
| 1419 |
+
def __init__(self):
|
| 1420 |
+
super().__init__()
|
| 1421 |
+
# self.backbone = SwinB(pretrained=True)
|
| 1422 |
+
self.backbone = SwinB(pretrained=False)
|
| 1423 |
+
|
| 1424 |
+
emb_dim = 128
|
| 1425 |
+
self.output5 = make_cbr(1024, emb_dim)
|
| 1426 |
+
self.output4 = make_cbr(512, emb_dim)
|
| 1427 |
+
self.output3 = make_cbr(256, emb_dim)
|
| 1428 |
+
self.output2 = make_cbr(128, emb_dim)
|
| 1429 |
+
self.output1 = make_cbr(128, emb_dim)
|
| 1430 |
+
|
| 1431 |
+
self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
|
| 1432 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
| 1433 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
| 1434 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
| 1435 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
| 1436 |
+
self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1437 |
+
self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1438 |
+
self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1439 |
+
self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1440 |
+
|
| 1441 |
+
self.insmask_head = nn.Sequential(
|
| 1442 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
| 1443 |
+
nn.BatchNorm2d(384), nn.PReLU(),
|
| 1444 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
|
| 1445 |
+
nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
|
| 1446 |
+
|
| 1447 |
+
self.shallow = nn.Sequential(
|
| 1448 |
+
nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
| 1449 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
| 1450 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
| 1451 |
+
self.output = nn.Sequential(
|
| 1452 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1453 |
+
|
| 1454 |
+
for m in self.modules():
|
| 1455 |
+
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
|
| 1456 |
+
m.inplace = True
|
| 1457 |
+
|
| 1458 |
+
def forward(self, x):
|
| 1459 |
+
shallow = self.shallow(x)
|
| 1460 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
| 1461 |
+
loc = image2patches(x)
|
| 1462 |
+
input = torch.cat((loc, glb), dim=0)
|
| 1463 |
+
feature = self.backbone(input)
|
| 1464 |
+
e5 = self.output5(feature[4])
|
| 1465 |
+
e4 = self.output4(feature[3])
|
| 1466 |
+
e3 = self.output3(feature[2])
|
| 1467 |
+
e2 = self.output2(feature[1])
|
| 1468 |
+
e1 = self.output1(feature[0])
|
| 1469 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
| 1470 |
+
e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
|
| 1471 |
+
|
| 1472 |
+
e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
|
| 1473 |
+
e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
|
| 1474 |
+
e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
|
| 1475 |
+
e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
|
| 1476 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
| 1477 |
+
# after decoder, concat loc features to a whole one, and merge
|
| 1478 |
+
output1_cat = patches2image(loc_e1)
|
| 1479 |
+
# add glb feat in
|
| 1480 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
| 1481 |
+
# merge
|
| 1482 |
+
final_output = self.insmask_head(output1_cat)
|
| 1483 |
+
# shallow feature merge
|
| 1484 |
+
final_output = final_output + resize_as(shallow, final_output)
|
| 1485 |
+
final_output = self.upsample1(rescale_to(final_output))
|
| 1486 |
+
final_output = rescale_to(final_output +
|
| 1487 |
+
resize_as(shallow, final_output))
|
| 1488 |
+
final_output = self.upsample2(final_output)
|
| 1489 |
+
final_output = self.output(final_output)
|
| 1490 |
+
return final_output
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
class load_MVANet_Model:
|
| 1494 |
+
|
| 1495 |
+
def __init__(self):
|
| 1496 |
+
pass
|
| 1497 |
+
|
| 1498 |
+
@classmethod
|
| 1499 |
+
def INPUT_TYPES(s):
|
| 1500 |
+
return {
|
| 1501 |
+
"required": {},
|
| 1502 |
+
}
|
| 1503 |
+
|
| 1504 |
+
RETURN_TYPES = ("MVANet_Model", )
|
| 1505 |
+
FUNCTION = "test"
|
| 1506 |
+
CATEGORY = "MVANet"
|
| 1507 |
+
|
| 1508 |
+
def test(self):
|
| 1509 |
+
return (load_model(get_model_path()), )
|
| 1510 |
+
|
| 1511 |
+
|
| 1512 |
+
class run_MVANet_inference:
|
| 1513 |
+
|
| 1514 |
+
def __init__(self):
|
| 1515 |
+
pass
|
| 1516 |
+
|
| 1517 |
+
@classmethod
|
| 1518 |
+
def INPUT_TYPES(s):
|
| 1519 |
+
return {
|
| 1520 |
+
"required": {
|
| 1521 |
+
"image": ("IMAGE", ),
|
| 1522 |
+
"MVANet_Model": ("MVANet_Model", ),
|
| 1523 |
+
},
|
| 1524 |
+
}
|
| 1525 |
+
|
| 1526 |
+
RETURN_TYPES = ("MASK", )
|
| 1527 |
+
FUNCTION = "test"
|
| 1528 |
+
CATEGORY = "MVANet"
|
| 1529 |
+
|
| 1530 |
+
def test(
|
| 1531 |
+
self,
|
| 1532 |
+
image,
|
| 1533 |
+
MVANet_Model,
|
| 1534 |
+
):
|
| 1535 |
+
ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
|
| 1536 |
+
|
| 1537 |
+
return (ret, )
|
| 1538 |
+
|
| 1539 |
+
|
| 1540 |
+
NODE_CLASS_MAPPINGS = {
|
| 1541 |
+
"load_MVANet_Model": load_MVANet_Model,
|
| 1542 |
+
"run_MVANet_inference": run_MVANet_inference
|
| 1543 |
+
}
|
| 1544 |
+
|
| 1545 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 1546 |
+
"load_MVANet_Model": "load MVANet Model",
|
| 1547 |
+
"run_MVANet_inference": "run_MVANet_inference"
|
| 1548 |
+
}
|
ComfyUI_MVANet/download.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
get_repo(){
|
| 3 |
+
DIR_REPO="${HOME}/GITHUB/$('echo' "${1}" | 'sed' 's/^git@github.com://g ; s@^https://github.com/@@g ; s@.git$@@g' )"
|
| 4 |
+
DIR_BASE="$('dirname' '--' "${DIR_REPO}")"
|
| 5 |
+
mkdir -pv -- "${DIR_BASE}"
|
| 6 |
+
cd "${DIR_BASE}"
|
| 7 |
+
git clone "${1}"
|
| 8 |
+
cd "${DIR_REPO}"
|
| 9 |
+
git pull
|
| 10 |
+
git submodule update --recursive --init
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
get_repo 'https://github.com/qianyu-dlut/MVANet.git'
|
ComfyUI_MVANet/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
timm
|
| 2 |
+
einops
|
| 3 |
+
wget
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2020 Peike Li
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
MVANet_Inference/README.org
ADDED
|
@@ -0,0 +1,2179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
* COMMENT Sample
|
| 2 |
+
|
| 3 |
+
** Shell script to download
|
| 4 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
|
| 5 |
+
#+end_src
|
| 6 |
+
|
| 7 |
+
** MVANet_inference import
|
| 8 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
|
| 9 |
+
#+end_src
|
| 10 |
+
|
| 11 |
+
** MVANet_inference function
|
| 12 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 13 |
+
#+end_src
|
| 14 |
+
|
| 15 |
+
** MVANet_inference class
|
| 16 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
|
| 17 |
+
#+end_src
|
| 18 |
+
|
| 19 |
+
** MVANet_inference execute
|
| 20 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
|
| 21 |
+
#+end_src
|
| 22 |
+
|
| 23 |
+
** MVANet_inference unify
|
| 24 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
|
| 25 |
+
#+end_src
|
| 26 |
+
|
| 27 |
+
** MVANet_inference run
|
| 28 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.run.sh
|
| 29 |
+
#+end_src
|
| 30 |
+
|
| 31 |
+
* Download the code:
|
| 32 |
+
|
| 33 |
+
** Function to download
|
| 34 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
|
| 35 |
+
get_repo(){
|
| 36 |
+
DIR_REPO="${HOME}/GITHUB/$('echo' "${1}" | 'sed' 's/^git@github.com://g ; s@^https://github.com/@@g ; s@.git$@@g' )"
|
| 37 |
+
DIR_BASE="$('dirname' '--' "${DIR_REPO}")"
|
| 38 |
+
mkdir -pv -- "${DIR_BASE}"
|
| 39 |
+
cd "${DIR_BASE}"
|
| 40 |
+
git clone "${1}"
|
| 41 |
+
cd "${DIR_REPO}"
|
| 42 |
+
git pull
|
| 43 |
+
git submodule update --recursive --init
|
| 44 |
+
}
|
| 45 |
+
#+end_src
|
| 46 |
+
|
| 47 |
+
** Download
|
| 48 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
|
| 49 |
+
get_repo 'https://github.com/qianyu-dlut/MVANet.git'
|
| 50 |
+
#+end_src
|
| 51 |
+
|
| 52 |
+
* Dependencies
|
| 53 |
+
pip3 install mmdet==2.23.0
|
| 54 |
+
pip3 install mmcv==1.4.8
|
| 55 |
+
pip3 install ttach
|
| 56 |
+
|
| 57 |
+
* Python inference
|
| 58 |
+
|
| 59 |
+
** Important configs
|
| 60 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
|
| 61 |
+
import os
|
| 62 |
+
import sys
|
| 63 |
+
|
| 64 |
+
HOME_DIR = os.environ.get('HOME', '/root')
|
| 65 |
+
MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
|
| 66 |
+
finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
|
| 67 |
+
pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
|
| 68 |
+
#+end_src
|
| 69 |
+
|
| 70 |
+
** MVANet_inference import
|
| 71 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
|
| 72 |
+
import math
|
| 73 |
+
import numpy as np
|
| 74 |
+
from PIL import Image
|
| 75 |
+
import time
|
| 76 |
+
# import ttach as tta
|
| 77 |
+
import cv2
|
| 78 |
+
|
| 79 |
+
import torch
|
| 80 |
+
import torch.nn as nn
|
| 81 |
+
import torch.nn.functional as F
|
| 82 |
+
import torch.utils.checkpoint as checkpoint
|
| 83 |
+
from torch.autograd import Variable
|
| 84 |
+
from torch import nn
|
| 85 |
+
from torchvision import transforms
|
| 86 |
+
|
| 87 |
+
from einops import rearrange
|
| 88 |
+
|
| 89 |
+
from timm.models import load_checkpoint
|
| 90 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 91 |
+
#+end_src
|
| 92 |
+
|
| 93 |
+
** Load image using CV
|
| 94 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 95 |
+
def load_image(input_image_path):
|
| 96 |
+
img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
|
| 97 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 98 |
+
return img
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def load_image_torch(input_image_path):
|
| 102 |
+
img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
|
| 103 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 104 |
+
img = torch.from_numpy(img)
|
| 105 |
+
img = img.to(dtype=torch.float32)
|
| 106 |
+
img /= 255.0
|
| 107 |
+
img = img.unsqueeze(0)
|
| 108 |
+
return img
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def save_mask(output_image_path, mask):
|
| 112 |
+
cv2.imwrite(output_image_path, mask)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def save_mask_torch(output_image_path, mask):
|
| 116 |
+
mask = mask.detach().cpu()
|
| 117 |
+
mask *= 255.0
|
| 118 |
+
mask = mask.clamp(0, 255)
|
| 119 |
+
print(mask.shape)
|
| 120 |
+
mask = mask.squeeze(0)
|
| 121 |
+
mask = mask.to(dtype=torch.uint8)
|
| 122 |
+
print(mask.shape)
|
| 123 |
+
mask = mask.numpy()
|
| 124 |
+
print(mask.shape)
|
| 125 |
+
cv2.imwrite(output_image_path, mask)
|
| 126 |
+
#+end_src
|
| 127 |
+
|
| 128 |
+
** Device configs
|
| 129 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
|
| 130 |
+
torch_device = 'cuda'
|
| 131 |
+
torch_dtype = torch.float16
|
| 132 |
+
#+end_src
|
| 133 |
+
to(dtype=torch_dtype, device=torch_device)
|
| 134 |
+
|
| 135 |
+
** MVANet_inference function
|
| 136 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 137 |
+
def check_mkdir(dir_name):
|
| 138 |
+
if not os.path.isdir(dir_name):
|
| 139 |
+
os.makedirs(dir_name)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def SwinT(pretrained=True):
|
| 143 |
+
model = SwinTransformer(embed_dim=96,
|
| 144 |
+
depths=[2, 2, 6, 2],
|
| 145 |
+
num_heads=[3, 6, 12, 24],
|
| 146 |
+
window_size=7)
|
| 147 |
+
if pretrained is True:
|
| 148 |
+
model.load_state_dict(torch.load(
|
| 149 |
+
'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
|
| 150 |
+
map_location='cpu')['model'],
|
| 151 |
+
strict=False)
|
| 152 |
+
|
| 153 |
+
return model
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def SwinS(pretrained=True):
|
| 157 |
+
model = SwinTransformer(embed_dim=96,
|
| 158 |
+
depths=[2, 2, 18, 2],
|
| 159 |
+
num_heads=[3, 6, 12, 24],
|
| 160 |
+
window_size=7)
|
| 161 |
+
if pretrained is True:
|
| 162 |
+
model.load_state_dict(torch.load(
|
| 163 |
+
'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
|
| 164 |
+
map_location='cpu')['model'],
|
| 165 |
+
strict=False)
|
| 166 |
+
|
| 167 |
+
return model
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def SwinB(pretrained=True):
|
| 171 |
+
model = SwinTransformer(embed_dim=128,
|
| 172 |
+
depths=[2, 2, 18, 2],
|
| 173 |
+
num_heads=[4, 8, 16, 32],
|
| 174 |
+
window_size=12)
|
| 175 |
+
if pretrained is True:
|
| 176 |
+
import os
|
| 177 |
+
model.load_state_dict(torch.load(pretrained_SwinB_model_path,
|
| 178 |
+
map_location='cpu')['model'],
|
| 179 |
+
strict=False)
|
| 180 |
+
return model
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def SwinL(pretrained=True):
|
| 184 |
+
model = SwinTransformer(embed_dim=192,
|
| 185 |
+
depths=[2, 2, 18, 2],
|
| 186 |
+
num_heads=[6, 12, 24, 48],
|
| 187 |
+
window_size=12)
|
| 188 |
+
if pretrained is True:
|
| 189 |
+
model.load_state_dict(torch.load(
|
| 190 |
+
'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
|
| 191 |
+
map_location='cpu')['model'],
|
| 192 |
+
strict=False)
|
| 193 |
+
|
| 194 |
+
return model
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_activation_fn(activation):
|
| 198 |
+
"""Return an activation function given a string"""
|
| 199 |
+
if activation == "relu":
|
| 200 |
+
return F.relu
|
| 201 |
+
if activation == "gelu":
|
| 202 |
+
return F.gelu
|
| 203 |
+
if activation == "glu":
|
| 204 |
+
return F.glu
|
| 205 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def make_cbr(in_dim, out_dim):
|
| 209 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
|
| 210 |
+
nn.BatchNorm2d(out_dim), nn.PReLU())
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def make_cbg(in_dim, out_dim):
|
| 214 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
|
| 215 |
+
nn.BatchNorm2d(out_dim), nn.GELU())
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
|
| 219 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def resize_as(x, y, interpolation='bilinear'):
|
| 223 |
+
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def image2patches(x):
|
| 227 |
+
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
|
| 228 |
+
x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def patches2image(x):
|
| 233 |
+
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
|
| 234 |
+
x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def window_partition(x, window_size):
|
| 239 |
+
"""
|
| 240 |
+
Args:
|
| 241 |
+
x: (B, H, W, C)
|
| 242 |
+
window_size (int): window size
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 246 |
+
"""
|
| 247 |
+
B, H, W, C = x.shape
|
| 248 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
| 249 |
+
C)
|
| 250 |
+
windows = x.permute(0, 1, 3, 2, 4,
|
| 251 |
+
5).contiguous().view(-1, window_size, window_size, C)
|
| 252 |
+
return windows
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def window_reverse(windows, window_size, H, W):
|
| 256 |
+
"""
|
| 257 |
+
Args:
|
| 258 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 259 |
+
window_size (int): Window size
|
| 260 |
+
H (int): Height of image
|
| 261 |
+
W (int): Width of image
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
x: (B, H, W, C)
|
| 265 |
+
"""
|
| 266 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 267 |
+
x = windows.view(B, H // window_size, W // window_size, window_size,
|
| 268 |
+
window_size, -1)
|
| 269 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 270 |
+
return x
|
| 271 |
+
#+end_src
|
| 272 |
+
|
| 273 |
+
** MVANet_inference class
|
| 274 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
|
| 275 |
+
class Mlp(nn.Module):
|
| 276 |
+
""" Multilayer perceptron."""
|
| 277 |
+
|
| 278 |
+
def __init__(self,
|
| 279 |
+
in_features,
|
| 280 |
+
hidden_features=None,
|
| 281 |
+
out_features=None,
|
| 282 |
+
act_layer=nn.GELU,
|
| 283 |
+
drop=0.):
|
| 284 |
+
super().__init__()
|
| 285 |
+
out_features = out_features or in_features
|
| 286 |
+
hidden_features = hidden_features or in_features
|
| 287 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 288 |
+
self.act = act_layer()
|
| 289 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 290 |
+
self.drop = nn.Dropout(drop)
|
| 291 |
+
|
| 292 |
+
def forward(self, x):
|
| 293 |
+
x = self.fc1(x)
|
| 294 |
+
x = self.act(x)
|
| 295 |
+
x = self.drop(x)
|
| 296 |
+
x = self.fc2(x)
|
| 297 |
+
x = self.drop(x)
|
| 298 |
+
return x
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class WindowAttention(nn.Module):
|
| 302 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 303 |
+
It supports both of shifted and non-shifted window.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
dim (int): Number of input channels.
|
| 307 |
+
window_size (tuple[int]): The height and width of the window.
|
| 308 |
+
num_heads (int): Number of attention heads.
|
| 309 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 310 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 311 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 312 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(self,
|
| 316 |
+
dim,
|
| 317 |
+
window_size,
|
| 318 |
+
num_heads,
|
| 319 |
+
qkv_bias=True,
|
| 320 |
+
qk_scale=None,
|
| 321 |
+
attn_drop=0.,
|
| 322 |
+
proj_drop=0.):
|
| 323 |
+
|
| 324 |
+
super().__init__()
|
| 325 |
+
self.dim = dim
|
| 326 |
+
self.window_size = window_size # Wh, Ww
|
| 327 |
+
self.num_heads = num_heads
|
| 328 |
+
head_dim = dim // num_heads
|
| 329 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 330 |
+
|
| 331 |
+
# define a parameter table of relative position bias
|
| 332 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 333 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
| 334 |
+
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 335 |
+
|
| 336 |
+
# get pair-wise relative position index for each token inside the window
|
| 337 |
+
coords_h = torch.arange(self.window_size[0])
|
| 338 |
+
coords_w = torch.arange(self.window_size[1])
|
| 339 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 340 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 341 |
+
relative_coords = coords_flatten[:, :,
|
| 342 |
+
None] - coords_flatten[:,
|
| 343 |
+
None, :] # 2, Wh*Ww, Wh*Ww
|
| 344 |
+
relative_coords = relative_coords.permute(
|
| 345 |
+
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 346 |
+
relative_coords[:, :,
|
| 347 |
+
0] += self.window_size[0] - 1 # shift to start from 0
|
| 348 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 349 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 350 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 351 |
+
self.register_buffer("relative_position_index",
|
| 352 |
+
relative_position_index)
|
| 353 |
+
|
| 354 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 355 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 356 |
+
self.proj = nn.Linear(dim, dim)
|
| 357 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 358 |
+
|
| 359 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 360 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 361 |
+
|
| 362 |
+
def forward(self, x, mask=None):
|
| 363 |
+
""" Forward function.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 367 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 368 |
+
"""
|
| 369 |
+
x = x.to(dtype=torch_dtype, device=torch_device)
|
| 370 |
+
B_, N, C = x.shape
|
| 371 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
| 372 |
+
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 373 |
+
q, k, v = qkv[0], qkv[1], qkv[
|
| 374 |
+
2] # make torchscript happy (cannot use tensor as tuple)
|
| 375 |
+
|
| 376 |
+
q = q * self.scale
|
| 377 |
+
attn = (q @ k.transpose(-2, -1))
|
| 378 |
+
|
| 379 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 380 |
+
self.relative_position_index.view(-1)].view(
|
| 381 |
+
self.window_size[0] * self.window_size[1],
|
| 382 |
+
self.window_size[0] * self.window_size[1],
|
| 383 |
+
-1) # Wh*Ww,Wh*Ww,nH
|
| 384 |
+
relative_position_bias = relative_position_bias.permute(
|
| 385 |
+
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 386 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 387 |
+
|
| 388 |
+
if mask is not None:
|
| 389 |
+
nW = mask.shape[0]
|
| 390 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
| 391 |
+
N) + mask.unsqueeze(1).unsqueeze(0)
|
| 392 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 393 |
+
attn = self.softmax(attn)
|
| 394 |
+
else:
|
| 395 |
+
attn = self.softmax(attn)
|
| 396 |
+
|
| 397 |
+
attn = self.attn_drop(attn)
|
| 398 |
+
attn = attn.to(dtype=torch_dtype, device=torch_device)
|
| 399 |
+
v = v.to(dtype=torch_dtype, device=torch_device)
|
| 400 |
+
|
| 401 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 402 |
+
x = self.proj(x)
|
| 403 |
+
x = self.proj_drop(x)
|
| 404 |
+
return x
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class SwinTransformerBlock(nn.Module):
|
| 408 |
+
""" Swin Transformer Block.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
dim (int): Number of input channels.
|
| 412 |
+
num_heads (int): Number of attention heads.
|
| 413 |
+
window_size (int): Window size.
|
| 414 |
+
shift_size (int): Shift size for SW-MSA.
|
| 415 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 416 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 417 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 418 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 419 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 420 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 421 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 422 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def __init__(self,
|
| 426 |
+
dim,
|
| 427 |
+
num_heads,
|
| 428 |
+
window_size=7,
|
| 429 |
+
shift_size=0,
|
| 430 |
+
mlp_ratio=4.,
|
| 431 |
+
qkv_bias=True,
|
| 432 |
+
qk_scale=None,
|
| 433 |
+
drop=0.,
|
| 434 |
+
attn_drop=0.,
|
| 435 |
+
drop_path=0.,
|
| 436 |
+
act_layer=nn.GELU,
|
| 437 |
+
norm_layer=nn.LayerNorm):
|
| 438 |
+
super().__init__()
|
| 439 |
+
self.dim = dim
|
| 440 |
+
self.num_heads = num_heads
|
| 441 |
+
self.window_size = window_size
|
| 442 |
+
self.shift_size = shift_size
|
| 443 |
+
self.mlp_ratio = mlp_ratio
|
| 444 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 445 |
+
|
| 446 |
+
self.norm1 = norm_layer(dim)
|
| 447 |
+
self.attn = WindowAttention(dim,
|
| 448 |
+
window_size=to_2tuple(self.window_size),
|
| 449 |
+
num_heads=num_heads,
|
| 450 |
+
qkv_bias=qkv_bias,
|
| 451 |
+
qk_scale=qk_scale,
|
| 452 |
+
attn_drop=attn_drop,
|
| 453 |
+
proj_drop=drop)
|
| 454 |
+
|
| 455 |
+
self.drop_path = DropPath(
|
| 456 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 457 |
+
self.norm2 = norm_layer(dim)
|
| 458 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 459 |
+
self.mlp = Mlp(in_features=dim,
|
| 460 |
+
hidden_features=mlp_hidden_dim,
|
| 461 |
+
act_layer=act_layer,
|
| 462 |
+
drop=drop)
|
| 463 |
+
|
| 464 |
+
self.H = None
|
| 465 |
+
self.W = None
|
| 466 |
+
|
| 467 |
+
def forward(self, x, mask_matrix):
|
| 468 |
+
""" Forward function.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 472 |
+
H, W: Spatial resolution of the input feature.
|
| 473 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 474 |
+
"""
|
| 475 |
+
B, L, C = x.shape
|
| 476 |
+
H, W = self.H, self.W
|
| 477 |
+
assert L == H * W, "input feature has wrong size"
|
| 478 |
+
|
| 479 |
+
shortcut = x
|
| 480 |
+
x = self.norm1(x)
|
| 481 |
+
x = x.view(B, H, W, C)
|
| 482 |
+
|
| 483 |
+
# pad feature maps to multiples of window size
|
| 484 |
+
pad_l = pad_t = 0
|
| 485 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 486 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 487 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 488 |
+
_, Hp, Wp, _ = x.shape
|
| 489 |
+
|
| 490 |
+
# cyclic shift
|
| 491 |
+
if self.shift_size > 0:
|
| 492 |
+
shifted_x = torch.roll(x,
|
| 493 |
+
shifts=(-self.shift_size, -self.shift_size),
|
| 494 |
+
dims=(1, 2))
|
| 495 |
+
attn_mask = mask_matrix
|
| 496 |
+
else:
|
| 497 |
+
shifted_x = x
|
| 498 |
+
attn_mask = None
|
| 499 |
+
|
| 500 |
+
# partition windows
|
| 501 |
+
x_windows = window_partition(
|
| 502 |
+
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 503 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
| 504 |
+
C) # nW*B, window_size*window_size, C
|
| 505 |
+
|
| 506 |
+
# W-MSA/SW-MSA
|
| 507 |
+
attn_windows = self.attn(
|
| 508 |
+
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 509 |
+
|
| 510 |
+
# merge windows
|
| 511 |
+
attn_windows = attn_windows.view(-1, self.window_size,
|
| 512 |
+
self.window_size, C)
|
| 513 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
| 514 |
+
Wp) # B H' W' C
|
| 515 |
+
|
| 516 |
+
# reverse cyclic shift
|
| 517 |
+
if self.shift_size > 0:
|
| 518 |
+
x = torch.roll(shifted_x,
|
| 519 |
+
shifts=(self.shift_size, self.shift_size),
|
| 520 |
+
dims=(1, 2))
|
| 521 |
+
else:
|
| 522 |
+
x = shifted_x
|
| 523 |
+
|
| 524 |
+
if pad_r > 0 or pad_b > 0:
|
| 525 |
+
x = x[:, :H, :W, :].contiguous()
|
| 526 |
+
|
| 527 |
+
x = x.view(B, H * W, C)
|
| 528 |
+
|
| 529 |
+
# FFN
|
| 530 |
+
x = shortcut + self.drop_path(x)
|
| 531 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 532 |
+
|
| 533 |
+
return x
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class PatchMerging(nn.Module):
|
| 537 |
+
""" Patch Merging Layer
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
dim (int): Number of input channels.
|
| 541 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 542 |
+
"""
|
| 543 |
+
|
| 544 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 545 |
+
super().__init__()
|
| 546 |
+
self.dim = dim
|
| 547 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 548 |
+
self.norm = norm_layer(4 * dim)
|
| 549 |
+
|
| 550 |
+
def forward(self, x, H, W):
|
| 551 |
+
""" Forward function.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 555 |
+
H, W: Spatial resolution of the input feature.
|
| 556 |
+
"""
|
| 557 |
+
B, L, C = x.shape
|
| 558 |
+
assert L == H * W, "input feature has wrong size"
|
| 559 |
+
|
| 560 |
+
x = x.view(B, H, W, C)
|
| 561 |
+
|
| 562 |
+
# padding
|
| 563 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 564 |
+
if pad_input:
|
| 565 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 566 |
+
|
| 567 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 568 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 569 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 570 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 571 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 572 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 573 |
+
|
| 574 |
+
x = self.norm(x)
|
| 575 |
+
x = self.reduction(x)
|
| 576 |
+
|
| 577 |
+
return x
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class BasicLayer(nn.Module):
|
| 581 |
+
""" A basic Swin Transformer layer for one stage.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
dim (int): Number of feature channels
|
| 585 |
+
depth (int): Depths of this stage.
|
| 586 |
+
num_heads (int): Number of attention head.
|
| 587 |
+
window_size (int): Local window size. Default: 7.
|
| 588 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 589 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 590 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 591 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 592 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 593 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 594 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 595 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 596 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 597 |
+
"""
|
| 598 |
+
|
| 599 |
+
def __init__(self,
|
| 600 |
+
dim,
|
| 601 |
+
depth,
|
| 602 |
+
num_heads,
|
| 603 |
+
window_size=7,
|
| 604 |
+
mlp_ratio=4.,
|
| 605 |
+
qkv_bias=True,
|
| 606 |
+
qk_scale=None,
|
| 607 |
+
drop=0.,
|
| 608 |
+
attn_drop=0.,
|
| 609 |
+
drop_path=0.,
|
| 610 |
+
norm_layer=nn.LayerNorm,
|
| 611 |
+
downsample=None,
|
| 612 |
+
use_checkpoint=False):
|
| 613 |
+
super().__init__()
|
| 614 |
+
self.window_size = window_size
|
| 615 |
+
self.shift_size = window_size // 2
|
| 616 |
+
self.depth = depth
|
| 617 |
+
self.use_checkpoint = use_checkpoint
|
| 618 |
+
|
| 619 |
+
# build blocks
|
| 620 |
+
self.blocks = nn.ModuleList([
|
| 621 |
+
SwinTransformerBlock(dim=dim,
|
| 622 |
+
num_heads=num_heads,
|
| 623 |
+
window_size=window_size,
|
| 624 |
+
shift_size=0 if
|
| 625 |
+
(i % 2 == 0) else window_size // 2,
|
| 626 |
+
mlp_ratio=mlp_ratio,
|
| 627 |
+
qkv_bias=qkv_bias,
|
| 628 |
+
qk_scale=qk_scale,
|
| 629 |
+
drop=drop,
|
| 630 |
+
attn_drop=attn_drop,
|
| 631 |
+
drop_path=drop_path[i] if isinstance(
|
| 632 |
+
drop_path, list) else drop_path,
|
| 633 |
+
norm_layer=norm_layer) for i in range(depth)
|
| 634 |
+
])
|
| 635 |
+
|
| 636 |
+
# patch merging layer
|
| 637 |
+
if downsample is not None:
|
| 638 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 639 |
+
else:
|
| 640 |
+
self.downsample = None
|
| 641 |
+
|
| 642 |
+
def forward(self, x, H, W):
|
| 643 |
+
""" Forward function.
|
| 644 |
+
|
| 645 |
+
Args:
|
| 646 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 647 |
+
H, W: Spatial resolution of the input feature.
|
| 648 |
+
"""
|
| 649 |
+
|
| 650 |
+
# calculate attention mask for SW-MSA
|
| 651 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 652 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 653 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 654 |
+
h_slices = (slice(0, -self.window_size),
|
| 655 |
+
slice(-self.window_size,
|
| 656 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 657 |
+
w_slices = (slice(0, -self.window_size),
|
| 658 |
+
slice(-self.window_size,
|
| 659 |
+
-self.shift_size), slice(-self.shift_size, None))
|
| 660 |
+
cnt = 0
|
| 661 |
+
for h in h_slices:
|
| 662 |
+
for w in w_slices:
|
| 663 |
+
img_mask[:, h, w, :] = cnt
|
| 664 |
+
cnt += 1
|
| 665 |
+
|
| 666 |
+
mask_windows = window_partition(
|
| 667 |
+
img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 668 |
+
mask_windows = mask_windows.view(-1,
|
| 669 |
+
self.window_size * self.window_size)
|
| 670 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 671 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
| 672 |
+
float(-100.0)).masked_fill(
|
| 673 |
+
attn_mask == 0, float(0.0))
|
| 674 |
+
|
| 675 |
+
for blk in self.blocks:
|
| 676 |
+
blk.H, blk.W = H, W
|
| 677 |
+
if self.use_checkpoint:
|
| 678 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 679 |
+
else:
|
| 680 |
+
x = blk(x, attn_mask)
|
| 681 |
+
if self.downsample is not None:
|
| 682 |
+
x_down = self.downsample(x, H, W)
|
| 683 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 684 |
+
return x, H, W, x_down, Wh, Ww
|
| 685 |
+
else:
|
| 686 |
+
return x, H, W, x, H, W
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
class PatchEmbed(nn.Module):
|
| 690 |
+
""" Image to Patch Embedding
|
| 691 |
+
|
| 692 |
+
Args:
|
| 693 |
+
patch_size (int): Patch token size. Default: 4.
|
| 694 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 695 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 696 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 697 |
+
"""
|
| 698 |
+
|
| 699 |
+
def __init__(self,
|
| 700 |
+
patch_size=4,
|
| 701 |
+
in_chans=3,
|
| 702 |
+
embed_dim=96,
|
| 703 |
+
norm_layer=None):
|
| 704 |
+
super().__init__()
|
| 705 |
+
patch_size = to_2tuple(patch_size)
|
| 706 |
+
self.patch_size = patch_size
|
| 707 |
+
|
| 708 |
+
self.in_chans = in_chans
|
| 709 |
+
self.embed_dim = embed_dim
|
| 710 |
+
|
| 711 |
+
self.proj = nn.Conv2d(in_chans,
|
| 712 |
+
embed_dim,
|
| 713 |
+
kernel_size=patch_size,
|
| 714 |
+
stride=patch_size)
|
| 715 |
+
if norm_layer is not None:
|
| 716 |
+
self.norm = norm_layer(embed_dim)
|
| 717 |
+
else:
|
| 718 |
+
self.norm = None
|
| 719 |
+
|
| 720 |
+
def forward(self, x):
|
| 721 |
+
"""Forward function."""
|
| 722 |
+
# padding
|
| 723 |
+
_, _, H, W = x.size()
|
| 724 |
+
if W % self.patch_size[1] != 0:
|
| 725 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 726 |
+
if H % self.patch_size[0] != 0:
|
| 727 |
+
x = F.pad(x,
|
| 728 |
+
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 729 |
+
|
| 730 |
+
x = self.proj(x) # B C Wh Ww
|
| 731 |
+
if self.norm is not None:
|
| 732 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 733 |
+
x = x.flatten(2).transpose(1, 2)
|
| 734 |
+
x = self.norm(x)
|
| 735 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 736 |
+
|
| 737 |
+
return x
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
class SwinTransformer(nn.Module):
|
| 741 |
+
""" Swin Transformer backbone.
|
| 742 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 743 |
+
https://arxiv.org/pdf/2103.14030
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 747 |
+
used in absolute postion embedding. Default 224.
|
| 748 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 749 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 750 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 751 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 752 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 753 |
+
window_size (int): Window size. Default: 7.
|
| 754 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 755 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 756 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 757 |
+
drop_rate (float): Dropout rate.
|
| 758 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 759 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 760 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 761 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 762 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 763 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 764 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 765 |
+
-1 means not freezing any parameters.
|
| 766 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 767 |
+
"""
|
| 768 |
+
|
| 769 |
+
def __init__(self,
|
| 770 |
+
pretrain_img_size=224,
|
| 771 |
+
patch_size=4,
|
| 772 |
+
in_chans=3,
|
| 773 |
+
embed_dim=96,
|
| 774 |
+
depths=[2, 2, 6, 2],
|
| 775 |
+
num_heads=[3, 6, 12, 24],
|
| 776 |
+
window_size=7,
|
| 777 |
+
mlp_ratio=4.,
|
| 778 |
+
qkv_bias=True,
|
| 779 |
+
qk_scale=None,
|
| 780 |
+
drop_rate=0.,
|
| 781 |
+
attn_drop_rate=0.,
|
| 782 |
+
drop_path_rate=0.2,
|
| 783 |
+
norm_layer=nn.LayerNorm,
|
| 784 |
+
ape=False,
|
| 785 |
+
patch_norm=True,
|
| 786 |
+
out_indices=(0, 1, 2, 3),
|
| 787 |
+
frozen_stages=-1,
|
| 788 |
+
use_checkpoint=False):
|
| 789 |
+
super().__init__()
|
| 790 |
+
|
| 791 |
+
self.pretrain_img_size = pretrain_img_size
|
| 792 |
+
self.num_layers = len(depths)
|
| 793 |
+
self.embed_dim = embed_dim
|
| 794 |
+
self.ape = ape
|
| 795 |
+
self.patch_norm = patch_norm
|
| 796 |
+
self.out_indices = out_indices
|
| 797 |
+
self.frozen_stages = frozen_stages
|
| 798 |
+
|
| 799 |
+
# split image into non-overlapping patches
|
| 800 |
+
self.patch_embed = PatchEmbed(
|
| 801 |
+
patch_size=patch_size,
|
| 802 |
+
in_chans=in_chans,
|
| 803 |
+
embed_dim=embed_dim,
|
| 804 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 805 |
+
|
| 806 |
+
# absolute position embedding
|
| 807 |
+
if self.ape:
|
| 808 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 809 |
+
patch_size = to_2tuple(patch_size)
|
| 810 |
+
patches_resolution = [
|
| 811 |
+
pretrain_img_size[0] // patch_size[0],
|
| 812 |
+
pretrain_img_size[1] // patch_size[1]
|
| 813 |
+
]
|
| 814 |
+
|
| 815 |
+
self.absolute_pos_embed = nn.Parameter(
|
| 816 |
+
torch.zeros(1, embed_dim, patches_resolution[0],
|
| 817 |
+
patches_resolution[1]))
|
| 818 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 819 |
+
|
| 820 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 821 |
+
|
| 822 |
+
# stochastic depth
|
| 823 |
+
dpr = [
|
| 824 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 825 |
+
] # stochastic depth decay rule
|
| 826 |
+
|
| 827 |
+
# build layers
|
| 828 |
+
self.layers = nn.ModuleList()
|
| 829 |
+
for i_layer in range(self.num_layers):
|
| 830 |
+
layer = BasicLayer(
|
| 831 |
+
dim=int(embed_dim * 2**i_layer),
|
| 832 |
+
depth=depths[i_layer],
|
| 833 |
+
num_heads=num_heads[i_layer],
|
| 834 |
+
window_size=window_size,
|
| 835 |
+
mlp_ratio=mlp_ratio,
|
| 836 |
+
qkv_bias=qkv_bias,
|
| 837 |
+
qk_scale=qk_scale,
|
| 838 |
+
drop=drop_rate,
|
| 839 |
+
attn_drop=attn_drop_rate,
|
| 840 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 841 |
+
norm_layer=norm_layer,
|
| 842 |
+
downsample=PatchMerging if
|
| 843 |
+
(i_layer < self.num_layers - 1) else None,
|
| 844 |
+
use_checkpoint=use_checkpoint)
|
| 845 |
+
self.layers.append(layer)
|
| 846 |
+
|
| 847 |
+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
| 848 |
+
self.num_features = num_features
|
| 849 |
+
|
| 850 |
+
# add a norm layer for each output
|
| 851 |
+
for i_layer in out_indices:
|
| 852 |
+
layer = norm_layer(num_features[i_layer])
|
| 853 |
+
layer_name = f'norm{i_layer}'
|
| 854 |
+
self.add_module(layer_name, layer)
|
| 855 |
+
|
| 856 |
+
self._freeze_stages()
|
| 857 |
+
|
| 858 |
+
def _freeze_stages(self):
|
| 859 |
+
if self.frozen_stages >= 0:
|
| 860 |
+
self.patch_embed.eval()
|
| 861 |
+
for param in self.patch_embed.parameters():
|
| 862 |
+
param.requires_grad = False
|
| 863 |
+
|
| 864 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 865 |
+
self.absolute_pos_embed.requires_grad = False
|
| 866 |
+
|
| 867 |
+
if self.frozen_stages >= 2:
|
| 868 |
+
self.pos_drop.eval()
|
| 869 |
+
for i in range(0, self.frozen_stages - 1):
|
| 870 |
+
m = self.layers[i]
|
| 871 |
+
m.eval()
|
| 872 |
+
for param in m.parameters():
|
| 873 |
+
param.requires_grad = False
|
| 874 |
+
|
| 875 |
+
def init_weights(self, pretrained=None):
|
| 876 |
+
"""Initialize the weights in backbone.
|
| 877 |
+
|
| 878 |
+
Args:
|
| 879 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 880 |
+
Defaults to None.
|
| 881 |
+
"""
|
| 882 |
+
|
| 883 |
+
def _init_weights(m):
|
| 884 |
+
if isinstance(m, nn.Linear):
|
| 885 |
+
trunc_normal_(m.weight, std=.02)
|
| 886 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 887 |
+
nn.init.constant_(m.bias, 0)
|
| 888 |
+
elif isinstance(m, nn.LayerNorm):
|
| 889 |
+
nn.init.constant_(m.bias, 0)
|
| 890 |
+
nn.init.constant_(m.weight, 1.0)
|
| 891 |
+
|
| 892 |
+
if isinstance(pretrained, str):
|
| 893 |
+
self.apply(_init_weights)
|
| 894 |
+
load_checkpoint(self, pretrained, strict=False, logger=None)
|
| 895 |
+
elif pretrained is None:
|
| 896 |
+
self.apply(_init_weights)
|
| 897 |
+
else:
|
| 898 |
+
raise TypeError('pretrained must be a str or None')
|
| 899 |
+
|
| 900 |
+
def forward(self, x):
|
| 901 |
+
x = self.patch_embed(x)
|
| 902 |
+
|
| 903 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 904 |
+
if self.ape:
|
| 905 |
+
# interpolate the position embedding to the corresponding size
|
| 906 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
|
| 907 |
+
size=(Wh, Ww),
|
| 908 |
+
mode='bicubic')
|
| 909 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 910 |
+
|
| 911 |
+
outs = [x.contiguous()]
|
| 912 |
+
x = x.flatten(2).transpose(1, 2)
|
| 913 |
+
x = self.pos_drop(x)
|
| 914 |
+
for i in range(self.num_layers):
|
| 915 |
+
layer = self.layers[i]
|
| 916 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 917 |
+
|
| 918 |
+
if i in self.out_indices:
|
| 919 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 920 |
+
x_out = norm_layer(x_out)
|
| 921 |
+
|
| 922 |
+
out = x_out.view(-1, H, W,
|
| 923 |
+
self.num_features[i]).permute(0, 3, 1,
|
| 924 |
+
2).contiguous()
|
| 925 |
+
outs.append(out)
|
| 926 |
+
|
| 927 |
+
return tuple(outs)
|
| 928 |
+
|
| 929 |
+
def train(self, mode=True):
|
| 930 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 931 |
+
super(SwinTransformer, self).train(mode)
|
| 932 |
+
self._freeze_stages()
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
class PositionEmbeddingSine:
|
| 936 |
+
|
| 937 |
+
def __init__(self,
|
| 938 |
+
num_pos_feats=64,
|
| 939 |
+
temperature=10000,
|
| 940 |
+
normalize=False,
|
| 941 |
+
scale=None):
|
| 942 |
+
super().__init__()
|
| 943 |
+
self.num_pos_feats = num_pos_feats
|
| 944 |
+
self.temperature = temperature
|
| 945 |
+
self.normalize = normalize
|
| 946 |
+
if scale is not None and normalize is False:
|
| 947 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 948 |
+
if scale is None:
|
| 949 |
+
scale = 2 * math.pi
|
| 950 |
+
self.scale = scale
|
| 951 |
+
self.dim_t = torch.arange(0,
|
| 952 |
+
self.num_pos_feats,
|
| 953 |
+
dtype=torch_dtype,
|
| 954 |
+
device=torch_device)
|
| 955 |
+
|
| 956 |
+
def __call__(self, b, h, w):
|
| 957 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
|
| 958 |
+
assert mask is not None
|
| 959 |
+
not_mask = ~mask
|
| 960 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
|
| 961 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
|
| 962 |
+
if self.normalize:
|
| 963 |
+
eps = 1e-6
|
| 964 |
+
y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
|
| 965 |
+
self.scale).to(device=torch_device, dtype=torch_dtype)
|
| 966 |
+
x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
|
| 967 |
+
self.scale).to(device=torch_device, dtype=torch_dtype)
|
| 968 |
+
|
| 969 |
+
dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
|
| 970 |
+
|
| 971 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 972 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 973 |
+
pos_x = torch.stack(
|
| 974 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
|
| 975 |
+
dim=4).flatten(3)
|
| 976 |
+
pos_y = torch.stack(
|
| 977 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
|
| 978 |
+
dim=4).flatten(3)
|
| 979 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
class MCLM(nn.Module):
|
| 983 |
+
|
| 984 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 985 |
+
super(MCLM, self).__init__()
|
| 986 |
+
self.attention = nn.ModuleList([
|
| 987 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 988 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 989 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 990 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 991 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 992 |
+
])
|
| 993 |
+
|
| 994 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
| 995 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
| 996 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 997 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 998 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 999 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1000 |
+
self.dropout = nn.Dropout(0.1)
|
| 1001 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1002 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1003 |
+
self.activation = get_activation_fn('relu')
|
| 1004 |
+
self.pool_ratios = pool_ratios
|
| 1005 |
+
self.p_poses = []
|
| 1006 |
+
self.g_pos = None
|
| 1007 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1008 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1009 |
+
|
| 1010 |
+
def forward(self, l, g):
|
| 1011 |
+
"""
|
| 1012 |
+
l: 4,c,h,w
|
| 1013 |
+
g: 1,c,h,w
|
| 1014 |
+
"""
|
| 1015 |
+
b, c, h, w = l.size()
|
| 1016 |
+
# 4,c,h,w -> 1,c,2h,2w
|
| 1017 |
+
concated_locs = rearrange(l,
|
| 1018 |
+
'(hg wg b) c h w -> b c (hg h) (wg w)',
|
| 1019 |
+
hg=2,
|
| 1020 |
+
wg=2)
|
| 1021 |
+
|
| 1022 |
+
pools = []
|
| 1023 |
+
for pool_ratio in self.pool_ratios:
|
| 1024 |
+
# b,c,h,w
|
| 1025 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1026 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
| 1027 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
| 1028 |
+
if self.g_pos is None:
|
| 1029 |
+
pos_emb = self.positional_encoding(pool.shape[0],
|
| 1030 |
+
pool.shape[2],
|
| 1031 |
+
pool.shape[3])
|
| 1032 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1033 |
+
self.p_poses.append(pos_emb)
|
| 1034 |
+
pools = torch.cat(pools, 0)
|
| 1035 |
+
if self.g_pos is None:
|
| 1036 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
| 1037 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
|
| 1038 |
+
g.shape[3])
|
| 1039 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1040 |
+
|
| 1041 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
| 1042 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
| 1043 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
|
| 1044 |
+
g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
| 1045 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
| 1046 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(
|
| 1047 |
+
self.linear2(
|
| 1048 |
+
self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
| 1049 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
| 1050 |
+
|
| 1051 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
| 1052 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
| 1053 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
| 1054 |
+
_g_hw_b_c = rearrange(_g_hw_b_c,
|
| 1055 |
+
"(ng h) (nw w) b c -> (h w) (ng nw b) c",
|
| 1056 |
+
ng=2,
|
| 1057 |
+
nw=2)
|
| 1058 |
+
outputs_re = []
|
| 1059 |
+
for i, (_l, _g) in enumerate(
|
| 1060 |
+
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
| 1061 |
+
outputs_re.append(self.attention[i + 1](_l, _g,
|
| 1062 |
+
_g)[0]) # (h w) 1 c
|
| 1063 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
| 1064 |
+
|
| 1065 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 1066 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 1067 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(
|
| 1068 |
+
self.linear4(
|
| 1069 |
+
self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 1070 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 1071 |
+
|
| 1072 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 1073 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
class inf_MCLM(nn.Module):
|
| 1077 |
+
|
| 1078 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
| 1079 |
+
super(inf_MCLM, self).__init__()
|
| 1080 |
+
self.attention = nn.ModuleList([
|
| 1081 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1082 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1083 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1084 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1085 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1086 |
+
])
|
| 1087 |
+
|
| 1088 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
| 1089 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
| 1090 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1091 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1092 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1093 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1094 |
+
self.dropout = nn.Dropout(0.1)
|
| 1095 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1096 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1097 |
+
self.activation = get_activation_fn('relu')
|
| 1098 |
+
self.pool_ratios = pool_ratios
|
| 1099 |
+
self.p_poses = []
|
| 1100 |
+
self.g_pos = None
|
| 1101 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1102 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1103 |
+
|
| 1104 |
+
def forward(self, l, g):
|
| 1105 |
+
"""
|
| 1106 |
+
l: 4,c,h,w
|
| 1107 |
+
g: 1,c,h,w
|
| 1108 |
+
"""
|
| 1109 |
+
b, c, h, w = l.size()
|
| 1110 |
+
# 4,c,h,w -> 1,c,2h,2w
|
| 1111 |
+
concated_locs = rearrange(l,
|
| 1112 |
+
'(hg wg b) c h w -> b c (hg h) (wg w)',
|
| 1113 |
+
hg=2,
|
| 1114 |
+
wg=2)
|
| 1115 |
+
self.p_poses = []
|
| 1116 |
+
pools = []
|
| 1117 |
+
for pool_ratio in self.pool_ratios:
|
| 1118 |
+
# b,c,h,w
|
| 1119 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1120 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
| 1121 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
| 1122 |
+
# if self.g_pos is None:
|
| 1123 |
+
pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
|
| 1124 |
+
pool.shape[3])
|
| 1125 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1126 |
+
self.p_poses.append(pos_emb)
|
| 1127 |
+
pools = torch.cat(pools, 0)
|
| 1128 |
+
# if self.g_pos is None:
|
| 1129 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
| 1130 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
|
| 1131 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
| 1132 |
+
|
| 1133 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
| 1134 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
| 1135 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
|
| 1136 |
+
g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
| 1137 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
| 1138 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(
|
| 1139 |
+
self.linear2(
|
| 1140 |
+
self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
| 1141 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
| 1142 |
+
|
| 1143 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
| 1144 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
| 1145 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
| 1146 |
+
_g_hw_b_c = rearrange(_g_hw_b_c,
|
| 1147 |
+
"(ng h) (nw w) b c -> (h w) (ng nw b) c",
|
| 1148 |
+
ng=2,
|
| 1149 |
+
nw=2)
|
| 1150 |
+
outputs_re = []
|
| 1151 |
+
for i, (_l, _g) in enumerate(
|
| 1152 |
+
zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
| 1153 |
+
outputs_re.append(self.attention[i + 1](_l, _g,
|
| 1154 |
+
_g)[0]) # (h w) 1 c
|
| 1155 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
| 1156 |
+
|
| 1157 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
| 1158 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
| 1159 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(
|
| 1160 |
+
self.linear4(
|
| 1161 |
+
self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
| 1162 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
| 1163 |
+
|
| 1164 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
| 1165 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
class MCRM(nn.Module):
|
| 1169 |
+
|
| 1170 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
| 1171 |
+
super(MCRM, self).__init__()
|
| 1172 |
+
self.attention = nn.ModuleList([
|
| 1173 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1174 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1175 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1176 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1177 |
+
])
|
| 1178 |
+
|
| 1179 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1180 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1181 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1182 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1183 |
+
self.dropout = nn.Dropout(0.1)
|
| 1184 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1185 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1186 |
+
self.sigmoid = nn.Sigmoid()
|
| 1187 |
+
self.activation = get_activation_fn('relu')
|
| 1188 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
| 1189 |
+
self.pool_ratios = pool_ratios
|
| 1190 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1191 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1192 |
+
|
| 1193 |
+
def forward(self, x):
|
| 1194 |
+
b, c, h, w = x.size()
|
| 1195 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 1196 |
+
# b(4),c,h,w
|
| 1197 |
+
patched_glb = rearrange(glb,
|
| 1198 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1199 |
+
hg=2,
|
| 1200 |
+
wg=2)
|
| 1201 |
+
|
| 1202 |
+
# generate token attention map
|
| 1203 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 1204 |
+
token_attention_map = F.interpolate(token_attention_map,
|
| 1205 |
+
size=patches2image(loc).shape[-2:],
|
| 1206 |
+
mode='nearest')
|
| 1207 |
+
loc = loc * rearrange(token_attention_map,
|
| 1208 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1209 |
+
hg=2,
|
| 1210 |
+
wg=2)
|
| 1211 |
+
pools = []
|
| 1212 |
+
for pool_ratio in self.pool_ratios:
|
| 1213 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1214 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 1215 |
+
pools.append(rearrange(pool,
|
| 1216 |
+
'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 1217 |
+
# nl(4),c,nphw -> nl(4),nphw,1,c
|
| 1218 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 1219 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 1220 |
+
outputs = []
|
| 1221 |
+
for i, q in enumerate(
|
| 1222 |
+
loc_.unbind(dim=0)): # traverse all local patches
|
| 1223 |
+
# np*hw,1,c
|
| 1224 |
+
v = pools[i]
|
| 1225 |
+
k = v
|
| 1226 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
| 1227 |
+
outputs = torch.cat(outputs, 1)
|
| 1228 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 1229 |
+
src = self.norm1(src)
|
| 1230 |
+
src = src + self.dropout2(
|
| 1231 |
+
self.linear4(
|
| 1232 |
+
self.dropout(self.activation(self.linear3(src)).clone())))
|
| 1233 |
+
src = self.norm2(src)
|
| 1234 |
+
|
| 1235 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 1236 |
+
glb = glb + F.interpolate(patches2image(src),
|
| 1237 |
+
size=glb.shape[-2:],
|
| 1238 |
+
mode='nearest') # freshed glb
|
| 1239 |
+
return torch.cat((src, glb), 0), token_attention_map
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
+
class inf_MCRM(nn.Module):
|
| 1243 |
+
|
| 1244 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
| 1245 |
+
super(inf_MCRM, self).__init__()
|
| 1246 |
+
self.attention = nn.ModuleList([
|
| 1247 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1248 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1249 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
| 1250 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
| 1251 |
+
])
|
| 1252 |
+
|
| 1253 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
| 1254 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
| 1255 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 1256 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 1257 |
+
self.dropout = nn.Dropout(0.1)
|
| 1258 |
+
self.dropout1 = nn.Dropout(0.1)
|
| 1259 |
+
self.dropout2 = nn.Dropout(0.1)
|
| 1260 |
+
self.sigmoid = nn.Sigmoid()
|
| 1261 |
+
self.activation = get_activation_fn('relu')
|
| 1262 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
| 1263 |
+
self.pool_ratios = pool_ratios
|
| 1264 |
+
self.positional_encoding = PositionEmbeddingSine(
|
| 1265 |
+
num_pos_feats=d_model // 2, normalize=True)
|
| 1266 |
+
|
| 1267 |
+
def forward(self, x):
|
| 1268 |
+
b, c, h, w = x.size()
|
| 1269 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
| 1270 |
+
# b(4),c,h,w
|
| 1271 |
+
patched_glb = rearrange(glb,
|
| 1272 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1273 |
+
hg=2,
|
| 1274 |
+
wg=2)
|
| 1275 |
+
|
| 1276 |
+
# generate token attention map
|
| 1277 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
| 1278 |
+
token_attention_map = F.interpolate(token_attention_map,
|
| 1279 |
+
size=patches2image(loc).shape[-2:],
|
| 1280 |
+
mode='nearest')
|
| 1281 |
+
loc = loc * rearrange(token_attention_map,
|
| 1282 |
+
'b c (hg h) (wg w) -> (hg wg b) c h w',
|
| 1283 |
+
hg=2,
|
| 1284 |
+
wg=2)
|
| 1285 |
+
pools = []
|
| 1286 |
+
for pool_ratio in self.pool_ratios:
|
| 1287 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
| 1288 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
| 1289 |
+
pools.append(rearrange(pool,
|
| 1290 |
+
'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
| 1291 |
+
# nl(4),c,nphw -> nl(4),nphw,1,c
|
| 1292 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
| 1293 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
| 1294 |
+
outputs = []
|
| 1295 |
+
for i, q in enumerate(
|
| 1296 |
+
loc_.unbind(dim=0)): # traverse all local patches
|
| 1297 |
+
# np*hw,1,c
|
| 1298 |
+
v = pools[i]
|
| 1299 |
+
k = v
|
| 1300 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
| 1301 |
+
outputs = torch.cat(outputs, 1)
|
| 1302 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
| 1303 |
+
src = self.norm1(src)
|
| 1304 |
+
src = src + self.dropout2(
|
| 1305 |
+
self.linear4(
|
| 1306 |
+
self.dropout(self.activation(self.linear3(src)).clone())))
|
| 1307 |
+
src = self.norm2(src)
|
| 1308 |
+
|
| 1309 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
| 1310 |
+
glb = glb + F.interpolate(patches2image(src),
|
| 1311 |
+
size=glb.shape[-2:],
|
| 1312 |
+
mode='nearest') # freshed glb
|
| 1313 |
+
return torch.cat((src, glb), 0)
|
| 1314 |
+
|
| 1315 |
+
|
| 1316 |
+
# model for single-scale training
|
| 1317 |
+
class MVANet(nn.Module):
|
| 1318 |
+
|
| 1319 |
+
def __init__(self):
|
| 1320 |
+
super().__init__()
|
| 1321 |
+
self.backbone = SwinB(pretrained=True)
|
| 1322 |
+
emb_dim = 128
|
| 1323 |
+
self.sideout5 = nn.Sequential(
|
| 1324 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1325 |
+
self.sideout4 = nn.Sequential(
|
| 1326 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1327 |
+
self.sideout3 = nn.Sequential(
|
| 1328 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1329 |
+
self.sideout2 = nn.Sequential(
|
| 1330 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1331 |
+
self.sideout1 = nn.Sequential(
|
| 1332 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1333 |
+
|
| 1334 |
+
self.output5 = make_cbr(1024, emb_dim)
|
| 1335 |
+
self.output4 = make_cbr(512, emb_dim)
|
| 1336 |
+
self.output3 = make_cbr(256, emb_dim)
|
| 1337 |
+
self.output2 = make_cbr(128, emb_dim)
|
| 1338 |
+
self.output1 = make_cbr(128, emb_dim)
|
| 1339 |
+
|
| 1340 |
+
self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
|
| 1341 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
| 1342 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
| 1343 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
| 1344 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
| 1345 |
+
self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1346 |
+
self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1347 |
+
self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1348 |
+
self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
|
| 1349 |
+
|
| 1350 |
+
self.insmask_head = nn.Sequential(
|
| 1351 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
| 1352 |
+
nn.BatchNorm2d(384), nn.PReLU(),
|
| 1353 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
|
| 1354 |
+
nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
|
| 1355 |
+
|
| 1356 |
+
self.shallow = nn.Sequential(
|
| 1357 |
+
nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
| 1358 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
| 1359 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
| 1360 |
+
self.output = nn.Sequential(
|
| 1361 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1362 |
+
|
| 1363 |
+
for m in self.modules():
|
| 1364 |
+
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
|
| 1365 |
+
m.inplace = True
|
| 1366 |
+
|
| 1367 |
+
def forward(self, x):
|
| 1368 |
+
x = x.to(dtype=torch_dtype, device=torch_device)
|
| 1369 |
+
shallow = self.shallow(x)
|
| 1370 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
| 1371 |
+
loc = image2patches(x)
|
| 1372 |
+
input = torch.cat((loc, glb), dim=0)
|
| 1373 |
+
feature = self.backbone(input)
|
| 1374 |
+
e5 = self.output5(feature[4]) # (5,128,16,16)
|
| 1375 |
+
e4 = self.output4(feature[3]) # (5,128,32,32)
|
| 1376 |
+
e3 = self.output3(feature[2]) # (5,128,64,64)
|
| 1377 |
+
e2 = self.output2(feature[1]) # (5,128,128,128)
|
| 1378 |
+
e1 = self.output1(feature[0]) # (5,128,128,128)
|
| 1379 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
| 1380 |
+
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
| 1381 |
+
|
| 1382 |
+
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
| 1383 |
+
e4 = self.conv4(e4)
|
| 1384 |
+
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
| 1385 |
+
e3 = self.conv3(e3)
|
| 1386 |
+
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
| 1387 |
+
e2 = self.conv2(e2)
|
| 1388 |
+
e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
|
| 1389 |
+
e1 = self.conv1(e1)
|
| 1390 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
| 1391 |
+
output1_cat = patches2image(loc_e1) # (1,128,256,256)
|
| 1392 |
+
# add glb feat in
|
| 1393 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
| 1394 |
+
# merge
|
| 1395 |
+
final_output = self.insmask_head(output1_cat) # (1,128,256,256)
|
| 1396 |
+
# shallow feature merge
|
| 1397 |
+
final_output = final_output + resize_as(shallow, final_output)
|
| 1398 |
+
final_output = self.upsample1(rescale_to(final_output))
|
| 1399 |
+
final_output = rescale_to(final_output +
|
| 1400 |
+
resize_as(shallow, final_output))
|
| 1401 |
+
final_output = self.upsample2(final_output)
|
| 1402 |
+
final_output = self.output(final_output)
|
| 1403 |
+
####
|
| 1404 |
+
sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
|
| 1405 |
+
sideout4 = self.sideout4(e4)
|
| 1406 |
+
sideout3 = self.sideout3(e3)
|
| 1407 |
+
sideout2 = self.sideout2(e2)
|
| 1408 |
+
sideout1 = self.sideout1(e1)
|
| 1409 |
+
#######glb_sideouts ######
|
| 1410 |
+
glb5 = self.sideout5(glb_e5)
|
| 1411 |
+
glb4 = sideout4[-1, :, :, :].unsqueeze(0)
|
| 1412 |
+
glb3 = sideout3[-1, :, :, :].unsqueeze(0)
|
| 1413 |
+
glb2 = sideout2[-1, :, :, :].unsqueeze(0)
|
| 1414 |
+
glb1 = sideout1[-1, :, :, :].unsqueeze(0)
|
| 1415 |
+
####### concat 4 to 1 #######
|
| 1416 |
+
sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
|
| 1417 |
+
device=torch_device)
|
| 1418 |
+
sideout2 = patches2image(sideout2[:-1]).to(
|
| 1419 |
+
dtype=torch_dtype,
|
| 1420 |
+
device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
|
| 1421 |
+
sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
|
| 1422 |
+
device=torch_device)
|
| 1423 |
+
sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
|
| 1424 |
+
device=torch_device)
|
| 1425 |
+
sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
|
| 1426 |
+
device=torch_device)
|
| 1427 |
+
if self.training:
|
| 1428 |
+
return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
|
| 1429 |
+
else:
|
| 1430 |
+
return final_output
|
| 1431 |
+
|
| 1432 |
+
|
| 1433 |
+
# model for multi-scale testing
|
| 1434 |
+
class inf_MVANet(nn.Module):
|
| 1435 |
+
|
| 1436 |
+
def __init__(self):
|
| 1437 |
+
super().__init__()
|
| 1438 |
+
# self.backbone = SwinB(pretrained=True)
|
| 1439 |
+
self.backbone = SwinB(pretrained=False)
|
| 1440 |
+
|
| 1441 |
+
emb_dim = 128
|
| 1442 |
+
self.output5 = make_cbr(1024, emb_dim)
|
| 1443 |
+
self.output4 = make_cbr(512, emb_dim)
|
| 1444 |
+
self.output3 = make_cbr(256, emb_dim)
|
| 1445 |
+
self.output2 = make_cbr(128, emb_dim)
|
| 1446 |
+
self.output1 = make_cbr(128, emb_dim)
|
| 1447 |
+
|
| 1448 |
+
self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
|
| 1449 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
| 1450 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
| 1451 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
| 1452 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
| 1453 |
+
self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1454 |
+
self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1455 |
+
self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1456 |
+
self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
|
| 1457 |
+
|
| 1458 |
+
self.insmask_head = nn.Sequential(
|
| 1459 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
| 1460 |
+
nn.BatchNorm2d(384), nn.PReLU(),
|
| 1461 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
|
| 1462 |
+
nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
|
| 1463 |
+
|
| 1464 |
+
self.shallow = nn.Sequential(
|
| 1465 |
+
nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
| 1466 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
| 1467 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
| 1468 |
+
self.output = nn.Sequential(
|
| 1469 |
+
nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
| 1470 |
+
|
| 1471 |
+
for m in self.modules():
|
| 1472 |
+
if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
|
| 1473 |
+
m.inplace = True
|
| 1474 |
+
|
| 1475 |
+
def forward(self, x):
|
| 1476 |
+
shallow = self.shallow(x)
|
| 1477 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
| 1478 |
+
loc = image2patches(x)
|
| 1479 |
+
input = torch.cat((loc, glb), dim=0)
|
| 1480 |
+
feature = self.backbone(input)
|
| 1481 |
+
e5 = self.output5(feature[4])
|
| 1482 |
+
e4 = self.output4(feature[3])
|
| 1483 |
+
e3 = self.output3(feature[2])
|
| 1484 |
+
e2 = self.output2(feature[1])
|
| 1485 |
+
e1 = self.output1(feature[0])
|
| 1486 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
| 1487 |
+
e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
|
| 1488 |
+
|
| 1489 |
+
e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
|
| 1490 |
+
e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
|
| 1491 |
+
e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
|
| 1492 |
+
e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
|
| 1493 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
| 1494 |
+
# after decoder, concat loc features to a whole one, and merge
|
| 1495 |
+
output1_cat = patches2image(loc_e1)
|
| 1496 |
+
# add glb feat in
|
| 1497 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
| 1498 |
+
# merge
|
| 1499 |
+
final_output = self.insmask_head(output1_cat)
|
| 1500 |
+
# shallow feature merge
|
| 1501 |
+
final_output = final_output + resize_as(shallow, final_output)
|
| 1502 |
+
final_output = self.upsample1(rescale_to(final_output))
|
| 1503 |
+
final_output = rescale_to(final_output +
|
| 1504 |
+
resize_as(shallow, final_output))
|
| 1505 |
+
final_output = self.upsample2(final_output)
|
| 1506 |
+
final_output = self.output(final_output)
|
| 1507 |
+
return final_output
|
| 1508 |
+
#+end_src
|
| 1509 |
+
|
| 1510 |
+
** Function to load model
|
| 1511 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 1512 |
+
def load_model(model_checkpoint_path):
|
| 1513 |
+
torch.cuda.set_device(0)
|
| 1514 |
+
|
| 1515 |
+
net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
|
| 1516 |
+
|
| 1517 |
+
pretrained_dict = torch.load(model_checkpoint_path,
|
| 1518 |
+
map_location=torch_device)
|
| 1519 |
+
|
| 1520 |
+
model_dict = net.state_dict()
|
| 1521 |
+
pretrained_dict = {
|
| 1522 |
+
k: v
|
| 1523 |
+
for k, v in pretrained_dict.items() if k in model_dict
|
| 1524 |
+
}
|
| 1525 |
+
model_dict.update(pretrained_dict)
|
| 1526 |
+
net.load_state_dict(model_dict)
|
| 1527 |
+
net = net.to(dtype=torch_dtype, device=torch_device)
|
| 1528 |
+
net.eval()
|
| 1529 |
+
return net
|
| 1530 |
+
|
| 1531 |
+
|
| 1532 |
+
def load_transforms_stripped():
|
| 1533 |
+
img_transform = transforms.Compose([
|
| 1534 |
+
# transforms.ToTensor(),
|
| 1535 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 1536 |
+
])
|
| 1537 |
+
|
| 1538 |
+
return img_transform
|
| 1539 |
+
|
| 1540 |
+
|
| 1541 |
+
def load_transforms():
|
| 1542 |
+
img_transform = transforms.Compose([
|
| 1543 |
+
# transforms.ToTensor(),
|
| 1544 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 1545 |
+
])
|
| 1546 |
+
|
| 1547 |
+
depth_transform = transforms.ToTensor()
|
| 1548 |
+
target_transform = transforms.ToTensor()
|
| 1549 |
+
to_pil = transforms.ToPILImage()
|
| 1550 |
+
|
| 1551 |
+
transforms_var = tta.Compose([
|
| 1552 |
+
tta.HorizontalFlip(),
|
| 1553 |
+
tta.Scale(scales=[0.75, 1, 1.25],
|
| 1554 |
+
interpolation='bilinear',
|
| 1555 |
+
align_corners=False),
|
| 1556 |
+
])
|
| 1557 |
+
|
| 1558 |
+
return (img_transform, depth_transform, target_transform, to_pil,
|
| 1559 |
+
transforms_var)
|
| 1560 |
+
#+end_src
|
| 1561 |
+
|
| 1562 |
+
** Function for modular inference CV
|
| 1563 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 1564 |
+
def do_infer_tensor2tensor(img, net):
|
| 1565 |
+
|
| 1566 |
+
img_transform = transforms.Compose(
|
| 1567 |
+
[transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
| 1568 |
+
|
| 1569 |
+
h_, w_ = img.shape[1], img.shape[2]
|
| 1570 |
+
|
| 1571 |
+
with torch.no_grad():
|
| 1572 |
+
|
| 1573 |
+
img = rearrange(img, 'B H W C -> B C H W')
|
| 1574 |
+
|
| 1575 |
+
img_resize = torch.nn.functional.interpolate(input=img,
|
| 1576 |
+
size=(1024, 1024),
|
| 1577 |
+
mode='bicubic',
|
| 1578 |
+
antialias=True)
|
| 1579 |
+
|
| 1580 |
+
img_var = img_transform(img_resize)
|
| 1581 |
+
img_var = Variable(img_var)
|
| 1582 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1583 |
+
|
| 1584 |
+
mask = []
|
| 1585 |
+
|
| 1586 |
+
mask.append(net(img_var))
|
| 1587 |
+
|
| 1588 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1589 |
+
prediction = prediction.sigmoid()
|
| 1590 |
+
|
| 1591 |
+
prediction = torch.nn.functional.interpolate(input=prediction,
|
| 1592 |
+
size=(h_, w_),
|
| 1593 |
+
mode='bicubic',
|
| 1594 |
+
antialias=True)
|
| 1595 |
+
|
| 1596 |
+
prediction = prediction.squeeze(0)
|
| 1597 |
+
prediction = prediction.clamp(0, 1)
|
| 1598 |
+
|
| 1599 |
+
return prediction
|
| 1600 |
+
|
| 1601 |
+
|
| 1602 |
+
def do_infer_modular_cv(input_image_path, output_mask_path, net,
|
| 1603 |
+
all_transforms):
|
| 1604 |
+
|
| 1605 |
+
(img_transform, depth_transform, target_transform, to_pil,
|
| 1606 |
+
transforms_var) = all_transforms
|
| 1607 |
+
|
| 1608 |
+
img = load_image_torch(input_image_path)
|
| 1609 |
+
|
| 1610 |
+
h_, w_ = img.shape[1], img.shape[2]
|
| 1611 |
+
|
| 1612 |
+
with torch.no_grad():
|
| 1613 |
+
|
| 1614 |
+
img = rearrange(img, 'B H W C -> B C H W')
|
| 1615 |
+
|
| 1616 |
+
img_resize = torch.nn.functional.interpolate(input=img,
|
| 1617 |
+
size=(1024, 1024),
|
| 1618 |
+
mode='bicubic',
|
| 1619 |
+
antialias=True)
|
| 1620 |
+
|
| 1621 |
+
img_var = img_transform(img_resize)
|
| 1622 |
+
img_var = Variable(img_var)
|
| 1623 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1624 |
+
|
| 1625 |
+
mask = []
|
| 1626 |
+
|
| 1627 |
+
for transformer in transforms_var:
|
| 1628 |
+
rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1629 |
+
mask.append(net(rgb_trans))
|
| 1630 |
+
|
| 1631 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1632 |
+
prediction = prediction.sigmoid()
|
| 1633 |
+
|
| 1634 |
+
prediction = torch.nn.functional.interpolate(input=prediction,
|
| 1635 |
+
size=(h_, w_),
|
| 1636 |
+
mode='bicubic',
|
| 1637 |
+
antialias=True)
|
| 1638 |
+
|
| 1639 |
+
prediction = prediction.squeeze(0)
|
| 1640 |
+
prediction = prediction.clamp(0, 1)
|
| 1641 |
+
|
| 1642 |
+
save_mask_torch(output_image_path=output_mask_path, mask=prediction)
|
| 1643 |
+
|
| 1644 |
+
|
| 1645 |
+
def do_infer_modular_cv_2(input_image_path, output_mask_path, net,
|
| 1646 |
+
all_transforms):
|
| 1647 |
+
|
| 1648 |
+
(img_transform, depth_transform, target_transform, to_pil,
|
| 1649 |
+
transforms_var) = all_transforms
|
| 1650 |
+
|
| 1651 |
+
img = load_image(input_image_path)
|
| 1652 |
+
w_, h_ = img.shape[0], img.shape[1]
|
| 1653 |
+
img_resize = cv2.resize(img, (1024, 1024), cv2.INTER_CUBIC)
|
| 1654 |
+
|
| 1655 |
+
with torch.no_grad():
|
| 1656 |
+
|
| 1657 |
+
# rgb_png_path = input_image_path
|
| 1658 |
+
# img = Image.open(rgb_png_path).convert('RGB')
|
| 1659 |
+
# w_, h_ = img.size
|
| 1660 |
+
|
| 1661 |
+
# img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
|
| 1662 |
+
|
| 1663 |
+
# img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
|
| 1664 |
+
# dtype=torch_dtype, device=torch_device)
|
| 1665 |
+
|
| 1666 |
+
img_resize = torch.from_numpy(img_resize)
|
| 1667 |
+
img_resize = img_resize.to(dtype=torch.float32)
|
| 1668 |
+
img_resize /= 255.0
|
| 1669 |
+
img_resize = rearrange(img_resize, 'H W C -> C H W')
|
| 1670 |
+
img_var = img_transform(img_resize)
|
| 1671 |
+
img_var = img_var.unsqueeze(0)
|
| 1672 |
+
img_var = Variable(img_var)
|
| 1673 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1674 |
+
|
| 1675 |
+
mask = []
|
| 1676 |
+
|
| 1677 |
+
for transformer in transforms_var:
|
| 1678 |
+
rgb_trans = transformer.augment_image(img_var)
|
| 1679 |
+
rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
|
| 1680 |
+
model_output = net(rgb_trans)
|
| 1681 |
+
deaug_mask = transformer.deaugment_mask(model_output)
|
| 1682 |
+
mask.append(deaug_mask)
|
| 1683 |
+
|
| 1684 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1685 |
+
prediction = prediction.sigmoid()
|
| 1686 |
+
prediction = to_pil(prediction.data.squeeze(0).cpu())
|
| 1687 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 1688 |
+
prediction.save(output_mask_path)
|
| 1689 |
+
|
| 1690 |
+
|
| 1691 |
+
def do_infer_modular_cv_3(input_image_path, output_mask_path, net,
|
| 1692 |
+
all_transforms):
|
| 1693 |
+
|
| 1694 |
+
(img_transform, depth_transform, target_transform, to_pil,
|
| 1695 |
+
transforms_var) = all_transforms
|
| 1696 |
+
|
| 1697 |
+
img = load_image(input_image_path)
|
| 1698 |
+
w_, h_ = img.shape[0], img.shape[1]
|
| 1699 |
+
|
| 1700 |
+
with torch.no_grad():
|
| 1701 |
+
|
| 1702 |
+
# rgb_png_path = input_image_path
|
| 1703 |
+
# img = Image.open(rgb_png_path).convert('RGB')
|
| 1704 |
+
# w_, h_ = img.size
|
| 1705 |
+
|
| 1706 |
+
# img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
|
| 1707 |
+
|
| 1708 |
+
# img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
|
| 1709 |
+
# dtype=torch_dtype, device=torch_device)
|
| 1710 |
+
|
| 1711 |
+
img_resize = torch.from_numpy(img)
|
| 1712 |
+
img_resize = img_resize.to(dtype=torch.float32)
|
| 1713 |
+
img_resize = rearrange(img_resize, 'H W C -> C H W')
|
| 1714 |
+
img_resize = img_resize.unsqueeze(0)
|
| 1715 |
+
|
| 1716 |
+
img_resize = torch.nn.functional.interpolate(input=img_resize,
|
| 1717 |
+
size=(1024, 1024),
|
| 1718 |
+
mode='bicubic',
|
| 1719 |
+
antialias=True)
|
| 1720 |
+
|
| 1721 |
+
img_resize = img_resize.squeeze(0)
|
| 1722 |
+
img_resize = rearrange(img_resize, 'C H W -> H W C')
|
| 1723 |
+
|
| 1724 |
+
img_resize = img_resize.to(dtype=torch.float32)
|
| 1725 |
+
img_resize /= 255.0
|
| 1726 |
+
img_resize = rearrange(img_resize, 'H W C -> C H W')
|
| 1727 |
+
img_var = img_transform(img_resize)
|
| 1728 |
+
img_var = img_var.unsqueeze(0)
|
| 1729 |
+
img_var = Variable(img_var)
|
| 1730 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1731 |
+
|
| 1732 |
+
mask = []
|
| 1733 |
+
|
| 1734 |
+
for transformer in transforms_var:
|
| 1735 |
+
rgb_trans = transformer.augment_image(img_var)
|
| 1736 |
+
rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
|
| 1737 |
+
model_output = net(rgb_trans)
|
| 1738 |
+
deaug_mask = transformer.deaugment_mask(model_output)
|
| 1739 |
+
mask.append(deaug_mask)
|
| 1740 |
+
|
| 1741 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1742 |
+
prediction = prediction.sigmoid()
|
| 1743 |
+
prediction = to_pil(prediction.data.squeeze(0).cpu())
|
| 1744 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 1745 |
+
prediction.save(output_mask_path)
|
| 1746 |
+
|
| 1747 |
+
|
| 1748 |
+
def do_infer_modular_cv_4(input_image_path, output_mask_path, net,
|
| 1749 |
+
all_transforms):
|
| 1750 |
+
|
| 1751 |
+
(img_transform, depth_transform, target_transform, to_pil,
|
| 1752 |
+
transforms_var) = all_transforms
|
| 1753 |
+
|
| 1754 |
+
img = load_image(input_image_path)
|
| 1755 |
+
w_, h_ = img.shape[0], img.shape[1]
|
| 1756 |
+
|
| 1757 |
+
with torch.no_grad():
|
| 1758 |
+
|
| 1759 |
+
img_resize = torch.from_numpy(img)
|
| 1760 |
+
img_resize = img_resize.to(dtype=torch.float32)
|
| 1761 |
+
img_resize /= 255.0
|
| 1762 |
+
img_resize = img_resize.unsqueeze(0)
|
| 1763 |
+
|
| 1764 |
+
img_resize = rearrange(img_resize, 'B H W C -> B C H W')
|
| 1765 |
+
|
| 1766 |
+
img_resize = torch.nn.functional.interpolate(input=img_resize,
|
| 1767 |
+
size=(1024, 1024),
|
| 1768 |
+
mode='bicubic',
|
| 1769 |
+
antialias=True)
|
| 1770 |
+
|
| 1771 |
+
img_resize = img_resize.squeeze(0)
|
| 1772 |
+
img_var = img_transform(img_resize)
|
| 1773 |
+
img_var = img_var.unsqueeze(0)
|
| 1774 |
+
img_var = Variable(img_var)
|
| 1775 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1776 |
+
|
| 1777 |
+
mask = []
|
| 1778 |
+
|
| 1779 |
+
for transformer in transforms_var:
|
| 1780 |
+
rgb_trans = transformer.augment_image(img_var)
|
| 1781 |
+
rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
|
| 1782 |
+
model_output = net(rgb_trans)
|
| 1783 |
+
deaug_mask = transformer.deaugment_mask(model_output)
|
| 1784 |
+
mask.append(deaug_mask)
|
| 1785 |
+
|
| 1786 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1787 |
+
prediction = prediction.sigmoid()
|
| 1788 |
+
prediction = to_pil(prediction.data.squeeze(0).cpu())
|
| 1789 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 1790 |
+
prediction.save(output_mask_path)
|
| 1791 |
+
|
| 1792 |
+
|
| 1793 |
+
def do_infer_modular_cv_5(input_image_path, output_mask_path, net,
|
| 1794 |
+
all_transforms):
|
| 1795 |
+
|
| 1796 |
+
(img_transform, depth_transform, target_transform, to_pil,
|
| 1797 |
+
transforms_var) = all_transforms
|
| 1798 |
+
|
| 1799 |
+
img = load_image(input_image_path)
|
| 1800 |
+
w_, h_ = img.shape[0], img.shape[1]
|
| 1801 |
+
|
| 1802 |
+
with torch.no_grad():
|
| 1803 |
+
|
| 1804 |
+
img_resize = torch.from_numpy(img)
|
| 1805 |
+
img_resize = img_resize.to(dtype=torch.float32)
|
| 1806 |
+
img_resize /= 255.0
|
| 1807 |
+
img_resize = img_resize.unsqueeze(0)
|
| 1808 |
+
|
| 1809 |
+
img_resize = rearrange(img_resize, 'B H W C -> B C H W')
|
| 1810 |
+
|
| 1811 |
+
img_resize = torch.nn.functional.interpolate(input=img_resize,
|
| 1812 |
+
size=(1024, 1024),
|
| 1813 |
+
mode='bicubic',
|
| 1814 |
+
antialias=True)
|
| 1815 |
+
|
| 1816 |
+
img_var = img_transform(img_resize)
|
| 1817 |
+
img_var = Variable(img_var)
|
| 1818 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1819 |
+
|
| 1820 |
+
mask = []
|
| 1821 |
+
|
| 1822 |
+
for transformer in transforms_var:
|
| 1823 |
+
rgb_trans = transformer.augment_image(img_var)
|
| 1824 |
+
rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
|
| 1825 |
+
model_output = net(rgb_trans)
|
| 1826 |
+
deaug_mask = transformer.deaugment_mask(model_output)
|
| 1827 |
+
mask.append(deaug_mask)
|
| 1828 |
+
|
| 1829 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1830 |
+
prediction = prediction.sigmoid()
|
| 1831 |
+
prediction = to_pil(prediction.data.squeeze(0).cpu())
|
| 1832 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 1833 |
+
prediction.save(output_mask_path)
|
| 1834 |
+
|
| 1835 |
+
|
| 1836 |
+
def do_infer_modular_cv_6(input_image_path, output_mask_path, net,
|
| 1837 |
+
all_transforms):
|
| 1838 |
+
|
| 1839 |
+
(img_transform, depth_transform, target_transform, to_pil,
|
| 1840 |
+
transforms_var) = all_transforms
|
| 1841 |
+
|
| 1842 |
+
img = load_image(input_image_path)
|
| 1843 |
+
w_, h_ = img.shape[0], img.shape[1]
|
| 1844 |
+
|
| 1845 |
+
with torch.no_grad():
|
| 1846 |
+
|
| 1847 |
+
img_resize = torch.from_numpy(img)
|
| 1848 |
+
img_resize = img_resize.to(dtype=torch.float32)
|
| 1849 |
+
img_resize /= 255.0
|
| 1850 |
+
img_resize = img_resize.unsqueeze(0)
|
| 1851 |
+
|
| 1852 |
+
img_resize = rearrange(img_resize, 'B H W C -> B C H W')
|
| 1853 |
+
|
| 1854 |
+
img_resize = torch.nn.functional.interpolate(input=img_resize,
|
| 1855 |
+
size=(1024, 1024),
|
| 1856 |
+
mode='bicubic',
|
| 1857 |
+
antialias=True)
|
| 1858 |
+
|
| 1859 |
+
img_var = img_transform(img_resize)
|
| 1860 |
+
img_var = Variable(img_var)
|
| 1861 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1862 |
+
|
| 1863 |
+
mask = []
|
| 1864 |
+
|
| 1865 |
+
for transformer in transforms_var:
|
| 1866 |
+
rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1867 |
+
mask.append(net(rgb_trans))
|
| 1868 |
+
|
| 1869 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1870 |
+
prediction = prediction.sigmoid()
|
| 1871 |
+
prediction = to_pil(prediction.data.squeeze(0).cpu())
|
| 1872 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 1873 |
+
prediction.save(output_mask_path)
|
| 1874 |
+
|
| 1875 |
+
|
| 1876 |
+
def do_infer_modular_cv_7(input_image_path, output_mask_path, net,
|
| 1877 |
+
all_transforms):
|
| 1878 |
+
|
| 1879 |
+
(img_transform, depth_transform, target_transform, to_pil,
|
| 1880 |
+
transforms_var) = all_transforms
|
| 1881 |
+
|
| 1882 |
+
img = load_image_torch(input_image_path)
|
| 1883 |
+
|
| 1884 |
+
h_, w_ = img.shape[1], img.shape[2]
|
| 1885 |
+
|
| 1886 |
+
with torch.no_grad():
|
| 1887 |
+
|
| 1888 |
+
img = rearrange(img, 'B H W C -> B C H W')
|
| 1889 |
+
|
| 1890 |
+
img_resize = torch.nn.functional.interpolate(input=img,
|
| 1891 |
+
size=(1024, 1024),
|
| 1892 |
+
mode='bicubic',
|
| 1893 |
+
antialias=True)
|
| 1894 |
+
|
| 1895 |
+
img_var = img_transform(img_resize)
|
| 1896 |
+
img_var = Variable(img_var)
|
| 1897 |
+
img_var = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1898 |
+
|
| 1899 |
+
mask = []
|
| 1900 |
+
|
| 1901 |
+
for transformer in transforms_var:
|
| 1902 |
+
rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
|
| 1903 |
+
mask.append(net(rgb_trans))
|
| 1904 |
+
|
| 1905 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1906 |
+
prediction = prediction.sigmoid()
|
| 1907 |
+
prediction = to_pil(prediction.data.squeeze(0).cpu())
|
| 1908 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 1909 |
+
prediction.save(output_mask_path)
|
| 1910 |
+
#+end_src
|
| 1911 |
+
|
| 1912 |
+
** Function for modular inference
|
| 1913 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 1914 |
+
def do_infer_modular(input_image_path, output_mask_path, net, all_transforms):
|
| 1915 |
+
# net = load_model(finetuned_MVANet_model_path)
|
| 1916 |
+
|
| 1917 |
+
(img_transform, depth_transform, target_transform, to_pil,
|
| 1918 |
+
transforms_var) = all_transforms
|
| 1919 |
+
|
| 1920 |
+
with torch.no_grad():
|
| 1921 |
+
rgb_png_path = input_image_path
|
| 1922 |
+
img = Image.open(rgb_png_path).convert('RGB')
|
| 1923 |
+
|
| 1924 |
+
w_, h_ = img.size
|
| 1925 |
+
# img_resize = img.resize([(w_ // 2) * 2, (h_ // 2) * 2], Image.BILINEAR)
|
| 1926 |
+
img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
|
| 1927 |
+
# img_resize = img
|
| 1928 |
+
img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
|
| 1929 |
+
dtype=torch_dtype, device=torch_device)
|
| 1930 |
+
mask = []
|
| 1931 |
+
for transformer in transforms_var:
|
| 1932 |
+
rgb_trans = transformer.augment_image(img_var)
|
| 1933 |
+
rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
|
| 1934 |
+
model_output = net(rgb_trans)
|
| 1935 |
+
deaug_mask = transformer.deaugment_mask(model_output)
|
| 1936 |
+
mask.append(deaug_mask)
|
| 1937 |
+
|
| 1938 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1939 |
+
prediction = prediction.sigmoid()
|
| 1940 |
+
prediction = to_pil(prediction.data.squeeze(0).cpu())
|
| 1941 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 1942 |
+
prediction.save(output_mask_path)
|
| 1943 |
+
#+end_src
|
| 1944 |
+
|
| 1945 |
+
** Function for inference
|
| 1946 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 1947 |
+
def do_infer():
|
| 1948 |
+
torch.cuda.set_device(0)
|
| 1949 |
+
args = {'crf_refine': True, 'save_results': True}
|
| 1950 |
+
|
| 1951 |
+
img_transform = transforms.Compose([
|
| 1952 |
+
transforms.ToTensor(),
|
| 1953 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 1954 |
+
])
|
| 1955 |
+
|
| 1956 |
+
depth_transform = transforms.ToTensor()
|
| 1957 |
+
target_transform = transforms.ToTensor()
|
| 1958 |
+
to_pil = transforms.ToPILImage()
|
| 1959 |
+
|
| 1960 |
+
transforms_var = tta.Compose([
|
| 1961 |
+
tta.HorizontalFlip(),
|
| 1962 |
+
tta.Scale(scales=[0.75, 1, 1.25],
|
| 1963 |
+
interpolation='bilinear',
|
| 1964 |
+
align_corners=False),
|
| 1965 |
+
])
|
| 1966 |
+
|
| 1967 |
+
net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
|
| 1968 |
+
pretrained_dict = torch.load(finetuned_MVANet_model_path,
|
| 1969 |
+
map_location=torch_device)
|
| 1970 |
+
model_dict = net.state_dict()
|
| 1971 |
+
pretrained_dict = {
|
| 1972 |
+
k: v
|
| 1973 |
+
for k, v in pretrained_dict.items() if k in model_dict
|
| 1974 |
+
}
|
| 1975 |
+
model_dict.update(pretrained_dict)
|
| 1976 |
+
net.load_state_dict(model_dict)
|
| 1977 |
+
net = net.to(dtype=torch_dtype, device=torch_device)
|
| 1978 |
+
net.eval()
|
| 1979 |
+
with torch.no_grad():
|
| 1980 |
+
rgb_png_path = '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png'
|
| 1981 |
+
img = Image.open(rgb_png_path).convert('RGB')
|
| 1982 |
+
w_, h_ = img.size
|
| 1983 |
+
# img_resize = img.resize([(w_ // 2) * 2, (h_ // 2) * 2], Image.BILINEAR)
|
| 1984 |
+
img_resize = img.resize([256 * 4 , 256 * 4 ], Image.BILINEAR)
|
| 1985 |
+
# img_resize = img
|
| 1986 |
+
img_var = Variable(img_transform(img_resize).unsqueeze(0),
|
| 1987 |
+
volatile=True).cuda()
|
| 1988 |
+
mask = []
|
| 1989 |
+
for transformer in transforms_var:
|
| 1990 |
+
rgb_trans = transformer.augment_image(img_var)
|
| 1991 |
+
rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
|
| 1992 |
+
model_output = net(rgb_trans)
|
| 1993 |
+
deaug_mask = transformer.deaugment_mask(model_output)
|
| 1994 |
+
mask.append(deaug_mask)
|
| 1995 |
+
|
| 1996 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 1997 |
+
prediction = prediction.sigmoid()
|
| 1998 |
+
prediction = to_pil(prediction.data.squeeze(0).cpu())
|
| 1999 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 2000 |
+
prediction.save('./tmp.png')
|
| 2001 |
+
#+end_src
|
| 2002 |
+
|
| 2003 |
+
** MVANet_inference function
|
| 2004 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
|
| 2005 |
+
def main(item):
|
| 2006 |
+
net = inf_MVANet().cuda()
|
| 2007 |
+
pretrained_dict = torch.load(os.path.join(ckpt_path, item + '.pth'),
|
| 2008 |
+
map_location='cuda')
|
| 2009 |
+
model_dict = net.state_dict()
|
| 2010 |
+
pretrained_dict = {
|
| 2011 |
+
k: v
|
| 2012 |
+
for k, v in pretrained_dict.items() if k in model_dict
|
| 2013 |
+
}
|
| 2014 |
+
model_dict.update(pretrained_dict)
|
| 2015 |
+
net.load_state_dict(model_dict)
|
| 2016 |
+
net.eval()
|
| 2017 |
+
with torch.no_grad():
|
| 2018 |
+
for name, root in to_test.items():
|
| 2019 |
+
root1 = os.path.join(root, 'images')
|
| 2020 |
+
img_list = [os.path.splitext(f) for f in os.listdir(root1)]
|
| 2021 |
+
for idx, img_name in enumerate(img_list):
|
| 2022 |
+
|
| 2023 |
+
print('predicting for %s: %d / %d' %
|
| 2024 |
+
(name, idx + 1, len(img_list)))
|
| 2025 |
+
rgb_png_path = os.path.join(root, 'images',
|
| 2026 |
+
img_name[0] + '.png')
|
| 2027 |
+
rgb_jpg_path = os.path.join(root, 'images',
|
| 2028 |
+
img_name[0] + '.jpg')
|
| 2029 |
+
if os.path.exists(rgb_png_path):
|
| 2030 |
+
img = Image.open(rgb_png_path).convert('RGB')
|
| 2031 |
+
else:
|
| 2032 |
+
img = Image.open(rgb_jpg_path).convert('RGB')
|
| 2033 |
+
w_, h_ = img.size
|
| 2034 |
+
img_resize = img.resize([1024, 1024], Image.BILINEAR)
|
| 2035 |
+
img_var = Variable(img_transform(img_resize).unsqueeze(0),
|
| 2036 |
+
volatile=True).cuda()
|
| 2037 |
+
mask = []
|
| 2038 |
+
for transformer in transforms_var:
|
| 2039 |
+
rgb_trans = transformer.augment_image(img_var)
|
| 2040 |
+
model_output = net(rgb_trans)
|
| 2041 |
+
deaug_mask = transformer.deaugment_mask(model_output)
|
| 2042 |
+
mask.append(deaug_mask)
|
| 2043 |
+
|
| 2044 |
+
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
|
| 2045 |
+
prediction = prediction.sigmoid()
|
| 2046 |
+
prediction = to_pil(prediction.data.squeeze(0))
|
| 2047 |
+
prediction = prediction.resize((w_, h_), Image.BILINEAR)
|
| 2048 |
+
if args['save_results']:
|
| 2049 |
+
check_mkdir(os.path.join(ckpt_path, item, name))
|
| 2050 |
+
prediction.save(
|
| 2051 |
+
os.path.join(ckpt_path, item, name,
|
| 2052 |
+
img_name[0] + '.png'))
|
| 2053 |
+
#+end_src
|
| 2054 |
+
|
| 2055 |
+
** MVANet_inference execute
|
| 2056 |
+
#+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
|
| 2057 |
+
def do_merge(path_image, path_mask, path_out):
|
| 2058 |
+
image = cv2.imread(path_image, cv2.IMREAD_COLOR)
|
| 2059 |
+
mask = cv2.imread(path_mask, cv2.IMREAD_GRAYSCALE)
|
| 2060 |
+
mask = (mask > 127).astype(dtype=np.uint8) * 255
|
| 2061 |
+
out = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
| 2062 |
+
out[:, :, 0:3] = image
|
| 2063 |
+
out[:, :, 3] = mask
|
| 2064 |
+
cv2.imwrite(path_out, out)
|
| 2065 |
+
|
| 2066 |
+
|
| 2067 |
+
if __name__ == '__main__':
|
| 2068 |
+
|
| 2069 |
+
# do_infer_modular_cv(
|
| 2070 |
+
# input_image_path=
|
| 2071 |
+
# '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png',
|
| 2072 |
+
# output_mask_path='./tmp.png',
|
| 2073 |
+
# net=load_model(finetuned_MVANet_model_path),
|
| 2074 |
+
# all_transforms=load_transforms(),
|
| 2075 |
+
# )
|
| 2076 |
+
|
| 2077 |
+
# net = load_model(
|
| 2078 |
+
# HOME_DIR + '/dreambooth_experiments/MVANet/MVANet_cloth_segment_14.pth')
|
| 2079 |
+
|
| 2080 |
+
# net = load_model(
|
| 2081 |
+
# HOME_DIR +
|
| 2082 |
+
# '/dreambooth_experiments/MVANet/new_type_crop_with_midshot.pth')
|
| 2083 |
+
|
| 2084 |
+
# net = load_model('/home/asd/MODEL_CHECKPOINTS/MVANet/SKIN_SEGMENTATION/1/Model_4.pth')
|
| 2085 |
+
|
| 2086 |
+
net = load_model('/home/asd/MODEL_CHECKPOINTS/MVANet/SKIN_SEGMENTATION/3/Model_14.pth')
|
| 2087 |
+
|
| 2088 |
+
|
| 2089 |
+
# net = load_model(HOME_DIR +
|
| 2090 |
+
# '/dreambooth_experiments/MVANet/mvanet_normal_crop_2.pth')
|
| 2091 |
+
|
| 2092 |
+
DATA_DIR_BASE = HOME_DIR + '/DATASETS/cloth_segmentation_test_images.dir/cloth_segmentation_test_images/'
|
| 2093 |
+
|
| 2094 |
+
images = (
|
| 2095 |
+
'1370', '1371', '1372', '1373', '1374', '1375', '1376', '1377', '1378',
|
| 2096 |
+
'1379', '1380', '1381', '1382', '1383', '1384', '1385', '1386', '1387',
|
| 2097 |
+
'1388', '1389', '1390', '1391', '1392', '1393', '1394', '1395', '1396',
|
| 2098 |
+
'1397', '1398', '1399', '1400', '1401', '1402', '1403', '1404', '1405',
|
| 2099 |
+
'1406', '1407', '1408', '1409', '1410', '1411', '1412', '1413', '1414',
|
| 2100 |
+
'1415', '1539', '1541', '1542', '1543', '17320', '4129', '4190',
|
| 2101 |
+
'4191', '4192', '4193', '4202', '4203', '4204', '4207', '4208', '4209',
|
| 2102 |
+
'4210', '4213', '4214', '4221', '4222', '4223', '4224', '4225', '4226',
|
| 2103 |
+
'4227', '4228', '4229', '4230', '4231', '4232', '4233', '4234', '4235',
|
| 2104 |
+
'4236', '4237', '4238', '4239', '4240', '4241', '4242', '4251', '4252',
|
| 2105 |
+
'4253', '4254', '4255', '4256', '4257', '4258', '4259', '4260', '4261',
|
| 2106 |
+
'4262', '4263', '4264', '6581', '6642', '6647', '6656', '6660', '6690',
|
| 2107 |
+
'6696', '6724', '6767', '6771', '6788', '6791', '6807', '6821', '6824',
|
| 2108 |
+
'6833', '6847', '6850', '6879', '6941', '7001', '7070', '7083', '7092',
|
| 2109 |
+
'7093', '7119', '7191', '7220', '7252', '7264', '7276', '7278', '7281',
|
| 2110 |
+
'7290', '7301', '7312', '7340', '7398', '7404', '7412', '7429', '7439',
|
| 2111 |
+
'7478', '7491', '7631', '7687', '7699', '7719', '7770', '7784', '7793',
|
| 2112 |
+
'7811', '7829', '7861', '7864', '7868', '7980', '7987', '7990', '8069',
|
| 2113 |
+
'8083', '8100', '8108', '8227', '8323', '8329', '8358', '8383', '8401',
|
| 2114 |
+
'8415', '8488', '8515', '8518', '8560', '8565', '8595', '8639', '8676',
|
| 2115 |
+
'8690', '8691', '8701', '8703', '8723', '8726', '8756', '8783', '8801',
|
| 2116 |
+
'8820', '8826', '8842', '8865', '8874', '8875', '8882', '8911', '8946',
|
| 2117 |
+
'8947', '8969', '8979', '8983')
|
| 2118 |
+
|
| 2119 |
+
masks = [DATA_DIR_BASE + i + '/garment_mask.png' for i in images]
|
| 2120 |
+
out = [DATA_DIR_BASE + i + '/garment_transparent.png' for i in images]
|
| 2121 |
+
|
| 2122 |
+
images = [DATA_DIR_BASE + i + '/original.jpg' for i in images]
|
| 2123 |
+
|
| 2124 |
+
for i in range(len(images)):
|
| 2125 |
+
image = images[i]
|
| 2126 |
+
image = load_image_torch(image)
|
| 2127 |
+
mask = do_infer_tensor2tensor(image, net)
|
| 2128 |
+
save_mask_torch(output_image_path=masks[i], mask=mask)
|
| 2129 |
+
do_merge(path_image=images[i], path_mask=masks[i], path_out=out[i])
|
| 2130 |
+
|
| 2131 |
+
# img = load_image_torch(
|
| 2132 |
+
# '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png'
|
| 2133 |
+
# )
|
| 2134 |
+
# # all_transforms = load_transforms()
|
| 2135 |
+
# masks = do_infer_tensor2tensor(img, net)
|
| 2136 |
+
# save_mask_torch(output_image_path='./tmp.png', mask=masks)
|
| 2137 |
+
#+end_src
|
| 2138 |
+
|
| 2139 |
+
** MVANet_inference unify
|
| 2140 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
|
| 2141 |
+
. "${HOME}/dbnew.sh"
|
| 2142 |
+
|
| 2143 |
+
(
|
| 2144 |
+
echo '#!/usr/bin/python3'
|
| 2145 |
+
cat \
|
| 2146 |
+
'./MVANet_inference.import.py' \
|
| 2147 |
+
'./MVANet_inference.function.py' \
|
| 2148 |
+
'./MVANet_inference.class.py' \
|
| 2149 |
+
'./MVANet_inference.execute.py' \
|
| 2150 |
+
| expand | yapf3 \
|
| 2151 |
+
| grep -v '#!/usr/bin/python3' \
|
| 2152 |
+
;
|
| 2153 |
+
) > './MVANet_inference.py' \
|
| 2154 |
+
;
|
| 2155 |
+
#+end_src
|
| 2156 |
+
|
| 2157 |
+
** MVANet_inference run
|
| 2158 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.run.sh
|
| 2159 |
+
. "${HOME}/dbnew.sh"
|
| 2160 |
+
python3 './MVANet_inference.py'
|
| 2161 |
+
#+end_src
|
| 2162 |
+
|
| 2163 |
+
* WORK SPACE
|
| 2164 |
+
|
| 2165 |
+
** elisp
|
| 2166 |
+
#+begin_src elisp
|
| 2167 |
+
(save-buffer)
|
| 2168 |
+
(org-babel-tangle)
|
| 2169 |
+
(shell-command "./MVANet_inference.unify.sh")
|
| 2170 |
+
#+end_src
|
| 2171 |
+
|
| 2172 |
+
#+RESULTS:
|
| 2173 |
+
: 0
|
| 2174 |
+
|
| 2175 |
+
** sh
|
| 2176 |
+
#+begin_src sh :shebang #!/bin/sh :results output
|
| 2177 |
+
realpath .
|
| 2178 |
+
cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/MVANet
|
| 2179 |
+
#+end_src
|
README.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Self Correction for Human Parsing
|
| 2 |
+
This is a copy of https://github.com/GoGoDuck912/Self-Correction-Human-Parsing
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
[](https://opensource.org/licenses/MIT)
|
| 7 |
+
|
| 8 |
+
An out-of-box human parsing representation extractor.
|
| 9 |
+
|
| 10 |
+
Our solution ranks 1st for all human parsing tracks (including single, multiple and video) in the third LIP challenge!
|
| 11 |
+
|
| 12 |
+

|
| 13 |
+
|
| 14 |
+
Features:
|
| 15 |
+
- [x] Out-of-box human parsing extractor for other downstream applications.
|
| 16 |
+
- [x] Pretrained model on three popular single person human parsing datasets.
|
| 17 |
+
- [x] Training and inferecne code.
|
| 18 |
+
- [x] Simple yet effective extension on multi-person and video human parsing tasks.
|
| 19 |
+
|
| 20 |
+
## Requirements
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
conda env create -f environment.yaml
|
| 24 |
+
conda activate schp
|
| 25 |
+
pip install -r requirements.txt
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Simple Out-of-Box Extractor
|
| 29 |
+
|
| 30 |
+
The easiest way to get started is to use our trained SCHP models on your own images to extract human parsing representations. Here we provided state-of-the-art [trained models](https://drive.google.com/drive/folders/1uOaQCpNtosIjEL2phQKEdiYd0Td18jNo?usp=sharing) on three popular datasets. Theses three datasets have different label system, you can choose the best one to fit on your own task.
|
| 31 |
+
|
| 32 |
+
**LIP** ([exp-schp-201908261155-lip.pth](https://drive.google.com/file/d/1k4dllHpu0bdx38J7H28rVVLpU-kOHmnH/view?usp=sharing))
|
| 33 |
+
|
| 34 |
+
* mIoU on LIP validation: **59.36 %**.
|
| 35 |
+
|
| 36 |
+
* LIP is the largest single person human parsing dataset with 50000+ images. This dataset focus more on the complicated real scenarios. LIP has 20 labels, including 'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe'.
|
| 37 |
+
|
| 38 |
+
**ATR** ([exp-schp-201908301523-atr.pth](https://drive.google.com/file/d/1ruJg4lqR_jgQPj-9K0PP-L2vJERYOxLP/view?usp=sharing))
|
| 39 |
+
|
| 40 |
+
* mIoU on ATR test: **82.29%**.
|
| 41 |
+
|
| 42 |
+
* ATR is a large single person human parsing dataset with 17000+ images. This dataset focus more on fashion AI. ATR has 18 labels, including 'Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'.
|
| 43 |
+
|
| 44 |
+
**Pascal-Person-Part** ([exp-schp-201908270938-pascal-person-part.pth](https://drive.google.com/file/d/1E5YwNKW2VOEayK9mWCS3Kpsxf-3z04ZE/view?usp=sharing))
|
| 45 |
+
|
| 46 |
+
* mIoU on Pascal-Person-Part validation: **71.46** %.
|
| 47 |
+
|
| 48 |
+
* Pascal Person Part is a tiny single person human parsing dataset with 3000+ images. This dataset focus more on body parts segmentation. Pascal Person Part has 7 labels, including 'Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'.
|
| 49 |
+
|
| 50 |
+
Choose one and have fun on your own task!
|
| 51 |
+
|
| 52 |
+
To extract the human parsing representation, simply put your own image in the `INPUT_PATH` folder, then download a pretrained model and run the following command. The output images with the same file name will be saved in `OUTPUT_PATH`
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
python simple_extractor.py --dataset [DATASET] --model-restore [CHECKPOINT_PATH] --input-dir [INPUT_PATH] --output-dir [OUTPUT_PATH]
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
**[Updated]** Here is also a [colab demo example](https://colab.research.google.com/drive/1JOwOPaChoc9GzyBi5FUEYTSaP2qxJl10?usp=sharing) for quick inference provided by [@levindabhi](https://github.com/levindabhi).
|
| 59 |
+
|
| 60 |
+
The `DATASET` command has three options, including 'lip', 'atr' and 'pascal'. Note each pixel in the output images denotes the predicted label number. The output images have the same size as the input ones. To better visualization, we put a palette with the output images. We suggest you to read the image with `PIL`.
|
| 61 |
+
|
| 62 |
+
If you need not only the final parsing images, but also the feature map representations. Add `--logits` command to save the output feature maps. These feature maps are the logits before softmax layer.
|
| 63 |
+
|
| 64 |
+
## Dataset Preparation
|
| 65 |
+
|
| 66 |
+
Please download the [LIP](http://sysu-hcp.net/lip/) dataset following the below structure.
|
| 67 |
+
|
| 68 |
+
```commandline
|
| 69 |
+
data/LIP
|
| 70 |
+
|--- train_imgaes # 30462 training single person images
|
| 71 |
+
|--- val_images # 10000 validation single person images
|
| 72 |
+
|--- train_segmentations # 30462 training annotations
|
| 73 |
+
|--- val_segmentations # 10000 training annotations
|
| 74 |
+
|--- train_id.txt # training image list
|
| 75 |
+
|--- val_id.txt # validation image list
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Training
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
python train.py
|
| 82 |
+
```
|
| 83 |
+
By default, the trained model will be saved in `./log` directory. Please read the arguments for more details.
|
| 84 |
+
|
| 85 |
+
## Evaluation
|
| 86 |
+
```
|
| 87 |
+
python evaluate.py --model-restore [CHECKPOINT_PATH]
|
| 88 |
+
```
|
| 89 |
+
CHECKPOINT_PATH should be the path of trained model.
|
| 90 |
+
|
| 91 |
+
## Extension on Multiple Human Parsing
|
| 92 |
+
|
| 93 |
+
Please read [MultipleHumanParsing.md](./mhp_extension/README.md) for more details.
|
| 94 |
+
|
| 95 |
+
## Citation
|
| 96 |
+
|
| 97 |
+
Please cite our work if you find this repo useful in your research.
|
| 98 |
+
|
| 99 |
+
```latex
|
| 100 |
+
@article{li2020self,
|
| 101 |
+
title={Self-Correction for Human Parsing},
|
| 102 |
+
author={Li, Peike and Xu, Yunqiu and Wei, Yunchao and Yang, Yi},
|
| 103 |
+
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
|
| 104 |
+
year={2020},
|
| 105 |
+
doi={10.1109/TPAMI.2020.3048039}}
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Visualization
|
| 109 |
+
|
| 110 |
+
* Source Image.
|
| 111 |
+

|
| 112 |
+
* LIP Parsing Result.
|
| 113 |
+

|
| 114 |
+
* ATR Parsing Result.
|
| 115 |
+

|
| 116 |
+
* Pascal-Person-Part Parsing Result.
|
| 117 |
+

|
| 118 |
+
* Source Image.
|
| 119 |
+

|
| 120 |
+
* Instance Human Mask.
|
| 121 |
+

|
| 122 |
+
* Global Human Parsing Result.
|
| 123 |
+

|
| 124 |
+
* Multiple Human Parsing Result.
|
| 125 |
+

|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
## Related
|
| 129 |
+
Our code adopts the [InplaceSyncBN](https://github.com/mapillary/inplace_abn) to save gpu memory cost.
|
| 130 |
+
|
| 131 |
+
There is also a [PaddlePaddle](https://github.com/PaddlePaddle/PaddleSeg/tree/develop/contrib/ACE2P) Implementation of this project.
|
checkpoints/AEMatter/AEM_RWA.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a475193549365ff3c892a85d2f4ca90ece2ac8dc4de4a39df250c76ca870d280
|
| 3 |
+
size 205399637
|
checkpoints/MVANet/garment.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7604ed46e06fbcff3b8f38c8934d253617171d02aecdd028f0f01086d9344893
|
| 3 |
+
size 380785263
|
checkpoints/MVANet/skin.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c71afcdd9cb1be73e43d84f5ffc2ae12b4964cc13c8460fc0adb6d52a0603cd4
|
| 3 |
+
size 380782803
|
checkpoints/Model_80.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffec20a382b0a1832786438475e8b912a03be727a0e3197e7ab039153fb3bc46
|
| 3 |
+
size 386621643
|
checkpoints/StableDiffusion/90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:010be7341cd98a136da775330ba3eb4e87025c6cfd2f5455dc64daee2200ae98
|
| 3 |
+
size 7105348616
|
checkpoints/StableDiffusion/b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b1689257e6e1b2e61544b1a41fc114e7d798f68854b3f875cd52070bfe1fbc00
|
| 3 |
+
size 6938072258
|
checkpoints/StableDiffusion/hash
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b Juggernaut_X_RunDiffusion_Hyper.safetensors
|
| 2 |
+
b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30 juggernautXL_versionXInpaint.safetensors
|
checkpoints/atr.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9d7c91ce3b4e7133df56b599fc817b533e3439c5e8d282a59126d2fda339a2a
|
| 3 |
+
size 267445237
|
checkpoints/lip.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24fa3254ceeb74c8435458994a64b522fb439a3635b7b86ff470457e0413da00
|
| 3 |
+
size 267449349
|
checkpoints/pascal.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b03d343c39fb0696f75d45c44c67b2fc23f5d0bf0925a82c0465e415799fa85
|
| 3 |
+
size 267422621
|
datasets/__init__.py
ADDED
|
File without changes
|
datasets/datasets.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
@Author : Peike Li
|
| 6 |
+
@Contact : peike.li@yahoo.com
|
| 7 |
+
@File : datasets.py
|
| 8 |
+
@Time : 8/4/19 3:35 PM
|
| 9 |
+
@Desc :
|
| 10 |
+
@License : This source code is licensed under the license found in the
|
| 11 |
+
LICENSE file in the root directory of this source tree.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import numpy as np
|
| 16 |
+
import random
|
| 17 |
+
import torch
|
| 18 |
+
import cv2
|
| 19 |
+
from torch.utils import data
|
| 20 |
+
from utils.transforms import get_affine_transform
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LIPDataSet(data.Dataset):
|
| 24 |
+
def __init__(self, root, dataset, crop_size=[473, 473], scale_factor=0.25,
|
| 25 |
+
rotation_factor=30, ignore_label=255, transform=None):
|
| 26 |
+
self.root = root
|
| 27 |
+
self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
|
| 28 |
+
self.crop_size = np.asarray(crop_size)
|
| 29 |
+
self.ignore_label = ignore_label
|
| 30 |
+
self.scale_factor = scale_factor
|
| 31 |
+
self.rotation_factor = rotation_factor
|
| 32 |
+
self.flip_prob = 0.5
|
| 33 |
+
self.transform = transform
|
| 34 |
+
self.dataset = dataset
|
| 35 |
+
|
| 36 |
+
list_path = os.path.join(self.root, self.dataset + '_id.txt')
|
| 37 |
+
train_list = [i_id.strip() for i_id in open(list_path)]
|
| 38 |
+
|
| 39 |
+
self.train_list = train_list
|
| 40 |
+
self.number_samples = len(self.train_list)
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return self.number_samples
|
| 44 |
+
|
| 45 |
+
def _box2cs(self, box):
|
| 46 |
+
x, y, w, h = box[:4]
|
| 47 |
+
return self._xywh2cs(x, y, w, h)
|
| 48 |
+
|
| 49 |
+
def _xywh2cs(self, x, y, w, h):
|
| 50 |
+
center = np.zeros((2), dtype=np.float32)
|
| 51 |
+
center[0] = x + w * 0.5
|
| 52 |
+
center[1] = y + h * 0.5
|
| 53 |
+
if w > self.aspect_ratio * h:
|
| 54 |
+
h = w * 1.0 / self.aspect_ratio
|
| 55 |
+
elif w < self.aspect_ratio * h:
|
| 56 |
+
w = h * self.aspect_ratio
|
| 57 |
+
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
|
| 58 |
+
return center, scale
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, index):
|
| 61 |
+
train_item = self.train_list[index]
|
| 62 |
+
|
| 63 |
+
im_path = os.path.join(self.root, self.dataset + '_images', train_item + '.jpg')
|
| 64 |
+
parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', train_item + '.png')
|
| 65 |
+
|
| 66 |
+
im = cv2.imread(im_path, cv2.IMREAD_COLOR)
|
| 67 |
+
h, w, _ = im.shape
|
| 68 |
+
parsing_anno = np.zeros((h, w), dtype=np.long)
|
| 69 |
+
|
| 70 |
+
# Get person center and scale
|
| 71 |
+
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
| 72 |
+
r = 0
|
| 73 |
+
|
| 74 |
+
if self.dataset != 'test':
|
| 75 |
+
# Get pose annotation
|
| 76 |
+
parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE)
|
| 77 |
+
if self.dataset == 'train' or self.dataset == 'trainval':
|
| 78 |
+
sf = self.scale_factor
|
| 79 |
+
rf = self.rotation_factor
|
| 80 |
+
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
|
| 81 |
+
r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) if random.random() <= 0.6 else 0
|
| 82 |
+
|
| 83 |
+
if random.random() <= self.flip_prob:
|
| 84 |
+
im = im[:, ::-1, :]
|
| 85 |
+
parsing_anno = parsing_anno[:, ::-1]
|
| 86 |
+
person_center[0] = im.shape[1] - person_center[0] - 1
|
| 87 |
+
right_idx = [15, 17, 19]
|
| 88 |
+
left_idx = [14, 16, 18]
|
| 89 |
+
for i in range(0, 3):
|
| 90 |
+
right_pos = np.where(parsing_anno == right_idx[i])
|
| 91 |
+
left_pos = np.where(parsing_anno == left_idx[i])
|
| 92 |
+
parsing_anno[right_pos[0], right_pos[1]] = left_idx[i]
|
| 93 |
+
parsing_anno[left_pos[0], left_pos[1]] = right_idx[i]
|
| 94 |
+
|
| 95 |
+
trans = get_affine_transform(person_center, s, r, self.crop_size)
|
| 96 |
+
input = cv2.warpAffine(
|
| 97 |
+
im,
|
| 98 |
+
trans,
|
| 99 |
+
(int(self.crop_size[1]), int(self.crop_size[0])),
|
| 100 |
+
flags=cv2.INTER_LINEAR,
|
| 101 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 102 |
+
borderValue=(0, 0, 0))
|
| 103 |
+
|
| 104 |
+
if self.transform:
|
| 105 |
+
input = self.transform(input)
|
| 106 |
+
|
| 107 |
+
meta = {
|
| 108 |
+
'name': train_item,
|
| 109 |
+
'center': person_center,
|
| 110 |
+
'height': h,
|
| 111 |
+
'width': w,
|
| 112 |
+
'scale': s,
|
| 113 |
+
'rotation': r
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
if self.dataset == 'val' or self.dataset == 'test':
|
| 117 |
+
return input, meta
|
| 118 |
+
else:
|
| 119 |
+
label_parsing = cv2.warpAffine(
|
| 120 |
+
parsing_anno,
|
| 121 |
+
trans,
|
| 122 |
+
(int(self.crop_size[1]), int(self.crop_size[0])),
|
| 123 |
+
flags=cv2.INTER_NEAREST,
|
| 124 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 125 |
+
borderValue=(255))
|
| 126 |
+
|
| 127 |
+
label_parsing = torch.from_numpy(label_parsing)
|
| 128 |
+
|
| 129 |
+
return input, label_parsing, meta
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class LIPDataValSet(data.Dataset):
|
| 133 |
+
def __init__(self, root, dataset='val', crop_size=[473, 473], transform=None, flip=False):
|
| 134 |
+
self.root = root
|
| 135 |
+
self.crop_size = crop_size
|
| 136 |
+
self.transform = transform
|
| 137 |
+
self.flip = flip
|
| 138 |
+
self.dataset = dataset
|
| 139 |
+
self.root = root
|
| 140 |
+
self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
|
| 141 |
+
self.crop_size = np.asarray(crop_size)
|
| 142 |
+
|
| 143 |
+
list_path = os.path.join(self.root, self.dataset + '_id.txt')
|
| 144 |
+
val_list = [i_id.strip() for i_id in open(list_path)]
|
| 145 |
+
|
| 146 |
+
self.val_list = val_list
|
| 147 |
+
self.number_samples = len(self.val_list)
|
| 148 |
+
|
| 149 |
+
def __len__(self):
|
| 150 |
+
return len(self.val_list)
|
| 151 |
+
|
| 152 |
+
def _box2cs(self, box):
|
| 153 |
+
x, y, w, h = box[:4]
|
| 154 |
+
return self._xywh2cs(x, y, w, h)
|
| 155 |
+
|
| 156 |
+
def _xywh2cs(self, x, y, w, h):
|
| 157 |
+
center = np.zeros((2), dtype=np.float32)
|
| 158 |
+
center[0] = x + w * 0.5
|
| 159 |
+
center[1] = y + h * 0.5
|
| 160 |
+
if w > self.aspect_ratio * h:
|
| 161 |
+
h = w * 1.0 / self.aspect_ratio
|
| 162 |
+
elif w < self.aspect_ratio * h:
|
| 163 |
+
w = h * self.aspect_ratio
|
| 164 |
+
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
|
| 165 |
+
|
| 166 |
+
return center, scale
|
| 167 |
+
|
| 168 |
+
def __getitem__(self, index):
|
| 169 |
+
val_item = self.val_list[index]
|
| 170 |
+
# Load training image
|
| 171 |
+
im_path = os.path.join(self.root, self.dataset + '_images', val_item + '.jpg')
|
| 172 |
+
im = cv2.imread(im_path, cv2.IMREAD_COLOR)
|
| 173 |
+
h, w, _ = im.shape
|
| 174 |
+
# Get person center and scale
|
| 175 |
+
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
| 176 |
+
r = 0
|
| 177 |
+
trans = get_affine_transform(person_center, s, r, self.crop_size)
|
| 178 |
+
input = cv2.warpAffine(
|
| 179 |
+
im,
|
| 180 |
+
trans,
|
| 181 |
+
(int(self.crop_size[1]), int(self.crop_size[0])),
|
| 182 |
+
flags=cv2.INTER_LINEAR,
|
| 183 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 184 |
+
borderValue=(0, 0, 0))
|
| 185 |
+
input = self.transform(input)
|
| 186 |
+
flip_input = input.flip(dims=[-1])
|
| 187 |
+
if self.flip:
|
| 188 |
+
batch_input_im = torch.stack([input, flip_input])
|
| 189 |
+
else:
|
| 190 |
+
batch_input_im = input
|
| 191 |
+
|
| 192 |
+
meta = {
|
| 193 |
+
'name': val_item,
|
| 194 |
+
'center': person_center,
|
| 195 |
+
'height': h,
|
| 196 |
+
'width': w,
|
| 197 |
+
'scale': s,
|
| 198 |
+
'rotation': r
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
return batch_input_im, meta
|
datasets/simple_extractor_dataset.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
@Author : Peike Li
|
| 6 |
+
@Contact : peike.li@yahoo.com
|
| 7 |
+
@File : dataset.py
|
| 8 |
+
@Time : 8/30/19 9:12 PM
|
| 9 |
+
@Desc : Dataset Definition
|
| 10 |
+
@License : This source code is licensed under the license found in the
|
| 11 |
+
LICENSE file in the root directory of this source tree.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import cv2
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from torch.utils import data
|
| 19 |
+
from utils.transforms import get_affine_transform
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SimpleFolderDataset(data.Dataset):
|
| 23 |
+
def __init__(self, root, input_size=[512, 512], transform=None):
|
| 24 |
+
self.root = root
|
| 25 |
+
self.input_size = input_size
|
| 26 |
+
self.transform = transform
|
| 27 |
+
self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
|
| 28 |
+
self.input_size = np.asarray(input_size)
|
| 29 |
+
|
| 30 |
+
self.file_list = os.listdir(self.root)
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.file_list)
|
| 34 |
+
|
| 35 |
+
def _box2cs(self, box):
|
| 36 |
+
x, y, w, h = box[:4]
|
| 37 |
+
return self._xywh2cs(x, y, w, h)
|
| 38 |
+
|
| 39 |
+
def _xywh2cs(self, x, y, w, h):
|
| 40 |
+
center = np.zeros((2), dtype=np.float32)
|
| 41 |
+
center[0] = x + w * 0.5
|
| 42 |
+
center[1] = y + h * 0.5
|
| 43 |
+
if w > self.aspect_ratio * h:
|
| 44 |
+
h = w * 1.0 / self.aspect_ratio
|
| 45 |
+
elif w < self.aspect_ratio * h:
|
| 46 |
+
w = h * self.aspect_ratio
|
| 47 |
+
scale = np.array([w, h], dtype=np.float32)
|
| 48 |
+
return center, scale
|
| 49 |
+
|
| 50 |
+
def __getitem__(self, index):
|
| 51 |
+
img_name = self.file_list[index]
|
| 52 |
+
img_path = os.path.join(self.root, img_name)
|
| 53 |
+
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
| 54 |
+
h, w, _ = img.shape
|
| 55 |
+
|
| 56 |
+
# Get person center and scale
|
| 57 |
+
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
| 58 |
+
r = 0
|
| 59 |
+
trans = get_affine_transform(person_center, s, r, self.input_size)
|
| 60 |
+
input = cv2.warpAffine(
|
| 61 |
+
img,
|
| 62 |
+
trans,
|
| 63 |
+
(int(self.input_size[1]), int(self.input_size[0])),
|
| 64 |
+
flags=cv2.INTER_LINEAR,
|
| 65 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 66 |
+
borderValue=(0, 0, 0))
|
| 67 |
+
|
| 68 |
+
input = self.transform(input)
|
| 69 |
+
meta = {
|
| 70 |
+
'name': img_name,
|
| 71 |
+
'center': person_center,
|
| 72 |
+
'height': h,
|
| 73 |
+
'width': w,
|
| 74 |
+
'scale': s,
|
| 75 |
+
'rotation': r
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
return input, meta
|
datasets/target_generation.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def generate_edge_tensor(label, edge_width=3):
|
| 6 |
+
label = label.type(torch.cuda.FloatTensor)
|
| 7 |
+
if len(label.shape) == 2:
|
| 8 |
+
label = label.unsqueeze(0)
|
| 9 |
+
n, h, w = label.shape
|
| 10 |
+
edge = torch.zeros(label.shape, dtype=torch.float).cuda()
|
| 11 |
+
# right
|
| 12 |
+
edge_right = edge[:, 1:h, :]
|
| 13 |
+
edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255)
|
| 14 |
+
& (label[:, :h - 1, :] != 255)] = 1
|
| 15 |
+
|
| 16 |
+
# up
|
| 17 |
+
edge_up = edge[:, :, :w - 1]
|
| 18 |
+
edge_up[(label[:, :, :w - 1] != label[:, :, 1:w])
|
| 19 |
+
& (label[:, :, :w - 1] != 255)
|
| 20 |
+
& (label[:, :, 1:w] != 255)] = 1
|
| 21 |
+
|
| 22 |
+
# upright
|
| 23 |
+
edge_upright = edge[:, :h - 1, :w - 1]
|
| 24 |
+
edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w])
|
| 25 |
+
& (label[:, :h - 1, :w - 1] != 255)
|
| 26 |
+
& (label[:, 1:h, 1:w] != 255)] = 1
|
| 27 |
+
|
| 28 |
+
# bottomright
|
| 29 |
+
edge_bottomright = edge[:, :h - 1, 1:w]
|
| 30 |
+
edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1])
|
| 31 |
+
& (label[:, :h - 1, 1:w] != 255)
|
| 32 |
+
& (label[:, 1:h, :w - 1] != 255)] = 1
|
| 33 |
+
|
| 34 |
+
kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda()
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
edge = edge.unsqueeze(1)
|
| 37 |
+
edge = F.conv2d(edge, kernel, stride=1, padding=1)
|
| 38 |
+
edge[edge!=0] = 1
|
| 39 |
+
edge = edge.squeeze()
|
| 40 |
+
return edge
|
demo/demo.jpg
ADDED
|
Git LFS Details
|
demo/demo_atr.png
ADDED
|
demo/demo_lip.png
ADDED
|
demo/demo_pascal.png
ADDED
|
demo/lip-visualization.jpg
ADDED
|
Git LFS Details
|
environment.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: schp
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- blas=1.0=mkl
|
| 8 |
+
- ca-certificates=2020.12.8=h06a4308_0
|
| 9 |
+
- certifi=2020.12.5=py38h06a4308_0
|
| 10 |
+
- cudatoolkit=10.1.243=h6bb024c_0
|
| 11 |
+
- freetype=2.10.4=h5ab3b9f_0
|
| 12 |
+
- intel-openmp=2020.2=254
|
| 13 |
+
- jpeg=9b=h024ee3a_2
|
| 14 |
+
- lcms2=2.11=h396b838_0
|
| 15 |
+
- ld_impl_linux-64=2.33.1=h53a641e_7
|
| 16 |
+
- libedit=3.1.20191231=h14c3975_1
|
| 17 |
+
- libffi=3.3=he6710b0_2
|
| 18 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
| 19 |
+
- libpng=1.6.37=hbc83047_0
|
| 20 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
| 21 |
+
- libtiff=4.1.0=h2733197_1
|
| 22 |
+
- lz4-c=1.9.2=heb0550a_3
|
| 23 |
+
- mkl=2020.2=256
|
| 24 |
+
- mkl-service=2.3.0=py38he904b0f_0
|
| 25 |
+
- mkl_fft=1.2.0=py38h23d657b_0
|
| 26 |
+
- mkl_random=1.1.1=py38h0573a6f_0
|
| 27 |
+
- ncurses=6.2=he6710b0_1
|
| 28 |
+
- ninja=1.10.2=py38hff7bd54_0
|
| 29 |
+
- numpy=1.19.2=py38h54aff64_0
|
| 30 |
+
- numpy-base=1.19.2=py38hfa32c7d_0
|
| 31 |
+
- olefile=0.46=py_0
|
| 32 |
+
- openssl=1.1.1i=h27cfd23_0
|
| 33 |
+
- pillow=8.0.1=py38he98fc37_0
|
| 34 |
+
- pip=20.3.3=py38h06a4308_0
|
| 35 |
+
- python=3.8.5=h7579374_1
|
| 36 |
+
- readline=8.0=h7b6447c_0
|
| 37 |
+
- setuptools=51.0.0=py38h06a4308_2
|
| 38 |
+
- six=1.15.0=py38h06a4308_0
|
| 39 |
+
- sqlite=3.33.0=h62c20be_0
|
| 40 |
+
- tk=8.6.10=hbc83047_0
|
| 41 |
+
- tqdm=4.55.0=pyhd3eb1b0_0
|
| 42 |
+
- wheel=0.36.2=pyhd3eb1b0_0
|
| 43 |
+
- xz=5.2.5=h7b6447c_0
|
| 44 |
+
- zlib=1.2.11=h7b6447c_3
|
| 45 |
+
- zstd=1.4.5=h9ceee32_0
|
| 46 |
+
- pytorch=1.5.1=py3.8_cuda10.1.243_cudnn7.6.3_0
|
| 47 |
+
- torchvision=0.6.1=py38_cu101
|
| 48 |
+
prefix: /home/peike/opt/anaconda3/envs/schp
|
| 49 |
+
|
evaluate.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
@Author : Peike Li
|
| 6 |
+
@Contact : peike.li@yahoo.com
|
| 7 |
+
@File : evaluate.py
|
| 8 |
+
@Time : 8/4/19 3:36 PM
|
| 9 |
+
@Desc :
|
| 10 |
+
@License : This source code is licensed under the license found in the
|
| 11 |
+
LICENSE file in the root directory of this source tree.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import argparse
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from torch.utils import data
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from PIL import Image as PILImage
|
| 22 |
+
import torchvision.transforms as transforms
|
| 23 |
+
import torch.backends.cudnn as cudnn
|
| 24 |
+
|
| 25 |
+
import networks
|
| 26 |
+
from datasets.datasets import LIPDataValSet
|
| 27 |
+
from utils.miou import compute_mean_ioU
|
| 28 |
+
from utils.transforms import BGR2RGB_transform
|
| 29 |
+
from utils.transforms import transform_parsing
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_arguments():
|
| 33 |
+
"""Parse all the arguments provided from the CLI.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
A list of parsed arguments.
|
| 37 |
+
"""
|
| 38 |
+
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
|
| 39 |
+
|
| 40 |
+
# Network Structure
|
| 41 |
+
parser.add_argument("--arch", type=str, default='resnet101')
|
| 42 |
+
# Data Preference
|
| 43 |
+
parser.add_argument("--data-dir", type=str, default='./data/LIP')
|
| 44 |
+
parser.add_argument("--batch-size", type=int, default=1)
|
| 45 |
+
parser.add_argument("--input-size", type=str, default='473,473')
|
| 46 |
+
parser.add_argument("--num-classes", type=int, default=20)
|
| 47 |
+
parser.add_argument("--ignore-label", type=int, default=255)
|
| 48 |
+
parser.add_argument("--random-mirror", action="store_true")
|
| 49 |
+
parser.add_argument("--random-scale", action="store_true")
|
| 50 |
+
# Evaluation Preference
|
| 51 |
+
parser.add_argument("--log-dir", type=str, default='./log')
|
| 52 |
+
parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar')
|
| 53 |
+
parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
|
| 54 |
+
parser.add_argument("--save-results", action="store_true", help="whether to save the results.")
|
| 55 |
+
parser.add_argument("--flip", action="store_true", help="random flip during the test.")
|
| 56 |
+
parser.add_argument("--multi-scales", type=str, default='1', help="multiple scales during the test")
|
| 57 |
+
return parser.parse_args()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_palette(num_cls):
|
| 61 |
+
""" Returns the color map for visualizing the segmentation mask.
|
| 62 |
+
Args:
|
| 63 |
+
num_cls: Number of classes
|
| 64 |
+
Returns:
|
| 65 |
+
The color map
|
| 66 |
+
"""
|
| 67 |
+
n = num_cls
|
| 68 |
+
palette = [0] * (n * 3)
|
| 69 |
+
for j in range(0, n):
|
| 70 |
+
lab = j
|
| 71 |
+
palette[j * 3 + 0] = 0
|
| 72 |
+
palette[j * 3 + 1] = 0
|
| 73 |
+
palette[j * 3 + 2] = 0
|
| 74 |
+
i = 0
|
| 75 |
+
while lab:
|
| 76 |
+
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
|
| 77 |
+
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
|
| 78 |
+
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
|
| 79 |
+
i += 1
|
| 80 |
+
lab >>= 3
|
| 81 |
+
return palette
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def multi_scale_testing(model, batch_input_im, crop_size=[473, 473], flip=True, multi_scales=[1]):
|
| 85 |
+
flipped_idx = (15, 14, 17, 16, 19, 18)
|
| 86 |
+
if len(batch_input_im.shape) > 4:
|
| 87 |
+
batch_input_im = batch_input_im.squeeze()
|
| 88 |
+
if len(batch_input_im.shape) == 3:
|
| 89 |
+
batch_input_im = batch_input_im.unsqueeze(0)
|
| 90 |
+
|
| 91 |
+
interp = torch.nn.Upsample(size=crop_size, mode='bilinear', align_corners=True)
|
| 92 |
+
ms_outputs = []
|
| 93 |
+
for s in multi_scales:
|
| 94 |
+
interp_im = torch.nn.Upsample(scale_factor=s, mode='bilinear', align_corners=True)
|
| 95 |
+
scaled_im = interp_im(batch_input_im)
|
| 96 |
+
parsing_output = model(scaled_im)
|
| 97 |
+
parsing_output = parsing_output[0][-1]
|
| 98 |
+
output = parsing_output[0]
|
| 99 |
+
if flip:
|
| 100 |
+
flipped_output = parsing_output[1]
|
| 101 |
+
flipped_output[14:20, :, :] = flipped_output[flipped_idx, :, :]
|
| 102 |
+
output += flipped_output.flip(dims=[-1])
|
| 103 |
+
output *= 0.5
|
| 104 |
+
output = interp(output.unsqueeze(0))
|
| 105 |
+
ms_outputs.append(output[0])
|
| 106 |
+
ms_fused_parsing_output = torch.stack(ms_outputs)
|
| 107 |
+
ms_fused_parsing_output = ms_fused_parsing_output.mean(0)
|
| 108 |
+
ms_fused_parsing_output = ms_fused_parsing_output.permute(1, 2, 0) # HWC
|
| 109 |
+
parsing = torch.argmax(ms_fused_parsing_output, dim=2)
|
| 110 |
+
parsing = parsing.data.cpu().numpy()
|
| 111 |
+
ms_fused_parsing_output = ms_fused_parsing_output.data.cpu().numpy()
|
| 112 |
+
return parsing, ms_fused_parsing_output
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def main():
|
| 116 |
+
"""Create the model and start the evaluation process."""
|
| 117 |
+
args = get_arguments()
|
| 118 |
+
multi_scales = [float(i) for i in args.multi_scales.split(',')]
|
| 119 |
+
gpus = [int(i) for i in args.gpu.split(',')]
|
| 120 |
+
assert len(gpus) == 1
|
| 121 |
+
if not args.gpu == 'None':
|
| 122 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
| 123 |
+
|
| 124 |
+
cudnn.benchmark = True
|
| 125 |
+
cudnn.enabled = True
|
| 126 |
+
|
| 127 |
+
h, w = map(int, args.input_size.split(','))
|
| 128 |
+
input_size = [h, w]
|
| 129 |
+
|
| 130 |
+
model = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=None)
|
| 131 |
+
|
| 132 |
+
IMAGE_MEAN = model.mean
|
| 133 |
+
IMAGE_STD = model.std
|
| 134 |
+
INPUT_SPACE = model.input_space
|
| 135 |
+
print('image mean: {}'.format(IMAGE_MEAN))
|
| 136 |
+
print('image std: {}'.format(IMAGE_STD))
|
| 137 |
+
print('input space:{}'.format(INPUT_SPACE))
|
| 138 |
+
if INPUT_SPACE == 'BGR':
|
| 139 |
+
print('BGR Transformation')
|
| 140 |
+
transform = transforms.Compose([
|
| 141 |
+
transforms.ToTensor(),
|
| 142 |
+
transforms.Normalize(mean=IMAGE_MEAN,
|
| 143 |
+
std=IMAGE_STD),
|
| 144 |
+
|
| 145 |
+
])
|
| 146 |
+
if INPUT_SPACE == 'RGB':
|
| 147 |
+
print('RGB Transformation')
|
| 148 |
+
transform = transforms.Compose([
|
| 149 |
+
transforms.ToTensor(),
|
| 150 |
+
BGR2RGB_transform(),
|
| 151 |
+
transforms.Normalize(mean=IMAGE_MEAN,
|
| 152 |
+
std=IMAGE_STD),
|
| 153 |
+
])
|
| 154 |
+
|
| 155 |
+
# Data loader
|
| 156 |
+
lip_test_dataset = LIPDataValSet(args.data_dir, 'val', crop_size=input_size, transform=transform, flip=args.flip)
|
| 157 |
+
num_samples = len(lip_test_dataset)
|
| 158 |
+
print('Totoal testing sample numbers: {}'.format(num_samples))
|
| 159 |
+
testloader = data.DataLoader(lip_test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
|
| 160 |
+
|
| 161 |
+
# Load model weight
|
| 162 |
+
state_dict = torch.load(args.model_restore)['state_dict']
|
| 163 |
+
from collections import OrderedDict
|
| 164 |
+
new_state_dict = OrderedDict()
|
| 165 |
+
for k, v in state_dict.items():
|
| 166 |
+
name = k[7:] # remove `module.`
|
| 167 |
+
new_state_dict[name] = v
|
| 168 |
+
model.load_state_dict(new_state_dict)
|
| 169 |
+
model.cuda()
|
| 170 |
+
model.eval()
|
| 171 |
+
|
| 172 |
+
sp_results_dir = os.path.join(args.log_dir, 'sp_results')
|
| 173 |
+
if not os.path.exists(sp_results_dir):
|
| 174 |
+
os.makedirs(sp_results_dir)
|
| 175 |
+
|
| 176 |
+
palette = get_palette(20)
|
| 177 |
+
parsing_preds = []
|
| 178 |
+
scales = np.zeros((num_samples, 2), dtype=np.float32)
|
| 179 |
+
centers = np.zeros((num_samples, 2), dtype=np.int32)
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
for idx, batch in enumerate(tqdm(testloader)):
|
| 182 |
+
image, meta = batch
|
| 183 |
+
if (len(image.shape) > 4):
|
| 184 |
+
image = image.squeeze()
|
| 185 |
+
im_name = meta['name'][0]
|
| 186 |
+
c = meta['center'].numpy()[0]
|
| 187 |
+
s = meta['scale'].numpy()[0]
|
| 188 |
+
w = meta['width'].numpy()[0]
|
| 189 |
+
h = meta['height'].numpy()[0]
|
| 190 |
+
scales[idx, :] = s
|
| 191 |
+
centers[idx, :] = c
|
| 192 |
+
parsing, logits = multi_scale_testing(model, image.cuda(), crop_size=input_size, flip=args.flip,
|
| 193 |
+
multi_scales=multi_scales)
|
| 194 |
+
if args.save_results:
|
| 195 |
+
parsing_result = transform_parsing(parsing, c, s, w, h, input_size)
|
| 196 |
+
parsing_result_path = os.path.join(sp_results_dir, im_name + '.png')
|
| 197 |
+
output_im = PILImage.fromarray(np.asarray(parsing_result, dtype=np.uint8))
|
| 198 |
+
output_im.putpalette(palette)
|
| 199 |
+
output_im.save(parsing_result_path)
|
| 200 |
+
|
| 201 |
+
parsing_preds.append(parsing)
|
| 202 |
+
assert len(parsing_preds) == num_samples
|
| 203 |
+
mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size)
|
| 204 |
+
print(mIoU)
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == '__main__':
|
| 209 |
+
main()
|
main.org
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
* COMMENT WORK SPACE
|
| 2 |
+
cd $HOME/HUGGINGFACE/aravindhv10/Self-Correction-Human-Parsing
|
| 3 |
+
|
| 4 |
+
** ELISP
|
| 5 |
+
#+begin_src elisp
|
| 6 |
+
(save-buffer)
|
| 7 |
+
(save-some-buffers)
|
| 8 |
+
(org-babel-tangle)
|
| 9 |
+
(shell-command "./work.sh" "output_log_work")
|
| 10 |
+
#+end_src
|
| 11 |
+
|
| 12 |
+
#+RESULTS:
|
| 13 |
+
: 0
|
| 14 |
+
|
| 15 |
+
** ELISP
|
| 16 |
+
#+begin_src elisp
|
| 17 |
+
(shell-command "git status" "output_log_git_status")
|
| 18 |
+
#+end_src
|
| 19 |
+
|
| 20 |
+
#+RESULTS:
|
| 21 |
+
: 0
|
| 22 |
+
|
| 23 |
+
** ELISP
|
| 24 |
+
#+begin_src elisp
|
| 25 |
+
(shell-command "./commit_and_push.sh" "output_log_commit_and_push")
|
| 26 |
+
#+end_src
|
| 27 |
+
|
| 28 |
+
#+RESULTS:
|
| 29 |
+
: 0
|
| 30 |
+
|
| 31 |
+
* Commit and push
|
| 32 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./commit_and_push.sh
|
| 33 |
+
git commit -m 'Routine updates'
|
| 34 |
+
git push
|
| 35 |
+
#+end_src
|
| 36 |
+
|
| 37 |
+
* Main script to do everything
|
| 38 |
+
#+begin_src sh :shebang #!/bin/sh :results output :tangle ./work.sh
|
| 39 |
+
do_ignore(){
|
| 40 |
+
'sed' 's@^@/@g' './rm.txt';
|
| 41 |
+
'cat' './gitignore.txt';
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
do_add(){
|
| 45 |
+
'sed' 's@^@("git" "lfs" "track" "./@g;s@$@");@g' './git_lfs_track.txt' ;
|
| 46 |
+
'cat' './git_add.txt' './git_lfs_track.txt' | \
|
| 47 |
+
'sed' 's@^@("git" "add" "./@g;s@$@");@g' ;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
do_rm(){
|
| 51 |
+
'sed' 's@^@("rm" "-vf" "--" "./@g ; s@$@");@g' './rm.txt' ;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
all_commands(){
|
| 55 |
+
do_add
|
| 56 |
+
do_rm
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
do_all(){
|
| 60 |
+
do_ignore > './.gitignore'
|
| 61 |
+
all_commands | sh
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
do_all
|
| 65 |
+
#+end_src
|
| 66 |
+
|
| 67 |
+
* List of large files
|
| 68 |
+
#+begin_src conf :tangle ./git_lfs_track.txt
|
| 69 |
+
checkpoints/AEMatter/AEM_RWA.ckpt
|
| 70 |
+
checkpoints/atr.pth
|
| 71 |
+
checkpoints/lip.pth
|
| 72 |
+
checkpoints/Model_80.pth
|
| 73 |
+
checkpoints/MVANet/garment.pth
|
| 74 |
+
checkpoints/MVANet/skin.pth
|
| 75 |
+
checkpoints/pascal.pth
|
| 76 |
+
checkpoints/StableDiffusion/90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b
|
| 77 |
+
checkpoints/StableDiffusion/b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30
|
| 78 |
+
demo/demo_atr.png
|
| 79 |
+
demo/demo.jpg
|
| 80 |
+
demo/demo_lip.png
|
| 81 |
+
demo/demo_pascal.png
|
| 82 |
+
demo/lip-visualization.jpg
|
| 83 |
+
#+end_src
|
| 84 |
+
|
| 85 |
+
* List of source files to add
|
| 86 |
+
#+begin_src conf :tangle ./git_add.txt
|
| 87 |
+
checkpoints/StableDiffusion/hash
|
| 88 |
+
ComfyUI_AEMatter/AEMatter.py
|
| 89 |
+
ComfyUI_AEMatter/AEMatter.run.sh
|
| 90 |
+
ComfyUI_AEMatter/__init__.py
|
| 91 |
+
ComfyUI_AEMatter/README.org
|
| 92 |
+
ComfyUI_MVANet/download.sh
|
| 93 |
+
ComfyUI_MVANet/__init__.py
|
| 94 |
+
ComfyUI_MVANet/MVANet_inference.py
|
| 95 |
+
ComfyUI_MVANet/MVANet_inference.run.sh
|
| 96 |
+
ComfyUI_MVANet/README.org
|
| 97 |
+
ComfyUI_MVANet/requirements.txt
|
| 98 |
+
datasets/datasets.py
|
| 99 |
+
datasets/__init__.py
|
| 100 |
+
datasets/simple_extractor_dataset.py
|
| 101 |
+
datasets/target_generation.py
|
| 102 |
+
environment.yaml
|
| 103 |
+
evaluate.py
|
| 104 |
+
.gitattributes
|
| 105 |
+
.gitignore
|
| 106 |
+
LICENSE
|
| 107 |
+
main.org
|
| 108 |
+
mhp_extension/coco_style_annotation_creator/human_to_coco.py
|
| 109 |
+
mhp_extension/coco_style_annotation_creator/pycococreatortools.py
|
| 110 |
+
mhp_extension/coco_style_annotation_creator/test_human2coco_format.py
|
| 111 |
+
mhp_extension/demo.ipynb
|
| 112 |
+
mhp_extension/detectron2/.circleci/config.yml
|
| 113 |
+
mhp_extension/detectron2/.clang-format
|
| 114 |
+
mhp_extension/detectron2/configs/Base-RCNN-C4.yaml
|
| 115 |
+
mhp_extension/detectron2/configs/Base-RCNN-DilatedC5.yaml
|
| 116 |
+
mhp_extension/detectron2/configs/Base-RCNN-FPN.yaml
|
| 117 |
+
mhp_extension/detectron2/configs/Base-RetinaNet.yaml
|
| 118 |
+
mhp_extension/detectron2/configs/Cityscapes/mask_rcnn_R_50_FPN.yaml
|
| 119 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_101_C4_3x.yaml
|
| 120 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml
|
| 121 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml
|
| 122 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_1x.yaml
|
| 123 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_3x.yaml
|
| 124 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_1x.yaml
|
| 125 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml
|
| 126 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml
|
| 127 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml
|
| 128 |
+
mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml
|
| 129 |
+
mhp_extension/detectron2/configs/COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml
|
| 130 |
+
mhp_extension/detectron2/configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml
|
| 131 |
+
mhp_extension/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_1x.yaml
|
| 132 |
+
mhp_extension/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_3x.yaml
|
| 133 |
+
mhp_extension/detectron2/configs/COCO-Detection/rpn_R_50_C4_1x.yaml
|
| 134 |
+
mhp_extension/detectron2/configs/COCO-Detection/rpn_R_50_FPN_1x.yaml
|
| 135 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml
|
| 136 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x.yaml
|
| 137 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml
|
| 138 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.yaml
|
| 139 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml
|
| 140 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x.yaml
|
| 141 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x.yaml
|
| 142 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
|
| 143 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml
|
| 144 |
+
mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml
|
| 145 |
+
mhp_extension/detectron2/configs/COCO-Keypoints/Base-Keypoint-RCNN-FPN.yaml
|
| 146 |
+
mhp_extension/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml
|
| 147 |
+
mhp_extension/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.yaml
|
| 148 |
+
mhp_extension/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml
|
| 149 |
+
mhp_extension/detectron2/configs/COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml
|
| 150 |
+
mhp_extension/detectron2/configs/COCO-PanopticSegmentation/Base-Panoptic-FPN.yaml
|
| 151 |
+
mhp_extension/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml
|
| 152 |
+
mhp_extension/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml
|
| 153 |
+
mhp_extension/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml
|
| 154 |
+
mhp_extension/detectron2/configs/Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml
|
| 155 |
+
mhp_extension/detectron2/configs/Detectron1-Comparisons/keypoint_rcnn_R_50_FPN_1x.yaml
|
| 156 |
+
mhp_extension/detectron2/configs/Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x.yaml
|
| 157 |
+
mhp_extension/detectron2/configs/Detectron1-Comparisons/README.md
|
| 158 |
+
mhp_extension/detectron2/configs/LVIS-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml
|
| 159 |
+
mhp_extension/detectron2/configs/LVIS-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
|
| 160 |
+
mhp_extension/detectron2/configs/LVIS-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml
|
| 161 |
+
mhp_extension/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_1x.yaml
|
| 162 |
+
mhp_extension/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml
|
| 163 |
+
mhp_extension/detectron2/configs/Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv_parsing.yaml
|
| 164 |
+
mhp_extension/detectron2/configs/Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml
|
| 165 |
+
mhp_extension/detectron2/configs/Misc/demo.yaml
|
| 166 |
+
mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_cls_agnostic.yaml
|
| 167 |
+
mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_dconv_c3-c5.yaml
|
| 168 |
+
mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5.yaml
|
| 169 |
+
mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_gn.yaml
|
| 170 |
+
mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml
|
| 171 |
+
mhp_extension/detectron2/configs/Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x.yaml
|
| 172 |
+
mhp_extension/detectron2/configs/Misc/parsing_finetune_cihp.yaml
|
| 173 |
+
mhp_extension/detectron2/configs/Misc/parsing_inference.yaml
|
| 174 |
+
mhp_extension/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml
|
| 175 |
+
mhp_extension/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_gn.yaml
|
| 176 |
+
mhp_extension/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn.yaml
|
| 177 |
+
mhp_extension/detectron2/configs/Misc/semantic_R_50_FPN_1x.yaml
|
| 178 |
+
mhp_extension/detectron2/configs/my_Base-RCNN-FPN.yaml
|
| 179 |
+
mhp_extension/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_C4.yaml
|
| 180 |
+
mhp_extension/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_FPN.yaml
|
| 181 |
+
mhp_extension/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_inference_acc_test.yaml
|
| 182 |
+
mhp_extension/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_instant_test.yaml
|
| 183 |
+
mhp_extension/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_inference_acc_test.yaml
|
| 184 |
+
mhp_extension/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_instant_test.yaml
|
| 185 |
+
mhp_extension/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_inference_acc_test.yaml
|
| 186 |
+
mhp_extension/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_instant_test.yaml
|
| 187 |
+
mhp_extension/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_normalized_training_acc_test.yaml
|
| 188 |
+
mhp_extension/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_training_acc_test.yaml
|
| 189 |
+
mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_GCV_instant_test.yaml
|
| 190 |
+
mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_inference_acc_test.yaml
|
| 191 |
+
mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_instant_test.yaml
|
| 192 |
+
mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_training_acc_test.yaml
|
| 193 |
+
mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_DC5_inference_acc_test.yaml
|
| 194 |
+
mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml
|
| 195 |
+
mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_instant_test.yaml
|
| 196 |
+
mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_training_acc_test.yaml
|
| 197 |
+
mhp_extension/detectron2/configs/quick_schedules/panoptic_fpn_R_50_inference_acc_test.yaml
|
| 198 |
+
mhp_extension/detectron2/configs/quick_schedules/panoptic_fpn_R_50_instant_test.yaml
|
| 199 |
+
mhp_extension/detectron2/configs/quick_schedules/panoptic_fpn_R_50_training_acc_test.yaml
|
| 200 |
+
mhp_extension/detectron2/configs/quick_schedules/README.md
|
| 201 |
+
mhp_extension/detectron2/configs/quick_schedules/retinanet_R_50_FPN_inference_acc_test.yaml
|
| 202 |
+
mhp_extension/detectron2/configs/quick_schedules/retinanet_R_50_FPN_instant_test.yaml
|
| 203 |
+
mhp_extension/detectron2/configs/quick_schedules/rpn_R_50_FPN_inference_acc_test.yaml
|
| 204 |
+
mhp_extension/detectron2/configs/quick_schedules/rpn_R_50_FPN_instant_test.yaml
|
| 205 |
+
mhp_extension/detectron2/configs/quick_schedules/semantic_R_50_FPN_inference_acc_test.yaml
|
| 206 |
+
mhp_extension/detectron2/configs/quick_schedules/semantic_R_50_FPN_instant_test.yaml
|
| 207 |
+
mhp_extension/detectron2/configs/quick_schedules/semantic_R_50_FPN_training_acc_test.yaml
|
| 208 |
+
mhp_extension/detectron2/demo/demo.py
|
| 209 |
+
mhp_extension/detectron2/demo/predictor.py
|
| 210 |
+
mhp_extension/detectron2/demo/README.md
|
| 211 |
+
mhp_extension/detectron2/detectron2/checkpoint/c2_model_loading.py
|
| 212 |
+
mhp_extension/detectron2/detectron2/checkpoint/catalog.py
|
| 213 |
+
mhp_extension/detectron2/detectron2/checkpoint/detection_checkpoint.py
|
| 214 |
+
mhp_extension/detectron2/detectron2/checkpoint/__init__.py
|
| 215 |
+
mhp_extension/detectron2/detectron2/config/compat.py
|
| 216 |
+
mhp_extension/detectron2/detectron2/config/config.py
|
| 217 |
+
mhp_extension/detectron2/detectron2/config/defaults.py
|
| 218 |
+
mhp_extension/detectron2/detectron2/config/__init__.py
|
| 219 |
+
mhp_extension/detectron2/detectron2/data/build.py
|
| 220 |
+
mhp_extension/detectron2/detectron2/data/catalog.py
|
| 221 |
+
mhp_extension/detectron2/detectron2/data/common.py
|
| 222 |
+
mhp_extension/detectron2/detectron2/data/dataset_mapper.py
|
| 223 |
+
mhp_extension/detectron2/detectron2/data/datasets/builtin_meta.py
|
| 224 |
+
mhp_extension/detectron2/detectron2/data/datasets/builtin.py
|
| 225 |
+
mhp_extension/detectron2/detectron2/data/datasets/cityscapes.py
|
| 226 |
+
mhp_extension/detectron2/detectron2/data/datasets/coco.py
|
| 227 |
+
mhp_extension/detectron2/detectron2/data/datasets/__init__.py
|
| 228 |
+
mhp_extension/detectron2/detectron2/data/datasets/lvis.py
|
| 229 |
+
mhp_extension/detectron2/detectron2/data/datasets/lvis_v0_5_categories.py
|
| 230 |
+
mhp_extension/detectron2/detectron2/data/datasets/pascal_voc.py
|
| 231 |
+
mhp_extension/detectron2/detectron2/data/datasets/README.md
|
| 232 |
+
mhp_extension/detectron2/detectron2/data/datasets/register_coco.py
|
| 233 |
+
mhp_extension/detectron2/detectron2/data/detection_utils.py
|
| 234 |
+
mhp_extension/detectron2/detectron2/data/__init__.py
|
| 235 |
+
mhp_extension/detectron2/detectron2/data/samplers/distributed_sampler.py
|
| 236 |
+
mhp_extension/detectron2/detectron2/data/samplers/grouped_batch_sampler.py
|
| 237 |
+
mhp_extension/detectron2/detectron2/data/samplers/__init__.py
|
| 238 |
+
mhp_extension/detectron2/detectron2/data/transforms/__init__.py
|
| 239 |
+
mhp_extension/detectron2/detectron2/data/transforms/transform_gen.py
|
| 240 |
+
mhp_extension/detectron2/detectron2/data/transforms/transform.py
|
| 241 |
+
mhp_extension/detectron2/detectron2/engine/defaults.py
|
| 242 |
+
mhp_extension/detectron2/detectron2/engine/hooks.py
|
| 243 |
+
mhp_extension/detectron2/detectron2/engine/__init__.py
|
| 244 |
+
mhp_extension/detectron2/detectron2/engine/launch.py
|
| 245 |
+
mhp_extension/detectron2/detectron2/engine/train_loop.py
|
| 246 |
+
mhp_extension/detectron2/detectron2/evaluation/cityscapes_evaluation.py
|
| 247 |
+
mhp_extension/detectron2/detectron2/evaluation/coco_evaluation.py
|
| 248 |
+
mhp_extension/detectron2/detectron2/evaluation/evaluator.py
|
| 249 |
+
mhp_extension/detectron2/detectron2/evaluation/__init__.py
|
| 250 |
+
mhp_extension/detectron2/detectron2/evaluation/lvis_evaluation.py
|
| 251 |
+
mhp_extension/detectron2/detectron2/evaluation/panoptic_evaluation.py
|
| 252 |
+
mhp_extension/detectron2/detectron2/evaluation/pascal_voc_evaluation.py
|
| 253 |
+
mhp_extension/detectron2/detectron2/evaluation/rotated_coco_evaluation.py
|
| 254 |
+
mhp_extension/detectron2/detectron2/evaluation/sem_seg_evaluation.py
|
| 255 |
+
mhp_extension/detectron2/detectron2/evaluation/testing.py
|
| 256 |
+
mhp_extension/detectron2/detectron2/export/api.py
|
| 257 |
+
mhp_extension/detectron2/detectron2/export/c10.py
|
| 258 |
+
mhp_extension/detectron2/detectron2/export/caffe2_export.py
|
| 259 |
+
mhp_extension/detectron2/detectron2/export/caffe2_inference.py
|
| 260 |
+
mhp_extension/detectron2/detectron2/export/caffe2_modeling.py
|
| 261 |
+
mhp_extension/detectron2/detectron2/export/__init__.py
|
| 262 |
+
mhp_extension/detectron2/detectron2/export/patcher.py
|
| 263 |
+
mhp_extension/detectron2/detectron2/export/README.md
|
| 264 |
+
mhp_extension/detectron2/detectron2/export/shared.py
|
| 265 |
+
mhp_extension/detectron2/detectron2/__init__.py
|
| 266 |
+
mhp_extension/detectron2/detectron2/layers/batch_norm.py
|
| 267 |
+
mhp_extension/detectron2/detectron2/layers/blocks.py
|
| 268 |
+
mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
|
| 269 |
+
mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
|
| 270 |
+
mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
|
| 271 |
+
mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
|
| 272 |
+
mhp_extension/detectron2/detectron2/layers/csrc/cuda_version.cu
|
| 273 |
+
mhp_extension/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda.cu
|
| 274 |
+
mhp_extension/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu
|
| 275 |
+
mhp_extension/detectron2/detectron2/layers/csrc/deformable/deform_conv.h
|
| 276 |
+
mhp_extension/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp
|
| 277 |
+
mhp_extension/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
|
| 278 |
+
mhp_extension/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated.h
|
| 279 |
+
mhp_extension/detectron2/detectron2/layers/csrc/README.md
|
| 280 |
+
mhp_extension/detectron2/detectron2/layers/csrc/ROIAlign/ROIAlign_cpu.cpp
|
| 281 |
+
mhp_extension/detectron2/detectron2/layers/csrc/ROIAlign/ROIAlign_cuda.cu
|
| 282 |
+
mhp_extension/detectron2/detectron2/layers/csrc/ROIAlign/ROIAlign.h
|
| 283 |
+
mhp_extension/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp
|
| 284 |
+
mhp_extension/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu
|
| 285 |
+
mhp_extension/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h
|
| 286 |
+
mhp_extension/detectron2/detectron2/layers/csrc/vision.cpp
|
| 287 |
+
mhp_extension/detectron2/detectron2/layers/deform_conv.py
|
| 288 |
+
mhp_extension/detectron2/detectron2/layers/__init__.py
|
| 289 |
+
mhp_extension/detectron2/detectron2/layers/mask_ops.py
|
| 290 |
+
mhp_extension/detectron2/detectron2/layers/nms.py
|
| 291 |
+
mhp_extension/detectron2/detectron2/layers/roi_align.py
|
| 292 |
+
mhp_extension/detectron2/detectron2/layers/roi_align_rotated.py
|
| 293 |
+
mhp_extension/detectron2/detectron2/layers/rotated_boxes.py
|
| 294 |
+
mhp_extension/detectron2/detectron2/layers/shape_spec.py
|
| 295 |
+
mhp_extension/detectron2/detectron2/layers/wrappers.py
|
| 296 |
+
mhp_extension/detectron2/detectron2/modeling/anchor_generator.py
|
| 297 |
+
mhp_extension/detectron2/detectron2/modeling/backbone/backbone.py
|
| 298 |
+
mhp_extension/detectron2/detectron2/modeling/backbone/build.py
|
| 299 |
+
mhp_extension/detectron2/detectron2/modeling/backbone/fpn.py
|
| 300 |
+
mhp_extension/detectron2/detectron2/modeling/backbone/__init__.py
|
| 301 |
+
mhp_extension/detectron2/detectron2/modeling/backbone/resnet.py
|
| 302 |
+
mhp_extension/detectron2/detectron2/modeling/box_regression.py
|
| 303 |
+
mhp_extension/detectron2/detectron2/modeling/__init__.py
|
| 304 |
+
mhp_extension/detectron2/detectron2/modeling/matcher.py
|
| 305 |
+
mhp_extension/detectron2/detectron2/modeling/meta_arch/build.py
|
| 306 |
+
mhp_extension/detectron2/detectron2/modeling/meta_arch/__init__.py
|
| 307 |
+
mhp_extension/detectron2/detectron2/modeling/meta_arch/panoptic_fpn.py
|
| 308 |
+
mhp_extension/detectron2/detectron2/modeling/meta_arch/rcnn.py
|
| 309 |
+
mhp_extension/detectron2/detectron2/modeling/meta_arch/retinanet.py
|
| 310 |
+
mhp_extension/detectron2/detectron2/modeling/meta_arch/semantic_seg.py
|
| 311 |
+
mhp_extension/detectron2/detectron2/modeling/poolers.py
|
| 312 |
+
mhp_extension/detectron2/detectron2/modeling/postprocessing.py
|
| 313 |
+
mhp_extension/detectron2/detectron2/modeling/proposal_generator/build.py
|
| 314 |
+
mhp_extension/detectron2/detectron2/modeling/proposal_generator/__init__.py
|
| 315 |
+
mhp_extension/detectron2/detectron2/modeling/proposal_generator/proposal_utils.py
|
| 316 |
+
mhp_extension/detectron2/detectron2/modeling/proposal_generator/rpn_outputs.py
|
| 317 |
+
mhp_extension/detectron2/detectron2/modeling/proposal_generator/rpn.py
|
| 318 |
+
mhp_extension/detectron2/detectron2/modeling/proposal_generator/rrpn.py
|
| 319 |
+
mhp_extension/detectron2/detectron2/modeling/roi_heads/box_head.py
|
| 320 |
+
mhp_extension/detectron2/detectron2/modeling/roi_heads/cascade_rcnn.py
|
| 321 |
+
mhp_extension/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py
|
| 322 |
+
mhp_extension/detectron2/detectron2/modeling/roi_heads/__init__.py
|
| 323 |
+
mhp_extension/detectron2/detectron2/modeling/roi_heads/keypoint_head.py
|
| 324 |
+
mhp_extension/detectron2/detectron2/modeling/roi_heads/mask_head.py
|
| 325 |
+
mhp_extension/detectron2/detectron2/modeling/roi_heads/roi_heads.py
|
| 326 |
+
mhp_extension/detectron2/detectron2/modeling/roi_heads/rotated_fast_rcnn.py
|
| 327 |
+
mhp_extension/detectron2/detectron2/modeling/sampling.py
|
| 328 |
+
mhp_extension/detectron2/detectron2/modeling/test_time_augmentation.py
|
| 329 |
+
mhp_extension/detectron2/detectron2/model_zoo/__init__.py
|
| 330 |
+
mhp_extension/detectron2/detectron2/model_zoo/model_zoo.py
|
| 331 |
+
mhp_extension/detectron2/detectron2/solver/build.py
|
| 332 |
+
mhp_extension/detectron2/detectron2/solver/__init__.py
|
| 333 |
+
mhp_extension/detectron2/detectron2/solver/lr_scheduler.py
|
| 334 |
+
mhp_extension/detectron2/detectron2/structures/boxes.py
|
| 335 |
+
mhp_extension/detectron2/detectron2/structures/image_list.py
|
| 336 |
+
mhp_extension/detectron2/detectron2/structures/__init__.py
|
| 337 |
+
mhp_extension/detectron2/detectron2/structures/instances.py
|
| 338 |
+
mhp_extension/detectron2/detectron2/structures/keypoints.py
|
| 339 |
+
mhp_extension/detectron2/detectron2/structures/masks.py
|
| 340 |
+
mhp_extension/detectron2/detectron2/structures/rotated_boxes.py
|
| 341 |
+
mhp_extension/detectron2/detectron2/utils/analysis.py
|
| 342 |
+
mhp_extension/detectron2/detectron2/utils/collect_env.py
|
| 343 |
+
mhp_extension/detectron2/detectron2/utils/colormap.py
|
| 344 |
+
mhp_extension/detectron2/detectron2/utils/comm.py
|
| 345 |
+
mhp_extension/detectron2/detectron2/utils/env.py
|
| 346 |
+
mhp_extension/detectron2/detectron2/utils/events.py
|
| 347 |
+
mhp_extension/detectron2/detectron2/utils/__init__.py
|
| 348 |
+
mhp_extension/detectron2/detectron2/utils/logger.py
|
| 349 |
+
mhp_extension/detectron2/detectron2/utils/memory.py
|
| 350 |
+
mhp_extension/detectron2/detectron2/utils/README.md
|
| 351 |
+
mhp_extension/detectron2/detectron2/utils/registry.py
|
| 352 |
+
mhp_extension/detectron2/detectron2/utils/serialize.py
|
| 353 |
+
mhp_extension/detectron2/detectron2/utils/video_visualizer.py
|
| 354 |
+
mhp_extension/detectron2/detectron2/utils/visualizer.py
|
| 355 |
+
mhp_extension/detectron2/dev/linter.sh
|
| 356 |
+
mhp_extension/detectron2/dev/packaging/build_all_wheels.sh
|
| 357 |
+
mhp_extension/detectron2/dev/packaging/build_wheel.sh
|
| 358 |
+
mhp_extension/detectron2/dev/packaging/gen_wheel_index.sh
|
| 359 |
+
mhp_extension/detectron2/dev/packaging/pkg_helpers.bash
|
| 360 |
+
mhp_extension/detectron2/dev/packaging/README.md
|
| 361 |
+
mhp_extension/detectron2/dev/parse_results.sh
|
| 362 |
+
mhp_extension/detectron2/dev/README.md
|
| 363 |
+
mhp_extension/detectron2/dev/run_inference_tests.sh
|
| 364 |
+
mhp_extension/detectron2/dev/run_instant_tests.sh
|
| 365 |
+
mhp_extension/detectron2/docker/docker-compose.yml
|
| 366 |
+
mhp_extension/detectron2/docker/Dockerfile
|
| 367 |
+
mhp_extension/detectron2/docker/Dockerfile-circleci
|
| 368 |
+
mhp_extension/detectron2/docker/README.md
|
| 369 |
+
mhp_extension/detectron2/docs/conf.py
|
| 370 |
+
mhp_extension/detectron2/docs/.gitignore
|
| 371 |
+
mhp_extension/detectron2/docs/index.rst
|
| 372 |
+
mhp_extension/detectron2/docs/Makefile
|
| 373 |
+
mhp_extension/detectron2/docs/modules/checkpoint.rst
|
| 374 |
+
mhp_extension/detectron2/docs/modules/config.rst
|
| 375 |
+
mhp_extension/detectron2/docs/modules/data.rst
|
| 376 |
+
mhp_extension/detectron2/docs/modules/engine.rst
|
| 377 |
+
mhp_extension/detectron2/docs/modules/evaluation.rst
|
| 378 |
+
mhp_extension/detectron2/docs/modules/export.rst
|
| 379 |
+
mhp_extension/detectron2/docs/modules/index.rst
|
| 380 |
+
mhp_extension/detectron2/docs/modules/layers.rst
|
| 381 |
+
mhp_extension/detectron2/docs/modules/modeling.rst
|
| 382 |
+
mhp_extension/detectron2/docs/modules/model_zoo.rst
|
| 383 |
+
mhp_extension/detectron2/docs/modules/solver.rst
|
| 384 |
+
mhp_extension/detectron2/docs/modules/structures.rst
|
| 385 |
+
mhp_extension/detectron2/docs/modules/utils.rst
|
| 386 |
+
mhp_extension/detectron2/docs/notes/benchmarks.md
|
| 387 |
+
mhp_extension/detectron2/docs/notes/changelog.md
|
| 388 |
+
mhp_extension/detectron2/docs/notes/compatibility.md
|
| 389 |
+
mhp_extension/detectron2/docs/notes/contributing.md
|
| 390 |
+
mhp_extension/detectron2/docs/notes/index.rst
|
| 391 |
+
mhp_extension/detectron2/docs/README.md
|
| 392 |
+
mhp_extension/detectron2/docs/tutorials/builtin_datasets.md
|
| 393 |
+
mhp_extension/detectron2/docs/tutorials/configs.md
|
| 394 |
+
mhp_extension/detectron2/docs/tutorials/data_loading.md
|
| 395 |
+
mhp_extension/detectron2/docs/tutorials/datasets.md
|
| 396 |
+
mhp_extension/detectron2/docs/tutorials/deployment.md
|
| 397 |
+
mhp_extension/detectron2/docs/tutorials/evaluation.md
|
| 398 |
+
mhp_extension/detectron2/docs/tutorials/extend.md
|
| 399 |
+
mhp_extension/detectron2/docs/tutorials/getting_started.md
|
| 400 |
+
mhp_extension/detectron2/docs/tutorials/index.rst
|
| 401 |
+
mhp_extension/detectron2/docs/tutorials/install.md
|
| 402 |
+
mhp_extension/detectron2/docs/tutorials/models.md
|
| 403 |
+
mhp_extension/detectron2/docs/tutorials/README.md
|
| 404 |
+
mhp_extension/detectron2/docs/tutorials/training.md
|
| 405 |
+
mhp_extension/detectron2/docs/tutorials/write-models.md
|
| 406 |
+
mhp_extension/detectron2/.flake8
|
| 407 |
+
mhp_extension/detectron2/GETTING_STARTED.md
|
| 408 |
+
mhp_extension/detectron2/.gitignore
|
| 409 |
+
mhp_extension/detectron2/INSTALL.md
|
| 410 |
+
mhp_extension/detectron2/LICENSE
|
| 411 |
+
mhp_extension/detectron2/MODEL_ZOO.md
|
| 412 |
+
mhp_extension/detectron2/projects/DensePose/apply_net.py
|
| 413 |
+
mhp_extension/detectron2/projects/DensePose/configs/Base-DensePose-RCNN-FPN.yaml
|
| 414 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_s1x.yaml
|
| 415 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC1_s1x.yaml
|
| 416 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC2_s1x.yaml
|
| 417 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x_legacy.yaml
|
| 418 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x.yaml
|
| 419 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC1_s1x.yaml
|
| 420 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC2_s1x.yaml
|
| 421 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_s1x.yaml
|
| 422 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC1_s1x.yaml
|
| 423 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC2_s1x.yaml
|
| 424 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x_legacy.yaml
|
| 425 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml
|
| 426 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC1_s1x.yaml
|
| 427 |
+
mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC2_s1x.yaml
|
| 428 |
+
mhp_extension/detectron2/projects/DensePose/configs/evolution/Base-RCNN-FPN-MC.yaml
|
| 429 |
+
mhp_extension/detectron2/projects/DensePose/configs/evolution/faster_rcnn_R_50_FPN_1x_MC.yaml
|
| 430 |
+
mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_DL_instant_test.yaml
|
| 431 |
+
mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_inference_acc_test.yaml
|
| 432 |
+
mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_instant_test.yaml
|
| 433 |
+
mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_training_acc_test.yaml
|
| 434 |
+
mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_TTA_inference_acc_test.yaml
|
| 435 |
+
mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC1_instant_test.yaml
|
| 436 |
+
mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC2_instant_test.yaml
|
| 437 |
+
mhp_extension/detectron2/projects/DensePose/densepose/config.py
|
| 438 |
+
mhp_extension/detectron2/projects/DensePose/densepose/data/build.py
|
| 439 |
+
mhp_extension/detectron2/projects/DensePose/densepose/data/dataset_mapper.py
|
| 440 |
+
mhp_extension/detectron2/projects/DensePose/densepose/data/datasets/builtin.py
|
| 441 |
+
mhp_extension/detectron2/projects/DensePose/densepose/data/datasets/coco.py
|
| 442 |
+
mhp_extension/detectron2/projects/DensePose/densepose/data/datasets/__init__.py
|
| 443 |
+
mhp_extension/detectron2/projects/DensePose/densepose/data/__init__.py
|
| 444 |
+
mhp_extension/detectron2/projects/DensePose/densepose/data/structures.py
|
| 445 |
+
mhp_extension/detectron2/projects/DensePose/densepose/densepose_coco_evaluation.py
|
| 446 |
+
mhp_extension/detectron2/projects/DensePose/densepose/densepose_head.py
|
| 447 |
+
mhp_extension/detectron2/projects/DensePose/densepose/evaluator.py
|
| 448 |
+
mhp_extension/detectron2/projects/DensePose/densepose/__init__.py
|
| 449 |
+
mhp_extension/detectron2/projects/DensePose/densepose/modeling/test_time_augmentation.py
|
| 450 |
+
mhp_extension/detectron2/projects/DensePose/densepose/roi_head.py
|
| 451 |
+
mhp_extension/detectron2/projects/DensePose/densepose/utils/dbhelper.py
|
| 452 |
+
mhp_extension/detectron2/projects/DensePose/densepose/utils/logger.py
|
| 453 |
+
mhp_extension/detectron2/projects/DensePose/densepose/utils/transform.py
|
| 454 |
+
mhp_extension/detectron2/projects/DensePose/densepose/vis/base.py
|
| 455 |
+
mhp_extension/detectron2/projects/DensePose/densepose/vis/bounding_box.py
|
| 456 |
+
mhp_extension/detectron2/projects/DensePose/densepose/vis/densepose.py
|
| 457 |
+
mhp_extension/detectron2/projects/DensePose/densepose/vis/extractor.py
|
| 458 |
+
mhp_extension/detectron2/projects/DensePose/dev/README.md
|
| 459 |
+
mhp_extension/detectron2/projects/DensePose/dev/run_inference_tests.sh
|
| 460 |
+
mhp_extension/detectron2/projects/DensePose/dev/run_instant_tests.sh
|
| 461 |
+
mhp_extension/detectron2/projects/DensePose/doc/GETTING_STARTED.md
|
| 462 |
+
mhp_extension/detectron2/projects/DensePose/doc/MODEL_ZOO.md
|
| 463 |
+
mhp_extension/detectron2/projects/DensePose/doc/TOOL_APPLY_NET.md
|
| 464 |
+
mhp_extension/detectron2/projects/DensePose/doc/TOOL_QUERY_DB.md
|
| 465 |
+
mhp_extension/detectron2/projects/DensePose/query_db.py
|
| 466 |
+
mhp_extension/detectron2/projects/DensePose/README.md
|
| 467 |
+
mhp_extension/detectron2/projects/DensePose/tests/common.py
|
| 468 |
+
mhp_extension/detectron2/projects/DensePose/tests/test_model_e2e.py
|
| 469 |
+
mhp_extension/detectron2/projects/DensePose/tests/test_setup.py
|
| 470 |
+
mhp_extension/detectron2/projects/DensePose/tests/test_structures.py
|
| 471 |
+
mhp_extension/detectron2/projects/DensePose/train_net.py
|
| 472 |
+
mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-PointRend-RCNN-FPN.yaml
|
| 473 |
+
mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml
|
| 474 |
+
mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml
|
| 475 |
+
mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml
|
| 476 |
+
mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_parsing.yaml
|
| 477 |
+
mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_parsing.yaml
|
| 478 |
+
mhp_extension/detectron2/projects/PointRend/configs/SemanticSegmentation/Base-PointRend-Semantic-FPN.yaml
|
| 479 |
+
mhp_extension/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml
|
| 480 |
+
mhp_extension/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_50_FPN_1x_coco.yaml
|
| 481 |
+
mhp_extension/detectron2/projects/PointRend/finetune_net.py
|
| 482 |
+
mhp_extension/detectron2/projects/PointRend/logs/hadoop.kylin.libdfs.log
|
| 483 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/coarse_mask_head.py
|
| 484 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/color_augmentation.py
|
| 485 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/config.py
|
| 486 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/dataset_mapper.py
|
| 487 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/__init__.py
|
| 488 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/point_features.py
|
| 489 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/point_head.py
|
| 490 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/roi_heads.py
|
| 491 |
+
mhp_extension/detectron2/projects/PointRend/point_rend/semantic_seg.py
|
| 492 |
+
mhp_extension/detectron2/projects/PointRend/README.md
|
| 493 |
+
mhp_extension/detectron2/projects/PointRend/run.sh
|
| 494 |
+
mhp_extension/detectron2/projects/PointRend/train_net.py
|
| 495 |
+
mhp_extension/detectron2/projects/README.md
|
| 496 |
+
mhp_extension/detectron2/projects/TensorMask/configs/Base-TensorMask.yaml
|
| 497 |
+
mhp_extension/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_1x.yaml
|
| 498 |
+
mhp_extension/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_6x.yaml
|
| 499 |
+
mhp_extension/detectron2/projects/TensorMask/README.md
|
| 500 |
+
mhp_extension/detectron2/projects/TensorMask/setup.py
|
| 501 |
+
mhp_extension/detectron2/projects/TensorMask/tensormask/arch.py
|
| 502 |
+
mhp_extension/detectron2/projects/TensorMask/tensormask/config.py
|
| 503 |
+
mhp_extension/detectron2/projects/TensorMask/tensormask/__init__.py
|
| 504 |
+
mhp_extension/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat_cuda.cu
|
| 505 |
+
mhp_extension/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat.h
|
| 506 |
+
mhp_extension/detectron2/projects/TensorMask/tensormask/layers/csrc/vision.cpp
|
| 507 |
+
mhp_extension/detectron2/projects/TensorMask/tensormask/layers/__init__.py
|
| 508 |
+
mhp_extension/detectron2/projects/TensorMask/tensormask/layers/swap_align2nat.py
|
| 509 |
+
mhp_extension/detectron2/projects/TensorMask/tests/__init__.py
|
| 510 |
+
mhp_extension/detectron2/projects/TensorMask/tests/test_swap_align2nat.py
|
| 511 |
+
mhp_extension/detectron2/projects/TensorMask/train_net.py
|
| 512 |
+
mhp_extension/detectron2/projects/TridentNet/configs/Base-TridentNet-Fast-C4.yaml
|
| 513 |
+
mhp_extension/detectron2/projects/TridentNet/configs/tridentnet_fast_R_101_C4_3x.yaml
|
| 514 |
+
mhp_extension/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_1x.yaml
|
| 515 |
+
mhp_extension/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_3x.yaml
|
| 516 |
+
mhp_extension/detectron2/projects/TridentNet/README.md
|
| 517 |
+
mhp_extension/detectron2/projects/TridentNet/train_net.py
|
| 518 |
+
mhp_extension/detectron2/projects/TridentNet/tridentnet/config.py
|
| 519 |
+
mhp_extension/detectron2/projects/TridentNet/tridentnet/__init__.py
|
| 520 |
+
mhp_extension/detectron2/projects/TridentNet/tridentnet/trident_backbone.py
|
| 521 |
+
mhp_extension/detectron2/projects/TridentNet/tridentnet/trident_conv.py
|
| 522 |
+
mhp_extension/detectron2/projects/TridentNet/tridentnet/trident_rcnn.py
|
| 523 |
+
mhp_extension/detectron2/projects/TridentNet/tridentnet/trident_rpn.py
|
| 524 |
+
mhp_extension/detectron2/README.md
|
| 525 |
+
mhp_extension/detectron2/setup.cfg
|
| 526 |
+
mhp_extension/detectron2/setup.py
|
| 527 |
+
mhp_extension/detectron2/tests/data/__init__.py
|
| 528 |
+
mhp_extension/detectron2/tests/data/test_coco.py
|
| 529 |
+
mhp_extension/detectron2/tests/data/test_detection_utils.py
|
| 530 |
+
mhp_extension/detectron2/tests/data/test_rotation_transform.py
|
| 531 |
+
mhp_extension/detectron2/tests/data/test_sampler.py
|
| 532 |
+
mhp_extension/detectron2/tests/data/test_transforms.py
|
| 533 |
+
mhp_extension/detectron2/tests/__init__.py
|
| 534 |
+
mhp_extension/detectron2/tests/layers/__init__.py
|
| 535 |
+
mhp_extension/detectron2/tests/layers/test_mask_ops.py
|
| 536 |
+
mhp_extension/detectron2/tests/layers/test_nms_rotated.py
|
| 537 |
+
mhp_extension/detectron2/tests/layers/test_roi_align.py
|
| 538 |
+
mhp_extension/detectron2/tests/layers/test_roi_align_rotated.py
|
| 539 |
+
mhp_extension/detectron2/tests/modeling/__init__.py
|
| 540 |
+
mhp_extension/detectron2/tests/modeling/test_anchor_generator.py
|
| 541 |
+
mhp_extension/detectron2/tests/modeling/test_box2box_transform.py
|
| 542 |
+
mhp_extension/detectron2/tests/modeling/test_fast_rcnn.py
|
| 543 |
+
mhp_extension/detectron2/tests/modeling/test_model_e2e.py
|
| 544 |
+
mhp_extension/detectron2/tests/modeling/test_roi_heads.py
|
| 545 |
+
mhp_extension/detectron2/tests/modeling/test_roi_pooler.py
|
| 546 |
+
mhp_extension/detectron2/tests/modeling/test_rpn.py
|
| 547 |
+
mhp_extension/detectron2/tests/README.md
|
| 548 |
+
mhp_extension/detectron2/tests/structures/__init__.py
|
| 549 |
+
mhp_extension/detectron2/tests/structures/test_boxes.py
|
| 550 |
+
mhp_extension/detectron2/tests/structures/test_imagelist.py
|
| 551 |
+
mhp_extension/detectron2/tests/structures/test_instances.py
|
| 552 |
+
mhp_extension/detectron2/tests/structures/test_rotated_boxes.py
|
| 553 |
+
mhp_extension/detectron2/tests/test_checkpoint.py
|
| 554 |
+
mhp_extension/detectron2/tests/test_config.py
|
| 555 |
+
mhp_extension/detectron2/tests/test_export_caffe2.py
|
| 556 |
+
mhp_extension/detectron2/tests/test_model_analysis.py
|
| 557 |
+
mhp_extension/detectron2/tests/test_model_zoo.py
|
| 558 |
+
mhp_extension/detectron2/tests/test_visualizer.py
|
| 559 |
+
mhp_extension/detectron2/tools/analyze_model.py
|
| 560 |
+
mhp_extension/detectron2/tools/benchmark.py
|
| 561 |
+
mhp_extension/detectron2/tools/convert-torchvision-to-d2.py
|
| 562 |
+
mhp_extension/detectron2/tools/deploy/caffe2_converter.py
|
| 563 |
+
mhp_extension/detectron2/tools/deploy/caffe2_mask_rcnn.cpp
|
| 564 |
+
mhp_extension/detectron2/tools/deploy/README.md
|
| 565 |
+
mhp_extension/detectron2/tools/deploy/torchscript_traced_mask_rcnn.cpp
|
| 566 |
+
mhp_extension/detectron2/tools/finetune_net.py
|
| 567 |
+
mhp_extension/detectron2/tools/inference.sh
|
| 568 |
+
mhp_extension/detectron2/tools/plain_train_net.py
|
| 569 |
+
mhp_extension/detectron2/tools/README.md
|
| 570 |
+
mhp_extension/detectron2/tools/run.sh
|
| 571 |
+
mhp_extension/detectron2/tools/train_net.py
|
| 572 |
+
mhp_extension/detectron2/tools/visualize_data.py
|
| 573 |
+
mhp_extension/detectron2/tools/visualize_json_results.py
|
| 574 |
+
mhp_extension/global_local_parsing/global_local_datasets.py
|
| 575 |
+
mhp_extension/global_local_parsing/global_local_evaluate.py
|
| 576 |
+
mhp_extension/global_local_parsing/global_local_train.py
|
| 577 |
+
mhp_extension/global_local_parsing/make_id_list.py
|
| 578 |
+
mhp_extension/logits_fusion.py
|
| 579 |
+
mhp_extension/make_crop_and_mask_w_mask_nms.py
|
| 580 |
+
mhp_extension/README.md
|
| 581 |
+
mhp_extension/scripts/make_coco_style_annotation.sh
|
| 582 |
+
mhp_extension/scripts/make_crop.sh
|
| 583 |
+
mhp_extension/scripts/parsing_fusion.sh
|
| 584 |
+
modules/bn.py
|
| 585 |
+
modules/deeplab.py
|
| 586 |
+
modules/dense.py
|
| 587 |
+
modules/functions.py
|
| 588 |
+
modules/__init__.py
|
| 589 |
+
modules/misc.py
|
| 590 |
+
modules/residual.py
|
| 591 |
+
modules/src/checks.h
|
| 592 |
+
modules/src/inplace_abn.cpp
|
| 593 |
+
modules/src/inplace_abn_cpu.cpp
|
| 594 |
+
modules/src/inplace_abn_cuda.cu
|
| 595 |
+
modules/src/inplace_abn_cuda_half.cu
|
| 596 |
+
modules/src/inplace_abn.h
|
| 597 |
+
modules/src/utils/checks.h
|
| 598 |
+
modules/src/utils/common.h
|
| 599 |
+
modules/src/utils/cuda.cuh
|
| 600 |
+
networks/AugmentCE2P.py
|
| 601 |
+
networks/backbone/mobilenetv2.py
|
| 602 |
+
networks/backbone/resnet.py
|
| 603 |
+
networks/backbone/resnext.py
|
| 604 |
+
networks/context_encoding/aspp.py
|
| 605 |
+
networks/context_encoding/ocnet.py
|
| 606 |
+
networks/context_encoding/psp.py
|
| 607 |
+
networks/__init__.py
|
| 608 |
+
README.md
|
| 609 |
+
requirements.txt
|
| 610 |
+
simple_extractor.py
|
| 611 |
+
training_code/MVANet/README.org
|
| 612 |
+
train.py
|
| 613 |
+
utils/consistency_loss.py
|
| 614 |
+
utils/criterion.py
|
| 615 |
+
utils/encoding.py
|
| 616 |
+
utils/__init__.py
|
| 617 |
+
utils/kl_loss.py
|
| 618 |
+
utils/lovasz_softmax.py
|
| 619 |
+
utils/miou.py
|
| 620 |
+
utils/schp.py
|
| 621 |
+
utils/soft_dice_loss.py
|
| 622 |
+
utils/transforms.py
|
| 623 |
+
utils/warmup_scheduler.py
|
| 624 |
+
MVANet_Inference/README.org
|
| 625 |
+
#+end_src
|
| 626 |
+
|
| 627 |
+
* List of files to remove
|
| 628 |
+
#+begin_src conf :tangle ./rm.txt
|
| 629 |
+
ComfyUI_MVANet/__pycache__/__init__.cpython-310.pyc
|
| 630 |
+
ComfyUI_MVANet/#README.org#
|
| 631 |
+
ComfyUI_MVANet/.#README.org
|
| 632 |
+
ComfyUI_MVANet/README.org~
|
| 633 |
+
ComfyUI_MVANet/.README.org.~undo-tree~
|
| 634 |
+
#main.org#
|
| 635 |
+
.#main.org
|
| 636 |
+
main.org~
|
| 637 |
+
.main.org.~undo-tree~
|
| 638 |
+
.README.md.~undo-tree~
|
| 639 |
+
ComfyUI_MVANet/.#README.org
|
| 640 |
+
ComfyUI_AEMatter/__pycache__/__init__.cpython-310.pyc
|
| 641 |
+
ComfyUI_AEMatter/AEMatter.class.py
|
| 642 |
+
ComfyUI_AEMatter/AEMatter.execute.py
|
| 643 |
+
ComfyUI_AEMatter/AEMatter.function.py
|
| 644 |
+
ComfyUI_AEMatter/AEMatter.import.py
|
| 645 |
+
ComfyUI_MVANet/MVANet_inference.class.py
|
| 646 |
+
ComfyUI_MVANet/MVANet_inference.execute.py
|
| 647 |
+
ComfyUI_MVANet/MVANet_inference.function.py
|
| 648 |
+
ComfyUI_MVANet/MVANet_inference.import.py
|
| 649 |
+
ComfyUI_MVANet/MVANet_inference.unify.sh
|
| 650 |
+
ComfyUI_AEMatter/AEMatter.unify.sh
|
| 651 |
+
git_add.txt
|
| 652 |
+
git_lfs_track.txt
|
| 653 |
+
gitignore.txt
|
| 654 |
+
rm.txt
|
| 655 |
+
work.sh
|
| 656 |
+
#+end_src
|
| 657 |
+
|
| 658 |
+
* List of patterns to ignore
|
| 659 |
+
#+begin_src conf :tangle ./gitignore.txt
|
| 660 |
+
log/
|
| 661 |
+
pretrain_model/
|
| 662 |
+
commit_and_push.sh
|
| 663 |
+
#+end_src
|
mhp_extension/README.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Self Correction for Human Parsing
|
| 2 |
+
|
| 3 |
+
We propose a simple yet effective multiple human parsing framework by extending our self-correction network.
|
| 4 |
+
|
| 5 |
+
Here we show an example usage jupyter notebook in [demo.ipynb](./demo.ipynb).
|
| 6 |
+
|
| 7 |
+
## Requirements
|
| 8 |
+
|
| 9 |
+
Please see [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md) for further requirements.
|
| 10 |
+
|
| 11 |
+
## Citation
|
| 12 |
+
|
| 13 |
+
Please cite our work if you find this repo useful in your research.
|
| 14 |
+
|
| 15 |
+
```latex
|
| 16 |
+
@article{li2019self,
|
| 17 |
+
title={Self-Correction for Human Parsing},
|
| 18 |
+
author={Li, Peike and Xu, Yunqiu and Wei, Yunchao and Yang, Yi},
|
| 19 |
+
journal={arXiv preprint arXiv:1910.09777},
|
| 20 |
+
year={2019}
|
| 21 |
+
}
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
## Visualization
|
| 25 |
+
|
| 26 |
+
* Source Image.
|
| 27 |
+

|
| 28 |
+
* Instance Human Mask.
|
| 29 |
+

|
| 30 |
+
* Global Human Parsing Result.
|
| 31 |
+

|
| 32 |
+
* Multiple Human Parsing Result.
|
| 33 |
+

|
| 34 |
+
|
| 35 |
+
## Related
|
| 36 |
+
|
| 37 |
+
Our implementation is based on the [Detectron2](https://github.com/facebookresearch/detectron2).
|
| 38 |
+
|
mhp_extension/coco_style_annotation_creator/__pycache__/pycococreatortools.cpython-37.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
mhp_extension/coco_style_annotation_creator/human_to_coco.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import pycococreatortools
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_arguments():
|
| 12 |
+
parser = argparse.ArgumentParser(description="transform mask annotation to coco annotation")
|
| 13 |
+
parser.add_argument("--dataset", type=str, default='CIHP', help="name of dataset (CIHP, MHPv2 or VIP)")
|
| 14 |
+
parser.add_argument("--json_save_dir", type=str, default='../data/msrcnn_finetune_annotations',
|
| 15 |
+
help="path to save coco-style annotation json file")
|
| 16 |
+
parser.add_argument("--use_val", type=bool, default=False,
|
| 17 |
+
help="use train+val set for finetuning or not")
|
| 18 |
+
parser.add_argument("--train_img_dir", type=str, default='../data/instance-level_human_parsing/Training/Images',
|
| 19 |
+
help="train image path")
|
| 20 |
+
parser.add_argument("--train_anno_dir", type=str,
|
| 21 |
+
default='../data/instance-level_human_parsing/Training/Human_ids',
|
| 22 |
+
help="train human mask path")
|
| 23 |
+
parser.add_argument("--val_img_dir", type=str, default='../data/instance-level_human_parsing/Validation/Images',
|
| 24 |
+
help="val image path")
|
| 25 |
+
parser.add_argument("--val_anno_dir", type=str,
|
| 26 |
+
default='../data/instance-level_human_parsing/Validation/Human_ids',
|
| 27 |
+
help="val human mask path")
|
| 28 |
+
return parser.parse_args()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main(args):
|
| 32 |
+
INFO = {
|
| 33 |
+
"description": args.split_name + " Dataset",
|
| 34 |
+
"url": "",
|
| 35 |
+
"version": "",
|
| 36 |
+
"year": 2019,
|
| 37 |
+
"contributor": "xyq",
|
| 38 |
+
"date_created": datetime.datetime.utcnow().isoformat(' ')
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
LICENSES = [
|
| 42 |
+
{
|
| 43 |
+
"id": 1,
|
| 44 |
+
"name": "",
|
| 45 |
+
"url": ""
|
| 46 |
+
}
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
CATEGORIES = [
|
| 50 |
+
{
|
| 51 |
+
'id': 1,
|
| 52 |
+
'name': 'person',
|
| 53 |
+
'supercategory': 'person',
|
| 54 |
+
},
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
coco_output = {
|
| 58 |
+
"info": INFO,
|
| 59 |
+
"licenses": LICENSES,
|
| 60 |
+
"categories": CATEGORIES,
|
| 61 |
+
"images": [],
|
| 62 |
+
"annotations": []
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
image_id = 1
|
| 66 |
+
segmentation_id = 1
|
| 67 |
+
|
| 68 |
+
for image_name in os.listdir(args.train_img_dir):
|
| 69 |
+
image = Image.open(os.path.join(args.train_img_dir, image_name))
|
| 70 |
+
image_info = pycococreatortools.create_image_info(
|
| 71 |
+
image_id, image_name, image.size
|
| 72 |
+
)
|
| 73 |
+
coco_output["images"].append(image_info)
|
| 74 |
+
|
| 75 |
+
human_mask_name = os.path.splitext(image_name)[0] + '.png'
|
| 76 |
+
human_mask = np.asarray(Image.open(os.path.join(args.train_anno_dir, human_mask_name)))
|
| 77 |
+
human_gt_labels = np.unique(human_mask)
|
| 78 |
+
|
| 79 |
+
for i in range(1, len(human_gt_labels)):
|
| 80 |
+
category_info = {'id': 1, 'is_crowd': 0}
|
| 81 |
+
binary_mask = np.uint8(human_mask == i)
|
| 82 |
+
annotation_info = pycococreatortools.create_annotation_info(
|
| 83 |
+
segmentation_id, image_id, category_info, binary_mask,
|
| 84 |
+
image.size, tolerance=10
|
| 85 |
+
)
|
| 86 |
+
if annotation_info is not None:
|
| 87 |
+
coco_output["annotations"].append(annotation_info)
|
| 88 |
+
|
| 89 |
+
segmentation_id += 1
|
| 90 |
+
image_id += 1
|
| 91 |
+
|
| 92 |
+
if not os.path.exists(args.json_save_dir):
|
| 93 |
+
os.makedirs(args.json_save_dir)
|
| 94 |
+
if not args.use_val:
|
| 95 |
+
with open('{}/{}_train.json'.format(args.json_save_dir, args.split_name), 'w') as output_json_file:
|
| 96 |
+
json.dump(coco_output, output_json_file)
|
| 97 |
+
else:
|
| 98 |
+
for image_name in os.listdir(args.val_img_dir):
|
| 99 |
+
image = Image.open(os.path.join(args.val_img_dir, image_name))
|
| 100 |
+
image_info = pycococreatortools.create_image_info(
|
| 101 |
+
image_id, image_name, image.size
|
| 102 |
+
)
|
| 103 |
+
coco_output["images"].append(image_info)
|
| 104 |
+
|
| 105 |
+
human_mask_name = os.path.splitext(image_name)[0] + '.png'
|
| 106 |
+
human_mask = np.asarray(Image.open(os.path.join(args.val_anno_dir, human_mask_name)))
|
| 107 |
+
human_gt_labels = np.unique(human_mask)
|
| 108 |
+
|
| 109 |
+
for i in range(1, len(human_gt_labels)):
|
| 110 |
+
category_info = {'id': 1, 'is_crowd': 0}
|
| 111 |
+
binary_mask = np.uint8(human_mask == i)
|
| 112 |
+
annotation_info = pycococreatortools.create_annotation_info(
|
| 113 |
+
segmentation_id, image_id, category_info, binary_mask,
|
| 114 |
+
image.size, tolerance=10
|
| 115 |
+
)
|
| 116 |
+
if annotation_info is not None:
|
| 117 |
+
coco_output["annotations"].append(annotation_info)
|
| 118 |
+
|
| 119 |
+
segmentation_id += 1
|
| 120 |
+
image_id += 1
|
| 121 |
+
|
| 122 |
+
with open('{}/{}_trainval.json'.format(args.json_save_dir, args.split_name), 'w') as output_json_file:
|
| 123 |
+
json.dump(coco_output, output_json_file)
|
| 124 |
+
|
| 125 |
+
coco_output_val = {
|
| 126 |
+
"info": INFO,
|
| 127 |
+
"licenses": LICENSES,
|
| 128 |
+
"categories": CATEGORIES,
|
| 129 |
+
"images": [],
|
| 130 |
+
"annotations": []
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
image_id_val = 1
|
| 134 |
+
segmentation_id_val = 1
|
| 135 |
+
|
| 136 |
+
for image_name in os.listdir(args.val_img_dir):
|
| 137 |
+
image = Image.open(os.path.join(args.val_img_dir, image_name))
|
| 138 |
+
image_info = pycococreatortools.create_image_info(
|
| 139 |
+
image_id_val, image_name, image.size
|
| 140 |
+
)
|
| 141 |
+
coco_output_val["images"].append(image_info)
|
| 142 |
+
|
| 143 |
+
human_mask_name = os.path.splitext(image_name)[0] + '.png'
|
| 144 |
+
human_mask = np.asarray(Image.open(os.path.join(args.val_anno_dir, human_mask_name)))
|
| 145 |
+
human_gt_labels = np.unique(human_mask)
|
| 146 |
+
|
| 147 |
+
for i in range(1, len(human_gt_labels)):
|
| 148 |
+
category_info = {'id': 1, 'is_crowd': 0}
|
| 149 |
+
binary_mask = np.uint8(human_mask == i)
|
| 150 |
+
annotation_info = pycococreatortools.create_annotation_info(
|
| 151 |
+
segmentation_id_val, image_id_val, category_info, binary_mask,
|
| 152 |
+
image.size, tolerance=10
|
| 153 |
+
)
|
| 154 |
+
if annotation_info is not None:
|
| 155 |
+
coco_output_val["annotations"].append(annotation_info)
|
| 156 |
+
|
| 157 |
+
segmentation_id_val += 1
|
| 158 |
+
image_id_val += 1
|
| 159 |
+
|
| 160 |
+
with open('{}/{}_val.json'.format(args.json_save_dir, args.split_name), 'w') as output_json_file_val:
|
| 161 |
+
json.dump(coco_output_val, output_json_file_val)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
args = get_arguments()
|
| 166 |
+
main(args)
|
mhp_extension/coco_style_annotation_creator/pycococreatortools.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import datetime
|
| 3 |
+
import numpy as np
|
| 4 |
+
from itertools import groupby
|
| 5 |
+
from skimage import measure
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from pycocotools import mask
|
| 8 |
+
|
| 9 |
+
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
| 10 |
+
natrual_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resize_binary_mask(array, new_size):
|
| 14 |
+
image = Image.fromarray(array.astype(np.uint8) * 255)
|
| 15 |
+
image = image.resize(new_size)
|
| 16 |
+
return np.asarray(image).astype(np.bool_)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def close_contour(contour):
|
| 20 |
+
if not np.array_equal(contour[0], contour[-1]):
|
| 21 |
+
contour = np.vstack((contour, contour[0]))
|
| 22 |
+
return contour
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def binary_mask_to_rle(binary_mask):
|
| 26 |
+
rle = {'counts': [], 'size': list(binary_mask.shape)}
|
| 27 |
+
counts = rle.get('counts')
|
| 28 |
+
for i, (value, elements) in enumerate(groupby(binary_mask.ravel(order='F'))):
|
| 29 |
+
if i == 0 and value == 1:
|
| 30 |
+
counts.append(0)
|
| 31 |
+
counts.append(len(list(elements)))
|
| 32 |
+
|
| 33 |
+
return rle
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def binary_mask_to_polygon(binary_mask, tolerance=0):
|
| 37 |
+
"""Converts a binary mask to COCO polygon representation
|
| 38 |
+
Args:
|
| 39 |
+
binary_mask: a 2D binary numpy array where '1's represent the object
|
| 40 |
+
tolerance: Maximum distance from original points of polygon to approximated
|
| 41 |
+
polygonal chain. If tolerance is 0, the original coordinate array is returned.
|
| 42 |
+
"""
|
| 43 |
+
polygons = []
|
| 44 |
+
# pad mask to close contours of shapes which start and end at an edge
|
| 45 |
+
padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
|
| 46 |
+
contours = measure.find_contours(padded_binary_mask, 0.5)
|
| 47 |
+
contours = np.subtract(contours, 1)
|
| 48 |
+
for contour in contours:
|
| 49 |
+
contour = close_contour(contour)
|
| 50 |
+
contour = measure.approximate_polygon(contour, tolerance)
|
| 51 |
+
if len(contour) < 3:
|
| 52 |
+
continue
|
| 53 |
+
contour = np.flip(contour, axis=1)
|
| 54 |
+
segmentation = contour.ravel().tolist()
|
| 55 |
+
# after padding and subtracting 1 we may get -0.5 points in our segmentation
|
| 56 |
+
segmentation = [0 if i < 0 else i for i in segmentation]
|
| 57 |
+
polygons.append(segmentation)
|
| 58 |
+
|
| 59 |
+
return polygons
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_image_info(image_id, file_name, image_size,
|
| 63 |
+
date_captured=datetime.datetime.utcnow().isoformat(' '),
|
| 64 |
+
license_id=1, coco_url="", flickr_url=""):
|
| 65 |
+
image_info = {
|
| 66 |
+
"id": image_id,
|
| 67 |
+
"file_name": file_name,
|
| 68 |
+
"width": image_size[0],
|
| 69 |
+
"height": image_size[1],
|
| 70 |
+
"date_captured": date_captured,
|
| 71 |
+
"license": license_id,
|
| 72 |
+
"coco_url": coco_url,
|
| 73 |
+
"flickr_url": flickr_url
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return image_info
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def create_annotation_info(annotation_id, image_id, category_info, binary_mask,
|
| 80 |
+
image_size=None, tolerance=2, bounding_box=None):
|
| 81 |
+
if image_size is not None:
|
| 82 |
+
binary_mask = resize_binary_mask(binary_mask, image_size)
|
| 83 |
+
|
| 84 |
+
binary_mask_encoded = mask.encode(np.asfortranarray(binary_mask.astype(np.uint8)))
|
| 85 |
+
|
| 86 |
+
area = mask.area(binary_mask_encoded)
|
| 87 |
+
if area < 1:
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
if bounding_box is None:
|
| 91 |
+
bounding_box = mask.toBbox(binary_mask_encoded)
|
| 92 |
+
|
| 93 |
+
if category_info["is_crowd"]:
|
| 94 |
+
is_crowd = 1
|
| 95 |
+
segmentation = binary_mask_to_rle(binary_mask)
|
| 96 |
+
else:
|
| 97 |
+
is_crowd = 0
|
| 98 |
+
segmentation = binary_mask_to_polygon(binary_mask, tolerance)
|
| 99 |
+
if not segmentation:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
annotation_info = {
|
| 103 |
+
"id": annotation_id,
|
| 104 |
+
"image_id": image_id,
|
| 105 |
+
"category_id": category_info["id"],
|
| 106 |
+
"iscrowd": is_crowd,
|
| 107 |
+
"area": area.tolist(),
|
| 108 |
+
"bbox": bounding_box.tolist(),
|
| 109 |
+
"segmentation": segmentation,
|
| 110 |
+
"width": binary_mask.shape[1],
|
| 111 |
+
"height": binary_mask.shape[0],
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
return annotation_info
|
mhp_extension/coco_style_annotation_creator/test_human2coco_format.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import pycococreatortools
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_arguments():
|
| 11 |
+
parser = argparse.ArgumentParser(description="transform mask annotation to coco annotation")
|
| 12 |
+
parser.add_argument("--dataset", type=str, default='CIHP', help="name of dataset (CIHP, MHPv2 or VIP)")
|
| 13 |
+
parser.add_argument("--json_save_dir", type=str, default='../data/CIHP/annotations',
|
| 14 |
+
help="path to save coco-style annotation json file")
|
| 15 |
+
parser.add_argument("--test_img_dir", type=str, default='../data/CIHP/Testing/Images',
|
| 16 |
+
help="test image path")
|
| 17 |
+
return parser.parse_args()
|
| 18 |
+
|
| 19 |
+
args = get_arguments()
|
| 20 |
+
|
| 21 |
+
INFO = {
|
| 22 |
+
"description": args.dataset + "Dataset",
|
| 23 |
+
"url": "",
|
| 24 |
+
"version": "",
|
| 25 |
+
"year": 2020,
|
| 26 |
+
"contributor": "yunqiuxu",
|
| 27 |
+
"date_created": datetime.datetime.utcnow().isoformat(' ')
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
LICENSES = [
|
| 31 |
+
{
|
| 32 |
+
"id": 1,
|
| 33 |
+
"name": "",
|
| 34 |
+
"url": ""
|
| 35 |
+
}
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
CATEGORIES = [
|
| 39 |
+
{
|
| 40 |
+
'id': 1,
|
| 41 |
+
'name': 'person',
|
| 42 |
+
'supercategory': 'person',
|
| 43 |
+
},
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main(args):
|
| 48 |
+
coco_output = {
|
| 49 |
+
"info": INFO,
|
| 50 |
+
"licenses": LICENSES,
|
| 51 |
+
"categories": CATEGORIES,
|
| 52 |
+
"images": [],
|
| 53 |
+
"annotations": []
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
image_id = 1
|
| 57 |
+
|
| 58 |
+
for image_name in os.listdir(args.test_img_dir):
|
| 59 |
+
image = Image.open(os.path.join(args.test_img_dir, image_name))
|
| 60 |
+
image_info = pycococreatortools.create_image_info(
|
| 61 |
+
image_id, image_name, image.size
|
| 62 |
+
)
|
| 63 |
+
coco_output["images"].append(image_info)
|
| 64 |
+
image_id += 1
|
| 65 |
+
|
| 66 |
+
if not os.path.exists(os.path.join(args.json_save_dir)):
|
| 67 |
+
os.mkdir(os.path.join(args.json_save_dir))
|
| 68 |
+
|
| 69 |
+
with open('{}/{}.json'.format(args.json_save_dir, args.dataset), 'w') as output_json_file:
|
| 70 |
+
json.dump(coco_output, output_json_file)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main(args)
|
mhp_extension/demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
mhp_extension/detectron2/.circleci/config.yml
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python CircleCI 2.0 configuration file
|
| 2 |
+
#
|
| 3 |
+
# Check https://circleci.com/docs/2.0/language-python/ for more details
|
| 4 |
+
#
|
| 5 |
+
version: 2
|
| 6 |
+
|
| 7 |
+
# -------------------------------------------------------------------------------------
|
| 8 |
+
# Environments to run the jobs in
|
| 9 |
+
# -------------------------------------------------------------------------------------
|
| 10 |
+
cpu: &cpu
|
| 11 |
+
docker:
|
| 12 |
+
- image: circleci/python:3.6.8-stretch
|
| 13 |
+
resource_class: medium
|
| 14 |
+
|
| 15 |
+
gpu: &gpu
|
| 16 |
+
machine:
|
| 17 |
+
image: ubuntu-1604:201903-01
|
| 18 |
+
docker_layer_caching: true
|
| 19 |
+
resource_class: gpu.small
|
| 20 |
+
|
| 21 |
+
# -------------------------------------------------------------------------------------
|
| 22 |
+
# Re-usable commands
|
| 23 |
+
# -------------------------------------------------------------------------------------
|
| 24 |
+
install_python: &install_python
|
| 25 |
+
- run:
|
| 26 |
+
name: Install Python
|
| 27 |
+
working_directory: ~/
|
| 28 |
+
command: |
|
| 29 |
+
pyenv install 3.6.1
|
| 30 |
+
pyenv global 3.6.1
|
| 31 |
+
|
| 32 |
+
setup_venv: &setup_venv
|
| 33 |
+
- run:
|
| 34 |
+
name: Setup Virtual Env
|
| 35 |
+
working_directory: ~/
|
| 36 |
+
command: |
|
| 37 |
+
python -m venv ~/venv
|
| 38 |
+
echo ". ~/venv/bin/activate" >> $BASH_ENV
|
| 39 |
+
. ~/venv/bin/activate
|
| 40 |
+
python --version
|
| 41 |
+
which python
|
| 42 |
+
which pip
|
| 43 |
+
pip install --upgrade pip
|
| 44 |
+
|
| 45 |
+
install_dep: &install_dep
|
| 46 |
+
- run:
|
| 47 |
+
name: Install Dependencies
|
| 48 |
+
command: |
|
| 49 |
+
pip install --progress-bar off -U 'git+https://github.com/facebookresearch/fvcore'
|
| 50 |
+
pip install --progress-bar off cython opencv-python
|
| 51 |
+
pip install --progress-bar off 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
|
| 52 |
+
pip install --progress-bar off torch torchvision
|
| 53 |
+
|
| 54 |
+
install_detectron2: &install_detectron2
|
| 55 |
+
- run:
|
| 56 |
+
name: Install Detectron2
|
| 57 |
+
command: |
|
| 58 |
+
gcc --version
|
| 59 |
+
pip install -U --progress-bar off -e .[dev]
|
| 60 |
+
python -m detectron2.utils.collect_env
|
| 61 |
+
|
| 62 |
+
install_nvidia_driver: &install_nvidia_driver
|
| 63 |
+
- run:
|
| 64 |
+
name: Install nvidia driver
|
| 65 |
+
working_directory: ~/
|
| 66 |
+
command: |
|
| 67 |
+
wget -q 'https://s3.amazonaws.com/ossci-linux/nvidia_driver/NVIDIA-Linux-x86_64-430.40.run'
|
| 68 |
+
sudo /bin/bash ./NVIDIA-Linux-x86_64-430.40.run -s --no-drm
|
| 69 |
+
nvidia-smi
|
| 70 |
+
|
| 71 |
+
run_unittests: &run_unittests
|
| 72 |
+
- run:
|
| 73 |
+
name: Run Unit Tests
|
| 74 |
+
command: |
|
| 75 |
+
python -m unittest discover -v -s tests
|
| 76 |
+
|
| 77 |
+
# -------------------------------------------------------------------------------------
|
| 78 |
+
# Jobs to run
|
| 79 |
+
# -------------------------------------------------------------------------------------
|
| 80 |
+
jobs:
|
| 81 |
+
cpu_tests:
|
| 82 |
+
<<: *cpu
|
| 83 |
+
|
| 84 |
+
working_directory: ~/detectron2
|
| 85 |
+
|
| 86 |
+
steps:
|
| 87 |
+
- checkout
|
| 88 |
+
- <<: *setup_venv
|
| 89 |
+
|
| 90 |
+
# Cache the venv directory that contains dependencies
|
| 91 |
+
- restore_cache:
|
| 92 |
+
keys:
|
| 93 |
+
- cache-key-{{ .Branch }}-ID-20200425
|
| 94 |
+
|
| 95 |
+
- <<: *install_dep
|
| 96 |
+
|
| 97 |
+
- save_cache:
|
| 98 |
+
paths:
|
| 99 |
+
- ~/venv
|
| 100 |
+
key: cache-key-{{ .Branch }}-ID-20200425
|
| 101 |
+
|
| 102 |
+
- <<: *install_detectron2
|
| 103 |
+
|
| 104 |
+
- run:
|
| 105 |
+
name: isort
|
| 106 |
+
command: |
|
| 107 |
+
isort -c -sp .
|
| 108 |
+
- run:
|
| 109 |
+
name: black
|
| 110 |
+
command: |
|
| 111 |
+
black --check -l 100 .
|
| 112 |
+
- run:
|
| 113 |
+
name: flake8
|
| 114 |
+
command: |
|
| 115 |
+
flake8 .
|
| 116 |
+
|
| 117 |
+
- <<: *run_unittests
|
| 118 |
+
|
| 119 |
+
gpu_tests:
|
| 120 |
+
<<: *gpu
|
| 121 |
+
|
| 122 |
+
working_directory: ~/detectron2
|
| 123 |
+
|
| 124 |
+
steps:
|
| 125 |
+
- checkout
|
| 126 |
+
- <<: *install_nvidia_driver
|
| 127 |
+
|
| 128 |
+
- run:
|
| 129 |
+
name: Install nvidia-docker
|
| 130 |
+
working_directory: ~/
|
| 131 |
+
command: |
|
| 132 |
+
curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
|
| 133 |
+
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
|
| 134 |
+
curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \
|
| 135 |
+
sudo tee /etc/apt/sources.list.d/nvidia-docker.list
|
| 136 |
+
sudo apt-get update && sudo apt-get install -y nvidia-docker2
|
| 137 |
+
# reload the docker daemon configuration
|
| 138 |
+
sudo pkill -SIGHUP dockerd
|
| 139 |
+
|
| 140 |
+
- run:
|
| 141 |
+
name: Launch docker
|
| 142 |
+
working_directory: ~/detectron2/docker
|
| 143 |
+
command: |
|
| 144 |
+
nvidia-docker build -t detectron2:v0 -f Dockerfile-circleci .
|
| 145 |
+
nvidia-docker run -itd --name d2 detectron2:v0
|
| 146 |
+
docker exec -it d2 nvidia-smi
|
| 147 |
+
|
| 148 |
+
- run:
|
| 149 |
+
name: Build Detectron2
|
| 150 |
+
command: |
|
| 151 |
+
docker exec -it d2 pip install 'git+https://github.com/facebookresearch/fvcore'
|
| 152 |
+
docker cp ~/detectron2 d2:/detectron2
|
| 153 |
+
# This will build d2 for the target GPU arch only
|
| 154 |
+
docker exec -it d2 pip install -e /detectron2
|
| 155 |
+
docker exec -it d2 python3 -m detectron2.utils.collect_env
|
| 156 |
+
docker exec -it d2 python3 -c 'import torch; assert(torch.cuda.is_available())'
|
| 157 |
+
|
| 158 |
+
- run:
|
| 159 |
+
name: Run Unit Tests
|
| 160 |
+
command: |
|
| 161 |
+
docker exec -e CIRCLECI=true -it d2 python3 -m unittest discover -v -s /detectron2/tests
|
| 162 |
+
|
| 163 |
+
workflows:
|
| 164 |
+
version: 2
|
| 165 |
+
regular_test:
|
| 166 |
+
jobs:
|
| 167 |
+
- cpu_tests
|
| 168 |
+
- gpu_tests
|
| 169 |
+
|
| 170 |
+
#nightly_test:
|
| 171 |
+
#jobs:
|
| 172 |
+
#- gpu_tests
|
| 173 |
+
#triggers:
|
| 174 |
+
#- schedule:
|
| 175 |
+
#cron: "0 0 * * *"
|
| 176 |
+
#filters:
|
| 177 |
+
#branches:
|
| 178 |
+
#only:
|
| 179 |
+
#- master
|
mhp_extension/detectron2/.clang-format
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
AccessModifierOffset: -1
|
| 2 |
+
AlignAfterOpenBracket: AlwaysBreak
|
| 3 |
+
AlignConsecutiveAssignments: false
|
| 4 |
+
AlignConsecutiveDeclarations: false
|
| 5 |
+
AlignEscapedNewlinesLeft: true
|
| 6 |
+
AlignOperands: false
|
| 7 |
+
AlignTrailingComments: false
|
| 8 |
+
AllowAllParametersOfDeclarationOnNextLine: false
|
| 9 |
+
AllowShortBlocksOnASingleLine: false
|
| 10 |
+
AllowShortCaseLabelsOnASingleLine: false
|
| 11 |
+
AllowShortFunctionsOnASingleLine: Empty
|
| 12 |
+
AllowShortIfStatementsOnASingleLine: false
|
| 13 |
+
AllowShortLoopsOnASingleLine: false
|
| 14 |
+
AlwaysBreakAfterReturnType: None
|
| 15 |
+
AlwaysBreakBeforeMultilineStrings: true
|
| 16 |
+
AlwaysBreakTemplateDeclarations: true
|
| 17 |
+
BinPackArguments: false
|
| 18 |
+
BinPackParameters: false
|
| 19 |
+
BraceWrapping:
|
| 20 |
+
AfterClass: false
|
| 21 |
+
AfterControlStatement: false
|
| 22 |
+
AfterEnum: false
|
| 23 |
+
AfterFunction: false
|
| 24 |
+
AfterNamespace: false
|
| 25 |
+
AfterObjCDeclaration: false
|
| 26 |
+
AfterStruct: false
|
| 27 |
+
AfterUnion: false
|
| 28 |
+
BeforeCatch: false
|
| 29 |
+
BeforeElse: false
|
| 30 |
+
IndentBraces: false
|
| 31 |
+
BreakBeforeBinaryOperators: None
|
| 32 |
+
BreakBeforeBraces: Attach
|
| 33 |
+
BreakBeforeTernaryOperators: true
|
| 34 |
+
BreakConstructorInitializersBeforeComma: false
|
| 35 |
+
BreakAfterJavaFieldAnnotations: false
|
| 36 |
+
BreakStringLiterals: false
|
| 37 |
+
ColumnLimit: 80
|
| 38 |
+
CommentPragmas: '^ IWYU pragma:'
|
| 39 |
+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
| 40 |
+
ConstructorInitializerIndentWidth: 4
|
| 41 |
+
ContinuationIndentWidth: 4
|
| 42 |
+
Cpp11BracedListStyle: true
|
| 43 |
+
DerivePointerAlignment: false
|
| 44 |
+
DisableFormat: false
|
| 45 |
+
ForEachMacros: [ FOR_EACH, FOR_EACH_ENUMERATE, FOR_EACH_KV, FOR_EACH_R, FOR_EACH_RANGE, ]
|
| 46 |
+
IncludeCategories:
|
| 47 |
+
- Regex: '^<.*\.h(pp)?>'
|
| 48 |
+
Priority: 1
|
| 49 |
+
- Regex: '^<.*'
|
| 50 |
+
Priority: 2
|
| 51 |
+
- Regex: '.*'
|
| 52 |
+
Priority: 3
|
| 53 |
+
IndentCaseLabels: true
|
| 54 |
+
IndentWidth: 2
|
| 55 |
+
IndentWrappedFunctionNames: false
|
| 56 |
+
KeepEmptyLinesAtTheStartOfBlocks: false
|
| 57 |
+
MacroBlockBegin: ''
|
| 58 |
+
MacroBlockEnd: ''
|
| 59 |
+
MaxEmptyLinesToKeep: 1
|
| 60 |
+
NamespaceIndentation: None
|
| 61 |
+
ObjCBlockIndentWidth: 2
|
| 62 |
+
ObjCSpaceAfterProperty: false
|
| 63 |
+
ObjCSpaceBeforeProtocolList: false
|
| 64 |
+
PenaltyBreakBeforeFirstCallParameter: 1
|
| 65 |
+
PenaltyBreakComment: 300
|
| 66 |
+
PenaltyBreakFirstLessLess: 120
|
| 67 |
+
PenaltyBreakString: 1000
|
| 68 |
+
PenaltyExcessCharacter: 1000000
|
| 69 |
+
PenaltyReturnTypeOnItsOwnLine: 200
|
| 70 |
+
PointerAlignment: Left
|
| 71 |
+
ReflowComments: true
|
| 72 |
+
SortIncludes: true
|
| 73 |
+
SpaceAfterCStyleCast: false
|
| 74 |
+
SpaceBeforeAssignmentOperators: true
|
| 75 |
+
SpaceBeforeParens: ControlStatements
|
| 76 |
+
SpaceInEmptyParentheses: false
|
| 77 |
+
SpacesBeforeTrailingComments: 1
|
| 78 |
+
SpacesInAngles: false
|
| 79 |
+
SpacesInContainerLiterals: true
|
| 80 |
+
SpacesInCStyleCastParentheses: false
|
| 81 |
+
SpacesInParentheses: false
|
| 82 |
+
SpacesInSquareBrackets: false
|
| 83 |
+
Standard: Cpp11
|
| 84 |
+
TabWidth: 8
|
| 85 |
+
UseTab: Never
|
mhp_extension/detectron2/.flake8
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is an example .flake8 config, used when developing *Black* itself.
|
| 2 |
+
# Keep in sync with setup.cfg which is used for source packages.
|
| 3 |
+
|
| 4 |
+
[flake8]
|
| 5 |
+
ignore = W503, E203, E221, C901, C408, E741
|
| 6 |
+
max-line-length = 100
|
| 7 |
+
max-complexity = 18
|
| 8 |
+
select = B,C,E,F,W,T4,B9
|
| 9 |
+
exclude = build,__init__.py
|
mhp_extension/detectron2/.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# output dir
|
| 2 |
+
output
|
| 3 |
+
instant_test_output
|
| 4 |
+
inference_test_output
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
*.jpg
|
| 8 |
+
*.png
|
| 9 |
+
*.txt
|
| 10 |
+
*.json
|
| 11 |
+
*.diff
|
| 12 |
+
|
| 13 |
+
# compilation and distribution
|
| 14 |
+
__pycache__
|
| 15 |
+
_ext
|
| 16 |
+
*.pyc
|
| 17 |
+
*.so
|
| 18 |
+
detectron2.egg-info/
|
| 19 |
+
build/
|
| 20 |
+
dist/
|
| 21 |
+
wheels/
|
| 22 |
+
|
| 23 |
+
# pytorch/python/numpy formats
|
| 24 |
+
*.pth
|
| 25 |
+
*.pkl
|
| 26 |
+
*.npy
|
| 27 |
+
|
| 28 |
+
# ipython/jupyter notebooks
|
| 29 |
+
*.ipynb
|
| 30 |
+
**/.ipynb_checkpoints/
|
| 31 |
+
|
| 32 |
+
# Editor temporaries
|
| 33 |
+
*.swn
|
| 34 |
+
*.swo
|
| 35 |
+
*.swp
|
| 36 |
+
*~
|
| 37 |
+
|
| 38 |
+
# editor settings
|
| 39 |
+
.idea
|
| 40 |
+
.vscode
|
| 41 |
+
|
| 42 |
+
# project dirs
|
| 43 |
+
/detectron2/model_zoo/configs
|
| 44 |
+
/datasets
|
| 45 |
+
/projects/*/datasets
|
| 46 |
+
/models
|
mhp_extension/detectron2/GETTING_STARTED.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Getting Started with Detectron2
|
| 2 |
+
|
| 3 |
+
This document provides a brief intro of the usage of builtin command-line tools in detectron2.
|
| 4 |
+
|
| 5 |
+
For a tutorial that involves actual coding with the API,
|
| 6 |
+
see our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
|
| 7 |
+
which covers how to run inference with an
|
| 8 |
+
existing model, and how to train a builtin model on a custom dataset.
|
| 9 |
+
|
| 10 |
+
For more advanced tutorials, refer to our [documentation](https://detectron2.readthedocs.io/tutorials/extend.html).
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
### Inference Demo with Pre-trained Models
|
| 14 |
+
|
| 15 |
+
1. Pick a model and its config file from
|
| 16 |
+
[model zoo](MODEL_ZOO.md),
|
| 17 |
+
for example, `mask_rcnn_R_50_FPN_3x.yaml`.
|
| 18 |
+
2. We provide `demo.py` that is able to run builtin standard models. Run it with:
|
| 19 |
+
```
|
| 20 |
+
cd demo/
|
| 21 |
+
python demo.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
|
| 22 |
+
--input input1.jpg input2.jpg \
|
| 23 |
+
[--other-options]
|
| 24 |
+
--opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl
|
| 25 |
+
```
|
| 26 |
+
The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation.
|
| 27 |
+
This command will run the inference and show visualizations in an OpenCV window.
|
| 28 |
+
|
| 29 |
+
For details of the command line arguments, see `demo.py -h` or look at its source code
|
| 30 |
+
to understand its behavior. Some common arguments are:
|
| 31 |
+
* To run __on your webcam__, replace `--input files` with `--webcam`.
|
| 32 |
+
* To run __on a video__, replace `--input files` with `--video-input video.mp4`.
|
| 33 |
+
* To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`.
|
| 34 |
+
* To save outputs to a directory (for images) or a file (for webcam or video), use `--output`.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
### Training & Evaluation in Command Line
|
| 38 |
+
|
| 39 |
+
We provide a script in "tools/{,plain_}train_net.py", that is made to train
|
| 40 |
+
all the configs provided in detectron2.
|
| 41 |
+
You may want to use it as a reference to write your own training script.
|
| 42 |
+
|
| 43 |
+
To train a model with "train_net.py", first
|
| 44 |
+
setup the corresponding datasets following
|
| 45 |
+
[datasets/README.md](./datasets/README.md),
|
| 46 |
+
then run:
|
| 47 |
+
```
|
| 48 |
+
cd tools/
|
| 49 |
+
./train_net.py --num-gpus 8 \
|
| 50 |
+
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
The configs are made for 8-GPU training.
|
| 54 |
+
To train on 1 GPU, you may need to [change some parameters](https://arxiv.org/abs/1706.02677), e.g.:
|
| 55 |
+
```
|
| 56 |
+
./train_net.py \
|
| 57 |
+
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
|
| 58 |
+
--num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
For most models, CPU training is not supported.
|
| 62 |
+
|
| 63 |
+
To evaluate a model's performance, use
|
| 64 |
+
```
|
| 65 |
+
./train_net.py \
|
| 66 |
+
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
|
| 67 |
+
--eval-only MODEL.WEIGHTS /path/to/checkpoint_file
|
| 68 |
+
```
|
| 69 |
+
For more options, see `./train_net.py -h`.
|
| 70 |
+
|
| 71 |
+
### Use Detectron2 APIs in Your Code
|
| 72 |
+
|
| 73 |
+
See our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
|
| 74 |
+
to learn how to use detectron2 APIs to:
|
| 75 |
+
1. run inference with an existing model
|
| 76 |
+
2. train a builtin model on a custom dataset
|
| 77 |
+
|
| 78 |
+
See [detectron2/projects](https://github.com/facebookresearch/detectron2/tree/master/projects)
|
| 79 |
+
for more ways to build your project on detectron2.
|
mhp_extension/detectron2/INSTALL.md
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Installation
|
| 2 |
+
|
| 3 |
+
Our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
|
| 4 |
+
has step-by-step instructions that install detectron2.
|
| 5 |
+
The [Dockerfile](docker)
|
| 6 |
+
also installs detectron2 with a few simple commands.
|
| 7 |
+
|
| 8 |
+
### Requirements
|
| 9 |
+
- Linux or macOS with Python ≥ 3.6
|
| 10 |
+
- PyTorch ≥ 1.4
|
| 11 |
+
- [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
|
| 12 |
+
You can install them together at [pytorch.org](https://pytorch.org) to make sure of this.
|
| 13 |
+
- OpenCV, optional, needed by demo and visualization
|
| 14 |
+
- pycocotools: `pip install cython; pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'`
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
### Build Detectron2 from Source
|
| 18 |
+
|
| 19 |
+
gcc & g++ ≥ 5 are required. [ninja](https://ninja-build.org/) is recommended for faster build.
|
| 20 |
+
After having them, run:
|
| 21 |
+
```
|
| 22 |
+
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
|
| 23 |
+
# (add --user if you don't have permission)
|
| 24 |
+
|
| 25 |
+
# Or, to install it from a local clone:
|
| 26 |
+
git clone https://github.com/facebookresearch/detectron2.git
|
| 27 |
+
python -m pip install -e detectron2
|
| 28 |
+
|
| 29 |
+
# Or if you are on macOS
|
| 30 |
+
# CC=clang CXX=clang++ python -m pip install -e .
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
To __rebuild__ detectron2 that's built from a local clone, use `rm -rf build/ **/*.so` to clean the
|
| 34 |
+
old build first. You often need to rebuild detectron2 after reinstalling PyTorch.
|
| 35 |
+
|
| 36 |
+
### Install Pre-Built Detectron2 (Linux only)
|
| 37 |
+
```
|
| 38 |
+
# for CUDA 10.1:
|
| 39 |
+
python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/index.html
|
| 40 |
+
```
|
| 41 |
+
You can replace cu101 with "cu{100,92}" or "cpu".
|
| 42 |
+
|
| 43 |
+
Note that:
|
| 44 |
+
1. Such installation has to be used with certain version of official PyTorch release.
|
| 45 |
+
See [releases](https://github.com/facebookresearch/detectron2/releases) for requirements.
|
| 46 |
+
It will not work with a different version of PyTorch or a non-official build of PyTorch.
|
| 47 |
+
2. Such installation is out-of-date w.r.t. master branch of detectron2. It may not be
|
| 48 |
+
compatible with the master branch of a research project that uses detectron2 (e.g. those in
|
| 49 |
+
[projects](projects) or [meshrcnn](https://github.com/facebookresearch/meshrcnn/)).
|
| 50 |
+
|
| 51 |
+
### Common Installation Issues
|
| 52 |
+
|
| 53 |
+
If you met issues using the pre-built detectron2, please uninstall it and try building it from source.
|
| 54 |
+
|
| 55 |
+
Click each issue for its solutions:
|
| 56 |
+
|
| 57 |
+
<details>
|
| 58 |
+
<summary>
|
| 59 |
+
Undefined torch/aten/caffe2 symbols, or segmentation fault immediately when running the library.
|
| 60 |
+
</summary>
|
| 61 |
+
<br/>
|
| 62 |
+
|
| 63 |
+
This usually happens when detectron2 or torchvision is not
|
| 64 |
+
compiled with the version of PyTorch you're running.
|
| 65 |
+
|
| 66 |
+
Pre-built torchvision or detectron2 has to work with the corresponding official release of pytorch.
|
| 67 |
+
If the error comes from a pre-built torchvision, uninstall torchvision and pytorch and reinstall them
|
| 68 |
+
following [pytorch.org](http://pytorch.org). So the versions will match.
|
| 69 |
+
|
| 70 |
+
If the error comes from a pre-built detectron2, check [release notes](https://github.com/facebookresearch/detectron2/releases)
|
| 71 |
+
to see the corresponding pytorch version required for each pre-built detectron2.
|
| 72 |
+
|
| 73 |
+
If the error comes from detectron2 or torchvision that you built manually from source,
|
| 74 |
+
remove files you built (`build/`, `**/*.so`) and rebuild it so it can pick up the version of pytorch currently in your environment.
|
| 75 |
+
|
| 76 |
+
If you cannot resolve this problem, please include the output of `gdb -ex "r" -ex "bt" -ex "quit" --args python -m detectron2.utils.collect_env`
|
| 77 |
+
in your issue.
|
| 78 |
+
</details>
|
| 79 |
+
|
| 80 |
+
<details>
|
| 81 |
+
<summary>
|
| 82 |
+
Undefined C++ symbols (e.g. `GLIBCXX`) or C++ symbols not found.
|
| 83 |
+
</summary>
|
| 84 |
+
<br/>
|
| 85 |
+
Usually it's because the library is compiled with a newer C++ compiler but run with an old C++ runtime.
|
| 86 |
+
|
| 87 |
+
This often happens with old anaconda.
|
| 88 |
+
Try `conda update libgcc`. Then rebuild detectron2.
|
| 89 |
+
|
| 90 |
+
The fundamental solution is to run the code with proper C++ runtime.
|
| 91 |
+
One way is to use `LD_PRELOAD=/path/to/libstdc++.so`.
|
| 92 |
+
|
| 93 |
+
</details>
|
| 94 |
+
|
| 95 |
+
<details>
|
| 96 |
+
<summary>
|
| 97 |
+
"Not compiled with GPU support" or "Detectron2 CUDA Compiler: not available".
|
| 98 |
+
</summary>
|
| 99 |
+
<br/>
|
| 100 |
+
CUDA is not found when building detectron2.
|
| 101 |
+
You should make sure
|
| 102 |
+
|
| 103 |
+
```
|
| 104 |
+
python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
print valid outputs at the time you build detectron2.
|
| 108 |
+
|
| 109 |
+
Most models can run inference (but not training) without GPU support. To use CPUs, set `MODEL.DEVICE='cpu'` in the config.
|
| 110 |
+
</details>
|
| 111 |
+
|
| 112 |
+
<details>
|
| 113 |
+
<summary>
|
| 114 |
+
"invalid device function" or "no kernel image is available for execution".
|
| 115 |
+
</summary>
|
| 116 |
+
<br/>
|
| 117 |
+
Two possibilities:
|
| 118 |
+
|
| 119 |
+
* You build detectron2 with one version of CUDA but run it with a different version.
|
| 120 |
+
|
| 121 |
+
To check whether it is the case,
|
| 122 |
+
use `python -m detectron2.utils.collect_env` to find out inconsistent CUDA versions.
|
| 123 |
+
In the output of this command, you should expect "Detectron2 CUDA Compiler", "CUDA_HOME", "PyTorch built with - CUDA"
|
| 124 |
+
to contain cuda libraries of the same version.
|
| 125 |
+
|
| 126 |
+
When they are inconsistent,
|
| 127 |
+
you need to either install a different build of PyTorch (or build by yourself)
|
| 128 |
+
to match your local CUDA installation, or install a different version of CUDA to match PyTorch.
|
| 129 |
+
|
| 130 |
+
* Detectron2 or PyTorch/torchvision is not built for the correct GPU architecture (compute compatibility).
|
| 131 |
+
|
| 132 |
+
The GPU architecture for PyTorch/detectron2/torchvision is available in the "architecture flags" in
|
| 133 |
+
`python -m detectron2.utils.collect_env`.
|
| 134 |
+
|
| 135 |
+
The GPU architecture flags of detectron2/torchvision by default matches the GPU model detected
|
| 136 |
+
during compilation. This means the compiled code may not work on a different GPU model.
|
| 137 |
+
To overwrite the GPU architecture for detectron2/torchvision, use `TORCH_CUDA_ARCH_LIST` environment variable during compilation.
|
| 138 |
+
|
| 139 |
+
For example, `export TORCH_CUDA_ARCH_LIST=6.0,7.0` makes it compile for both P100s and V100s.
|
| 140 |
+
Visit [developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus) to find out
|
| 141 |
+
the correct compute compatibility number for your device.
|
| 142 |
+
|
| 143 |
+
</details>
|
| 144 |
+
|
| 145 |
+
<details>
|
| 146 |
+
<summary>
|
| 147 |
+
Undefined CUDA symbols; cannot open libcudart.so; other nvcc failures.
|
| 148 |
+
</summary>
|
| 149 |
+
<br/>
|
| 150 |
+
The version of NVCC you use to build detectron2 or torchvision does
|
| 151 |
+
not match the version of CUDA you are running with.
|
| 152 |
+
This often happens when using anaconda's CUDA runtime.
|
| 153 |
+
|
| 154 |
+
Use `python -m detectron2.utils.collect_env` to find out inconsistent CUDA versions.
|
| 155 |
+
In the output of this command, you should expect "Detectron2 CUDA Compiler", "CUDA_HOME", "PyTorch built with - CUDA"
|
| 156 |
+
to contain cuda libraries of the same version.
|
| 157 |
+
|
| 158 |
+
When they are inconsistent,
|
| 159 |
+
you need to either install a different build of PyTorch (or build by yourself)
|
| 160 |
+
to match your local CUDA installation, or install a different version of CUDA to match PyTorch.
|
| 161 |
+
</details>
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
<details>
|
| 165 |
+
<summary>
|
| 166 |
+
"ImportError: cannot import name '_C'".
|
| 167 |
+
</summary>
|
| 168 |
+
<br/>
|
| 169 |
+
Please build and install detectron2 following the instructions above.
|
| 170 |
+
|
| 171 |
+
If you are running code from detectron2's root directory, `cd` to a different one.
|
| 172 |
+
Otherwise you may not import the code that you installed.
|
| 173 |
+
</details>
|
| 174 |
+
|
| 175 |
+
<details>
|
| 176 |
+
<summary>
|
| 177 |
+
ONNX conversion segfault after some "TraceWarning".
|
| 178 |
+
</summary>
|
| 179 |
+
<br/>
|
| 180 |
+
The ONNX package is compiled with too old compiler.
|
| 181 |
+
|
| 182 |
+
Please build and install ONNX from its source code using a compiler
|
| 183 |
+
whose version is closer to what's used by PyTorch (available in `torch.__config__.show()`).
|
| 184 |
+
</details>
|
mhp_extension/detectron2/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2019 - present, Facebook, Inc
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|