Jashan887 commited on
Commit
234f949
·
verified ·
1 Parent(s): 2c37abe

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. .gitignore +30 -0
  3. ComfyUI_AEMatter/AEMatter.py +1248 -0
  4. ComfyUI_AEMatter/AEMatter.run.sh +3 -0
  5. ComfyUI_AEMatter/README.org +1357 -0
  6. ComfyUI_AEMatter/__init__.py +1248 -0
  7. ComfyUI_MVANet/MVANet_inference.py +1548 -0
  8. ComfyUI_MVANet/MVANet_inference.run.sh +3 -0
  9. ComfyUI_MVANet/README.org +1694 -0
  10. ComfyUI_MVANet/__init__.py +1548 -0
  11. ComfyUI_MVANet/download.sh +13 -0
  12. ComfyUI_MVANet/requirements.txt +3 -0
  13. LICENSE +21 -0
  14. MVANet_Inference/README.org +2179 -0
  15. README.md +131 -0
  16. checkpoints/AEMatter/AEM_RWA.ckpt +3 -0
  17. checkpoints/MVANet/garment.pth +3 -0
  18. checkpoints/MVANet/skin.pth +3 -0
  19. checkpoints/Model_80.pth +3 -0
  20. checkpoints/StableDiffusion/90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b +3 -0
  21. checkpoints/StableDiffusion/b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30 +3 -0
  22. checkpoints/StableDiffusion/hash +2 -0
  23. checkpoints/atr.pth +3 -0
  24. checkpoints/lip.pth +3 -0
  25. checkpoints/pascal.pth +3 -0
  26. datasets/__init__.py +0 -0
  27. datasets/datasets.py +201 -0
  28. datasets/simple_extractor_dataset.py +78 -0
  29. datasets/target_generation.py +40 -0
  30. demo/demo.jpg +3 -0
  31. demo/demo_atr.png +0 -0
  32. demo/demo_lip.png +0 -0
  33. demo/demo_pascal.png +0 -0
  34. demo/lip-visualization.jpg +3 -0
  35. environment.yaml +49 -0
  36. evaluate.py +209 -0
  37. main.org +663 -0
  38. mhp_extension/README.md +38 -0
  39. mhp_extension/coco_style_annotation_creator/__pycache__/pycococreatortools.cpython-37.pyc +0 -0
  40. mhp_extension/coco_style_annotation_creator/human_to_coco.py +166 -0
  41. mhp_extension/coco_style_annotation_creator/pycococreatortools.py +114 -0
  42. mhp_extension/coco_style_annotation_creator/test_human2coco_format.py +74 -0
  43. mhp_extension/demo.ipynb +0 -0
  44. mhp_extension/detectron2/.circleci/config.yml +179 -0
  45. mhp_extension/detectron2/.clang-format +85 -0
  46. mhp_extension/detectron2/.flake8 +9 -0
  47. mhp_extension/detectron2/.gitignore +46 -0
  48. mhp_extension/detectron2/GETTING_STARTED.md +79 -0
  49. mhp_extension/detectron2/INSTALL.md +184 -0
  50. mhp_extension/detectron2/LICENSE +201 -0
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/atr.pth filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/lip.pth filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/pascal.pth filter=lfs diff=lfs merge=lfs -text
39
+ checkpoints/Model_80.pth filter=lfs diff=lfs merge=lfs -text
40
+ checkpoints/AEMatter/AEM_RWA.ckpt filter=lfs diff=lfs merge=lfs -text
41
+ checkpoints/StableDiffusion/90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b filter=lfs diff=lfs merge=lfs -text
42
+ checkpoints/StableDiffusion/b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30 filter=lfs diff=lfs merge=lfs -text
43
+ checkpoints/MVANet/skin.pth filter=lfs diff=lfs merge=lfs -text
44
+ checkpoints/MVANet/garment.pth filter=lfs diff=lfs merge=lfs -text
45
+ demo/demo_lip.png filter=lfs diff=lfs merge=lfs -text
46
+ demo/lip-visualization.jpg filter=lfs diff=lfs merge=lfs -text
47
+ demo/demo_pascal.png filter=lfs diff=lfs merge=lfs -text
48
+ demo/demo_atr.png filter=lfs diff=lfs merge=lfs -text
49
+ demo/demo.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /ComfyUI_MVANet/__pycache__/__init__.cpython-310.pyc
2
+ /ComfyUI_MVANet/#README.org#
3
+ /ComfyUI_MVANet/.#README.org
4
+ /ComfyUI_MVANet/README.org~
5
+ /ComfyUI_MVANet/.README.org.~undo-tree~
6
+ /#main.org#
7
+ /.#main.org
8
+ /main.org~
9
+ /.main.org.~undo-tree~
10
+ /.README.md.~undo-tree~
11
+ /ComfyUI_MVANet/.#README.org
12
+ /ComfyUI_AEMatter/__pycache__/__init__.cpython-310.pyc
13
+ /ComfyUI_AEMatter/AEMatter.class.py
14
+ /ComfyUI_AEMatter/AEMatter.execute.py
15
+ /ComfyUI_AEMatter/AEMatter.function.py
16
+ /ComfyUI_AEMatter/AEMatter.import.py
17
+ /ComfyUI_MVANet/MVANet_inference.class.py
18
+ /ComfyUI_MVANet/MVANet_inference.execute.py
19
+ /ComfyUI_MVANet/MVANet_inference.function.py
20
+ /ComfyUI_MVANet/MVANet_inference.import.py
21
+ /ComfyUI_MVANet/MVANet_inference.unify.sh
22
+ /ComfyUI_AEMatter/AEMatter.unify.sh
23
+ /git_add.txt
24
+ /git_lfs_track.txt
25
+ /gitignore.txt
26
+ /rm.txt
27
+ /work.sh
28
+ log/
29
+ pretrain_model/
30
+ commit_and_push.sh
ComfyUI_AEMatter/AEMatter.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import cv2
3
+ import math
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import wget
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import init
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+
15
+ from collections import OrderedDict
16
+ from einops import rearrange, repeat
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ import folder_paths
20
+ from folder_paths import models_dir
21
+
22
+
23
+ #!/usr/bin/python3
24
+ def mkdir_safe(out_path):
25
+ if type(out_path) == str:
26
+ if len(out_path) > 0:
27
+ if not os.path.exists(out_path):
28
+ os.mkdir(out_path)
29
+
30
+
31
+ def get_model_path():
32
+ import folder_paths
33
+ from folder_paths import models_dir
34
+
35
+ path_file_model = models_dir
36
+ mkdir_safe(out_path=path_file_model)
37
+
38
+ path_file_model = os.path.join(path_file_model, 'AEMatter')
39
+ mkdir_safe(out_path=path_file_model)
40
+
41
+ path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
42
+
43
+ return path_file_model
44
+
45
+
46
+ def download_model(path):
47
+ if not os.path.exists(path):
48
+ wget.download(
49
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
50
+ out=path)
51
+
52
+
53
+ def from_torch_image(image):
54
+ image = image.cpu().numpy() * 255.0
55
+ image = np.clip(image, 0, 255).astype(np.uint8)
56
+ return image
57
+
58
+
59
+ def to_torch_image(image):
60
+ image = image.astype(dtype=np.float32)
61
+ image /= 255.0
62
+ image = torch.from_numpy(image)
63
+ return image
64
+
65
+
66
+ def window_partition(x, window_size):
67
+ """
68
+ Args:
69
+ x: (B, H, W, C)
70
+ window_size (int): window size
71
+ Returns:
72
+ windows: (num_windows*B, window_size, window_size, C)
73
+ """
74
+ B, H, W, C = x.shape
75
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
76
+ C)
77
+ windows = x.permute(0, 1, 3, 2, 4,
78
+ 5).contiguous().view(-1, window_size, window_size, C)
79
+ return windows
80
+
81
+
82
+ def window_reverse(windows, window_size, H, W):
83
+ """
84
+ Args:
85
+ windows: (num_windows*B, window_size, window_size, C)
86
+ window_size (int): Window size
87
+ H (int): Height of image
88
+ W (int): Width of image
89
+ Returns:
90
+ x: (B, H, W, C)
91
+ """
92
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
93
+ x = windows.view(B, H // window_size, W // window_size, window_size,
94
+ window_size, -1)
95
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
96
+ return x
97
+
98
+
99
+ def get_AEMatter_model(path_model_checkpoint):
100
+
101
+ download_model(path=path_model_checkpoint)
102
+
103
+ matmodel = AEMatter()
104
+ matmodel.load_state_dict(
105
+ torch.load(path_model_checkpoint, map_location='cpu')['model'])
106
+
107
+ matmodel = matmodel.cuda()
108
+ matmodel.eval()
109
+
110
+ return matmodel
111
+
112
+
113
+ def do_infer(rawimg, trimap, matmodel):
114
+ trimap_nonp = trimap.copy()
115
+ h, w, c = rawimg.shape
116
+ nonph, nonpw, _ = rawimg.shape
117
+ newh = (((h - 1) // 32) + 1) * 32
118
+ neww = (((w - 1) // 32) + 1) * 32
119
+ padh = newh - h
120
+ padh1 = int(padh / 2)
121
+ padh2 = padh - padh1
122
+ padw = neww - w
123
+ padw1 = int(padw / 2)
124
+ padw2 = padw - padw1
125
+
126
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
127
+ cv2.BORDER_REFLECT)
128
+
129
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
130
+ cv2.BORDER_REFLECT)
131
+
132
+ h_pad, w_pad, _ = rawimg_pad.shape
133
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
134
+ tritemp[:, :, 0] = (trimap_pad == 0)
135
+ tritemp[:, :, 1] = (trimap_pad == 128)
136
+ tritemp[:, :, 2] = (trimap_pad == 255)
137
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
138
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
139
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
140
+ img = np.array(img, np.float32)
141
+ img = img / 255.
142
+ img = torch.from_numpy(img).cuda()
143
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
144
+ with torch.no_grad():
145
+ pred = matmodel(img, tritempimgs)
146
+ pred = pred.detach().cpu().numpy()[0]
147
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
148
+ preda = pred[
149
+ 0:1,
150
+ ] * 255
151
+ preda = np.transpose(preda, (1, 2, 0))
152
+ preda = preda * (trimap_nonp[:, :, None]
153
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
154
+ preda = np.array(preda, np.uint8)
155
+ return preda
156
+
157
+
158
+ def main():
159
+ ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
160
+ pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
161
+ p_outs = 'alpha.png'
162
+
163
+ matmodel = get_AEMatter_model(
164
+ path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
165
+
166
+ # matmodel = AEMatter()
167
+ # matmodel.load_state_dict(
168
+ # torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
169
+ # map_location='cpu')['model'])
170
+
171
+ # matmodel = matmodel.cuda()
172
+ # matmodel.eval()
173
+
174
+ rawimg = pimgs
175
+ trimap = ptrimap
176
+ rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
177
+ trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
178
+ trimap_nonp = trimap.copy()
179
+ h, w, c = rawimg.shape
180
+ nonph, nonpw, _ = rawimg.shape
181
+ newh = (((h - 1) // 32) + 1) * 32
182
+ neww = (((w - 1) // 32) + 1) * 32
183
+ padh = newh - h
184
+ padh1 = int(padh / 2)
185
+ padh2 = padh - padh1
186
+ padw = neww - w
187
+ padw1 = int(padw / 2)
188
+ padw2 = padw - padw1
189
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
190
+ cv2.BORDER_REFLECT)
191
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
192
+ cv2.BORDER_REFLECT)
193
+ h_pad, w_pad, _ = rawimg_pad.shape
194
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
195
+ tritemp[:, :, 0] = (trimap_pad == 0)
196
+ tritemp[:, :, 1] = (trimap_pad == 128)
197
+ tritemp[:, :, 2] = (trimap_pad == 255)
198
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
199
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
200
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
201
+ img = np.array(img, np.float32)
202
+ img = img / 255.
203
+ img = torch.from_numpy(img).cuda()
204
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
205
+ with torch.no_grad():
206
+ pred = matmodel(img, tritempimgs)
207
+ pred = pred.detach().cpu().numpy()[0]
208
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
209
+ preda = pred[
210
+ 0:1,
211
+ ] * 255
212
+ preda = np.transpose(preda, (1, 2, 0))
213
+ preda = preda * (trimap_nonp[:, :, None]
214
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
215
+ preda = np.array(preda, np.uint8)
216
+ cv2.imwrite(p_outs, preda)
217
+
218
+
219
+ #!/usr/bin/python3
220
+ class WindowAttention(nn.Module):
221
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
222
+ It supports both of shifted and non-shifted window.
223
+ Args:
224
+ dim (int): Number of input channels.
225
+ window_size (tuple[int]): The height and width of the window.
226
+ num_heads (int): Number of attention heads.
227
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
228
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
229
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
230
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
231
+ """
232
+
233
+ def __init__(self,
234
+ dim,
235
+ window_size,
236
+ num_heads,
237
+ qkv_bias=True,
238
+ qk_scale=None,
239
+ attn_drop=0.,
240
+ proj_drop=0.):
241
+
242
+ super().__init__()
243
+ self.dim = dim
244
+ self.window_size = window_size # Wh, Ww
245
+ self.num_heads = num_heads
246
+ head_dim = dim // num_heads
247
+ self.scale = qk_scale or head_dim**-0.5
248
+
249
+ # define a parameter table of relative position bias
250
+ self.relative_position_bias_table = nn.Parameter(
251
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
252
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
253
+
254
+ # get pair-wise relative position index for each token inside the window
255
+ coords_h = torch.arange(self.window_size[0])
256
+ coords_w = torch.arange(self.window_size[1])
257
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
258
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
259
+ relative_coords = coords_flatten[:, :,
260
+ None] - coords_flatten[:,
261
+ None, :] # 2, Wh*Ww, Wh*Ww
262
+ relative_coords = relative_coords.permute(
263
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
264
+ relative_coords[:, :,
265
+ 0] += self.window_size[0] - 1 # shift to start from 0
266
+ relative_coords[:, :, 1] += self.window_size[1] - 1
267
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
268
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
269
+ self.register_buffer("relative_position_index",
270
+ relative_position_index)
271
+
272
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
273
+ self.attn_drop = nn.Dropout(attn_drop)
274
+ self.proj = nn.Linear(dim, dim)
275
+ self.proj_drop = nn.Dropout(proj_drop)
276
+
277
+ trunc_normal_(self.relative_position_bias_table, std=.02)
278
+ self.softmax = nn.Softmax(dim=-1)
279
+
280
+ def forward(self, x, mask=None):
281
+ """ Forward function.
282
+ Args:
283
+ x: input features with shape of (num_windows*B, N, C)
284
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
285
+ """
286
+ B_, N, C = x.shape
287
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
288
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
289
+ q, k, v = qkv[0], qkv[1], qkv[
290
+ 2] # make torchscript happy (cannot use tensor as tuple)
291
+
292
+ q = q * self.scale
293
+ attn = (q @ k.transpose(-2, -1))
294
+
295
+ relative_position_bias = self.relative_position_bias_table[
296
+ self.relative_position_index.view(-1)].view(
297
+ self.window_size[0] * self.window_size[1],
298
+ self.window_size[0] * self.window_size[1],
299
+ -1) # Wh*Ww,Wh*Ww,nH
300
+ relative_position_bias = relative_position_bias.permute(
301
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
302
+ attn = attn + relative_position_bias.unsqueeze(0)
303
+
304
+ if mask is not None:
305
+ nW = mask.shape[0]
306
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
307
+ N) + mask.unsqueeze(1).unsqueeze(0)
308
+ attn = attn.view(-1, self.num_heads, N, N)
309
+ attn = self.softmax(attn)
310
+ else:
311
+ attn = self.softmax(attn)
312
+
313
+ attn = self.attn_drop(attn)
314
+
315
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
316
+ x = self.proj(x)
317
+ x = self.proj_drop(x)
318
+ return x
319
+
320
+
321
+ class SwinTransformerBlock(nn.Module):
322
+ """ Swin Transformer Block.
323
+ Args:
324
+ dim (int): Number of input channels.
325
+ num_heads (int): Number of attention heads.
326
+ window_size (int): Window size.
327
+ shift_size (int): Shift size for SW-MSA.
328
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
329
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
330
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
331
+ drop (float, optional): Dropout rate. Default: 0.0
332
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
333
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
334
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
335
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
336
+ """
337
+
338
+ def __init__(self,
339
+ dim,
340
+ num_heads,
341
+ window_size=7,
342
+ shift_size=0,
343
+ mlp_ratio=4.,
344
+ qkv_bias=True,
345
+ qk_scale=None,
346
+ drop=0.,
347
+ attn_drop=0.,
348
+ drop_path=0.,
349
+ act_layer=nn.GELU,
350
+ norm_layer=nn.LayerNorm):
351
+ super().__init__()
352
+ self.dim = dim
353
+ self.num_heads = num_heads
354
+ self.window_size = window_size
355
+ self.shift_size = shift_size
356
+ self.mlp_ratio = mlp_ratio
357
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
358
+
359
+ self.norm1 = norm_layer(dim)
360
+ self.attn = WindowAttention(dim,
361
+ window_size=to_2tuple(self.window_size),
362
+ num_heads=num_heads,
363
+ qkv_bias=qkv_bias,
364
+ qk_scale=qk_scale,
365
+ attn_drop=attn_drop,
366
+ proj_drop=drop)
367
+
368
+ self.drop_path = DropPath(
369
+ drop_path) if drop_path > 0. else nn.Identity()
370
+ self.norm2 = norm_layer(dim)
371
+ mlp_hidden_dim = int(dim * mlp_ratio)
372
+ self.mlp = Mlp(in_features=dim,
373
+ hidden_features=mlp_hidden_dim,
374
+ act_layer=act_layer,
375
+ drop=drop)
376
+
377
+ self.H = None
378
+ self.W = None
379
+
380
+ def forward(self, x, mask_matrix):
381
+ """ Forward function.
382
+ Args:
383
+ x: Input feature, tensor size (B, H*W, C).
384
+ H, W: Spatial resolution of the input feature.
385
+ mask_matrix: Attention mask for cyclic shift.
386
+ """
387
+ B, L, C = x.shape
388
+ H, W = self.H, self.W
389
+ assert L == H * W, "input feature has wrong size"
390
+
391
+ shortcut = x
392
+ x = self.norm1(x)
393
+ x = x.view(B, H, W, C)
394
+
395
+ # pad feature maps to multiples of window size
396
+ pad_l = pad_t = 0
397
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
398
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
399
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
400
+ _, Hp, Wp, _ = x.shape
401
+
402
+ # cyclic shift
403
+ if self.shift_size > 0:
404
+ shifted_x = torch.roll(x,
405
+ shifts=(-self.shift_size, -self.shift_size),
406
+ dims=(1, 2))
407
+ attn_mask = mask_matrix
408
+ else:
409
+ shifted_x = x
410
+ attn_mask = None
411
+
412
+ # partition windows
413
+ x_windows = window_partition(
414
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
415
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
416
+ C) # nW*B, window_size*window_size, C
417
+
418
+ # W-MSA/SW-MSA
419
+ attn_windows = self.attn(
420
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
421
+
422
+ # merge windows
423
+ attn_windows = attn_windows.view(-1, self.window_size,
424
+ self.window_size, C)
425
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
426
+ Wp) # B H' W' C
427
+
428
+ # reverse cyclic shift
429
+ if self.shift_size > 0:
430
+ x = torch.roll(shifted_x,
431
+ shifts=(self.shift_size, self.shift_size),
432
+ dims=(1, 2))
433
+ else:
434
+ x = shifted_x
435
+
436
+ if pad_r > 0 or pad_b > 0:
437
+ x = x[:, :H, :W, :].contiguous()
438
+
439
+ x = x.view(B, H * W, C)
440
+
441
+ # FFN
442
+ x = shortcut + self.drop_path(x)
443
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
444
+
445
+ return x
446
+
447
+
448
+ class PatchMerging(nn.Module):
449
+ """ Patch Merging Layer
450
+ Args:
451
+ dim (int): Number of input channels.
452
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
453
+ """
454
+
455
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
456
+ super().__init__()
457
+ self.dim = dim
458
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
459
+ self.norm = norm_layer(4 * dim)
460
+
461
+ def forward(self, x, H, W):
462
+ """ Forward function.
463
+ Args:
464
+ x: Input feature, tensor size (B, H*W, C).
465
+ H, W: Spatial resolution of the input feature.
466
+ """
467
+ B, L, C = x.shape
468
+ assert L == H * W, "input feature has wrong size"
469
+
470
+ x = x.view(B, H, W, C)
471
+
472
+ # padding
473
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
474
+ if pad_input:
475
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
476
+
477
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
478
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
479
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
480
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
481
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
482
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
483
+
484
+ x = self.norm(x)
485
+ x = self.reduction(x)
486
+
487
+ return x
488
+
489
+
490
+ class BasicLayer(nn.Module):
491
+ """ A basic Swin Transformer layer for one stage.
492
+ Args:
493
+ dim (int): Number of feature channels
494
+ depth (int): Depths of this stage.
495
+ num_heads (int): Number of attention head.
496
+ window_size (int): Local window size. Default: 7.
497
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
498
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
499
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
500
+ drop (float, optional): Dropout rate. Default: 0.0
501
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
502
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
503
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
504
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
505
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
506
+ """
507
+
508
+ def __init__(self,
509
+ dim,
510
+ depth,
511
+ num_heads,
512
+ window_size=7,
513
+ mlp_ratio=4.,
514
+ qkv_bias=True,
515
+ qk_scale=None,
516
+ drop=0.,
517
+ attn_drop=0.,
518
+ drop_path=0.,
519
+ norm_layer=nn.LayerNorm,
520
+ downsample=None,
521
+ use_checkpoint=False):
522
+
523
+ super().__init__()
524
+ self.window_size = window_size
525
+ self.shift_size = window_size // 2
526
+ self.depth = depth
527
+ self.use_checkpoint = use_checkpoint
528
+
529
+ # build blocks
530
+ self.blocks = nn.ModuleList([
531
+ SwinTransformerBlock(dim=dim,
532
+ num_heads=num_heads,
533
+ window_size=window_size,
534
+ shift_size=0 if
535
+ (i % 2 == 0) else window_size // 2,
536
+ mlp_ratio=mlp_ratio,
537
+ qkv_bias=qkv_bias,
538
+ qk_scale=qk_scale,
539
+ drop=drop,
540
+ attn_drop=attn_drop,
541
+ drop_path=drop_path[i] if isinstance(
542
+ drop_path, list) else drop_path,
543
+ norm_layer=norm_layer) for i in range(depth)
544
+ ])
545
+
546
+ # patch merging layer
547
+ if downsample is not None:
548
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
549
+ else:
550
+ self.downsample = None
551
+
552
+ def forward(self, x, H, W):
553
+ """ Forward function.
554
+ Args:
555
+ x: Input feature, tensor size (B, H*W, C).
556
+ H, W: Spatial resolution of the input feature.
557
+ """
558
+ # print(x.shape,H,W)
559
+ # calculate attention mask for SW-MSA
560
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
561
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
562
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
563
+ h_slices = (slice(0, -self.window_size),
564
+ slice(-self.window_size,
565
+ -self.shift_size), slice(-self.shift_size, None))
566
+ w_slices = (slice(0, -self.window_size),
567
+ slice(-self.window_size,
568
+ -self.shift_size), slice(-self.shift_size, None))
569
+ cnt = 0
570
+ for h in h_slices:
571
+ for w in w_slices:
572
+ img_mask[:, h, w, :] = cnt
573
+ cnt += 1
574
+
575
+ mask_windows = window_partition(
576
+ img_mask, self.window_size) # nW, window_size, window_size, 1
577
+
578
+ mask_windows = mask_windows.view(-1,
579
+ self.window_size * self.window_size)
580
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
581
+ 2) # nW, ww window_size*window_size
582
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
583
+ float(-100.0)).masked_fill(
584
+ attn_mask == 0, float(0.0))
585
+
586
+ for blk in self.blocks:
587
+ blk.H, blk.W = H, W
588
+ if self.use_checkpoint:
589
+ x = checkpoint.checkpoint(blk, x, attn_mask)
590
+ else:
591
+ x = blk(x, attn_mask)
592
+
593
+ if self.downsample is not None:
594
+ x_down = self.downsample(x, H, W)
595
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
596
+ return x, H, W, x_down, Wh, Ww
597
+ else:
598
+ return x, H, W, x, H, W
599
+
600
+
601
+ class PatchEmbed(nn.Module):
602
+ """ Image to Patch Embedding
603
+ Args:
604
+ patch_size (int): Patch token size. Default: 4.
605
+ in_chans (int): Number of input image channels. Default: 3.
606
+ embed_dim (int): Number of linear projection output channels. Default: 96.
607
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
608
+ """
609
+
610
+ def __init__(self,
611
+ patch_size=4,
612
+ in_chans=3,
613
+ embed_dim=96,
614
+ norm_layer=None):
615
+
616
+ super().__init__()
617
+ patch_size = to_2tuple(patch_size)
618
+ self.patch_size = patch_size
619
+
620
+ self.in_chans = in_chans
621
+ self.embed_dim = embed_dim
622
+
623
+ self.proj = nn.Conv2d(in_chans,
624
+ embed_dim,
625
+ kernel_size=patch_size,
626
+ stride=patch_size)
627
+ if norm_layer is not None:
628
+ self.norm = norm_layer(embed_dim)
629
+ else:
630
+ self.norm = None
631
+
632
+ def forward(self, x):
633
+ """Forward function."""
634
+ # padding
635
+ _, _, H, W = x.size()
636
+ if W % self.patch_size[1] != 0:
637
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
638
+ if H % self.patch_size[0] != 0:
639
+ x = F.pad(x,
640
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
641
+
642
+ x = self.proj(x) # B C Wh Ww
643
+ if self.norm is not None:
644
+ Wh, Ww = x.size(2), x.size(3)
645
+ x = x.flatten(2).transpose(1, 2)
646
+ x = self.norm(x)
647
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
648
+
649
+ return x
650
+
651
+
652
+ class SwinTransformer(nn.Module):
653
+ """ Swin Transformer backbone.
654
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
655
+ https://arxiv.org/pdf/2103.14030
656
+ Args:
657
+ pretrain_img_size (int): Input image size for training the pretrained model,
658
+ used in absolute postion embedding. Default 224.
659
+ patch_size (int | tuple(int)): Patch size. Default: 4.
660
+ in_chans (int): Number of input image channels. Default: 3.
661
+ embed_dim (int): Number of linear projection output channels. Default: 96.
662
+ depths (tuple[int]): Depths of each Swin Transformer stage.
663
+ num_heads (tuple[int]): Number of attention head of each stage.
664
+ window_size (int): Window size. Default: 7.
665
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
666
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
667
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
668
+ drop_rate (float): Dropout rate.
669
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
670
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
671
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
672
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
673
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
674
+ out_indices (Sequence[int]): Output from which stages.
675
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
676
+ -1 means not freezing any parameters.
677
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
678
+ """
679
+
680
+ def __init__(self,
681
+ pretrain_img_size=224,
682
+ patch_size=4,
683
+ in_chans=3,
684
+ embed_dim=96,
685
+ depths=[2, 2, 6, 2],
686
+ num_heads=[3, 6, 12, 24],
687
+ window_size=7,
688
+ mlp_ratio=4.,
689
+ qkv_bias=True,
690
+ qk_scale=None,
691
+ drop_rate=0.,
692
+ attn_drop_rate=0.,
693
+ drop_path_rate=0.2,
694
+ norm_layer=nn.LayerNorm,
695
+ ape=False,
696
+ patch_norm=True,
697
+ out_indices=(0, 1, 2, 3),
698
+ frozen_stages=-1,
699
+ use_checkpoint=False):
700
+
701
+ super().__init__()
702
+
703
+ self.pretrain_img_size = pretrain_img_size
704
+ self.num_layers = len(depths)
705
+ self.embed_dim = embed_dim
706
+ self.ape = ape
707
+ self.patch_norm = patch_norm
708
+ self.out_indices = out_indices
709
+ self.frozen_stages = frozen_stages
710
+
711
+ # split image into non-overlapping patches
712
+ self.patch_embed = PatchEmbed(
713
+ patch_size=patch_size,
714
+ in_chans=in_chans,
715
+ embed_dim=embed_dim,
716
+ norm_layer=norm_layer if self.patch_norm else None)
717
+
718
+ # absolute position embedding
719
+ if self.ape:
720
+ pretrain_img_size = to_2tuple(pretrain_img_size)
721
+ patch_size = to_2tuple(patch_size)
722
+ patches_resolution = [
723
+ pretrain_img_size[0] // patch_size[0],
724
+ pretrain_img_size[1] // patch_size[1]
725
+ ]
726
+
727
+ self.absolute_pos_embed = nn.Parameter(
728
+ torch.zeros(1, embed_dim, patches_resolution[0],
729
+ patches_resolution[1]))
730
+ trunc_normal_(self.absolute_pos_embed, std=.02)
731
+
732
+ self.pos_drop = nn.Dropout(p=drop_rate)
733
+
734
+ # stochastic depth
735
+ dpr = [
736
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
737
+ ] # stochastic depth decay rule
738
+
739
+ # build layers
740
+ self.layers = nn.ModuleList()
741
+ for i_layer in range(self.num_layers):
742
+ layer = BasicLayer(
743
+ dim=int(embed_dim * 2**i_layer),
744
+ depth=depths[i_layer],
745
+ num_heads=num_heads[i_layer],
746
+ window_size=window_size,
747
+ mlp_ratio=mlp_ratio,
748
+ qkv_bias=qkv_bias,
749
+ qk_scale=qk_scale,
750
+ drop=drop_rate,
751
+ attn_drop=attn_drop_rate,
752
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
753
+ norm_layer=norm_layer,
754
+ downsample=PatchMerging if
755
+ (i_layer < self.num_layers - 1) else None,
756
+ use_checkpoint=use_checkpoint)
757
+ self.layers.append(layer)
758
+
759
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
760
+ self.num_features = num_features
761
+
762
+ # add a norm layer for each output
763
+ for i_layer in out_indices:
764
+ layer = norm_layer(num_features[i_layer])
765
+ layer_name = f'norm{i_layer}'
766
+ self.add_module(layer_name, layer)
767
+
768
+ self._freeze_stages()
769
+
770
+ def _freeze_stages(self):
771
+ if self.frozen_stages >= 0:
772
+ self.patch_embed.eval()
773
+ for param in self.patch_embed.parameters():
774
+ param.requires_grad = False
775
+
776
+ if self.frozen_stages >= 1 and self.ape:
777
+ self.absolute_pos_embed.requires_grad = False
778
+
779
+ if self.frozen_stages >= 2:
780
+ self.pos_drop.eval()
781
+ for i in range(0, self.frozen_stages - 1):
782
+ m = self.layers[i]
783
+ m.eval()
784
+ for param in m.parameters():
785
+ param.requires_grad = False
786
+
787
+ def init_weights(self, pretrained=None):
788
+ """Initialize the weights in backbone.
789
+ Args:
790
+ pretrained (str, optional): Path to pre-trained weights.
791
+ Defaults to None.
792
+ """
793
+
794
+ def forward(self, x):
795
+ """Forward function."""
796
+ x = self.patch_embed(x)
797
+
798
+ Wh, Ww = x.size(2), x.size(3)
799
+ if self.ape:
800
+ # interpolate the position embedding to the corresponding size
801
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
802
+ size=(Wh, Ww),
803
+ mode='bicubic')
804
+ x = (x + absolute_pos_embed).flatten(2).transpose(1,
805
+ 2) # B Wh*Ww C
806
+ else:
807
+ x = x.flatten(2).transpose(1, 2)
808
+ x = self.pos_drop(x)
809
+
810
+ outs = []
811
+ for i in range(self.num_layers):
812
+ layer = self.layers[i]
813
+
814
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
815
+
816
+ if i in self.out_indices:
817
+ norm_layer = getattr(self, f'norm{i}')
818
+ x_out = norm_layer(x_out)
819
+
820
+ out = x_out.view(-1, H, W,
821
+ self.num_features[i]).permute(0, 3, 1,
822
+ 2).contiguous()
823
+ outs.append(out)
824
+
825
+ return tuple(outs)
826
+
827
+ def train(self, mode=True):
828
+ """Convert the model into training mode while keep layers freezed."""
829
+ super(SwinTransformer, self).train(mode)
830
+ self._freeze_stages()
831
+
832
+
833
+ class Mlp(nn.Module):
834
+ """ Multilayer perceptron."""
835
+
836
+ def __init__(self,
837
+ in_features,
838
+ hidden_features=None,
839
+ out_features=None,
840
+ act_layer=nn.GELU,
841
+ drop=0.):
842
+ super().__init__()
843
+ out_features = out_features or in_features
844
+ hidden_features = hidden_features or in_features
845
+ self.fc1 = nn.Linear(in_features, hidden_features)
846
+ self.act = act_layer()
847
+ self.fc2 = nn.Linear(hidden_features, out_features)
848
+ self.drop = nn.Dropout(drop)
849
+
850
+ def forward(self, x):
851
+ x = self.fc1(x)
852
+ x = self.act(x)
853
+ x = self.drop(x)
854
+ x = self.fc2(x)
855
+ x = self.drop(x)
856
+ return x
857
+
858
+
859
+ class ResBlock(nn.Module):
860
+
861
+ def __init__(self, inc, midc):
862
+ super(ResBlock, self).__init__()
863
+ self.conv1 = nn.Conv2d(inc,
864
+ midc,
865
+ kernel_size=1,
866
+ stride=1,
867
+ padding=0,
868
+ bias=True)
869
+ self.gn1 = nn.GroupNorm(16, midc)
870
+ self.conv2 = nn.Conv2d(midc,
871
+ midc,
872
+ kernel_size=3,
873
+ stride=1,
874
+ padding=1,
875
+ bias=True)
876
+ self.gn2 = nn.GroupNorm(16, midc)
877
+ self.conv3 = nn.Conv2d(midc,
878
+ inc,
879
+ kernel_size=1,
880
+ stride=1,
881
+ padding=0,
882
+ bias=True)
883
+ self.relu = nn.LeakyReLU(0.1)
884
+
885
+ def forward(self, x):
886
+ x_ = x
887
+ x = self.conv1(x)
888
+ x = self.gn1(x)
889
+ x = self.relu(x)
890
+ x = self.conv2(x)
891
+ x = self.gn2(x)
892
+ x = self.relu(x)
893
+ x = self.conv3(x)
894
+ x = x + x_
895
+ x = self.relu(x)
896
+ return x
897
+
898
+
899
+ class AEALblock(nn.Module):
900
+
901
+ def __init__(self,
902
+ d_model,
903
+ nhead,
904
+ dim_feedforward=512,
905
+ dropout=0.0,
906
+ layer_norm_eps=1e-5,
907
+ batch_first=True,
908
+ norm_first=False,
909
+ width=5):
910
+ super(AEALblock, self).__init__()
911
+ self.self_attn2 = nn.MultiheadAttention(d_model // 2,
912
+ nhead // 2,
913
+ dropout=dropout,
914
+ batch_first=batch_first)
915
+ self.self_attn1 = nn.MultiheadAttention(d_model // 2,
916
+ nhead // 2,
917
+ dropout=dropout,
918
+ batch_first=batch_first)
919
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
920
+ self.dropout = nn.Dropout(dropout)
921
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
922
+ self.norm_first = norm_first
923
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
924
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
925
+ self.dropout1 = nn.Dropout(dropout)
926
+ self.dropout2 = nn.Dropout(dropout)
927
+ self.activation = nn.ReLU()
928
+ self.width = width
929
+ self.trans = nn.Sequential(
930
+ nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
931
+ ResBlock(d_model // 2, d_model // 4),
932
+ nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
933
+ self.gamma = nn.Parameter(torch.zeros(1))
934
+
935
+ def forward(
936
+ self,
937
+ src,
938
+ feats,
939
+ ):
940
+ src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
941
+ b, c, h, w = src.shape
942
+ x1 = src[:, 0:c // 2]
943
+ x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
944
+ x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
945
+ x2 = src[:, c // 2:]
946
+ x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
947
+ x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
948
+ x = rearrange(src, 'b c h w-> b (h w) c')
949
+ x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
950
+ x = self.norm2(x + self._ff_block(x))
951
+ x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
952
+ return x
953
+
954
+ def _sa_block(self, x1, x2, h, w):
955
+ x1 = self.self_attn1(x1,
956
+ x1,
957
+ x1,
958
+ attn_mask=None,
959
+ key_padding_mask=None,
960
+ need_weights=False)[0]
961
+
962
+ x2 = self.self_attn2(x2,
963
+ x2,
964
+ x2,
965
+ attn_mask=None,
966
+ key_padding_mask=None,
967
+ need_weights=False)[0]
968
+
969
+ x1 = rearrange(x1,
970
+ '(b h1) (h2 w) c-> b (h1 h2 w) c',
971
+ h2=self.width,
972
+ h1=h // self.width)
973
+ x2 = rearrange(x2,
974
+ ' (b w1) (h w2) c-> b (h w1 w2) c',
975
+ w2=self.width,
976
+ w1=w // self.width)
977
+ x = torch.cat([x1, x2], dim=2)
978
+ return self.dropout1(x)
979
+
980
+ def _ff_block(self, x):
981
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
982
+ return self.dropout2(x)
983
+
984
+
985
+ class AEMatter(nn.Module):
986
+
987
+ def __init__(self):
988
+ super(AEMatter, self).__init__()
989
+ trans = SwinTransformer(pretrain_img_size=224,
990
+ embed_dim=96,
991
+ depths=[2, 2, 6, 2],
992
+ num_heads=[3, 6, 12, 24],
993
+ window_size=7,
994
+ ape=False,
995
+ drop_path_rate=0.2,
996
+ patch_norm=True,
997
+ use_checkpoint=False)
998
+
999
+ # trans.load_state_dict(torch.load(
1000
+ # '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
1001
+ # map_location="cpu")["model"],
1002
+ # strict=False)
1003
+
1004
+ trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
1005
+
1006
+ self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
1007
+ nn.PReLU(48))
1008
+
1009
+ self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
1010
+ 1), nn.PReLU(64),
1011
+ nn.Conv2d(64, 64, 3, 1, 1),
1012
+ nn.PReLU(64))
1013
+
1014
+ self.trans = trans
1015
+ self.conv1 = nn.Sequential(
1016
+ nn.Conv2d(in_channels=640 + 768,
1017
+ out_channels=256,
1018
+ kernel_size=1,
1019
+ stride=1,
1020
+ padding=0,
1021
+ bias=True))
1022
+ self.conv2 = nn.Sequential(
1023
+ nn.Conv2d(in_channels=256 + 384,
1024
+ out_channels=256,
1025
+ kernel_size=1,
1026
+ stride=1,
1027
+ padding=0,
1028
+ bias=True), )
1029
+ self.conv3 = nn.Sequential(
1030
+ nn.Conv2d(in_channels=256 + 192,
1031
+ out_channels=192,
1032
+ kernel_size=1,
1033
+ stride=1,
1034
+ padding=0,
1035
+ bias=True), )
1036
+ self.conv4 = nn.Sequential(
1037
+ nn.Conv2d(in_channels=192 + 96,
1038
+ out_channels=128,
1039
+ kernel_size=1,
1040
+ stride=1,
1041
+ padding=0,
1042
+ bias=True), )
1043
+ self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
1044
+ self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
1045
+ self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
1046
+ self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
1047
+ self.conv5 = nn.Sequential(
1048
+ nn.Conv2d(in_channels=192,
1049
+ out_channels=64,
1050
+ kernel_size=3,
1051
+ stride=1,
1052
+ padding=1,
1053
+ bias=True), nn.PReLU(64),
1054
+ nn.Conv2d(in_channels=64,
1055
+ out_channels=64,
1056
+ kernel_size=3,
1057
+ stride=1,
1058
+ padding=1,
1059
+ bias=True), nn.PReLU(64),
1060
+ nn.Conv2d(in_channels=64,
1061
+ out_channels=48,
1062
+ kernel_size=3,
1063
+ stride=1,
1064
+ padding=1,
1065
+ bias=True), nn.PReLU(48))
1066
+ self.convo = nn.Sequential(
1067
+ nn.Conv2d(in_channels=48 + 48 + 6,
1068
+ out_channels=32,
1069
+ kernel_size=3,
1070
+ stride=1,
1071
+ padding=1,
1072
+ bias=True), nn.PReLU(32),
1073
+ nn.Conv2d(in_channels=32,
1074
+ out_channels=32,
1075
+ kernel_size=3,
1076
+ stride=1,
1077
+ padding=1,
1078
+ bias=True), nn.PReLU(32),
1079
+ nn.Conv2d(in_channels=32,
1080
+ out_channels=1,
1081
+ kernel_size=3,
1082
+ stride=1,
1083
+ padding=1,
1084
+ bias=True))
1085
+ self.up = nn.Upsample(scale_factor=2,
1086
+ mode='bilinear',
1087
+ align_corners=False)
1088
+ self.upn = nn.Upsample(scale_factor=2, mode='nearest')
1089
+ self.apptrans = nn.Sequential(
1090
+ nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
1091
+ ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
1092
+ ResBlock(512, 128))
1093
+ self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
1094
+ ResBlock(640, 160))
1095
+ self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
1096
+ self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
1097
+ self.width = 5
1098
+ self.trans1 = AEALblock(d_model=640,
1099
+ nhead=20,
1100
+ dim_feedforward=2048,
1101
+ dropout=0.2,
1102
+ width=self.width)
1103
+ self.trans2 = AEALblock(d_model=640,
1104
+ nhead=20,
1105
+ dim_feedforward=2048,
1106
+ dropout=0.2,
1107
+ width=self.width)
1108
+ self.trans3 = AEALblock(d_model=640,
1109
+ nhead=20,
1110
+ dim_feedforward=2048,
1111
+ dropout=0.2,
1112
+ width=self.width)
1113
+
1114
+ def aeal(self, x, sem):
1115
+ xe = self.emb(x)
1116
+ x_ = xe
1117
+ x_ = self.embdp(x_)
1118
+ b, c, h1, w1 = x_.shape
1119
+ bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
1120
+ bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
1121
+ newph1 = bnew_ph // 2
1122
+ newph2 = bnew_ph - newph1
1123
+ newpw1 = bnew_pw // 2
1124
+ newpw2 = bnew_pw - newpw1
1125
+ x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
1126
+ sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
1127
+ x_ = self.trans1(x_, sem)
1128
+ x_ = self.trans2(x_, sem)
1129
+ x_ = self.trans3(x_, sem)
1130
+ x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
1131
+ return x_
1132
+
1133
+ def forward(self, x, y):
1134
+ inputs = torch.cat((x, y), 1)
1135
+ x = self.start_conv0(inputs)
1136
+ x_ = self.start_conv(x)
1137
+ x1, x2, x3, x4 = self.trans(x_)
1138
+ x4h = self.h2l(x4)
1139
+ x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
1140
+ x4_ = self.aeal(x4, x3s)
1141
+ x4 = torch.cat((x4, x4_), 1)
1142
+ X4 = self.conv1(x4)
1143
+ wh, ww = X4.shape[2], X4.shape[3]
1144
+ X4 = rearrange(X4, 'b c h w -> b (h w) c')
1145
+ X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
1146
+ X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
1147
+ X3 = self.up(X4)
1148
+ X3 = torch.cat((x3, X3), 1)
1149
+ X3 = self.conv2(X3)
1150
+ wh, ww = X3.shape[2], X3.shape[3]
1151
+ X3 = rearrange(X3, 'b c h w -> b (h w) c')
1152
+ X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
1153
+ X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
1154
+ X2 = self.up(X3)
1155
+ X2 = torch.cat((x2, X2), 1)
1156
+ X2 = self.conv3(X2)
1157
+ wh, ww = X2.shape[2], X2.shape[3]
1158
+ X2 = rearrange(X2, 'b c h w -> b (h w) c')
1159
+ X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
1160
+ X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
1161
+ X1 = self.up(X2)
1162
+ X1 = torch.cat((x1, X1), 1)
1163
+ X1 = self.conv4(X1)
1164
+ wh, ww = X1.shape[2], X1.shape[3]
1165
+ X1 = rearrange(X1, 'b c h w -> b (h w) c')
1166
+ X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
1167
+ X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
1168
+ X0 = self.up(X1)
1169
+ X0 = torch.cat((x_, X0), 1)
1170
+ X0 = self.conv5(X0)
1171
+ X = self.up(X0)
1172
+ X = torch.cat((inputs, x, X), 1)
1173
+ alpha = self.convo(X)
1174
+ alpha = torch.clamp(alpha, min=0, max=1)
1175
+ return alpha
1176
+
1177
+
1178
+ class load_AEMatter_Model:
1179
+
1180
+ def __init__(self):
1181
+ pass
1182
+
1183
+ @classmethod
1184
+ def INPUT_TYPES(s):
1185
+ return {
1186
+ "required": {},
1187
+ }
1188
+
1189
+ RETURN_TYPES = ("AEMatter_Model", )
1190
+ FUNCTION = "test"
1191
+ CATEGORY = "AEMatter"
1192
+
1193
+ def test(self):
1194
+ return (get_AEMatter_model(get_model_path()), )
1195
+
1196
+
1197
+ class run_AEMatter_inference:
1198
+
1199
+ def __init__(self):
1200
+ pass
1201
+
1202
+ @classmethod
1203
+ def INPUT_TYPES(s):
1204
+ return {
1205
+ "required": {
1206
+ "image": ("IMAGE", ),
1207
+ "trimap": ("MASK", ),
1208
+ "AEMatter_Model": ("AEMatter_Model", ),
1209
+ },
1210
+ }
1211
+
1212
+ RETURN_TYPES = ("MASK", )
1213
+ FUNCTION = "test"
1214
+ CATEGORY = "AEMatter"
1215
+
1216
+ def test(
1217
+ self,
1218
+ image,
1219
+ trimap,
1220
+ AEMatter_Model,
1221
+ ):
1222
+
1223
+ ret = []
1224
+ batch_size = image.shape[0]
1225
+
1226
+ for i in range(batch_size):
1227
+ tmp_i = from_torch_image(image[i])
1228
+ tmp_m = from_torch_image(trimap[i])
1229
+ tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
1230
+ ret.append(tmp)
1231
+
1232
+ ret = to_torch_image(np.array(ret))
1233
+ ret = ret.squeeze(-1)
1234
+ print(ret.shape)
1235
+
1236
+ return ret
1237
+
1238
+
1239
+ #!/usr/bin/python3
1240
+ NODE_CLASS_MAPPINGS = {
1241
+ 'load_AEMatter_Model': load_AEMatter_Model,
1242
+ 'run_AEMatter_inference': run_AEMatter_inference,
1243
+ }
1244
+
1245
+ NODE_DISPLAY_NAME_MAPPINGS = {
1246
+ 'load_AEMatter_Model': 'load_AEMatter_Model',
1247
+ 'run_AEMatter_inference': 'run_AEMatter_inference',
1248
+ }
ComfyUI_AEMatter/AEMatter.run.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/sh
2
+ . "${HOME}/dbnew.sh"
3
+ python3 './AEMatter.py'
ComfyUI_AEMatter/README.org ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * COMMENT SAMPLE
2
+
3
+ ** AEMatter.import.py
4
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.import.py
5
+ #+end_src
6
+
7
+ ** AEMatter.function.py
8
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
9
+ #+end_src
10
+
11
+ ** AEMatter.class.py
12
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
13
+ #+end_src
14
+
15
+ ** AEMatter.execute.py
16
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
17
+ #+end_src
18
+
19
+ ** AEMatter.unify.sh
20
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.unify.sh
21
+ #+end_src
22
+
23
+ ** AEMatter.run.sh
24
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.run.sh
25
+ #+end_src
26
+
27
+ * Code for AEMatter inference
28
+
29
+ ** AEMatter.import.py
30
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.import.py
31
+ import cv2
32
+ import math
33
+ import numpy as np
34
+ import os
35
+ import random
36
+ import wget
37
+
38
+ import torch
39
+ import torch.nn as nn
40
+ from torch.nn import init
41
+ import torch.nn.functional as F
42
+ import torch.utils.checkpoint as checkpoint
43
+
44
+ from collections import OrderedDict
45
+ from einops import rearrange, repeat
46
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
47
+
48
+ import folder_paths
49
+ from folder_paths import models_dir
50
+ #+end_src
51
+
52
+ ** Functions to prepare directory structure and download models
53
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
54
+ def mkdir_safe(out_path):
55
+ if type(out_path) == str:
56
+ if len(out_path) > 0:
57
+ if not os.path.exists(out_path):
58
+ os.mkdir(out_path)
59
+
60
+
61
+ def get_model_path():
62
+ import folder_paths
63
+ from folder_paths import models_dir
64
+
65
+ path_file_model = models_dir
66
+ mkdir_safe(out_path=path_file_model)
67
+
68
+ path_file_model = os.path.join(path_file_model, 'AEMatter')
69
+ mkdir_safe(out_path=path_file_model)
70
+
71
+ path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
72
+
73
+ return path_file_model
74
+
75
+
76
+ def download_model(path):
77
+ if not os.path.exists(path):
78
+ wget.download(
79
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
80
+ out=path)
81
+
82
+
83
+ def from_torch_image(image):
84
+ image = image.cpu().numpy() * 255.0
85
+ image = np.clip(image, 0, 255).astype(np.uint8)
86
+ return image
87
+
88
+
89
+ def to_torch_image(image):
90
+ image = image.astype(dtype=np.float32)
91
+ image /= 255.0
92
+ image = torch.from_numpy(image)
93
+ return image
94
+ #+end_src
95
+
96
+ ** AEMatter.function.py
97
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
98
+ def window_partition(x, window_size):
99
+ """
100
+ Args:
101
+ x: (B, H, W, C)
102
+ window_size (int): window size
103
+ Returns:
104
+ windows: (num_windows*B, window_size, window_size, C)
105
+ """
106
+ B, H, W, C = x.shape
107
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
108
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
109
+ return windows
110
+ #+end_src
111
+
112
+ ** AEMatter.function.py
113
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
114
+ def window_reverse(windows, window_size, H, W):
115
+ """
116
+ Args:
117
+ windows: (num_windows*B, window_size, window_size, C)
118
+ window_size (int): Window size
119
+ H (int): Height of image
120
+ W (int): Width of image
121
+ Returns:
122
+ x: (B, H, W, C)
123
+ """
124
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
125
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
126
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
127
+ return x
128
+ #+end_src
129
+
130
+ ** AEMatter.class.py
131
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
132
+ class WindowAttention(nn.Module):
133
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
134
+ It supports both of shifted and non-shifted window.
135
+ Args:
136
+ dim (int): Number of input channels.
137
+ window_size (tuple[int]): The height and width of the window.
138
+ num_heads (int): Number of attention heads.
139
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
140
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
141
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
142
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
143
+ """
144
+
145
+ def __init__(self,
146
+ dim,
147
+ window_size,
148
+ num_heads,
149
+ qkv_bias=True,
150
+ qk_scale=None,
151
+ attn_drop=0.,
152
+ proj_drop=0.):
153
+
154
+ super().__init__()
155
+ self.dim = dim
156
+ self.window_size = window_size # Wh, Ww
157
+ self.num_heads = num_heads
158
+ head_dim = dim // num_heads
159
+ self.scale = qk_scale or head_dim**-0.5
160
+
161
+ # define a parameter table of relative position bias
162
+ self.relative_position_bias_table = nn.Parameter(
163
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
164
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
165
+
166
+ # get pair-wise relative position index for each token inside the window
167
+ coords_h = torch.arange(self.window_size[0])
168
+ coords_w = torch.arange(self.window_size[1])
169
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
170
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
171
+ relative_coords = coords_flatten[:, :,
172
+ None] - coords_flatten[:,
173
+ None, :] # 2, Wh*Ww, Wh*Ww
174
+ relative_coords = relative_coords.permute(
175
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
176
+ relative_coords[:, :,
177
+ 0] += self.window_size[0] - 1 # shift to start from 0
178
+ relative_coords[:, :, 1] += self.window_size[1] - 1
179
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
180
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
181
+ self.register_buffer("relative_position_index",
182
+ relative_position_index)
183
+
184
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
185
+ self.attn_drop = nn.Dropout(attn_drop)
186
+ self.proj = nn.Linear(dim, dim)
187
+ self.proj_drop = nn.Dropout(proj_drop)
188
+
189
+ trunc_normal_(self.relative_position_bias_table, std=.02)
190
+ self.softmax = nn.Softmax(dim=-1)
191
+
192
+ def forward(self, x, mask=None):
193
+ """ Forward function.
194
+ Args:
195
+ x: input features with shape of (num_windows*B, N, C)
196
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
197
+ """
198
+ B_, N, C = x.shape
199
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
200
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ q, k, v = qkv[0], qkv[1], qkv[
202
+ 2] # make torchscript happy (cannot use tensor as tuple)
203
+
204
+ q = q * self.scale
205
+ attn = (q @ k.transpose(-2, -1))
206
+
207
+ relative_position_bias = self.relative_position_bias_table[
208
+ self.relative_position_index.view(-1)].view(
209
+ self.window_size[0] * self.window_size[1],
210
+ self.window_size[0] * self.window_size[1],
211
+ -1) # Wh*Ww,Wh*Ww,nH
212
+ relative_position_bias = relative_position_bias.permute(
213
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
214
+ attn = attn + relative_position_bias.unsqueeze(0)
215
+
216
+ if mask is not None:
217
+ nW = mask.shape[0]
218
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
219
+ N) + mask.unsqueeze(1).unsqueeze(0)
220
+ attn = attn.view(-1, self.num_heads, N, N)
221
+ attn = self.softmax(attn)
222
+ else:
223
+ attn = self.softmax(attn)
224
+
225
+ attn = self.attn_drop(attn)
226
+
227
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
228
+ x = self.proj(x)
229
+ x = self.proj_drop(x)
230
+ return x
231
+ #+end_src
232
+
233
+ ** AEMatter.class.py
234
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
235
+ class SwinTransformerBlock(nn.Module):
236
+ """ Swin Transformer Block.
237
+ Args:
238
+ dim (int): Number of input channels.
239
+ num_heads (int): Number of attention heads.
240
+ window_size (int): Window size.
241
+ shift_size (int): Shift size for SW-MSA.
242
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
243
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
244
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
245
+ drop (float, optional): Dropout rate. Default: 0.0
246
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
247
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
248
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
249
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
250
+ """
251
+
252
+ def __init__(self,
253
+ dim,
254
+ num_heads,
255
+ window_size=7,
256
+ shift_size=0,
257
+ mlp_ratio=4.,
258
+ qkv_bias=True,
259
+ qk_scale=None,
260
+ drop=0.,
261
+ attn_drop=0.,
262
+ drop_path=0.,
263
+ act_layer=nn.GELU,
264
+ norm_layer=nn.LayerNorm):
265
+ super().__init__()
266
+ self.dim = dim
267
+ self.num_heads = num_heads
268
+ self.window_size = window_size
269
+ self.shift_size = shift_size
270
+ self.mlp_ratio = mlp_ratio
271
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
272
+
273
+ self.norm1 = norm_layer(dim)
274
+ self.attn = WindowAttention(dim,
275
+ window_size=to_2tuple(self.window_size),
276
+ num_heads=num_heads,
277
+ qkv_bias=qkv_bias,
278
+ qk_scale=qk_scale,
279
+ attn_drop=attn_drop,
280
+ proj_drop=drop)
281
+
282
+ self.drop_path = DropPath(
283
+ drop_path) if drop_path > 0. else nn.Identity()
284
+ self.norm2 = norm_layer(dim)
285
+ mlp_hidden_dim = int(dim * mlp_ratio)
286
+ self.mlp = Mlp(in_features=dim,
287
+ hidden_features=mlp_hidden_dim,
288
+ act_layer=act_layer,
289
+ drop=drop)
290
+
291
+ self.H = None
292
+ self.W = None
293
+
294
+ def forward(self, x, mask_matrix):
295
+ """ Forward function.
296
+ Args:
297
+ x: Input feature, tensor size (B, H*W, C).
298
+ H, W: Spatial resolution of the input feature.
299
+ mask_matrix: Attention mask for cyclic shift.
300
+ """
301
+ B, L, C = x.shape
302
+ H, W = self.H, self.W
303
+ assert L == H * W, "input feature has wrong size"
304
+
305
+ shortcut = x
306
+ x = self.norm1(x)
307
+ x = x.view(B, H, W, C)
308
+
309
+ # pad feature maps to multiples of window size
310
+ pad_l = pad_t = 0
311
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
312
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
313
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
314
+ _, Hp, Wp, _ = x.shape
315
+
316
+ # cyclic shift
317
+ if self.shift_size > 0:
318
+ shifted_x = torch.roll(x,
319
+ shifts=(-self.shift_size, -self.shift_size),
320
+ dims=(1, 2))
321
+ attn_mask = mask_matrix
322
+ else:
323
+ shifted_x = x
324
+ attn_mask = None
325
+
326
+ # partition windows
327
+ x_windows = window_partition(
328
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
329
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
330
+ C) # nW*B, window_size*window_size, C
331
+
332
+ # W-MSA/SW-MSA
333
+ attn_windows = self.attn(
334
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
335
+
336
+ # merge windows
337
+ attn_windows = attn_windows.view(-1, self.window_size,
338
+ self.window_size, C)
339
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
340
+ Wp) # B H' W' C
341
+
342
+ # reverse cyclic shift
343
+ if self.shift_size > 0:
344
+ x = torch.roll(shifted_x,
345
+ shifts=(self.shift_size, self.shift_size),
346
+ dims=(1, 2))
347
+ else:
348
+ x = shifted_x
349
+
350
+ if pad_r > 0 or pad_b > 0:
351
+ x = x[:, :H, :W, :].contiguous()
352
+
353
+ x = x.view(B, H * W, C)
354
+
355
+ # FFN
356
+ x = shortcut + self.drop_path(x)
357
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
358
+
359
+ return x
360
+ #+end_src
361
+
362
+ ** AEMatter.class.py
363
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
364
+ class PatchMerging(nn.Module):
365
+ """ Patch Merging Layer
366
+ Args:
367
+ dim (int): Number of input channels.
368
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
369
+ """
370
+
371
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
372
+ super().__init__()
373
+ self.dim = dim
374
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
375
+ self.norm = norm_layer(4 * dim)
376
+
377
+ def forward(self, x, H, W):
378
+ """ Forward function.
379
+ Args:
380
+ x: Input feature, tensor size (B, H*W, C).
381
+ H, W: Spatial resolution of the input feature.
382
+ """
383
+ B, L, C = x.shape
384
+ assert L == H * W, "input feature has wrong size"
385
+
386
+ x = x.view(B, H, W, C)
387
+
388
+ # padding
389
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
390
+ if pad_input:
391
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
392
+
393
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
394
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
395
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
396
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
397
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
398
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
399
+
400
+ x = self.norm(x)
401
+ x = self.reduction(x)
402
+
403
+ return x
404
+ #+end_src
405
+
406
+
407
+ ** AEMatter.class.py
408
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
409
+ class BasicLayer(nn.Module):
410
+ """ A basic Swin Transformer layer for one stage.
411
+ Args:
412
+ dim (int): Number of feature channels
413
+ depth (int): Depths of this stage.
414
+ num_heads (int): Number of attention head.
415
+ window_size (int): Local window size. Default: 7.
416
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
417
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
418
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
419
+ drop (float, optional): Dropout rate. Default: 0.0
420
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
421
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
422
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
423
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
424
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
425
+ """
426
+
427
+ def __init__(self,
428
+ dim,
429
+ depth,
430
+ num_heads,
431
+ window_size=7,
432
+ mlp_ratio=4.,
433
+ qkv_bias=True,
434
+ qk_scale=None,
435
+ drop=0.,
436
+ attn_drop=0.,
437
+ drop_path=0.,
438
+ norm_layer=nn.LayerNorm,
439
+ downsample=None,
440
+ use_checkpoint=False):
441
+
442
+ super().__init__()
443
+ self.window_size = window_size
444
+ self.shift_size = window_size // 2
445
+ self.depth = depth
446
+ self.use_checkpoint = use_checkpoint
447
+
448
+ # build blocks
449
+ self.blocks = nn.ModuleList([
450
+ SwinTransformerBlock(dim=dim,
451
+ num_heads=num_heads,
452
+ window_size=window_size,
453
+ shift_size=0 if
454
+ (i % 2 == 0) else window_size // 2,
455
+ mlp_ratio=mlp_ratio,
456
+ qkv_bias=qkv_bias,
457
+ qk_scale=qk_scale,
458
+ drop=drop,
459
+ attn_drop=attn_drop,
460
+ drop_path=drop_path[i] if isinstance(
461
+ drop_path, list) else drop_path,
462
+ norm_layer=norm_layer) for i in range(depth)
463
+ ])
464
+
465
+ # patch merging layer
466
+ if downsample is not None:
467
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
468
+ else:
469
+ self.downsample = None
470
+
471
+ def forward(self, x, H, W):
472
+ """ Forward function.
473
+ Args:
474
+ x: Input feature, tensor size (B, H*W, C).
475
+ H, W: Spatial resolution of the input feature.
476
+ """
477
+ # print(x.shape,H,W)
478
+ # calculate attention mask for SW-MSA
479
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
480
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
481
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
482
+ h_slices = (slice(0, -self.window_size),
483
+ slice(-self.window_size,
484
+ -self.shift_size), slice(-self.shift_size, None))
485
+ w_slices = (slice(0, -self.window_size),
486
+ slice(-self.window_size,
487
+ -self.shift_size), slice(-self.shift_size, None))
488
+ cnt = 0
489
+ for h in h_slices:
490
+ for w in w_slices:
491
+ img_mask[:, h, w, :] = cnt
492
+ cnt += 1
493
+
494
+ mask_windows = window_partition(
495
+ img_mask, self.window_size) # nW, window_size, window_size, 1
496
+
497
+ mask_windows = mask_windows.view(-1,
498
+ self.window_size * self.window_size)
499
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
500
+ 2) # nW, ww window_size*window_size
501
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
502
+ float(-100.0)).masked_fill(
503
+ attn_mask == 0, float(0.0))
504
+
505
+ for blk in self.blocks:
506
+ blk.H, blk.W = H, W
507
+ if self.use_checkpoint:
508
+ x = checkpoint.checkpoint(blk, x, attn_mask)
509
+ else:
510
+ x = blk(x, attn_mask)
511
+
512
+ if self.downsample is not None:
513
+ x_down = self.downsample(x, H, W)
514
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
515
+ return x, H, W, x_down, Wh, Ww
516
+ else:
517
+ return x, H, W, x, H, W
518
+ #+end_src
519
+
520
+ ** AEMatter.class.py
521
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
522
+ class PatchEmbed(nn.Module):
523
+ """ Image to Patch Embedding
524
+ Args:
525
+ patch_size (int): Patch token size. Default: 4.
526
+ in_chans (int): Number of input image channels. Default: 3.
527
+ embed_dim (int): Number of linear projection output channels. Default: 96.
528
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
529
+ """
530
+
531
+ def __init__(self,
532
+ patch_size=4,
533
+ in_chans=3,
534
+ embed_dim=96,
535
+ norm_layer=None):
536
+
537
+ super().__init__()
538
+ patch_size = to_2tuple(patch_size)
539
+ self.patch_size = patch_size
540
+
541
+ self.in_chans = in_chans
542
+ self.embed_dim = embed_dim
543
+
544
+ self.proj = nn.Conv2d(in_chans,
545
+ embed_dim,
546
+ kernel_size=patch_size,
547
+ stride=patch_size)
548
+ if norm_layer is not None:
549
+ self.norm = norm_layer(embed_dim)
550
+ else:
551
+ self.norm = None
552
+
553
+ def forward(self, x):
554
+ """Forward function."""
555
+ # padding
556
+ _, _, H, W = x.size()
557
+ if W % self.patch_size[1] != 0:
558
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
559
+ if H % self.patch_size[0] != 0:
560
+ x = F.pad(x,
561
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
562
+
563
+ x = self.proj(x) # B C Wh Ww
564
+ if self.norm is not None:
565
+ Wh, Ww = x.size(2), x.size(3)
566
+ x = x.flatten(2).transpose(1, 2)
567
+ x = self.norm(x)
568
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
569
+
570
+ return x
571
+ #+end_src
572
+
573
+
574
+ ** AEMatter.class.py
575
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
576
+ class SwinTransformer(nn.Module):
577
+ """ Swin Transformer backbone.
578
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
579
+ https://arxiv.org/pdf/2103.14030
580
+ Args:
581
+ pretrain_img_size (int): Input image size for training the pretrained model,
582
+ used in absolute postion embedding. Default 224.
583
+ patch_size (int | tuple(int)): Patch size. Default: 4.
584
+ in_chans (int): Number of input image channels. Default: 3.
585
+ embed_dim (int): Number of linear projection output channels. Default: 96.
586
+ depths (tuple[int]): Depths of each Swin Transformer stage.
587
+ num_heads (tuple[int]): Number of attention head of each stage.
588
+ window_size (int): Window size. Default: 7.
589
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
590
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
591
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
592
+ drop_rate (float): Dropout rate.
593
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
594
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
595
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
596
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
597
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
598
+ out_indices (Sequence[int]): Output from which stages.
599
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
600
+ -1 means not freezing any parameters.
601
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
602
+ """
603
+
604
+ def __init__(self,
605
+ pretrain_img_size=224,
606
+ patch_size=4,
607
+ in_chans=3,
608
+ embed_dim=96,
609
+ depths=[2, 2, 6, 2],
610
+ num_heads=[3, 6, 12, 24],
611
+ window_size=7,
612
+ mlp_ratio=4.,
613
+ qkv_bias=True,
614
+ qk_scale=None,
615
+ drop_rate=0.,
616
+ attn_drop_rate=0.,
617
+ drop_path_rate=0.2,
618
+ norm_layer=nn.LayerNorm,
619
+ ape=False,
620
+ patch_norm=True,
621
+ out_indices=(0, 1, 2, 3),
622
+ frozen_stages=-1,
623
+ use_checkpoint=False):
624
+
625
+ super().__init__()
626
+
627
+ self.pretrain_img_size = pretrain_img_size
628
+ self.num_layers = len(depths)
629
+ self.embed_dim = embed_dim
630
+ self.ape = ape
631
+ self.patch_norm = patch_norm
632
+ self.out_indices = out_indices
633
+ self.frozen_stages = frozen_stages
634
+
635
+ # split image into non-overlapping patches
636
+ self.patch_embed = PatchEmbed(
637
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
638
+ norm_layer=norm_layer if self.patch_norm else None)
639
+
640
+ # absolute position embedding
641
+ if self.ape:
642
+ pretrain_img_size = to_2tuple(pretrain_img_size)
643
+ patch_size = to_2tuple(patch_size)
644
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
645
+
646
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
647
+ trunc_normal_(self.absolute_pos_embed, std=.02)
648
+
649
+ self.pos_drop = nn.Dropout(p=drop_rate)
650
+
651
+ # stochastic depth
652
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
653
+
654
+ # build layers
655
+ self.layers = nn.ModuleList()
656
+ for i_layer in range(self.num_layers):
657
+ layer = BasicLayer(
658
+ dim=int(embed_dim * 2 ** i_layer),
659
+ depth=depths[i_layer],
660
+ num_heads=num_heads[i_layer],
661
+ window_size=window_size,
662
+ mlp_ratio=mlp_ratio,
663
+ qkv_bias=qkv_bias,
664
+ qk_scale=qk_scale,
665
+ drop=drop_rate,
666
+ attn_drop=attn_drop_rate,
667
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
668
+ norm_layer=norm_layer,
669
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
670
+ use_checkpoint=use_checkpoint)
671
+ self.layers.append(layer)
672
+
673
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
674
+ self.num_features = num_features
675
+
676
+ # add a norm layer for each output
677
+ for i_layer in out_indices:
678
+ layer = norm_layer(num_features[i_layer])
679
+ layer_name = f'norm{i_layer}'
680
+ self.add_module(layer_name, layer)
681
+
682
+ self._freeze_stages()
683
+
684
+ def _freeze_stages(self):
685
+ if self.frozen_stages >= 0:
686
+ self.patch_embed.eval()
687
+ for param in self.patch_embed.parameters():
688
+ param.requires_grad = False
689
+
690
+ if self.frozen_stages >= 1 and self.ape:
691
+ self.absolute_pos_embed.requires_grad = False
692
+
693
+ if self.frozen_stages >= 2:
694
+ self.pos_drop.eval()
695
+ for i in range(0, self.frozen_stages - 1):
696
+ m = self.layers[i]
697
+ m.eval()
698
+ for param in m.parameters():
699
+ param.requires_grad = False
700
+
701
+ def init_weights(self, pretrained=None):
702
+ """Initialize the weights in backbone.
703
+ Args:
704
+ pretrained (str, optional): Path to pre-trained weights.
705
+ Defaults to None.
706
+ """
707
+
708
+
709
+ def forward(self, x):
710
+ """Forward function."""
711
+ x = self.patch_embed(x)
712
+
713
+ Wh, Ww = x.size(2), x.size(3)
714
+ if self.ape:
715
+ # interpolate the position embedding to the corresponding size
716
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
717
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
718
+ else:
719
+ x = x.flatten(2).transpose(1, 2)
720
+ x = self.pos_drop(x)
721
+
722
+ outs = []
723
+ for i in range(self.num_layers):
724
+ layer = self.layers[i]
725
+
726
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
727
+
728
+ if i in self.out_indices:
729
+ norm_layer = getattr(self, f'norm{i}')
730
+ x_out = norm_layer(x_out)
731
+
732
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
733
+ outs.append(out)
734
+
735
+ return tuple(outs)
736
+
737
+ def train(self, mode=True):
738
+ """Convert the model into training mode while keep layers freezed."""
739
+ super(SwinTransformer, self).train(mode)
740
+ self._freeze_stages()
741
+ #+end_src
742
+
743
+ ** AEMatter.class.py
744
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
745
+ class Mlp(nn.Module):
746
+ """ Multilayer perceptron."""
747
+
748
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
749
+ super().__init__()
750
+ out_features = out_features or in_features
751
+ hidden_features = hidden_features or in_features
752
+ self.fc1 = nn.Linear(in_features, hidden_features)
753
+ self.act = act_layer()
754
+ self.fc2 = nn.Linear(hidden_features, out_features)
755
+ self.drop = nn.Dropout(drop)
756
+
757
+ def forward(self, x):
758
+ x = self.fc1(x)
759
+ x = self.act(x)
760
+ x = self.drop(x)
761
+ x = self.fc2(x)
762
+ x = self.drop(x)
763
+ return x
764
+ #+end_src
765
+
766
+
767
+ ** AEMatter.class.py
768
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
769
+ class ResBlock(nn.Module):
770
+
771
+ def __init__(self, inc, midc):
772
+ super(ResBlock, self).__init__()
773
+ self.conv1 = nn.Conv2d(inc,
774
+ midc,
775
+ kernel_size=1,
776
+ stride=1,
777
+ padding=0,
778
+ bias=True)
779
+ self.gn1 = nn.GroupNorm(16, midc)
780
+ self.conv2 = nn.Conv2d(midc,
781
+ midc,
782
+ kernel_size=3,
783
+ stride=1,
784
+ padding=1,
785
+ bias=True)
786
+ self.gn2 = nn.GroupNorm(16, midc)
787
+ self.conv3 = nn.Conv2d(midc,
788
+ inc,
789
+ kernel_size=1,
790
+ stride=1,
791
+ padding=0,
792
+ bias=True)
793
+ self.relu = nn.LeakyReLU(0.1)
794
+
795
+ def forward(self, x):
796
+ x_ = x
797
+ x = self.conv1(x)
798
+ x = self.gn1(x)
799
+ x = self.relu(x)
800
+ x = self.conv2(x)
801
+ x = self.gn2(x)
802
+ x = self.relu(x)
803
+ x = self.conv3(x)
804
+ x = x + x_
805
+ x = self.relu(x)
806
+ return x
807
+ #+end_src
808
+
809
+ ** AEMatter.class.py
810
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
811
+ class AEALblock(nn.Module):
812
+
813
+ def __init__(self,
814
+ d_model,
815
+ nhead,
816
+ dim_feedforward=512,
817
+ dropout=0.0,
818
+ layer_norm_eps=1e-5,
819
+ batch_first=True,
820
+ norm_first=False,
821
+ width=5):
822
+ super(AEALblock, self).__init__()
823
+ self.self_attn2 = nn.MultiheadAttention(d_model // 2,
824
+ nhead // 2,
825
+ dropout=dropout,
826
+ batch_first=batch_first)
827
+ self.self_attn1 = nn.MultiheadAttention(d_model // 2,
828
+ nhead // 2,
829
+ dropout=dropout,
830
+ batch_first=batch_first)
831
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
832
+ self.dropout = nn.Dropout(dropout)
833
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
834
+ self.norm_first = norm_first
835
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
836
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
837
+ self.dropout1 = nn.Dropout(dropout)
838
+ self.dropout2 = nn.Dropout(dropout)
839
+ self.activation = nn.ReLU()
840
+ self.width = width
841
+ self.trans = nn.Sequential(
842
+ nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
843
+ ResBlock(d_model // 2, d_model // 4),
844
+ nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
845
+ self.gamma = nn.Parameter(torch.zeros(1))
846
+
847
+ def forward(
848
+ self,
849
+ src,
850
+ feats,
851
+ ):
852
+ src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
853
+ b, c, h, w = src.shape
854
+ x1 = src[:, 0:c // 2]
855
+ x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
856
+ x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
857
+ x2 = src[:, c // 2:]
858
+ x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
859
+ x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
860
+ x = rearrange(src, 'b c h w-> b (h w) c')
861
+ x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
862
+ x = self.norm2(x + self._ff_block(x))
863
+ x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
864
+ return x
865
+
866
+ def _sa_block(self, x1, x2, h, w):
867
+ x1 = self.self_attn1(x1,
868
+ x1,
869
+ x1,
870
+ attn_mask=None,
871
+ key_padding_mask=None,
872
+ need_weights=False)[0]
873
+
874
+ x2 = self.self_attn2(x2,
875
+ x2,
876
+ x2,
877
+ attn_mask=None,
878
+ key_padding_mask=None,
879
+ need_weights=False)[0]
880
+
881
+ x1 = rearrange(x1,
882
+ '(b h1) (h2 w) c-> b (h1 h2 w) c',
883
+ h2=self.width,
884
+ h1=h // self.width)
885
+ x2 = rearrange(x2,
886
+ ' (b w1) (h w2) c-> b (h w1 w2) c',
887
+ w2=self.width,
888
+ w1=w // self.width)
889
+ x = torch.cat([x1, x2], dim=2)
890
+ return self.dropout1(x)
891
+
892
+ def _ff_block(self, x):
893
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
894
+ return self.dropout2(x)
895
+ #+end_src
896
+
897
+ ** AEMatter.class.py
898
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
899
+ class AEMatter(nn.Module):
900
+
901
+ def __init__(self):
902
+ super(AEMatter, self).__init__()
903
+ trans = SwinTransformer(pretrain_img_size=224,
904
+ embed_dim=96,
905
+ depths=[2, 2, 6, 2],
906
+ num_heads=[3, 6, 12, 24],
907
+ window_size=7,
908
+ ape=False,
909
+ drop_path_rate=0.2,
910
+ patch_norm=True,
911
+ use_checkpoint=False)
912
+
913
+ # trans.load_state_dict(torch.load(
914
+ # '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
915
+ # map_location="cpu")["model"],
916
+ # strict=False)
917
+
918
+ trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
919
+
920
+ self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
921
+ nn.PReLU(48))
922
+
923
+ self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
924
+ 1), nn.PReLU(64),
925
+ nn.Conv2d(64, 64, 3, 1, 1),
926
+ nn.PReLU(64))
927
+
928
+ self.trans = trans
929
+ self.conv1 = nn.Sequential(
930
+ nn.Conv2d(in_channels=640 + 768,
931
+ out_channels=256,
932
+ kernel_size=1,
933
+ stride=1,
934
+ padding=0,
935
+ bias=True))
936
+ self.conv2 = nn.Sequential(
937
+ nn.Conv2d(in_channels=256 + 384,
938
+ out_channels=256,
939
+ kernel_size=1,
940
+ stride=1,
941
+ padding=0,
942
+ bias=True), )
943
+ self.conv3 = nn.Sequential(
944
+ nn.Conv2d(in_channels=256 + 192,
945
+ out_channels=192,
946
+ kernel_size=1,
947
+ stride=1,
948
+ padding=0,
949
+ bias=True), )
950
+ self.conv4 = nn.Sequential(
951
+ nn.Conv2d(in_channels=192 + 96,
952
+ out_channels=128,
953
+ kernel_size=1,
954
+ stride=1,
955
+ padding=0,
956
+ bias=True), )
957
+ self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
958
+ self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
959
+ self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
960
+ self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
961
+ self.conv5 = nn.Sequential(
962
+ nn.Conv2d(in_channels=192,
963
+ out_channels=64,
964
+ kernel_size=3,
965
+ stride=1,
966
+ padding=1,
967
+ bias=True), nn.PReLU(64),
968
+ nn.Conv2d(in_channels=64,
969
+ out_channels=64,
970
+ kernel_size=3,
971
+ stride=1,
972
+ padding=1,
973
+ bias=True), nn.PReLU(64),
974
+ nn.Conv2d(in_channels=64,
975
+ out_channels=48,
976
+ kernel_size=3,
977
+ stride=1,
978
+ padding=1,
979
+ bias=True), nn.PReLU(48))
980
+ self.convo = nn.Sequential(
981
+ nn.Conv2d(in_channels=48 + 48 + 6,
982
+ out_channels=32,
983
+ kernel_size=3,
984
+ stride=1,
985
+ padding=1,
986
+ bias=True), nn.PReLU(32),
987
+ nn.Conv2d(in_channels=32,
988
+ out_channels=32,
989
+ kernel_size=3,
990
+ stride=1,
991
+ padding=1,
992
+ bias=True), nn.PReLU(32),
993
+ nn.Conv2d(in_channels=32,
994
+ out_channels=1,
995
+ kernel_size=3,
996
+ stride=1,
997
+ padding=1,
998
+ bias=True))
999
+ self.up = nn.Upsample(scale_factor=2,
1000
+ mode='bilinear',
1001
+ align_corners=False)
1002
+ self.upn = nn.Upsample(scale_factor=2, mode='nearest')
1003
+ self.apptrans = nn.Sequential(
1004
+ nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
1005
+ ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
1006
+ ResBlock(512, 128))
1007
+ self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
1008
+ ResBlock(640, 160))
1009
+ self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
1010
+ self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
1011
+ self.width = 5
1012
+ self.trans1 = AEALblock(d_model=640,
1013
+ nhead=20,
1014
+ dim_feedforward=2048,
1015
+ dropout=0.2,
1016
+ width=self.width)
1017
+ self.trans2 = AEALblock(d_model=640,
1018
+ nhead=20,
1019
+ dim_feedforward=2048,
1020
+ dropout=0.2,
1021
+ width=self.width)
1022
+ self.trans3 = AEALblock(d_model=640,
1023
+ nhead=20,
1024
+ dim_feedforward=2048,
1025
+ dropout=0.2,
1026
+ width=self.width)
1027
+
1028
+ def aeal(self, x, sem):
1029
+ xe = self.emb(x)
1030
+ x_ = xe
1031
+ x_ = self.embdp(x_)
1032
+ b, c, h1, w1 = x_.shape
1033
+ bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
1034
+ bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
1035
+ newph1 = bnew_ph // 2
1036
+ newph2 = bnew_ph - newph1
1037
+ newpw1 = bnew_pw // 2
1038
+ newpw2 = bnew_pw - newpw1
1039
+ x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
1040
+ sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
1041
+ x_ = self.trans1(x_, sem)
1042
+ x_ = self.trans2(x_, sem)
1043
+ x_ = self.trans3(x_, sem)
1044
+ x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
1045
+ return x_
1046
+
1047
+ def forward(self, x, y):
1048
+ inputs = torch.cat((x, y), 1)
1049
+ x = self.start_conv0(inputs)
1050
+ x_ = self.start_conv(x)
1051
+ x1, x2, x3, x4 = self.trans(x_)
1052
+ x4h = self.h2l(x4)
1053
+ x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
1054
+ x4_ = self.aeal(x4, x3s)
1055
+ x4 = torch.cat((x4, x4_), 1)
1056
+ X4 = self.conv1(x4)
1057
+ wh, ww = X4.shape[2], X4.shape[3]
1058
+ X4 = rearrange(X4, 'b c h w -> b (h w) c')
1059
+ X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
1060
+ X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
1061
+ X3 = self.up(X4)
1062
+ X3 = torch.cat((x3, X3), 1)
1063
+ X3 = self.conv2(X3)
1064
+ wh, ww = X3.shape[2], X3.shape[3]
1065
+ X3 = rearrange(X3, 'b c h w -> b (h w) c')
1066
+ X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
1067
+ X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
1068
+ X2 = self.up(X3)
1069
+ X2 = torch.cat((x2, X2), 1)
1070
+ X2 = self.conv3(X2)
1071
+ wh, ww = X2.shape[2], X2.shape[3]
1072
+ X2 = rearrange(X2, 'b c h w -> b (h w) c')
1073
+ X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
1074
+ X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
1075
+ X1 = self.up(X2)
1076
+ X1 = torch.cat((x1, X1), 1)
1077
+ X1 = self.conv4(X1)
1078
+ wh, ww = X1.shape[2], X1.shape[3]
1079
+ X1 = rearrange(X1, 'b c h w -> b (h w) c')
1080
+ X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
1081
+ X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
1082
+ X0 = self.up(X1)
1083
+ X0 = torch.cat((x_, X0), 1)
1084
+ X0 = self.conv5(X0)
1085
+ X = self.up(X0)
1086
+ X = torch.cat((inputs, x, X), 1)
1087
+ alpha = self.convo(X)
1088
+ alpha = torch.clamp(alpha, min=0, max=1)
1089
+ return alpha
1090
+ #+end_src
1091
+
1092
+ ** Function to load model
1093
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
1094
+ def get_AEMatter_model(path_model_checkpoint):
1095
+
1096
+ download_model(path=path_model_checkpoint)
1097
+
1098
+ matmodel = AEMatter()
1099
+ matmodel.load_state_dict(
1100
+ torch.load(path_model_checkpoint, map_location='cpu')['model'])
1101
+
1102
+ matmodel = matmodel.cuda()
1103
+ matmodel.eval()
1104
+
1105
+ return matmodel
1106
+ #+end_src
1107
+
1108
+ ** Function to do inference
1109
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
1110
+ def do_infer(rawimg, trimap, matmodel):
1111
+ trimap_nonp = trimap.copy()
1112
+ h, w, c = rawimg.shape
1113
+ nonph, nonpw, _ = rawimg.shape
1114
+ newh = (((h - 1) // 32) + 1) * 32
1115
+ neww = (((w - 1) // 32) + 1) * 32
1116
+ padh = newh - h
1117
+ padh1 = int(padh / 2)
1118
+ padh2 = padh - padh1
1119
+ padw = neww - w
1120
+ padw1 = int(padw / 2)
1121
+ padw2 = padw - padw1
1122
+
1123
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
1124
+ cv2.BORDER_REFLECT)
1125
+
1126
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
1127
+ cv2.BORDER_REFLECT)
1128
+
1129
+ h_pad, w_pad, _ = rawimg_pad.shape
1130
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
1131
+ tritemp[:, :, 0] = (trimap_pad == 0)
1132
+ tritemp[:, :, 1] = (trimap_pad == 128)
1133
+ tritemp[:, :, 2] = (trimap_pad == 255)
1134
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
1135
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
1136
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
1137
+ img = np.array(img, np.float32)
1138
+ img = img / 255.
1139
+ img = torch.from_numpy(img).cuda()
1140
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
1141
+ with torch.no_grad():
1142
+ pred = matmodel(img, tritempimgs)
1143
+ pred = pred.detach().cpu().numpy()[0]
1144
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
1145
+ preda = pred[
1146
+ 0:1,
1147
+ ] * 255
1148
+ preda = np.transpose(preda, (1, 2, 0))
1149
+ preda = preda * (trimap_nonp[:, :, None]
1150
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
1151
+ preda = np.array(preda, np.uint8)
1152
+ return preda
1153
+ #+end_src
1154
+
1155
+ ** Load ComfyUI AEMatter model
1156
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.class.py
1157
+ class load_AEMatter_Model:
1158
+
1159
+ def __init__(self):
1160
+ pass
1161
+
1162
+ @classmethod
1163
+ def INPUT_TYPES(s):
1164
+ return {
1165
+ "required": {},
1166
+ }
1167
+
1168
+ RETURN_TYPES = ("AEMatter_Model", )
1169
+ FUNCTION = "test"
1170
+ CATEGORY = "AEMatter"
1171
+
1172
+ def test(self):
1173
+ return (get_AEMatter_model(get_model_path()), )
1174
+
1175
+
1176
+ class run_AEMatter_inference:
1177
+
1178
+ def __init__(self):
1179
+ pass
1180
+
1181
+ @classmethod
1182
+ def INPUT_TYPES(s):
1183
+ return {
1184
+ "required": {
1185
+ "image": ("IMAGE", ),
1186
+ "trimap": ("MASK", ),
1187
+ "AEMatter_Model": ("AEMatter_Model", ),
1188
+ },
1189
+ }
1190
+
1191
+ RETURN_TYPES = ("MASK", )
1192
+ FUNCTION = "test"
1193
+ CATEGORY = "AEMatter"
1194
+
1195
+ def test(
1196
+ self,
1197
+ image,
1198
+ trimap,
1199
+ AEMatter_Model,
1200
+ ):
1201
+
1202
+ ret = []
1203
+ batch_size = image.shape[0]
1204
+
1205
+ for i in range(batch_size):
1206
+ tmp_i = from_torch_image(image[i])
1207
+ tmp_m = from_torch_image(trimap[i])
1208
+ tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
1209
+ ret.append(tmp)
1210
+
1211
+ ret = to_torch_image(np.array(ret))
1212
+ ret = ret.squeeze(-1)
1213
+ print(ret.shape)
1214
+
1215
+ return ret
1216
+ #+end_src
1217
+
1218
+ ** Main function
1219
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.function.py
1220
+ def main():
1221
+ ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
1222
+ pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
1223
+ p_outs = 'alpha.png'
1224
+
1225
+ matmodel = get_AEMatter_model(
1226
+ path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
1227
+
1228
+ # matmodel = AEMatter()
1229
+ # matmodel.load_state_dict(
1230
+ # torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
1231
+ # map_location='cpu')['model'])
1232
+
1233
+ # matmodel = matmodel.cuda()
1234
+ # matmodel.eval()
1235
+
1236
+ rawimg = pimgs
1237
+ trimap = ptrimap
1238
+ rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
1239
+ trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
1240
+ trimap_nonp = trimap.copy()
1241
+ h, w, c = rawimg.shape
1242
+ nonph, nonpw, _ = rawimg.shape
1243
+ newh = (((h - 1) // 32) + 1) * 32
1244
+ neww = (((w - 1) // 32) + 1) * 32
1245
+ padh = newh - h
1246
+ padh1 = int(padh / 2)
1247
+ padh2 = padh - padh1
1248
+ padw = neww - w
1249
+ padw1 = int(padw / 2)
1250
+ padw2 = padw - padw1
1251
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
1252
+ cv2.BORDER_REFLECT)
1253
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
1254
+ cv2.BORDER_REFLECT)
1255
+ h_pad, w_pad, _ = rawimg_pad.shape
1256
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
1257
+ tritemp[:, :, 0] = (trimap_pad == 0)
1258
+ tritemp[:, :, 1] = (trimap_pad == 128)
1259
+ tritemp[:, :, 2] = (trimap_pad == 255)
1260
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
1261
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
1262
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
1263
+ img = np.array(img, np.float32)
1264
+ img = img / 255.
1265
+ img = torch.from_numpy(img).cuda()
1266
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
1267
+ with torch.no_grad():
1268
+ pred = matmodel(img, tritempimgs)
1269
+ pred = pred.detach().cpu().numpy()[0]
1270
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
1271
+ preda = pred[
1272
+ 0:1,
1273
+ ] * 255
1274
+ preda = np.transpose(preda, (1, 2, 0))
1275
+ preda = preda * (trimap_nonp[:, :, None]
1276
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
1277
+ preda = np.array(preda, np.uint8)
1278
+ cv2.imwrite(p_outs, preda)
1279
+
1280
+ #+end_src
1281
+
1282
+ ** Comfyui Dictionary
1283
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
1284
+ NODE_CLASS_MAPPINGS = {
1285
+ 'load_AEMatter_Model': load_AEMatter_Model,
1286
+ 'run_AEMatter_inference': run_AEMatter_inference,
1287
+ }
1288
+
1289
+ NODE_DISPLAY_NAME_MAPPINGS = {
1290
+ 'load_AEMatter_Model': 'load_AEMatter_Model',
1291
+ 'run_AEMatter_inference': 'run_AEMatter_inference',
1292
+ }
1293
+ #+end_src
1294
+
1295
+ ** COMMENT AEMatter.execute.py
1296
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./AEMatter.execute.py
1297
+ if __name__ == '__main__':
1298
+ # main()
1299
+
1300
+ rawimg = cv2.imread('/home/asd/Desktop/demo/retriever_rgb.png',
1301
+ cv2.IMREAD_COLOR)
1302
+
1303
+ trimap = cv2.imread('/home/asd/Desktop/demo/retriever_trimap.png',
1304
+ cv2.IMREAD_GRAYSCALE)
1305
+
1306
+ do_infer(rawimg, trimap,
1307
+ get_AEMatter_model('/home/asd/Desktop/AEM_RWA.ckpt'))
1308
+ #+end_src
1309
+
1310
+ ** AEMatter.unify.sh
1311
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.unify.sh
1312
+ . "${HOME}/dbnew.sh"
1313
+
1314
+ cat \
1315
+ 'AEMatter.import.py' \
1316
+ 'AEMatter.function.py' \
1317
+ 'AEMatter.class.py' \
1318
+ 'AEMatter.execute.py' \
1319
+ | expand | yapf3 \
1320
+ > 'AEMatter.py' \
1321
+ ;
1322
+
1323
+ cp 'AEMatter.py' '__init__.py'
1324
+ #+end_src
1325
+
1326
+ ** AEMatter.run.sh
1327
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./AEMatter.run.sh
1328
+ . "${HOME}/dbnew.sh"
1329
+ python3 './AEMatter.py'
1330
+ #+end_src
1331
+
1332
+ #+RESULTS:
1333
+
1334
+ * COMMENT WORK SPACE
1335
+
1336
+ ** ESHELL
1337
+ #+begin_src elisp
1338
+ (save-buffer)
1339
+ (org-babel-tangle)
1340
+ (shell-command "./AEMatter.unify.sh")
1341
+ #+end_src
1342
+
1343
+ #+RESULTS:
1344
+ : 0
1345
+
1346
+ ** SHELL
1347
+ #+begin_src sh :shebang #!/bin/sh :results output
1348
+ realpath .
1349
+ cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/AEMatter
1350
+ #+end_src
1351
+
1352
+ #+RESULTS:
1353
+
1354
+ ** SHELL
1355
+ #+begin_src sh :shebang #!/bin/sh :results output
1356
+ ls
1357
+ #+end_src
ComfyUI_AEMatter/__init__.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import cv2
3
+ import math
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import wget
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import init
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+
15
+ from collections import OrderedDict
16
+ from einops import rearrange, repeat
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ import folder_paths
20
+ from folder_paths import models_dir
21
+
22
+
23
+ #!/usr/bin/python3
24
+ def mkdir_safe(out_path):
25
+ if type(out_path) == str:
26
+ if len(out_path) > 0:
27
+ if not os.path.exists(out_path):
28
+ os.mkdir(out_path)
29
+
30
+
31
+ def get_model_path():
32
+ import folder_paths
33
+ from folder_paths import models_dir
34
+
35
+ path_file_model = models_dir
36
+ mkdir_safe(out_path=path_file_model)
37
+
38
+ path_file_model = os.path.join(path_file_model, 'AEMatter')
39
+ mkdir_safe(out_path=path_file_model)
40
+
41
+ path_file_model = os.path.join(path_file_model, 'AEM_RWA.ckpt')
42
+
43
+ return path_file_model
44
+
45
+
46
+ def download_model(path):
47
+ if not os.path.exists(path):
48
+ wget.download(
49
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/AEMatter/AEM_RWA.ckpt?download=true',
50
+ out=path)
51
+
52
+
53
+ def from_torch_image(image):
54
+ image = image.cpu().numpy() * 255.0
55
+ image = np.clip(image, 0, 255).astype(np.uint8)
56
+ return image
57
+
58
+
59
+ def to_torch_image(image):
60
+ image = image.astype(dtype=np.float32)
61
+ image /= 255.0
62
+ image = torch.from_numpy(image)
63
+ return image
64
+
65
+
66
+ def window_partition(x, window_size):
67
+ """
68
+ Args:
69
+ x: (B, H, W, C)
70
+ window_size (int): window size
71
+ Returns:
72
+ windows: (num_windows*B, window_size, window_size, C)
73
+ """
74
+ B, H, W, C = x.shape
75
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
76
+ C)
77
+ windows = x.permute(0, 1, 3, 2, 4,
78
+ 5).contiguous().view(-1, window_size, window_size, C)
79
+ return windows
80
+
81
+
82
+ def window_reverse(windows, window_size, H, W):
83
+ """
84
+ Args:
85
+ windows: (num_windows*B, window_size, window_size, C)
86
+ window_size (int): Window size
87
+ H (int): Height of image
88
+ W (int): Width of image
89
+ Returns:
90
+ x: (B, H, W, C)
91
+ """
92
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
93
+ x = windows.view(B, H // window_size, W // window_size, window_size,
94
+ window_size, -1)
95
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
96
+ return x
97
+
98
+
99
+ def get_AEMatter_model(path_model_checkpoint):
100
+
101
+ download_model(path=path_model_checkpoint)
102
+
103
+ matmodel = AEMatter()
104
+ matmodel.load_state_dict(
105
+ torch.load(path_model_checkpoint, map_location='cpu')['model'])
106
+
107
+ matmodel = matmodel.cuda()
108
+ matmodel.eval()
109
+
110
+ return matmodel
111
+
112
+
113
+ def do_infer(rawimg, trimap, matmodel):
114
+ trimap_nonp = trimap.copy()
115
+ h, w, c = rawimg.shape
116
+ nonph, nonpw, _ = rawimg.shape
117
+ newh = (((h - 1) // 32) + 1) * 32
118
+ neww = (((w - 1) // 32) + 1) * 32
119
+ padh = newh - h
120
+ padh1 = int(padh / 2)
121
+ padh2 = padh - padh1
122
+ padw = neww - w
123
+ padw1 = int(padw / 2)
124
+ padw2 = padw - padw1
125
+
126
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
127
+ cv2.BORDER_REFLECT)
128
+
129
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
130
+ cv2.BORDER_REFLECT)
131
+
132
+ h_pad, w_pad, _ = rawimg_pad.shape
133
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
134
+ tritemp[:, :, 0] = (trimap_pad == 0)
135
+ tritemp[:, :, 1] = (trimap_pad == 128)
136
+ tritemp[:, :, 2] = (trimap_pad == 255)
137
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
138
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
139
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
140
+ img = np.array(img, np.float32)
141
+ img = img / 255.
142
+ img = torch.from_numpy(img).cuda()
143
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
144
+ with torch.no_grad():
145
+ pred = matmodel(img, tritempimgs)
146
+ pred = pred.detach().cpu().numpy()[0]
147
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
148
+ preda = pred[
149
+ 0:1,
150
+ ] * 255
151
+ preda = np.transpose(preda, (1, 2, 0))
152
+ preda = preda * (trimap_nonp[:, :, None]
153
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
154
+ preda = np.array(preda, np.uint8)
155
+ return preda
156
+
157
+
158
+ def main():
159
+ ptrimap = '/home/asd/Desktop/demo/retriever_trimap.png'
160
+ pimgs = '/home/asd/Desktop/demo/retriever_rgb.png'
161
+ p_outs = 'alpha.png'
162
+
163
+ matmodel = get_AEMatter_model(
164
+ path_model_checkpoint='/home/asd/Desktop/AEM_RWA.ckpt')
165
+
166
+ # matmodel = AEMatter()
167
+ # matmodel.load_state_dict(
168
+ # torch.load('/home/asd/Desktop/AEM_RWA.ckpt',
169
+ # map_location='cpu')['model'])
170
+
171
+ # matmodel = matmodel.cuda()
172
+ # matmodel.eval()
173
+
174
+ rawimg = pimgs
175
+ trimap = ptrimap
176
+ rawimg = cv2.imread(rawimg, cv2.IMREAD_COLOR)
177
+ trimap = cv2.imread(trimap, cv2.IMREAD_GRAYSCALE)
178
+ trimap_nonp = trimap.copy()
179
+ h, w, c = rawimg.shape
180
+ nonph, nonpw, _ = rawimg.shape
181
+ newh = (((h - 1) // 32) + 1) * 32
182
+ neww = (((w - 1) // 32) + 1) * 32
183
+ padh = newh - h
184
+ padh1 = int(padh / 2)
185
+ padh2 = padh - padh1
186
+ padw = neww - w
187
+ padw1 = int(padw / 2)
188
+ padw2 = padw - padw1
189
+ rawimg_pad = cv2.copyMakeBorder(rawimg, padh1, padh2, padw1, padw2,
190
+ cv2.BORDER_REFLECT)
191
+ trimap_pad = cv2.copyMakeBorder(trimap, padh1, padh2, padw1, padw2,
192
+ cv2.BORDER_REFLECT)
193
+ h_pad, w_pad, _ = rawimg_pad.shape
194
+ tritemp = np.zeros([*trimap_pad.shape, 3], np.float32)
195
+ tritemp[:, :, 0] = (trimap_pad == 0)
196
+ tritemp[:, :, 1] = (trimap_pad == 128)
197
+ tritemp[:, :, 2] = (trimap_pad == 255)
198
+ tritempimgs = np.transpose(tritemp, (2, 0, 1))
199
+ tritempimgs = tritempimgs[np.newaxis, :, :, :]
200
+ img = np.transpose(rawimg_pad, (2, 0, 1))[np.newaxis, ::-1, :, :]
201
+ img = np.array(img, np.float32)
202
+ img = img / 255.
203
+ img = torch.from_numpy(img).cuda()
204
+ tritempimgs = torch.from_numpy(tritempimgs).cuda()
205
+ with torch.no_grad():
206
+ pred = matmodel(img, tritempimgs)
207
+ pred = pred.detach().cpu().numpy()[0]
208
+ pred = pred[:, padh1:padh1 + h, padw1:padw1 + w]
209
+ preda = pred[
210
+ 0:1,
211
+ ] * 255
212
+ preda = np.transpose(preda, (1, 2, 0))
213
+ preda = preda * (trimap_nonp[:, :, None]
214
+ == 128) + (trimap_nonp[:, :, None] == 255) * 255
215
+ preda = np.array(preda, np.uint8)
216
+ cv2.imwrite(p_outs, preda)
217
+
218
+
219
+ #!/usr/bin/python3
220
+ class WindowAttention(nn.Module):
221
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
222
+ It supports both of shifted and non-shifted window.
223
+ Args:
224
+ dim (int): Number of input channels.
225
+ window_size (tuple[int]): The height and width of the window.
226
+ num_heads (int): Number of attention heads.
227
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
228
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
229
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
230
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
231
+ """
232
+
233
+ def __init__(self,
234
+ dim,
235
+ window_size,
236
+ num_heads,
237
+ qkv_bias=True,
238
+ qk_scale=None,
239
+ attn_drop=0.,
240
+ proj_drop=0.):
241
+
242
+ super().__init__()
243
+ self.dim = dim
244
+ self.window_size = window_size # Wh, Ww
245
+ self.num_heads = num_heads
246
+ head_dim = dim // num_heads
247
+ self.scale = qk_scale or head_dim**-0.5
248
+
249
+ # define a parameter table of relative position bias
250
+ self.relative_position_bias_table = nn.Parameter(
251
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
252
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
253
+
254
+ # get pair-wise relative position index for each token inside the window
255
+ coords_h = torch.arange(self.window_size[0])
256
+ coords_w = torch.arange(self.window_size[1])
257
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
258
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
259
+ relative_coords = coords_flatten[:, :,
260
+ None] - coords_flatten[:,
261
+ None, :] # 2, Wh*Ww, Wh*Ww
262
+ relative_coords = relative_coords.permute(
263
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
264
+ relative_coords[:, :,
265
+ 0] += self.window_size[0] - 1 # shift to start from 0
266
+ relative_coords[:, :, 1] += self.window_size[1] - 1
267
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
268
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
269
+ self.register_buffer("relative_position_index",
270
+ relative_position_index)
271
+
272
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
273
+ self.attn_drop = nn.Dropout(attn_drop)
274
+ self.proj = nn.Linear(dim, dim)
275
+ self.proj_drop = nn.Dropout(proj_drop)
276
+
277
+ trunc_normal_(self.relative_position_bias_table, std=.02)
278
+ self.softmax = nn.Softmax(dim=-1)
279
+
280
+ def forward(self, x, mask=None):
281
+ """ Forward function.
282
+ Args:
283
+ x: input features with shape of (num_windows*B, N, C)
284
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
285
+ """
286
+ B_, N, C = x.shape
287
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
288
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
289
+ q, k, v = qkv[0], qkv[1], qkv[
290
+ 2] # make torchscript happy (cannot use tensor as tuple)
291
+
292
+ q = q * self.scale
293
+ attn = (q @ k.transpose(-2, -1))
294
+
295
+ relative_position_bias = self.relative_position_bias_table[
296
+ self.relative_position_index.view(-1)].view(
297
+ self.window_size[0] * self.window_size[1],
298
+ self.window_size[0] * self.window_size[1],
299
+ -1) # Wh*Ww,Wh*Ww,nH
300
+ relative_position_bias = relative_position_bias.permute(
301
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
302
+ attn = attn + relative_position_bias.unsqueeze(0)
303
+
304
+ if mask is not None:
305
+ nW = mask.shape[0]
306
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
307
+ N) + mask.unsqueeze(1).unsqueeze(0)
308
+ attn = attn.view(-1, self.num_heads, N, N)
309
+ attn = self.softmax(attn)
310
+ else:
311
+ attn = self.softmax(attn)
312
+
313
+ attn = self.attn_drop(attn)
314
+
315
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
316
+ x = self.proj(x)
317
+ x = self.proj_drop(x)
318
+ return x
319
+
320
+
321
+ class SwinTransformerBlock(nn.Module):
322
+ """ Swin Transformer Block.
323
+ Args:
324
+ dim (int): Number of input channels.
325
+ num_heads (int): Number of attention heads.
326
+ window_size (int): Window size.
327
+ shift_size (int): Shift size for SW-MSA.
328
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
329
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
330
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
331
+ drop (float, optional): Dropout rate. Default: 0.0
332
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
333
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
334
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
335
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
336
+ """
337
+
338
+ def __init__(self,
339
+ dim,
340
+ num_heads,
341
+ window_size=7,
342
+ shift_size=0,
343
+ mlp_ratio=4.,
344
+ qkv_bias=True,
345
+ qk_scale=None,
346
+ drop=0.,
347
+ attn_drop=0.,
348
+ drop_path=0.,
349
+ act_layer=nn.GELU,
350
+ norm_layer=nn.LayerNorm):
351
+ super().__init__()
352
+ self.dim = dim
353
+ self.num_heads = num_heads
354
+ self.window_size = window_size
355
+ self.shift_size = shift_size
356
+ self.mlp_ratio = mlp_ratio
357
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
358
+
359
+ self.norm1 = norm_layer(dim)
360
+ self.attn = WindowAttention(dim,
361
+ window_size=to_2tuple(self.window_size),
362
+ num_heads=num_heads,
363
+ qkv_bias=qkv_bias,
364
+ qk_scale=qk_scale,
365
+ attn_drop=attn_drop,
366
+ proj_drop=drop)
367
+
368
+ self.drop_path = DropPath(
369
+ drop_path) if drop_path > 0. else nn.Identity()
370
+ self.norm2 = norm_layer(dim)
371
+ mlp_hidden_dim = int(dim * mlp_ratio)
372
+ self.mlp = Mlp(in_features=dim,
373
+ hidden_features=mlp_hidden_dim,
374
+ act_layer=act_layer,
375
+ drop=drop)
376
+
377
+ self.H = None
378
+ self.W = None
379
+
380
+ def forward(self, x, mask_matrix):
381
+ """ Forward function.
382
+ Args:
383
+ x: Input feature, tensor size (B, H*W, C).
384
+ H, W: Spatial resolution of the input feature.
385
+ mask_matrix: Attention mask for cyclic shift.
386
+ """
387
+ B, L, C = x.shape
388
+ H, W = self.H, self.W
389
+ assert L == H * W, "input feature has wrong size"
390
+
391
+ shortcut = x
392
+ x = self.norm1(x)
393
+ x = x.view(B, H, W, C)
394
+
395
+ # pad feature maps to multiples of window size
396
+ pad_l = pad_t = 0
397
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
398
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
399
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
400
+ _, Hp, Wp, _ = x.shape
401
+
402
+ # cyclic shift
403
+ if self.shift_size > 0:
404
+ shifted_x = torch.roll(x,
405
+ shifts=(-self.shift_size, -self.shift_size),
406
+ dims=(1, 2))
407
+ attn_mask = mask_matrix
408
+ else:
409
+ shifted_x = x
410
+ attn_mask = None
411
+
412
+ # partition windows
413
+ x_windows = window_partition(
414
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
415
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
416
+ C) # nW*B, window_size*window_size, C
417
+
418
+ # W-MSA/SW-MSA
419
+ attn_windows = self.attn(
420
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
421
+
422
+ # merge windows
423
+ attn_windows = attn_windows.view(-1, self.window_size,
424
+ self.window_size, C)
425
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
426
+ Wp) # B H' W' C
427
+
428
+ # reverse cyclic shift
429
+ if self.shift_size > 0:
430
+ x = torch.roll(shifted_x,
431
+ shifts=(self.shift_size, self.shift_size),
432
+ dims=(1, 2))
433
+ else:
434
+ x = shifted_x
435
+
436
+ if pad_r > 0 or pad_b > 0:
437
+ x = x[:, :H, :W, :].contiguous()
438
+
439
+ x = x.view(B, H * W, C)
440
+
441
+ # FFN
442
+ x = shortcut + self.drop_path(x)
443
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
444
+
445
+ return x
446
+
447
+
448
+ class PatchMerging(nn.Module):
449
+ """ Patch Merging Layer
450
+ Args:
451
+ dim (int): Number of input channels.
452
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
453
+ """
454
+
455
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
456
+ super().__init__()
457
+ self.dim = dim
458
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
459
+ self.norm = norm_layer(4 * dim)
460
+
461
+ def forward(self, x, H, W):
462
+ """ Forward function.
463
+ Args:
464
+ x: Input feature, tensor size (B, H*W, C).
465
+ H, W: Spatial resolution of the input feature.
466
+ """
467
+ B, L, C = x.shape
468
+ assert L == H * W, "input feature has wrong size"
469
+
470
+ x = x.view(B, H, W, C)
471
+
472
+ # padding
473
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
474
+ if pad_input:
475
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
476
+
477
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
478
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
479
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
480
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
481
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
482
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
483
+
484
+ x = self.norm(x)
485
+ x = self.reduction(x)
486
+
487
+ return x
488
+
489
+
490
+ class BasicLayer(nn.Module):
491
+ """ A basic Swin Transformer layer for one stage.
492
+ Args:
493
+ dim (int): Number of feature channels
494
+ depth (int): Depths of this stage.
495
+ num_heads (int): Number of attention head.
496
+ window_size (int): Local window size. Default: 7.
497
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
498
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
499
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
500
+ drop (float, optional): Dropout rate. Default: 0.0
501
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
502
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
503
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
504
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
505
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
506
+ """
507
+
508
+ def __init__(self,
509
+ dim,
510
+ depth,
511
+ num_heads,
512
+ window_size=7,
513
+ mlp_ratio=4.,
514
+ qkv_bias=True,
515
+ qk_scale=None,
516
+ drop=0.,
517
+ attn_drop=0.,
518
+ drop_path=0.,
519
+ norm_layer=nn.LayerNorm,
520
+ downsample=None,
521
+ use_checkpoint=False):
522
+
523
+ super().__init__()
524
+ self.window_size = window_size
525
+ self.shift_size = window_size // 2
526
+ self.depth = depth
527
+ self.use_checkpoint = use_checkpoint
528
+
529
+ # build blocks
530
+ self.blocks = nn.ModuleList([
531
+ SwinTransformerBlock(dim=dim,
532
+ num_heads=num_heads,
533
+ window_size=window_size,
534
+ shift_size=0 if
535
+ (i % 2 == 0) else window_size // 2,
536
+ mlp_ratio=mlp_ratio,
537
+ qkv_bias=qkv_bias,
538
+ qk_scale=qk_scale,
539
+ drop=drop,
540
+ attn_drop=attn_drop,
541
+ drop_path=drop_path[i] if isinstance(
542
+ drop_path, list) else drop_path,
543
+ norm_layer=norm_layer) for i in range(depth)
544
+ ])
545
+
546
+ # patch merging layer
547
+ if downsample is not None:
548
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
549
+ else:
550
+ self.downsample = None
551
+
552
+ def forward(self, x, H, W):
553
+ """ Forward function.
554
+ Args:
555
+ x: Input feature, tensor size (B, H*W, C).
556
+ H, W: Spatial resolution of the input feature.
557
+ """
558
+ # print(x.shape,H,W)
559
+ # calculate attention mask for SW-MSA
560
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
561
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
562
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
563
+ h_slices = (slice(0, -self.window_size),
564
+ slice(-self.window_size,
565
+ -self.shift_size), slice(-self.shift_size, None))
566
+ w_slices = (slice(0, -self.window_size),
567
+ slice(-self.window_size,
568
+ -self.shift_size), slice(-self.shift_size, None))
569
+ cnt = 0
570
+ for h in h_slices:
571
+ for w in w_slices:
572
+ img_mask[:, h, w, :] = cnt
573
+ cnt += 1
574
+
575
+ mask_windows = window_partition(
576
+ img_mask, self.window_size) # nW, window_size, window_size, 1
577
+
578
+ mask_windows = mask_windows.view(-1,
579
+ self.window_size * self.window_size)
580
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(
581
+ 2) # nW, ww window_size*window_size
582
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
583
+ float(-100.0)).masked_fill(
584
+ attn_mask == 0, float(0.0))
585
+
586
+ for blk in self.blocks:
587
+ blk.H, blk.W = H, W
588
+ if self.use_checkpoint:
589
+ x = checkpoint.checkpoint(blk, x, attn_mask)
590
+ else:
591
+ x = blk(x, attn_mask)
592
+
593
+ if self.downsample is not None:
594
+ x_down = self.downsample(x, H, W)
595
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
596
+ return x, H, W, x_down, Wh, Ww
597
+ else:
598
+ return x, H, W, x, H, W
599
+
600
+
601
+ class PatchEmbed(nn.Module):
602
+ """ Image to Patch Embedding
603
+ Args:
604
+ patch_size (int): Patch token size. Default: 4.
605
+ in_chans (int): Number of input image channels. Default: 3.
606
+ embed_dim (int): Number of linear projection output channels. Default: 96.
607
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
608
+ """
609
+
610
+ def __init__(self,
611
+ patch_size=4,
612
+ in_chans=3,
613
+ embed_dim=96,
614
+ norm_layer=None):
615
+
616
+ super().__init__()
617
+ patch_size = to_2tuple(patch_size)
618
+ self.patch_size = patch_size
619
+
620
+ self.in_chans = in_chans
621
+ self.embed_dim = embed_dim
622
+
623
+ self.proj = nn.Conv2d(in_chans,
624
+ embed_dim,
625
+ kernel_size=patch_size,
626
+ stride=patch_size)
627
+ if norm_layer is not None:
628
+ self.norm = norm_layer(embed_dim)
629
+ else:
630
+ self.norm = None
631
+
632
+ def forward(self, x):
633
+ """Forward function."""
634
+ # padding
635
+ _, _, H, W = x.size()
636
+ if W % self.patch_size[1] != 0:
637
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
638
+ if H % self.patch_size[0] != 0:
639
+ x = F.pad(x,
640
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
641
+
642
+ x = self.proj(x) # B C Wh Ww
643
+ if self.norm is not None:
644
+ Wh, Ww = x.size(2), x.size(3)
645
+ x = x.flatten(2).transpose(1, 2)
646
+ x = self.norm(x)
647
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
648
+
649
+ return x
650
+
651
+
652
+ class SwinTransformer(nn.Module):
653
+ """ Swin Transformer backbone.
654
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
655
+ https://arxiv.org/pdf/2103.14030
656
+ Args:
657
+ pretrain_img_size (int): Input image size for training the pretrained model,
658
+ used in absolute postion embedding. Default 224.
659
+ patch_size (int | tuple(int)): Patch size. Default: 4.
660
+ in_chans (int): Number of input image channels. Default: 3.
661
+ embed_dim (int): Number of linear projection output channels. Default: 96.
662
+ depths (tuple[int]): Depths of each Swin Transformer stage.
663
+ num_heads (tuple[int]): Number of attention head of each stage.
664
+ window_size (int): Window size. Default: 7.
665
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
666
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
667
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
668
+ drop_rate (float): Dropout rate.
669
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
670
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
671
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
672
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
673
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
674
+ out_indices (Sequence[int]): Output from which stages.
675
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
676
+ -1 means not freezing any parameters.
677
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
678
+ """
679
+
680
+ def __init__(self,
681
+ pretrain_img_size=224,
682
+ patch_size=4,
683
+ in_chans=3,
684
+ embed_dim=96,
685
+ depths=[2, 2, 6, 2],
686
+ num_heads=[3, 6, 12, 24],
687
+ window_size=7,
688
+ mlp_ratio=4.,
689
+ qkv_bias=True,
690
+ qk_scale=None,
691
+ drop_rate=0.,
692
+ attn_drop_rate=0.,
693
+ drop_path_rate=0.2,
694
+ norm_layer=nn.LayerNorm,
695
+ ape=False,
696
+ patch_norm=True,
697
+ out_indices=(0, 1, 2, 3),
698
+ frozen_stages=-1,
699
+ use_checkpoint=False):
700
+
701
+ super().__init__()
702
+
703
+ self.pretrain_img_size = pretrain_img_size
704
+ self.num_layers = len(depths)
705
+ self.embed_dim = embed_dim
706
+ self.ape = ape
707
+ self.patch_norm = patch_norm
708
+ self.out_indices = out_indices
709
+ self.frozen_stages = frozen_stages
710
+
711
+ # split image into non-overlapping patches
712
+ self.patch_embed = PatchEmbed(
713
+ patch_size=patch_size,
714
+ in_chans=in_chans,
715
+ embed_dim=embed_dim,
716
+ norm_layer=norm_layer if self.patch_norm else None)
717
+
718
+ # absolute position embedding
719
+ if self.ape:
720
+ pretrain_img_size = to_2tuple(pretrain_img_size)
721
+ patch_size = to_2tuple(patch_size)
722
+ patches_resolution = [
723
+ pretrain_img_size[0] // patch_size[0],
724
+ pretrain_img_size[1] // patch_size[1]
725
+ ]
726
+
727
+ self.absolute_pos_embed = nn.Parameter(
728
+ torch.zeros(1, embed_dim, patches_resolution[0],
729
+ patches_resolution[1]))
730
+ trunc_normal_(self.absolute_pos_embed, std=.02)
731
+
732
+ self.pos_drop = nn.Dropout(p=drop_rate)
733
+
734
+ # stochastic depth
735
+ dpr = [
736
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
737
+ ] # stochastic depth decay rule
738
+
739
+ # build layers
740
+ self.layers = nn.ModuleList()
741
+ for i_layer in range(self.num_layers):
742
+ layer = BasicLayer(
743
+ dim=int(embed_dim * 2**i_layer),
744
+ depth=depths[i_layer],
745
+ num_heads=num_heads[i_layer],
746
+ window_size=window_size,
747
+ mlp_ratio=mlp_ratio,
748
+ qkv_bias=qkv_bias,
749
+ qk_scale=qk_scale,
750
+ drop=drop_rate,
751
+ attn_drop=attn_drop_rate,
752
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
753
+ norm_layer=norm_layer,
754
+ downsample=PatchMerging if
755
+ (i_layer < self.num_layers - 1) else None,
756
+ use_checkpoint=use_checkpoint)
757
+ self.layers.append(layer)
758
+
759
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
760
+ self.num_features = num_features
761
+
762
+ # add a norm layer for each output
763
+ for i_layer in out_indices:
764
+ layer = norm_layer(num_features[i_layer])
765
+ layer_name = f'norm{i_layer}'
766
+ self.add_module(layer_name, layer)
767
+
768
+ self._freeze_stages()
769
+
770
+ def _freeze_stages(self):
771
+ if self.frozen_stages >= 0:
772
+ self.patch_embed.eval()
773
+ for param in self.patch_embed.parameters():
774
+ param.requires_grad = False
775
+
776
+ if self.frozen_stages >= 1 and self.ape:
777
+ self.absolute_pos_embed.requires_grad = False
778
+
779
+ if self.frozen_stages >= 2:
780
+ self.pos_drop.eval()
781
+ for i in range(0, self.frozen_stages - 1):
782
+ m = self.layers[i]
783
+ m.eval()
784
+ for param in m.parameters():
785
+ param.requires_grad = False
786
+
787
+ def init_weights(self, pretrained=None):
788
+ """Initialize the weights in backbone.
789
+ Args:
790
+ pretrained (str, optional): Path to pre-trained weights.
791
+ Defaults to None.
792
+ """
793
+
794
+ def forward(self, x):
795
+ """Forward function."""
796
+ x = self.patch_embed(x)
797
+
798
+ Wh, Ww = x.size(2), x.size(3)
799
+ if self.ape:
800
+ # interpolate the position embedding to the corresponding size
801
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
802
+ size=(Wh, Ww),
803
+ mode='bicubic')
804
+ x = (x + absolute_pos_embed).flatten(2).transpose(1,
805
+ 2) # B Wh*Ww C
806
+ else:
807
+ x = x.flatten(2).transpose(1, 2)
808
+ x = self.pos_drop(x)
809
+
810
+ outs = []
811
+ for i in range(self.num_layers):
812
+ layer = self.layers[i]
813
+
814
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
815
+
816
+ if i in self.out_indices:
817
+ norm_layer = getattr(self, f'norm{i}')
818
+ x_out = norm_layer(x_out)
819
+
820
+ out = x_out.view(-1, H, W,
821
+ self.num_features[i]).permute(0, 3, 1,
822
+ 2).contiguous()
823
+ outs.append(out)
824
+
825
+ return tuple(outs)
826
+
827
+ def train(self, mode=True):
828
+ """Convert the model into training mode while keep layers freezed."""
829
+ super(SwinTransformer, self).train(mode)
830
+ self._freeze_stages()
831
+
832
+
833
+ class Mlp(nn.Module):
834
+ """ Multilayer perceptron."""
835
+
836
+ def __init__(self,
837
+ in_features,
838
+ hidden_features=None,
839
+ out_features=None,
840
+ act_layer=nn.GELU,
841
+ drop=0.):
842
+ super().__init__()
843
+ out_features = out_features or in_features
844
+ hidden_features = hidden_features or in_features
845
+ self.fc1 = nn.Linear(in_features, hidden_features)
846
+ self.act = act_layer()
847
+ self.fc2 = nn.Linear(hidden_features, out_features)
848
+ self.drop = nn.Dropout(drop)
849
+
850
+ def forward(self, x):
851
+ x = self.fc1(x)
852
+ x = self.act(x)
853
+ x = self.drop(x)
854
+ x = self.fc2(x)
855
+ x = self.drop(x)
856
+ return x
857
+
858
+
859
+ class ResBlock(nn.Module):
860
+
861
+ def __init__(self, inc, midc):
862
+ super(ResBlock, self).__init__()
863
+ self.conv1 = nn.Conv2d(inc,
864
+ midc,
865
+ kernel_size=1,
866
+ stride=1,
867
+ padding=0,
868
+ bias=True)
869
+ self.gn1 = nn.GroupNorm(16, midc)
870
+ self.conv2 = nn.Conv2d(midc,
871
+ midc,
872
+ kernel_size=3,
873
+ stride=1,
874
+ padding=1,
875
+ bias=True)
876
+ self.gn2 = nn.GroupNorm(16, midc)
877
+ self.conv3 = nn.Conv2d(midc,
878
+ inc,
879
+ kernel_size=1,
880
+ stride=1,
881
+ padding=0,
882
+ bias=True)
883
+ self.relu = nn.LeakyReLU(0.1)
884
+
885
+ def forward(self, x):
886
+ x_ = x
887
+ x = self.conv1(x)
888
+ x = self.gn1(x)
889
+ x = self.relu(x)
890
+ x = self.conv2(x)
891
+ x = self.gn2(x)
892
+ x = self.relu(x)
893
+ x = self.conv3(x)
894
+ x = x + x_
895
+ x = self.relu(x)
896
+ return x
897
+
898
+
899
+ class AEALblock(nn.Module):
900
+
901
+ def __init__(self,
902
+ d_model,
903
+ nhead,
904
+ dim_feedforward=512,
905
+ dropout=0.0,
906
+ layer_norm_eps=1e-5,
907
+ batch_first=True,
908
+ norm_first=False,
909
+ width=5):
910
+ super(AEALblock, self).__init__()
911
+ self.self_attn2 = nn.MultiheadAttention(d_model // 2,
912
+ nhead // 2,
913
+ dropout=dropout,
914
+ batch_first=batch_first)
915
+ self.self_attn1 = nn.MultiheadAttention(d_model // 2,
916
+ nhead // 2,
917
+ dropout=dropout,
918
+ batch_first=batch_first)
919
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
920
+ self.dropout = nn.Dropout(dropout)
921
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
922
+ self.norm_first = norm_first
923
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
924
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
925
+ self.dropout1 = nn.Dropout(dropout)
926
+ self.dropout2 = nn.Dropout(dropout)
927
+ self.activation = nn.ReLU()
928
+ self.width = width
929
+ self.trans = nn.Sequential(
930
+ nn.Conv2d(d_model + 512, d_model // 2, 1, 1, 0),
931
+ ResBlock(d_model // 2, d_model // 4),
932
+ nn.Conv2d(d_model // 2, d_model, 1, 1, 0))
933
+ self.gamma = nn.Parameter(torch.zeros(1))
934
+
935
+ def forward(
936
+ self,
937
+ src,
938
+ feats,
939
+ ):
940
+ src = self.gamma * self.trans(torch.cat([src, feats], 1)) + src
941
+ b, c, h, w = src.shape
942
+ x1 = src[:, 0:c // 2]
943
+ x1_ = rearrange(x1, 'b c (h1 h2) w -> b c h1 h2 w', h2=self.width)
944
+ x1_ = rearrange(x1_, 'b c h1 h2 w -> (b h1) (h2 w) c')
945
+ x2 = src[:, c // 2:]
946
+ x2_ = rearrange(x2, 'b c h (w1 w2) -> b c h w1 w2', w2=self.width)
947
+ x2_ = rearrange(x2_, 'b c h w1 w2 -> (b w1) (h w2) c')
948
+ x = rearrange(src, 'b c h w-> b (h w) c')
949
+ x = self.norm1(x + self._sa_block(x1_, x2_, h, w))
950
+ x = self.norm2(x + self._ff_block(x))
951
+ x = rearrange(x, 'b (h w) c->b c h w', h=h, w=w)
952
+ return x
953
+
954
+ def _sa_block(self, x1, x2, h, w):
955
+ x1 = self.self_attn1(x1,
956
+ x1,
957
+ x1,
958
+ attn_mask=None,
959
+ key_padding_mask=None,
960
+ need_weights=False)[0]
961
+
962
+ x2 = self.self_attn2(x2,
963
+ x2,
964
+ x2,
965
+ attn_mask=None,
966
+ key_padding_mask=None,
967
+ need_weights=False)[0]
968
+
969
+ x1 = rearrange(x1,
970
+ '(b h1) (h2 w) c-> b (h1 h2 w) c',
971
+ h2=self.width,
972
+ h1=h // self.width)
973
+ x2 = rearrange(x2,
974
+ ' (b w1) (h w2) c-> b (h w1 w2) c',
975
+ w2=self.width,
976
+ w1=w // self.width)
977
+ x = torch.cat([x1, x2], dim=2)
978
+ return self.dropout1(x)
979
+
980
+ def _ff_block(self, x):
981
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
982
+ return self.dropout2(x)
983
+
984
+
985
+ class AEMatter(nn.Module):
986
+
987
+ def __init__(self):
988
+ super(AEMatter, self).__init__()
989
+ trans = SwinTransformer(pretrain_img_size=224,
990
+ embed_dim=96,
991
+ depths=[2, 2, 6, 2],
992
+ num_heads=[3, 6, 12, 24],
993
+ window_size=7,
994
+ ape=False,
995
+ drop_path_rate=0.2,
996
+ patch_norm=True,
997
+ use_checkpoint=False)
998
+
999
+ # trans.load_state_dict(torch.load(
1000
+ # '/home/asd/Desktop/swin_tiny_patch4_window7_224.pth',
1001
+ # map_location="cpu")["model"],
1002
+ # strict=False)
1003
+
1004
+ trans.patch_embed.proj = nn.Conv2d(64, 96, 3, 2, 1)
1005
+
1006
+ self.start_conv0 = nn.Sequential(nn.Conv2d(6, 48, 3, 1, 1),
1007
+ nn.PReLU(48))
1008
+
1009
+ self.start_conv = nn.Sequential(nn.Conv2d(48, 64, 3, 2,
1010
+ 1), nn.PReLU(64),
1011
+ nn.Conv2d(64, 64, 3, 1, 1),
1012
+ nn.PReLU(64))
1013
+
1014
+ self.trans = trans
1015
+ self.conv1 = nn.Sequential(
1016
+ nn.Conv2d(in_channels=640 + 768,
1017
+ out_channels=256,
1018
+ kernel_size=1,
1019
+ stride=1,
1020
+ padding=0,
1021
+ bias=True))
1022
+ self.conv2 = nn.Sequential(
1023
+ nn.Conv2d(in_channels=256 + 384,
1024
+ out_channels=256,
1025
+ kernel_size=1,
1026
+ stride=1,
1027
+ padding=0,
1028
+ bias=True), )
1029
+ self.conv3 = nn.Sequential(
1030
+ nn.Conv2d(in_channels=256 + 192,
1031
+ out_channels=192,
1032
+ kernel_size=1,
1033
+ stride=1,
1034
+ padding=0,
1035
+ bias=True), )
1036
+ self.conv4 = nn.Sequential(
1037
+ nn.Conv2d(in_channels=192 + 96,
1038
+ out_channels=128,
1039
+ kernel_size=1,
1040
+ stride=1,
1041
+ padding=0,
1042
+ bias=True), )
1043
+ self.ctran0 = BasicLayer(256, 3, 8, 7, drop_path=0.09)
1044
+ self.ctran1 = BasicLayer(256, 3, 8, 7, drop_path=0.07)
1045
+ self.ctran2 = BasicLayer(192, 3, 6, 7, drop_path=0.05)
1046
+ self.ctran3 = BasicLayer(128, 3, 4, 7, drop_path=0.03)
1047
+ self.conv5 = nn.Sequential(
1048
+ nn.Conv2d(in_channels=192,
1049
+ out_channels=64,
1050
+ kernel_size=3,
1051
+ stride=1,
1052
+ padding=1,
1053
+ bias=True), nn.PReLU(64),
1054
+ nn.Conv2d(in_channels=64,
1055
+ out_channels=64,
1056
+ kernel_size=3,
1057
+ stride=1,
1058
+ padding=1,
1059
+ bias=True), nn.PReLU(64),
1060
+ nn.Conv2d(in_channels=64,
1061
+ out_channels=48,
1062
+ kernel_size=3,
1063
+ stride=1,
1064
+ padding=1,
1065
+ bias=True), nn.PReLU(48))
1066
+ self.convo = nn.Sequential(
1067
+ nn.Conv2d(in_channels=48 + 48 + 6,
1068
+ out_channels=32,
1069
+ kernel_size=3,
1070
+ stride=1,
1071
+ padding=1,
1072
+ bias=True), nn.PReLU(32),
1073
+ nn.Conv2d(in_channels=32,
1074
+ out_channels=32,
1075
+ kernel_size=3,
1076
+ stride=1,
1077
+ padding=1,
1078
+ bias=True), nn.PReLU(32),
1079
+ nn.Conv2d(in_channels=32,
1080
+ out_channels=1,
1081
+ kernel_size=3,
1082
+ stride=1,
1083
+ padding=1,
1084
+ bias=True))
1085
+ self.up = nn.Upsample(scale_factor=2,
1086
+ mode='bilinear',
1087
+ align_corners=False)
1088
+ self.upn = nn.Upsample(scale_factor=2, mode='nearest')
1089
+ self.apptrans = nn.Sequential(
1090
+ nn.Conv2d(256 + 384, 256, 1, 1, bias=True), ResBlock(256, 128),
1091
+ ResBlock(256, 128), nn.Conv2d(256, 512, 2, 2, bias=True),
1092
+ ResBlock(512, 128))
1093
+ self.emb = nn.Sequential(nn.Conv2d(768, 640, 1, 1, 0),
1094
+ ResBlock(640, 160))
1095
+ self.embdp = nn.Sequential(nn.Conv2d(640, 640, 1, 1, 0))
1096
+ self.h2l = nn.Conv2d(768, 256, 1, 1, 0)
1097
+ self.width = 5
1098
+ self.trans1 = AEALblock(d_model=640,
1099
+ nhead=20,
1100
+ dim_feedforward=2048,
1101
+ dropout=0.2,
1102
+ width=self.width)
1103
+ self.trans2 = AEALblock(d_model=640,
1104
+ nhead=20,
1105
+ dim_feedforward=2048,
1106
+ dropout=0.2,
1107
+ width=self.width)
1108
+ self.trans3 = AEALblock(d_model=640,
1109
+ nhead=20,
1110
+ dim_feedforward=2048,
1111
+ dropout=0.2,
1112
+ width=self.width)
1113
+
1114
+ def aeal(self, x, sem):
1115
+ xe = self.emb(x)
1116
+ x_ = xe
1117
+ x_ = self.embdp(x_)
1118
+ b, c, h1, w1 = x_.shape
1119
+ bnew_ph = int(np.ceil(h1 / self.width) * self.width) - h1
1120
+ bnew_pw = int(np.ceil(w1 / self.width) * self.width) - w1
1121
+ newph1 = bnew_ph // 2
1122
+ newph2 = bnew_ph - newph1
1123
+ newpw1 = bnew_pw // 2
1124
+ newpw2 = bnew_pw - newpw1
1125
+ x_ = F.pad(x_, (newpw1, newpw2, newph1, newph2))
1126
+ sem = F.pad(sem, (newpw1, newpw2, newph1, newph2))
1127
+ x_ = self.trans1(x_, sem)
1128
+ x_ = self.trans2(x_, sem)
1129
+ x_ = self.trans3(x_, sem)
1130
+ x_ = x_[:, :, newph1:h1 + newph1, newpw1:w1 + newpw1]
1131
+ return x_
1132
+
1133
+ def forward(self, x, y):
1134
+ inputs = torch.cat((x, y), 1)
1135
+ x = self.start_conv0(inputs)
1136
+ x_ = self.start_conv(x)
1137
+ x1, x2, x3, x4 = self.trans(x_)
1138
+ x4h = self.h2l(x4)
1139
+ x3s = self.apptrans(torch.cat([x3, self.upn(x4h)], 1))
1140
+ x4_ = self.aeal(x4, x3s)
1141
+ x4 = torch.cat((x4, x4_), 1)
1142
+ X4 = self.conv1(x4)
1143
+ wh, ww = X4.shape[2], X4.shape[3]
1144
+ X4 = rearrange(X4, 'b c h w -> b (h w) c')
1145
+ X4, _, _, _, _, _ = self.ctran0(X4, wh, ww)
1146
+ X4 = rearrange(X4, 'b (h w) c -> b c h w', h=wh, w=ww)
1147
+ X3 = self.up(X4)
1148
+ X3 = torch.cat((x3, X3), 1)
1149
+ X3 = self.conv2(X3)
1150
+ wh, ww = X3.shape[2], X3.shape[3]
1151
+ X3 = rearrange(X3, 'b c h w -> b (h w) c')
1152
+ X3, _, _, _, _, _ = self.ctran1(X3, wh, ww)
1153
+ X3 = rearrange(X3, 'b (h w) c -> b c h w', h=wh, w=ww)
1154
+ X2 = self.up(X3)
1155
+ X2 = torch.cat((x2, X2), 1)
1156
+ X2 = self.conv3(X2)
1157
+ wh, ww = X2.shape[2], X2.shape[3]
1158
+ X2 = rearrange(X2, 'b c h w -> b (h w) c')
1159
+ X2, _, _, _, _, _ = self.ctran2(X2, wh, ww)
1160
+ X2 = rearrange(X2, 'b (h w) c -> b c h w', h=wh, w=ww)
1161
+ X1 = self.up(X2)
1162
+ X1 = torch.cat((x1, X1), 1)
1163
+ X1 = self.conv4(X1)
1164
+ wh, ww = X1.shape[2], X1.shape[3]
1165
+ X1 = rearrange(X1, 'b c h w -> b (h w) c')
1166
+ X1, _, _, _, _, _ = self.ctran3(X1, wh, ww)
1167
+ X1 = rearrange(X1, 'b (h w) c -> b c h w', h=wh, w=ww)
1168
+ X0 = self.up(X1)
1169
+ X0 = torch.cat((x_, X0), 1)
1170
+ X0 = self.conv5(X0)
1171
+ X = self.up(X0)
1172
+ X = torch.cat((inputs, x, X), 1)
1173
+ alpha = self.convo(X)
1174
+ alpha = torch.clamp(alpha, min=0, max=1)
1175
+ return alpha
1176
+
1177
+
1178
+ class load_AEMatter_Model:
1179
+
1180
+ def __init__(self):
1181
+ pass
1182
+
1183
+ @classmethod
1184
+ def INPUT_TYPES(s):
1185
+ return {
1186
+ "required": {},
1187
+ }
1188
+
1189
+ RETURN_TYPES = ("AEMatter_Model", )
1190
+ FUNCTION = "test"
1191
+ CATEGORY = "AEMatter"
1192
+
1193
+ def test(self):
1194
+ return (get_AEMatter_model(get_model_path()), )
1195
+
1196
+
1197
+ class run_AEMatter_inference:
1198
+
1199
+ def __init__(self):
1200
+ pass
1201
+
1202
+ @classmethod
1203
+ def INPUT_TYPES(s):
1204
+ return {
1205
+ "required": {
1206
+ "image": ("IMAGE", ),
1207
+ "trimap": ("MASK", ),
1208
+ "AEMatter_Model": ("AEMatter_Model", ),
1209
+ },
1210
+ }
1211
+
1212
+ RETURN_TYPES = ("MASK", )
1213
+ FUNCTION = "test"
1214
+ CATEGORY = "AEMatter"
1215
+
1216
+ def test(
1217
+ self,
1218
+ image,
1219
+ trimap,
1220
+ AEMatter_Model,
1221
+ ):
1222
+
1223
+ ret = []
1224
+ batch_size = image.shape[0]
1225
+
1226
+ for i in range(batch_size):
1227
+ tmp_i = from_torch_image(image[i])
1228
+ tmp_m = from_torch_image(trimap[i])
1229
+ tmp = do_infer(tmp_i, tmp_m, AEMatter_Model)
1230
+ ret.append(tmp)
1231
+
1232
+ ret = to_torch_image(np.array(ret))
1233
+ ret = ret.squeeze(-1)
1234
+ print(ret.shape)
1235
+
1236
+ return ret
1237
+
1238
+
1239
+ #!/usr/bin/python3
1240
+ NODE_CLASS_MAPPINGS = {
1241
+ 'load_AEMatter_Model': load_AEMatter_Model,
1242
+ 'run_AEMatter_inference': run_AEMatter_inference,
1243
+ }
1244
+
1245
+ NODE_DISPLAY_NAME_MAPPINGS = {
1246
+ 'load_AEMatter_Model': 'load_AEMatter_Model',
1247
+ 'run_AEMatter_inference': 'run_AEMatter_inference',
1248
+ }
ComfyUI_MVANet/MVANet_inference.py ADDED
@@ -0,0 +1,1548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ import sys
4
+
5
+ HOME_DIR = os.environ.get('HOME', '/root')
6
+ MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
7
+ finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
8
+ pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
9
+
10
+ import math
11
+ import numpy as np
12
+ import cv2
13
+ import wget
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from torch.autograd import Variable
20
+ from torch import nn
21
+ from torchvision import transforms
22
+
23
+ from einops import rearrange
24
+
25
+ from timm.models import load_checkpoint
26
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
+
28
+ torch_device = 'cuda'
29
+ torch_dtype = torch.float16
30
+
31
+
32
+ def check_mkdir(dir_name):
33
+ if not os.path.isdir(dir_name):
34
+ os.makedirs(dir_name)
35
+
36
+
37
+ def SwinT(pretrained=True):
38
+ model = SwinTransformer(embed_dim=96,
39
+ depths=[2, 2, 6, 2],
40
+ num_heads=[3, 6, 12, 24],
41
+ window_size=7)
42
+ if pretrained is True:
43
+ model.load_state_dict(torch.load(
44
+ 'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
45
+ map_location='cpu')['model'],
46
+ strict=False)
47
+
48
+ return model
49
+
50
+
51
+ def SwinS(pretrained=True):
52
+ model = SwinTransformer(embed_dim=96,
53
+ depths=[2, 2, 18, 2],
54
+ num_heads=[3, 6, 12, 24],
55
+ window_size=7)
56
+ if pretrained is True:
57
+ model.load_state_dict(torch.load(
58
+ 'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
59
+ map_location='cpu')['model'],
60
+ strict=False)
61
+
62
+ return model
63
+
64
+
65
+ def SwinB(pretrained=True):
66
+ model = SwinTransformer(embed_dim=128,
67
+ depths=[2, 2, 18, 2],
68
+ num_heads=[4, 8, 16, 32],
69
+ window_size=12)
70
+ if pretrained is True:
71
+ import os
72
+ model.load_state_dict(torch.load(pretrained_SwinB_model_path,
73
+ map_location='cpu')['model'],
74
+ strict=False)
75
+ return model
76
+
77
+
78
+ def SwinL(pretrained=True):
79
+ model = SwinTransformer(embed_dim=192,
80
+ depths=[2, 2, 18, 2],
81
+ num_heads=[6, 12, 24, 48],
82
+ window_size=12)
83
+ if pretrained is True:
84
+ model.load_state_dict(torch.load(
85
+ 'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
86
+ map_location='cpu')['model'],
87
+ strict=False)
88
+
89
+ return model
90
+
91
+
92
+ def get_activation_fn(activation):
93
+ """Return an activation function given a string"""
94
+ if activation == "relu":
95
+ return F.relu
96
+ if activation == "gelu":
97
+ return F.gelu
98
+ if activation == "glu":
99
+ return F.glu
100
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
101
+
102
+
103
+ def make_cbr(in_dim, out_dim):
104
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
105
+ nn.BatchNorm2d(out_dim), nn.PReLU())
106
+
107
+
108
+ def make_cbg(in_dim, out_dim):
109
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
110
+ nn.BatchNorm2d(out_dim), nn.GELU())
111
+
112
+
113
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
114
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
115
+
116
+
117
+ def resize_as(x, y, interpolation='bilinear'):
118
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
119
+
120
+
121
+ def image2patches(x):
122
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
123
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
124
+ return x
125
+
126
+
127
+ def patches2image(x):
128
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
129
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
130
+ return x
131
+
132
+
133
+ def window_partition(x, window_size):
134
+ """
135
+ Args:
136
+ x: (B, H, W, C)
137
+ window_size (int): window size
138
+
139
+ Returns:
140
+ windows: (num_windows*B, window_size, window_size, C)
141
+ """
142
+ B, H, W, C = x.shape
143
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
144
+ C)
145
+ windows = x.permute(0, 1, 3, 2, 4,
146
+ 5).contiguous().view(-1, window_size, window_size, C)
147
+ return windows
148
+
149
+
150
+ def window_reverse(windows, window_size, H, W):
151
+ """
152
+ Args:
153
+ windows: (num_windows*B, window_size, window_size, C)
154
+ window_size (int): Window size
155
+ H (int): Height of image
156
+ W (int): Width of image
157
+
158
+ Returns:
159
+ x: (B, H, W, C)
160
+ """
161
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
162
+ x = windows.view(B, H // window_size, W // window_size, window_size,
163
+ window_size, -1)
164
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
165
+ return x
166
+
167
+
168
+ def mkdir_safe(out_path):
169
+ if type(out_path) == str:
170
+ if len(out_path) > 0:
171
+ if not os.path.exists(out_path):
172
+ os.mkdir(out_path)
173
+
174
+
175
+ def get_model_path():
176
+ import folder_paths
177
+ from folder_paths import models_dir
178
+
179
+ path_file_model = models_dir
180
+ mkdir_safe(out_path=path_file_model)
181
+
182
+ path_file_model = os.path.join(path_file_model, 'MVANet')
183
+ mkdir_safe(out_path=path_file_model)
184
+
185
+ path_file_model = os.path.join(path_file_model, 'Model_80.pth')
186
+
187
+ return path_file_model
188
+
189
+
190
+ def download_model(path):
191
+ if not os.path.exists(path):
192
+ wget.download(
193
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
194
+ out=path)
195
+
196
+
197
+ def load_model(model_checkpoint_path):
198
+ download_model(path=model_checkpoint_path)
199
+ torch.cuda.set_device(0)
200
+
201
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
202
+
203
+ pretrained_dict = torch.load(finetuned_MVANet_model_path,
204
+ map_location=torch_device)
205
+
206
+ model_dict = net.state_dict()
207
+ pretrained_dict = {
208
+ k: v
209
+ for k, v in pretrained_dict.items() if k in model_dict
210
+ }
211
+ model_dict.update(pretrained_dict)
212
+ net.load_state_dict(model_dict)
213
+ net = net.to(dtype=torch_dtype, device=torch_device)
214
+ net.eval()
215
+ return net
216
+
217
+
218
+ def do_infer_tensor2tensor(img, net):
219
+
220
+ img_transform = transforms.Compose(
221
+ [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
222
+
223
+ h_, w_ = img.shape[1], img.shape[2]
224
+
225
+ with torch.no_grad():
226
+
227
+ img = rearrange(img, 'B H W C -> B C H W')
228
+
229
+ img_resize = torch.nn.functional.interpolate(input=img,
230
+ size=(1024, 1024),
231
+ mode='bicubic',
232
+ antialias=True)
233
+
234
+ img_var = img_transform(img_resize)
235
+ img_var = Variable(img_var)
236
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
237
+
238
+ mask = []
239
+
240
+ mask.append(net(img_var))
241
+
242
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
243
+ prediction = prediction.sigmoid()
244
+
245
+ prediction = torch.nn.functional.interpolate(input=prediction,
246
+ size=(h_, w_),
247
+ mode='bicubic',
248
+ antialias=True)
249
+
250
+ prediction = prediction.squeeze(0)
251
+ prediction = prediction.clamp(0, 1)
252
+ prediction = prediction.detach()
253
+ prediction = prediction.to(dtype=torch.float32, device='cpu')
254
+
255
+ return prediction
256
+
257
+
258
+ class Mlp(nn.Module):
259
+ """ Multilayer perceptron."""
260
+
261
+ def __init__(self,
262
+ in_features,
263
+ hidden_features=None,
264
+ out_features=None,
265
+ act_layer=nn.GELU,
266
+ drop=0.):
267
+ super().__init__()
268
+ out_features = out_features or in_features
269
+ hidden_features = hidden_features or in_features
270
+ self.fc1 = nn.Linear(in_features, hidden_features)
271
+ self.act = act_layer()
272
+ self.fc2 = nn.Linear(hidden_features, out_features)
273
+ self.drop = nn.Dropout(drop)
274
+
275
+ def forward(self, x):
276
+ x = self.fc1(x)
277
+ x = self.act(x)
278
+ x = self.drop(x)
279
+ x = self.fc2(x)
280
+ x = self.drop(x)
281
+ return x
282
+
283
+
284
+ class WindowAttention(nn.Module):
285
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
286
+ It supports both of shifted and non-shifted window.
287
+
288
+ Args:
289
+ dim (int): Number of input channels.
290
+ window_size (tuple[int]): The height and width of the window.
291
+ num_heads (int): Number of attention heads.
292
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
293
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
294
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
295
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
296
+ """
297
+
298
+ def __init__(self,
299
+ dim,
300
+ window_size,
301
+ num_heads,
302
+ qkv_bias=True,
303
+ qk_scale=None,
304
+ attn_drop=0.,
305
+ proj_drop=0.):
306
+
307
+ super().__init__()
308
+ self.dim = dim
309
+ self.window_size = window_size # Wh, Ww
310
+ self.num_heads = num_heads
311
+ head_dim = dim // num_heads
312
+ self.scale = qk_scale or head_dim**-0.5
313
+
314
+ # define a parameter table of relative position bias
315
+ self.relative_position_bias_table = nn.Parameter(
316
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
317
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
318
+
319
+ # get pair-wise relative position index for each token inside the window
320
+ coords_h = torch.arange(self.window_size[0])
321
+ coords_w = torch.arange(self.window_size[1])
322
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
323
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
324
+ relative_coords = coords_flatten[:, :,
325
+ None] - coords_flatten[:,
326
+ None, :] # 2, Wh*Ww, Wh*Ww
327
+ relative_coords = relative_coords.permute(
328
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
329
+ relative_coords[:, :,
330
+ 0] += self.window_size[0] - 1 # shift to start from 0
331
+ relative_coords[:, :, 1] += self.window_size[1] - 1
332
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
333
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
334
+ self.register_buffer("relative_position_index",
335
+ relative_position_index)
336
+
337
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
338
+ self.attn_drop = nn.Dropout(attn_drop)
339
+ self.proj = nn.Linear(dim, dim)
340
+ self.proj_drop = nn.Dropout(proj_drop)
341
+
342
+ trunc_normal_(self.relative_position_bias_table, std=.02)
343
+ self.softmax = nn.Softmax(dim=-1)
344
+
345
+ def forward(self, x, mask=None):
346
+ """ Forward function.
347
+
348
+ Args:
349
+ x: input features with shape of (num_windows*B, N, C)
350
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
351
+ """
352
+ x = x.to(dtype=torch_dtype, device=torch_device)
353
+ B_, N, C = x.shape
354
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
355
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
356
+ q, k, v = qkv[0], qkv[1], qkv[
357
+ 2] # make torchscript happy (cannot use tensor as tuple)
358
+
359
+ q = q * self.scale
360
+ attn = (q @ k.transpose(-2, -1))
361
+
362
+ relative_position_bias = self.relative_position_bias_table[
363
+ self.relative_position_index.view(-1)].view(
364
+ self.window_size[0] * self.window_size[1],
365
+ self.window_size[0] * self.window_size[1],
366
+ -1) # Wh*Ww,Wh*Ww,nH
367
+ relative_position_bias = relative_position_bias.permute(
368
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
369
+ attn = attn + relative_position_bias.unsqueeze(0)
370
+
371
+ if mask is not None:
372
+ nW = mask.shape[0]
373
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
374
+ N) + mask.unsqueeze(1).unsqueeze(0)
375
+ attn = attn.view(-1, self.num_heads, N, N)
376
+ attn = self.softmax(attn)
377
+ else:
378
+ attn = self.softmax(attn)
379
+
380
+ attn = self.attn_drop(attn)
381
+ attn = attn.to(dtype=torch_dtype, device=torch_device)
382
+ v = v.to(dtype=torch_dtype, device=torch_device)
383
+
384
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
385
+ x = self.proj(x)
386
+ x = self.proj_drop(x)
387
+ return x
388
+
389
+
390
+ class SwinTransformerBlock(nn.Module):
391
+ """ Swin Transformer Block.
392
+
393
+ Args:
394
+ dim (int): Number of input channels.
395
+ num_heads (int): Number of attention heads.
396
+ window_size (int): Window size.
397
+ shift_size (int): Shift size for SW-MSA.
398
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
399
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
400
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
401
+ drop (float, optional): Dropout rate. Default: 0.0
402
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
403
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
404
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
405
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
406
+ """
407
+
408
+ def __init__(self,
409
+ dim,
410
+ num_heads,
411
+ window_size=7,
412
+ shift_size=0,
413
+ mlp_ratio=4.,
414
+ qkv_bias=True,
415
+ qk_scale=None,
416
+ drop=0.,
417
+ attn_drop=0.,
418
+ drop_path=0.,
419
+ act_layer=nn.GELU,
420
+ norm_layer=nn.LayerNorm):
421
+ super().__init__()
422
+ self.dim = dim
423
+ self.num_heads = num_heads
424
+ self.window_size = window_size
425
+ self.shift_size = shift_size
426
+ self.mlp_ratio = mlp_ratio
427
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
428
+
429
+ self.norm1 = norm_layer(dim)
430
+ self.attn = WindowAttention(dim,
431
+ window_size=to_2tuple(self.window_size),
432
+ num_heads=num_heads,
433
+ qkv_bias=qkv_bias,
434
+ qk_scale=qk_scale,
435
+ attn_drop=attn_drop,
436
+ proj_drop=drop)
437
+
438
+ self.drop_path = DropPath(
439
+ drop_path) if drop_path > 0. else nn.Identity()
440
+ self.norm2 = norm_layer(dim)
441
+ mlp_hidden_dim = int(dim * mlp_ratio)
442
+ self.mlp = Mlp(in_features=dim,
443
+ hidden_features=mlp_hidden_dim,
444
+ act_layer=act_layer,
445
+ drop=drop)
446
+
447
+ self.H = None
448
+ self.W = None
449
+
450
+ def forward(self, x, mask_matrix):
451
+ """ Forward function.
452
+
453
+ Args:
454
+ x: Input feature, tensor size (B, H*W, C).
455
+ H, W: Spatial resolution of the input feature.
456
+ mask_matrix: Attention mask for cyclic shift.
457
+ """
458
+ B, L, C = x.shape
459
+ H, W = self.H, self.W
460
+ assert L == H * W, "input feature has wrong size"
461
+
462
+ shortcut = x
463
+ x = self.norm1(x)
464
+ x = x.view(B, H, W, C)
465
+
466
+ # pad feature maps to multiples of window size
467
+ pad_l = pad_t = 0
468
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
469
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
470
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
471
+ _, Hp, Wp, _ = x.shape
472
+
473
+ # cyclic shift
474
+ if self.shift_size > 0:
475
+ shifted_x = torch.roll(x,
476
+ shifts=(-self.shift_size, -self.shift_size),
477
+ dims=(1, 2))
478
+ attn_mask = mask_matrix
479
+ else:
480
+ shifted_x = x
481
+ attn_mask = None
482
+
483
+ # partition windows
484
+ x_windows = window_partition(
485
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
486
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
487
+ C) # nW*B, window_size*window_size, C
488
+
489
+ # W-MSA/SW-MSA
490
+ attn_windows = self.attn(
491
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
492
+
493
+ # merge windows
494
+ attn_windows = attn_windows.view(-1, self.window_size,
495
+ self.window_size, C)
496
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
497
+ Wp) # B H' W' C
498
+
499
+ # reverse cyclic shift
500
+ if self.shift_size > 0:
501
+ x = torch.roll(shifted_x,
502
+ shifts=(self.shift_size, self.shift_size),
503
+ dims=(1, 2))
504
+ else:
505
+ x = shifted_x
506
+
507
+ if pad_r > 0 or pad_b > 0:
508
+ x = x[:, :H, :W, :].contiguous()
509
+
510
+ x = x.view(B, H * W, C)
511
+
512
+ # FFN
513
+ x = shortcut + self.drop_path(x)
514
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
515
+
516
+ return x
517
+
518
+
519
+ class PatchMerging(nn.Module):
520
+ """ Patch Merging Layer
521
+
522
+ Args:
523
+ dim (int): Number of input channels.
524
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
525
+ """
526
+
527
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
528
+ super().__init__()
529
+ self.dim = dim
530
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
531
+ self.norm = norm_layer(4 * dim)
532
+
533
+ def forward(self, x, H, W):
534
+ """ Forward function.
535
+
536
+ Args:
537
+ x: Input feature, tensor size (B, H*W, C).
538
+ H, W: Spatial resolution of the input feature.
539
+ """
540
+ B, L, C = x.shape
541
+ assert L == H * W, "input feature has wrong size"
542
+
543
+ x = x.view(B, H, W, C)
544
+
545
+ # padding
546
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
547
+ if pad_input:
548
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
549
+
550
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
551
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
552
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
553
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
554
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
555
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
556
+
557
+ x = self.norm(x)
558
+ x = self.reduction(x)
559
+
560
+ return x
561
+
562
+
563
+ class BasicLayer(nn.Module):
564
+ """ A basic Swin Transformer layer for one stage.
565
+
566
+ Args:
567
+ dim (int): Number of feature channels
568
+ depth (int): Depths of this stage.
569
+ num_heads (int): Number of attention head.
570
+ window_size (int): Local window size. Default: 7.
571
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
572
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
573
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
574
+ drop (float, optional): Dropout rate. Default: 0.0
575
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
576
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
577
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
578
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
579
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
580
+ """
581
+
582
+ def __init__(self,
583
+ dim,
584
+ depth,
585
+ num_heads,
586
+ window_size=7,
587
+ mlp_ratio=4.,
588
+ qkv_bias=True,
589
+ qk_scale=None,
590
+ drop=0.,
591
+ attn_drop=0.,
592
+ drop_path=0.,
593
+ norm_layer=nn.LayerNorm,
594
+ downsample=None,
595
+ use_checkpoint=False):
596
+ super().__init__()
597
+ self.window_size = window_size
598
+ self.shift_size = window_size // 2
599
+ self.depth = depth
600
+ self.use_checkpoint = use_checkpoint
601
+
602
+ # build blocks
603
+ self.blocks = nn.ModuleList([
604
+ SwinTransformerBlock(dim=dim,
605
+ num_heads=num_heads,
606
+ window_size=window_size,
607
+ shift_size=0 if
608
+ (i % 2 == 0) else window_size // 2,
609
+ mlp_ratio=mlp_ratio,
610
+ qkv_bias=qkv_bias,
611
+ qk_scale=qk_scale,
612
+ drop=drop,
613
+ attn_drop=attn_drop,
614
+ drop_path=drop_path[i] if isinstance(
615
+ drop_path, list) else drop_path,
616
+ norm_layer=norm_layer) for i in range(depth)
617
+ ])
618
+
619
+ # patch merging layer
620
+ if downsample is not None:
621
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
622
+ else:
623
+ self.downsample = None
624
+
625
+ def forward(self, x, H, W):
626
+ """ Forward function.
627
+
628
+ Args:
629
+ x: Input feature, tensor size (B, H*W, C).
630
+ H, W: Spatial resolution of the input feature.
631
+ """
632
+
633
+ # calculate attention mask for SW-MSA
634
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
635
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
636
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
637
+ h_slices = (slice(0, -self.window_size),
638
+ slice(-self.window_size,
639
+ -self.shift_size), slice(-self.shift_size, None))
640
+ w_slices = (slice(0, -self.window_size),
641
+ slice(-self.window_size,
642
+ -self.shift_size), slice(-self.shift_size, None))
643
+ cnt = 0
644
+ for h in h_slices:
645
+ for w in w_slices:
646
+ img_mask[:, h, w, :] = cnt
647
+ cnt += 1
648
+
649
+ mask_windows = window_partition(
650
+ img_mask, self.window_size) # nW, window_size, window_size, 1
651
+ mask_windows = mask_windows.view(-1,
652
+ self.window_size * self.window_size)
653
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
654
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
655
+ float(-100.0)).masked_fill(
656
+ attn_mask == 0, float(0.0))
657
+
658
+ for blk in self.blocks:
659
+ blk.H, blk.W = H, W
660
+ if self.use_checkpoint:
661
+ x = checkpoint.checkpoint(blk, x, attn_mask)
662
+ else:
663
+ x = blk(x, attn_mask)
664
+ if self.downsample is not None:
665
+ x_down = self.downsample(x, H, W)
666
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
667
+ return x, H, W, x_down, Wh, Ww
668
+ else:
669
+ return x, H, W, x, H, W
670
+
671
+
672
+ class PatchEmbed(nn.Module):
673
+ """ Image to Patch Embedding
674
+
675
+ Args:
676
+ patch_size (int): Patch token size. Default: 4.
677
+ in_chans (int): Number of input image channels. Default: 3.
678
+ embed_dim (int): Number of linear projection output channels. Default: 96.
679
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
680
+ """
681
+
682
+ def __init__(self,
683
+ patch_size=4,
684
+ in_chans=3,
685
+ embed_dim=96,
686
+ norm_layer=None):
687
+ super().__init__()
688
+ patch_size = to_2tuple(patch_size)
689
+ self.patch_size = patch_size
690
+
691
+ self.in_chans = in_chans
692
+ self.embed_dim = embed_dim
693
+
694
+ self.proj = nn.Conv2d(in_chans,
695
+ embed_dim,
696
+ kernel_size=patch_size,
697
+ stride=patch_size)
698
+ if norm_layer is not None:
699
+ self.norm = norm_layer(embed_dim)
700
+ else:
701
+ self.norm = None
702
+
703
+ def forward(self, x):
704
+ """Forward function."""
705
+ # padding
706
+ _, _, H, W = x.size()
707
+ if W % self.patch_size[1] != 0:
708
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
709
+ if H % self.patch_size[0] != 0:
710
+ x = F.pad(x,
711
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
712
+
713
+ x = self.proj(x) # B C Wh Ww
714
+ if self.norm is not None:
715
+ Wh, Ww = x.size(2), x.size(3)
716
+ x = x.flatten(2).transpose(1, 2)
717
+ x = self.norm(x)
718
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
719
+
720
+ return x
721
+
722
+
723
+ class SwinTransformer(nn.Module):
724
+ """ Swin Transformer backbone.
725
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
726
+ https://arxiv.org/pdf/2103.14030
727
+
728
+ Args:
729
+ pretrain_img_size (int): Input image size for training the pretrained model,
730
+ used in absolute postion embedding. Default 224.
731
+ patch_size (int | tuple(int)): Patch size. Default: 4.
732
+ in_chans (int): Number of input image channels. Default: 3.
733
+ embed_dim (int): Number of linear projection output channels. Default: 96.
734
+ depths (tuple[int]): Depths of each Swin Transformer stage.
735
+ num_heads (tuple[int]): Number of attention head of each stage.
736
+ window_size (int): Window size. Default: 7.
737
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
738
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
739
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
740
+ drop_rate (float): Dropout rate.
741
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
742
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
743
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
744
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
745
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
746
+ out_indices (Sequence[int]): Output from which stages.
747
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
748
+ -1 means not freezing any parameters.
749
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
750
+ """
751
+
752
+ def __init__(self,
753
+ pretrain_img_size=224,
754
+ patch_size=4,
755
+ in_chans=3,
756
+ embed_dim=96,
757
+ depths=[2, 2, 6, 2],
758
+ num_heads=[3, 6, 12, 24],
759
+ window_size=7,
760
+ mlp_ratio=4.,
761
+ qkv_bias=True,
762
+ qk_scale=None,
763
+ drop_rate=0.,
764
+ attn_drop_rate=0.,
765
+ drop_path_rate=0.2,
766
+ norm_layer=nn.LayerNorm,
767
+ ape=False,
768
+ patch_norm=True,
769
+ out_indices=(0, 1, 2, 3),
770
+ frozen_stages=-1,
771
+ use_checkpoint=False):
772
+ super().__init__()
773
+
774
+ self.pretrain_img_size = pretrain_img_size
775
+ self.num_layers = len(depths)
776
+ self.embed_dim = embed_dim
777
+ self.ape = ape
778
+ self.patch_norm = patch_norm
779
+ self.out_indices = out_indices
780
+ self.frozen_stages = frozen_stages
781
+
782
+ # split image into non-overlapping patches
783
+ self.patch_embed = PatchEmbed(
784
+ patch_size=patch_size,
785
+ in_chans=in_chans,
786
+ embed_dim=embed_dim,
787
+ norm_layer=norm_layer if self.patch_norm else None)
788
+
789
+ # absolute position embedding
790
+ if self.ape:
791
+ pretrain_img_size = to_2tuple(pretrain_img_size)
792
+ patch_size = to_2tuple(patch_size)
793
+ patches_resolution = [
794
+ pretrain_img_size[0] // patch_size[0],
795
+ pretrain_img_size[1] // patch_size[1]
796
+ ]
797
+
798
+ self.absolute_pos_embed = nn.Parameter(
799
+ torch.zeros(1, embed_dim, patches_resolution[0],
800
+ patches_resolution[1]))
801
+ trunc_normal_(self.absolute_pos_embed, std=.02)
802
+
803
+ self.pos_drop = nn.Dropout(p=drop_rate)
804
+
805
+ # stochastic depth
806
+ dpr = [
807
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
808
+ ] # stochastic depth decay rule
809
+
810
+ # build layers
811
+ self.layers = nn.ModuleList()
812
+ for i_layer in range(self.num_layers):
813
+ layer = BasicLayer(
814
+ dim=int(embed_dim * 2**i_layer),
815
+ depth=depths[i_layer],
816
+ num_heads=num_heads[i_layer],
817
+ window_size=window_size,
818
+ mlp_ratio=mlp_ratio,
819
+ qkv_bias=qkv_bias,
820
+ qk_scale=qk_scale,
821
+ drop=drop_rate,
822
+ attn_drop=attn_drop_rate,
823
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
824
+ norm_layer=norm_layer,
825
+ downsample=PatchMerging if
826
+ (i_layer < self.num_layers - 1) else None,
827
+ use_checkpoint=use_checkpoint)
828
+ self.layers.append(layer)
829
+
830
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
831
+ self.num_features = num_features
832
+
833
+ # add a norm layer for each output
834
+ for i_layer in out_indices:
835
+ layer = norm_layer(num_features[i_layer])
836
+ layer_name = f'norm{i_layer}'
837
+ self.add_module(layer_name, layer)
838
+
839
+ self._freeze_stages()
840
+
841
+ def _freeze_stages(self):
842
+ if self.frozen_stages >= 0:
843
+ self.patch_embed.eval()
844
+ for param in self.patch_embed.parameters():
845
+ param.requires_grad = False
846
+
847
+ if self.frozen_stages >= 1 and self.ape:
848
+ self.absolute_pos_embed.requires_grad = False
849
+
850
+ if self.frozen_stages >= 2:
851
+ self.pos_drop.eval()
852
+ for i in range(0, self.frozen_stages - 1):
853
+ m = self.layers[i]
854
+ m.eval()
855
+ for param in m.parameters():
856
+ param.requires_grad = False
857
+
858
+ def init_weights(self, pretrained=None):
859
+ """Initialize the weights in backbone.
860
+
861
+ Args:
862
+ pretrained (str, optional): Path to pre-trained weights.
863
+ Defaults to None.
864
+ """
865
+
866
+ def _init_weights(m):
867
+ if isinstance(m, nn.Linear):
868
+ trunc_normal_(m.weight, std=.02)
869
+ if isinstance(m, nn.Linear) and m.bias is not None:
870
+ nn.init.constant_(m.bias, 0)
871
+ elif isinstance(m, nn.LayerNorm):
872
+ nn.init.constant_(m.bias, 0)
873
+ nn.init.constant_(m.weight, 1.0)
874
+
875
+ if isinstance(pretrained, str):
876
+ self.apply(_init_weights)
877
+ load_checkpoint(self, pretrained, strict=False, logger=None)
878
+ elif pretrained is None:
879
+ self.apply(_init_weights)
880
+ else:
881
+ raise TypeError('pretrained must be a str or None')
882
+
883
+ def forward(self, x):
884
+ x = self.patch_embed(x)
885
+
886
+ Wh, Ww = x.size(2), x.size(3)
887
+ if self.ape:
888
+ # interpolate the position embedding to the corresponding size
889
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
890
+ size=(Wh, Ww),
891
+ mode='bicubic')
892
+ x = (x + absolute_pos_embed) # B Wh*Ww C
893
+
894
+ outs = [x.contiguous()]
895
+ x = x.flatten(2).transpose(1, 2)
896
+ x = self.pos_drop(x)
897
+ for i in range(self.num_layers):
898
+ layer = self.layers[i]
899
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
900
+
901
+ if i in self.out_indices:
902
+ norm_layer = getattr(self, f'norm{i}')
903
+ x_out = norm_layer(x_out)
904
+
905
+ out = x_out.view(-1, H, W,
906
+ self.num_features[i]).permute(0, 3, 1,
907
+ 2).contiguous()
908
+ outs.append(out)
909
+
910
+ return tuple(outs)
911
+
912
+ def train(self, mode=True):
913
+ """Convert the model into training mode while keep layers freezed."""
914
+ super(SwinTransformer, self).train(mode)
915
+ self._freeze_stages()
916
+
917
+
918
+ class PositionEmbeddingSine:
919
+
920
+ def __init__(self,
921
+ num_pos_feats=64,
922
+ temperature=10000,
923
+ normalize=False,
924
+ scale=None):
925
+ super().__init__()
926
+ self.num_pos_feats = num_pos_feats
927
+ self.temperature = temperature
928
+ self.normalize = normalize
929
+ if scale is not None and normalize is False:
930
+ raise ValueError("normalize should be True if scale is passed")
931
+ if scale is None:
932
+ scale = 2 * math.pi
933
+ self.scale = scale
934
+ self.dim_t = torch.arange(0,
935
+ self.num_pos_feats,
936
+ dtype=torch_dtype,
937
+ device=torch_device)
938
+
939
+ def __call__(self, b, h, w):
940
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
941
+ assert mask is not None
942
+ not_mask = ~mask
943
+ y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
944
+ x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
945
+ if self.normalize:
946
+ eps = 1e-6
947
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
948
+ self.scale).to(device=torch_device, dtype=torch_dtype)
949
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
950
+ self.scale).to(device=torch_device, dtype=torch_dtype)
951
+
952
+ dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
953
+
954
+ pos_x = x_embed[:, :, :, None] / dim_t
955
+ pos_y = y_embed[:, :, :, None] / dim_t
956
+ pos_x = torch.stack(
957
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
958
+ dim=4).flatten(3)
959
+ pos_y = torch.stack(
960
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
961
+ dim=4).flatten(3)
962
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
963
+
964
+
965
+ class MCLM(nn.Module):
966
+
967
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
968
+ super(MCLM, self).__init__()
969
+ self.attention = nn.ModuleList([
970
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
971
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
972
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
973
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
974
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
975
+ ])
976
+
977
+ self.linear1 = nn.Linear(d_model, d_model * 2)
978
+ self.linear2 = nn.Linear(d_model * 2, d_model)
979
+ self.linear3 = nn.Linear(d_model, d_model * 2)
980
+ self.linear4 = nn.Linear(d_model * 2, d_model)
981
+ self.norm1 = nn.LayerNorm(d_model)
982
+ self.norm2 = nn.LayerNorm(d_model)
983
+ self.dropout = nn.Dropout(0.1)
984
+ self.dropout1 = nn.Dropout(0.1)
985
+ self.dropout2 = nn.Dropout(0.1)
986
+ self.activation = get_activation_fn('relu')
987
+ self.pool_ratios = pool_ratios
988
+ self.p_poses = []
989
+ self.g_pos = None
990
+ self.positional_encoding = PositionEmbeddingSine(
991
+ num_pos_feats=d_model // 2, normalize=True)
992
+
993
+ def forward(self, l, g):
994
+ """
995
+ l: 4,c,h,w
996
+ g: 1,c,h,w
997
+ """
998
+ b, c, h, w = l.size()
999
+ # 4,c,h,w -> 1,c,2h,2w
1000
+ concated_locs = rearrange(l,
1001
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1002
+ hg=2,
1003
+ wg=2)
1004
+
1005
+ pools = []
1006
+ for pool_ratio in self.pool_ratios:
1007
+ # b,c,h,w
1008
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1009
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1010
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1011
+ if self.g_pos is None:
1012
+ pos_emb = self.positional_encoding(pool.shape[0],
1013
+ pool.shape[2],
1014
+ pool.shape[3])
1015
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1016
+ self.p_poses.append(pos_emb)
1017
+ pools = torch.cat(pools, 0)
1018
+ if self.g_pos is None:
1019
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1020
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
1021
+ g.shape[3])
1022
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1023
+
1024
+ # attention between glb (q) & multisensory concated-locs (k,v)
1025
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1026
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1027
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1028
+ g_hw_b_c = self.norm1(g_hw_b_c)
1029
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1030
+ self.linear2(
1031
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1032
+ g_hw_b_c = self.norm2(g_hw_b_c)
1033
+
1034
+ # attention between origin locs (q) & freashed glb (k,v)
1035
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1036
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1037
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1038
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1039
+ ng=2,
1040
+ nw=2)
1041
+ outputs_re = []
1042
+ for i, (_l, _g) in enumerate(
1043
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1044
+ outputs_re.append(self.attention[i + 1](_l, _g,
1045
+ _g)[0]) # (h w) 1 c
1046
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1047
+
1048
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1049
+ l_hw_b_c = self.norm1(l_hw_b_c)
1050
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1051
+ self.linear4(
1052
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1053
+ l_hw_b_c = self.norm2(l_hw_b_c)
1054
+
1055
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1056
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1057
+
1058
+
1059
+ class inf_MCLM(nn.Module):
1060
+
1061
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
1062
+ super(inf_MCLM, self).__init__()
1063
+ self.attention = nn.ModuleList([
1064
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1065
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1066
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1067
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1068
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1069
+ ])
1070
+
1071
+ self.linear1 = nn.Linear(d_model, d_model * 2)
1072
+ self.linear2 = nn.Linear(d_model * 2, d_model)
1073
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1074
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1075
+ self.norm1 = nn.LayerNorm(d_model)
1076
+ self.norm2 = nn.LayerNorm(d_model)
1077
+ self.dropout = nn.Dropout(0.1)
1078
+ self.dropout1 = nn.Dropout(0.1)
1079
+ self.dropout2 = nn.Dropout(0.1)
1080
+ self.activation = get_activation_fn('relu')
1081
+ self.pool_ratios = pool_ratios
1082
+ self.p_poses = []
1083
+ self.g_pos = None
1084
+ self.positional_encoding = PositionEmbeddingSine(
1085
+ num_pos_feats=d_model // 2, normalize=True)
1086
+
1087
+ def forward(self, l, g):
1088
+ """
1089
+ l: 4,c,h,w
1090
+ g: 1,c,h,w
1091
+ """
1092
+ b, c, h, w = l.size()
1093
+ # 4,c,h,w -> 1,c,2h,2w
1094
+ concated_locs = rearrange(l,
1095
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1096
+ hg=2,
1097
+ wg=2)
1098
+ self.p_poses = []
1099
+ pools = []
1100
+ for pool_ratio in self.pool_ratios:
1101
+ # b,c,h,w
1102
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1103
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1104
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1105
+ # if self.g_pos is None:
1106
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
1107
+ pool.shape[3])
1108
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1109
+ self.p_poses.append(pos_emb)
1110
+ pools = torch.cat(pools, 0)
1111
+ # if self.g_pos is None:
1112
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1113
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
1114
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1115
+
1116
+ # attention between glb (q) & multisensory concated-locs (k,v)
1117
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1118
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1119
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1120
+ g_hw_b_c = self.norm1(g_hw_b_c)
1121
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1122
+ self.linear2(
1123
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1124
+ g_hw_b_c = self.norm2(g_hw_b_c)
1125
+
1126
+ # attention between origin locs (q) & freashed glb (k,v)
1127
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1128
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1129
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1130
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1131
+ ng=2,
1132
+ nw=2)
1133
+ outputs_re = []
1134
+ for i, (_l, _g) in enumerate(
1135
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1136
+ outputs_re.append(self.attention[i + 1](_l, _g,
1137
+ _g)[0]) # (h w) 1 c
1138
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1139
+
1140
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1141
+ l_hw_b_c = self.norm1(l_hw_b_c)
1142
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1143
+ self.linear4(
1144
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1145
+ l_hw_b_c = self.norm2(l_hw_b_c)
1146
+
1147
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1148
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1149
+
1150
+
1151
+ class MCRM(nn.Module):
1152
+
1153
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1154
+ super(MCRM, self).__init__()
1155
+ self.attention = nn.ModuleList([
1156
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1157
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1158
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1159
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1160
+ ])
1161
+
1162
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1163
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1164
+ self.norm1 = nn.LayerNorm(d_model)
1165
+ self.norm2 = nn.LayerNorm(d_model)
1166
+ self.dropout = nn.Dropout(0.1)
1167
+ self.dropout1 = nn.Dropout(0.1)
1168
+ self.dropout2 = nn.Dropout(0.1)
1169
+ self.sigmoid = nn.Sigmoid()
1170
+ self.activation = get_activation_fn('relu')
1171
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1172
+ self.pool_ratios = pool_ratios
1173
+ self.positional_encoding = PositionEmbeddingSine(
1174
+ num_pos_feats=d_model // 2, normalize=True)
1175
+
1176
+ def forward(self, x):
1177
+ b, c, h, w = x.size()
1178
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1179
+ # b(4),c,h,w
1180
+ patched_glb = rearrange(glb,
1181
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1182
+ hg=2,
1183
+ wg=2)
1184
+
1185
+ # generate token attention map
1186
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1187
+ token_attention_map = F.interpolate(token_attention_map,
1188
+ size=patches2image(loc).shape[-2:],
1189
+ mode='nearest')
1190
+ loc = loc * rearrange(token_attention_map,
1191
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1192
+ hg=2,
1193
+ wg=2)
1194
+ pools = []
1195
+ for pool_ratio in self.pool_ratios:
1196
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1197
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1198
+ pools.append(rearrange(pool,
1199
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1200
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1201
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1202
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1203
+ outputs = []
1204
+ for i, q in enumerate(
1205
+ loc_.unbind(dim=0)): # traverse all local patches
1206
+ # np*hw,1,c
1207
+ v = pools[i]
1208
+ k = v
1209
+ outputs.append(self.attention[i](q, k, v)[0])
1210
+ outputs = torch.cat(outputs, 1)
1211
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1212
+ src = self.norm1(src)
1213
+ src = src + self.dropout2(
1214
+ self.linear4(
1215
+ self.dropout(self.activation(self.linear3(src)).clone())))
1216
+ src = self.norm2(src)
1217
+
1218
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1219
+ glb = glb + F.interpolate(patches2image(src),
1220
+ size=glb.shape[-2:],
1221
+ mode='nearest') # freshed glb
1222
+ return torch.cat((src, glb), 0), token_attention_map
1223
+
1224
+
1225
+ class inf_MCRM(nn.Module):
1226
+
1227
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1228
+ super(inf_MCRM, self).__init__()
1229
+ self.attention = nn.ModuleList([
1230
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1231
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1232
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1233
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1234
+ ])
1235
+
1236
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1237
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1238
+ self.norm1 = nn.LayerNorm(d_model)
1239
+ self.norm2 = nn.LayerNorm(d_model)
1240
+ self.dropout = nn.Dropout(0.1)
1241
+ self.dropout1 = nn.Dropout(0.1)
1242
+ self.dropout2 = nn.Dropout(0.1)
1243
+ self.sigmoid = nn.Sigmoid()
1244
+ self.activation = get_activation_fn('relu')
1245
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1246
+ self.pool_ratios = pool_ratios
1247
+ self.positional_encoding = PositionEmbeddingSine(
1248
+ num_pos_feats=d_model // 2, normalize=True)
1249
+
1250
+ def forward(self, x):
1251
+ b, c, h, w = x.size()
1252
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1253
+ # b(4),c,h,w
1254
+ patched_glb = rearrange(glb,
1255
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1256
+ hg=2,
1257
+ wg=2)
1258
+
1259
+ # generate token attention map
1260
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1261
+ token_attention_map = F.interpolate(token_attention_map,
1262
+ size=patches2image(loc).shape[-2:],
1263
+ mode='nearest')
1264
+ loc = loc * rearrange(token_attention_map,
1265
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1266
+ hg=2,
1267
+ wg=2)
1268
+ pools = []
1269
+ for pool_ratio in self.pool_ratios:
1270
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1271
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1272
+ pools.append(rearrange(pool,
1273
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1274
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1275
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1276
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1277
+ outputs = []
1278
+ for i, q in enumerate(
1279
+ loc_.unbind(dim=0)): # traverse all local patches
1280
+ # np*hw,1,c
1281
+ v = pools[i]
1282
+ k = v
1283
+ outputs.append(self.attention[i](q, k, v)[0])
1284
+ outputs = torch.cat(outputs, 1)
1285
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1286
+ src = self.norm1(src)
1287
+ src = src + self.dropout2(
1288
+ self.linear4(
1289
+ self.dropout(self.activation(self.linear3(src)).clone())))
1290
+ src = self.norm2(src)
1291
+
1292
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1293
+ glb = glb + F.interpolate(patches2image(src),
1294
+ size=glb.shape[-2:],
1295
+ mode='nearest') # freshed glb
1296
+ return torch.cat((src, glb), 0)
1297
+
1298
+
1299
+ # model for single-scale training
1300
+ class MVANet(nn.Module):
1301
+
1302
+ def __init__(self):
1303
+ super().__init__()
1304
+ self.backbone = SwinB(pretrained=True)
1305
+ emb_dim = 128
1306
+ self.sideout5 = nn.Sequential(
1307
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1308
+ self.sideout4 = nn.Sequential(
1309
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1310
+ self.sideout3 = nn.Sequential(
1311
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1312
+ self.sideout2 = nn.Sequential(
1313
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1314
+ self.sideout1 = nn.Sequential(
1315
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1316
+
1317
+ self.output5 = make_cbr(1024, emb_dim)
1318
+ self.output4 = make_cbr(512, emb_dim)
1319
+ self.output3 = make_cbr(256, emb_dim)
1320
+ self.output2 = make_cbr(128, emb_dim)
1321
+ self.output1 = make_cbr(128, emb_dim)
1322
+
1323
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
1324
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1325
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1326
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1327
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1328
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
1329
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
1330
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
1331
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
1332
+
1333
+ self.insmask_head = nn.Sequential(
1334
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1335
+ nn.BatchNorm2d(384), nn.PReLU(),
1336
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1337
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1338
+
1339
+ self.shallow = nn.Sequential(
1340
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1341
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1342
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1343
+ self.output = nn.Sequential(
1344
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1345
+
1346
+ for m in self.modules():
1347
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1348
+ m.inplace = True
1349
+
1350
+ def forward(self, x):
1351
+ x = x.to(dtype=torch_dtype, device=torch_device)
1352
+ shallow = self.shallow(x)
1353
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1354
+ loc = image2patches(x)
1355
+ input = torch.cat((loc, glb), dim=0)
1356
+ feature = self.backbone(input)
1357
+ e5 = self.output5(feature[4]) # (5,128,16,16)
1358
+ e4 = self.output4(feature[3]) # (5,128,32,32)
1359
+ e3 = self.output3(feature[2]) # (5,128,64,64)
1360
+ e2 = self.output2(feature[1]) # (5,128,128,128)
1361
+ e1 = self.output1(feature[0]) # (5,128,128,128)
1362
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1363
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
1364
+
1365
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
1366
+ e4 = self.conv4(e4)
1367
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
1368
+ e3 = self.conv3(e3)
1369
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
1370
+ e2 = self.conv2(e2)
1371
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
1372
+ e1 = self.conv1(e1)
1373
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1374
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
1375
+ # add glb feat in
1376
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1377
+ # merge
1378
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
1379
+ # shallow feature merge
1380
+ final_output = final_output + resize_as(shallow, final_output)
1381
+ final_output = self.upsample1(rescale_to(final_output))
1382
+ final_output = rescale_to(final_output +
1383
+ resize_as(shallow, final_output))
1384
+ final_output = self.upsample2(final_output)
1385
+ final_output = self.output(final_output)
1386
+ ####
1387
+ sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
1388
+ sideout4 = self.sideout4(e4)
1389
+ sideout3 = self.sideout3(e3)
1390
+ sideout2 = self.sideout2(e2)
1391
+ sideout1 = self.sideout1(e1)
1392
+ #######glb_sideouts ######
1393
+ glb5 = self.sideout5(glb_e5)
1394
+ glb4 = sideout4[-1, :, :, :].unsqueeze(0)
1395
+ glb3 = sideout3[-1, :, :, :].unsqueeze(0)
1396
+ glb2 = sideout2[-1, :, :, :].unsqueeze(0)
1397
+ glb1 = sideout1[-1, :, :, :].unsqueeze(0)
1398
+ ####### concat 4 to 1 #######
1399
+ sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
1400
+ device=torch_device)
1401
+ sideout2 = patches2image(sideout2[:-1]).to(
1402
+ dtype=torch_dtype,
1403
+ device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
1404
+ sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
1405
+ device=torch_device)
1406
+ sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
1407
+ device=torch_device)
1408
+ sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
1409
+ device=torch_device)
1410
+ if self.training:
1411
+ return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
1412
+ else:
1413
+ return final_output
1414
+
1415
+
1416
+ # model for multi-scale testing
1417
+ class inf_MVANet(nn.Module):
1418
+
1419
+ def __init__(self):
1420
+ super().__init__()
1421
+ # self.backbone = SwinB(pretrained=True)
1422
+ self.backbone = SwinB(pretrained=False)
1423
+
1424
+ emb_dim = 128
1425
+ self.output5 = make_cbr(1024, emb_dim)
1426
+ self.output4 = make_cbr(512, emb_dim)
1427
+ self.output3 = make_cbr(256, emb_dim)
1428
+ self.output2 = make_cbr(128, emb_dim)
1429
+ self.output1 = make_cbr(128, emb_dim)
1430
+
1431
+ self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
1432
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1433
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1434
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1435
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1436
+ self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1437
+ self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1438
+ self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1439
+ self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1440
+
1441
+ self.insmask_head = nn.Sequential(
1442
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1443
+ nn.BatchNorm2d(384), nn.PReLU(),
1444
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1445
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1446
+
1447
+ self.shallow = nn.Sequential(
1448
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1449
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1450
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1451
+ self.output = nn.Sequential(
1452
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1453
+
1454
+ for m in self.modules():
1455
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1456
+ m.inplace = True
1457
+
1458
+ def forward(self, x):
1459
+ shallow = self.shallow(x)
1460
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1461
+ loc = image2patches(x)
1462
+ input = torch.cat((loc, glb), dim=0)
1463
+ feature = self.backbone(input)
1464
+ e5 = self.output5(feature[4])
1465
+ e4 = self.output4(feature[3])
1466
+ e3 = self.output3(feature[2])
1467
+ e2 = self.output2(feature[1])
1468
+ e1 = self.output1(feature[0])
1469
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1470
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
1471
+
1472
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1473
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1474
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1475
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1476
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1477
+ # after decoder, concat loc features to a whole one, and merge
1478
+ output1_cat = patches2image(loc_e1)
1479
+ # add glb feat in
1480
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1481
+ # merge
1482
+ final_output = self.insmask_head(output1_cat)
1483
+ # shallow feature merge
1484
+ final_output = final_output + resize_as(shallow, final_output)
1485
+ final_output = self.upsample1(rescale_to(final_output))
1486
+ final_output = rescale_to(final_output +
1487
+ resize_as(shallow, final_output))
1488
+ final_output = self.upsample2(final_output)
1489
+ final_output = self.output(final_output)
1490
+ return final_output
1491
+
1492
+
1493
+ class load_MVANet_Model:
1494
+
1495
+ def __init__(self):
1496
+ pass
1497
+
1498
+ @classmethod
1499
+ def INPUT_TYPES(s):
1500
+ return {
1501
+ "required": {},
1502
+ }
1503
+
1504
+ RETURN_TYPES = ("MVANet_Model", )
1505
+ FUNCTION = "test"
1506
+ CATEGORY = "MVANet"
1507
+
1508
+ def test(self):
1509
+ return (load_model(get_model_path()), )
1510
+
1511
+
1512
+ class run_MVANet_inference:
1513
+
1514
+ def __init__(self):
1515
+ pass
1516
+
1517
+ @classmethod
1518
+ def INPUT_TYPES(s):
1519
+ return {
1520
+ "required": {
1521
+ "image": ("IMAGE", ),
1522
+ "MVANet_Model": ("MVANet_Model", ),
1523
+ },
1524
+ }
1525
+
1526
+ RETURN_TYPES = ("MASK", )
1527
+ FUNCTION = "test"
1528
+ CATEGORY = "MVANet"
1529
+
1530
+ def test(
1531
+ self,
1532
+ image,
1533
+ MVANet_Model,
1534
+ ):
1535
+ ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
1536
+
1537
+ return (ret, )
1538
+
1539
+
1540
+ NODE_CLASS_MAPPINGS = {
1541
+ "load_MVANet_Model": load_MVANet_Model,
1542
+ "run_MVANet_inference": run_MVANet_inference
1543
+ }
1544
+
1545
+ NODE_DISPLAY_NAME_MAPPINGS = {
1546
+ "load_MVANet_Model": "load MVANet Model",
1547
+ "run_MVANet_inference": "run_MVANet_inference"
1548
+ }
ComfyUI_MVANet/MVANet_inference.run.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/sh
2
+ . "${HOME}/dbnew.sh"
3
+ python3 './MVANet_inference.py'
ComfyUI_MVANet/README.org ADDED
@@ -0,0 +1,1694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * COMMENT Sample
2
+
3
+ ** Shell script to download
4
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
5
+ #+end_src
6
+
7
+ ** MVANet_inference import
8
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
9
+ #+end_src
10
+
11
+ ** MVANet_inference function
12
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
13
+ #+end_src
14
+
15
+ ** MVANet_inference class
16
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
17
+ #+end_src
18
+
19
+ ** MVANet_inference execute
20
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
21
+ #+end_src
22
+
23
+ ** MVANet_inference unify
24
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
25
+ #+end_src
26
+
27
+ * Download the code:
28
+
29
+ ** Function to download
30
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
31
+ get_repo(){
32
+ DIR_REPO="${HOME}/GITHUB/$('echo' "${1}" | 'sed' 's/^git@github.com://g ; s@^https://github.com/@@g ; s@.git$@@g' )"
33
+ DIR_BASE="$('dirname' '--' "${DIR_REPO}")"
34
+ mkdir -pv -- "${DIR_BASE}"
35
+ cd "${DIR_BASE}"
36
+ git clone "${1}"
37
+ cd "${DIR_REPO}"
38
+ git pull
39
+ git submodule update --recursive --init
40
+ }
41
+ #+end_src
42
+
43
+ ** Download
44
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
45
+ get_repo 'https://github.com/qianyu-dlut/MVANet.git'
46
+ #+end_src
47
+
48
+ * Dependencies
49
+ #+begin_src conf :tangle ./requirements.txt
50
+ timm
51
+ einops
52
+ wget
53
+ #+end_src
54
+
55
+ * Python inference
56
+
57
+ ** Important configs
58
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
59
+ import os
60
+ import sys
61
+
62
+ HOME_DIR = os.environ.get('HOME', '/root')
63
+ MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
64
+ finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
65
+ pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
66
+ #+end_src
67
+
68
+ ** MVANet_inference import
69
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
70
+ import math
71
+ import numpy as np
72
+ import cv2
73
+ import wget
74
+
75
+ import torch
76
+ import torch.nn as nn
77
+ import torch.nn.functional as F
78
+ import torch.utils.checkpoint as checkpoint
79
+ from torch.autograd import Variable
80
+ from torch import nn
81
+ from torchvision import transforms
82
+
83
+ from einops import rearrange
84
+
85
+ from timm.models import load_checkpoint
86
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
87
+
88
+ torch_device = 'cuda'
89
+ torch_dtype = torch.float16
90
+ #+end_src
91
+
92
+ ** COMMENT Load image using CV
93
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
94
+ def load_image(input_image_path):
95
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
96
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
97
+ return img
98
+
99
+
100
+ def load_image_torch(input_image_path):
101
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
102
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
+ img = torch.from_numpy(img)
104
+ img = img.to(dtype=torch.float32)
105
+ img /= 255.0
106
+ img = img.unsqueeze(0)
107
+ return img
108
+
109
+
110
+ def save_mask(output_image_path, mask):
111
+ cv2.imwrite(output_image_path, mask)
112
+
113
+
114
+ def save_mask_torch(output_image_path, mask):
115
+ mask = mask.detach().cpu()
116
+ mask *= 255.0
117
+ mask = mask.clamp(0, 255)
118
+ print(mask.shape)
119
+ mask = mask.squeeze(0)
120
+ mask = mask.to(dtype=torch.uint8)
121
+ print(mask.shape)
122
+ mask = mask.numpy()
123
+ print(mask.shape)
124
+ cv2.imwrite(output_image_path, mask)
125
+ #+end_src
126
+
127
+ ** MVANet_inference function
128
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
129
+ def check_mkdir(dir_name):
130
+ if not os.path.isdir(dir_name):
131
+ os.makedirs(dir_name)
132
+
133
+
134
+ def SwinT(pretrained=True):
135
+ model = SwinTransformer(embed_dim=96,
136
+ depths=[2, 2, 6, 2],
137
+ num_heads=[3, 6, 12, 24],
138
+ window_size=7)
139
+ if pretrained is True:
140
+ model.load_state_dict(torch.load(
141
+ 'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
142
+ map_location='cpu')['model'],
143
+ strict=False)
144
+
145
+ return model
146
+
147
+
148
+ def SwinS(pretrained=True):
149
+ model = SwinTransformer(embed_dim=96,
150
+ depths=[2, 2, 18, 2],
151
+ num_heads=[3, 6, 12, 24],
152
+ window_size=7)
153
+ if pretrained is True:
154
+ model.load_state_dict(torch.load(
155
+ 'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
156
+ map_location='cpu')['model'],
157
+ strict=False)
158
+
159
+ return model
160
+
161
+
162
+ def SwinB(pretrained=True):
163
+ model = SwinTransformer(embed_dim=128,
164
+ depths=[2, 2, 18, 2],
165
+ num_heads=[4, 8, 16, 32],
166
+ window_size=12)
167
+ if pretrained is True:
168
+ import os
169
+ model.load_state_dict(torch.load(pretrained_SwinB_model_path,
170
+ map_location='cpu')['model'],
171
+ strict=False)
172
+ return model
173
+
174
+
175
+ def SwinL(pretrained=True):
176
+ model = SwinTransformer(embed_dim=192,
177
+ depths=[2, 2, 18, 2],
178
+ num_heads=[6, 12, 24, 48],
179
+ window_size=12)
180
+ if pretrained is True:
181
+ model.load_state_dict(torch.load(
182
+ 'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
183
+ map_location='cpu')['model'],
184
+ strict=False)
185
+
186
+ return model
187
+
188
+
189
+ def get_activation_fn(activation):
190
+ """Return an activation function given a string"""
191
+ if activation == "relu":
192
+ return F.relu
193
+ if activation == "gelu":
194
+ return F.gelu
195
+ if activation == "glu":
196
+ return F.glu
197
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
198
+
199
+
200
+ def make_cbr(in_dim, out_dim):
201
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
202
+ nn.BatchNorm2d(out_dim), nn.PReLU())
203
+
204
+
205
+ def make_cbg(in_dim, out_dim):
206
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
207
+ nn.BatchNorm2d(out_dim), nn.GELU())
208
+
209
+
210
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
211
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
212
+
213
+
214
+ def resize_as(x, y, interpolation='bilinear'):
215
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
216
+
217
+
218
+ def image2patches(x):
219
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
220
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
221
+ return x
222
+
223
+
224
+ def patches2image(x):
225
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
226
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
227
+ return x
228
+
229
+
230
+ def window_partition(x, window_size):
231
+ """
232
+ Args:
233
+ x: (B, H, W, C)
234
+ window_size (int): window size
235
+
236
+ Returns:
237
+ windows: (num_windows*B, window_size, window_size, C)
238
+ """
239
+ B, H, W, C = x.shape
240
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
241
+ C)
242
+ windows = x.permute(0, 1, 3, 2, 4,
243
+ 5).contiguous().view(-1, window_size, window_size, C)
244
+ return windows
245
+
246
+
247
+ def window_reverse(windows, window_size, H, W):
248
+ """
249
+ Args:
250
+ windows: (num_windows*B, window_size, window_size, C)
251
+ window_size (int): Window size
252
+ H (int): Height of image
253
+ W (int): Width of image
254
+
255
+ Returns:
256
+ x: (B, H, W, C)
257
+ """
258
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
259
+ x = windows.view(B, H // window_size, W // window_size, window_size,
260
+ window_size, -1)
261
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
262
+ return x
263
+ #+end_src
264
+
265
+ ** MVANet_inference class
266
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
267
+ class Mlp(nn.Module):
268
+ """ Multilayer perceptron."""
269
+
270
+ def __init__(self,
271
+ in_features,
272
+ hidden_features=None,
273
+ out_features=None,
274
+ act_layer=nn.GELU,
275
+ drop=0.):
276
+ super().__init__()
277
+ out_features = out_features or in_features
278
+ hidden_features = hidden_features or in_features
279
+ self.fc1 = nn.Linear(in_features, hidden_features)
280
+ self.act = act_layer()
281
+ self.fc2 = nn.Linear(hidden_features, out_features)
282
+ self.drop = nn.Dropout(drop)
283
+
284
+ def forward(self, x):
285
+ x = self.fc1(x)
286
+ x = self.act(x)
287
+ x = self.drop(x)
288
+ x = self.fc2(x)
289
+ x = self.drop(x)
290
+ return x
291
+
292
+
293
+ class WindowAttention(nn.Module):
294
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
295
+ It supports both of shifted and non-shifted window.
296
+
297
+ Args:
298
+ dim (int): Number of input channels.
299
+ window_size (tuple[int]): The height and width of the window.
300
+ num_heads (int): Number of attention heads.
301
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
302
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
303
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
304
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
305
+ """
306
+
307
+ def __init__(self,
308
+ dim,
309
+ window_size,
310
+ num_heads,
311
+ qkv_bias=True,
312
+ qk_scale=None,
313
+ attn_drop=0.,
314
+ proj_drop=0.):
315
+
316
+ super().__init__()
317
+ self.dim = dim
318
+ self.window_size = window_size # Wh, Ww
319
+ self.num_heads = num_heads
320
+ head_dim = dim // num_heads
321
+ self.scale = qk_scale or head_dim**-0.5
322
+
323
+ # define a parameter table of relative position bias
324
+ self.relative_position_bias_table = nn.Parameter(
325
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
326
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
327
+
328
+ # get pair-wise relative position index for each token inside the window
329
+ coords_h = torch.arange(self.window_size[0])
330
+ coords_w = torch.arange(self.window_size[1])
331
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
332
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
333
+ relative_coords = coords_flatten[:, :,
334
+ None] - coords_flatten[:,
335
+ None, :] # 2, Wh*Ww, Wh*Ww
336
+ relative_coords = relative_coords.permute(
337
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
338
+ relative_coords[:, :,
339
+ 0] += self.window_size[0] - 1 # shift to start from 0
340
+ relative_coords[:, :, 1] += self.window_size[1] - 1
341
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
342
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
343
+ self.register_buffer("relative_position_index",
344
+ relative_position_index)
345
+
346
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
347
+ self.attn_drop = nn.Dropout(attn_drop)
348
+ self.proj = nn.Linear(dim, dim)
349
+ self.proj_drop = nn.Dropout(proj_drop)
350
+
351
+ trunc_normal_(self.relative_position_bias_table, std=.02)
352
+ self.softmax = nn.Softmax(dim=-1)
353
+
354
+ def forward(self, x, mask=None):
355
+ """ Forward function.
356
+
357
+ Args:
358
+ x: input features with shape of (num_windows*B, N, C)
359
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
360
+ """
361
+ x = x.to(dtype=torch_dtype, device=torch_device)
362
+ B_, N, C = x.shape
363
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
364
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
365
+ q, k, v = qkv[0], qkv[1], qkv[
366
+ 2] # make torchscript happy (cannot use tensor as tuple)
367
+
368
+ q = q * self.scale
369
+ attn = (q @ k.transpose(-2, -1))
370
+
371
+ relative_position_bias = self.relative_position_bias_table[
372
+ self.relative_position_index.view(-1)].view(
373
+ self.window_size[0] * self.window_size[1],
374
+ self.window_size[0] * self.window_size[1],
375
+ -1) # Wh*Ww,Wh*Ww,nH
376
+ relative_position_bias = relative_position_bias.permute(
377
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
378
+ attn = attn + relative_position_bias.unsqueeze(0)
379
+
380
+ if mask is not None:
381
+ nW = mask.shape[0]
382
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
383
+ N) + mask.unsqueeze(1).unsqueeze(0)
384
+ attn = attn.view(-1, self.num_heads, N, N)
385
+ attn = self.softmax(attn)
386
+ else:
387
+ attn = self.softmax(attn)
388
+
389
+ attn = self.attn_drop(attn)
390
+ attn = attn.to(dtype=torch_dtype, device=torch_device)
391
+ v = v.to(dtype=torch_dtype, device=torch_device)
392
+
393
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
394
+ x = self.proj(x)
395
+ x = self.proj_drop(x)
396
+ return x
397
+
398
+
399
+ class SwinTransformerBlock(nn.Module):
400
+ """ Swin Transformer Block.
401
+
402
+ Args:
403
+ dim (int): Number of input channels.
404
+ num_heads (int): Number of attention heads.
405
+ window_size (int): Window size.
406
+ shift_size (int): Shift size for SW-MSA.
407
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
408
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
409
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
410
+ drop (float, optional): Dropout rate. Default: 0.0
411
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
412
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
413
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
414
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
415
+ """
416
+
417
+ def __init__(self,
418
+ dim,
419
+ num_heads,
420
+ window_size=7,
421
+ shift_size=0,
422
+ mlp_ratio=4.,
423
+ qkv_bias=True,
424
+ qk_scale=None,
425
+ drop=0.,
426
+ attn_drop=0.,
427
+ drop_path=0.,
428
+ act_layer=nn.GELU,
429
+ norm_layer=nn.LayerNorm):
430
+ super().__init__()
431
+ self.dim = dim
432
+ self.num_heads = num_heads
433
+ self.window_size = window_size
434
+ self.shift_size = shift_size
435
+ self.mlp_ratio = mlp_ratio
436
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
437
+
438
+ self.norm1 = norm_layer(dim)
439
+ self.attn = WindowAttention(dim,
440
+ window_size=to_2tuple(self.window_size),
441
+ num_heads=num_heads,
442
+ qkv_bias=qkv_bias,
443
+ qk_scale=qk_scale,
444
+ attn_drop=attn_drop,
445
+ proj_drop=drop)
446
+
447
+ self.drop_path = DropPath(
448
+ drop_path) if drop_path > 0. else nn.Identity()
449
+ self.norm2 = norm_layer(dim)
450
+ mlp_hidden_dim = int(dim * mlp_ratio)
451
+ self.mlp = Mlp(in_features=dim,
452
+ hidden_features=mlp_hidden_dim,
453
+ act_layer=act_layer,
454
+ drop=drop)
455
+
456
+ self.H = None
457
+ self.W = None
458
+
459
+ def forward(self, x, mask_matrix):
460
+ """ Forward function.
461
+
462
+ Args:
463
+ x: Input feature, tensor size (B, H*W, C).
464
+ H, W: Spatial resolution of the input feature.
465
+ mask_matrix: Attention mask for cyclic shift.
466
+ """
467
+ B, L, C = x.shape
468
+ H, W = self.H, self.W
469
+ assert L == H * W, "input feature has wrong size"
470
+
471
+ shortcut = x
472
+ x = self.norm1(x)
473
+ x = x.view(B, H, W, C)
474
+
475
+ # pad feature maps to multiples of window size
476
+ pad_l = pad_t = 0
477
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
478
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
479
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
480
+ _, Hp, Wp, _ = x.shape
481
+
482
+ # cyclic shift
483
+ if self.shift_size > 0:
484
+ shifted_x = torch.roll(x,
485
+ shifts=(-self.shift_size, -self.shift_size),
486
+ dims=(1, 2))
487
+ attn_mask = mask_matrix
488
+ else:
489
+ shifted_x = x
490
+ attn_mask = None
491
+
492
+ # partition windows
493
+ x_windows = window_partition(
494
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
495
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
496
+ C) # nW*B, window_size*window_size, C
497
+
498
+ # W-MSA/SW-MSA
499
+ attn_windows = self.attn(
500
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
501
+
502
+ # merge windows
503
+ attn_windows = attn_windows.view(-1, self.window_size,
504
+ self.window_size, C)
505
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
506
+ Wp) # B H' W' C
507
+
508
+ # reverse cyclic shift
509
+ if self.shift_size > 0:
510
+ x = torch.roll(shifted_x,
511
+ shifts=(self.shift_size, self.shift_size),
512
+ dims=(1, 2))
513
+ else:
514
+ x = shifted_x
515
+
516
+ if pad_r > 0 or pad_b > 0:
517
+ x = x[:, :H, :W, :].contiguous()
518
+
519
+ x = x.view(B, H * W, C)
520
+
521
+ # FFN
522
+ x = shortcut + self.drop_path(x)
523
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
524
+
525
+ return x
526
+
527
+
528
+ class PatchMerging(nn.Module):
529
+ """ Patch Merging Layer
530
+
531
+ Args:
532
+ dim (int): Number of input channels.
533
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
534
+ """
535
+
536
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
537
+ super().__init__()
538
+ self.dim = dim
539
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
540
+ self.norm = norm_layer(4 * dim)
541
+
542
+ def forward(self, x, H, W):
543
+ """ Forward function.
544
+
545
+ Args:
546
+ x: Input feature, tensor size (B, H*W, C).
547
+ H, W: Spatial resolution of the input feature.
548
+ """
549
+ B, L, C = x.shape
550
+ assert L == H * W, "input feature has wrong size"
551
+
552
+ x = x.view(B, H, W, C)
553
+
554
+ # padding
555
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
556
+ if pad_input:
557
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
558
+
559
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
560
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
561
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
562
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
563
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
564
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
565
+
566
+ x = self.norm(x)
567
+ x = self.reduction(x)
568
+
569
+ return x
570
+
571
+
572
+ class BasicLayer(nn.Module):
573
+ """ A basic Swin Transformer layer for one stage.
574
+
575
+ Args:
576
+ dim (int): Number of feature channels
577
+ depth (int): Depths of this stage.
578
+ num_heads (int): Number of attention head.
579
+ window_size (int): Local window size. Default: 7.
580
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
581
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
582
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
583
+ drop (float, optional): Dropout rate. Default: 0.0
584
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
585
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
586
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
587
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
588
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
589
+ """
590
+
591
+ def __init__(self,
592
+ dim,
593
+ depth,
594
+ num_heads,
595
+ window_size=7,
596
+ mlp_ratio=4.,
597
+ qkv_bias=True,
598
+ qk_scale=None,
599
+ drop=0.,
600
+ attn_drop=0.,
601
+ drop_path=0.,
602
+ norm_layer=nn.LayerNorm,
603
+ downsample=None,
604
+ use_checkpoint=False):
605
+ super().__init__()
606
+ self.window_size = window_size
607
+ self.shift_size = window_size // 2
608
+ self.depth = depth
609
+ self.use_checkpoint = use_checkpoint
610
+
611
+ # build blocks
612
+ self.blocks = nn.ModuleList([
613
+ SwinTransformerBlock(dim=dim,
614
+ num_heads=num_heads,
615
+ window_size=window_size,
616
+ shift_size=0 if
617
+ (i % 2 == 0) else window_size // 2,
618
+ mlp_ratio=mlp_ratio,
619
+ qkv_bias=qkv_bias,
620
+ qk_scale=qk_scale,
621
+ drop=drop,
622
+ attn_drop=attn_drop,
623
+ drop_path=drop_path[i] if isinstance(
624
+ drop_path, list) else drop_path,
625
+ norm_layer=norm_layer) for i in range(depth)
626
+ ])
627
+
628
+ # patch merging layer
629
+ if downsample is not None:
630
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
631
+ else:
632
+ self.downsample = None
633
+
634
+ def forward(self, x, H, W):
635
+ """ Forward function.
636
+
637
+ Args:
638
+ x: Input feature, tensor size (B, H*W, C).
639
+ H, W: Spatial resolution of the input feature.
640
+ """
641
+
642
+ # calculate attention mask for SW-MSA
643
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
644
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
645
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
646
+ h_slices = (slice(0, -self.window_size),
647
+ slice(-self.window_size,
648
+ -self.shift_size), slice(-self.shift_size, None))
649
+ w_slices = (slice(0, -self.window_size),
650
+ slice(-self.window_size,
651
+ -self.shift_size), slice(-self.shift_size, None))
652
+ cnt = 0
653
+ for h in h_slices:
654
+ for w in w_slices:
655
+ img_mask[:, h, w, :] = cnt
656
+ cnt += 1
657
+
658
+ mask_windows = window_partition(
659
+ img_mask, self.window_size) # nW, window_size, window_size, 1
660
+ mask_windows = mask_windows.view(-1,
661
+ self.window_size * self.window_size)
662
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
663
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
664
+ float(-100.0)).masked_fill(
665
+ attn_mask == 0, float(0.0))
666
+
667
+ for blk in self.blocks:
668
+ blk.H, blk.W = H, W
669
+ if self.use_checkpoint:
670
+ x = checkpoint.checkpoint(blk, x, attn_mask)
671
+ else:
672
+ x = blk(x, attn_mask)
673
+ if self.downsample is not None:
674
+ x_down = self.downsample(x, H, W)
675
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
676
+ return x, H, W, x_down, Wh, Ww
677
+ else:
678
+ return x, H, W, x, H, W
679
+
680
+
681
+ class PatchEmbed(nn.Module):
682
+ """ Image to Patch Embedding
683
+
684
+ Args:
685
+ patch_size (int): Patch token size. Default: 4.
686
+ in_chans (int): Number of input image channels. Default: 3.
687
+ embed_dim (int): Number of linear projection output channels. Default: 96.
688
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
689
+ """
690
+
691
+ def __init__(self,
692
+ patch_size=4,
693
+ in_chans=3,
694
+ embed_dim=96,
695
+ norm_layer=None):
696
+ super().__init__()
697
+ patch_size = to_2tuple(patch_size)
698
+ self.patch_size = patch_size
699
+
700
+ self.in_chans = in_chans
701
+ self.embed_dim = embed_dim
702
+
703
+ self.proj = nn.Conv2d(in_chans,
704
+ embed_dim,
705
+ kernel_size=patch_size,
706
+ stride=patch_size)
707
+ if norm_layer is not None:
708
+ self.norm = norm_layer(embed_dim)
709
+ else:
710
+ self.norm = None
711
+
712
+ def forward(self, x):
713
+ """Forward function."""
714
+ # padding
715
+ _, _, H, W = x.size()
716
+ if W % self.patch_size[1] != 0:
717
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
718
+ if H % self.patch_size[0] != 0:
719
+ x = F.pad(x,
720
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
721
+
722
+ x = self.proj(x) # B C Wh Ww
723
+ if self.norm is not None:
724
+ Wh, Ww = x.size(2), x.size(3)
725
+ x = x.flatten(2).transpose(1, 2)
726
+ x = self.norm(x)
727
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
728
+
729
+ return x
730
+
731
+
732
+ class SwinTransformer(nn.Module):
733
+ """ Swin Transformer backbone.
734
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
735
+ https://arxiv.org/pdf/2103.14030
736
+
737
+ Args:
738
+ pretrain_img_size (int): Input image size for training the pretrained model,
739
+ used in absolute postion embedding. Default 224.
740
+ patch_size (int | tuple(int)): Patch size. Default: 4.
741
+ in_chans (int): Number of input image channels. Default: 3.
742
+ embed_dim (int): Number of linear projection output channels. Default: 96.
743
+ depths (tuple[int]): Depths of each Swin Transformer stage.
744
+ num_heads (tuple[int]): Number of attention head of each stage.
745
+ window_size (int): Window size. Default: 7.
746
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
747
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
748
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
749
+ drop_rate (float): Dropout rate.
750
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
751
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
752
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
753
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
754
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
755
+ out_indices (Sequence[int]): Output from which stages.
756
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
757
+ -1 means not freezing any parameters.
758
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
759
+ """
760
+
761
+ def __init__(self,
762
+ pretrain_img_size=224,
763
+ patch_size=4,
764
+ in_chans=3,
765
+ embed_dim=96,
766
+ depths=[2, 2, 6, 2],
767
+ num_heads=[3, 6, 12, 24],
768
+ window_size=7,
769
+ mlp_ratio=4.,
770
+ qkv_bias=True,
771
+ qk_scale=None,
772
+ drop_rate=0.,
773
+ attn_drop_rate=0.,
774
+ drop_path_rate=0.2,
775
+ norm_layer=nn.LayerNorm,
776
+ ape=False,
777
+ patch_norm=True,
778
+ out_indices=(0, 1, 2, 3),
779
+ frozen_stages=-1,
780
+ use_checkpoint=False):
781
+ super().__init__()
782
+
783
+ self.pretrain_img_size = pretrain_img_size
784
+ self.num_layers = len(depths)
785
+ self.embed_dim = embed_dim
786
+ self.ape = ape
787
+ self.patch_norm = patch_norm
788
+ self.out_indices = out_indices
789
+ self.frozen_stages = frozen_stages
790
+
791
+ # split image into non-overlapping patches
792
+ self.patch_embed = PatchEmbed(
793
+ patch_size=patch_size,
794
+ in_chans=in_chans,
795
+ embed_dim=embed_dim,
796
+ norm_layer=norm_layer if self.patch_norm else None)
797
+
798
+ # absolute position embedding
799
+ if self.ape:
800
+ pretrain_img_size = to_2tuple(pretrain_img_size)
801
+ patch_size = to_2tuple(patch_size)
802
+ patches_resolution = [
803
+ pretrain_img_size[0] // patch_size[0],
804
+ pretrain_img_size[1] // patch_size[1]
805
+ ]
806
+
807
+ self.absolute_pos_embed = nn.Parameter(
808
+ torch.zeros(1, embed_dim, patches_resolution[0],
809
+ patches_resolution[1]))
810
+ trunc_normal_(self.absolute_pos_embed, std=.02)
811
+
812
+ self.pos_drop = nn.Dropout(p=drop_rate)
813
+
814
+ # stochastic depth
815
+ dpr = [
816
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
817
+ ] # stochastic depth decay rule
818
+
819
+ # build layers
820
+ self.layers = nn.ModuleList()
821
+ for i_layer in range(self.num_layers):
822
+ layer = BasicLayer(
823
+ dim=int(embed_dim * 2**i_layer),
824
+ depth=depths[i_layer],
825
+ num_heads=num_heads[i_layer],
826
+ window_size=window_size,
827
+ mlp_ratio=mlp_ratio,
828
+ qkv_bias=qkv_bias,
829
+ qk_scale=qk_scale,
830
+ drop=drop_rate,
831
+ attn_drop=attn_drop_rate,
832
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
833
+ norm_layer=norm_layer,
834
+ downsample=PatchMerging if
835
+ (i_layer < self.num_layers - 1) else None,
836
+ use_checkpoint=use_checkpoint)
837
+ self.layers.append(layer)
838
+
839
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
840
+ self.num_features = num_features
841
+
842
+ # add a norm layer for each output
843
+ for i_layer in out_indices:
844
+ layer = norm_layer(num_features[i_layer])
845
+ layer_name = f'norm{i_layer}'
846
+ self.add_module(layer_name, layer)
847
+
848
+ self._freeze_stages()
849
+
850
+ def _freeze_stages(self):
851
+ if self.frozen_stages >= 0:
852
+ self.patch_embed.eval()
853
+ for param in self.patch_embed.parameters():
854
+ param.requires_grad = False
855
+
856
+ if self.frozen_stages >= 1 and self.ape:
857
+ self.absolute_pos_embed.requires_grad = False
858
+
859
+ if self.frozen_stages >= 2:
860
+ self.pos_drop.eval()
861
+ for i in range(0, self.frozen_stages - 1):
862
+ m = self.layers[i]
863
+ m.eval()
864
+ for param in m.parameters():
865
+ param.requires_grad = False
866
+
867
+ def init_weights(self, pretrained=None):
868
+ """Initialize the weights in backbone.
869
+
870
+ Args:
871
+ pretrained (str, optional): Path to pre-trained weights.
872
+ Defaults to None.
873
+ """
874
+
875
+ def _init_weights(m):
876
+ if isinstance(m, nn.Linear):
877
+ trunc_normal_(m.weight, std=.02)
878
+ if isinstance(m, nn.Linear) and m.bias is not None:
879
+ nn.init.constant_(m.bias, 0)
880
+ elif isinstance(m, nn.LayerNorm):
881
+ nn.init.constant_(m.bias, 0)
882
+ nn.init.constant_(m.weight, 1.0)
883
+
884
+ if isinstance(pretrained, str):
885
+ self.apply(_init_weights)
886
+ load_checkpoint(self, pretrained, strict=False, logger=None)
887
+ elif pretrained is None:
888
+ self.apply(_init_weights)
889
+ else:
890
+ raise TypeError('pretrained must be a str or None')
891
+
892
+ def forward(self, x):
893
+ x = self.patch_embed(x)
894
+
895
+ Wh, Ww = x.size(2), x.size(3)
896
+ if self.ape:
897
+ # interpolate the position embedding to the corresponding size
898
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
899
+ size=(Wh, Ww),
900
+ mode='bicubic')
901
+ x = (x + absolute_pos_embed) # B Wh*Ww C
902
+
903
+ outs = [x.contiguous()]
904
+ x = x.flatten(2).transpose(1, 2)
905
+ x = self.pos_drop(x)
906
+ for i in range(self.num_layers):
907
+ layer = self.layers[i]
908
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
909
+
910
+ if i in self.out_indices:
911
+ norm_layer = getattr(self, f'norm{i}')
912
+ x_out = norm_layer(x_out)
913
+
914
+ out = x_out.view(-1, H, W,
915
+ self.num_features[i]).permute(0, 3, 1,
916
+ 2).contiguous()
917
+ outs.append(out)
918
+
919
+ return tuple(outs)
920
+
921
+ def train(self, mode=True):
922
+ """Convert the model into training mode while keep layers freezed."""
923
+ super(SwinTransformer, self).train(mode)
924
+ self._freeze_stages()
925
+
926
+
927
+ class PositionEmbeddingSine:
928
+
929
+ def __init__(self,
930
+ num_pos_feats=64,
931
+ temperature=10000,
932
+ normalize=False,
933
+ scale=None):
934
+ super().__init__()
935
+ self.num_pos_feats = num_pos_feats
936
+ self.temperature = temperature
937
+ self.normalize = normalize
938
+ if scale is not None and normalize is False:
939
+ raise ValueError("normalize should be True if scale is passed")
940
+ if scale is None:
941
+ scale = 2 * math.pi
942
+ self.scale = scale
943
+ self.dim_t = torch.arange(0,
944
+ self.num_pos_feats,
945
+ dtype=torch_dtype,
946
+ device=torch_device)
947
+
948
+ def __call__(self, b, h, w):
949
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
950
+ assert mask is not None
951
+ not_mask = ~mask
952
+ y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
953
+ x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
954
+ if self.normalize:
955
+ eps = 1e-6
956
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
957
+ self.scale).to(device=torch_device, dtype=torch_dtype)
958
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
959
+ self.scale).to(device=torch_device, dtype=torch_dtype)
960
+
961
+ dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
962
+
963
+ pos_x = x_embed[:, :, :, None] / dim_t
964
+ pos_y = y_embed[:, :, :, None] / dim_t
965
+ pos_x = torch.stack(
966
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
967
+ dim=4).flatten(3)
968
+ pos_y = torch.stack(
969
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
970
+ dim=4).flatten(3)
971
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
972
+
973
+
974
+ class MCLM(nn.Module):
975
+
976
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
977
+ super(MCLM, self).__init__()
978
+ self.attention = nn.ModuleList([
979
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
980
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
981
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
982
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
983
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
984
+ ])
985
+
986
+ self.linear1 = nn.Linear(d_model, d_model * 2)
987
+ self.linear2 = nn.Linear(d_model * 2, d_model)
988
+ self.linear3 = nn.Linear(d_model, d_model * 2)
989
+ self.linear4 = nn.Linear(d_model * 2, d_model)
990
+ self.norm1 = nn.LayerNorm(d_model)
991
+ self.norm2 = nn.LayerNorm(d_model)
992
+ self.dropout = nn.Dropout(0.1)
993
+ self.dropout1 = nn.Dropout(0.1)
994
+ self.dropout2 = nn.Dropout(0.1)
995
+ self.activation = get_activation_fn('relu')
996
+ self.pool_ratios = pool_ratios
997
+ self.p_poses = []
998
+ self.g_pos = None
999
+ self.positional_encoding = PositionEmbeddingSine(
1000
+ num_pos_feats=d_model // 2, normalize=True)
1001
+
1002
+ def forward(self, l, g):
1003
+ """
1004
+ l: 4,c,h,w
1005
+ g: 1,c,h,w
1006
+ """
1007
+ b, c, h, w = l.size()
1008
+ # 4,c,h,w -> 1,c,2h,2w
1009
+ concated_locs = rearrange(l,
1010
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1011
+ hg=2,
1012
+ wg=2)
1013
+
1014
+ pools = []
1015
+ for pool_ratio in self.pool_ratios:
1016
+ # b,c,h,w
1017
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1018
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1019
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1020
+ if self.g_pos is None:
1021
+ pos_emb = self.positional_encoding(pool.shape[0],
1022
+ pool.shape[2],
1023
+ pool.shape[3])
1024
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1025
+ self.p_poses.append(pos_emb)
1026
+ pools = torch.cat(pools, 0)
1027
+ if self.g_pos is None:
1028
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1029
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
1030
+ g.shape[3])
1031
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1032
+
1033
+ # attention between glb (q) & multisensory concated-locs (k,v)
1034
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1035
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1036
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1037
+ g_hw_b_c = self.norm1(g_hw_b_c)
1038
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1039
+ self.linear2(
1040
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1041
+ g_hw_b_c = self.norm2(g_hw_b_c)
1042
+
1043
+ # attention between origin locs (q) & freashed glb (k,v)
1044
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1045
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1046
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1047
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1048
+ ng=2,
1049
+ nw=2)
1050
+ outputs_re = []
1051
+ for i, (_l, _g) in enumerate(
1052
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1053
+ outputs_re.append(self.attention[i + 1](_l, _g,
1054
+ _g)[0]) # (h w) 1 c
1055
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1056
+
1057
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1058
+ l_hw_b_c = self.norm1(l_hw_b_c)
1059
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1060
+ self.linear4(
1061
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1062
+ l_hw_b_c = self.norm2(l_hw_b_c)
1063
+
1064
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1065
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1066
+
1067
+
1068
+ class inf_MCLM(nn.Module):
1069
+
1070
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
1071
+ super(inf_MCLM, self).__init__()
1072
+ self.attention = nn.ModuleList([
1073
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1074
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1075
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1076
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1077
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1078
+ ])
1079
+
1080
+ self.linear1 = nn.Linear(d_model, d_model * 2)
1081
+ self.linear2 = nn.Linear(d_model * 2, d_model)
1082
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1083
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1084
+ self.norm1 = nn.LayerNorm(d_model)
1085
+ self.norm2 = nn.LayerNorm(d_model)
1086
+ self.dropout = nn.Dropout(0.1)
1087
+ self.dropout1 = nn.Dropout(0.1)
1088
+ self.dropout2 = nn.Dropout(0.1)
1089
+ self.activation = get_activation_fn('relu')
1090
+ self.pool_ratios = pool_ratios
1091
+ self.p_poses = []
1092
+ self.g_pos = None
1093
+ self.positional_encoding = PositionEmbeddingSine(
1094
+ num_pos_feats=d_model // 2, normalize=True)
1095
+
1096
+ def forward(self, l, g):
1097
+ """
1098
+ l: 4,c,h,w
1099
+ g: 1,c,h,w
1100
+ """
1101
+ b, c, h, w = l.size()
1102
+ # 4,c,h,w -> 1,c,2h,2w
1103
+ concated_locs = rearrange(l,
1104
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1105
+ hg=2,
1106
+ wg=2)
1107
+ self.p_poses = []
1108
+ pools = []
1109
+ for pool_ratio in self.pool_ratios:
1110
+ # b,c,h,w
1111
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1112
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1113
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1114
+ # if self.g_pos is None:
1115
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
1116
+ pool.shape[3])
1117
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1118
+ self.p_poses.append(pos_emb)
1119
+ pools = torch.cat(pools, 0)
1120
+ # if self.g_pos is None:
1121
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1122
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
1123
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1124
+
1125
+ # attention between glb (q) & multisensory concated-locs (k,v)
1126
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1127
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1128
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1129
+ g_hw_b_c = self.norm1(g_hw_b_c)
1130
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1131
+ self.linear2(
1132
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1133
+ g_hw_b_c = self.norm2(g_hw_b_c)
1134
+
1135
+ # attention between origin locs (q) & freashed glb (k,v)
1136
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1137
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1138
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1139
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1140
+ ng=2,
1141
+ nw=2)
1142
+ outputs_re = []
1143
+ for i, (_l, _g) in enumerate(
1144
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1145
+ outputs_re.append(self.attention[i + 1](_l, _g,
1146
+ _g)[0]) # (h w) 1 c
1147
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1148
+
1149
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1150
+ l_hw_b_c = self.norm1(l_hw_b_c)
1151
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1152
+ self.linear4(
1153
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1154
+ l_hw_b_c = self.norm2(l_hw_b_c)
1155
+
1156
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1157
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1158
+
1159
+
1160
+ class MCRM(nn.Module):
1161
+
1162
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1163
+ super(MCRM, self).__init__()
1164
+ self.attention = nn.ModuleList([
1165
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1166
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1167
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1168
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1169
+ ])
1170
+
1171
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1172
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1173
+ self.norm1 = nn.LayerNorm(d_model)
1174
+ self.norm2 = nn.LayerNorm(d_model)
1175
+ self.dropout = nn.Dropout(0.1)
1176
+ self.dropout1 = nn.Dropout(0.1)
1177
+ self.dropout2 = nn.Dropout(0.1)
1178
+ self.sigmoid = nn.Sigmoid()
1179
+ self.activation = get_activation_fn('relu')
1180
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1181
+ self.pool_ratios = pool_ratios
1182
+ self.positional_encoding = PositionEmbeddingSine(
1183
+ num_pos_feats=d_model // 2, normalize=True)
1184
+
1185
+ def forward(self, x):
1186
+ b, c, h, w = x.size()
1187
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1188
+ # b(4),c,h,w
1189
+ patched_glb = rearrange(glb,
1190
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1191
+ hg=2,
1192
+ wg=2)
1193
+
1194
+ # generate token attention map
1195
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1196
+ token_attention_map = F.interpolate(token_attention_map,
1197
+ size=patches2image(loc).shape[-2:],
1198
+ mode='nearest')
1199
+ loc = loc * rearrange(token_attention_map,
1200
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1201
+ hg=2,
1202
+ wg=2)
1203
+ pools = []
1204
+ for pool_ratio in self.pool_ratios:
1205
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1206
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1207
+ pools.append(rearrange(pool,
1208
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1209
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1210
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1211
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1212
+ outputs = []
1213
+ for i, q in enumerate(
1214
+ loc_.unbind(dim=0)): # traverse all local patches
1215
+ # np*hw,1,c
1216
+ v = pools[i]
1217
+ k = v
1218
+ outputs.append(self.attention[i](q, k, v)[0])
1219
+ outputs = torch.cat(outputs, 1)
1220
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1221
+ src = self.norm1(src)
1222
+ src = src + self.dropout2(
1223
+ self.linear4(
1224
+ self.dropout(self.activation(self.linear3(src)).clone())))
1225
+ src = self.norm2(src)
1226
+
1227
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1228
+ glb = glb + F.interpolate(patches2image(src),
1229
+ size=glb.shape[-2:],
1230
+ mode='nearest') # freshed glb
1231
+ return torch.cat((src, glb), 0), token_attention_map
1232
+
1233
+
1234
+ class inf_MCRM(nn.Module):
1235
+
1236
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1237
+ super(inf_MCRM, self).__init__()
1238
+ self.attention = nn.ModuleList([
1239
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1240
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1241
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1242
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1243
+ ])
1244
+
1245
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1246
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1247
+ self.norm1 = nn.LayerNorm(d_model)
1248
+ self.norm2 = nn.LayerNorm(d_model)
1249
+ self.dropout = nn.Dropout(0.1)
1250
+ self.dropout1 = nn.Dropout(0.1)
1251
+ self.dropout2 = nn.Dropout(0.1)
1252
+ self.sigmoid = nn.Sigmoid()
1253
+ self.activation = get_activation_fn('relu')
1254
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1255
+ self.pool_ratios = pool_ratios
1256
+ self.positional_encoding = PositionEmbeddingSine(
1257
+ num_pos_feats=d_model // 2, normalize=True)
1258
+
1259
+ def forward(self, x):
1260
+ b, c, h, w = x.size()
1261
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1262
+ # b(4),c,h,w
1263
+ patched_glb = rearrange(glb,
1264
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1265
+ hg=2,
1266
+ wg=2)
1267
+
1268
+ # generate token attention map
1269
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1270
+ token_attention_map = F.interpolate(token_attention_map,
1271
+ size=patches2image(loc).shape[-2:],
1272
+ mode='nearest')
1273
+ loc = loc * rearrange(token_attention_map,
1274
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1275
+ hg=2,
1276
+ wg=2)
1277
+ pools = []
1278
+ for pool_ratio in self.pool_ratios:
1279
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1280
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1281
+ pools.append(rearrange(pool,
1282
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1283
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1284
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1285
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1286
+ outputs = []
1287
+ for i, q in enumerate(
1288
+ loc_.unbind(dim=0)): # traverse all local patches
1289
+ # np*hw,1,c
1290
+ v = pools[i]
1291
+ k = v
1292
+ outputs.append(self.attention[i](q, k, v)[0])
1293
+ outputs = torch.cat(outputs, 1)
1294
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1295
+ src = self.norm1(src)
1296
+ src = src + self.dropout2(
1297
+ self.linear4(
1298
+ self.dropout(self.activation(self.linear3(src)).clone())))
1299
+ src = self.norm2(src)
1300
+
1301
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1302
+ glb = glb + F.interpolate(patches2image(src),
1303
+ size=glb.shape[-2:],
1304
+ mode='nearest') # freshed glb
1305
+ return torch.cat((src, glb), 0)
1306
+
1307
+
1308
+ # model for single-scale training
1309
+ class MVANet(nn.Module):
1310
+
1311
+ def __init__(self):
1312
+ super().__init__()
1313
+ self.backbone = SwinB(pretrained=True)
1314
+ emb_dim = 128
1315
+ self.sideout5 = nn.Sequential(
1316
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1317
+ self.sideout4 = nn.Sequential(
1318
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1319
+ self.sideout3 = nn.Sequential(
1320
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1321
+ self.sideout2 = nn.Sequential(
1322
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1323
+ self.sideout1 = nn.Sequential(
1324
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1325
+
1326
+ self.output5 = make_cbr(1024, emb_dim)
1327
+ self.output4 = make_cbr(512, emb_dim)
1328
+ self.output3 = make_cbr(256, emb_dim)
1329
+ self.output2 = make_cbr(128, emb_dim)
1330
+ self.output1 = make_cbr(128, emb_dim)
1331
+
1332
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
1333
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1334
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1335
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1336
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1337
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
1338
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
1339
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
1340
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
1341
+
1342
+ self.insmask_head = nn.Sequential(
1343
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1344
+ nn.BatchNorm2d(384), nn.PReLU(),
1345
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1346
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1347
+
1348
+ self.shallow = nn.Sequential(
1349
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1350
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1351
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1352
+ self.output = nn.Sequential(
1353
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1354
+
1355
+ for m in self.modules():
1356
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1357
+ m.inplace = True
1358
+
1359
+ def forward(self, x):
1360
+ x = x.to(dtype=torch_dtype, device=torch_device)
1361
+ shallow = self.shallow(x)
1362
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1363
+ loc = image2patches(x)
1364
+ input = torch.cat((loc, glb), dim=0)
1365
+ feature = self.backbone(input)
1366
+ e5 = self.output5(feature[4]) # (5,128,16,16)
1367
+ e4 = self.output4(feature[3]) # (5,128,32,32)
1368
+ e3 = self.output3(feature[2]) # (5,128,64,64)
1369
+ e2 = self.output2(feature[1]) # (5,128,128,128)
1370
+ e1 = self.output1(feature[0]) # (5,128,128,128)
1371
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1372
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
1373
+
1374
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
1375
+ e4 = self.conv4(e4)
1376
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
1377
+ e3 = self.conv3(e3)
1378
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
1379
+ e2 = self.conv2(e2)
1380
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
1381
+ e1 = self.conv1(e1)
1382
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1383
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
1384
+ # add glb feat in
1385
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1386
+ # merge
1387
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
1388
+ # shallow feature merge
1389
+ final_output = final_output + resize_as(shallow, final_output)
1390
+ final_output = self.upsample1(rescale_to(final_output))
1391
+ final_output = rescale_to(final_output +
1392
+ resize_as(shallow, final_output))
1393
+ final_output = self.upsample2(final_output)
1394
+ final_output = self.output(final_output)
1395
+ ####
1396
+ sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
1397
+ sideout4 = self.sideout4(e4)
1398
+ sideout3 = self.sideout3(e3)
1399
+ sideout2 = self.sideout2(e2)
1400
+ sideout1 = self.sideout1(e1)
1401
+ #######glb_sideouts ######
1402
+ glb5 = self.sideout5(glb_e5)
1403
+ glb4 = sideout4[-1, :, :, :].unsqueeze(0)
1404
+ glb3 = sideout3[-1, :, :, :].unsqueeze(0)
1405
+ glb2 = sideout2[-1, :, :, :].unsqueeze(0)
1406
+ glb1 = sideout1[-1, :, :, :].unsqueeze(0)
1407
+ ####### concat 4 to 1 #######
1408
+ sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
1409
+ device=torch_device)
1410
+ sideout2 = patches2image(sideout2[:-1]).to(
1411
+ dtype=torch_dtype,
1412
+ device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
1413
+ sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
1414
+ device=torch_device)
1415
+ sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
1416
+ device=torch_device)
1417
+ sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
1418
+ device=torch_device)
1419
+ if self.training:
1420
+ return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
1421
+ else:
1422
+ return final_output
1423
+
1424
+
1425
+ # model for multi-scale testing
1426
+ class inf_MVANet(nn.Module):
1427
+
1428
+ def __init__(self):
1429
+ super().__init__()
1430
+ # self.backbone = SwinB(pretrained=True)
1431
+ self.backbone = SwinB(pretrained=False)
1432
+
1433
+ emb_dim = 128
1434
+ self.output5 = make_cbr(1024, emb_dim)
1435
+ self.output4 = make_cbr(512, emb_dim)
1436
+ self.output3 = make_cbr(256, emb_dim)
1437
+ self.output2 = make_cbr(128, emb_dim)
1438
+ self.output1 = make_cbr(128, emb_dim)
1439
+
1440
+ self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
1441
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1442
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1443
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1444
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1445
+ self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1446
+ self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1447
+ self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1448
+ self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1449
+
1450
+ self.insmask_head = nn.Sequential(
1451
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1452
+ nn.BatchNorm2d(384), nn.PReLU(),
1453
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1454
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1455
+
1456
+ self.shallow = nn.Sequential(
1457
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1458
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1459
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1460
+ self.output = nn.Sequential(
1461
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1462
+
1463
+ for m in self.modules():
1464
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1465
+ m.inplace = True
1466
+
1467
+ def forward(self, x):
1468
+ shallow = self.shallow(x)
1469
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1470
+ loc = image2patches(x)
1471
+ input = torch.cat((loc, glb), dim=0)
1472
+ feature = self.backbone(input)
1473
+ e5 = self.output5(feature[4])
1474
+ e4 = self.output4(feature[3])
1475
+ e3 = self.output3(feature[2])
1476
+ e2 = self.output2(feature[1])
1477
+ e1 = self.output1(feature[0])
1478
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1479
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
1480
+
1481
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1482
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1483
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1484
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1485
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1486
+ # after decoder, concat loc features to a whole one, and merge
1487
+ output1_cat = patches2image(loc_e1)
1488
+ # add glb feat in
1489
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1490
+ # merge
1491
+ final_output = self.insmask_head(output1_cat)
1492
+ # shallow feature merge
1493
+ final_output = final_output + resize_as(shallow, final_output)
1494
+ final_output = self.upsample1(rescale_to(final_output))
1495
+ final_output = rescale_to(final_output +
1496
+ resize_as(shallow, final_output))
1497
+ final_output = self.upsample2(final_output)
1498
+ final_output = self.output(final_output)
1499
+ return final_output
1500
+ #+end_src
1501
+
1502
+ ** Function to load model
1503
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1504
+ def mkdir_safe(out_path):
1505
+ if type(out_path) == str:
1506
+ if len(out_path) > 0:
1507
+ if not os.path.exists(out_path):
1508
+ os.mkdir(out_path)
1509
+
1510
+
1511
+ def get_model_path():
1512
+ import folder_paths
1513
+ from folder_paths import models_dir
1514
+
1515
+ path_file_model = models_dir
1516
+ mkdir_safe(out_path=path_file_model)
1517
+
1518
+ path_file_model = os.path.join(path_file_model, 'MVANet')
1519
+ mkdir_safe(out_path=path_file_model)
1520
+
1521
+ path_file_model = os.path.join(path_file_model, 'Model_80.pth')
1522
+
1523
+ return path_file_model
1524
+
1525
+
1526
+ def download_model(path):
1527
+ if not os.path.exists(path):
1528
+ wget.download(
1529
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
1530
+ out=path)
1531
+
1532
+
1533
+ def load_model(model_checkpoint_path):
1534
+ download_model(path=model_checkpoint_path)
1535
+ torch.cuda.set_device(0)
1536
+
1537
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
1538
+
1539
+ pretrained_dict = torch.load(finetuned_MVANet_model_path,
1540
+ map_location=torch_device)
1541
+
1542
+ model_dict = net.state_dict()
1543
+ pretrained_dict = {
1544
+ k: v
1545
+ for k, v in pretrained_dict.items() if k in model_dict
1546
+ }
1547
+ model_dict.update(pretrained_dict)
1548
+ net.load_state_dict(model_dict)
1549
+ net = net.to(dtype=torch_dtype, device=torch_device)
1550
+ net.eval()
1551
+ return net
1552
+ #+end_src
1553
+
1554
+ ** Function for modular inference CV
1555
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1556
+ def do_infer_tensor2tensor(img, net):
1557
+
1558
+ img_transform = transforms.Compose(
1559
+ [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
1560
+
1561
+ h_, w_ = img.shape[1], img.shape[2]
1562
+
1563
+ with torch.no_grad():
1564
+
1565
+ img = rearrange(img, 'B H W C -> B C H W')
1566
+
1567
+ img_resize = torch.nn.functional.interpolate(input=img,
1568
+ size=(1024, 1024),
1569
+ mode='bicubic',
1570
+ antialias=True)
1571
+
1572
+ img_var = img_transform(img_resize)
1573
+ img_var = Variable(img_var)
1574
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1575
+
1576
+ mask = []
1577
+
1578
+ mask.append(net(img_var))
1579
+
1580
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1581
+ prediction = prediction.sigmoid()
1582
+
1583
+ prediction = torch.nn.functional.interpolate(input=prediction,
1584
+ size=(h_, w_),
1585
+ mode='bicubic',
1586
+ antialias=True)
1587
+
1588
+ prediction = prediction.squeeze(0)
1589
+ prediction = prediction.clamp(0, 1)
1590
+ prediction = prediction.detach()
1591
+ prediction = prediction.to(dtype=torch.float32, device='cpu')
1592
+
1593
+ return prediction
1594
+ #+end_src
1595
+
1596
+ ** Comfyui wrapper classes
1597
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
1598
+ class load_MVANet_Model:
1599
+
1600
+ def __init__(self):
1601
+ pass
1602
+
1603
+ @classmethod
1604
+ def INPUT_TYPES(s):
1605
+ return {
1606
+ "required": {},
1607
+ }
1608
+
1609
+ RETURN_TYPES = ("MVANet_Model", )
1610
+ FUNCTION = "test"
1611
+ CATEGORY = "MVANet"
1612
+
1613
+ def test(self):
1614
+ return (load_model(get_model_path()), )
1615
+
1616
+
1617
+ class run_MVANet_inference:
1618
+
1619
+ def __init__(self):
1620
+ pass
1621
+
1622
+ @classmethod
1623
+ def INPUT_TYPES(s):
1624
+ return {
1625
+ "required": {
1626
+ "image": ("IMAGE", ),
1627
+ "MVANet_Model": ("MVANet_Model", ),
1628
+ },
1629
+ }
1630
+
1631
+ RETURN_TYPES = ("MASK", )
1632
+ FUNCTION = "test"
1633
+ CATEGORY = "MVANet"
1634
+
1635
+ def test(
1636
+ self,
1637
+ image,
1638
+ MVANet_Model,
1639
+ ):
1640
+ ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
1641
+
1642
+ return (ret, )
1643
+ #+end_src
1644
+
1645
+ ** MVANet_inference execute
1646
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
1647
+ NODE_CLASS_MAPPINGS = {
1648
+ "load_MVANet_Model": load_MVANet_Model,
1649
+ "run_MVANet_inference": run_MVANet_inference
1650
+ }
1651
+
1652
+ NODE_DISPLAY_NAME_MAPPINGS = {
1653
+ "load_MVANet_Model": "load MVANet Model",
1654
+ "run_MVANet_inference": "run_MVANet_inference"
1655
+ }
1656
+ #+end_src
1657
+
1658
+ ** MVANet_inference unify
1659
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
1660
+ . "${HOME}/dbnew.sh"
1661
+
1662
+ (
1663
+ echo '#!/usr/bin/python3'
1664
+ cat \
1665
+ './MVANet_inference.import.py' \
1666
+ './MVANet_inference.function.py' \
1667
+ './MVANet_inference.class.py' \
1668
+ './MVANet_inference.execute.py' \
1669
+ | expand | yapf3 \
1670
+ | grep -v '#!/usr/bin/python3' \
1671
+ ;
1672
+ ) > './MVANet_inference.py' \
1673
+ ;
1674
+
1675
+ cp './MVANet_inference.py' '__init__.py'
1676
+ #+end_src
1677
+
1678
+ * WORK SPACE
1679
+
1680
+ ** elisp
1681
+ #+begin_src elisp
1682
+ (save-buffer)
1683
+ (org-babel-tangle)
1684
+ (shell-command "./MVANet_inference.unify.sh")
1685
+ #+end_src
1686
+
1687
+ #+RESULTS:
1688
+ : 0
1689
+
1690
+ ** sh
1691
+ #+begin_src sh :shebang #!/bin/sh :results output
1692
+ realpath .
1693
+ cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/MVANet
1694
+ #+end_src
ComfyUI_MVANet/__init__.py ADDED
@@ -0,0 +1,1548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ import sys
4
+
5
+ HOME_DIR = os.environ.get('HOME', '/root')
6
+ MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
7
+ finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
8
+ pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
9
+
10
+ import math
11
+ import numpy as np
12
+ import cv2
13
+ import wget
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from torch.autograd import Variable
20
+ from torch import nn
21
+ from torchvision import transforms
22
+
23
+ from einops import rearrange
24
+
25
+ from timm.models import load_checkpoint
26
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
+
28
+ torch_device = 'cuda'
29
+ torch_dtype = torch.float16
30
+
31
+
32
+ def check_mkdir(dir_name):
33
+ if not os.path.isdir(dir_name):
34
+ os.makedirs(dir_name)
35
+
36
+
37
+ def SwinT(pretrained=True):
38
+ model = SwinTransformer(embed_dim=96,
39
+ depths=[2, 2, 6, 2],
40
+ num_heads=[3, 6, 12, 24],
41
+ window_size=7)
42
+ if pretrained is True:
43
+ model.load_state_dict(torch.load(
44
+ 'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
45
+ map_location='cpu')['model'],
46
+ strict=False)
47
+
48
+ return model
49
+
50
+
51
+ def SwinS(pretrained=True):
52
+ model = SwinTransformer(embed_dim=96,
53
+ depths=[2, 2, 18, 2],
54
+ num_heads=[3, 6, 12, 24],
55
+ window_size=7)
56
+ if pretrained is True:
57
+ model.load_state_dict(torch.load(
58
+ 'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
59
+ map_location='cpu')['model'],
60
+ strict=False)
61
+
62
+ return model
63
+
64
+
65
+ def SwinB(pretrained=True):
66
+ model = SwinTransformer(embed_dim=128,
67
+ depths=[2, 2, 18, 2],
68
+ num_heads=[4, 8, 16, 32],
69
+ window_size=12)
70
+ if pretrained is True:
71
+ import os
72
+ model.load_state_dict(torch.load(pretrained_SwinB_model_path,
73
+ map_location='cpu')['model'],
74
+ strict=False)
75
+ return model
76
+
77
+
78
+ def SwinL(pretrained=True):
79
+ model = SwinTransformer(embed_dim=192,
80
+ depths=[2, 2, 18, 2],
81
+ num_heads=[6, 12, 24, 48],
82
+ window_size=12)
83
+ if pretrained is True:
84
+ model.load_state_dict(torch.load(
85
+ 'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
86
+ map_location='cpu')['model'],
87
+ strict=False)
88
+
89
+ return model
90
+
91
+
92
+ def get_activation_fn(activation):
93
+ """Return an activation function given a string"""
94
+ if activation == "relu":
95
+ return F.relu
96
+ if activation == "gelu":
97
+ return F.gelu
98
+ if activation == "glu":
99
+ return F.glu
100
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
101
+
102
+
103
+ def make_cbr(in_dim, out_dim):
104
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
105
+ nn.BatchNorm2d(out_dim), nn.PReLU())
106
+
107
+
108
+ def make_cbg(in_dim, out_dim):
109
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
110
+ nn.BatchNorm2d(out_dim), nn.GELU())
111
+
112
+
113
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
114
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
115
+
116
+
117
+ def resize_as(x, y, interpolation='bilinear'):
118
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
119
+
120
+
121
+ def image2patches(x):
122
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
123
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
124
+ return x
125
+
126
+
127
+ def patches2image(x):
128
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
129
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
130
+ return x
131
+
132
+
133
+ def window_partition(x, window_size):
134
+ """
135
+ Args:
136
+ x: (B, H, W, C)
137
+ window_size (int): window size
138
+
139
+ Returns:
140
+ windows: (num_windows*B, window_size, window_size, C)
141
+ """
142
+ B, H, W, C = x.shape
143
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
144
+ C)
145
+ windows = x.permute(0, 1, 3, 2, 4,
146
+ 5).contiguous().view(-1, window_size, window_size, C)
147
+ return windows
148
+
149
+
150
+ def window_reverse(windows, window_size, H, W):
151
+ """
152
+ Args:
153
+ windows: (num_windows*B, window_size, window_size, C)
154
+ window_size (int): Window size
155
+ H (int): Height of image
156
+ W (int): Width of image
157
+
158
+ Returns:
159
+ x: (B, H, W, C)
160
+ """
161
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
162
+ x = windows.view(B, H // window_size, W // window_size, window_size,
163
+ window_size, -1)
164
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
165
+ return x
166
+
167
+
168
+ def mkdir_safe(out_path):
169
+ if type(out_path) == str:
170
+ if len(out_path) > 0:
171
+ if not os.path.exists(out_path):
172
+ os.mkdir(out_path)
173
+
174
+
175
+ def get_model_path():
176
+ import folder_paths
177
+ from folder_paths import models_dir
178
+
179
+ path_file_model = models_dir
180
+ mkdir_safe(out_path=path_file_model)
181
+
182
+ path_file_model = os.path.join(path_file_model, 'MVANet')
183
+ mkdir_safe(out_path=path_file_model)
184
+
185
+ path_file_model = os.path.join(path_file_model, 'Model_80.pth')
186
+
187
+ return path_file_model
188
+
189
+
190
+ def download_model(path):
191
+ if not os.path.exists(path):
192
+ wget.download(
193
+ 'https://huggingface.co/aravindhv10/Self-Correction-Human-Parsing/resolve/main/checkpoints/Model_80.pth',
194
+ out=path)
195
+
196
+
197
+ def load_model(model_checkpoint_path):
198
+ download_model(path=model_checkpoint_path)
199
+ torch.cuda.set_device(0)
200
+
201
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
202
+
203
+ pretrained_dict = torch.load(finetuned_MVANet_model_path,
204
+ map_location=torch_device)
205
+
206
+ model_dict = net.state_dict()
207
+ pretrained_dict = {
208
+ k: v
209
+ for k, v in pretrained_dict.items() if k in model_dict
210
+ }
211
+ model_dict.update(pretrained_dict)
212
+ net.load_state_dict(model_dict)
213
+ net = net.to(dtype=torch_dtype, device=torch_device)
214
+ net.eval()
215
+ return net
216
+
217
+
218
+ def do_infer_tensor2tensor(img, net):
219
+
220
+ img_transform = transforms.Compose(
221
+ [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
222
+
223
+ h_, w_ = img.shape[1], img.shape[2]
224
+
225
+ with torch.no_grad():
226
+
227
+ img = rearrange(img, 'B H W C -> B C H W')
228
+
229
+ img_resize = torch.nn.functional.interpolate(input=img,
230
+ size=(1024, 1024),
231
+ mode='bicubic',
232
+ antialias=True)
233
+
234
+ img_var = img_transform(img_resize)
235
+ img_var = Variable(img_var)
236
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
237
+
238
+ mask = []
239
+
240
+ mask.append(net(img_var))
241
+
242
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
243
+ prediction = prediction.sigmoid()
244
+
245
+ prediction = torch.nn.functional.interpolate(input=prediction,
246
+ size=(h_, w_),
247
+ mode='bicubic',
248
+ antialias=True)
249
+
250
+ prediction = prediction.squeeze(0)
251
+ prediction = prediction.clamp(0, 1)
252
+ prediction = prediction.detach()
253
+ prediction = prediction.to(dtype=torch.float32, device='cpu')
254
+
255
+ return prediction
256
+
257
+
258
+ class Mlp(nn.Module):
259
+ """ Multilayer perceptron."""
260
+
261
+ def __init__(self,
262
+ in_features,
263
+ hidden_features=None,
264
+ out_features=None,
265
+ act_layer=nn.GELU,
266
+ drop=0.):
267
+ super().__init__()
268
+ out_features = out_features or in_features
269
+ hidden_features = hidden_features or in_features
270
+ self.fc1 = nn.Linear(in_features, hidden_features)
271
+ self.act = act_layer()
272
+ self.fc2 = nn.Linear(hidden_features, out_features)
273
+ self.drop = nn.Dropout(drop)
274
+
275
+ def forward(self, x):
276
+ x = self.fc1(x)
277
+ x = self.act(x)
278
+ x = self.drop(x)
279
+ x = self.fc2(x)
280
+ x = self.drop(x)
281
+ return x
282
+
283
+
284
+ class WindowAttention(nn.Module):
285
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
286
+ It supports both of shifted and non-shifted window.
287
+
288
+ Args:
289
+ dim (int): Number of input channels.
290
+ window_size (tuple[int]): The height and width of the window.
291
+ num_heads (int): Number of attention heads.
292
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
293
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
294
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
295
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
296
+ """
297
+
298
+ def __init__(self,
299
+ dim,
300
+ window_size,
301
+ num_heads,
302
+ qkv_bias=True,
303
+ qk_scale=None,
304
+ attn_drop=0.,
305
+ proj_drop=0.):
306
+
307
+ super().__init__()
308
+ self.dim = dim
309
+ self.window_size = window_size # Wh, Ww
310
+ self.num_heads = num_heads
311
+ head_dim = dim // num_heads
312
+ self.scale = qk_scale or head_dim**-0.5
313
+
314
+ # define a parameter table of relative position bias
315
+ self.relative_position_bias_table = nn.Parameter(
316
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
317
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
318
+
319
+ # get pair-wise relative position index for each token inside the window
320
+ coords_h = torch.arange(self.window_size[0])
321
+ coords_w = torch.arange(self.window_size[1])
322
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
323
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
324
+ relative_coords = coords_flatten[:, :,
325
+ None] - coords_flatten[:,
326
+ None, :] # 2, Wh*Ww, Wh*Ww
327
+ relative_coords = relative_coords.permute(
328
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
329
+ relative_coords[:, :,
330
+ 0] += self.window_size[0] - 1 # shift to start from 0
331
+ relative_coords[:, :, 1] += self.window_size[1] - 1
332
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
333
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
334
+ self.register_buffer("relative_position_index",
335
+ relative_position_index)
336
+
337
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
338
+ self.attn_drop = nn.Dropout(attn_drop)
339
+ self.proj = nn.Linear(dim, dim)
340
+ self.proj_drop = nn.Dropout(proj_drop)
341
+
342
+ trunc_normal_(self.relative_position_bias_table, std=.02)
343
+ self.softmax = nn.Softmax(dim=-1)
344
+
345
+ def forward(self, x, mask=None):
346
+ """ Forward function.
347
+
348
+ Args:
349
+ x: input features with shape of (num_windows*B, N, C)
350
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
351
+ """
352
+ x = x.to(dtype=torch_dtype, device=torch_device)
353
+ B_, N, C = x.shape
354
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
355
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
356
+ q, k, v = qkv[0], qkv[1], qkv[
357
+ 2] # make torchscript happy (cannot use tensor as tuple)
358
+
359
+ q = q * self.scale
360
+ attn = (q @ k.transpose(-2, -1))
361
+
362
+ relative_position_bias = self.relative_position_bias_table[
363
+ self.relative_position_index.view(-1)].view(
364
+ self.window_size[0] * self.window_size[1],
365
+ self.window_size[0] * self.window_size[1],
366
+ -1) # Wh*Ww,Wh*Ww,nH
367
+ relative_position_bias = relative_position_bias.permute(
368
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
369
+ attn = attn + relative_position_bias.unsqueeze(0)
370
+
371
+ if mask is not None:
372
+ nW = mask.shape[0]
373
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
374
+ N) + mask.unsqueeze(1).unsqueeze(0)
375
+ attn = attn.view(-1, self.num_heads, N, N)
376
+ attn = self.softmax(attn)
377
+ else:
378
+ attn = self.softmax(attn)
379
+
380
+ attn = self.attn_drop(attn)
381
+ attn = attn.to(dtype=torch_dtype, device=torch_device)
382
+ v = v.to(dtype=torch_dtype, device=torch_device)
383
+
384
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
385
+ x = self.proj(x)
386
+ x = self.proj_drop(x)
387
+ return x
388
+
389
+
390
+ class SwinTransformerBlock(nn.Module):
391
+ """ Swin Transformer Block.
392
+
393
+ Args:
394
+ dim (int): Number of input channels.
395
+ num_heads (int): Number of attention heads.
396
+ window_size (int): Window size.
397
+ shift_size (int): Shift size for SW-MSA.
398
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
399
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
400
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
401
+ drop (float, optional): Dropout rate. Default: 0.0
402
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
403
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
404
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
405
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
406
+ """
407
+
408
+ def __init__(self,
409
+ dim,
410
+ num_heads,
411
+ window_size=7,
412
+ shift_size=0,
413
+ mlp_ratio=4.,
414
+ qkv_bias=True,
415
+ qk_scale=None,
416
+ drop=0.,
417
+ attn_drop=0.,
418
+ drop_path=0.,
419
+ act_layer=nn.GELU,
420
+ norm_layer=nn.LayerNorm):
421
+ super().__init__()
422
+ self.dim = dim
423
+ self.num_heads = num_heads
424
+ self.window_size = window_size
425
+ self.shift_size = shift_size
426
+ self.mlp_ratio = mlp_ratio
427
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
428
+
429
+ self.norm1 = norm_layer(dim)
430
+ self.attn = WindowAttention(dim,
431
+ window_size=to_2tuple(self.window_size),
432
+ num_heads=num_heads,
433
+ qkv_bias=qkv_bias,
434
+ qk_scale=qk_scale,
435
+ attn_drop=attn_drop,
436
+ proj_drop=drop)
437
+
438
+ self.drop_path = DropPath(
439
+ drop_path) if drop_path > 0. else nn.Identity()
440
+ self.norm2 = norm_layer(dim)
441
+ mlp_hidden_dim = int(dim * mlp_ratio)
442
+ self.mlp = Mlp(in_features=dim,
443
+ hidden_features=mlp_hidden_dim,
444
+ act_layer=act_layer,
445
+ drop=drop)
446
+
447
+ self.H = None
448
+ self.W = None
449
+
450
+ def forward(self, x, mask_matrix):
451
+ """ Forward function.
452
+
453
+ Args:
454
+ x: Input feature, tensor size (B, H*W, C).
455
+ H, W: Spatial resolution of the input feature.
456
+ mask_matrix: Attention mask for cyclic shift.
457
+ """
458
+ B, L, C = x.shape
459
+ H, W = self.H, self.W
460
+ assert L == H * W, "input feature has wrong size"
461
+
462
+ shortcut = x
463
+ x = self.norm1(x)
464
+ x = x.view(B, H, W, C)
465
+
466
+ # pad feature maps to multiples of window size
467
+ pad_l = pad_t = 0
468
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
469
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
470
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
471
+ _, Hp, Wp, _ = x.shape
472
+
473
+ # cyclic shift
474
+ if self.shift_size > 0:
475
+ shifted_x = torch.roll(x,
476
+ shifts=(-self.shift_size, -self.shift_size),
477
+ dims=(1, 2))
478
+ attn_mask = mask_matrix
479
+ else:
480
+ shifted_x = x
481
+ attn_mask = None
482
+
483
+ # partition windows
484
+ x_windows = window_partition(
485
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
486
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
487
+ C) # nW*B, window_size*window_size, C
488
+
489
+ # W-MSA/SW-MSA
490
+ attn_windows = self.attn(
491
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
492
+
493
+ # merge windows
494
+ attn_windows = attn_windows.view(-1, self.window_size,
495
+ self.window_size, C)
496
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
497
+ Wp) # B H' W' C
498
+
499
+ # reverse cyclic shift
500
+ if self.shift_size > 0:
501
+ x = torch.roll(shifted_x,
502
+ shifts=(self.shift_size, self.shift_size),
503
+ dims=(1, 2))
504
+ else:
505
+ x = shifted_x
506
+
507
+ if pad_r > 0 or pad_b > 0:
508
+ x = x[:, :H, :W, :].contiguous()
509
+
510
+ x = x.view(B, H * W, C)
511
+
512
+ # FFN
513
+ x = shortcut + self.drop_path(x)
514
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
515
+
516
+ return x
517
+
518
+
519
+ class PatchMerging(nn.Module):
520
+ """ Patch Merging Layer
521
+
522
+ Args:
523
+ dim (int): Number of input channels.
524
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
525
+ """
526
+
527
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
528
+ super().__init__()
529
+ self.dim = dim
530
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
531
+ self.norm = norm_layer(4 * dim)
532
+
533
+ def forward(self, x, H, W):
534
+ """ Forward function.
535
+
536
+ Args:
537
+ x: Input feature, tensor size (B, H*W, C).
538
+ H, W: Spatial resolution of the input feature.
539
+ """
540
+ B, L, C = x.shape
541
+ assert L == H * W, "input feature has wrong size"
542
+
543
+ x = x.view(B, H, W, C)
544
+
545
+ # padding
546
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
547
+ if pad_input:
548
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
549
+
550
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
551
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
552
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
553
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
554
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
555
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
556
+
557
+ x = self.norm(x)
558
+ x = self.reduction(x)
559
+
560
+ return x
561
+
562
+
563
+ class BasicLayer(nn.Module):
564
+ """ A basic Swin Transformer layer for one stage.
565
+
566
+ Args:
567
+ dim (int): Number of feature channels
568
+ depth (int): Depths of this stage.
569
+ num_heads (int): Number of attention head.
570
+ window_size (int): Local window size. Default: 7.
571
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
572
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
573
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
574
+ drop (float, optional): Dropout rate. Default: 0.0
575
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
576
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
577
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
578
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
579
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
580
+ """
581
+
582
+ def __init__(self,
583
+ dim,
584
+ depth,
585
+ num_heads,
586
+ window_size=7,
587
+ mlp_ratio=4.,
588
+ qkv_bias=True,
589
+ qk_scale=None,
590
+ drop=0.,
591
+ attn_drop=0.,
592
+ drop_path=0.,
593
+ norm_layer=nn.LayerNorm,
594
+ downsample=None,
595
+ use_checkpoint=False):
596
+ super().__init__()
597
+ self.window_size = window_size
598
+ self.shift_size = window_size // 2
599
+ self.depth = depth
600
+ self.use_checkpoint = use_checkpoint
601
+
602
+ # build blocks
603
+ self.blocks = nn.ModuleList([
604
+ SwinTransformerBlock(dim=dim,
605
+ num_heads=num_heads,
606
+ window_size=window_size,
607
+ shift_size=0 if
608
+ (i % 2 == 0) else window_size // 2,
609
+ mlp_ratio=mlp_ratio,
610
+ qkv_bias=qkv_bias,
611
+ qk_scale=qk_scale,
612
+ drop=drop,
613
+ attn_drop=attn_drop,
614
+ drop_path=drop_path[i] if isinstance(
615
+ drop_path, list) else drop_path,
616
+ norm_layer=norm_layer) for i in range(depth)
617
+ ])
618
+
619
+ # patch merging layer
620
+ if downsample is not None:
621
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
622
+ else:
623
+ self.downsample = None
624
+
625
+ def forward(self, x, H, W):
626
+ """ Forward function.
627
+
628
+ Args:
629
+ x: Input feature, tensor size (B, H*W, C).
630
+ H, W: Spatial resolution of the input feature.
631
+ """
632
+
633
+ # calculate attention mask for SW-MSA
634
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
635
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
636
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
637
+ h_slices = (slice(0, -self.window_size),
638
+ slice(-self.window_size,
639
+ -self.shift_size), slice(-self.shift_size, None))
640
+ w_slices = (slice(0, -self.window_size),
641
+ slice(-self.window_size,
642
+ -self.shift_size), slice(-self.shift_size, None))
643
+ cnt = 0
644
+ for h in h_slices:
645
+ for w in w_slices:
646
+ img_mask[:, h, w, :] = cnt
647
+ cnt += 1
648
+
649
+ mask_windows = window_partition(
650
+ img_mask, self.window_size) # nW, window_size, window_size, 1
651
+ mask_windows = mask_windows.view(-1,
652
+ self.window_size * self.window_size)
653
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
654
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
655
+ float(-100.0)).masked_fill(
656
+ attn_mask == 0, float(0.0))
657
+
658
+ for blk in self.blocks:
659
+ blk.H, blk.W = H, W
660
+ if self.use_checkpoint:
661
+ x = checkpoint.checkpoint(blk, x, attn_mask)
662
+ else:
663
+ x = blk(x, attn_mask)
664
+ if self.downsample is not None:
665
+ x_down = self.downsample(x, H, W)
666
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
667
+ return x, H, W, x_down, Wh, Ww
668
+ else:
669
+ return x, H, W, x, H, W
670
+
671
+
672
+ class PatchEmbed(nn.Module):
673
+ """ Image to Patch Embedding
674
+
675
+ Args:
676
+ patch_size (int): Patch token size. Default: 4.
677
+ in_chans (int): Number of input image channels. Default: 3.
678
+ embed_dim (int): Number of linear projection output channels. Default: 96.
679
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
680
+ """
681
+
682
+ def __init__(self,
683
+ patch_size=4,
684
+ in_chans=3,
685
+ embed_dim=96,
686
+ norm_layer=None):
687
+ super().__init__()
688
+ patch_size = to_2tuple(patch_size)
689
+ self.patch_size = patch_size
690
+
691
+ self.in_chans = in_chans
692
+ self.embed_dim = embed_dim
693
+
694
+ self.proj = nn.Conv2d(in_chans,
695
+ embed_dim,
696
+ kernel_size=patch_size,
697
+ stride=patch_size)
698
+ if norm_layer is not None:
699
+ self.norm = norm_layer(embed_dim)
700
+ else:
701
+ self.norm = None
702
+
703
+ def forward(self, x):
704
+ """Forward function."""
705
+ # padding
706
+ _, _, H, W = x.size()
707
+ if W % self.patch_size[1] != 0:
708
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
709
+ if H % self.patch_size[0] != 0:
710
+ x = F.pad(x,
711
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
712
+
713
+ x = self.proj(x) # B C Wh Ww
714
+ if self.norm is not None:
715
+ Wh, Ww = x.size(2), x.size(3)
716
+ x = x.flatten(2).transpose(1, 2)
717
+ x = self.norm(x)
718
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
719
+
720
+ return x
721
+
722
+
723
+ class SwinTransformer(nn.Module):
724
+ """ Swin Transformer backbone.
725
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
726
+ https://arxiv.org/pdf/2103.14030
727
+
728
+ Args:
729
+ pretrain_img_size (int): Input image size for training the pretrained model,
730
+ used in absolute postion embedding. Default 224.
731
+ patch_size (int | tuple(int)): Patch size. Default: 4.
732
+ in_chans (int): Number of input image channels. Default: 3.
733
+ embed_dim (int): Number of linear projection output channels. Default: 96.
734
+ depths (tuple[int]): Depths of each Swin Transformer stage.
735
+ num_heads (tuple[int]): Number of attention head of each stage.
736
+ window_size (int): Window size. Default: 7.
737
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
738
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
739
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
740
+ drop_rate (float): Dropout rate.
741
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
742
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
743
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
744
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
745
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
746
+ out_indices (Sequence[int]): Output from which stages.
747
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
748
+ -1 means not freezing any parameters.
749
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
750
+ """
751
+
752
+ def __init__(self,
753
+ pretrain_img_size=224,
754
+ patch_size=4,
755
+ in_chans=3,
756
+ embed_dim=96,
757
+ depths=[2, 2, 6, 2],
758
+ num_heads=[3, 6, 12, 24],
759
+ window_size=7,
760
+ mlp_ratio=4.,
761
+ qkv_bias=True,
762
+ qk_scale=None,
763
+ drop_rate=0.,
764
+ attn_drop_rate=0.,
765
+ drop_path_rate=0.2,
766
+ norm_layer=nn.LayerNorm,
767
+ ape=False,
768
+ patch_norm=True,
769
+ out_indices=(0, 1, 2, 3),
770
+ frozen_stages=-1,
771
+ use_checkpoint=False):
772
+ super().__init__()
773
+
774
+ self.pretrain_img_size = pretrain_img_size
775
+ self.num_layers = len(depths)
776
+ self.embed_dim = embed_dim
777
+ self.ape = ape
778
+ self.patch_norm = patch_norm
779
+ self.out_indices = out_indices
780
+ self.frozen_stages = frozen_stages
781
+
782
+ # split image into non-overlapping patches
783
+ self.patch_embed = PatchEmbed(
784
+ patch_size=patch_size,
785
+ in_chans=in_chans,
786
+ embed_dim=embed_dim,
787
+ norm_layer=norm_layer if self.patch_norm else None)
788
+
789
+ # absolute position embedding
790
+ if self.ape:
791
+ pretrain_img_size = to_2tuple(pretrain_img_size)
792
+ patch_size = to_2tuple(patch_size)
793
+ patches_resolution = [
794
+ pretrain_img_size[0] // patch_size[0],
795
+ pretrain_img_size[1] // patch_size[1]
796
+ ]
797
+
798
+ self.absolute_pos_embed = nn.Parameter(
799
+ torch.zeros(1, embed_dim, patches_resolution[0],
800
+ patches_resolution[1]))
801
+ trunc_normal_(self.absolute_pos_embed, std=.02)
802
+
803
+ self.pos_drop = nn.Dropout(p=drop_rate)
804
+
805
+ # stochastic depth
806
+ dpr = [
807
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
808
+ ] # stochastic depth decay rule
809
+
810
+ # build layers
811
+ self.layers = nn.ModuleList()
812
+ for i_layer in range(self.num_layers):
813
+ layer = BasicLayer(
814
+ dim=int(embed_dim * 2**i_layer),
815
+ depth=depths[i_layer],
816
+ num_heads=num_heads[i_layer],
817
+ window_size=window_size,
818
+ mlp_ratio=mlp_ratio,
819
+ qkv_bias=qkv_bias,
820
+ qk_scale=qk_scale,
821
+ drop=drop_rate,
822
+ attn_drop=attn_drop_rate,
823
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
824
+ norm_layer=norm_layer,
825
+ downsample=PatchMerging if
826
+ (i_layer < self.num_layers - 1) else None,
827
+ use_checkpoint=use_checkpoint)
828
+ self.layers.append(layer)
829
+
830
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
831
+ self.num_features = num_features
832
+
833
+ # add a norm layer for each output
834
+ for i_layer in out_indices:
835
+ layer = norm_layer(num_features[i_layer])
836
+ layer_name = f'norm{i_layer}'
837
+ self.add_module(layer_name, layer)
838
+
839
+ self._freeze_stages()
840
+
841
+ def _freeze_stages(self):
842
+ if self.frozen_stages >= 0:
843
+ self.patch_embed.eval()
844
+ for param in self.patch_embed.parameters():
845
+ param.requires_grad = False
846
+
847
+ if self.frozen_stages >= 1 and self.ape:
848
+ self.absolute_pos_embed.requires_grad = False
849
+
850
+ if self.frozen_stages >= 2:
851
+ self.pos_drop.eval()
852
+ for i in range(0, self.frozen_stages - 1):
853
+ m = self.layers[i]
854
+ m.eval()
855
+ for param in m.parameters():
856
+ param.requires_grad = False
857
+
858
+ def init_weights(self, pretrained=None):
859
+ """Initialize the weights in backbone.
860
+
861
+ Args:
862
+ pretrained (str, optional): Path to pre-trained weights.
863
+ Defaults to None.
864
+ """
865
+
866
+ def _init_weights(m):
867
+ if isinstance(m, nn.Linear):
868
+ trunc_normal_(m.weight, std=.02)
869
+ if isinstance(m, nn.Linear) and m.bias is not None:
870
+ nn.init.constant_(m.bias, 0)
871
+ elif isinstance(m, nn.LayerNorm):
872
+ nn.init.constant_(m.bias, 0)
873
+ nn.init.constant_(m.weight, 1.0)
874
+
875
+ if isinstance(pretrained, str):
876
+ self.apply(_init_weights)
877
+ load_checkpoint(self, pretrained, strict=False, logger=None)
878
+ elif pretrained is None:
879
+ self.apply(_init_weights)
880
+ else:
881
+ raise TypeError('pretrained must be a str or None')
882
+
883
+ def forward(self, x):
884
+ x = self.patch_embed(x)
885
+
886
+ Wh, Ww = x.size(2), x.size(3)
887
+ if self.ape:
888
+ # interpolate the position embedding to the corresponding size
889
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
890
+ size=(Wh, Ww),
891
+ mode='bicubic')
892
+ x = (x + absolute_pos_embed) # B Wh*Ww C
893
+
894
+ outs = [x.contiguous()]
895
+ x = x.flatten(2).transpose(1, 2)
896
+ x = self.pos_drop(x)
897
+ for i in range(self.num_layers):
898
+ layer = self.layers[i]
899
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
900
+
901
+ if i in self.out_indices:
902
+ norm_layer = getattr(self, f'norm{i}')
903
+ x_out = norm_layer(x_out)
904
+
905
+ out = x_out.view(-1, H, W,
906
+ self.num_features[i]).permute(0, 3, 1,
907
+ 2).contiguous()
908
+ outs.append(out)
909
+
910
+ return tuple(outs)
911
+
912
+ def train(self, mode=True):
913
+ """Convert the model into training mode while keep layers freezed."""
914
+ super(SwinTransformer, self).train(mode)
915
+ self._freeze_stages()
916
+
917
+
918
+ class PositionEmbeddingSine:
919
+
920
+ def __init__(self,
921
+ num_pos_feats=64,
922
+ temperature=10000,
923
+ normalize=False,
924
+ scale=None):
925
+ super().__init__()
926
+ self.num_pos_feats = num_pos_feats
927
+ self.temperature = temperature
928
+ self.normalize = normalize
929
+ if scale is not None and normalize is False:
930
+ raise ValueError("normalize should be True if scale is passed")
931
+ if scale is None:
932
+ scale = 2 * math.pi
933
+ self.scale = scale
934
+ self.dim_t = torch.arange(0,
935
+ self.num_pos_feats,
936
+ dtype=torch_dtype,
937
+ device=torch_device)
938
+
939
+ def __call__(self, b, h, w):
940
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
941
+ assert mask is not None
942
+ not_mask = ~mask
943
+ y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
944
+ x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
945
+ if self.normalize:
946
+ eps = 1e-6
947
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
948
+ self.scale).to(device=torch_device, dtype=torch_dtype)
949
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
950
+ self.scale).to(device=torch_device, dtype=torch_dtype)
951
+
952
+ dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
953
+
954
+ pos_x = x_embed[:, :, :, None] / dim_t
955
+ pos_y = y_embed[:, :, :, None] / dim_t
956
+ pos_x = torch.stack(
957
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
958
+ dim=4).flatten(3)
959
+ pos_y = torch.stack(
960
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
961
+ dim=4).flatten(3)
962
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
963
+
964
+
965
+ class MCLM(nn.Module):
966
+
967
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
968
+ super(MCLM, self).__init__()
969
+ self.attention = nn.ModuleList([
970
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
971
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
972
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
973
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
974
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
975
+ ])
976
+
977
+ self.linear1 = nn.Linear(d_model, d_model * 2)
978
+ self.linear2 = nn.Linear(d_model * 2, d_model)
979
+ self.linear3 = nn.Linear(d_model, d_model * 2)
980
+ self.linear4 = nn.Linear(d_model * 2, d_model)
981
+ self.norm1 = nn.LayerNorm(d_model)
982
+ self.norm2 = nn.LayerNorm(d_model)
983
+ self.dropout = nn.Dropout(0.1)
984
+ self.dropout1 = nn.Dropout(0.1)
985
+ self.dropout2 = nn.Dropout(0.1)
986
+ self.activation = get_activation_fn('relu')
987
+ self.pool_ratios = pool_ratios
988
+ self.p_poses = []
989
+ self.g_pos = None
990
+ self.positional_encoding = PositionEmbeddingSine(
991
+ num_pos_feats=d_model // 2, normalize=True)
992
+
993
+ def forward(self, l, g):
994
+ """
995
+ l: 4,c,h,w
996
+ g: 1,c,h,w
997
+ """
998
+ b, c, h, w = l.size()
999
+ # 4,c,h,w -> 1,c,2h,2w
1000
+ concated_locs = rearrange(l,
1001
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1002
+ hg=2,
1003
+ wg=2)
1004
+
1005
+ pools = []
1006
+ for pool_ratio in self.pool_ratios:
1007
+ # b,c,h,w
1008
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1009
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1010
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1011
+ if self.g_pos is None:
1012
+ pos_emb = self.positional_encoding(pool.shape[0],
1013
+ pool.shape[2],
1014
+ pool.shape[3])
1015
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1016
+ self.p_poses.append(pos_emb)
1017
+ pools = torch.cat(pools, 0)
1018
+ if self.g_pos is None:
1019
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1020
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
1021
+ g.shape[3])
1022
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1023
+
1024
+ # attention between glb (q) & multisensory concated-locs (k,v)
1025
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1026
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1027
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1028
+ g_hw_b_c = self.norm1(g_hw_b_c)
1029
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1030
+ self.linear2(
1031
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1032
+ g_hw_b_c = self.norm2(g_hw_b_c)
1033
+
1034
+ # attention between origin locs (q) & freashed glb (k,v)
1035
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1036
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1037
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1038
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1039
+ ng=2,
1040
+ nw=2)
1041
+ outputs_re = []
1042
+ for i, (_l, _g) in enumerate(
1043
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1044
+ outputs_re.append(self.attention[i + 1](_l, _g,
1045
+ _g)[0]) # (h w) 1 c
1046
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1047
+
1048
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1049
+ l_hw_b_c = self.norm1(l_hw_b_c)
1050
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1051
+ self.linear4(
1052
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1053
+ l_hw_b_c = self.norm2(l_hw_b_c)
1054
+
1055
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1056
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1057
+
1058
+
1059
+ class inf_MCLM(nn.Module):
1060
+
1061
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
1062
+ super(inf_MCLM, self).__init__()
1063
+ self.attention = nn.ModuleList([
1064
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1065
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1066
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1067
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1068
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1069
+ ])
1070
+
1071
+ self.linear1 = nn.Linear(d_model, d_model * 2)
1072
+ self.linear2 = nn.Linear(d_model * 2, d_model)
1073
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1074
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1075
+ self.norm1 = nn.LayerNorm(d_model)
1076
+ self.norm2 = nn.LayerNorm(d_model)
1077
+ self.dropout = nn.Dropout(0.1)
1078
+ self.dropout1 = nn.Dropout(0.1)
1079
+ self.dropout2 = nn.Dropout(0.1)
1080
+ self.activation = get_activation_fn('relu')
1081
+ self.pool_ratios = pool_ratios
1082
+ self.p_poses = []
1083
+ self.g_pos = None
1084
+ self.positional_encoding = PositionEmbeddingSine(
1085
+ num_pos_feats=d_model // 2, normalize=True)
1086
+
1087
+ def forward(self, l, g):
1088
+ """
1089
+ l: 4,c,h,w
1090
+ g: 1,c,h,w
1091
+ """
1092
+ b, c, h, w = l.size()
1093
+ # 4,c,h,w -> 1,c,2h,2w
1094
+ concated_locs = rearrange(l,
1095
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1096
+ hg=2,
1097
+ wg=2)
1098
+ self.p_poses = []
1099
+ pools = []
1100
+ for pool_ratio in self.pool_ratios:
1101
+ # b,c,h,w
1102
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1103
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1104
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1105
+ # if self.g_pos is None:
1106
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
1107
+ pool.shape[3])
1108
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1109
+ self.p_poses.append(pos_emb)
1110
+ pools = torch.cat(pools, 0)
1111
+ # if self.g_pos is None:
1112
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1113
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
1114
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1115
+
1116
+ # attention between glb (q) & multisensory concated-locs (k,v)
1117
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1118
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1119
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1120
+ g_hw_b_c = self.norm1(g_hw_b_c)
1121
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1122
+ self.linear2(
1123
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1124
+ g_hw_b_c = self.norm2(g_hw_b_c)
1125
+
1126
+ # attention between origin locs (q) & freashed glb (k,v)
1127
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1128
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1129
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1130
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1131
+ ng=2,
1132
+ nw=2)
1133
+ outputs_re = []
1134
+ for i, (_l, _g) in enumerate(
1135
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1136
+ outputs_re.append(self.attention[i + 1](_l, _g,
1137
+ _g)[0]) # (h w) 1 c
1138
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1139
+
1140
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1141
+ l_hw_b_c = self.norm1(l_hw_b_c)
1142
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1143
+ self.linear4(
1144
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1145
+ l_hw_b_c = self.norm2(l_hw_b_c)
1146
+
1147
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1148
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1149
+
1150
+
1151
+ class MCRM(nn.Module):
1152
+
1153
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1154
+ super(MCRM, self).__init__()
1155
+ self.attention = nn.ModuleList([
1156
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1157
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1158
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1159
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1160
+ ])
1161
+
1162
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1163
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1164
+ self.norm1 = nn.LayerNorm(d_model)
1165
+ self.norm2 = nn.LayerNorm(d_model)
1166
+ self.dropout = nn.Dropout(0.1)
1167
+ self.dropout1 = nn.Dropout(0.1)
1168
+ self.dropout2 = nn.Dropout(0.1)
1169
+ self.sigmoid = nn.Sigmoid()
1170
+ self.activation = get_activation_fn('relu')
1171
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1172
+ self.pool_ratios = pool_ratios
1173
+ self.positional_encoding = PositionEmbeddingSine(
1174
+ num_pos_feats=d_model // 2, normalize=True)
1175
+
1176
+ def forward(self, x):
1177
+ b, c, h, w = x.size()
1178
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1179
+ # b(4),c,h,w
1180
+ patched_glb = rearrange(glb,
1181
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1182
+ hg=2,
1183
+ wg=2)
1184
+
1185
+ # generate token attention map
1186
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1187
+ token_attention_map = F.interpolate(token_attention_map,
1188
+ size=patches2image(loc).shape[-2:],
1189
+ mode='nearest')
1190
+ loc = loc * rearrange(token_attention_map,
1191
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1192
+ hg=2,
1193
+ wg=2)
1194
+ pools = []
1195
+ for pool_ratio in self.pool_ratios:
1196
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1197
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1198
+ pools.append(rearrange(pool,
1199
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1200
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1201
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1202
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1203
+ outputs = []
1204
+ for i, q in enumerate(
1205
+ loc_.unbind(dim=0)): # traverse all local patches
1206
+ # np*hw,1,c
1207
+ v = pools[i]
1208
+ k = v
1209
+ outputs.append(self.attention[i](q, k, v)[0])
1210
+ outputs = torch.cat(outputs, 1)
1211
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1212
+ src = self.norm1(src)
1213
+ src = src + self.dropout2(
1214
+ self.linear4(
1215
+ self.dropout(self.activation(self.linear3(src)).clone())))
1216
+ src = self.norm2(src)
1217
+
1218
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1219
+ glb = glb + F.interpolate(patches2image(src),
1220
+ size=glb.shape[-2:],
1221
+ mode='nearest') # freshed glb
1222
+ return torch.cat((src, glb), 0), token_attention_map
1223
+
1224
+
1225
+ class inf_MCRM(nn.Module):
1226
+
1227
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1228
+ super(inf_MCRM, self).__init__()
1229
+ self.attention = nn.ModuleList([
1230
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1231
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1232
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1233
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1234
+ ])
1235
+
1236
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1237
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1238
+ self.norm1 = nn.LayerNorm(d_model)
1239
+ self.norm2 = nn.LayerNorm(d_model)
1240
+ self.dropout = nn.Dropout(0.1)
1241
+ self.dropout1 = nn.Dropout(0.1)
1242
+ self.dropout2 = nn.Dropout(0.1)
1243
+ self.sigmoid = nn.Sigmoid()
1244
+ self.activation = get_activation_fn('relu')
1245
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1246
+ self.pool_ratios = pool_ratios
1247
+ self.positional_encoding = PositionEmbeddingSine(
1248
+ num_pos_feats=d_model // 2, normalize=True)
1249
+
1250
+ def forward(self, x):
1251
+ b, c, h, w = x.size()
1252
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1253
+ # b(4),c,h,w
1254
+ patched_glb = rearrange(glb,
1255
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1256
+ hg=2,
1257
+ wg=2)
1258
+
1259
+ # generate token attention map
1260
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1261
+ token_attention_map = F.interpolate(token_attention_map,
1262
+ size=patches2image(loc).shape[-2:],
1263
+ mode='nearest')
1264
+ loc = loc * rearrange(token_attention_map,
1265
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1266
+ hg=2,
1267
+ wg=2)
1268
+ pools = []
1269
+ for pool_ratio in self.pool_ratios:
1270
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1271
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1272
+ pools.append(rearrange(pool,
1273
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1274
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1275
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1276
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1277
+ outputs = []
1278
+ for i, q in enumerate(
1279
+ loc_.unbind(dim=0)): # traverse all local patches
1280
+ # np*hw,1,c
1281
+ v = pools[i]
1282
+ k = v
1283
+ outputs.append(self.attention[i](q, k, v)[0])
1284
+ outputs = torch.cat(outputs, 1)
1285
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1286
+ src = self.norm1(src)
1287
+ src = src + self.dropout2(
1288
+ self.linear4(
1289
+ self.dropout(self.activation(self.linear3(src)).clone())))
1290
+ src = self.norm2(src)
1291
+
1292
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1293
+ glb = glb + F.interpolate(patches2image(src),
1294
+ size=glb.shape[-2:],
1295
+ mode='nearest') # freshed glb
1296
+ return torch.cat((src, glb), 0)
1297
+
1298
+
1299
+ # model for single-scale training
1300
+ class MVANet(nn.Module):
1301
+
1302
+ def __init__(self):
1303
+ super().__init__()
1304
+ self.backbone = SwinB(pretrained=True)
1305
+ emb_dim = 128
1306
+ self.sideout5 = nn.Sequential(
1307
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1308
+ self.sideout4 = nn.Sequential(
1309
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1310
+ self.sideout3 = nn.Sequential(
1311
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1312
+ self.sideout2 = nn.Sequential(
1313
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1314
+ self.sideout1 = nn.Sequential(
1315
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1316
+
1317
+ self.output5 = make_cbr(1024, emb_dim)
1318
+ self.output4 = make_cbr(512, emb_dim)
1319
+ self.output3 = make_cbr(256, emb_dim)
1320
+ self.output2 = make_cbr(128, emb_dim)
1321
+ self.output1 = make_cbr(128, emb_dim)
1322
+
1323
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
1324
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1325
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1326
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1327
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1328
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
1329
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
1330
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
1331
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
1332
+
1333
+ self.insmask_head = nn.Sequential(
1334
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1335
+ nn.BatchNorm2d(384), nn.PReLU(),
1336
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1337
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1338
+
1339
+ self.shallow = nn.Sequential(
1340
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1341
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1342
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1343
+ self.output = nn.Sequential(
1344
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1345
+
1346
+ for m in self.modules():
1347
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1348
+ m.inplace = True
1349
+
1350
+ def forward(self, x):
1351
+ x = x.to(dtype=torch_dtype, device=torch_device)
1352
+ shallow = self.shallow(x)
1353
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1354
+ loc = image2patches(x)
1355
+ input = torch.cat((loc, glb), dim=0)
1356
+ feature = self.backbone(input)
1357
+ e5 = self.output5(feature[4]) # (5,128,16,16)
1358
+ e4 = self.output4(feature[3]) # (5,128,32,32)
1359
+ e3 = self.output3(feature[2]) # (5,128,64,64)
1360
+ e2 = self.output2(feature[1]) # (5,128,128,128)
1361
+ e1 = self.output1(feature[0]) # (5,128,128,128)
1362
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1363
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
1364
+
1365
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
1366
+ e4 = self.conv4(e4)
1367
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
1368
+ e3 = self.conv3(e3)
1369
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
1370
+ e2 = self.conv2(e2)
1371
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
1372
+ e1 = self.conv1(e1)
1373
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1374
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
1375
+ # add glb feat in
1376
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1377
+ # merge
1378
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
1379
+ # shallow feature merge
1380
+ final_output = final_output + resize_as(shallow, final_output)
1381
+ final_output = self.upsample1(rescale_to(final_output))
1382
+ final_output = rescale_to(final_output +
1383
+ resize_as(shallow, final_output))
1384
+ final_output = self.upsample2(final_output)
1385
+ final_output = self.output(final_output)
1386
+ ####
1387
+ sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
1388
+ sideout4 = self.sideout4(e4)
1389
+ sideout3 = self.sideout3(e3)
1390
+ sideout2 = self.sideout2(e2)
1391
+ sideout1 = self.sideout1(e1)
1392
+ #######glb_sideouts ######
1393
+ glb5 = self.sideout5(glb_e5)
1394
+ glb4 = sideout4[-1, :, :, :].unsqueeze(0)
1395
+ glb3 = sideout3[-1, :, :, :].unsqueeze(0)
1396
+ glb2 = sideout2[-1, :, :, :].unsqueeze(0)
1397
+ glb1 = sideout1[-1, :, :, :].unsqueeze(0)
1398
+ ####### concat 4 to 1 #######
1399
+ sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
1400
+ device=torch_device)
1401
+ sideout2 = patches2image(sideout2[:-1]).to(
1402
+ dtype=torch_dtype,
1403
+ device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
1404
+ sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
1405
+ device=torch_device)
1406
+ sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
1407
+ device=torch_device)
1408
+ sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
1409
+ device=torch_device)
1410
+ if self.training:
1411
+ return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
1412
+ else:
1413
+ return final_output
1414
+
1415
+
1416
+ # model for multi-scale testing
1417
+ class inf_MVANet(nn.Module):
1418
+
1419
+ def __init__(self):
1420
+ super().__init__()
1421
+ # self.backbone = SwinB(pretrained=True)
1422
+ self.backbone = SwinB(pretrained=False)
1423
+
1424
+ emb_dim = 128
1425
+ self.output5 = make_cbr(1024, emb_dim)
1426
+ self.output4 = make_cbr(512, emb_dim)
1427
+ self.output3 = make_cbr(256, emb_dim)
1428
+ self.output2 = make_cbr(128, emb_dim)
1429
+ self.output1 = make_cbr(128, emb_dim)
1430
+
1431
+ self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
1432
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1433
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1434
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1435
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1436
+ self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1437
+ self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1438
+ self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1439
+ self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1440
+
1441
+ self.insmask_head = nn.Sequential(
1442
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1443
+ nn.BatchNorm2d(384), nn.PReLU(),
1444
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1445
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1446
+
1447
+ self.shallow = nn.Sequential(
1448
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1449
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1450
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1451
+ self.output = nn.Sequential(
1452
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1453
+
1454
+ for m in self.modules():
1455
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1456
+ m.inplace = True
1457
+
1458
+ def forward(self, x):
1459
+ shallow = self.shallow(x)
1460
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1461
+ loc = image2patches(x)
1462
+ input = torch.cat((loc, glb), dim=0)
1463
+ feature = self.backbone(input)
1464
+ e5 = self.output5(feature[4])
1465
+ e4 = self.output4(feature[3])
1466
+ e3 = self.output3(feature[2])
1467
+ e2 = self.output2(feature[1])
1468
+ e1 = self.output1(feature[0])
1469
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1470
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
1471
+
1472
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1473
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1474
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1475
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1476
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1477
+ # after decoder, concat loc features to a whole one, and merge
1478
+ output1_cat = patches2image(loc_e1)
1479
+ # add glb feat in
1480
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1481
+ # merge
1482
+ final_output = self.insmask_head(output1_cat)
1483
+ # shallow feature merge
1484
+ final_output = final_output + resize_as(shallow, final_output)
1485
+ final_output = self.upsample1(rescale_to(final_output))
1486
+ final_output = rescale_to(final_output +
1487
+ resize_as(shallow, final_output))
1488
+ final_output = self.upsample2(final_output)
1489
+ final_output = self.output(final_output)
1490
+ return final_output
1491
+
1492
+
1493
+ class load_MVANet_Model:
1494
+
1495
+ def __init__(self):
1496
+ pass
1497
+
1498
+ @classmethod
1499
+ def INPUT_TYPES(s):
1500
+ return {
1501
+ "required": {},
1502
+ }
1503
+
1504
+ RETURN_TYPES = ("MVANet_Model", )
1505
+ FUNCTION = "test"
1506
+ CATEGORY = "MVANet"
1507
+
1508
+ def test(self):
1509
+ return (load_model(get_model_path()), )
1510
+
1511
+
1512
+ class run_MVANet_inference:
1513
+
1514
+ def __init__(self):
1515
+ pass
1516
+
1517
+ @classmethod
1518
+ def INPUT_TYPES(s):
1519
+ return {
1520
+ "required": {
1521
+ "image": ("IMAGE", ),
1522
+ "MVANet_Model": ("MVANet_Model", ),
1523
+ },
1524
+ }
1525
+
1526
+ RETURN_TYPES = ("MASK", )
1527
+ FUNCTION = "test"
1528
+ CATEGORY = "MVANet"
1529
+
1530
+ def test(
1531
+ self,
1532
+ image,
1533
+ MVANet_Model,
1534
+ ):
1535
+ ret = do_infer_tensor2tensor(img=image, net=MVANet_Model)
1536
+
1537
+ return (ret, )
1538
+
1539
+
1540
+ NODE_CLASS_MAPPINGS = {
1541
+ "load_MVANet_Model": load_MVANet_Model,
1542
+ "run_MVANet_inference": run_MVANet_inference
1543
+ }
1544
+
1545
+ NODE_DISPLAY_NAME_MAPPINGS = {
1546
+ "load_MVANet_Model": "load MVANet Model",
1547
+ "run_MVANet_inference": "run_MVANet_inference"
1548
+ }
ComfyUI_MVANet/download.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ get_repo(){
3
+ DIR_REPO="${HOME}/GITHUB/$('echo' "${1}" | 'sed' 's/^git@github.com://g ; s@^https://github.com/@@g ; s@.git$@@g' )"
4
+ DIR_BASE="$('dirname' '--' "${DIR_REPO}")"
5
+ mkdir -pv -- "${DIR_BASE}"
6
+ cd "${DIR_BASE}"
7
+ git clone "${1}"
8
+ cd "${DIR_REPO}"
9
+ git pull
10
+ git submodule update --recursive --init
11
+ }
12
+
13
+ get_repo 'https://github.com/qianyu-dlut/MVANet.git'
ComfyUI_MVANet/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ timm
2
+ einops
3
+ wget
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Peike Li
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MVANet_Inference/README.org ADDED
@@ -0,0 +1,2179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * COMMENT Sample
2
+
3
+ ** Shell script to download
4
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
5
+ #+end_src
6
+
7
+ ** MVANet_inference import
8
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
9
+ #+end_src
10
+
11
+ ** MVANet_inference function
12
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
13
+ #+end_src
14
+
15
+ ** MVANet_inference class
16
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
17
+ #+end_src
18
+
19
+ ** MVANet_inference execute
20
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
21
+ #+end_src
22
+
23
+ ** MVANet_inference unify
24
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
25
+ #+end_src
26
+
27
+ ** MVANet_inference run
28
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.run.sh
29
+ #+end_src
30
+
31
+ * Download the code:
32
+
33
+ ** Function to download
34
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
35
+ get_repo(){
36
+ DIR_REPO="${HOME}/GITHUB/$('echo' "${1}" | 'sed' 's/^git@github.com://g ; s@^https://github.com/@@g ; s@.git$@@g' )"
37
+ DIR_BASE="$('dirname' '--' "${DIR_REPO}")"
38
+ mkdir -pv -- "${DIR_BASE}"
39
+ cd "${DIR_BASE}"
40
+ git clone "${1}"
41
+ cd "${DIR_REPO}"
42
+ git pull
43
+ git submodule update --recursive --init
44
+ }
45
+ #+end_src
46
+
47
+ ** Download
48
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./download.sh
49
+ get_repo 'https://github.com/qianyu-dlut/MVANet.git'
50
+ #+end_src
51
+
52
+ * Dependencies
53
+ pip3 install mmdet==2.23.0
54
+ pip3 install mmcv==1.4.8
55
+ pip3 install ttach
56
+
57
+ * Python inference
58
+
59
+ ** Important configs
60
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
61
+ import os
62
+ import sys
63
+
64
+ HOME_DIR = os.environ.get('HOME', '/root')
65
+ MVANET_SOURCE_DIR = HOME_DIR + '/GITHUB/qianyu-dlut/MVANet'
66
+ finetuned_MVANet_model_path = MVANET_SOURCE_DIR + '/model/Model_80.pth'
67
+ pretrained_SwinB_model_path = MVANET_SOURCE_DIR + '/model/swin_base_patch4_window12_384_22kto1k.pth'
68
+ #+end_src
69
+
70
+ ** MVANet_inference import
71
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.import.py
72
+ import math
73
+ import numpy as np
74
+ from PIL import Image
75
+ import time
76
+ # import ttach as tta
77
+ import cv2
78
+
79
+ import torch
80
+ import torch.nn as nn
81
+ import torch.nn.functional as F
82
+ import torch.utils.checkpoint as checkpoint
83
+ from torch.autograd import Variable
84
+ from torch import nn
85
+ from torchvision import transforms
86
+
87
+ from einops import rearrange
88
+
89
+ from timm.models import load_checkpoint
90
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
91
+ #+end_src
92
+
93
+ ** Load image using CV
94
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
95
+ def load_image(input_image_path):
96
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
97
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
98
+ return img
99
+
100
+
101
+ def load_image_torch(input_image_path):
102
+ img = cv2.imread(input_image_path, cv2.IMREAD_COLOR)
103
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
104
+ img = torch.from_numpy(img)
105
+ img = img.to(dtype=torch.float32)
106
+ img /= 255.0
107
+ img = img.unsqueeze(0)
108
+ return img
109
+
110
+
111
+ def save_mask(output_image_path, mask):
112
+ cv2.imwrite(output_image_path, mask)
113
+
114
+
115
+ def save_mask_torch(output_image_path, mask):
116
+ mask = mask.detach().cpu()
117
+ mask *= 255.0
118
+ mask = mask.clamp(0, 255)
119
+ print(mask.shape)
120
+ mask = mask.squeeze(0)
121
+ mask = mask.to(dtype=torch.uint8)
122
+ print(mask.shape)
123
+ mask = mask.numpy()
124
+ print(mask.shape)
125
+ cv2.imwrite(output_image_path, mask)
126
+ #+end_src
127
+
128
+ ** Device configs
129
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
130
+ torch_device = 'cuda'
131
+ torch_dtype = torch.float16
132
+ #+end_src
133
+ to(dtype=torch_dtype, device=torch_device)
134
+
135
+ ** MVANet_inference function
136
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
137
+ def check_mkdir(dir_name):
138
+ if not os.path.isdir(dir_name):
139
+ os.makedirs(dir_name)
140
+
141
+
142
+ def SwinT(pretrained=True):
143
+ model = SwinTransformer(embed_dim=96,
144
+ depths=[2, 2, 6, 2],
145
+ num_heads=[3, 6, 12, 24],
146
+ window_size=7)
147
+ if pretrained is True:
148
+ model.load_state_dict(torch.load(
149
+ 'data/backbone_ckpt/swin_tiny_patch4_window7_224.pth',
150
+ map_location='cpu')['model'],
151
+ strict=False)
152
+
153
+ return model
154
+
155
+
156
+ def SwinS(pretrained=True):
157
+ model = SwinTransformer(embed_dim=96,
158
+ depths=[2, 2, 18, 2],
159
+ num_heads=[3, 6, 12, 24],
160
+ window_size=7)
161
+ if pretrained is True:
162
+ model.load_state_dict(torch.load(
163
+ 'data/backbone_ckpt/swin_small_patch4_window7_224.pth',
164
+ map_location='cpu')['model'],
165
+ strict=False)
166
+
167
+ return model
168
+
169
+
170
+ def SwinB(pretrained=True):
171
+ model = SwinTransformer(embed_dim=128,
172
+ depths=[2, 2, 18, 2],
173
+ num_heads=[4, 8, 16, 32],
174
+ window_size=12)
175
+ if pretrained is True:
176
+ import os
177
+ model.load_state_dict(torch.load(pretrained_SwinB_model_path,
178
+ map_location='cpu')['model'],
179
+ strict=False)
180
+ return model
181
+
182
+
183
+ def SwinL(pretrained=True):
184
+ model = SwinTransformer(embed_dim=192,
185
+ depths=[2, 2, 18, 2],
186
+ num_heads=[6, 12, 24, 48],
187
+ window_size=12)
188
+ if pretrained is True:
189
+ model.load_state_dict(torch.load(
190
+ 'data/backbone_ckpt/swin_large_patch4_window12_384_22kto1k.pth',
191
+ map_location='cpu')['model'],
192
+ strict=False)
193
+
194
+ return model
195
+
196
+
197
+ def get_activation_fn(activation):
198
+ """Return an activation function given a string"""
199
+ if activation == "relu":
200
+ return F.relu
201
+ if activation == "gelu":
202
+ return F.gelu
203
+ if activation == "glu":
204
+ return F.glu
205
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
206
+
207
+
208
+ def make_cbr(in_dim, out_dim):
209
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
210
+ nn.BatchNorm2d(out_dim), nn.PReLU())
211
+
212
+
213
+ def make_cbg(in_dim, out_dim):
214
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
215
+ nn.BatchNorm2d(out_dim), nn.GELU())
216
+
217
+
218
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
219
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
220
+
221
+
222
+ def resize_as(x, y, interpolation='bilinear'):
223
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
224
+
225
+
226
+ def image2patches(x):
227
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
228
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
229
+ return x
230
+
231
+
232
+ def patches2image(x):
233
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
234
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
235
+ return x
236
+
237
+
238
+ def window_partition(x, window_size):
239
+ """
240
+ Args:
241
+ x: (B, H, W, C)
242
+ window_size (int): window size
243
+
244
+ Returns:
245
+ windows: (num_windows*B, window_size, window_size, C)
246
+ """
247
+ B, H, W, C = x.shape
248
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
249
+ C)
250
+ windows = x.permute(0, 1, 3, 2, 4,
251
+ 5).contiguous().view(-1, window_size, window_size, C)
252
+ return windows
253
+
254
+
255
+ def window_reverse(windows, window_size, H, W):
256
+ """
257
+ Args:
258
+ windows: (num_windows*B, window_size, window_size, C)
259
+ window_size (int): Window size
260
+ H (int): Height of image
261
+ W (int): Width of image
262
+
263
+ Returns:
264
+ x: (B, H, W, C)
265
+ """
266
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
267
+ x = windows.view(B, H // window_size, W // window_size, window_size,
268
+ window_size, -1)
269
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
270
+ return x
271
+ #+end_src
272
+
273
+ ** MVANet_inference class
274
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.class.py
275
+ class Mlp(nn.Module):
276
+ """ Multilayer perceptron."""
277
+
278
+ def __init__(self,
279
+ in_features,
280
+ hidden_features=None,
281
+ out_features=None,
282
+ act_layer=nn.GELU,
283
+ drop=0.):
284
+ super().__init__()
285
+ out_features = out_features or in_features
286
+ hidden_features = hidden_features or in_features
287
+ self.fc1 = nn.Linear(in_features, hidden_features)
288
+ self.act = act_layer()
289
+ self.fc2 = nn.Linear(hidden_features, out_features)
290
+ self.drop = nn.Dropout(drop)
291
+
292
+ def forward(self, x):
293
+ x = self.fc1(x)
294
+ x = self.act(x)
295
+ x = self.drop(x)
296
+ x = self.fc2(x)
297
+ x = self.drop(x)
298
+ return x
299
+
300
+
301
+ class WindowAttention(nn.Module):
302
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
303
+ It supports both of shifted and non-shifted window.
304
+
305
+ Args:
306
+ dim (int): Number of input channels.
307
+ window_size (tuple[int]): The height and width of the window.
308
+ num_heads (int): Number of attention heads.
309
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
310
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
311
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
312
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
313
+ """
314
+
315
+ def __init__(self,
316
+ dim,
317
+ window_size,
318
+ num_heads,
319
+ qkv_bias=True,
320
+ qk_scale=None,
321
+ attn_drop=0.,
322
+ proj_drop=0.):
323
+
324
+ super().__init__()
325
+ self.dim = dim
326
+ self.window_size = window_size # Wh, Ww
327
+ self.num_heads = num_heads
328
+ head_dim = dim // num_heads
329
+ self.scale = qk_scale or head_dim**-0.5
330
+
331
+ # define a parameter table of relative position bias
332
+ self.relative_position_bias_table = nn.Parameter(
333
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
334
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
335
+
336
+ # get pair-wise relative position index for each token inside the window
337
+ coords_h = torch.arange(self.window_size[0])
338
+ coords_w = torch.arange(self.window_size[1])
339
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
340
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
341
+ relative_coords = coords_flatten[:, :,
342
+ None] - coords_flatten[:,
343
+ None, :] # 2, Wh*Ww, Wh*Ww
344
+ relative_coords = relative_coords.permute(
345
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
346
+ relative_coords[:, :,
347
+ 0] += self.window_size[0] - 1 # shift to start from 0
348
+ relative_coords[:, :, 1] += self.window_size[1] - 1
349
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
350
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
351
+ self.register_buffer("relative_position_index",
352
+ relative_position_index)
353
+
354
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
355
+ self.attn_drop = nn.Dropout(attn_drop)
356
+ self.proj = nn.Linear(dim, dim)
357
+ self.proj_drop = nn.Dropout(proj_drop)
358
+
359
+ trunc_normal_(self.relative_position_bias_table, std=.02)
360
+ self.softmax = nn.Softmax(dim=-1)
361
+
362
+ def forward(self, x, mask=None):
363
+ """ Forward function.
364
+
365
+ Args:
366
+ x: input features with shape of (num_windows*B, N, C)
367
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
368
+ """
369
+ x = x.to(dtype=torch_dtype, device=torch_device)
370
+ B_, N, C = x.shape
371
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
372
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
373
+ q, k, v = qkv[0], qkv[1], qkv[
374
+ 2] # make torchscript happy (cannot use tensor as tuple)
375
+
376
+ q = q * self.scale
377
+ attn = (q @ k.transpose(-2, -1))
378
+
379
+ relative_position_bias = self.relative_position_bias_table[
380
+ self.relative_position_index.view(-1)].view(
381
+ self.window_size[0] * self.window_size[1],
382
+ self.window_size[0] * self.window_size[1],
383
+ -1) # Wh*Ww,Wh*Ww,nH
384
+ relative_position_bias = relative_position_bias.permute(
385
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
386
+ attn = attn + relative_position_bias.unsqueeze(0)
387
+
388
+ if mask is not None:
389
+ nW = mask.shape[0]
390
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
391
+ N) + mask.unsqueeze(1).unsqueeze(0)
392
+ attn = attn.view(-1, self.num_heads, N, N)
393
+ attn = self.softmax(attn)
394
+ else:
395
+ attn = self.softmax(attn)
396
+
397
+ attn = self.attn_drop(attn)
398
+ attn = attn.to(dtype=torch_dtype, device=torch_device)
399
+ v = v.to(dtype=torch_dtype, device=torch_device)
400
+
401
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
402
+ x = self.proj(x)
403
+ x = self.proj_drop(x)
404
+ return x
405
+
406
+
407
+ class SwinTransformerBlock(nn.Module):
408
+ """ Swin Transformer Block.
409
+
410
+ Args:
411
+ dim (int): Number of input channels.
412
+ num_heads (int): Number of attention heads.
413
+ window_size (int): Window size.
414
+ shift_size (int): Shift size for SW-MSA.
415
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
416
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
417
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
418
+ drop (float, optional): Dropout rate. Default: 0.0
419
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
420
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
421
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
422
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
423
+ """
424
+
425
+ def __init__(self,
426
+ dim,
427
+ num_heads,
428
+ window_size=7,
429
+ shift_size=0,
430
+ mlp_ratio=4.,
431
+ qkv_bias=True,
432
+ qk_scale=None,
433
+ drop=0.,
434
+ attn_drop=0.,
435
+ drop_path=0.,
436
+ act_layer=nn.GELU,
437
+ norm_layer=nn.LayerNorm):
438
+ super().__init__()
439
+ self.dim = dim
440
+ self.num_heads = num_heads
441
+ self.window_size = window_size
442
+ self.shift_size = shift_size
443
+ self.mlp_ratio = mlp_ratio
444
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
445
+
446
+ self.norm1 = norm_layer(dim)
447
+ self.attn = WindowAttention(dim,
448
+ window_size=to_2tuple(self.window_size),
449
+ num_heads=num_heads,
450
+ qkv_bias=qkv_bias,
451
+ qk_scale=qk_scale,
452
+ attn_drop=attn_drop,
453
+ proj_drop=drop)
454
+
455
+ self.drop_path = DropPath(
456
+ drop_path) if drop_path > 0. else nn.Identity()
457
+ self.norm2 = norm_layer(dim)
458
+ mlp_hidden_dim = int(dim * mlp_ratio)
459
+ self.mlp = Mlp(in_features=dim,
460
+ hidden_features=mlp_hidden_dim,
461
+ act_layer=act_layer,
462
+ drop=drop)
463
+
464
+ self.H = None
465
+ self.W = None
466
+
467
+ def forward(self, x, mask_matrix):
468
+ """ Forward function.
469
+
470
+ Args:
471
+ x: Input feature, tensor size (B, H*W, C).
472
+ H, W: Spatial resolution of the input feature.
473
+ mask_matrix: Attention mask for cyclic shift.
474
+ """
475
+ B, L, C = x.shape
476
+ H, W = self.H, self.W
477
+ assert L == H * W, "input feature has wrong size"
478
+
479
+ shortcut = x
480
+ x = self.norm1(x)
481
+ x = x.view(B, H, W, C)
482
+
483
+ # pad feature maps to multiples of window size
484
+ pad_l = pad_t = 0
485
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
486
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
487
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
488
+ _, Hp, Wp, _ = x.shape
489
+
490
+ # cyclic shift
491
+ if self.shift_size > 0:
492
+ shifted_x = torch.roll(x,
493
+ shifts=(-self.shift_size, -self.shift_size),
494
+ dims=(1, 2))
495
+ attn_mask = mask_matrix
496
+ else:
497
+ shifted_x = x
498
+ attn_mask = None
499
+
500
+ # partition windows
501
+ x_windows = window_partition(
502
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
503
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
504
+ C) # nW*B, window_size*window_size, C
505
+
506
+ # W-MSA/SW-MSA
507
+ attn_windows = self.attn(
508
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
509
+
510
+ # merge windows
511
+ attn_windows = attn_windows.view(-1, self.window_size,
512
+ self.window_size, C)
513
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp,
514
+ Wp) # B H' W' C
515
+
516
+ # reverse cyclic shift
517
+ if self.shift_size > 0:
518
+ x = torch.roll(shifted_x,
519
+ shifts=(self.shift_size, self.shift_size),
520
+ dims=(1, 2))
521
+ else:
522
+ x = shifted_x
523
+
524
+ if pad_r > 0 or pad_b > 0:
525
+ x = x[:, :H, :W, :].contiguous()
526
+
527
+ x = x.view(B, H * W, C)
528
+
529
+ # FFN
530
+ x = shortcut + self.drop_path(x)
531
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
532
+
533
+ return x
534
+
535
+
536
+ class PatchMerging(nn.Module):
537
+ """ Patch Merging Layer
538
+
539
+ Args:
540
+ dim (int): Number of input channels.
541
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
542
+ """
543
+
544
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
545
+ super().__init__()
546
+ self.dim = dim
547
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
548
+ self.norm = norm_layer(4 * dim)
549
+
550
+ def forward(self, x, H, W):
551
+ """ Forward function.
552
+
553
+ Args:
554
+ x: Input feature, tensor size (B, H*W, C).
555
+ H, W: Spatial resolution of the input feature.
556
+ """
557
+ B, L, C = x.shape
558
+ assert L == H * W, "input feature has wrong size"
559
+
560
+ x = x.view(B, H, W, C)
561
+
562
+ # padding
563
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
564
+ if pad_input:
565
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
566
+
567
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
568
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
569
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
570
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
571
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
572
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
573
+
574
+ x = self.norm(x)
575
+ x = self.reduction(x)
576
+
577
+ return x
578
+
579
+
580
+ class BasicLayer(nn.Module):
581
+ """ A basic Swin Transformer layer for one stage.
582
+
583
+ Args:
584
+ dim (int): Number of feature channels
585
+ depth (int): Depths of this stage.
586
+ num_heads (int): Number of attention head.
587
+ window_size (int): Local window size. Default: 7.
588
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
589
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
590
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
591
+ drop (float, optional): Dropout rate. Default: 0.0
592
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
593
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
594
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
595
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
596
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
597
+ """
598
+
599
+ def __init__(self,
600
+ dim,
601
+ depth,
602
+ num_heads,
603
+ window_size=7,
604
+ mlp_ratio=4.,
605
+ qkv_bias=True,
606
+ qk_scale=None,
607
+ drop=0.,
608
+ attn_drop=0.,
609
+ drop_path=0.,
610
+ norm_layer=nn.LayerNorm,
611
+ downsample=None,
612
+ use_checkpoint=False):
613
+ super().__init__()
614
+ self.window_size = window_size
615
+ self.shift_size = window_size // 2
616
+ self.depth = depth
617
+ self.use_checkpoint = use_checkpoint
618
+
619
+ # build blocks
620
+ self.blocks = nn.ModuleList([
621
+ SwinTransformerBlock(dim=dim,
622
+ num_heads=num_heads,
623
+ window_size=window_size,
624
+ shift_size=0 if
625
+ (i % 2 == 0) else window_size // 2,
626
+ mlp_ratio=mlp_ratio,
627
+ qkv_bias=qkv_bias,
628
+ qk_scale=qk_scale,
629
+ drop=drop,
630
+ attn_drop=attn_drop,
631
+ drop_path=drop_path[i] if isinstance(
632
+ drop_path, list) else drop_path,
633
+ norm_layer=norm_layer) for i in range(depth)
634
+ ])
635
+
636
+ # patch merging layer
637
+ if downsample is not None:
638
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
639
+ else:
640
+ self.downsample = None
641
+
642
+ def forward(self, x, H, W):
643
+ """ Forward function.
644
+
645
+ Args:
646
+ x: Input feature, tensor size (B, H*W, C).
647
+ H, W: Spatial resolution of the input feature.
648
+ """
649
+
650
+ # calculate attention mask for SW-MSA
651
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
652
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
653
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
654
+ h_slices = (slice(0, -self.window_size),
655
+ slice(-self.window_size,
656
+ -self.shift_size), slice(-self.shift_size, None))
657
+ w_slices = (slice(0, -self.window_size),
658
+ slice(-self.window_size,
659
+ -self.shift_size), slice(-self.shift_size, None))
660
+ cnt = 0
661
+ for h in h_slices:
662
+ for w in w_slices:
663
+ img_mask[:, h, w, :] = cnt
664
+ cnt += 1
665
+
666
+ mask_windows = window_partition(
667
+ img_mask, self.window_size) # nW, window_size, window_size, 1
668
+ mask_windows = mask_windows.view(-1,
669
+ self.window_size * self.window_size)
670
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
671
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
672
+ float(-100.0)).masked_fill(
673
+ attn_mask == 0, float(0.0))
674
+
675
+ for blk in self.blocks:
676
+ blk.H, blk.W = H, W
677
+ if self.use_checkpoint:
678
+ x = checkpoint.checkpoint(blk, x, attn_mask)
679
+ else:
680
+ x = blk(x, attn_mask)
681
+ if self.downsample is not None:
682
+ x_down = self.downsample(x, H, W)
683
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
684
+ return x, H, W, x_down, Wh, Ww
685
+ else:
686
+ return x, H, W, x, H, W
687
+
688
+
689
+ class PatchEmbed(nn.Module):
690
+ """ Image to Patch Embedding
691
+
692
+ Args:
693
+ patch_size (int): Patch token size. Default: 4.
694
+ in_chans (int): Number of input image channels. Default: 3.
695
+ embed_dim (int): Number of linear projection output channels. Default: 96.
696
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
697
+ """
698
+
699
+ def __init__(self,
700
+ patch_size=4,
701
+ in_chans=3,
702
+ embed_dim=96,
703
+ norm_layer=None):
704
+ super().__init__()
705
+ patch_size = to_2tuple(patch_size)
706
+ self.patch_size = patch_size
707
+
708
+ self.in_chans = in_chans
709
+ self.embed_dim = embed_dim
710
+
711
+ self.proj = nn.Conv2d(in_chans,
712
+ embed_dim,
713
+ kernel_size=patch_size,
714
+ stride=patch_size)
715
+ if norm_layer is not None:
716
+ self.norm = norm_layer(embed_dim)
717
+ else:
718
+ self.norm = None
719
+
720
+ def forward(self, x):
721
+ """Forward function."""
722
+ # padding
723
+ _, _, H, W = x.size()
724
+ if W % self.patch_size[1] != 0:
725
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
726
+ if H % self.patch_size[0] != 0:
727
+ x = F.pad(x,
728
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
729
+
730
+ x = self.proj(x) # B C Wh Ww
731
+ if self.norm is not None:
732
+ Wh, Ww = x.size(2), x.size(3)
733
+ x = x.flatten(2).transpose(1, 2)
734
+ x = self.norm(x)
735
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
736
+
737
+ return x
738
+
739
+
740
+ class SwinTransformer(nn.Module):
741
+ """ Swin Transformer backbone.
742
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
743
+ https://arxiv.org/pdf/2103.14030
744
+
745
+ Args:
746
+ pretrain_img_size (int): Input image size for training the pretrained model,
747
+ used in absolute postion embedding. Default 224.
748
+ patch_size (int | tuple(int)): Patch size. Default: 4.
749
+ in_chans (int): Number of input image channels. Default: 3.
750
+ embed_dim (int): Number of linear projection output channels. Default: 96.
751
+ depths (tuple[int]): Depths of each Swin Transformer stage.
752
+ num_heads (tuple[int]): Number of attention head of each stage.
753
+ window_size (int): Window size. Default: 7.
754
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
755
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
756
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
757
+ drop_rate (float): Dropout rate.
758
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
759
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
760
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
761
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
762
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
763
+ out_indices (Sequence[int]): Output from which stages.
764
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
765
+ -1 means not freezing any parameters.
766
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
767
+ """
768
+
769
+ def __init__(self,
770
+ pretrain_img_size=224,
771
+ patch_size=4,
772
+ in_chans=3,
773
+ embed_dim=96,
774
+ depths=[2, 2, 6, 2],
775
+ num_heads=[3, 6, 12, 24],
776
+ window_size=7,
777
+ mlp_ratio=4.,
778
+ qkv_bias=True,
779
+ qk_scale=None,
780
+ drop_rate=0.,
781
+ attn_drop_rate=0.,
782
+ drop_path_rate=0.2,
783
+ norm_layer=nn.LayerNorm,
784
+ ape=False,
785
+ patch_norm=True,
786
+ out_indices=(0, 1, 2, 3),
787
+ frozen_stages=-1,
788
+ use_checkpoint=False):
789
+ super().__init__()
790
+
791
+ self.pretrain_img_size = pretrain_img_size
792
+ self.num_layers = len(depths)
793
+ self.embed_dim = embed_dim
794
+ self.ape = ape
795
+ self.patch_norm = patch_norm
796
+ self.out_indices = out_indices
797
+ self.frozen_stages = frozen_stages
798
+
799
+ # split image into non-overlapping patches
800
+ self.patch_embed = PatchEmbed(
801
+ patch_size=patch_size,
802
+ in_chans=in_chans,
803
+ embed_dim=embed_dim,
804
+ norm_layer=norm_layer if self.patch_norm else None)
805
+
806
+ # absolute position embedding
807
+ if self.ape:
808
+ pretrain_img_size = to_2tuple(pretrain_img_size)
809
+ patch_size = to_2tuple(patch_size)
810
+ patches_resolution = [
811
+ pretrain_img_size[0] // patch_size[0],
812
+ pretrain_img_size[1] // patch_size[1]
813
+ ]
814
+
815
+ self.absolute_pos_embed = nn.Parameter(
816
+ torch.zeros(1, embed_dim, patches_resolution[0],
817
+ patches_resolution[1]))
818
+ trunc_normal_(self.absolute_pos_embed, std=.02)
819
+
820
+ self.pos_drop = nn.Dropout(p=drop_rate)
821
+
822
+ # stochastic depth
823
+ dpr = [
824
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
825
+ ] # stochastic depth decay rule
826
+
827
+ # build layers
828
+ self.layers = nn.ModuleList()
829
+ for i_layer in range(self.num_layers):
830
+ layer = BasicLayer(
831
+ dim=int(embed_dim * 2**i_layer),
832
+ depth=depths[i_layer],
833
+ num_heads=num_heads[i_layer],
834
+ window_size=window_size,
835
+ mlp_ratio=mlp_ratio,
836
+ qkv_bias=qkv_bias,
837
+ qk_scale=qk_scale,
838
+ drop=drop_rate,
839
+ attn_drop=attn_drop_rate,
840
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
841
+ norm_layer=norm_layer,
842
+ downsample=PatchMerging if
843
+ (i_layer < self.num_layers - 1) else None,
844
+ use_checkpoint=use_checkpoint)
845
+ self.layers.append(layer)
846
+
847
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
848
+ self.num_features = num_features
849
+
850
+ # add a norm layer for each output
851
+ for i_layer in out_indices:
852
+ layer = norm_layer(num_features[i_layer])
853
+ layer_name = f'norm{i_layer}'
854
+ self.add_module(layer_name, layer)
855
+
856
+ self._freeze_stages()
857
+
858
+ def _freeze_stages(self):
859
+ if self.frozen_stages >= 0:
860
+ self.patch_embed.eval()
861
+ for param in self.patch_embed.parameters():
862
+ param.requires_grad = False
863
+
864
+ if self.frozen_stages >= 1 and self.ape:
865
+ self.absolute_pos_embed.requires_grad = False
866
+
867
+ if self.frozen_stages >= 2:
868
+ self.pos_drop.eval()
869
+ for i in range(0, self.frozen_stages - 1):
870
+ m = self.layers[i]
871
+ m.eval()
872
+ for param in m.parameters():
873
+ param.requires_grad = False
874
+
875
+ def init_weights(self, pretrained=None):
876
+ """Initialize the weights in backbone.
877
+
878
+ Args:
879
+ pretrained (str, optional): Path to pre-trained weights.
880
+ Defaults to None.
881
+ """
882
+
883
+ def _init_weights(m):
884
+ if isinstance(m, nn.Linear):
885
+ trunc_normal_(m.weight, std=.02)
886
+ if isinstance(m, nn.Linear) and m.bias is not None:
887
+ nn.init.constant_(m.bias, 0)
888
+ elif isinstance(m, nn.LayerNorm):
889
+ nn.init.constant_(m.bias, 0)
890
+ nn.init.constant_(m.weight, 1.0)
891
+
892
+ if isinstance(pretrained, str):
893
+ self.apply(_init_weights)
894
+ load_checkpoint(self, pretrained, strict=False, logger=None)
895
+ elif pretrained is None:
896
+ self.apply(_init_weights)
897
+ else:
898
+ raise TypeError('pretrained must be a str or None')
899
+
900
+ def forward(self, x):
901
+ x = self.patch_embed(x)
902
+
903
+ Wh, Ww = x.size(2), x.size(3)
904
+ if self.ape:
905
+ # interpolate the position embedding to the corresponding size
906
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed,
907
+ size=(Wh, Ww),
908
+ mode='bicubic')
909
+ x = (x + absolute_pos_embed) # B Wh*Ww C
910
+
911
+ outs = [x.contiguous()]
912
+ x = x.flatten(2).transpose(1, 2)
913
+ x = self.pos_drop(x)
914
+ for i in range(self.num_layers):
915
+ layer = self.layers[i]
916
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
917
+
918
+ if i in self.out_indices:
919
+ norm_layer = getattr(self, f'norm{i}')
920
+ x_out = norm_layer(x_out)
921
+
922
+ out = x_out.view(-1, H, W,
923
+ self.num_features[i]).permute(0, 3, 1,
924
+ 2).contiguous()
925
+ outs.append(out)
926
+
927
+ return tuple(outs)
928
+
929
+ def train(self, mode=True):
930
+ """Convert the model into training mode while keep layers freezed."""
931
+ super(SwinTransformer, self).train(mode)
932
+ self._freeze_stages()
933
+
934
+
935
+ class PositionEmbeddingSine:
936
+
937
+ def __init__(self,
938
+ num_pos_feats=64,
939
+ temperature=10000,
940
+ normalize=False,
941
+ scale=None):
942
+ super().__init__()
943
+ self.num_pos_feats = num_pos_feats
944
+ self.temperature = temperature
945
+ self.normalize = normalize
946
+ if scale is not None and normalize is False:
947
+ raise ValueError("normalize should be True if scale is passed")
948
+ if scale is None:
949
+ scale = 2 * math.pi
950
+ self.scale = scale
951
+ self.dim_t = torch.arange(0,
952
+ self.num_pos_feats,
953
+ dtype=torch_dtype,
954
+ device=torch_device)
955
+
956
+ def __call__(self, b, h, w):
957
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=torch_device)
958
+ assert mask is not None
959
+ not_mask = ~mask
960
+ y_embed = not_mask.cumsum(dim=1, dtype=torch_dtype)
961
+ x_embed = not_mask.cumsum(dim=2, dtype=torch_dtype)
962
+ if self.normalize:
963
+ eps = 1e-6
964
+ y_embed = ((y_embed - 0.5) / (y_embed[:, -1:, :] + eps) *
965
+ self.scale).to(device=torch_device, dtype=torch_dtype)
966
+ x_embed = ((x_embed - 0.5) / (x_embed[:, :, -1:] + eps) *
967
+ self.scale).to(device=torch_device, dtype=torch_dtype)
968
+
969
+ dim_t = self.temperature**(2 * (self.dim_t // 2) / self.num_pos_feats)
970
+
971
+ pos_x = x_embed[:, :, :, None] / dim_t
972
+ pos_y = y_embed[:, :, :, None] / dim_t
973
+ pos_x = torch.stack(
974
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
975
+ dim=4).flatten(3)
976
+ pos_y = torch.stack(
977
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
978
+ dim=4).flatten(3)
979
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
980
+
981
+
982
+ class MCLM(nn.Module):
983
+
984
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
985
+ super(MCLM, self).__init__()
986
+ self.attention = nn.ModuleList([
987
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
988
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
989
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
990
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
991
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
992
+ ])
993
+
994
+ self.linear1 = nn.Linear(d_model, d_model * 2)
995
+ self.linear2 = nn.Linear(d_model * 2, d_model)
996
+ self.linear3 = nn.Linear(d_model, d_model * 2)
997
+ self.linear4 = nn.Linear(d_model * 2, d_model)
998
+ self.norm1 = nn.LayerNorm(d_model)
999
+ self.norm2 = nn.LayerNorm(d_model)
1000
+ self.dropout = nn.Dropout(0.1)
1001
+ self.dropout1 = nn.Dropout(0.1)
1002
+ self.dropout2 = nn.Dropout(0.1)
1003
+ self.activation = get_activation_fn('relu')
1004
+ self.pool_ratios = pool_ratios
1005
+ self.p_poses = []
1006
+ self.g_pos = None
1007
+ self.positional_encoding = PositionEmbeddingSine(
1008
+ num_pos_feats=d_model // 2, normalize=True)
1009
+
1010
+ def forward(self, l, g):
1011
+ """
1012
+ l: 4,c,h,w
1013
+ g: 1,c,h,w
1014
+ """
1015
+ b, c, h, w = l.size()
1016
+ # 4,c,h,w -> 1,c,2h,2w
1017
+ concated_locs = rearrange(l,
1018
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1019
+ hg=2,
1020
+ wg=2)
1021
+
1022
+ pools = []
1023
+ for pool_ratio in self.pool_ratios:
1024
+ # b,c,h,w
1025
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1026
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1027
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1028
+ if self.g_pos is None:
1029
+ pos_emb = self.positional_encoding(pool.shape[0],
1030
+ pool.shape[2],
1031
+ pool.shape[3])
1032
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1033
+ self.p_poses.append(pos_emb)
1034
+ pools = torch.cat(pools, 0)
1035
+ if self.g_pos is None:
1036
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1037
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2],
1038
+ g.shape[3])
1039
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1040
+
1041
+ # attention between glb (q) & multisensory concated-locs (k,v)
1042
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1043
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1044
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1045
+ g_hw_b_c = self.norm1(g_hw_b_c)
1046
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1047
+ self.linear2(
1048
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1049
+ g_hw_b_c = self.norm2(g_hw_b_c)
1050
+
1051
+ # attention between origin locs (q) & freashed glb (k,v)
1052
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1053
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1054
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1055
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1056
+ ng=2,
1057
+ nw=2)
1058
+ outputs_re = []
1059
+ for i, (_l, _g) in enumerate(
1060
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1061
+ outputs_re.append(self.attention[i + 1](_l, _g,
1062
+ _g)[0]) # (h w) 1 c
1063
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1064
+
1065
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1066
+ l_hw_b_c = self.norm1(l_hw_b_c)
1067
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1068
+ self.linear4(
1069
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1070
+ l_hw_b_c = self.norm2(l_hw_b_c)
1071
+
1072
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1073
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1074
+
1075
+
1076
+ class inf_MCLM(nn.Module):
1077
+
1078
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
1079
+ super(inf_MCLM, self).__init__()
1080
+ self.attention = nn.ModuleList([
1081
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1082
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1083
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1084
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1085
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1086
+ ])
1087
+
1088
+ self.linear1 = nn.Linear(d_model, d_model * 2)
1089
+ self.linear2 = nn.Linear(d_model * 2, d_model)
1090
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1091
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1092
+ self.norm1 = nn.LayerNorm(d_model)
1093
+ self.norm2 = nn.LayerNorm(d_model)
1094
+ self.dropout = nn.Dropout(0.1)
1095
+ self.dropout1 = nn.Dropout(0.1)
1096
+ self.dropout2 = nn.Dropout(0.1)
1097
+ self.activation = get_activation_fn('relu')
1098
+ self.pool_ratios = pool_ratios
1099
+ self.p_poses = []
1100
+ self.g_pos = None
1101
+ self.positional_encoding = PositionEmbeddingSine(
1102
+ num_pos_feats=d_model // 2, normalize=True)
1103
+
1104
+ def forward(self, l, g):
1105
+ """
1106
+ l: 4,c,h,w
1107
+ g: 1,c,h,w
1108
+ """
1109
+ b, c, h, w = l.size()
1110
+ # 4,c,h,w -> 1,c,2h,2w
1111
+ concated_locs = rearrange(l,
1112
+ '(hg wg b) c h w -> b c (hg h) (wg w)',
1113
+ hg=2,
1114
+ wg=2)
1115
+ self.p_poses = []
1116
+ pools = []
1117
+ for pool_ratio in self.pool_ratios:
1118
+ # b,c,h,w
1119
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1120
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
1121
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
1122
+ # if self.g_pos is None:
1123
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2],
1124
+ pool.shape[3])
1125
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
1126
+ self.p_poses.append(pos_emb)
1127
+ pools = torch.cat(pools, 0)
1128
+ # if self.g_pos is None:
1129
+ self.p_poses = torch.cat(self.p_poses, dim=0)
1130
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
1131
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
1132
+
1133
+ # attention between glb (q) & multisensory concated-locs (k,v)
1134
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
1135
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](
1136
+ g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
1137
+ g_hw_b_c = self.norm1(g_hw_b_c)
1138
+ g_hw_b_c = g_hw_b_c + self.dropout2(
1139
+ self.linear2(
1140
+ self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
1141
+ g_hw_b_c = self.norm2(g_hw_b_c)
1142
+
1143
+ # attention between origin locs (q) & freashed glb (k,v)
1144
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
1145
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
1146
+ _g_hw_b_c = rearrange(_g_hw_b_c,
1147
+ "(ng h) (nw w) b c -> (h w) (ng nw b) c",
1148
+ ng=2,
1149
+ nw=2)
1150
+ outputs_re = []
1151
+ for i, (_l, _g) in enumerate(
1152
+ zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
1153
+ outputs_re.append(self.attention[i + 1](_l, _g,
1154
+ _g)[0]) # (h w) 1 c
1155
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
1156
+
1157
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
1158
+ l_hw_b_c = self.norm1(l_hw_b_c)
1159
+ l_hw_b_c = l_hw_b_c + self.dropout2(
1160
+ self.linear4(
1161
+ self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
1162
+ l_hw_b_c = self.norm2(l_hw_b_c)
1163
+
1164
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
1165
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
1166
+
1167
+
1168
+ class MCRM(nn.Module):
1169
+
1170
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1171
+ super(MCRM, self).__init__()
1172
+ self.attention = nn.ModuleList([
1173
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1174
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1175
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1176
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1177
+ ])
1178
+
1179
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1180
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1181
+ self.norm1 = nn.LayerNorm(d_model)
1182
+ self.norm2 = nn.LayerNorm(d_model)
1183
+ self.dropout = nn.Dropout(0.1)
1184
+ self.dropout1 = nn.Dropout(0.1)
1185
+ self.dropout2 = nn.Dropout(0.1)
1186
+ self.sigmoid = nn.Sigmoid()
1187
+ self.activation = get_activation_fn('relu')
1188
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1189
+ self.pool_ratios = pool_ratios
1190
+ self.positional_encoding = PositionEmbeddingSine(
1191
+ num_pos_feats=d_model // 2, normalize=True)
1192
+
1193
+ def forward(self, x):
1194
+ b, c, h, w = x.size()
1195
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1196
+ # b(4),c,h,w
1197
+ patched_glb = rearrange(glb,
1198
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1199
+ hg=2,
1200
+ wg=2)
1201
+
1202
+ # generate token attention map
1203
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1204
+ token_attention_map = F.interpolate(token_attention_map,
1205
+ size=patches2image(loc).shape[-2:],
1206
+ mode='nearest')
1207
+ loc = loc * rearrange(token_attention_map,
1208
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1209
+ hg=2,
1210
+ wg=2)
1211
+ pools = []
1212
+ for pool_ratio in self.pool_ratios:
1213
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1214
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1215
+ pools.append(rearrange(pool,
1216
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1217
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1218
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1219
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1220
+ outputs = []
1221
+ for i, q in enumerate(
1222
+ loc_.unbind(dim=0)): # traverse all local patches
1223
+ # np*hw,1,c
1224
+ v = pools[i]
1225
+ k = v
1226
+ outputs.append(self.attention[i](q, k, v)[0])
1227
+ outputs = torch.cat(outputs, 1)
1228
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1229
+ src = self.norm1(src)
1230
+ src = src + self.dropout2(
1231
+ self.linear4(
1232
+ self.dropout(self.activation(self.linear3(src)).clone())))
1233
+ src = self.norm2(src)
1234
+
1235
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1236
+ glb = glb + F.interpolate(patches2image(src),
1237
+ size=glb.shape[-2:],
1238
+ mode='nearest') # freshed glb
1239
+ return torch.cat((src, glb), 0), token_attention_map
1240
+
1241
+
1242
+ class inf_MCRM(nn.Module):
1243
+
1244
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
1245
+ super(inf_MCRM, self).__init__()
1246
+ self.attention = nn.ModuleList([
1247
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1248
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1249
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
1250
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
1251
+ ])
1252
+
1253
+ self.linear3 = nn.Linear(d_model, d_model * 2)
1254
+ self.linear4 = nn.Linear(d_model * 2, d_model)
1255
+ self.norm1 = nn.LayerNorm(d_model)
1256
+ self.norm2 = nn.LayerNorm(d_model)
1257
+ self.dropout = nn.Dropout(0.1)
1258
+ self.dropout1 = nn.Dropout(0.1)
1259
+ self.dropout2 = nn.Dropout(0.1)
1260
+ self.sigmoid = nn.Sigmoid()
1261
+ self.activation = get_activation_fn('relu')
1262
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
1263
+ self.pool_ratios = pool_ratios
1264
+ self.positional_encoding = PositionEmbeddingSine(
1265
+ num_pos_feats=d_model // 2, normalize=True)
1266
+
1267
+ def forward(self, x):
1268
+ b, c, h, w = x.size()
1269
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
1270
+ # b(4),c,h,w
1271
+ patched_glb = rearrange(glb,
1272
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1273
+ hg=2,
1274
+ wg=2)
1275
+
1276
+ # generate token attention map
1277
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
1278
+ token_attention_map = F.interpolate(token_attention_map,
1279
+ size=patches2image(loc).shape[-2:],
1280
+ mode='nearest')
1281
+ loc = loc * rearrange(token_attention_map,
1282
+ 'b c (hg h) (wg w) -> (hg wg b) c h w',
1283
+ hg=2,
1284
+ wg=2)
1285
+ pools = []
1286
+ for pool_ratio in self.pool_ratios:
1287
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
1288
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
1289
+ pools.append(rearrange(pool,
1290
+ 'nl c h w -> nl c (h w)')) # nl(4),c,hw
1291
+ # nl(4),c,nphw -> nl(4),nphw,1,c
1292
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
1293
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
1294
+ outputs = []
1295
+ for i, q in enumerate(
1296
+ loc_.unbind(dim=0)): # traverse all local patches
1297
+ # np*hw,1,c
1298
+ v = pools[i]
1299
+ k = v
1300
+ outputs.append(self.attention[i](q, k, v)[0])
1301
+ outputs = torch.cat(outputs, 1)
1302
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
1303
+ src = self.norm1(src)
1304
+ src = src + self.dropout2(
1305
+ self.linear4(
1306
+ self.dropout(self.activation(self.linear3(src)).clone())))
1307
+ src = self.norm2(src)
1308
+
1309
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
1310
+ glb = glb + F.interpolate(patches2image(src),
1311
+ size=glb.shape[-2:],
1312
+ mode='nearest') # freshed glb
1313
+ return torch.cat((src, glb), 0)
1314
+
1315
+
1316
+ # model for single-scale training
1317
+ class MVANet(nn.Module):
1318
+
1319
+ def __init__(self):
1320
+ super().__init__()
1321
+ self.backbone = SwinB(pretrained=True)
1322
+ emb_dim = 128
1323
+ self.sideout5 = nn.Sequential(
1324
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1325
+ self.sideout4 = nn.Sequential(
1326
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1327
+ self.sideout3 = nn.Sequential(
1328
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1329
+ self.sideout2 = nn.Sequential(
1330
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1331
+ self.sideout1 = nn.Sequential(
1332
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1333
+
1334
+ self.output5 = make_cbr(1024, emb_dim)
1335
+ self.output4 = make_cbr(512, emb_dim)
1336
+ self.output3 = make_cbr(256, emb_dim)
1337
+ self.output2 = make_cbr(128, emb_dim)
1338
+ self.output1 = make_cbr(128, emb_dim)
1339
+
1340
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
1341
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1342
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1343
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1344
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1345
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
1346
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
1347
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
1348
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
1349
+
1350
+ self.insmask_head = nn.Sequential(
1351
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1352
+ nn.BatchNorm2d(384), nn.PReLU(),
1353
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1354
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1355
+
1356
+ self.shallow = nn.Sequential(
1357
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1358
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1359
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1360
+ self.output = nn.Sequential(
1361
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1362
+
1363
+ for m in self.modules():
1364
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1365
+ m.inplace = True
1366
+
1367
+ def forward(self, x):
1368
+ x = x.to(dtype=torch_dtype, device=torch_device)
1369
+ shallow = self.shallow(x)
1370
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1371
+ loc = image2patches(x)
1372
+ input = torch.cat((loc, glb), dim=0)
1373
+ feature = self.backbone(input)
1374
+ e5 = self.output5(feature[4]) # (5,128,16,16)
1375
+ e4 = self.output4(feature[3]) # (5,128,32,32)
1376
+ e3 = self.output3(feature[2]) # (5,128,64,64)
1377
+ e2 = self.output2(feature[1]) # (5,128,128,128)
1378
+ e1 = self.output1(feature[0]) # (5,128,128,128)
1379
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1380
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
1381
+
1382
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
1383
+ e4 = self.conv4(e4)
1384
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
1385
+ e3 = self.conv3(e3)
1386
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
1387
+ e2 = self.conv2(e2)
1388
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
1389
+ e1 = self.conv1(e1)
1390
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1391
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
1392
+ # add glb feat in
1393
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1394
+ # merge
1395
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
1396
+ # shallow feature merge
1397
+ final_output = final_output + resize_as(shallow, final_output)
1398
+ final_output = self.upsample1(rescale_to(final_output))
1399
+ final_output = rescale_to(final_output +
1400
+ resize_as(shallow, final_output))
1401
+ final_output = self.upsample2(final_output)
1402
+ final_output = self.output(final_output)
1403
+ ####
1404
+ sideout5 = self.sideout5(e5).to(dtype=torch_dtype, device=torch_device)
1405
+ sideout4 = self.sideout4(e4)
1406
+ sideout3 = self.sideout3(e3)
1407
+ sideout2 = self.sideout2(e2)
1408
+ sideout1 = self.sideout1(e1)
1409
+ #######glb_sideouts ######
1410
+ glb5 = self.sideout5(glb_e5)
1411
+ glb4 = sideout4[-1, :, :, :].unsqueeze(0)
1412
+ glb3 = sideout3[-1, :, :, :].unsqueeze(0)
1413
+ glb2 = sideout2[-1, :, :, :].unsqueeze(0)
1414
+ glb1 = sideout1[-1, :, :, :].unsqueeze(0)
1415
+ ####### concat 4 to 1 #######
1416
+ sideout1 = patches2image(sideout1[:-1]).to(dtype=torch_dtype,
1417
+ device=torch_device)
1418
+ sideout2 = patches2image(sideout2[:-1]).to(
1419
+ dtype=torch_dtype,
1420
+ device=torch_device) ####(5,c,h,w) -> (1 c 2h,2w)
1421
+ sideout3 = patches2image(sideout3[:-1]).to(dtype=torch_dtype,
1422
+ device=torch_device)
1423
+ sideout4 = patches2image(sideout4[:-1]).to(dtype=torch_dtype,
1424
+ device=torch_device)
1425
+ sideout5 = patches2image(sideout5[:-1]).to(dtype=torch_dtype,
1426
+ device=torch_device)
1427
+ if self.training:
1428
+ return sideout5, sideout4, sideout3, sideout2, sideout1, final_output, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1
1429
+ else:
1430
+ return final_output
1431
+
1432
+
1433
+ # model for multi-scale testing
1434
+ class inf_MVANet(nn.Module):
1435
+
1436
+ def __init__(self):
1437
+ super().__init__()
1438
+ # self.backbone = SwinB(pretrained=True)
1439
+ self.backbone = SwinB(pretrained=False)
1440
+
1441
+ emb_dim = 128
1442
+ self.output5 = make_cbr(1024, emb_dim)
1443
+ self.output4 = make_cbr(512, emb_dim)
1444
+ self.output3 = make_cbr(256, emb_dim)
1445
+ self.output2 = make_cbr(128, emb_dim)
1446
+ self.output1 = make_cbr(128, emb_dim)
1447
+
1448
+ self.multifieldcrossatt = inf_MCLM(emb_dim, 1, [1, 4, 8])
1449
+ self.conv1 = make_cbr(emb_dim, emb_dim)
1450
+ self.conv2 = make_cbr(emb_dim, emb_dim)
1451
+ self.conv3 = make_cbr(emb_dim, emb_dim)
1452
+ self.conv4 = make_cbr(emb_dim, emb_dim)
1453
+ self.dec_blk1 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1454
+ self.dec_blk2 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1455
+ self.dec_blk3 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1456
+ self.dec_blk4 = inf_MCRM(emb_dim, 1, [2, 4, 8])
1457
+
1458
+ self.insmask_head = nn.Sequential(
1459
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
1460
+ nn.BatchNorm2d(384), nn.PReLU(),
1461
+ nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384),
1462
+ nn.PReLU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1))
1463
+
1464
+ self.shallow = nn.Sequential(
1465
+ nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
1466
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
1467
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
1468
+ self.output = nn.Sequential(
1469
+ nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
1470
+
1471
+ for m in self.modules():
1472
+ if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
1473
+ m.inplace = True
1474
+
1475
+ def forward(self, x):
1476
+ shallow = self.shallow(x)
1477
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
1478
+ loc = image2patches(x)
1479
+ input = torch.cat((loc, glb), dim=0)
1480
+ feature = self.backbone(input)
1481
+ e5 = self.output5(feature[4])
1482
+ e4 = self.output4(feature[3])
1483
+ e3 = self.output3(feature[2])
1484
+ e2 = self.output2(feature[1])
1485
+ e1 = self.output1(feature[0])
1486
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
1487
+ e5_cat = self.multifieldcrossatt(loc_e5, glb_e5)
1488
+
1489
+ e4 = self.conv4(self.dec_blk4(e4 + resize_as(e5_cat, e4)))
1490
+ e3 = self.conv3(self.dec_blk3(e3 + resize_as(e4, e3)))
1491
+ e2 = self.conv2(self.dec_blk2(e2 + resize_as(e3, e2)))
1492
+ e1 = self.conv1(self.dec_blk1(e1 + resize_as(e2, e1)))
1493
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
1494
+ # after decoder, concat loc features to a whole one, and merge
1495
+ output1_cat = patches2image(loc_e1)
1496
+ # add glb feat in
1497
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
1498
+ # merge
1499
+ final_output = self.insmask_head(output1_cat)
1500
+ # shallow feature merge
1501
+ final_output = final_output + resize_as(shallow, final_output)
1502
+ final_output = self.upsample1(rescale_to(final_output))
1503
+ final_output = rescale_to(final_output +
1504
+ resize_as(shallow, final_output))
1505
+ final_output = self.upsample2(final_output)
1506
+ final_output = self.output(final_output)
1507
+ return final_output
1508
+ #+end_src
1509
+
1510
+ ** Function to load model
1511
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1512
+ def load_model(model_checkpoint_path):
1513
+ torch.cuda.set_device(0)
1514
+
1515
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
1516
+
1517
+ pretrained_dict = torch.load(model_checkpoint_path,
1518
+ map_location=torch_device)
1519
+
1520
+ model_dict = net.state_dict()
1521
+ pretrained_dict = {
1522
+ k: v
1523
+ for k, v in pretrained_dict.items() if k in model_dict
1524
+ }
1525
+ model_dict.update(pretrained_dict)
1526
+ net.load_state_dict(model_dict)
1527
+ net = net.to(dtype=torch_dtype, device=torch_device)
1528
+ net.eval()
1529
+ return net
1530
+
1531
+
1532
+ def load_transforms_stripped():
1533
+ img_transform = transforms.Compose([
1534
+ # transforms.ToTensor(),
1535
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1536
+ ])
1537
+
1538
+ return img_transform
1539
+
1540
+
1541
+ def load_transforms():
1542
+ img_transform = transforms.Compose([
1543
+ # transforms.ToTensor(),
1544
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1545
+ ])
1546
+
1547
+ depth_transform = transforms.ToTensor()
1548
+ target_transform = transforms.ToTensor()
1549
+ to_pil = transforms.ToPILImage()
1550
+
1551
+ transforms_var = tta.Compose([
1552
+ tta.HorizontalFlip(),
1553
+ tta.Scale(scales=[0.75, 1, 1.25],
1554
+ interpolation='bilinear',
1555
+ align_corners=False),
1556
+ ])
1557
+
1558
+ return (img_transform, depth_transform, target_transform, to_pil,
1559
+ transforms_var)
1560
+ #+end_src
1561
+
1562
+ ** Function for modular inference CV
1563
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1564
+ def do_infer_tensor2tensor(img, net):
1565
+
1566
+ img_transform = transforms.Compose(
1567
+ [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
1568
+
1569
+ h_, w_ = img.shape[1], img.shape[2]
1570
+
1571
+ with torch.no_grad():
1572
+
1573
+ img = rearrange(img, 'B H W C -> B C H W')
1574
+
1575
+ img_resize = torch.nn.functional.interpolate(input=img,
1576
+ size=(1024, 1024),
1577
+ mode='bicubic',
1578
+ antialias=True)
1579
+
1580
+ img_var = img_transform(img_resize)
1581
+ img_var = Variable(img_var)
1582
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1583
+
1584
+ mask = []
1585
+
1586
+ mask.append(net(img_var))
1587
+
1588
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1589
+ prediction = prediction.sigmoid()
1590
+
1591
+ prediction = torch.nn.functional.interpolate(input=prediction,
1592
+ size=(h_, w_),
1593
+ mode='bicubic',
1594
+ antialias=True)
1595
+
1596
+ prediction = prediction.squeeze(0)
1597
+ prediction = prediction.clamp(0, 1)
1598
+
1599
+ return prediction
1600
+
1601
+
1602
+ def do_infer_modular_cv(input_image_path, output_mask_path, net,
1603
+ all_transforms):
1604
+
1605
+ (img_transform, depth_transform, target_transform, to_pil,
1606
+ transforms_var) = all_transforms
1607
+
1608
+ img = load_image_torch(input_image_path)
1609
+
1610
+ h_, w_ = img.shape[1], img.shape[2]
1611
+
1612
+ with torch.no_grad():
1613
+
1614
+ img = rearrange(img, 'B H W C -> B C H W')
1615
+
1616
+ img_resize = torch.nn.functional.interpolate(input=img,
1617
+ size=(1024, 1024),
1618
+ mode='bicubic',
1619
+ antialias=True)
1620
+
1621
+ img_var = img_transform(img_resize)
1622
+ img_var = Variable(img_var)
1623
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1624
+
1625
+ mask = []
1626
+
1627
+ for transformer in transforms_var:
1628
+ rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
1629
+ mask.append(net(rgb_trans))
1630
+
1631
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1632
+ prediction = prediction.sigmoid()
1633
+
1634
+ prediction = torch.nn.functional.interpolate(input=prediction,
1635
+ size=(h_, w_),
1636
+ mode='bicubic',
1637
+ antialias=True)
1638
+
1639
+ prediction = prediction.squeeze(0)
1640
+ prediction = prediction.clamp(0, 1)
1641
+
1642
+ save_mask_torch(output_image_path=output_mask_path, mask=prediction)
1643
+
1644
+
1645
+ def do_infer_modular_cv_2(input_image_path, output_mask_path, net,
1646
+ all_transforms):
1647
+
1648
+ (img_transform, depth_transform, target_transform, to_pil,
1649
+ transforms_var) = all_transforms
1650
+
1651
+ img = load_image(input_image_path)
1652
+ w_, h_ = img.shape[0], img.shape[1]
1653
+ img_resize = cv2.resize(img, (1024, 1024), cv2.INTER_CUBIC)
1654
+
1655
+ with torch.no_grad():
1656
+
1657
+ # rgb_png_path = input_image_path
1658
+ # img = Image.open(rgb_png_path).convert('RGB')
1659
+ # w_, h_ = img.size
1660
+
1661
+ # img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
1662
+
1663
+ # img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
1664
+ # dtype=torch_dtype, device=torch_device)
1665
+
1666
+ img_resize = torch.from_numpy(img_resize)
1667
+ img_resize = img_resize.to(dtype=torch.float32)
1668
+ img_resize /= 255.0
1669
+ img_resize = rearrange(img_resize, 'H W C -> C H W')
1670
+ img_var = img_transform(img_resize)
1671
+ img_var = img_var.unsqueeze(0)
1672
+ img_var = Variable(img_var)
1673
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1674
+
1675
+ mask = []
1676
+
1677
+ for transformer in transforms_var:
1678
+ rgb_trans = transformer.augment_image(img_var)
1679
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1680
+ model_output = net(rgb_trans)
1681
+ deaug_mask = transformer.deaugment_mask(model_output)
1682
+ mask.append(deaug_mask)
1683
+
1684
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1685
+ prediction = prediction.sigmoid()
1686
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1687
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1688
+ prediction.save(output_mask_path)
1689
+
1690
+
1691
+ def do_infer_modular_cv_3(input_image_path, output_mask_path, net,
1692
+ all_transforms):
1693
+
1694
+ (img_transform, depth_transform, target_transform, to_pil,
1695
+ transforms_var) = all_transforms
1696
+
1697
+ img = load_image(input_image_path)
1698
+ w_, h_ = img.shape[0], img.shape[1]
1699
+
1700
+ with torch.no_grad():
1701
+
1702
+ # rgb_png_path = input_image_path
1703
+ # img = Image.open(rgb_png_path).convert('RGB')
1704
+ # w_, h_ = img.size
1705
+
1706
+ # img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
1707
+
1708
+ # img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
1709
+ # dtype=torch_dtype, device=torch_device)
1710
+
1711
+ img_resize = torch.from_numpy(img)
1712
+ img_resize = img_resize.to(dtype=torch.float32)
1713
+ img_resize = rearrange(img_resize, 'H W C -> C H W')
1714
+ img_resize = img_resize.unsqueeze(0)
1715
+
1716
+ img_resize = torch.nn.functional.interpolate(input=img_resize,
1717
+ size=(1024, 1024),
1718
+ mode='bicubic',
1719
+ antialias=True)
1720
+
1721
+ img_resize = img_resize.squeeze(0)
1722
+ img_resize = rearrange(img_resize, 'C H W -> H W C')
1723
+
1724
+ img_resize = img_resize.to(dtype=torch.float32)
1725
+ img_resize /= 255.0
1726
+ img_resize = rearrange(img_resize, 'H W C -> C H W')
1727
+ img_var = img_transform(img_resize)
1728
+ img_var = img_var.unsqueeze(0)
1729
+ img_var = Variable(img_var)
1730
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1731
+
1732
+ mask = []
1733
+
1734
+ for transformer in transforms_var:
1735
+ rgb_trans = transformer.augment_image(img_var)
1736
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1737
+ model_output = net(rgb_trans)
1738
+ deaug_mask = transformer.deaugment_mask(model_output)
1739
+ mask.append(deaug_mask)
1740
+
1741
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1742
+ prediction = prediction.sigmoid()
1743
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1744
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1745
+ prediction.save(output_mask_path)
1746
+
1747
+
1748
+ def do_infer_modular_cv_4(input_image_path, output_mask_path, net,
1749
+ all_transforms):
1750
+
1751
+ (img_transform, depth_transform, target_transform, to_pil,
1752
+ transforms_var) = all_transforms
1753
+
1754
+ img = load_image(input_image_path)
1755
+ w_, h_ = img.shape[0], img.shape[1]
1756
+
1757
+ with torch.no_grad():
1758
+
1759
+ img_resize = torch.from_numpy(img)
1760
+ img_resize = img_resize.to(dtype=torch.float32)
1761
+ img_resize /= 255.0
1762
+ img_resize = img_resize.unsqueeze(0)
1763
+
1764
+ img_resize = rearrange(img_resize, 'B H W C -> B C H W')
1765
+
1766
+ img_resize = torch.nn.functional.interpolate(input=img_resize,
1767
+ size=(1024, 1024),
1768
+ mode='bicubic',
1769
+ antialias=True)
1770
+
1771
+ img_resize = img_resize.squeeze(0)
1772
+ img_var = img_transform(img_resize)
1773
+ img_var = img_var.unsqueeze(0)
1774
+ img_var = Variable(img_var)
1775
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1776
+
1777
+ mask = []
1778
+
1779
+ for transformer in transforms_var:
1780
+ rgb_trans = transformer.augment_image(img_var)
1781
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1782
+ model_output = net(rgb_trans)
1783
+ deaug_mask = transformer.deaugment_mask(model_output)
1784
+ mask.append(deaug_mask)
1785
+
1786
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1787
+ prediction = prediction.sigmoid()
1788
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1789
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1790
+ prediction.save(output_mask_path)
1791
+
1792
+
1793
+ def do_infer_modular_cv_5(input_image_path, output_mask_path, net,
1794
+ all_transforms):
1795
+
1796
+ (img_transform, depth_transform, target_transform, to_pil,
1797
+ transforms_var) = all_transforms
1798
+
1799
+ img = load_image(input_image_path)
1800
+ w_, h_ = img.shape[0], img.shape[1]
1801
+
1802
+ with torch.no_grad():
1803
+
1804
+ img_resize = torch.from_numpy(img)
1805
+ img_resize = img_resize.to(dtype=torch.float32)
1806
+ img_resize /= 255.0
1807
+ img_resize = img_resize.unsqueeze(0)
1808
+
1809
+ img_resize = rearrange(img_resize, 'B H W C -> B C H W')
1810
+
1811
+ img_resize = torch.nn.functional.interpolate(input=img_resize,
1812
+ size=(1024, 1024),
1813
+ mode='bicubic',
1814
+ antialias=True)
1815
+
1816
+ img_var = img_transform(img_resize)
1817
+ img_var = Variable(img_var)
1818
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1819
+
1820
+ mask = []
1821
+
1822
+ for transformer in transforms_var:
1823
+ rgb_trans = transformer.augment_image(img_var)
1824
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1825
+ model_output = net(rgb_trans)
1826
+ deaug_mask = transformer.deaugment_mask(model_output)
1827
+ mask.append(deaug_mask)
1828
+
1829
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1830
+ prediction = prediction.sigmoid()
1831
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1832
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1833
+ prediction.save(output_mask_path)
1834
+
1835
+
1836
+ def do_infer_modular_cv_6(input_image_path, output_mask_path, net,
1837
+ all_transforms):
1838
+
1839
+ (img_transform, depth_transform, target_transform, to_pil,
1840
+ transforms_var) = all_transforms
1841
+
1842
+ img = load_image(input_image_path)
1843
+ w_, h_ = img.shape[0], img.shape[1]
1844
+
1845
+ with torch.no_grad():
1846
+
1847
+ img_resize = torch.from_numpy(img)
1848
+ img_resize = img_resize.to(dtype=torch.float32)
1849
+ img_resize /= 255.0
1850
+ img_resize = img_resize.unsqueeze(0)
1851
+
1852
+ img_resize = rearrange(img_resize, 'B H W C -> B C H W')
1853
+
1854
+ img_resize = torch.nn.functional.interpolate(input=img_resize,
1855
+ size=(1024, 1024),
1856
+ mode='bicubic',
1857
+ antialias=True)
1858
+
1859
+ img_var = img_transform(img_resize)
1860
+ img_var = Variable(img_var)
1861
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1862
+
1863
+ mask = []
1864
+
1865
+ for transformer in transforms_var:
1866
+ rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
1867
+ mask.append(net(rgb_trans))
1868
+
1869
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1870
+ prediction = prediction.sigmoid()
1871
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1872
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1873
+ prediction.save(output_mask_path)
1874
+
1875
+
1876
+ def do_infer_modular_cv_7(input_image_path, output_mask_path, net,
1877
+ all_transforms):
1878
+
1879
+ (img_transform, depth_transform, target_transform, to_pil,
1880
+ transforms_var) = all_transforms
1881
+
1882
+ img = load_image_torch(input_image_path)
1883
+
1884
+ h_, w_ = img.shape[1], img.shape[2]
1885
+
1886
+ with torch.no_grad():
1887
+
1888
+ img = rearrange(img, 'B H W C -> B C H W')
1889
+
1890
+ img_resize = torch.nn.functional.interpolate(input=img,
1891
+ size=(1024, 1024),
1892
+ mode='bicubic',
1893
+ antialias=True)
1894
+
1895
+ img_var = img_transform(img_resize)
1896
+ img_var = Variable(img_var)
1897
+ img_var = img_var.to(dtype=torch_dtype, device=torch_device)
1898
+
1899
+ mask = []
1900
+
1901
+ for transformer in transforms_var:
1902
+ rgb_trans = img_var.to(dtype=torch_dtype, device=torch_device)
1903
+ mask.append(net(rgb_trans))
1904
+
1905
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1906
+ prediction = prediction.sigmoid()
1907
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1908
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1909
+ prediction.save(output_mask_path)
1910
+ #+end_src
1911
+
1912
+ ** Function for modular inference
1913
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1914
+ def do_infer_modular(input_image_path, output_mask_path, net, all_transforms):
1915
+ # net = load_model(finetuned_MVANet_model_path)
1916
+
1917
+ (img_transform, depth_transform, target_transform, to_pil,
1918
+ transforms_var) = all_transforms
1919
+
1920
+ with torch.no_grad():
1921
+ rgb_png_path = input_image_path
1922
+ img = Image.open(rgb_png_path).convert('RGB')
1923
+
1924
+ w_, h_ = img.size
1925
+ # img_resize = img.resize([(w_ // 2) * 2, (h_ // 2) * 2], Image.BILINEAR)
1926
+ img_resize = img.resize([256 * 4, 256 * 4], Image.BILINEAR)
1927
+ # img_resize = img
1928
+ img_var = Variable(img_transform(img_resize).unsqueeze(0)).to(
1929
+ dtype=torch_dtype, device=torch_device)
1930
+ mask = []
1931
+ for transformer in transforms_var:
1932
+ rgb_trans = transformer.augment_image(img_var)
1933
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1934
+ model_output = net(rgb_trans)
1935
+ deaug_mask = transformer.deaugment_mask(model_output)
1936
+ mask.append(deaug_mask)
1937
+
1938
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1939
+ prediction = prediction.sigmoid()
1940
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1941
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
1942
+ prediction.save(output_mask_path)
1943
+ #+end_src
1944
+
1945
+ ** Function for inference
1946
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
1947
+ def do_infer():
1948
+ torch.cuda.set_device(0)
1949
+ args = {'crf_refine': True, 'save_results': True}
1950
+
1951
+ img_transform = transforms.Compose([
1952
+ transforms.ToTensor(),
1953
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1954
+ ])
1955
+
1956
+ depth_transform = transforms.ToTensor()
1957
+ target_transform = transforms.ToTensor()
1958
+ to_pil = transforms.ToPILImage()
1959
+
1960
+ transforms_var = tta.Compose([
1961
+ tta.HorizontalFlip(),
1962
+ tta.Scale(scales=[0.75, 1, 1.25],
1963
+ interpolation='bilinear',
1964
+ align_corners=False),
1965
+ ])
1966
+
1967
+ net = inf_MVANet().to(dtype=torch_dtype, device=torch_device)
1968
+ pretrained_dict = torch.load(finetuned_MVANet_model_path,
1969
+ map_location=torch_device)
1970
+ model_dict = net.state_dict()
1971
+ pretrained_dict = {
1972
+ k: v
1973
+ for k, v in pretrained_dict.items() if k in model_dict
1974
+ }
1975
+ model_dict.update(pretrained_dict)
1976
+ net.load_state_dict(model_dict)
1977
+ net = net.to(dtype=torch_dtype, device=torch_device)
1978
+ net.eval()
1979
+ with torch.no_grad():
1980
+ rgb_png_path = '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png'
1981
+ img = Image.open(rgb_png_path).convert('RGB')
1982
+ w_, h_ = img.size
1983
+ # img_resize = img.resize([(w_ // 2) * 2, (h_ // 2) * 2], Image.BILINEAR)
1984
+ img_resize = img.resize([256 * 4 , 256 * 4 ], Image.BILINEAR)
1985
+ # img_resize = img
1986
+ img_var = Variable(img_transform(img_resize).unsqueeze(0),
1987
+ volatile=True).cuda()
1988
+ mask = []
1989
+ for transformer in transforms_var:
1990
+ rgb_trans = transformer.augment_image(img_var)
1991
+ rgb_trans = rgb_trans.to(dtype=torch_dtype, device=torch_device)
1992
+ model_output = net(rgb_trans)
1993
+ deaug_mask = transformer.deaugment_mask(model_output)
1994
+ mask.append(deaug_mask)
1995
+
1996
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
1997
+ prediction = prediction.sigmoid()
1998
+ prediction = to_pil(prediction.data.squeeze(0).cpu())
1999
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
2000
+ prediction.save('./tmp.png')
2001
+ #+end_src
2002
+
2003
+ ** MVANet_inference function
2004
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.function.py
2005
+ def main(item):
2006
+ net = inf_MVANet().cuda()
2007
+ pretrained_dict = torch.load(os.path.join(ckpt_path, item + '.pth'),
2008
+ map_location='cuda')
2009
+ model_dict = net.state_dict()
2010
+ pretrained_dict = {
2011
+ k: v
2012
+ for k, v in pretrained_dict.items() if k in model_dict
2013
+ }
2014
+ model_dict.update(pretrained_dict)
2015
+ net.load_state_dict(model_dict)
2016
+ net.eval()
2017
+ with torch.no_grad():
2018
+ for name, root in to_test.items():
2019
+ root1 = os.path.join(root, 'images')
2020
+ img_list = [os.path.splitext(f) for f in os.listdir(root1)]
2021
+ for idx, img_name in enumerate(img_list):
2022
+
2023
+ print('predicting for %s: %d / %d' %
2024
+ (name, idx + 1, len(img_list)))
2025
+ rgb_png_path = os.path.join(root, 'images',
2026
+ img_name[0] + '.png')
2027
+ rgb_jpg_path = os.path.join(root, 'images',
2028
+ img_name[0] + '.jpg')
2029
+ if os.path.exists(rgb_png_path):
2030
+ img = Image.open(rgb_png_path).convert('RGB')
2031
+ else:
2032
+ img = Image.open(rgb_jpg_path).convert('RGB')
2033
+ w_, h_ = img.size
2034
+ img_resize = img.resize([1024, 1024], Image.BILINEAR)
2035
+ img_var = Variable(img_transform(img_resize).unsqueeze(0),
2036
+ volatile=True).cuda()
2037
+ mask = []
2038
+ for transformer in transforms_var:
2039
+ rgb_trans = transformer.augment_image(img_var)
2040
+ model_output = net(rgb_trans)
2041
+ deaug_mask = transformer.deaugment_mask(model_output)
2042
+ mask.append(deaug_mask)
2043
+
2044
+ prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
2045
+ prediction = prediction.sigmoid()
2046
+ prediction = to_pil(prediction.data.squeeze(0))
2047
+ prediction = prediction.resize((w_, h_), Image.BILINEAR)
2048
+ if args['save_results']:
2049
+ check_mkdir(os.path.join(ckpt_path, item, name))
2050
+ prediction.save(
2051
+ os.path.join(ckpt_path, item, name,
2052
+ img_name[0] + '.png'))
2053
+ #+end_src
2054
+
2055
+ ** MVANet_inference execute
2056
+ #+begin_src python :shebang #!/usr/bin/python3 :results output :tangle ./MVANet_inference.execute.py
2057
+ def do_merge(path_image, path_mask, path_out):
2058
+ image = cv2.imread(path_image, cv2.IMREAD_COLOR)
2059
+ mask = cv2.imread(path_mask, cv2.IMREAD_GRAYSCALE)
2060
+ mask = (mask > 127).astype(dtype=np.uint8) * 255
2061
+ out = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
2062
+ out[:, :, 0:3] = image
2063
+ out[:, :, 3] = mask
2064
+ cv2.imwrite(path_out, out)
2065
+
2066
+
2067
+ if __name__ == '__main__':
2068
+
2069
+ # do_infer_modular_cv(
2070
+ # input_image_path=
2071
+ # '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png',
2072
+ # output_mask_path='./tmp.png',
2073
+ # net=load_model(finetuned_MVANet_model_path),
2074
+ # all_transforms=load_transforms(),
2075
+ # )
2076
+
2077
+ # net = load_model(
2078
+ # HOME_DIR + '/dreambooth_experiments/MVANet/MVANet_cloth_segment_14.pth')
2079
+
2080
+ # net = load_model(
2081
+ # HOME_DIR +
2082
+ # '/dreambooth_experiments/MVANet/new_type_crop_with_midshot.pth')
2083
+
2084
+ # net = load_model('/home/asd/MODEL_CHECKPOINTS/MVANet/SKIN_SEGMENTATION/1/Model_4.pth')
2085
+
2086
+ net = load_model('/home/asd/MODEL_CHECKPOINTS/MVANet/SKIN_SEGMENTATION/3/Model_14.pth')
2087
+
2088
+
2089
+ # net = load_model(HOME_DIR +
2090
+ # '/dreambooth_experiments/MVANet/mvanet_normal_crop_2.pth')
2091
+
2092
+ DATA_DIR_BASE = HOME_DIR + '/DATASETS/cloth_segmentation_test_images.dir/cloth_segmentation_test_images/'
2093
+
2094
+ images = (
2095
+ '1370', '1371', '1372', '1373', '1374', '1375', '1376', '1377', '1378',
2096
+ '1379', '1380', '1381', '1382', '1383', '1384', '1385', '1386', '1387',
2097
+ '1388', '1389', '1390', '1391', '1392', '1393', '1394', '1395', '1396',
2098
+ '1397', '1398', '1399', '1400', '1401', '1402', '1403', '1404', '1405',
2099
+ '1406', '1407', '1408', '1409', '1410', '1411', '1412', '1413', '1414',
2100
+ '1415', '1539', '1541', '1542', '1543', '17320', '4129', '4190',
2101
+ '4191', '4192', '4193', '4202', '4203', '4204', '4207', '4208', '4209',
2102
+ '4210', '4213', '4214', '4221', '4222', '4223', '4224', '4225', '4226',
2103
+ '4227', '4228', '4229', '4230', '4231', '4232', '4233', '4234', '4235',
2104
+ '4236', '4237', '4238', '4239', '4240', '4241', '4242', '4251', '4252',
2105
+ '4253', '4254', '4255', '4256', '4257', '4258', '4259', '4260', '4261',
2106
+ '4262', '4263', '4264', '6581', '6642', '6647', '6656', '6660', '6690',
2107
+ '6696', '6724', '6767', '6771', '6788', '6791', '6807', '6821', '6824',
2108
+ '6833', '6847', '6850', '6879', '6941', '7001', '7070', '7083', '7092',
2109
+ '7093', '7119', '7191', '7220', '7252', '7264', '7276', '7278', '7281',
2110
+ '7290', '7301', '7312', '7340', '7398', '7404', '7412', '7429', '7439',
2111
+ '7478', '7491', '7631', '7687', '7699', '7719', '7770', '7784', '7793',
2112
+ '7811', '7829', '7861', '7864', '7868', '7980', '7987', '7990', '8069',
2113
+ '8083', '8100', '8108', '8227', '8323', '8329', '8358', '8383', '8401',
2114
+ '8415', '8488', '8515', '8518', '8560', '8565', '8595', '8639', '8676',
2115
+ '8690', '8691', '8701', '8703', '8723', '8726', '8756', '8783', '8801',
2116
+ '8820', '8826', '8842', '8865', '8874', '8875', '8882', '8911', '8946',
2117
+ '8947', '8969', '8979', '8983')
2118
+
2119
+ masks = [DATA_DIR_BASE + i + '/garment_mask.png' for i in images]
2120
+ out = [DATA_DIR_BASE + i + '/garment_transparent.png' for i in images]
2121
+
2122
+ images = [DATA_DIR_BASE + i + '/original.jpg' for i in images]
2123
+
2124
+ for i in range(len(images)):
2125
+ image = images[i]
2126
+ image = load_image_torch(image)
2127
+ mask = do_infer_tensor2tensor(image, net)
2128
+ save_mask_torch(output_image_path=masks[i], mask=mask)
2129
+ do_merge(path_image=images[i], path_mask=masks[i], path_out=out[i])
2130
+
2131
+ # img = load_image_torch(
2132
+ # '/home/asd/DATASETS/SD_BG_SWAP_TEST/comfyui_outputs/4/output_fooocus/bgswap-output.png'
2133
+ # )
2134
+ # # all_transforms = load_transforms()
2135
+ # masks = do_infer_tensor2tensor(img, net)
2136
+ # save_mask_torch(output_image_path='./tmp.png', mask=masks)
2137
+ #+end_src
2138
+
2139
+ ** MVANet_inference unify
2140
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.unify.sh
2141
+ . "${HOME}/dbnew.sh"
2142
+
2143
+ (
2144
+ echo '#!/usr/bin/python3'
2145
+ cat \
2146
+ './MVANet_inference.import.py' \
2147
+ './MVANet_inference.function.py' \
2148
+ './MVANet_inference.class.py' \
2149
+ './MVANet_inference.execute.py' \
2150
+ | expand | yapf3 \
2151
+ | grep -v '#!/usr/bin/python3' \
2152
+ ;
2153
+ ) > './MVANet_inference.py' \
2154
+ ;
2155
+ #+end_src
2156
+
2157
+ ** MVANet_inference run
2158
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./MVANet_inference.run.sh
2159
+ . "${HOME}/dbnew.sh"
2160
+ python3 './MVANet_inference.py'
2161
+ #+end_src
2162
+
2163
+ * WORK SPACE
2164
+
2165
+ ** elisp
2166
+ #+begin_src elisp
2167
+ (save-buffer)
2168
+ (org-babel-tangle)
2169
+ (shell-command "./MVANet_inference.unify.sh")
2170
+ #+end_src
2171
+
2172
+ #+RESULTS:
2173
+ : 0
2174
+
2175
+ ** sh
2176
+ #+begin_src sh :shebang #!/bin/sh :results output
2177
+ realpath .
2178
+ cd /home/asd/GITHUB/aravind-h-v/dreambooth_experiments/MVANet
2179
+ #+end_src
README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Self Correction for Human Parsing
2
+ This is a copy of https://github.com/GoGoDuck912/Self-Correction-Human-Parsing
3
+
4
+
5
+ ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg)
6
+ [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
7
+
8
+ An out-of-box human parsing representation extractor.
9
+
10
+ Our solution ranks 1st for all human parsing tracks (including single, multiple and video) in the third LIP challenge!
11
+
12
+ ![lip-visualization](./demo/lip-visualization.jpg)
13
+
14
+ Features:
15
+ - [x] Out-of-box human parsing extractor for other downstream applications.
16
+ - [x] Pretrained model on three popular single person human parsing datasets.
17
+ - [x] Training and inferecne code.
18
+ - [x] Simple yet effective extension on multi-person and video human parsing tasks.
19
+
20
+ ## Requirements
21
+
22
+ ```
23
+ conda env create -f environment.yaml
24
+ conda activate schp
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ ## Simple Out-of-Box Extractor
29
+
30
+ The easiest way to get started is to use our trained SCHP models on your own images to extract human parsing representations. Here we provided state-of-the-art [trained models](https://drive.google.com/drive/folders/1uOaQCpNtosIjEL2phQKEdiYd0Td18jNo?usp=sharing) on three popular datasets. Theses three datasets have different label system, you can choose the best one to fit on your own task.
31
+
32
+ **LIP** ([exp-schp-201908261155-lip.pth](https://drive.google.com/file/d/1k4dllHpu0bdx38J7H28rVVLpU-kOHmnH/view?usp=sharing))
33
+
34
+ * mIoU on LIP validation: **59.36 %**.
35
+
36
+ * LIP is the largest single person human parsing dataset with 50000+ images. This dataset focus more on the complicated real scenarios. LIP has 20 labels, including 'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe'.
37
+
38
+ **ATR** ([exp-schp-201908301523-atr.pth](https://drive.google.com/file/d/1ruJg4lqR_jgQPj-9K0PP-L2vJERYOxLP/view?usp=sharing))
39
+
40
+ * mIoU on ATR test: **82.29%**.
41
+
42
+ * ATR is a large single person human parsing dataset with 17000+ images. This dataset focus more on fashion AI. ATR has 18 labels, including 'Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'.
43
+
44
+ **Pascal-Person-Part** ([exp-schp-201908270938-pascal-person-part.pth](https://drive.google.com/file/d/1E5YwNKW2VOEayK9mWCS3Kpsxf-3z04ZE/view?usp=sharing))
45
+
46
+ * mIoU on Pascal-Person-Part validation: **71.46** %.
47
+
48
+ * Pascal Person Part is a tiny single person human parsing dataset with 3000+ images. This dataset focus more on body parts segmentation. Pascal Person Part has 7 labels, including 'Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'.
49
+
50
+ Choose one and have fun on your own task!
51
+
52
+ To extract the human parsing representation, simply put your own image in the `INPUT_PATH` folder, then download a pretrained model and run the following command. The output images with the same file name will be saved in `OUTPUT_PATH`
53
+
54
+ ```
55
+ python simple_extractor.py --dataset [DATASET] --model-restore [CHECKPOINT_PATH] --input-dir [INPUT_PATH] --output-dir [OUTPUT_PATH]
56
+ ```
57
+
58
+ **[Updated]** Here is also a [colab demo example](https://colab.research.google.com/drive/1JOwOPaChoc9GzyBi5FUEYTSaP2qxJl10?usp=sharing) for quick inference provided by [@levindabhi](https://github.com/levindabhi).
59
+
60
+ The `DATASET` command has three options, including 'lip', 'atr' and 'pascal'. Note each pixel in the output images denotes the predicted label number. The output images have the same size as the input ones. To better visualization, we put a palette with the output images. We suggest you to read the image with `PIL`.
61
+
62
+ If you need not only the final parsing images, but also the feature map representations. Add `--logits` command to save the output feature maps. These feature maps are the logits before softmax layer.
63
+
64
+ ## Dataset Preparation
65
+
66
+ Please download the [LIP](http://sysu-hcp.net/lip/) dataset following the below structure.
67
+
68
+ ```commandline
69
+ data/LIP
70
+ |--- train_imgaes # 30462 training single person images
71
+ |--- val_images # 10000 validation single person images
72
+ |--- train_segmentations # 30462 training annotations
73
+ |--- val_segmentations # 10000 training annotations
74
+ |--- train_id.txt # training image list
75
+ |--- val_id.txt # validation image list
76
+ ```
77
+
78
+ ## Training
79
+
80
+ ```
81
+ python train.py
82
+ ```
83
+ By default, the trained model will be saved in `./log` directory. Please read the arguments for more details.
84
+
85
+ ## Evaluation
86
+ ```
87
+ python evaluate.py --model-restore [CHECKPOINT_PATH]
88
+ ```
89
+ CHECKPOINT_PATH should be the path of trained model.
90
+
91
+ ## Extension on Multiple Human Parsing
92
+
93
+ Please read [MultipleHumanParsing.md](./mhp_extension/README.md) for more details.
94
+
95
+ ## Citation
96
+
97
+ Please cite our work if you find this repo useful in your research.
98
+
99
+ ```latex
100
+ @article{li2020self,
101
+ title={Self-Correction for Human Parsing},
102
+ author={Li, Peike and Xu, Yunqiu and Wei, Yunchao and Yang, Yi},
103
+ journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
104
+ year={2020},
105
+ doi={10.1109/TPAMI.2020.3048039}}
106
+ ```
107
+
108
+ ## Visualization
109
+
110
+ * Source Image.
111
+ ![demo](./demo/demo.jpg)
112
+ * LIP Parsing Result.
113
+ ![demo-lip](./demo/demo_lip.png)
114
+ * ATR Parsing Result.
115
+ ![demo-atr](./demo/demo_atr.png)
116
+ * Pascal-Person-Part Parsing Result.
117
+ ![demo-pascal](./demo/demo_pascal.png)
118
+ * Source Image.
119
+ ![demo](./mhp_extension/demo/demo.jpg)
120
+ * Instance Human Mask.
121
+ ![demo-lip](./mhp_extension/demo/demo_instance_human_mask.png)
122
+ * Global Human Parsing Result.
123
+ ![demo-lip](./mhp_extension/demo/demo_global_human_parsing.png)
124
+ * Multiple Human Parsing Result.
125
+ ![demo-lip](./mhp_extension/demo/demo_multiple_human_parsing.png)
126
+
127
+
128
+ ## Related
129
+ Our code adopts the [InplaceSyncBN](https://github.com/mapillary/inplace_abn) to save gpu memory cost.
130
+
131
+ There is also a [PaddlePaddle](https://github.com/PaddlePaddle/PaddleSeg/tree/develop/contrib/ACE2P) Implementation of this project.
checkpoints/AEMatter/AEM_RWA.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a475193549365ff3c892a85d2f4ca90ece2ac8dc4de4a39df250c76ca870d280
3
+ size 205399637
checkpoints/MVANet/garment.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7604ed46e06fbcff3b8f38c8934d253617171d02aecdd028f0f01086d9344893
3
+ size 380785263
checkpoints/MVANet/skin.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c71afcdd9cb1be73e43d84f5ffc2ae12b4964cc13c8460fc0adb6d52a0603cd4
3
+ size 380782803
checkpoints/Model_80.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffec20a382b0a1832786438475e8b912a03be727a0e3197e7ab039153fb3bc46
3
+ size 386621643
checkpoints/StableDiffusion/90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:010be7341cd98a136da775330ba3eb4e87025c6cfd2f5455dc64daee2200ae98
3
+ size 7105348616
checkpoints/StableDiffusion/b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1689257e6e1b2e61544b1a41fc114e7d798f68854b3f875cd52070bfe1fbc00
3
+ size 6938072258
checkpoints/StableDiffusion/hash ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b Juggernaut_X_RunDiffusion_Hyper.safetensors
2
+ b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30 juggernautXL_versionXInpaint.safetensors
checkpoints/atr.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9d7c91ce3b4e7133df56b599fc817b533e3439c5e8d282a59126d2fda339a2a
3
+ size 267445237
checkpoints/lip.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24fa3254ceeb74c8435458994a64b522fb439a3635b7b86ff470457e0413da00
3
+ size 267449349
checkpoints/pascal.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b03d343c39fb0696f75d45c44c67b2fc23f5d0bf0925a82c0465e415799fa85
3
+ size 267422621
datasets/__init__.py ADDED
File without changes
datasets/datasets.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : datasets.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import numpy as np
16
+ import random
17
+ import torch
18
+ import cv2
19
+ from torch.utils import data
20
+ from utils.transforms import get_affine_transform
21
+
22
+
23
+ class LIPDataSet(data.Dataset):
24
+ def __init__(self, root, dataset, crop_size=[473, 473], scale_factor=0.25,
25
+ rotation_factor=30, ignore_label=255, transform=None):
26
+ self.root = root
27
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
28
+ self.crop_size = np.asarray(crop_size)
29
+ self.ignore_label = ignore_label
30
+ self.scale_factor = scale_factor
31
+ self.rotation_factor = rotation_factor
32
+ self.flip_prob = 0.5
33
+ self.transform = transform
34
+ self.dataset = dataset
35
+
36
+ list_path = os.path.join(self.root, self.dataset + '_id.txt')
37
+ train_list = [i_id.strip() for i_id in open(list_path)]
38
+
39
+ self.train_list = train_list
40
+ self.number_samples = len(self.train_list)
41
+
42
+ def __len__(self):
43
+ return self.number_samples
44
+
45
+ def _box2cs(self, box):
46
+ x, y, w, h = box[:4]
47
+ return self._xywh2cs(x, y, w, h)
48
+
49
+ def _xywh2cs(self, x, y, w, h):
50
+ center = np.zeros((2), dtype=np.float32)
51
+ center[0] = x + w * 0.5
52
+ center[1] = y + h * 0.5
53
+ if w > self.aspect_ratio * h:
54
+ h = w * 1.0 / self.aspect_ratio
55
+ elif w < self.aspect_ratio * h:
56
+ w = h * self.aspect_ratio
57
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
58
+ return center, scale
59
+
60
+ def __getitem__(self, index):
61
+ train_item = self.train_list[index]
62
+
63
+ im_path = os.path.join(self.root, self.dataset + '_images', train_item + '.jpg')
64
+ parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', train_item + '.png')
65
+
66
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
67
+ h, w, _ = im.shape
68
+ parsing_anno = np.zeros((h, w), dtype=np.long)
69
+
70
+ # Get person center and scale
71
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
72
+ r = 0
73
+
74
+ if self.dataset != 'test':
75
+ # Get pose annotation
76
+ parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE)
77
+ if self.dataset == 'train' or self.dataset == 'trainval':
78
+ sf = self.scale_factor
79
+ rf = self.rotation_factor
80
+ s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
81
+ r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) if random.random() <= 0.6 else 0
82
+
83
+ if random.random() <= self.flip_prob:
84
+ im = im[:, ::-1, :]
85
+ parsing_anno = parsing_anno[:, ::-1]
86
+ person_center[0] = im.shape[1] - person_center[0] - 1
87
+ right_idx = [15, 17, 19]
88
+ left_idx = [14, 16, 18]
89
+ for i in range(0, 3):
90
+ right_pos = np.where(parsing_anno == right_idx[i])
91
+ left_pos = np.where(parsing_anno == left_idx[i])
92
+ parsing_anno[right_pos[0], right_pos[1]] = left_idx[i]
93
+ parsing_anno[left_pos[0], left_pos[1]] = right_idx[i]
94
+
95
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
96
+ input = cv2.warpAffine(
97
+ im,
98
+ trans,
99
+ (int(self.crop_size[1]), int(self.crop_size[0])),
100
+ flags=cv2.INTER_LINEAR,
101
+ borderMode=cv2.BORDER_CONSTANT,
102
+ borderValue=(0, 0, 0))
103
+
104
+ if self.transform:
105
+ input = self.transform(input)
106
+
107
+ meta = {
108
+ 'name': train_item,
109
+ 'center': person_center,
110
+ 'height': h,
111
+ 'width': w,
112
+ 'scale': s,
113
+ 'rotation': r
114
+ }
115
+
116
+ if self.dataset == 'val' or self.dataset == 'test':
117
+ return input, meta
118
+ else:
119
+ label_parsing = cv2.warpAffine(
120
+ parsing_anno,
121
+ trans,
122
+ (int(self.crop_size[1]), int(self.crop_size[0])),
123
+ flags=cv2.INTER_NEAREST,
124
+ borderMode=cv2.BORDER_CONSTANT,
125
+ borderValue=(255))
126
+
127
+ label_parsing = torch.from_numpy(label_parsing)
128
+
129
+ return input, label_parsing, meta
130
+
131
+
132
+ class LIPDataValSet(data.Dataset):
133
+ def __init__(self, root, dataset='val', crop_size=[473, 473], transform=None, flip=False):
134
+ self.root = root
135
+ self.crop_size = crop_size
136
+ self.transform = transform
137
+ self.flip = flip
138
+ self.dataset = dataset
139
+ self.root = root
140
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
141
+ self.crop_size = np.asarray(crop_size)
142
+
143
+ list_path = os.path.join(self.root, self.dataset + '_id.txt')
144
+ val_list = [i_id.strip() for i_id in open(list_path)]
145
+
146
+ self.val_list = val_list
147
+ self.number_samples = len(self.val_list)
148
+
149
+ def __len__(self):
150
+ return len(self.val_list)
151
+
152
+ def _box2cs(self, box):
153
+ x, y, w, h = box[:4]
154
+ return self._xywh2cs(x, y, w, h)
155
+
156
+ def _xywh2cs(self, x, y, w, h):
157
+ center = np.zeros((2), dtype=np.float32)
158
+ center[0] = x + w * 0.5
159
+ center[1] = y + h * 0.5
160
+ if w > self.aspect_ratio * h:
161
+ h = w * 1.0 / self.aspect_ratio
162
+ elif w < self.aspect_ratio * h:
163
+ w = h * self.aspect_ratio
164
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
165
+
166
+ return center, scale
167
+
168
+ def __getitem__(self, index):
169
+ val_item = self.val_list[index]
170
+ # Load training image
171
+ im_path = os.path.join(self.root, self.dataset + '_images', val_item + '.jpg')
172
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
173
+ h, w, _ = im.shape
174
+ # Get person center and scale
175
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
176
+ r = 0
177
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
178
+ input = cv2.warpAffine(
179
+ im,
180
+ trans,
181
+ (int(self.crop_size[1]), int(self.crop_size[0])),
182
+ flags=cv2.INTER_LINEAR,
183
+ borderMode=cv2.BORDER_CONSTANT,
184
+ borderValue=(0, 0, 0))
185
+ input = self.transform(input)
186
+ flip_input = input.flip(dims=[-1])
187
+ if self.flip:
188
+ batch_input_im = torch.stack([input, flip_input])
189
+ else:
190
+ batch_input_im = input
191
+
192
+ meta = {
193
+ 'name': val_item,
194
+ 'center': person_center,
195
+ 'height': h,
196
+ 'width': w,
197
+ 'scale': s,
198
+ 'rotation': r
199
+ }
200
+
201
+ return batch_input_im, meta
datasets/simple_extractor_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : dataset.py
8
+ @Time : 8/30/19 9:12 PM
9
+ @Desc : Dataset Definition
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import cv2
16
+ import numpy as np
17
+
18
+ from torch.utils import data
19
+ from utils.transforms import get_affine_transform
20
+
21
+
22
+ class SimpleFolderDataset(data.Dataset):
23
+ def __init__(self, root, input_size=[512, 512], transform=None):
24
+ self.root = root
25
+ self.input_size = input_size
26
+ self.transform = transform
27
+ self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
28
+ self.input_size = np.asarray(input_size)
29
+
30
+ self.file_list = os.listdir(self.root)
31
+
32
+ def __len__(self):
33
+ return len(self.file_list)
34
+
35
+ def _box2cs(self, box):
36
+ x, y, w, h = box[:4]
37
+ return self._xywh2cs(x, y, w, h)
38
+
39
+ def _xywh2cs(self, x, y, w, h):
40
+ center = np.zeros((2), dtype=np.float32)
41
+ center[0] = x + w * 0.5
42
+ center[1] = y + h * 0.5
43
+ if w > self.aspect_ratio * h:
44
+ h = w * 1.0 / self.aspect_ratio
45
+ elif w < self.aspect_ratio * h:
46
+ w = h * self.aspect_ratio
47
+ scale = np.array([w, h], dtype=np.float32)
48
+ return center, scale
49
+
50
+ def __getitem__(self, index):
51
+ img_name = self.file_list[index]
52
+ img_path = os.path.join(self.root, img_name)
53
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
54
+ h, w, _ = img.shape
55
+
56
+ # Get person center and scale
57
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
58
+ r = 0
59
+ trans = get_affine_transform(person_center, s, r, self.input_size)
60
+ input = cv2.warpAffine(
61
+ img,
62
+ trans,
63
+ (int(self.input_size[1]), int(self.input_size[0])),
64
+ flags=cv2.INTER_LINEAR,
65
+ borderMode=cv2.BORDER_CONSTANT,
66
+ borderValue=(0, 0, 0))
67
+
68
+ input = self.transform(input)
69
+ meta = {
70
+ 'name': img_name,
71
+ 'center': person_center,
72
+ 'height': h,
73
+ 'width': w,
74
+ 'scale': s,
75
+ 'rotation': r
76
+ }
77
+
78
+ return input, meta
datasets/target_generation.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def generate_edge_tensor(label, edge_width=3):
6
+ label = label.type(torch.cuda.FloatTensor)
7
+ if len(label.shape) == 2:
8
+ label = label.unsqueeze(0)
9
+ n, h, w = label.shape
10
+ edge = torch.zeros(label.shape, dtype=torch.float).cuda()
11
+ # right
12
+ edge_right = edge[:, 1:h, :]
13
+ edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255)
14
+ & (label[:, :h - 1, :] != 255)] = 1
15
+
16
+ # up
17
+ edge_up = edge[:, :, :w - 1]
18
+ edge_up[(label[:, :, :w - 1] != label[:, :, 1:w])
19
+ & (label[:, :, :w - 1] != 255)
20
+ & (label[:, :, 1:w] != 255)] = 1
21
+
22
+ # upright
23
+ edge_upright = edge[:, :h - 1, :w - 1]
24
+ edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w])
25
+ & (label[:, :h - 1, :w - 1] != 255)
26
+ & (label[:, 1:h, 1:w] != 255)] = 1
27
+
28
+ # bottomright
29
+ edge_bottomright = edge[:, :h - 1, 1:w]
30
+ edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1])
31
+ & (label[:, :h - 1, 1:w] != 255)
32
+ & (label[:, 1:h, :w - 1] != 255)] = 1
33
+
34
+ kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda()
35
+ with torch.no_grad():
36
+ edge = edge.unsqueeze(1)
37
+ edge = F.conv2d(edge, kernel, stride=1, padding=1)
38
+ edge[edge!=0] = 1
39
+ edge = edge.squeeze()
40
+ return edge
demo/demo.jpg ADDED

Git LFS Details

  • SHA256: 6871c209cc202232323f309bbdec6ef9c2834aedaa3aef3f50293c4e783f0fec
  • Pointer size: 131 Bytes
  • Size of remote file: 310 kB
demo/demo_atr.png ADDED
demo/demo_lip.png ADDED
demo/demo_pascal.png ADDED
demo/lip-visualization.jpg ADDED

Git LFS Details

  • SHA256: d311b9ac4871d4e05a6b29953b13d6431afb269514571992267ef7038953bf1d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
environment.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: schp
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - blas=1.0=mkl
8
+ - ca-certificates=2020.12.8=h06a4308_0
9
+ - certifi=2020.12.5=py38h06a4308_0
10
+ - cudatoolkit=10.1.243=h6bb024c_0
11
+ - freetype=2.10.4=h5ab3b9f_0
12
+ - intel-openmp=2020.2=254
13
+ - jpeg=9b=h024ee3a_2
14
+ - lcms2=2.11=h396b838_0
15
+ - ld_impl_linux-64=2.33.1=h53a641e_7
16
+ - libedit=3.1.20191231=h14c3975_1
17
+ - libffi=3.3=he6710b0_2
18
+ - libgcc-ng=9.1.0=hdf63c60_0
19
+ - libpng=1.6.37=hbc83047_0
20
+ - libstdcxx-ng=9.1.0=hdf63c60_0
21
+ - libtiff=4.1.0=h2733197_1
22
+ - lz4-c=1.9.2=heb0550a_3
23
+ - mkl=2020.2=256
24
+ - mkl-service=2.3.0=py38he904b0f_0
25
+ - mkl_fft=1.2.0=py38h23d657b_0
26
+ - mkl_random=1.1.1=py38h0573a6f_0
27
+ - ncurses=6.2=he6710b0_1
28
+ - ninja=1.10.2=py38hff7bd54_0
29
+ - numpy=1.19.2=py38h54aff64_0
30
+ - numpy-base=1.19.2=py38hfa32c7d_0
31
+ - olefile=0.46=py_0
32
+ - openssl=1.1.1i=h27cfd23_0
33
+ - pillow=8.0.1=py38he98fc37_0
34
+ - pip=20.3.3=py38h06a4308_0
35
+ - python=3.8.5=h7579374_1
36
+ - readline=8.0=h7b6447c_0
37
+ - setuptools=51.0.0=py38h06a4308_2
38
+ - six=1.15.0=py38h06a4308_0
39
+ - sqlite=3.33.0=h62c20be_0
40
+ - tk=8.6.10=hbc83047_0
41
+ - tqdm=4.55.0=pyhd3eb1b0_0
42
+ - wheel=0.36.2=pyhd3eb1b0_0
43
+ - xz=5.2.5=h7b6447c_0
44
+ - zlib=1.2.11=h7b6447c_3
45
+ - zstd=1.4.5=h9ceee32_0
46
+ - pytorch=1.5.1=py3.8_cuda10.1.243_cudnn7.6.3_0
47
+ - torchvision=0.6.1=py38_cu101
48
+ prefix: /home/peike/opt/anaconda3/envs/schp
49
+
evaluate.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : evaluate.py
8
+ @Time : 8/4/19 3:36 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import argparse
16
+ import numpy as np
17
+ import torch
18
+
19
+ from torch.utils import data
20
+ from tqdm import tqdm
21
+ from PIL import Image as PILImage
22
+ import torchvision.transforms as transforms
23
+ import torch.backends.cudnn as cudnn
24
+
25
+ import networks
26
+ from datasets.datasets import LIPDataValSet
27
+ from utils.miou import compute_mean_ioU
28
+ from utils.transforms import BGR2RGB_transform
29
+ from utils.transforms import transform_parsing
30
+
31
+
32
+ def get_arguments():
33
+ """Parse all the arguments provided from the CLI.
34
+
35
+ Returns:
36
+ A list of parsed arguments.
37
+ """
38
+ parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
39
+
40
+ # Network Structure
41
+ parser.add_argument("--arch", type=str, default='resnet101')
42
+ # Data Preference
43
+ parser.add_argument("--data-dir", type=str, default='./data/LIP')
44
+ parser.add_argument("--batch-size", type=int, default=1)
45
+ parser.add_argument("--input-size", type=str, default='473,473')
46
+ parser.add_argument("--num-classes", type=int, default=20)
47
+ parser.add_argument("--ignore-label", type=int, default=255)
48
+ parser.add_argument("--random-mirror", action="store_true")
49
+ parser.add_argument("--random-scale", action="store_true")
50
+ # Evaluation Preference
51
+ parser.add_argument("--log-dir", type=str, default='./log')
52
+ parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar')
53
+ parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
54
+ parser.add_argument("--save-results", action="store_true", help="whether to save the results.")
55
+ parser.add_argument("--flip", action="store_true", help="random flip during the test.")
56
+ parser.add_argument("--multi-scales", type=str, default='1', help="multiple scales during the test")
57
+ return parser.parse_args()
58
+
59
+
60
+ def get_palette(num_cls):
61
+ """ Returns the color map for visualizing the segmentation mask.
62
+ Args:
63
+ num_cls: Number of classes
64
+ Returns:
65
+ The color map
66
+ """
67
+ n = num_cls
68
+ palette = [0] * (n * 3)
69
+ for j in range(0, n):
70
+ lab = j
71
+ palette[j * 3 + 0] = 0
72
+ palette[j * 3 + 1] = 0
73
+ palette[j * 3 + 2] = 0
74
+ i = 0
75
+ while lab:
76
+ palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
77
+ palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
78
+ palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
79
+ i += 1
80
+ lab >>= 3
81
+ return palette
82
+
83
+
84
+ def multi_scale_testing(model, batch_input_im, crop_size=[473, 473], flip=True, multi_scales=[1]):
85
+ flipped_idx = (15, 14, 17, 16, 19, 18)
86
+ if len(batch_input_im.shape) > 4:
87
+ batch_input_im = batch_input_im.squeeze()
88
+ if len(batch_input_im.shape) == 3:
89
+ batch_input_im = batch_input_im.unsqueeze(0)
90
+
91
+ interp = torch.nn.Upsample(size=crop_size, mode='bilinear', align_corners=True)
92
+ ms_outputs = []
93
+ for s in multi_scales:
94
+ interp_im = torch.nn.Upsample(scale_factor=s, mode='bilinear', align_corners=True)
95
+ scaled_im = interp_im(batch_input_im)
96
+ parsing_output = model(scaled_im)
97
+ parsing_output = parsing_output[0][-1]
98
+ output = parsing_output[0]
99
+ if flip:
100
+ flipped_output = parsing_output[1]
101
+ flipped_output[14:20, :, :] = flipped_output[flipped_idx, :, :]
102
+ output += flipped_output.flip(dims=[-1])
103
+ output *= 0.5
104
+ output = interp(output.unsqueeze(0))
105
+ ms_outputs.append(output[0])
106
+ ms_fused_parsing_output = torch.stack(ms_outputs)
107
+ ms_fused_parsing_output = ms_fused_parsing_output.mean(0)
108
+ ms_fused_parsing_output = ms_fused_parsing_output.permute(1, 2, 0) # HWC
109
+ parsing = torch.argmax(ms_fused_parsing_output, dim=2)
110
+ parsing = parsing.data.cpu().numpy()
111
+ ms_fused_parsing_output = ms_fused_parsing_output.data.cpu().numpy()
112
+ return parsing, ms_fused_parsing_output
113
+
114
+
115
+ def main():
116
+ """Create the model and start the evaluation process."""
117
+ args = get_arguments()
118
+ multi_scales = [float(i) for i in args.multi_scales.split(',')]
119
+ gpus = [int(i) for i in args.gpu.split(',')]
120
+ assert len(gpus) == 1
121
+ if not args.gpu == 'None':
122
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
123
+
124
+ cudnn.benchmark = True
125
+ cudnn.enabled = True
126
+
127
+ h, w = map(int, args.input_size.split(','))
128
+ input_size = [h, w]
129
+
130
+ model = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=None)
131
+
132
+ IMAGE_MEAN = model.mean
133
+ IMAGE_STD = model.std
134
+ INPUT_SPACE = model.input_space
135
+ print('image mean: {}'.format(IMAGE_MEAN))
136
+ print('image std: {}'.format(IMAGE_STD))
137
+ print('input space:{}'.format(INPUT_SPACE))
138
+ if INPUT_SPACE == 'BGR':
139
+ print('BGR Transformation')
140
+ transform = transforms.Compose([
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(mean=IMAGE_MEAN,
143
+ std=IMAGE_STD),
144
+
145
+ ])
146
+ if INPUT_SPACE == 'RGB':
147
+ print('RGB Transformation')
148
+ transform = transforms.Compose([
149
+ transforms.ToTensor(),
150
+ BGR2RGB_transform(),
151
+ transforms.Normalize(mean=IMAGE_MEAN,
152
+ std=IMAGE_STD),
153
+ ])
154
+
155
+ # Data loader
156
+ lip_test_dataset = LIPDataValSet(args.data_dir, 'val', crop_size=input_size, transform=transform, flip=args.flip)
157
+ num_samples = len(lip_test_dataset)
158
+ print('Totoal testing sample numbers: {}'.format(num_samples))
159
+ testloader = data.DataLoader(lip_test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
160
+
161
+ # Load model weight
162
+ state_dict = torch.load(args.model_restore)['state_dict']
163
+ from collections import OrderedDict
164
+ new_state_dict = OrderedDict()
165
+ for k, v in state_dict.items():
166
+ name = k[7:] # remove `module.`
167
+ new_state_dict[name] = v
168
+ model.load_state_dict(new_state_dict)
169
+ model.cuda()
170
+ model.eval()
171
+
172
+ sp_results_dir = os.path.join(args.log_dir, 'sp_results')
173
+ if not os.path.exists(sp_results_dir):
174
+ os.makedirs(sp_results_dir)
175
+
176
+ palette = get_palette(20)
177
+ parsing_preds = []
178
+ scales = np.zeros((num_samples, 2), dtype=np.float32)
179
+ centers = np.zeros((num_samples, 2), dtype=np.int32)
180
+ with torch.no_grad():
181
+ for idx, batch in enumerate(tqdm(testloader)):
182
+ image, meta = batch
183
+ if (len(image.shape) > 4):
184
+ image = image.squeeze()
185
+ im_name = meta['name'][0]
186
+ c = meta['center'].numpy()[0]
187
+ s = meta['scale'].numpy()[0]
188
+ w = meta['width'].numpy()[0]
189
+ h = meta['height'].numpy()[0]
190
+ scales[idx, :] = s
191
+ centers[idx, :] = c
192
+ parsing, logits = multi_scale_testing(model, image.cuda(), crop_size=input_size, flip=args.flip,
193
+ multi_scales=multi_scales)
194
+ if args.save_results:
195
+ parsing_result = transform_parsing(parsing, c, s, w, h, input_size)
196
+ parsing_result_path = os.path.join(sp_results_dir, im_name + '.png')
197
+ output_im = PILImage.fromarray(np.asarray(parsing_result, dtype=np.uint8))
198
+ output_im.putpalette(palette)
199
+ output_im.save(parsing_result_path)
200
+
201
+ parsing_preds.append(parsing)
202
+ assert len(parsing_preds) == num_samples
203
+ mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size)
204
+ print(mIoU)
205
+ return
206
+
207
+
208
+ if __name__ == '__main__':
209
+ main()
main.org ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * COMMENT WORK SPACE
2
+ cd $HOME/HUGGINGFACE/aravindhv10/Self-Correction-Human-Parsing
3
+
4
+ ** ELISP
5
+ #+begin_src elisp
6
+ (save-buffer)
7
+ (save-some-buffers)
8
+ (org-babel-tangle)
9
+ (shell-command "./work.sh" "output_log_work")
10
+ #+end_src
11
+
12
+ #+RESULTS:
13
+ : 0
14
+
15
+ ** ELISP
16
+ #+begin_src elisp
17
+ (shell-command "git status" "output_log_git_status")
18
+ #+end_src
19
+
20
+ #+RESULTS:
21
+ : 0
22
+
23
+ ** ELISP
24
+ #+begin_src elisp
25
+ (shell-command "./commit_and_push.sh" "output_log_commit_and_push")
26
+ #+end_src
27
+
28
+ #+RESULTS:
29
+ : 0
30
+
31
+ * Commit and push
32
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./commit_and_push.sh
33
+ git commit -m 'Routine updates'
34
+ git push
35
+ #+end_src
36
+
37
+ * Main script to do everything
38
+ #+begin_src sh :shebang #!/bin/sh :results output :tangle ./work.sh
39
+ do_ignore(){
40
+ 'sed' 's@^@/@g' './rm.txt';
41
+ 'cat' './gitignore.txt';
42
+ }
43
+
44
+ do_add(){
45
+ 'sed' 's@^@("git" "lfs" "track" "./@g;s@$@");@g' './git_lfs_track.txt' ;
46
+ 'cat' './git_add.txt' './git_lfs_track.txt' | \
47
+ 'sed' 's@^@("git" "add" "./@g;s@$@");@g' ;
48
+ }
49
+
50
+ do_rm(){
51
+ 'sed' 's@^@("rm" "-vf" "--" "./@g ; s@$@");@g' './rm.txt' ;
52
+ }
53
+
54
+ all_commands(){
55
+ do_add
56
+ do_rm
57
+ }
58
+
59
+ do_all(){
60
+ do_ignore > './.gitignore'
61
+ all_commands | sh
62
+ }
63
+
64
+ do_all
65
+ #+end_src
66
+
67
+ * List of large files
68
+ #+begin_src conf :tangle ./git_lfs_track.txt
69
+ checkpoints/AEMatter/AEM_RWA.ckpt
70
+ checkpoints/atr.pth
71
+ checkpoints/lip.pth
72
+ checkpoints/Model_80.pth
73
+ checkpoints/MVANet/garment.pth
74
+ checkpoints/MVANet/skin.pth
75
+ checkpoints/pascal.pth
76
+ checkpoints/StableDiffusion/90c7c97574f8db765509b6a5d2e7b2551b430a10cac03e37d368654eac5e8169cd149644d188be4b5b2f1b9f29e66b64a02535f622f2bf284c319b076224cb2b
77
+ checkpoints/StableDiffusion/b970812225cfb95427c13e73b75eef66430e2a525876dddac494d70fe4ed0524cb197043e0ac3dc3026b32a45cd1d6d126ec2fe74a5bc3ef5df21836ca022b30
78
+ demo/demo_atr.png
79
+ demo/demo.jpg
80
+ demo/demo_lip.png
81
+ demo/demo_pascal.png
82
+ demo/lip-visualization.jpg
83
+ #+end_src
84
+
85
+ * List of source files to add
86
+ #+begin_src conf :tangle ./git_add.txt
87
+ checkpoints/StableDiffusion/hash
88
+ ComfyUI_AEMatter/AEMatter.py
89
+ ComfyUI_AEMatter/AEMatter.run.sh
90
+ ComfyUI_AEMatter/__init__.py
91
+ ComfyUI_AEMatter/README.org
92
+ ComfyUI_MVANet/download.sh
93
+ ComfyUI_MVANet/__init__.py
94
+ ComfyUI_MVANet/MVANet_inference.py
95
+ ComfyUI_MVANet/MVANet_inference.run.sh
96
+ ComfyUI_MVANet/README.org
97
+ ComfyUI_MVANet/requirements.txt
98
+ datasets/datasets.py
99
+ datasets/__init__.py
100
+ datasets/simple_extractor_dataset.py
101
+ datasets/target_generation.py
102
+ environment.yaml
103
+ evaluate.py
104
+ .gitattributes
105
+ .gitignore
106
+ LICENSE
107
+ main.org
108
+ mhp_extension/coco_style_annotation_creator/human_to_coco.py
109
+ mhp_extension/coco_style_annotation_creator/pycococreatortools.py
110
+ mhp_extension/coco_style_annotation_creator/test_human2coco_format.py
111
+ mhp_extension/demo.ipynb
112
+ mhp_extension/detectron2/.circleci/config.yml
113
+ mhp_extension/detectron2/.clang-format
114
+ mhp_extension/detectron2/configs/Base-RCNN-C4.yaml
115
+ mhp_extension/detectron2/configs/Base-RCNN-DilatedC5.yaml
116
+ mhp_extension/detectron2/configs/Base-RCNN-FPN.yaml
117
+ mhp_extension/detectron2/configs/Base-RetinaNet.yaml
118
+ mhp_extension/detectron2/configs/Cityscapes/mask_rcnn_R_50_FPN.yaml
119
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_101_C4_3x.yaml
120
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml
121
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml
122
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_1x.yaml
123
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_3x.yaml
124
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_1x.yaml
125
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml
126
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml
127
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml
128
+ mhp_extension/detectron2/configs/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml
129
+ mhp_extension/detectron2/configs/COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml
130
+ mhp_extension/detectron2/configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml
131
+ mhp_extension/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_1x.yaml
132
+ mhp_extension/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_3x.yaml
133
+ mhp_extension/detectron2/configs/COCO-Detection/rpn_R_50_C4_1x.yaml
134
+ mhp_extension/detectron2/configs/COCO-Detection/rpn_R_50_FPN_1x.yaml
135
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml
136
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x.yaml
137
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml
138
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.yaml
139
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml
140
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x.yaml
141
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x.yaml
142
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
143
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml
144
+ mhp_extension/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml
145
+ mhp_extension/detectron2/configs/COCO-Keypoints/Base-Keypoint-RCNN-FPN.yaml
146
+ mhp_extension/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml
147
+ mhp_extension/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.yaml
148
+ mhp_extension/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml
149
+ mhp_extension/detectron2/configs/COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml
150
+ mhp_extension/detectron2/configs/COCO-PanopticSegmentation/Base-Panoptic-FPN.yaml
151
+ mhp_extension/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml
152
+ mhp_extension/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml
153
+ mhp_extension/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml
154
+ mhp_extension/detectron2/configs/Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml
155
+ mhp_extension/detectron2/configs/Detectron1-Comparisons/keypoint_rcnn_R_50_FPN_1x.yaml
156
+ mhp_extension/detectron2/configs/Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x.yaml
157
+ mhp_extension/detectron2/configs/Detectron1-Comparisons/README.md
158
+ mhp_extension/detectron2/configs/LVIS-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml
159
+ mhp_extension/detectron2/configs/LVIS-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
160
+ mhp_extension/detectron2/configs/LVIS-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml
161
+ mhp_extension/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_1x.yaml
162
+ mhp_extension/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml
163
+ mhp_extension/detectron2/configs/Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv_parsing.yaml
164
+ mhp_extension/detectron2/configs/Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml
165
+ mhp_extension/detectron2/configs/Misc/demo.yaml
166
+ mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_cls_agnostic.yaml
167
+ mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_dconv_c3-c5.yaml
168
+ mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5.yaml
169
+ mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_gn.yaml
170
+ mhp_extension/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml
171
+ mhp_extension/detectron2/configs/Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x.yaml
172
+ mhp_extension/detectron2/configs/Misc/parsing_finetune_cihp.yaml
173
+ mhp_extension/detectron2/configs/Misc/parsing_inference.yaml
174
+ mhp_extension/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml
175
+ mhp_extension/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_gn.yaml
176
+ mhp_extension/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn.yaml
177
+ mhp_extension/detectron2/configs/Misc/semantic_R_50_FPN_1x.yaml
178
+ mhp_extension/detectron2/configs/my_Base-RCNN-FPN.yaml
179
+ mhp_extension/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_C4.yaml
180
+ mhp_extension/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_FPN.yaml
181
+ mhp_extension/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_inference_acc_test.yaml
182
+ mhp_extension/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_instant_test.yaml
183
+ mhp_extension/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_inference_acc_test.yaml
184
+ mhp_extension/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_instant_test.yaml
185
+ mhp_extension/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_inference_acc_test.yaml
186
+ mhp_extension/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_instant_test.yaml
187
+ mhp_extension/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_normalized_training_acc_test.yaml
188
+ mhp_extension/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_training_acc_test.yaml
189
+ mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_GCV_instant_test.yaml
190
+ mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_inference_acc_test.yaml
191
+ mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_instant_test.yaml
192
+ mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_training_acc_test.yaml
193
+ mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_DC5_inference_acc_test.yaml
194
+ mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml
195
+ mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_instant_test.yaml
196
+ mhp_extension/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_training_acc_test.yaml
197
+ mhp_extension/detectron2/configs/quick_schedules/panoptic_fpn_R_50_inference_acc_test.yaml
198
+ mhp_extension/detectron2/configs/quick_schedules/panoptic_fpn_R_50_instant_test.yaml
199
+ mhp_extension/detectron2/configs/quick_schedules/panoptic_fpn_R_50_training_acc_test.yaml
200
+ mhp_extension/detectron2/configs/quick_schedules/README.md
201
+ mhp_extension/detectron2/configs/quick_schedules/retinanet_R_50_FPN_inference_acc_test.yaml
202
+ mhp_extension/detectron2/configs/quick_schedules/retinanet_R_50_FPN_instant_test.yaml
203
+ mhp_extension/detectron2/configs/quick_schedules/rpn_R_50_FPN_inference_acc_test.yaml
204
+ mhp_extension/detectron2/configs/quick_schedules/rpn_R_50_FPN_instant_test.yaml
205
+ mhp_extension/detectron2/configs/quick_schedules/semantic_R_50_FPN_inference_acc_test.yaml
206
+ mhp_extension/detectron2/configs/quick_schedules/semantic_R_50_FPN_instant_test.yaml
207
+ mhp_extension/detectron2/configs/quick_schedules/semantic_R_50_FPN_training_acc_test.yaml
208
+ mhp_extension/detectron2/demo/demo.py
209
+ mhp_extension/detectron2/demo/predictor.py
210
+ mhp_extension/detectron2/demo/README.md
211
+ mhp_extension/detectron2/detectron2/checkpoint/c2_model_loading.py
212
+ mhp_extension/detectron2/detectron2/checkpoint/catalog.py
213
+ mhp_extension/detectron2/detectron2/checkpoint/detection_checkpoint.py
214
+ mhp_extension/detectron2/detectron2/checkpoint/__init__.py
215
+ mhp_extension/detectron2/detectron2/config/compat.py
216
+ mhp_extension/detectron2/detectron2/config/config.py
217
+ mhp_extension/detectron2/detectron2/config/defaults.py
218
+ mhp_extension/detectron2/detectron2/config/__init__.py
219
+ mhp_extension/detectron2/detectron2/data/build.py
220
+ mhp_extension/detectron2/detectron2/data/catalog.py
221
+ mhp_extension/detectron2/detectron2/data/common.py
222
+ mhp_extension/detectron2/detectron2/data/dataset_mapper.py
223
+ mhp_extension/detectron2/detectron2/data/datasets/builtin_meta.py
224
+ mhp_extension/detectron2/detectron2/data/datasets/builtin.py
225
+ mhp_extension/detectron2/detectron2/data/datasets/cityscapes.py
226
+ mhp_extension/detectron2/detectron2/data/datasets/coco.py
227
+ mhp_extension/detectron2/detectron2/data/datasets/__init__.py
228
+ mhp_extension/detectron2/detectron2/data/datasets/lvis.py
229
+ mhp_extension/detectron2/detectron2/data/datasets/lvis_v0_5_categories.py
230
+ mhp_extension/detectron2/detectron2/data/datasets/pascal_voc.py
231
+ mhp_extension/detectron2/detectron2/data/datasets/README.md
232
+ mhp_extension/detectron2/detectron2/data/datasets/register_coco.py
233
+ mhp_extension/detectron2/detectron2/data/detection_utils.py
234
+ mhp_extension/detectron2/detectron2/data/__init__.py
235
+ mhp_extension/detectron2/detectron2/data/samplers/distributed_sampler.py
236
+ mhp_extension/detectron2/detectron2/data/samplers/grouped_batch_sampler.py
237
+ mhp_extension/detectron2/detectron2/data/samplers/__init__.py
238
+ mhp_extension/detectron2/detectron2/data/transforms/__init__.py
239
+ mhp_extension/detectron2/detectron2/data/transforms/transform_gen.py
240
+ mhp_extension/detectron2/detectron2/data/transforms/transform.py
241
+ mhp_extension/detectron2/detectron2/engine/defaults.py
242
+ mhp_extension/detectron2/detectron2/engine/hooks.py
243
+ mhp_extension/detectron2/detectron2/engine/__init__.py
244
+ mhp_extension/detectron2/detectron2/engine/launch.py
245
+ mhp_extension/detectron2/detectron2/engine/train_loop.py
246
+ mhp_extension/detectron2/detectron2/evaluation/cityscapes_evaluation.py
247
+ mhp_extension/detectron2/detectron2/evaluation/coco_evaluation.py
248
+ mhp_extension/detectron2/detectron2/evaluation/evaluator.py
249
+ mhp_extension/detectron2/detectron2/evaluation/__init__.py
250
+ mhp_extension/detectron2/detectron2/evaluation/lvis_evaluation.py
251
+ mhp_extension/detectron2/detectron2/evaluation/panoptic_evaluation.py
252
+ mhp_extension/detectron2/detectron2/evaluation/pascal_voc_evaluation.py
253
+ mhp_extension/detectron2/detectron2/evaluation/rotated_coco_evaluation.py
254
+ mhp_extension/detectron2/detectron2/evaluation/sem_seg_evaluation.py
255
+ mhp_extension/detectron2/detectron2/evaluation/testing.py
256
+ mhp_extension/detectron2/detectron2/export/api.py
257
+ mhp_extension/detectron2/detectron2/export/c10.py
258
+ mhp_extension/detectron2/detectron2/export/caffe2_export.py
259
+ mhp_extension/detectron2/detectron2/export/caffe2_inference.py
260
+ mhp_extension/detectron2/detectron2/export/caffe2_modeling.py
261
+ mhp_extension/detectron2/detectron2/export/__init__.py
262
+ mhp_extension/detectron2/detectron2/export/patcher.py
263
+ mhp_extension/detectron2/detectron2/export/README.md
264
+ mhp_extension/detectron2/detectron2/export/shared.py
265
+ mhp_extension/detectron2/detectron2/__init__.py
266
+ mhp_extension/detectron2/detectron2/layers/batch_norm.py
267
+ mhp_extension/detectron2/detectron2/layers/blocks.py
268
+ mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
269
+ mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
270
+ mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
271
+ mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
272
+ mhp_extension/detectron2/detectron2/layers/csrc/cuda_version.cu
273
+ mhp_extension/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda.cu
274
+ mhp_extension/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu
275
+ mhp_extension/detectron2/detectron2/layers/csrc/deformable/deform_conv.h
276
+ mhp_extension/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp
277
+ mhp_extension/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
278
+ mhp_extension/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated.h
279
+ mhp_extension/detectron2/detectron2/layers/csrc/README.md
280
+ mhp_extension/detectron2/detectron2/layers/csrc/ROIAlign/ROIAlign_cpu.cpp
281
+ mhp_extension/detectron2/detectron2/layers/csrc/ROIAlign/ROIAlign_cuda.cu
282
+ mhp_extension/detectron2/detectron2/layers/csrc/ROIAlign/ROIAlign.h
283
+ mhp_extension/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp
284
+ mhp_extension/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu
285
+ mhp_extension/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h
286
+ mhp_extension/detectron2/detectron2/layers/csrc/vision.cpp
287
+ mhp_extension/detectron2/detectron2/layers/deform_conv.py
288
+ mhp_extension/detectron2/detectron2/layers/__init__.py
289
+ mhp_extension/detectron2/detectron2/layers/mask_ops.py
290
+ mhp_extension/detectron2/detectron2/layers/nms.py
291
+ mhp_extension/detectron2/detectron2/layers/roi_align.py
292
+ mhp_extension/detectron2/detectron2/layers/roi_align_rotated.py
293
+ mhp_extension/detectron2/detectron2/layers/rotated_boxes.py
294
+ mhp_extension/detectron2/detectron2/layers/shape_spec.py
295
+ mhp_extension/detectron2/detectron2/layers/wrappers.py
296
+ mhp_extension/detectron2/detectron2/modeling/anchor_generator.py
297
+ mhp_extension/detectron2/detectron2/modeling/backbone/backbone.py
298
+ mhp_extension/detectron2/detectron2/modeling/backbone/build.py
299
+ mhp_extension/detectron2/detectron2/modeling/backbone/fpn.py
300
+ mhp_extension/detectron2/detectron2/modeling/backbone/__init__.py
301
+ mhp_extension/detectron2/detectron2/modeling/backbone/resnet.py
302
+ mhp_extension/detectron2/detectron2/modeling/box_regression.py
303
+ mhp_extension/detectron2/detectron2/modeling/__init__.py
304
+ mhp_extension/detectron2/detectron2/modeling/matcher.py
305
+ mhp_extension/detectron2/detectron2/modeling/meta_arch/build.py
306
+ mhp_extension/detectron2/detectron2/modeling/meta_arch/__init__.py
307
+ mhp_extension/detectron2/detectron2/modeling/meta_arch/panoptic_fpn.py
308
+ mhp_extension/detectron2/detectron2/modeling/meta_arch/rcnn.py
309
+ mhp_extension/detectron2/detectron2/modeling/meta_arch/retinanet.py
310
+ mhp_extension/detectron2/detectron2/modeling/meta_arch/semantic_seg.py
311
+ mhp_extension/detectron2/detectron2/modeling/poolers.py
312
+ mhp_extension/detectron2/detectron2/modeling/postprocessing.py
313
+ mhp_extension/detectron2/detectron2/modeling/proposal_generator/build.py
314
+ mhp_extension/detectron2/detectron2/modeling/proposal_generator/__init__.py
315
+ mhp_extension/detectron2/detectron2/modeling/proposal_generator/proposal_utils.py
316
+ mhp_extension/detectron2/detectron2/modeling/proposal_generator/rpn_outputs.py
317
+ mhp_extension/detectron2/detectron2/modeling/proposal_generator/rpn.py
318
+ mhp_extension/detectron2/detectron2/modeling/proposal_generator/rrpn.py
319
+ mhp_extension/detectron2/detectron2/modeling/roi_heads/box_head.py
320
+ mhp_extension/detectron2/detectron2/modeling/roi_heads/cascade_rcnn.py
321
+ mhp_extension/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py
322
+ mhp_extension/detectron2/detectron2/modeling/roi_heads/__init__.py
323
+ mhp_extension/detectron2/detectron2/modeling/roi_heads/keypoint_head.py
324
+ mhp_extension/detectron2/detectron2/modeling/roi_heads/mask_head.py
325
+ mhp_extension/detectron2/detectron2/modeling/roi_heads/roi_heads.py
326
+ mhp_extension/detectron2/detectron2/modeling/roi_heads/rotated_fast_rcnn.py
327
+ mhp_extension/detectron2/detectron2/modeling/sampling.py
328
+ mhp_extension/detectron2/detectron2/modeling/test_time_augmentation.py
329
+ mhp_extension/detectron2/detectron2/model_zoo/__init__.py
330
+ mhp_extension/detectron2/detectron2/model_zoo/model_zoo.py
331
+ mhp_extension/detectron2/detectron2/solver/build.py
332
+ mhp_extension/detectron2/detectron2/solver/__init__.py
333
+ mhp_extension/detectron2/detectron2/solver/lr_scheduler.py
334
+ mhp_extension/detectron2/detectron2/structures/boxes.py
335
+ mhp_extension/detectron2/detectron2/structures/image_list.py
336
+ mhp_extension/detectron2/detectron2/structures/__init__.py
337
+ mhp_extension/detectron2/detectron2/structures/instances.py
338
+ mhp_extension/detectron2/detectron2/structures/keypoints.py
339
+ mhp_extension/detectron2/detectron2/structures/masks.py
340
+ mhp_extension/detectron2/detectron2/structures/rotated_boxes.py
341
+ mhp_extension/detectron2/detectron2/utils/analysis.py
342
+ mhp_extension/detectron2/detectron2/utils/collect_env.py
343
+ mhp_extension/detectron2/detectron2/utils/colormap.py
344
+ mhp_extension/detectron2/detectron2/utils/comm.py
345
+ mhp_extension/detectron2/detectron2/utils/env.py
346
+ mhp_extension/detectron2/detectron2/utils/events.py
347
+ mhp_extension/detectron2/detectron2/utils/__init__.py
348
+ mhp_extension/detectron2/detectron2/utils/logger.py
349
+ mhp_extension/detectron2/detectron2/utils/memory.py
350
+ mhp_extension/detectron2/detectron2/utils/README.md
351
+ mhp_extension/detectron2/detectron2/utils/registry.py
352
+ mhp_extension/detectron2/detectron2/utils/serialize.py
353
+ mhp_extension/detectron2/detectron2/utils/video_visualizer.py
354
+ mhp_extension/detectron2/detectron2/utils/visualizer.py
355
+ mhp_extension/detectron2/dev/linter.sh
356
+ mhp_extension/detectron2/dev/packaging/build_all_wheels.sh
357
+ mhp_extension/detectron2/dev/packaging/build_wheel.sh
358
+ mhp_extension/detectron2/dev/packaging/gen_wheel_index.sh
359
+ mhp_extension/detectron2/dev/packaging/pkg_helpers.bash
360
+ mhp_extension/detectron2/dev/packaging/README.md
361
+ mhp_extension/detectron2/dev/parse_results.sh
362
+ mhp_extension/detectron2/dev/README.md
363
+ mhp_extension/detectron2/dev/run_inference_tests.sh
364
+ mhp_extension/detectron2/dev/run_instant_tests.sh
365
+ mhp_extension/detectron2/docker/docker-compose.yml
366
+ mhp_extension/detectron2/docker/Dockerfile
367
+ mhp_extension/detectron2/docker/Dockerfile-circleci
368
+ mhp_extension/detectron2/docker/README.md
369
+ mhp_extension/detectron2/docs/conf.py
370
+ mhp_extension/detectron2/docs/.gitignore
371
+ mhp_extension/detectron2/docs/index.rst
372
+ mhp_extension/detectron2/docs/Makefile
373
+ mhp_extension/detectron2/docs/modules/checkpoint.rst
374
+ mhp_extension/detectron2/docs/modules/config.rst
375
+ mhp_extension/detectron2/docs/modules/data.rst
376
+ mhp_extension/detectron2/docs/modules/engine.rst
377
+ mhp_extension/detectron2/docs/modules/evaluation.rst
378
+ mhp_extension/detectron2/docs/modules/export.rst
379
+ mhp_extension/detectron2/docs/modules/index.rst
380
+ mhp_extension/detectron2/docs/modules/layers.rst
381
+ mhp_extension/detectron2/docs/modules/modeling.rst
382
+ mhp_extension/detectron2/docs/modules/model_zoo.rst
383
+ mhp_extension/detectron2/docs/modules/solver.rst
384
+ mhp_extension/detectron2/docs/modules/structures.rst
385
+ mhp_extension/detectron2/docs/modules/utils.rst
386
+ mhp_extension/detectron2/docs/notes/benchmarks.md
387
+ mhp_extension/detectron2/docs/notes/changelog.md
388
+ mhp_extension/detectron2/docs/notes/compatibility.md
389
+ mhp_extension/detectron2/docs/notes/contributing.md
390
+ mhp_extension/detectron2/docs/notes/index.rst
391
+ mhp_extension/detectron2/docs/README.md
392
+ mhp_extension/detectron2/docs/tutorials/builtin_datasets.md
393
+ mhp_extension/detectron2/docs/tutorials/configs.md
394
+ mhp_extension/detectron2/docs/tutorials/data_loading.md
395
+ mhp_extension/detectron2/docs/tutorials/datasets.md
396
+ mhp_extension/detectron2/docs/tutorials/deployment.md
397
+ mhp_extension/detectron2/docs/tutorials/evaluation.md
398
+ mhp_extension/detectron2/docs/tutorials/extend.md
399
+ mhp_extension/detectron2/docs/tutorials/getting_started.md
400
+ mhp_extension/detectron2/docs/tutorials/index.rst
401
+ mhp_extension/detectron2/docs/tutorials/install.md
402
+ mhp_extension/detectron2/docs/tutorials/models.md
403
+ mhp_extension/detectron2/docs/tutorials/README.md
404
+ mhp_extension/detectron2/docs/tutorials/training.md
405
+ mhp_extension/detectron2/docs/tutorials/write-models.md
406
+ mhp_extension/detectron2/.flake8
407
+ mhp_extension/detectron2/GETTING_STARTED.md
408
+ mhp_extension/detectron2/.gitignore
409
+ mhp_extension/detectron2/INSTALL.md
410
+ mhp_extension/detectron2/LICENSE
411
+ mhp_extension/detectron2/MODEL_ZOO.md
412
+ mhp_extension/detectron2/projects/DensePose/apply_net.py
413
+ mhp_extension/detectron2/projects/DensePose/configs/Base-DensePose-RCNN-FPN.yaml
414
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_s1x.yaml
415
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC1_s1x.yaml
416
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC2_s1x.yaml
417
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x_legacy.yaml
418
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x.yaml
419
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC1_s1x.yaml
420
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC2_s1x.yaml
421
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_s1x.yaml
422
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC1_s1x.yaml
423
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC2_s1x.yaml
424
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x_legacy.yaml
425
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml
426
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC1_s1x.yaml
427
+ mhp_extension/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC2_s1x.yaml
428
+ mhp_extension/detectron2/projects/DensePose/configs/evolution/Base-RCNN-FPN-MC.yaml
429
+ mhp_extension/detectron2/projects/DensePose/configs/evolution/faster_rcnn_R_50_FPN_1x_MC.yaml
430
+ mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_DL_instant_test.yaml
431
+ mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_inference_acc_test.yaml
432
+ mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_instant_test.yaml
433
+ mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_training_acc_test.yaml
434
+ mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_TTA_inference_acc_test.yaml
435
+ mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC1_instant_test.yaml
436
+ mhp_extension/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC2_instant_test.yaml
437
+ mhp_extension/detectron2/projects/DensePose/densepose/config.py
438
+ mhp_extension/detectron2/projects/DensePose/densepose/data/build.py
439
+ mhp_extension/detectron2/projects/DensePose/densepose/data/dataset_mapper.py
440
+ mhp_extension/detectron2/projects/DensePose/densepose/data/datasets/builtin.py
441
+ mhp_extension/detectron2/projects/DensePose/densepose/data/datasets/coco.py
442
+ mhp_extension/detectron2/projects/DensePose/densepose/data/datasets/__init__.py
443
+ mhp_extension/detectron2/projects/DensePose/densepose/data/__init__.py
444
+ mhp_extension/detectron2/projects/DensePose/densepose/data/structures.py
445
+ mhp_extension/detectron2/projects/DensePose/densepose/densepose_coco_evaluation.py
446
+ mhp_extension/detectron2/projects/DensePose/densepose/densepose_head.py
447
+ mhp_extension/detectron2/projects/DensePose/densepose/evaluator.py
448
+ mhp_extension/detectron2/projects/DensePose/densepose/__init__.py
449
+ mhp_extension/detectron2/projects/DensePose/densepose/modeling/test_time_augmentation.py
450
+ mhp_extension/detectron2/projects/DensePose/densepose/roi_head.py
451
+ mhp_extension/detectron2/projects/DensePose/densepose/utils/dbhelper.py
452
+ mhp_extension/detectron2/projects/DensePose/densepose/utils/logger.py
453
+ mhp_extension/detectron2/projects/DensePose/densepose/utils/transform.py
454
+ mhp_extension/detectron2/projects/DensePose/densepose/vis/base.py
455
+ mhp_extension/detectron2/projects/DensePose/densepose/vis/bounding_box.py
456
+ mhp_extension/detectron2/projects/DensePose/densepose/vis/densepose.py
457
+ mhp_extension/detectron2/projects/DensePose/densepose/vis/extractor.py
458
+ mhp_extension/detectron2/projects/DensePose/dev/README.md
459
+ mhp_extension/detectron2/projects/DensePose/dev/run_inference_tests.sh
460
+ mhp_extension/detectron2/projects/DensePose/dev/run_instant_tests.sh
461
+ mhp_extension/detectron2/projects/DensePose/doc/GETTING_STARTED.md
462
+ mhp_extension/detectron2/projects/DensePose/doc/MODEL_ZOO.md
463
+ mhp_extension/detectron2/projects/DensePose/doc/TOOL_APPLY_NET.md
464
+ mhp_extension/detectron2/projects/DensePose/doc/TOOL_QUERY_DB.md
465
+ mhp_extension/detectron2/projects/DensePose/query_db.py
466
+ mhp_extension/detectron2/projects/DensePose/README.md
467
+ mhp_extension/detectron2/projects/DensePose/tests/common.py
468
+ mhp_extension/detectron2/projects/DensePose/tests/test_model_e2e.py
469
+ mhp_extension/detectron2/projects/DensePose/tests/test_setup.py
470
+ mhp_extension/detectron2/projects/DensePose/tests/test_structures.py
471
+ mhp_extension/detectron2/projects/DensePose/train_net.py
472
+ mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-PointRend-RCNN-FPN.yaml
473
+ mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml
474
+ mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml
475
+ mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml
476
+ mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_parsing.yaml
477
+ mhp_extension/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_parsing.yaml
478
+ mhp_extension/detectron2/projects/PointRend/configs/SemanticSegmentation/Base-PointRend-Semantic-FPN.yaml
479
+ mhp_extension/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml
480
+ mhp_extension/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_50_FPN_1x_coco.yaml
481
+ mhp_extension/detectron2/projects/PointRend/finetune_net.py
482
+ mhp_extension/detectron2/projects/PointRend/logs/hadoop.kylin.libdfs.log
483
+ mhp_extension/detectron2/projects/PointRend/point_rend/coarse_mask_head.py
484
+ mhp_extension/detectron2/projects/PointRend/point_rend/color_augmentation.py
485
+ mhp_extension/detectron2/projects/PointRend/point_rend/config.py
486
+ mhp_extension/detectron2/projects/PointRend/point_rend/dataset_mapper.py
487
+ mhp_extension/detectron2/projects/PointRend/point_rend/__init__.py
488
+ mhp_extension/detectron2/projects/PointRend/point_rend/point_features.py
489
+ mhp_extension/detectron2/projects/PointRend/point_rend/point_head.py
490
+ mhp_extension/detectron2/projects/PointRend/point_rend/roi_heads.py
491
+ mhp_extension/detectron2/projects/PointRend/point_rend/semantic_seg.py
492
+ mhp_extension/detectron2/projects/PointRend/README.md
493
+ mhp_extension/detectron2/projects/PointRend/run.sh
494
+ mhp_extension/detectron2/projects/PointRend/train_net.py
495
+ mhp_extension/detectron2/projects/README.md
496
+ mhp_extension/detectron2/projects/TensorMask/configs/Base-TensorMask.yaml
497
+ mhp_extension/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_1x.yaml
498
+ mhp_extension/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_6x.yaml
499
+ mhp_extension/detectron2/projects/TensorMask/README.md
500
+ mhp_extension/detectron2/projects/TensorMask/setup.py
501
+ mhp_extension/detectron2/projects/TensorMask/tensormask/arch.py
502
+ mhp_extension/detectron2/projects/TensorMask/tensormask/config.py
503
+ mhp_extension/detectron2/projects/TensorMask/tensormask/__init__.py
504
+ mhp_extension/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat_cuda.cu
505
+ mhp_extension/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat.h
506
+ mhp_extension/detectron2/projects/TensorMask/tensormask/layers/csrc/vision.cpp
507
+ mhp_extension/detectron2/projects/TensorMask/tensormask/layers/__init__.py
508
+ mhp_extension/detectron2/projects/TensorMask/tensormask/layers/swap_align2nat.py
509
+ mhp_extension/detectron2/projects/TensorMask/tests/__init__.py
510
+ mhp_extension/detectron2/projects/TensorMask/tests/test_swap_align2nat.py
511
+ mhp_extension/detectron2/projects/TensorMask/train_net.py
512
+ mhp_extension/detectron2/projects/TridentNet/configs/Base-TridentNet-Fast-C4.yaml
513
+ mhp_extension/detectron2/projects/TridentNet/configs/tridentnet_fast_R_101_C4_3x.yaml
514
+ mhp_extension/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_1x.yaml
515
+ mhp_extension/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_3x.yaml
516
+ mhp_extension/detectron2/projects/TridentNet/README.md
517
+ mhp_extension/detectron2/projects/TridentNet/train_net.py
518
+ mhp_extension/detectron2/projects/TridentNet/tridentnet/config.py
519
+ mhp_extension/detectron2/projects/TridentNet/tridentnet/__init__.py
520
+ mhp_extension/detectron2/projects/TridentNet/tridentnet/trident_backbone.py
521
+ mhp_extension/detectron2/projects/TridentNet/tridentnet/trident_conv.py
522
+ mhp_extension/detectron2/projects/TridentNet/tridentnet/trident_rcnn.py
523
+ mhp_extension/detectron2/projects/TridentNet/tridentnet/trident_rpn.py
524
+ mhp_extension/detectron2/README.md
525
+ mhp_extension/detectron2/setup.cfg
526
+ mhp_extension/detectron2/setup.py
527
+ mhp_extension/detectron2/tests/data/__init__.py
528
+ mhp_extension/detectron2/tests/data/test_coco.py
529
+ mhp_extension/detectron2/tests/data/test_detection_utils.py
530
+ mhp_extension/detectron2/tests/data/test_rotation_transform.py
531
+ mhp_extension/detectron2/tests/data/test_sampler.py
532
+ mhp_extension/detectron2/tests/data/test_transforms.py
533
+ mhp_extension/detectron2/tests/__init__.py
534
+ mhp_extension/detectron2/tests/layers/__init__.py
535
+ mhp_extension/detectron2/tests/layers/test_mask_ops.py
536
+ mhp_extension/detectron2/tests/layers/test_nms_rotated.py
537
+ mhp_extension/detectron2/tests/layers/test_roi_align.py
538
+ mhp_extension/detectron2/tests/layers/test_roi_align_rotated.py
539
+ mhp_extension/detectron2/tests/modeling/__init__.py
540
+ mhp_extension/detectron2/tests/modeling/test_anchor_generator.py
541
+ mhp_extension/detectron2/tests/modeling/test_box2box_transform.py
542
+ mhp_extension/detectron2/tests/modeling/test_fast_rcnn.py
543
+ mhp_extension/detectron2/tests/modeling/test_model_e2e.py
544
+ mhp_extension/detectron2/tests/modeling/test_roi_heads.py
545
+ mhp_extension/detectron2/tests/modeling/test_roi_pooler.py
546
+ mhp_extension/detectron2/tests/modeling/test_rpn.py
547
+ mhp_extension/detectron2/tests/README.md
548
+ mhp_extension/detectron2/tests/structures/__init__.py
549
+ mhp_extension/detectron2/tests/structures/test_boxes.py
550
+ mhp_extension/detectron2/tests/structures/test_imagelist.py
551
+ mhp_extension/detectron2/tests/structures/test_instances.py
552
+ mhp_extension/detectron2/tests/structures/test_rotated_boxes.py
553
+ mhp_extension/detectron2/tests/test_checkpoint.py
554
+ mhp_extension/detectron2/tests/test_config.py
555
+ mhp_extension/detectron2/tests/test_export_caffe2.py
556
+ mhp_extension/detectron2/tests/test_model_analysis.py
557
+ mhp_extension/detectron2/tests/test_model_zoo.py
558
+ mhp_extension/detectron2/tests/test_visualizer.py
559
+ mhp_extension/detectron2/tools/analyze_model.py
560
+ mhp_extension/detectron2/tools/benchmark.py
561
+ mhp_extension/detectron2/tools/convert-torchvision-to-d2.py
562
+ mhp_extension/detectron2/tools/deploy/caffe2_converter.py
563
+ mhp_extension/detectron2/tools/deploy/caffe2_mask_rcnn.cpp
564
+ mhp_extension/detectron2/tools/deploy/README.md
565
+ mhp_extension/detectron2/tools/deploy/torchscript_traced_mask_rcnn.cpp
566
+ mhp_extension/detectron2/tools/finetune_net.py
567
+ mhp_extension/detectron2/tools/inference.sh
568
+ mhp_extension/detectron2/tools/plain_train_net.py
569
+ mhp_extension/detectron2/tools/README.md
570
+ mhp_extension/detectron2/tools/run.sh
571
+ mhp_extension/detectron2/tools/train_net.py
572
+ mhp_extension/detectron2/tools/visualize_data.py
573
+ mhp_extension/detectron2/tools/visualize_json_results.py
574
+ mhp_extension/global_local_parsing/global_local_datasets.py
575
+ mhp_extension/global_local_parsing/global_local_evaluate.py
576
+ mhp_extension/global_local_parsing/global_local_train.py
577
+ mhp_extension/global_local_parsing/make_id_list.py
578
+ mhp_extension/logits_fusion.py
579
+ mhp_extension/make_crop_and_mask_w_mask_nms.py
580
+ mhp_extension/README.md
581
+ mhp_extension/scripts/make_coco_style_annotation.sh
582
+ mhp_extension/scripts/make_crop.sh
583
+ mhp_extension/scripts/parsing_fusion.sh
584
+ modules/bn.py
585
+ modules/deeplab.py
586
+ modules/dense.py
587
+ modules/functions.py
588
+ modules/__init__.py
589
+ modules/misc.py
590
+ modules/residual.py
591
+ modules/src/checks.h
592
+ modules/src/inplace_abn.cpp
593
+ modules/src/inplace_abn_cpu.cpp
594
+ modules/src/inplace_abn_cuda.cu
595
+ modules/src/inplace_abn_cuda_half.cu
596
+ modules/src/inplace_abn.h
597
+ modules/src/utils/checks.h
598
+ modules/src/utils/common.h
599
+ modules/src/utils/cuda.cuh
600
+ networks/AugmentCE2P.py
601
+ networks/backbone/mobilenetv2.py
602
+ networks/backbone/resnet.py
603
+ networks/backbone/resnext.py
604
+ networks/context_encoding/aspp.py
605
+ networks/context_encoding/ocnet.py
606
+ networks/context_encoding/psp.py
607
+ networks/__init__.py
608
+ README.md
609
+ requirements.txt
610
+ simple_extractor.py
611
+ training_code/MVANet/README.org
612
+ train.py
613
+ utils/consistency_loss.py
614
+ utils/criterion.py
615
+ utils/encoding.py
616
+ utils/__init__.py
617
+ utils/kl_loss.py
618
+ utils/lovasz_softmax.py
619
+ utils/miou.py
620
+ utils/schp.py
621
+ utils/soft_dice_loss.py
622
+ utils/transforms.py
623
+ utils/warmup_scheduler.py
624
+ MVANet_Inference/README.org
625
+ #+end_src
626
+
627
+ * List of files to remove
628
+ #+begin_src conf :tangle ./rm.txt
629
+ ComfyUI_MVANet/__pycache__/__init__.cpython-310.pyc
630
+ ComfyUI_MVANet/#README.org#
631
+ ComfyUI_MVANet/.#README.org
632
+ ComfyUI_MVANet/README.org~
633
+ ComfyUI_MVANet/.README.org.~undo-tree~
634
+ #main.org#
635
+ .#main.org
636
+ main.org~
637
+ .main.org.~undo-tree~
638
+ .README.md.~undo-tree~
639
+ ComfyUI_MVANet/.#README.org
640
+ ComfyUI_AEMatter/__pycache__/__init__.cpython-310.pyc
641
+ ComfyUI_AEMatter/AEMatter.class.py
642
+ ComfyUI_AEMatter/AEMatter.execute.py
643
+ ComfyUI_AEMatter/AEMatter.function.py
644
+ ComfyUI_AEMatter/AEMatter.import.py
645
+ ComfyUI_MVANet/MVANet_inference.class.py
646
+ ComfyUI_MVANet/MVANet_inference.execute.py
647
+ ComfyUI_MVANet/MVANet_inference.function.py
648
+ ComfyUI_MVANet/MVANet_inference.import.py
649
+ ComfyUI_MVANet/MVANet_inference.unify.sh
650
+ ComfyUI_AEMatter/AEMatter.unify.sh
651
+ git_add.txt
652
+ git_lfs_track.txt
653
+ gitignore.txt
654
+ rm.txt
655
+ work.sh
656
+ #+end_src
657
+
658
+ * List of patterns to ignore
659
+ #+begin_src conf :tangle ./gitignore.txt
660
+ log/
661
+ pretrain_model/
662
+ commit_and_push.sh
663
+ #+end_src
mhp_extension/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Self Correction for Human Parsing
2
+
3
+ We propose a simple yet effective multiple human parsing framework by extending our self-correction network.
4
+
5
+ Here we show an example usage jupyter notebook in [demo.ipynb](./demo.ipynb).
6
+
7
+ ## Requirements
8
+
9
+ Please see [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md) for further requirements.
10
+
11
+ ## Citation
12
+
13
+ Please cite our work if you find this repo useful in your research.
14
+
15
+ ```latex
16
+ @article{li2019self,
17
+ title={Self-Correction for Human Parsing},
18
+ author={Li, Peike and Xu, Yunqiu and Wei, Yunchao and Yang, Yi},
19
+ journal={arXiv preprint arXiv:1910.09777},
20
+ year={2019}
21
+ }
22
+ ```
23
+
24
+ ## Visualization
25
+
26
+ * Source Image.
27
+ ![demo](./demo/demo.jpg)
28
+ * Instance Human Mask.
29
+ ![demo-lip](./demo/demo_instance_human_mask.png)
30
+ * Global Human Parsing Result.
31
+ ![demo-lip](./demo/demo_global_human_parsing.png)
32
+ * Multiple Human Parsing Result.
33
+ ![demo-lip](./demo/demo_multiple_human_parsing.png)
34
+
35
+ ## Related
36
+
37
+ Our implementation is based on the [Detectron2](https://github.com/facebookresearch/detectron2).
38
+
mhp_extension/coco_style_annotation_creator/__pycache__/pycococreatortools.cpython-37.pyc ADDED
Binary file (3.6 kB). View file
 
mhp_extension/coco_style_annotation_creator/human_to_coco.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ import pycococreatortools
9
+
10
+
11
+ def get_arguments():
12
+ parser = argparse.ArgumentParser(description="transform mask annotation to coco annotation")
13
+ parser.add_argument("--dataset", type=str, default='CIHP', help="name of dataset (CIHP, MHPv2 or VIP)")
14
+ parser.add_argument("--json_save_dir", type=str, default='../data/msrcnn_finetune_annotations',
15
+ help="path to save coco-style annotation json file")
16
+ parser.add_argument("--use_val", type=bool, default=False,
17
+ help="use train+val set for finetuning or not")
18
+ parser.add_argument("--train_img_dir", type=str, default='../data/instance-level_human_parsing/Training/Images',
19
+ help="train image path")
20
+ parser.add_argument("--train_anno_dir", type=str,
21
+ default='../data/instance-level_human_parsing/Training/Human_ids',
22
+ help="train human mask path")
23
+ parser.add_argument("--val_img_dir", type=str, default='../data/instance-level_human_parsing/Validation/Images',
24
+ help="val image path")
25
+ parser.add_argument("--val_anno_dir", type=str,
26
+ default='../data/instance-level_human_parsing/Validation/Human_ids',
27
+ help="val human mask path")
28
+ return parser.parse_args()
29
+
30
+
31
+ def main(args):
32
+ INFO = {
33
+ "description": args.split_name + " Dataset",
34
+ "url": "",
35
+ "version": "",
36
+ "year": 2019,
37
+ "contributor": "xyq",
38
+ "date_created": datetime.datetime.utcnow().isoformat(' ')
39
+ }
40
+
41
+ LICENSES = [
42
+ {
43
+ "id": 1,
44
+ "name": "",
45
+ "url": ""
46
+ }
47
+ ]
48
+
49
+ CATEGORIES = [
50
+ {
51
+ 'id': 1,
52
+ 'name': 'person',
53
+ 'supercategory': 'person',
54
+ },
55
+ ]
56
+
57
+ coco_output = {
58
+ "info": INFO,
59
+ "licenses": LICENSES,
60
+ "categories": CATEGORIES,
61
+ "images": [],
62
+ "annotations": []
63
+ }
64
+
65
+ image_id = 1
66
+ segmentation_id = 1
67
+
68
+ for image_name in os.listdir(args.train_img_dir):
69
+ image = Image.open(os.path.join(args.train_img_dir, image_name))
70
+ image_info = pycococreatortools.create_image_info(
71
+ image_id, image_name, image.size
72
+ )
73
+ coco_output["images"].append(image_info)
74
+
75
+ human_mask_name = os.path.splitext(image_name)[0] + '.png'
76
+ human_mask = np.asarray(Image.open(os.path.join(args.train_anno_dir, human_mask_name)))
77
+ human_gt_labels = np.unique(human_mask)
78
+
79
+ for i in range(1, len(human_gt_labels)):
80
+ category_info = {'id': 1, 'is_crowd': 0}
81
+ binary_mask = np.uint8(human_mask == i)
82
+ annotation_info = pycococreatortools.create_annotation_info(
83
+ segmentation_id, image_id, category_info, binary_mask,
84
+ image.size, tolerance=10
85
+ )
86
+ if annotation_info is not None:
87
+ coco_output["annotations"].append(annotation_info)
88
+
89
+ segmentation_id += 1
90
+ image_id += 1
91
+
92
+ if not os.path.exists(args.json_save_dir):
93
+ os.makedirs(args.json_save_dir)
94
+ if not args.use_val:
95
+ with open('{}/{}_train.json'.format(args.json_save_dir, args.split_name), 'w') as output_json_file:
96
+ json.dump(coco_output, output_json_file)
97
+ else:
98
+ for image_name in os.listdir(args.val_img_dir):
99
+ image = Image.open(os.path.join(args.val_img_dir, image_name))
100
+ image_info = pycococreatortools.create_image_info(
101
+ image_id, image_name, image.size
102
+ )
103
+ coco_output["images"].append(image_info)
104
+
105
+ human_mask_name = os.path.splitext(image_name)[0] + '.png'
106
+ human_mask = np.asarray(Image.open(os.path.join(args.val_anno_dir, human_mask_name)))
107
+ human_gt_labels = np.unique(human_mask)
108
+
109
+ for i in range(1, len(human_gt_labels)):
110
+ category_info = {'id': 1, 'is_crowd': 0}
111
+ binary_mask = np.uint8(human_mask == i)
112
+ annotation_info = pycococreatortools.create_annotation_info(
113
+ segmentation_id, image_id, category_info, binary_mask,
114
+ image.size, tolerance=10
115
+ )
116
+ if annotation_info is not None:
117
+ coco_output["annotations"].append(annotation_info)
118
+
119
+ segmentation_id += 1
120
+ image_id += 1
121
+
122
+ with open('{}/{}_trainval.json'.format(args.json_save_dir, args.split_name), 'w') as output_json_file:
123
+ json.dump(coco_output, output_json_file)
124
+
125
+ coco_output_val = {
126
+ "info": INFO,
127
+ "licenses": LICENSES,
128
+ "categories": CATEGORIES,
129
+ "images": [],
130
+ "annotations": []
131
+ }
132
+
133
+ image_id_val = 1
134
+ segmentation_id_val = 1
135
+
136
+ for image_name in os.listdir(args.val_img_dir):
137
+ image = Image.open(os.path.join(args.val_img_dir, image_name))
138
+ image_info = pycococreatortools.create_image_info(
139
+ image_id_val, image_name, image.size
140
+ )
141
+ coco_output_val["images"].append(image_info)
142
+
143
+ human_mask_name = os.path.splitext(image_name)[0] + '.png'
144
+ human_mask = np.asarray(Image.open(os.path.join(args.val_anno_dir, human_mask_name)))
145
+ human_gt_labels = np.unique(human_mask)
146
+
147
+ for i in range(1, len(human_gt_labels)):
148
+ category_info = {'id': 1, 'is_crowd': 0}
149
+ binary_mask = np.uint8(human_mask == i)
150
+ annotation_info = pycococreatortools.create_annotation_info(
151
+ segmentation_id_val, image_id_val, category_info, binary_mask,
152
+ image.size, tolerance=10
153
+ )
154
+ if annotation_info is not None:
155
+ coco_output_val["annotations"].append(annotation_info)
156
+
157
+ segmentation_id_val += 1
158
+ image_id_val += 1
159
+
160
+ with open('{}/{}_val.json'.format(args.json_save_dir, args.split_name), 'w') as output_json_file_val:
161
+ json.dump(coco_output_val, output_json_file_val)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ args = get_arguments()
166
+ main(args)
mhp_extension/coco_style_annotation_creator/pycococreatortools.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import datetime
3
+ import numpy as np
4
+ from itertools import groupby
5
+ from skimage import measure
6
+ from PIL import Image
7
+ from pycocotools import mask
8
+
9
+ convert = lambda text: int(text) if text.isdigit() else text.lower()
10
+ natrual_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
11
+
12
+
13
+ def resize_binary_mask(array, new_size):
14
+ image = Image.fromarray(array.astype(np.uint8) * 255)
15
+ image = image.resize(new_size)
16
+ return np.asarray(image).astype(np.bool_)
17
+
18
+
19
+ def close_contour(contour):
20
+ if not np.array_equal(contour[0], contour[-1]):
21
+ contour = np.vstack((contour, contour[0]))
22
+ return contour
23
+
24
+
25
+ def binary_mask_to_rle(binary_mask):
26
+ rle = {'counts': [], 'size': list(binary_mask.shape)}
27
+ counts = rle.get('counts')
28
+ for i, (value, elements) in enumerate(groupby(binary_mask.ravel(order='F'))):
29
+ if i == 0 and value == 1:
30
+ counts.append(0)
31
+ counts.append(len(list(elements)))
32
+
33
+ return rle
34
+
35
+
36
+ def binary_mask_to_polygon(binary_mask, tolerance=0):
37
+ """Converts a binary mask to COCO polygon representation
38
+ Args:
39
+ binary_mask: a 2D binary numpy array where '1's represent the object
40
+ tolerance: Maximum distance from original points of polygon to approximated
41
+ polygonal chain. If tolerance is 0, the original coordinate array is returned.
42
+ """
43
+ polygons = []
44
+ # pad mask to close contours of shapes which start and end at an edge
45
+ padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
46
+ contours = measure.find_contours(padded_binary_mask, 0.5)
47
+ contours = np.subtract(contours, 1)
48
+ for contour in contours:
49
+ contour = close_contour(contour)
50
+ contour = measure.approximate_polygon(contour, tolerance)
51
+ if len(contour) < 3:
52
+ continue
53
+ contour = np.flip(contour, axis=1)
54
+ segmentation = contour.ravel().tolist()
55
+ # after padding and subtracting 1 we may get -0.5 points in our segmentation
56
+ segmentation = [0 if i < 0 else i for i in segmentation]
57
+ polygons.append(segmentation)
58
+
59
+ return polygons
60
+
61
+
62
+ def create_image_info(image_id, file_name, image_size,
63
+ date_captured=datetime.datetime.utcnow().isoformat(' '),
64
+ license_id=1, coco_url="", flickr_url=""):
65
+ image_info = {
66
+ "id": image_id,
67
+ "file_name": file_name,
68
+ "width": image_size[0],
69
+ "height": image_size[1],
70
+ "date_captured": date_captured,
71
+ "license": license_id,
72
+ "coco_url": coco_url,
73
+ "flickr_url": flickr_url
74
+ }
75
+
76
+ return image_info
77
+
78
+
79
+ def create_annotation_info(annotation_id, image_id, category_info, binary_mask,
80
+ image_size=None, tolerance=2, bounding_box=None):
81
+ if image_size is not None:
82
+ binary_mask = resize_binary_mask(binary_mask, image_size)
83
+
84
+ binary_mask_encoded = mask.encode(np.asfortranarray(binary_mask.astype(np.uint8)))
85
+
86
+ area = mask.area(binary_mask_encoded)
87
+ if area < 1:
88
+ return None
89
+
90
+ if bounding_box is None:
91
+ bounding_box = mask.toBbox(binary_mask_encoded)
92
+
93
+ if category_info["is_crowd"]:
94
+ is_crowd = 1
95
+ segmentation = binary_mask_to_rle(binary_mask)
96
+ else:
97
+ is_crowd = 0
98
+ segmentation = binary_mask_to_polygon(binary_mask, tolerance)
99
+ if not segmentation:
100
+ return None
101
+
102
+ annotation_info = {
103
+ "id": annotation_id,
104
+ "image_id": image_id,
105
+ "category_id": category_info["id"],
106
+ "iscrowd": is_crowd,
107
+ "area": area.tolist(),
108
+ "bbox": bounding_box.tolist(),
109
+ "segmentation": segmentation,
110
+ "width": binary_mask.shape[1],
111
+ "height": binary_mask.shape[0],
112
+ }
113
+
114
+ return annotation_info
mhp_extension/coco_style_annotation_creator/test_human2coco_format.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ from PIL import Image
6
+
7
+ import pycococreatortools
8
+
9
+
10
+ def get_arguments():
11
+ parser = argparse.ArgumentParser(description="transform mask annotation to coco annotation")
12
+ parser.add_argument("--dataset", type=str, default='CIHP', help="name of dataset (CIHP, MHPv2 or VIP)")
13
+ parser.add_argument("--json_save_dir", type=str, default='../data/CIHP/annotations',
14
+ help="path to save coco-style annotation json file")
15
+ parser.add_argument("--test_img_dir", type=str, default='../data/CIHP/Testing/Images',
16
+ help="test image path")
17
+ return parser.parse_args()
18
+
19
+ args = get_arguments()
20
+
21
+ INFO = {
22
+ "description": args.dataset + "Dataset",
23
+ "url": "",
24
+ "version": "",
25
+ "year": 2020,
26
+ "contributor": "yunqiuxu",
27
+ "date_created": datetime.datetime.utcnow().isoformat(' ')
28
+ }
29
+
30
+ LICENSES = [
31
+ {
32
+ "id": 1,
33
+ "name": "",
34
+ "url": ""
35
+ }
36
+ ]
37
+
38
+ CATEGORIES = [
39
+ {
40
+ 'id': 1,
41
+ 'name': 'person',
42
+ 'supercategory': 'person',
43
+ },
44
+ ]
45
+
46
+
47
+ def main(args):
48
+ coco_output = {
49
+ "info": INFO,
50
+ "licenses": LICENSES,
51
+ "categories": CATEGORIES,
52
+ "images": [],
53
+ "annotations": []
54
+ }
55
+
56
+ image_id = 1
57
+
58
+ for image_name in os.listdir(args.test_img_dir):
59
+ image = Image.open(os.path.join(args.test_img_dir, image_name))
60
+ image_info = pycococreatortools.create_image_info(
61
+ image_id, image_name, image.size
62
+ )
63
+ coco_output["images"].append(image_info)
64
+ image_id += 1
65
+
66
+ if not os.path.exists(os.path.join(args.json_save_dir)):
67
+ os.mkdir(os.path.join(args.json_save_dir))
68
+
69
+ with open('{}/{}.json'.format(args.json_save_dir, args.dataset), 'w') as output_json_file:
70
+ json.dump(coco_output, output_json_file)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main(args)
mhp_extension/demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
mhp_extension/detectron2/.circleci/config.yml ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python CircleCI 2.0 configuration file
2
+ #
3
+ # Check https://circleci.com/docs/2.0/language-python/ for more details
4
+ #
5
+ version: 2
6
+
7
+ # -------------------------------------------------------------------------------------
8
+ # Environments to run the jobs in
9
+ # -------------------------------------------------------------------------------------
10
+ cpu: &cpu
11
+ docker:
12
+ - image: circleci/python:3.6.8-stretch
13
+ resource_class: medium
14
+
15
+ gpu: &gpu
16
+ machine:
17
+ image: ubuntu-1604:201903-01
18
+ docker_layer_caching: true
19
+ resource_class: gpu.small
20
+
21
+ # -------------------------------------------------------------------------------------
22
+ # Re-usable commands
23
+ # -------------------------------------------------------------------------------------
24
+ install_python: &install_python
25
+ - run:
26
+ name: Install Python
27
+ working_directory: ~/
28
+ command: |
29
+ pyenv install 3.6.1
30
+ pyenv global 3.6.1
31
+
32
+ setup_venv: &setup_venv
33
+ - run:
34
+ name: Setup Virtual Env
35
+ working_directory: ~/
36
+ command: |
37
+ python -m venv ~/venv
38
+ echo ". ~/venv/bin/activate" >> $BASH_ENV
39
+ . ~/venv/bin/activate
40
+ python --version
41
+ which python
42
+ which pip
43
+ pip install --upgrade pip
44
+
45
+ install_dep: &install_dep
46
+ - run:
47
+ name: Install Dependencies
48
+ command: |
49
+ pip install --progress-bar off -U 'git+https://github.com/facebookresearch/fvcore'
50
+ pip install --progress-bar off cython opencv-python
51
+ pip install --progress-bar off 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
52
+ pip install --progress-bar off torch torchvision
53
+
54
+ install_detectron2: &install_detectron2
55
+ - run:
56
+ name: Install Detectron2
57
+ command: |
58
+ gcc --version
59
+ pip install -U --progress-bar off -e .[dev]
60
+ python -m detectron2.utils.collect_env
61
+
62
+ install_nvidia_driver: &install_nvidia_driver
63
+ - run:
64
+ name: Install nvidia driver
65
+ working_directory: ~/
66
+ command: |
67
+ wget -q 'https://s3.amazonaws.com/ossci-linux/nvidia_driver/NVIDIA-Linux-x86_64-430.40.run'
68
+ sudo /bin/bash ./NVIDIA-Linux-x86_64-430.40.run -s --no-drm
69
+ nvidia-smi
70
+
71
+ run_unittests: &run_unittests
72
+ - run:
73
+ name: Run Unit Tests
74
+ command: |
75
+ python -m unittest discover -v -s tests
76
+
77
+ # -------------------------------------------------------------------------------------
78
+ # Jobs to run
79
+ # -------------------------------------------------------------------------------------
80
+ jobs:
81
+ cpu_tests:
82
+ <<: *cpu
83
+
84
+ working_directory: ~/detectron2
85
+
86
+ steps:
87
+ - checkout
88
+ - <<: *setup_venv
89
+
90
+ # Cache the venv directory that contains dependencies
91
+ - restore_cache:
92
+ keys:
93
+ - cache-key-{{ .Branch }}-ID-20200425
94
+
95
+ - <<: *install_dep
96
+
97
+ - save_cache:
98
+ paths:
99
+ - ~/venv
100
+ key: cache-key-{{ .Branch }}-ID-20200425
101
+
102
+ - <<: *install_detectron2
103
+
104
+ - run:
105
+ name: isort
106
+ command: |
107
+ isort -c -sp .
108
+ - run:
109
+ name: black
110
+ command: |
111
+ black --check -l 100 .
112
+ - run:
113
+ name: flake8
114
+ command: |
115
+ flake8 .
116
+
117
+ - <<: *run_unittests
118
+
119
+ gpu_tests:
120
+ <<: *gpu
121
+
122
+ working_directory: ~/detectron2
123
+
124
+ steps:
125
+ - checkout
126
+ - <<: *install_nvidia_driver
127
+
128
+ - run:
129
+ name: Install nvidia-docker
130
+ working_directory: ~/
131
+ command: |
132
+ curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
133
+ distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
134
+ curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \
135
+ sudo tee /etc/apt/sources.list.d/nvidia-docker.list
136
+ sudo apt-get update && sudo apt-get install -y nvidia-docker2
137
+ # reload the docker daemon configuration
138
+ sudo pkill -SIGHUP dockerd
139
+
140
+ - run:
141
+ name: Launch docker
142
+ working_directory: ~/detectron2/docker
143
+ command: |
144
+ nvidia-docker build -t detectron2:v0 -f Dockerfile-circleci .
145
+ nvidia-docker run -itd --name d2 detectron2:v0
146
+ docker exec -it d2 nvidia-smi
147
+
148
+ - run:
149
+ name: Build Detectron2
150
+ command: |
151
+ docker exec -it d2 pip install 'git+https://github.com/facebookresearch/fvcore'
152
+ docker cp ~/detectron2 d2:/detectron2
153
+ # This will build d2 for the target GPU arch only
154
+ docker exec -it d2 pip install -e /detectron2
155
+ docker exec -it d2 python3 -m detectron2.utils.collect_env
156
+ docker exec -it d2 python3 -c 'import torch; assert(torch.cuda.is_available())'
157
+
158
+ - run:
159
+ name: Run Unit Tests
160
+ command: |
161
+ docker exec -e CIRCLECI=true -it d2 python3 -m unittest discover -v -s /detectron2/tests
162
+
163
+ workflows:
164
+ version: 2
165
+ regular_test:
166
+ jobs:
167
+ - cpu_tests
168
+ - gpu_tests
169
+
170
+ #nightly_test:
171
+ #jobs:
172
+ #- gpu_tests
173
+ #triggers:
174
+ #- schedule:
175
+ #cron: "0 0 * * *"
176
+ #filters:
177
+ #branches:
178
+ #only:
179
+ #- master
mhp_extension/detectron2/.clang-format ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AccessModifierOffset: -1
2
+ AlignAfterOpenBracket: AlwaysBreak
3
+ AlignConsecutiveAssignments: false
4
+ AlignConsecutiveDeclarations: false
5
+ AlignEscapedNewlinesLeft: true
6
+ AlignOperands: false
7
+ AlignTrailingComments: false
8
+ AllowAllParametersOfDeclarationOnNextLine: false
9
+ AllowShortBlocksOnASingleLine: false
10
+ AllowShortCaseLabelsOnASingleLine: false
11
+ AllowShortFunctionsOnASingleLine: Empty
12
+ AllowShortIfStatementsOnASingleLine: false
13
+ AllowShortLoopsOnASingleLine: false
14
+ AlwaysBreakAfterReturnType: None
15
+ AlwaysBreakBeforeMultilineStrings: true
16
+ AlwaysBreakTemplateDeclarations: true
17
+ BinPackArguments: false
18
+ BinPackParameters: false
19
+ BraceWrapping:
20
+ AfterClass: false
21
+ AfterControlStatement: false
22
+ AfterEnum: false
23
+ AfterFunction: false
24
+ AfterNamespace: false
25
+ AfterObjCDeclaration: false
26
+ AfterStruct: false
27
+ AfterUnion: false
28
+ BeforeCatch: false
29
+ BeforeElse: false
30
+ IndentBraces: false
31
+ BreakBeforeBinaryOperators: None
32
+ BreakBeforeBraces: Attach
33
+ BreakBeforeTernaryOperators: true
34
+ BreakConstructorInitializersBeforeComma: false
35
+ BreakAfterJavaFieldAnnotations: false
36
+ BreakStringLiterals: false
37
+ ColumnLimit: 80
38
+ CommentPragmas: '^ IWYU pragma:'
39
+ ConstructorInitializerAllOnOneLineOrOnePerLine: true
40
+ ConstructorInitializerIndentWidth: 4
41
+ ContinuationIndentWidth: 4
42
+ Cpp11BracedListStyle: true
43
+ DerivePointerAlignment: false
44
+ DisableFormat: false
45
+ ForEachMacros: [ FOR_EACH, FOR_EACH_ENUMERATE, FOR_EACH_KV, FOR_EACH_R, FOR_EACH_RANGE, ]
46
+ IncludeCategories:
47
+ - Regex: '^<.*\.h(pp)?>'
48
+ Priority: 1
49
+ - Regex: '^<.*'
50
+ Priority: 2
51
+ - Regex: '.*'
52
+ Priority: 3
53
+ IndentCaseLabels: true
54
+ IndentWidth: 2
55
+ IndentWrappedFunctionNames: false
56
+ KeepEmptyLinesAtTheStartOfBlocks: false
57
+ MacroBlockBegin: ''
58
+ MacroBlockEnd: ''
59
+ MaxEmptyLinesToKeep: 1
60
+ NamespaceIndentation: None
61
+ ObjCBlockIndentWidth: 2
62
+ ObjCSpaceAfterProperty: false
63
+ ObjCSpaceBeforeProtocolList: false
64
+ PenaltyBreakBeforeFirstCallParameter: 1
65
+ PenaltyBreakComment: 300
66
+ PenaltyBreakFirstLessLess: 120
67
+ PenaltyBreakString: 1000
68
+ PenaltyExcessCharacter: 1000000
69
+ PenaltyReturnTypeOnItsOwnLine: 200
70
+ PointerAlignment: Left
71
+ ReflowComments: true
72
+ SortIncludes: true
73
+ SpaceAfterCStyleCast: false
74
+ SpaceBeforeAssignmentOperators: true
75
+ SpaceBeforeParens: ControlStatements
76
+ SpaceInEmptyParentheses: false
77
+ SpacesBeforeTrailingComments: 1
78
+ SpacesInAngles: false
79
+ SpacesInContainerLiterals: true
80
+ SpacesInCStyleCastParentheses: false
81
+ SpacesInParentheses: false
82
+ SpacesInSquareBrackets: false
83
+ Standard: Cpp11
84
+ TabWidth: 8
85
+ UseTab: Never
mhp_extension/detectron2/.flake8 ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example .flake8 config, used when developing *Black* itself.
2
+ # Keep in sync with setup.cfg which is used for source packages.
3
+
4
+ [flake8]
5
+ ignore = W503, E203, E221, C901, C408, E741
6
+ max-line-length = 100
7
+ max-complexity = 18
8
+ select = B,C,E,F,W,T4,B9
9
+ exclude = build,__init__.py
mhp_extension/detectron2/.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # output dir
2
+ output
3
+ instant_test_output
4
+ inference_test_output
5
+
6
+
7
+ *.jpg
8
+ *.png
9
+ *.txt
10
+ *.json
11
+ *.diff
12
+
13
+ # compilation and distribution
14
+ __pycache__
15
+ _ext
16
+ *.pyc
17
+ *.so
18
+ detectron2.egg-info/
19
+ build/
20
+ dist/
21
+ wheels/
22
+
23
+ # pytorch/python/numpy formats
24
+ *.pth
25
+ *.pkl
26
+ *.npy
27
+
28
+ # ipython/jupyter notebooks
29
+ *.ipynb
30
+ **/.ipynb_checkpoints/
31
+
32
+ # Editor temporaries
33
+ *.swn
34
+ *.swo
35
+ *.swp
36
+ *~
37
+
38
+ # editor settings
39
+ .idea
40
+ .vscode
41
+
42
+ # project dirs
43
+ /detectron2/model_zoo/configs
44
+ /datasets
45
+ /projects/*/datasets
46
+ /models
mhp_extension/detectron2/GETTING_STARTED.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Getting Started with Detectron2
2
+
3
+ This document provides a brief intro of the usage of builtin command-line tools in detectron2.
4
+
5
+ For a tutorial that involves actual coding with the API,
6
+ see our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
7
+ which covers how to run inference with an
8
+ existing model, and how to train a builtin model on a custom dataset.
9
+
10
+ For more advanced tutorials, refer to our [documentation](https://detectron2.readthedocs.io/tutorials/extend.html).
11
+
12
+
13
+ ### Inference Demo with Pre-trained Models
14
+
15
+ 1. Pick a model and its config file from
16
+ [model zoo](MODEL_ZOO.md),
17
+ for example, `mask_rcnn_R_50_FPN_3x.yaml`.
18
+ 2. We provide `demo.py` that is able to run builtin standard models. Run it with:
19
+ ```
20
+ cd demo/
21
+ python demo.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
22
+ --input input1.jpg input2.jpg \
23
+ [--other-options]
24
+ --opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl
25
+ ```
26
+ The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation.
27
+ This command will run the inference and show visualizations in an OpenCV window.
28
+
29
+ For details of the command line arguments, see `demo.py -h` or look at its source code
30
+ to understand its behavior. Some common arguments are:
31
+ * To run __on your webcam__, replace `--input files` with `--webcam`.
32
+ * To run __on a video__, replace `--input files` with `--video-input video.mp4`.
33
+ * To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`.
34
+ * To save outputs to a directory (for images) or a file (for webcam or video), use `--output`.
35
+
36
+
37
+ ### Training & Evaluation in Command Line
38
+
39
+ We provide a script in "tools/{,plain_}train_net.py", that is made to train
40
+ all the configs provided in detectron2.
41
+ You may want to use it as a reference to write your own training script.
42
+
43
+ To train a model with "train_net.py", first
44
+ setup the corresponding datasets following
45
+ [datasets/README.md](./datasets/README.md),
46
+ then run:
47
+ ```
48
+ cd tools/
49
+ ./train_net.py --num-gpus 8 \
50
+ --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
51
+ ```
52
+
53
+ The configs are made for 8-GPU training.
54
+ To train on 1 GPU, you may need to [change some parameters](https://arxiv.org/abs/1706.02677), e.g.:
55
+ ```
56
+ ./train_net.py \
57
+ --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
58
+ --num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025
59
+ ```
60
+
61
+ For most models, CPU training is not supported.
62
+
63
+ To evaluate a model's performance, use
64
+ ```
65
+ ./train_net.py \
66
+ --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
67
+ --eval-only MODEL.WEIGHTS /path/to/checkpoint_file
68
+ ```
69
+ For more options, see `./train_net.py -h`.
70
+
71
+ ### Use Detectron2 APIs in Your Code
72
+
73
+ See our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
74
+ to learn how to use detectron2 APIs to:
75
+ 1. run inference with an existing model
76
+ 2. train a builtin model on a custom dataset
77
+
78
+ See [detectron2/projects](https://github.com/facebookresearch/detectron2/tree/master/projects)
79
+ for more ways to build your project on detectron2.
mhp_extension/detectron2/INSTALL.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+
3
+ Our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
4
+ has step-by-step instructions that install detectron2.
5
+ The [Dockerfile](docker)
6
+ also installs detectron2 with a few simple commands.
7
+
8
+ ### Requirements
9
+ - Linux or macOS with Python ≥ 3.6
10
+ - PyTorch ≥ 1.4
11
+ - [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
12
+ You can install them together at [pytorch.org](https://pytorch.org) to make sure of this.
13
+ - OpenCV, optional, needed by demo and visualization
14
+ - pycocotools: `pip install cython; pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'`
15
+
16
+
17
+ ### Build Detectron2 from Source
18
+
19
+ gcc & g++ ≥ 5 are required. [ninja](https://ninja-build.org/) is recommended for faster build.
20
+ After having them, run:
21
+ ```
22
+ python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
23
+ # (add --user if you don't have permission)
24
+
25
+ # Or, to install it from a local clone:
26
+ git clone https://github.com/facebookresearch/detectron2.git
27
+ python -m pip install -e detectron2
28
+
29
+ # Or if you are on macOS
30
+ # CC=clang CXX=clang++ python -m pip install -e .
31
+ ```
32
+
33
+ To __rebuild__ detectron2 that's built from a local clone, use `rm -rf build/ **/*.so` to clean the
34
+ old build first. You often need to rebuild detectron2 after reinstalling PyTorch.
35
+
36
+ ### Install Pre-Built Detectron2 (Linux only)
37
+ ```
38
+ # for CUDA 10.1:
39
+ python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/index.html
40
+ ```
41
+ You can replace cu101 with "cu{100,92}" or "cpu".
42
+
43
+ Note that:
44
+ 1. Such installation has to be used with certain version of official PyTorch release.
45
+ See [releases](https://github.com/facebookresearch/detectron2/releases) for requirements.
46
+ It will not work with a different version of PyTorch or a non-official build of PyTorch.
47
+ 2. Such installation is out-of-date w.r.t. master branch of detectron2. It may not be
48
+ compatible with the master branch of a research project that uses detectron2 (e.g. those in
49
+ [projects](projects) or [meshrcnn](https://github.com/facebookresearch/meshrcnn/)).
50
+
51
+ ### Common Installation Issues
52
+
53
+ If you met issues using the pre-built detectron2, please uninstall it and try building it from source.
54
+
55
+ Click each issue for its solutions:
56
+
57
+ <details>
58
+ <summary>
59
+ Undefined torch/aten/caffe2 symbols, or segmentation fault immediately when running the library.
60
+ </summary>
61
+ <br/>
62
+
63
+ This usually happens when detectron2 or torchvision is not
64
+ compiled with the version of PyTorch you're running.
65
+
66
+ Pre-built torchvision or detectron2 has to work with the corresponding official release of pytorch.
67
+ If the error comes from a pre-built torchvision, uninstall torchvision and pytorch and reinstall them
68
+ following [pytorch.org](http://pytorch.org). So the versions will match.
69
+
70
+ If the error comes from a pre-built detectron2, check [release notes](https://github.com/facebookresearch/detectron2/releases)
71
+ to see the corresponding pytorch version required for each pre-built detectron2.
72
+
73
+ If the error comes from detectron2 or torchvision that you built manually from source,
74
+ remove files you built (`build/`, `**/*.so`) and rebuild it so it can pick up the version of pytorch currently in your environment.
75
+
76
+ If you cannot resolve this problem, please include the output of `gdb -ex "r" -ex "bt" -ex "quit" --args python -m detectron2.utils.collect_env`
77
+ in your issue.
78
+ </details>
79
+
80
+ <details>
81
+ <summary>
82
+ Undefined C++ symbols (e.g. `GLIBCXX`) or C++ symbols not found.
83
+ </summary>
84
+ <br/>
85
+ Usually it's because the library is compiled with a newer C++ compiler but run with an old C++ runtime.
86
+
87
+ This often happens with old anaconda.
88
+ Try `conda update libgcc`. Then rebuild detectron2.
89
+
90
+ The fundamental solution is to run the code with proper C++ runtime.
91
+ One way is to use `LD_PRELOAD=/path/to/libstdc++.so`.
92
+
93
+ </details>
94
+
95
+ <details>
96
+ <summary>
97
+ "Not compiled with GPU support" or "Detectron2 CUDA Compiler: not available".
98
+ </summary>
99
+ <br/>
100
+ CUDA is not found when building detectron2.
101
+ You should make sure
102
+
103
+ ```
104
+ python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
105
+ ```
106
+
107
+ print valid outputs at the time you build detectron2.
108
+
109
+ Most models can run inference (but not training) without GPU support. To use CPUs, set `MODEL.DEVICE='cpu'` in the config.
110
+ </details>
111
+
112
+ <details>
113
+ <summary>
114
+ "invalid device function" or "no kernel image is available for execution".
115
+ </summary>
116
+ <br/>
117
+ Two possibilities:
118
+
119
+ * You build detectron2 with one version of CUDA but run it with a different version.
120
+
121
+ To check whether it is the case,
122
+ use `python -m detectron2.utils.collect_env` to find out inconsistent CUDA versions.
123
+ In the output of this command, you should expect "Detectron2 CUDA Compiler", "CUDA_HOME", "PyTorch built with - CUDA"
124
+ to contain cuda libraries of the same version.
125
+
126
+ When they are inconsistent,
127
+ you need to either install a different build of PyTorch (or build by yourself)
128
+ to match your local CUDA installation, or install a different version of CUDA to match PyTorch.
129
+
130
+ * Detectron2 or PyTorch/torchvision is not built for the correct GPU architecture (compute compatibility).
131
+
132
+ The GPU architecture for PyTorch/detectron2/torchvision is available in the "architecture flags" in
133
+ `python -m detectron2.utils.collect_env`.
134
+
135
+ The GPU architecture flags of detectron2/torchvision by default matches the GPU model detected
136
+ during compilation. This means the compiled code may not work on a different GPU model.
137
+ To overwrite the GPU architecture for detectron2/torchvision, use `TORCH_CUDA_ARCH_LIST` environment variable during compilation.
138
+
139
+ For example, `export TORCH_CUDA_ARCH_LIST=6.0,7.0` makes it compile for both P100s and V100s.
140
+ Visit [developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus) to find out
141
+ the correct compute compatibility number for your device.
142
+
143
+ </details>
144
+
145
+ <details>
146
+ <summary>
147
+ Undefined CUDA symbols; cannot open libcudart.so; other nvcc failures.
148
+ </summary>
149
+ <br/>
150
+ The version of NVCC you use to build detectron2 or torchvision does
151
+ not match the version of CUDA you are running with.
152
+ This often happens when using anaconda's CUDA runtime.
153
+
154
+ Use `python -m detectron2.utils.collect_env` to find out inconsistent CUDA versions.
155
+ In the output of this command, you should expect "Detectron2 CUDA Compiler", "CUDA_HOME", "PyTorch built with - CUDA"
156
+ to contain cuda libraries of the same version.
157
+
158
+ When they are inconsistent,
159
+ you need to either install a different build of PyTorch (or build by yourself)
160
+ to match your local CUDA installation, or install a different version of CUDA to match PyTorch.
161
+ </details>
162
+
163
+
164
+ <details>
165
+ <summary>
166
+ "ImportError: cannot import name '_C'".
167
+ </summary>
168
+ <br/>
169
+ Please build and install detectron2 following the instructions above.
170
+
171
+ If you are running code from detectron2's root directory, `cd` to a different one.
172
+ Otherwise you may not import the code that you installed.
173
+ </details>
174
+
175
+ <details>
176
+ <summary>
177
+ ONNX conversion segfault after some "TraceWarning".
178
+ </summary>
179
+ <br/>
180
+ The ONNX package is compiled with too old compiler.
181
+
182
+ Please build and install ONNX from its source code using a compiler
183
+ whose version is closer to what's used by PyTorch (available in `torch.__config__.show()`).
184
+ </details>
mhp_extension/detectron2/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2019 - present, Facebook, Inc
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.