AnonymousUser20 commited on
Commit
3e426e9
·
verified ·
1 Parent(s): 178d33b

Upload 944 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. ID-like-train-change-bg/README.md +3 -0
  3. ID-like-train-change-bg/__pycache__/config.cpython-311.pyc +0 -0
  4. ID-like-train-change-bg/__pycache__/config.cpython-37.pyc +0 -0
  5. ID-like-train-change-bg/bash_allocation.slurm +15 -0
  6. ID-like-train-change-bg/batch_file_deal.py +39 -0
  7. ID-like-train-change-bg/clip checkpoint path/ViT-B-16.pt +3 -0
  8. ID-like-train-change-bg/clip/__init__.py +1 -0
  9. ID-like-train-change-bg/clip/__pycache__/__init__.cpython-311.pyc +0 -0
  10. ID-like-train-change-bg/clip/__pycache__/clip.cpython-311.pyc +0 -0
  11. ID-like-train-change-bg/clip/__pycache__/model.cpython-311.pyc +0 -0
  12. ID-like-train-change-bg/clip/__pycache__/simple_tokenizer.cpython-311.pyc +0 -0
  13. ID-like-train-change-bg/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  14. ID-like-train-change-bg/clip/clip.py +232 -0
  15. ID-like-train-change-bg/clip/model.py +438 -0
  16. ID-like-train-change-bg/clip/simple_tokenizer.py +132 -0
  17. ID-like-train-change-bg/config.py +248 -0
  18. ID-like-train-change-bg/dataloaders/__init__.py +4 -0
  19. ID-like-train-change-bg/dataloaders/__pycache__/__init__.cpython-311.pyc +0 -0
  20. ID-like-train-change-bg/dataloaders/__pycache__/bird200.cpython-311.pyc +0 -0
  21. ID-like-train-change-bg/dataloaders/__pycache__/car196.cpython-311.pyc +0 -0
  22. ID-like-train-change-bg/dataloaders/__pycache__/food101.cpython-311.pyc +0 -0
  23. ID-like-train-change-bg/dataloaders/__pycache__/pet37.cpython-311.pyc +0 -0
  24. ID-like-train-change-bg/dataloaders/bird200.py +64 -0
  25. ID-like-train-change-bg/dataloaders/car196.py +149 -0
  26. ID-like-train-change-bg/dataloaders/food101.py +123 -0
  27. ID-like-train-change-bg/dataloaders/pet37.py +152 -0
  28. ID-like-train-change-bg/error1.txt +0 -0
  29. ID-like-train-change-bg/eval_ood_detection.py +123 -0
  30. ID-like-train-change-bg/output1.txt +0 -0
  31. ID-like-train-change-bg/utils/__init__.py +2 -0
  32. ID-like-train-change-bg/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  33. ID-like-train-change-bg/utils/__pycache__/__init__.cpython-37.pyc +0 -0
  34. ID-like-train-change-bg/utils/__pycache__/common.cpython-311.pyc +0 -0
  35. ID-like-train-change-bg/utils/__pycache__/common.cpython-37.pyc +0 -0
  36. ID-like-train-change-bg/utils/__pycache__/dataloaders_utils.cpython-311.pyc +0 -0
  37. ID-like-train-change-bg/utils/__pycache__/file_ops.cpython-311.pyc +0 -0
  38. ID-like-train-change-bg/utils/__pycache__/file_ops.cpython-37.pyc +0 -0
  39. ID-like-train-change-bg/utils/__pycache__/id_like.cpython-311.pyc +0 -0
  40. ID-like-train-change-bg/utils/__pycache__/id_like_loss.cpython-311.pyc +0 -0
  41. ID-like-train-change-bg/utils/__pycache__/id_like_utils.cpython-311.pyc +0 -0
  42. ID-like-train-change-bg/utils/__pycache__/imagenet_templates.cpython-311.pyc +0 -0
  43. ID-like-train-change-bg/utils/__pycache__/plot_util.cpython-311.pyc +0 -0
  44. ID-like-train-change-bg/utils/__pycache__/plot_util.cpython-37.pyc +0 -0
  45. ID-like-train-change-bg/utils/common.py +164 -0
  46. ID-like-train-change-bg/utils/dataloaders_utils.py +462 -0
  47. ID-like-train-change-bg/utils/file_ops.py +68 -0
  48. ID-like-train-change-bg/utils/id_like.py +184 -0
  49. ID-like-train-change-bg/utils/id_like_loss.py +52 -0
  50. ID-like-train-change-bg/utils/id_like_utils.py +298 -0
.gitattributes CHANGED
@@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ LoCoOp-train-change-bg/figure/framework.png filter=lfs diff=lfs merge=lfs -text
37
+ LoCoOp-train-change-bg/figure/visualization_examples.png filter=lfs diff=lfs merge=lfs -text
38
+ LoCoOp-train-change-bg/output/shot_1/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
39
+ LoCoOp-train-change-bg/output/shot_10/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
40
+ LoCoOp-train-change-bg/output/shot_100000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
41
+ LoCoOp-train-change-bg/output/shot_30000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
42
+ LoCoOp-train-change-bg/output/shot_5/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
43
+ LoCoOp-train-change-bg/output/shot_60000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
44
+ LoCoOp-train-change-bg/output/shot_70000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
45
+ LoCoOp-train-change-bg/output/shot_80000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
46
+ LoCoOp-train-change-bg/output/shot_90000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
47
+ LoCoOp-train/figure/framework.png filter=lfs diff=lfs merge=lfs -text
48
+ LoCoOp-train/figure/visualization_examples.png filter=lfs diff=lfs merge=lfs -text
49
+ LoCoOp-train/output/shot_1/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
50
+ LoCoOp-train/output/shot_10/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
51
+ LoCoOp-train/output/shot_10000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
52
+ LoCoOp-train/output/shot_100000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
53
+ LoCoOp-train/output/shot_20000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
54
+ LoCoOp-train/output/shot_30000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
55
+ LoCoOp-train/output/shot_40000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
56
+ LoCoOp-train/output/shot_5/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
57
+ LoCoOp-train/output/shot_50000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
58
+ LoCoOp-train/output/shot_60000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
59
+ LoCoOp-train/output/shot_70000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
60
+ LoCoOp-train/output/shot_80000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
61
+ LoCoOp-train/output/shot_90000/prompt_learner/model.pth.tar-50 filter=lfs diff=lfs merge=lfs -text
62
+ zest_code/demo_assets/depths/n02824058_184.png filter=lfs diff=lfs merge=lfs -text
63
+ zest_code/demo_assets/input_imgs/n02824058_184.png filter=lfs diff=lfs merge=lfs -text
64
+ zest_code/demo_assets/material_exemplars/101001.png filter=lfs diff=lfs merge=lfs -text
65
+ zest_code/demo_assets/output_images/result.png filter=lfs diff=lfs merge=lfs -text
66
+ zest_code/demo_assets/temp_file/init_img.png filter=lfs diff=lfs merge=lfs -text
67
+ zest_code/error1.txt filter=lfs diff=lfs merge=lfs -text
68
+ zest_code/error2.txt filter=lfs diff=lfs merge=lfs -text
69
+ zest_code/fig/gradio_demo.png filter=lfs diff=lfs merge=lfs -text
70
+ zest_code/fig/method.jpg filter=lfs diff=lfs merge=lfs -text
ID-like-train-change-bg/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ID-like Prompt Learning for Few-Shot Out-of-Distribution Detection
2
+
3
+ This repository contains the code of our CVPR'2024 paper ID-like Prompt Learning for Few-Shot Out-of-Distribution Detection. We will gradually improve and enhance the code.
ID-like-train-change-bg/__pycache__/config.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
ID-like-train-change-bg/__pycache__/config.cpython-37.pyc ADDED
Binary file (31.9 kB). View file
 
ID-like-train-change-bg/bash_allocation.slurm ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=zzzz1
3
+ #SBATCH --output=output1.txt
4
+ #SBATCH --error=error1.txt
5
+ #SBATCH --cpus-per-task=5
6
+ #SBATCH --ntasks=4
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --mem=100000
9
+ #SBATCH -N 1
10
+
11
+
12
+ python batch_file_deal.py
13
+
14
+ # 取消当前作业以释放节点
15
+ scancel $SLURM_JOB_ID
ID-like-train-change-bg/batch_file_deal.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+
4
+ # 设置 PYTHONPATH 环境变量
5
+ pythonpath = '.'
6
+ if 'PYTHONPATH' in os.environ:
7
+ pythonpath += ':' + os.environ['PYTHONPATH']
8
+ os.environ['PYTHONPATH'] = pythonpath
9
+
10
+ ROOT = "/home/zhourixin/OOD_Folder/CODE/other_methods/ID-like-train-change-bg"
11
+
12
+
13
+ run_file = ROOT+"/eval_ood_detection.py"
14
+
15
+ # subprocess.run(["python", run_file, "--n_shot=1", "--batch_size=1"])
16
+ # subprocess.run(["python", run_file, "--n_shot=5", "--batch_size=1"])
17
+ # subprocess.run(["python", run_file, "--n_shot=10", "--batch_size=1"])
18
+
19
+ # subprocess.run(["python", run_file, "--n_shot=10000"])
20
+
21
+ # subprocess.run(["python", run_file, "--n_shot=20000"])
22
+
23
+ # subprocess.run(["python", run_file, "--n_shot=30000"])
24
+
25
+ # subprocess.run(["python", run_file, "--n_shot=40000"])
26
+
27
+ # subprocess.run(["python", run_file, "--n_shot=50000"])
28
+
29
+ # subprocess.run(["python", run_file, "--n_shot=60000"])
30
+
31
+ # subprocess.run(["python", run_file, "--n_shot=70000"])
32
+
33
+ # subprocess.run(["python", run_file, "--n_shot=80000"])
34
+
35
+ # subprocess.run(["python", run_file, "--n_shot=90000"])
36
+
37
+ subprocess.run(["python", run_file, "--n_shot=100000"])
38
+
39
+
ID-like-train-change-bg/clip checkpoint path/ViT-B-16.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e213bf161ab676ffde04a98d171217cce89419f17ec4b3fd69552102861c01ca
3
+ size 13434880
ID-like-train-change-bg/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
ID-like-train-change-bg/clip/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (236 Bytes). View file
 
ID-like-train-change-bg/clip/__pycache__/clip.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
ID-like-train-change-bg/clip/__pycache__/model.cpython-311.pyc ADDED
Binary file (31.9 kB). View file
 
ID-like-train-change-bg/clip/__pycache__/simple_tokenizer.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
ID-like-train-change-bg/clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
ID-like-train-change-bg/clip/clip.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ # from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ # if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ # warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ }
40
+
41
+
42
+ def _download(url: str, root: str):
43
+ os.makedirs(root, exist_ok=True)
44
+ filename = os.path.basename(url)
45
+
46
+ expected_sha256 = url.split("/")[-2]
47
+ download_target = os.path.join(root, filename)
48
+
49
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
50
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
51
+
52
+ if os.path.isfile(download_target):
53
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
54
+ return download_target
55
+ else:
56
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
57
+
58
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
59
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
60
+ while True:
61
+ buffer = source.read(8192)
62
+ if not buffer:
63
+ break
64
+
65
+ output.write(buffer)
66
+ loop.update(len(buffer))
67
+
68
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
69
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
70
+
71
+ return download_target
72
+
73
+
74
+ def _convert_image_to_rgb(image):
75
+ return image.convert("RGB")
76
+
77
+
78
+ def _transform(n_px):
79
+ return Compose([
80
+ Resize(n_px, interpolation=BICUBIC),
81
+ CenterCrop(n_px),
82
+ _convert_image_to_rgb,
83
+ ToTensor(),
84
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
85
+ ])
86
+
87
+
88
+ def available_models() -> List[str]:
89
+ """Returns the names of available CLIP models"""
90
+ return list(_MODELS.keys())
91
+
92
+
93
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
94
+ """Load a CLIP model
95
+
96
+ Parameters
97
+ ----------
98
+ name : str
99
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
100
+
101
+ device : Union[str, torch.device]
102
+ The device to put the loaded model
103
+
104
+ jit : bool
105
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
106
+
107
+ download_root: str
108
+ path to download the model files; by default, it uses "~/.cache/clip"
109
+
110
+ Returns
111
+ -------
112
+ model : torch.nn.Module
113
+ The CLIP model
114
+
115
+ preprocess : Callable[[PIL.Image], torch.Tensor]
116
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
117
+ """
118
+ if name in _MODELS:
119
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
120
+ elif os.path.isfile(name):
121
+ model_path = name
122
+ else:
123
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
124
+
125
+ try:
126
+ # loading JIT archive
127
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
128
+ state_dict = None
129
+ except RuntimeError:
130
+ # loading saved state dict
131
+ if jit:
132
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
133
+ jit = False
134
+ state_dict = torch.load(model_path, map_location="cpu")
135
+
136
+ embed_dim = model.state_dict()["text_projection"].shape[1]
137
+ if not jit:
138
+ model = build_model(state_dict or model.state_dict()).to(device)
139
+ if str(device) == "cpu":
140
+ model.float()
141
+ return model, embed_dim, _transform(model.visual.input_resolution)
142
+
143
+ # patch the device names
144
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
145
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
146
+
147
+ def patch_device(module):
148
+ try:
149
+ graphs = [module.graph] if hasattr(module, "graph") else []
150
+ except RuntimeError:
151
+ graphs = []
152
+
153
+ if hasattr(module, "forward1"):
154
+ graphs.append(module.forward1.graph)
155
+
156
+ for graph in graphs:
157
+ for node in graph.findAllNodes("prim::Constant"):
158
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
159
+ node.copyAttributes(device_node)
160
+
161
+ model.apply(patch_device)
162
+ patch_device(model.encode_image)
163
+ patch_device(model.encode_text)
164
+
165
+ # patch dtype to float32 on CPU
166
+ if str(device) == "cpu":
167
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
168
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
169
+ float_node = float_input.node()
170
+
171
+ def patch_float(module):
172
+ try:
173
+ graphs = [module.graph] if hasattr(module, "graph") else []
174
+ except RuntimeError:
175
+ graphs = []
176
+
177
+ if hasattr(module, "forward1"):
178
+ graphs.append(module.forward1.graph)
179
+
180
+ for graph in graphs:
181
+ for node in graph.findAllNodes("aten::to"):
182
+ inputs = list(node.inputs())
183
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
184
+ if inputs[i].node()["value"] == 5:
185
+ inputs[i].node().copyAttributes(float_node)
186
+
187
+ model.apply(patch_float)
188
+ patch_float(model.encode_image)
189
+ patch_float(model.encode_text)
190
+
191
+ model.float()
192
+
193
+ return model, embed_dim, _transform(model.input_resolution.item())
194
+
195
+
196
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
197
+ """
198
+ Returns the tokenized representation of given input string(s)
199
+
200
+ Parameters
201
+ ----------
202
+ texts : Union[str, List[str]]
203
+ An input string or a list of input strings to tokenize
204
+
205
+ context_length : int
206
+ The context length to use; all CLIP models use 77 as the context length
207
+
208
+ truncate: bool
209
+ Whether to truncate the text in case its encoding is longer than the context length
210
+
211
+ Returns
212
+ -------
213
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
214
+ """
215
+ if isinstance(texts, str):
216
+ texts = [texts]
217
+
218
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
219
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
220
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
221
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
222
+
223
+ for i, tokens in enumerate(all_tokens):
224
+ if len(tokens) > context_length:
225
+ if truncate:
226
+ tokens = tokens[:context_length]
227
+ tokens[-1] = eot_token
228
+ else:
229
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
230
+ result[i, :len(tokens)] = torch.tensor(tokens)
231
+
232
+ return result
ID-like-train-change-bg/clip/model.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisionTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ x = self.ln_post(x[:, 0, :])
236
+
237
+ if self.proj is not None:
238
+ x = x @ self.proj
239
+
240
+ return x
241
+
242
+
243
+ class CLIP(nn.Module):
244
+ def __init__(self,
245
+ embed_dim: int,
246
+ # vision
247
+ image_resolution: int,
248
+ vision_layers: Union[Tuple[int, int, int, int], int],
249
+ vision_width: int,
250
+ vision_patch_size: int,
251
+ # text
252
+ context_length: int,
253
+ vocab_size: int,
254
+ transformer_width: int,
255
+ transformer_heads: int,
256
+ transformer_layers: int
257
+ ):
258
+ super().__init__()
259
+
260
+ self.context_length = context_length
261
+
262
+ if isinstance(vision_layers, (tuple, list)):
263
+ vision_heads = vision_width * 32 // 64
264
+ self.visual = ModifiedResNet(
265
+ layers=vision_layers,
266
+ output_dim=embed_dim,
267
+ heads=vision_heads,
268
+ input_resolution=image_resolution,
269
+ width=vision_width
270
+ )
271
+ else:
272
+ vision_heads = vision_width // 64
273
+ self.visual = VisionTransformer(
274
+ input_resolution=image_resolution,
275
+ patch_size=vision_patch_size,
276
+ width=vision_width,
277
+ layers=vision_layers,
278
+ heads=vision_heads,
279
+ output_dim=embed_dim
280
+ )
281
+
282
+ self.transformer = Transformer(
283
+ width=transformer_width,
284
+ layers=transformer_layers,
285
+ heads=transformer_heads,
286
+ attn_mask=self.build_attention_mask()
287
+ )
288
+
289
+ self.vocab_size = vocab_size
290
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
291
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
292
+ self.ln_final = LayerNorm(transformer_width)
293
+
294
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
295
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
296
+
297
+ self.initialize_parameters()
298
+
299
+ def initialize_parameters(self):
300
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
301
+ nn.init.normal_(self.positional_embedding, std=0.01)
302
+
303
+ if isinstance(self.visual, ModifiedResNet):
304
+ if self.visual.attnpool is not None:
305
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
306
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
307
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
308
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
309
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
310
+
311
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
312
+ for name, param in resnet_block.named_parameters():
313
+ if name.endswith("bn3.weight"):
314
+ nn.init.zeros_(param)
315
+
316
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
317
+ attn_std = self.transformer.width ** -0.5
318
+ fc_std = (2 * self.transformer.width) ** -0.5
319
+ for block in self.transformer.resblocks:
320
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
321
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
322
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
323
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
324
+
325
+ if self.text_projection is not None:
326
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
327
+
328
+ def build_attention_mask(self):
329
+ # lazily create causal attention mask, with full attention between the vision tokens
330
+ # pytorch uses additive attention mask; fill with -inf
331
+ mask = torch.empty(self.context_length, self.context_length)
332
+ mask.fill_(float("-inf"))
333
+ mask.triu_(1) # zero out the lower diagonal
334
+ return mask
335
+
336
+ @property
337
+ def dtype(self):
338
+ return self.visual.conv1.weight.dtype
339
+
340
+ def encode_image(self, image):
341
+ return self.visual(image.type(self.dtype))
342
+
343
+ def encode_text(self, text):
344
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
345
+
346
+ x = x + self.positional_embedding.type(self.dtype)
347
+ x = x.permute(1, 0, 2) # NLD -> LND
348
+ x = self.transformer(x)
349
+ x = x.permute(1, 0, 2) # LND -> NLD
350
+ x = self.ln_final(x).type(self.dtype)
351
+
352
+ # x.shape = [batch_size, n_ctx, transformer.width]
353
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
354
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
355
+
356
+ return x
357
+
358
+ def forward(self, image, text):
359
+ image_features = self.encode_image(image)
360
+ text_features = self.encode_text(text)
361
+
362
+ # normalized features
363
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
364
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
365
+
366
+ # cosine similarity as logits
367
+ logit_scale = self.logit_scale.exp()
368
+ logits_per_image = logit_scale * image_features @ text_features.t()
369
+ logits_per_text = logits_per_image.t()
370
+
371
+ # shape = [global_batch_size, global_batch_size]
372
+ return logits_per_image, logits_per_text
373
+
374
+
375
+ def convert_weights(model: nn.Module):
376
+ """Convert applicable model parameters to fp16"""
377
+
378
+ def _convert_weights_to_fp16(l):
379
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
380
+ l.weight.data = l.weight.data.half()
381
+ if l.bias is not None:
382
+ l.bias.data = l.bias.data.half()
383
+
384
+ if isinstance(l, nn.MultiheadAttention):
385
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
386
+ tensor = getattr(l, attr)
387
+ if tensor is not None:
388
+ tensor.data = tensor.data.half()
389
+
390
+ for name in ["text_projection", "proj"]:
391
+ if hasattr(l, name):
392
+ attr = getattr(l, name)
393
+ if attr is not None:
394
+ attr.data = attr.data.half()
395
+
396
+ model.apply(_convert_weights_to_fp16)
397
+
398
+
399
+ def build_model(state_dict: dict):
400
+ vit = "visual.proj" in state_dict
401
+
402
+ if vit:
403
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
404
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
405
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
406
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
407
+ image_resolution = vision_patch_size * grid_size
408
+ else:
409
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
410
+ vision_layers = tuple(counts)
411
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
412
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
413
+ vision_patch_size = None
414
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
415
+ image_resolution = output_width * 32
416
+
417
+ embed_dim = state_dict["text_projection"].shape[1]
418
+ context_length = state_dict["positional_embedding"].shape[0]
419
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
420
+ transformer_width = state_dict["ln_final.weight"].shape[0]
421
+ transformer_heads = transformer_width // 64
422
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
423
+
424
+ model = CLIP(
425
+ embed_dim,
426
+ image_resolution, vision_layers, vision_width, vision_patch_size,
427
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
428
+ )
429
+
430
+ for key in ["input_resolution", "context_length", "vocab_size"]:
431
+ if key in state_dict:
432
+ del state_dict[key]
433
+
434
+ # convert_weights(model)
435
+ model.load_state_dict(state_dict)
436
+ del state_dict
437
+ torch.cuda.empty_cache()
438
+ return model.eval()
ID-like-train-change-bg/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
ID-like-train-change-bg/config.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_path = 'your datasets path'
2
+
3
+ DOWNLOAD_ROOT = '/home/zhourixin/OOD_Folder/CODE/other_methods/ID-like/clip checkpoint path'
4
+
5
+ CLIP_ckpt = 'ViT-B/16'
6
+
7
+ n_ctx = 16
8
+ ctx_init = None
9
+ ctx_position = 'end'
10
+ learned_cls = False
11
+
12
+ n_ex_ctx = 16
13
+ ex_ctx_init = None
14
+ ex_ctx_position = 'end'
15
+ ex_learned_cls = True
16
+
17
+ data_info = {
18
+ 'ImageNet': {
19
+ 'n_cls': 1000,
20
+ 'labels': ['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead shark', 'electric ray',
21
+ 'stingray', 'rooster', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco',
22
+ 'indigo bunting', 'American robin', 'bulbul', 'jay', 'magpie', 'chickadee', 'American dipper',
23
+ 'kite (bird of prey)', 'bald eagle', 'vulture', 'great grey owl', 'fire salamander', 'smooth newt',
24
+ 'eft', 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog', 'tailed frog',
25
+ 'loggerhead sea turtle', 'leatherback sea turtle', 'mud turtle', 'terrapin', 'box turtle',
26
+ 'banded gecko', 'green iguana', 'Carolina anole', 'desert grassland whiptail lizard', 'agama',
27
+ 'frilled-necked lizard', 'alligator lizard', 'Gila monster', 'European green lizard', 'chameleon',
28
+ 'Komodo dragon', 'Nile crocodile', 'American alligator', 'triceratops', 'worm snake',
29
+ 'ring-necked snake', 'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', 'garter snake',
30
+ 'water snake', 'vine snake', 'night snake', 'boa constrictor', 'African rock python', 'Indian cobra',
31
+ 'green mamba', 'sea snake', 'Saharan horned viper', 'eastern diamondback rattlesnake',
32
+ 'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion', 'yellow garden spider',
33
+ 'barn spider', 'European garden spider', 'southern black widow', 'tarantula', 'wolf spider', 'tick',
34
+ 'centipede', 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl', 'quail',
35
+ 'partridge', 'african grey parrot', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal',
36
+ 'bee eater', 'hornbill', 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser',
37
+ 'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala', 'wombat', 'jellyfish',
38
+ 'sea anemone', 'brain coral', 'flatworm', 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton',
39
+ 'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab', 'red king crab',
40
+ 'American lobster', 'spiny lobster', 'crayfish', 'hermit crab', 'isopod', 'white stork',
41
+ 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'great egret', 'bittern bird',
42
+ 'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'dunlin',
43
+ 'common redshank', 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross',
44
+ 'grey whale', 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', 'Maltese',
45
+ 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', 'Papillon', 'toy terrier', 'Rhodesian Ridgeback',
46
+ 'Afghan Hound', 'Basset Hound', 'Beagle', 'Bloodhound', 'Bluetick Coonhound',
47
+ 'Black and Tan Coonhound', 'Treeing Walker Coonhound', 'English foxhound', 'Redbone Coonhound',
48
+ 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', 'Ibizan Hound', 'Norwegian Elkhound',
49
+ 'Otterhound', 'Saluki', 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier',
50
+ 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', 'Kerry Blue Terrier',
51
+ 'Irish Terrier', 'Norfolk Terrier', 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier',
52
+ 'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier', 'Cairn Terrier', 'Australian Terrier',
53
+ 'Dandie Dinmont Terrier', 'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer',
54
+ 'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier', 'Australian Silky Terrier',
55
+ 'Soft-coated Wheaten Terrier', 'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever',
56
+ 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever', 'Chesapeake Bay Retriever',
57
+ 'German Shorthaired Pointer', 'Vizsla', 'English Setter', 'Irish Setter', 'Gordon Setter',
58
+ 'Brittany dog', 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel',
59
+ 'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', 'Schipperke', 'Groenendael dog',
60
+ 'Malinois', 'Briard', 'Australian Kelpie', 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog',
61
+ 'collie', 'Border Collie', 'Bouvier des Flandres dog', 'Rottweiler', 'German Shepherd Dog',
62
+ 'Dobermann', 'Miniature Pinscher', 'Greater Swiss Mountain Dog', 'Bernese Mountain Dog',
63
+ 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff', 'Tibetan Mastiff',
64
+ 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', 'Alaskan Malamute', 'Siberian Husky',
65
+ 'Dalmatian', 'Affenpinscher', 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog',
66
+ 'Great Pyrenees dog', 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon',
67
+ 'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle', 'Miniature Poodle', 'Standard Poodle',
68
+ 'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', 'Alaskan tundra wolf',
69
+ 'red wolf or maned wolf', 'coyote', 'dingo', 'dhole', 'African wild dog', 'hyena', 'red fox',
70
+ 'kit fox', 'Arctic fox', 'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat',
71
+ 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', 'lion', 'tiger', 'cheetah',
72
+ 'brown bear', 'American black bear', 'polar bear', 'sloth bear', 'mongoose', 'meerkat',
73
+ 'tiger beetle', 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', 'dung beetle',
74
+ 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant', 'grasshopper', 'cricket insect', 'stick insect',
75
+ 'cockroach', 'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly', 'damselfly',
76
+ 'red admiral butterfly', 'ringlet butterfly', 'monarch butterfly', 'small white butterfly',
77
+ 'sulphur butterfly', 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber',
78
+ 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', 'fox squirrel', 'marmot',
79
+ 'beaver', 'guinea pig', 'common sorrel horse', 'zebra', 'pig', 'wild boar', 'warthog',
80
+ 'hippopotamus', 'ox', 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep',
81
+ 'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle', 'arabian camel', 'llama', 'weasel',
82
+ 'mink', 'European polecat', 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo',
83
+ 'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon', 'siamang', 'guenon',
84
+ 'patas monkey', 'baboon', 'macaque', 'langur', 'black-and-white colobus', 'proboscis monkey',
85
+ 'marmoset', 'white-headed capuchin', 'howler monkey', 'titi monkey', "Geoffroy's spider monkey",
86
+ 'common squirrel monkey', 'ring-tailed lemur', 'indri', 'Asian elephant', 'African bush elephant',
87
+ 'red panda', 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish', 'clownfish',
88
+ 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', 'abaya', 'academic gown', 'accordion',
89
+ 'acoustic guitar', 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance',
90
+ 'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can', 'assault rifle', 'backpack',
91
+ 'bakery', 'balance beam', 'balloon', 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail',
92
+ 'barbell', 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow', 'baseball',
93
+ 'basketball', 'bassinet', 'bassoon', 'swimming cap', 'bath towel', 'bathtub', 'station wagon',
94
+ 'lighthouse', 'beaker', 'military hat (bearskin or shako)', 'beer bottle', 'beer glass',
95
+ 'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', 'binoculars', 'birdhouse',
96
+ 'boathouse', 'bobsleigh', 'bolo tie', 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap',
97
+ 'hunting bow', 'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate', 'broom',
98
+ 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', 'butcher shop', 'taxicab', 'cauldron',
99
+ 'candle', 'cannon', 'canoe', 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit',
100
+ 'cardboard box / carton', 'car wheel', 'automated teller machine', 'cassette', 'cassette player',
101
+ 'castle', 'catamaran', 'CD player', 'cello', 'mobile phone', 'chain', 'chain-link fence',
102
+ 'chain mail', 'chainsaw', 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet',
103
+ 'Christmas stocking', 'church', 'movie theater', 'cleaver', 'cliff dwelling', 'cloak', 'clogs',
104
+ 'cocktail shaker', 'coffee mug', 'coffeemaker', 'spiral or coil', 'combination lock',
105
+ 'computer keyboard', 'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet',
106
+ 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane', 'crash helmet', 'crate', 'infant bed',
107
+ 'Crock Pot', 'croquet ball', 'crutch', 'cuirass', 'dam', 'desk', 'desktop computer',
108
+ 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', 'dining table', 'dishcloth',
109
+ 'dishwasher', 'disc brake', 'dock', 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum',
110
+ 'drumstick', 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', 'electric locomotive',
111
+ 'entertainment center', 'envelope', 'espresso machine', 'face powder', 'feather boa',
112
+ 'filing cabinet', 'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute', 'folding chair',
113
+ 'football helmet', 'forklift', 'fountain', 'fountain pen', 'four-poster bed', 'freight car',
114
+ 'French horn', 'frying pan', 'fur coat', 'garbage truck', 'gas mask or respirator', 'gas pump',
115
+ 'goblet', 'go-kart', 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano',
116
+ 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', 'hair clip', 'hair spray',
117
+ 'half-track', 'hammer', 'hamper', 'hair dryer', 'hand-held computer', 'handkerchief',
118
+ 'hard disk drive', 'harmonica', 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater',
119
+ 'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar', 'horse-drawn vehicle', 'hourglass',
120
+ 'iPod', 'clothes iron', 'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw',
121
+ 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', 'laptop computer',
122
+ 'lawn mower', 'lens cap', 'letter opener', 'library', 'lifeboat', 'lighter', 'limousine',
123
+ 'ocean liner', 'lipstick', 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass',
124
+ 'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'maillot', 'one-piece bathing suit',
125
+ 'manhole cover', 'maraca', 'marimba', 'mask', 'matchstick', 'maypole', 'maze', 'measuring cup',
126
+ 'medicine cabinet', 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can',
127
+ 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', 'mobile home', 'ford model t',
128
+ 'modem', 'monastery', 'monitor', 'moped', 'mortar and pestle', 'graduation cap', 'mosque',
129
+ 'mosquito net', 'vespa', 'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van',
130
+ 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', 'notebook computer', 'obelisk',
131
+ 'oboe', 'ocarina', 'odometer', 'oil filter', 'pipe organ', 'oscilloscope', 'overskirt',
132
+ 'bullock cart', 'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel', 'padlock',
133
+ 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel', 'parachute', 'parallel bars',
134
+ 'park bench', 'parking meter', 'railroad car', 'patio', 'payphone', 'pedestal', 'pencil case',
135
+ 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum', 'Pickelhaube',
136
+ 'picket fence', 'pickup truck', 'pier', 'piggy bank', 'pill bottle', 'pillow', 'ping-pong ball',
137
+ 'pinwheel', 'pirate ship', 'drink pitcher', 'block plane', 'planetarium', 'plastic bag',
138
+ 'plate rack', 'farm plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho',
139
+ 'pool table', 'soda bottle', 'plant pot', "potter's wheel", 'power drill', 'prayer rug', 'printer',
140
+ 'prison', 'projectile', 'projector', 'hockey puck', 'punching bag', 'purse', 'quill', 'quilt',
141
+ 'race car', 'racket', 'radiator', 'radio', 'radio telescope', 'rain barrel', 'recreational vehicle',
142
+ 'fishing casting reel', 'reflex camera', 'refrigerator', 'remote control', 'restaurant', 'revolver',
143
+ 'rifle', 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', 'ruler measuring stick', 'sneaker',
144
+ 'safe', 'safety pin', 'salt shaker', 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale',
145
+ 'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw', 'screwdriver', 'seat belt',
146
+ 'sewing machine', 'shield', 'shoe store', 'shoji screen / room divider', 'shopping basket',
147
+ 'shopping cart', 'shovel', 'shower cap', 'shower curtain', 'ski', 'balaclava ski mask',
148
+ 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', 'snowmobile', 'snowplow',
149
+ 'soap dispenser', 'soccer ball', 'sock', 'solar thermal collector', 'sombrero', 'soup bowl',
150
+ 'keyboard space bar', 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web',
151
+ 'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive', 'through arch bridge',
152
+ 'steel drum', 'stethoscope', 'scarf', 'stone wall', 'stopwatch', 'stove', 'strainer', 'tram',
153
+ 'stretcher', 'couch', 'stupa', 'submarine', 'suit', 'sundial', 'sunglass', 'sunglasses', 'sunscreen',
154
+ 'suspension bridge', 'mop', 'sweatshirt', 'swim trunks / shorts', 'swing', 'electrical switch',
155
+ 'syringe', 'table lamp', 'tank', 'tape player', 'teapot', 'teddy bear', 'television', 'tennis ball',
156
+ 'thatched roof', 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', 'toaster',
157
+ 'tobacco shop', 'toilet seat', 'torch', 'totem pole', 'tow truck', 'toy store', 'tractor',
158
+ 'semi-trailer truck', 'tray', 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch',
159
+ 'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard', 'umbrella', 'unicycle',
160
+ 'upright piano', 'vacuum cleaner', 'vase', 'vaulted or arched ceiling', 'velvet fabric',
161
+ 'vending machine', 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock',
162
+ 'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine', 'water bottle', 'water jug',
163
+ 'water tower', 'whiskey jug', 'whistle', 'hair wig', 'window screen', 'window shade', 'Windsor tie',
164
+ 'wine bottle', 'airplane wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence', 'shipwreck',
165
+ 'sailboat', 'yurt', 'website', 'comic book', 'crossword', 'traffic or street sign', 'traffic light',
166
+ 'dust jacket', 'menu', 'plate', 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream',
167
+ 'popsicle', 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', 'mashed potatoes', 'cabbage',
168
+ 'broccoli', 'cauliflower', 'zucchini', 'spaghetti squash', 'acorn squash', 'butternut squash',
169
+ 'cucumber', 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple', 'strawberry',
170
+ 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit', 'cherimoya (custard apple)',
171
+ 'pomegranate', 'hay', 'carbonara', 'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie',
172
+ 'burrito', 'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble', 'cliff', 'coral reef',
173
+ 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach', 'valley', 'volcano', 'baseball player',
174
+ 'bridegroom', 'scuba diver', 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn',
175
+ 'rose hip', 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn mushroom',
176
+ 'earth star fungus', 'hen of the woods mushroom', 'bolete', 'corn cob', 'toilet paper'],
177
+ },
178
+ 'ImageNet100': {
179
+ 'n_cls': 100,
180
+ 'labels': ['stingray', 'ostrich', 'jay', 'American dipper', 'spotted salamander', 'alligator lizard',
181
+ 'Komodo dragon', 'wolf spider', 'african grey parrot', 'jacamar', 'red-breasted merganser', 'tusker',
182
+ 'jellyfish', 'brain coral', 'snail', 'white stork', 'dowitcher', 'albatross', 'Beagle', 'Otterhound',
183
+ 'Lakeland Terrier', 'Giant Schnauzer', 'Cocker Spaniel', 'Australian Kelpie', 'Miniature Pinscher',
184
+ 'Samoyed', 'Cardigan Welsh Corgi', 'Standard Poodle', 'Egyptian Mau', 'snow leopard', 'jaguar',
185
+ 'polar bear', 'cockroach', 'hare', 'orangutan', 'gibbon', 'guenon', 'black-and-white colobus',
186
+ "Geoffroy's spider monkey", 'bath towel', 'bell tower', 'birdhouse', 'bookstore',
187
+ 'cardboard box / carton', 'chainsaw', 'chiffonier', 'cornet', 'cradle', 'crate', 'Crock Pot',
188
+ 'desktop computer', 'rotary dial telephone', 'dog sled', 'electric locomotive', 'flagpole',
189
+ 'four-poster bed', 'French horn', 'frying pan', 'fur coat', 'gas pump', 'gong', 'greenhouse', 'jeep',
190
+ 'ladle', 'lighter', 'one-piece bathing suit', 'marimba', 'ocarina', 'overskirt', 'palace',
191
+ 'paper towel', 'railroad car', 'pencil sharpener', 'Pickelhaube', 'pier', 'piggy bank', 'pool table',
192
+ 'power drill', 'race car', 'radio', 'rifle', 'sarong', 'schooner', 'sewing machine', 'sliding door',
193
+ 'sunglasses', 'swim trunks / shorts', 'syringe', 'front curtain', 'tow truck', 'trimaran',
194
+ 'wardrobe', 'water tower', 'shipwreck', 'crossword', 'ice cream', 'cabbage', 'promontory',
195
+ 'baseball player', 'hen of the woods mushroom'],
196
+ },
197
+ 'ImageNet10': {
198
+ 'n_cls': 10,
199
+ 'labels': ['brambling', 'American bullfrog', 'Greater Swiss Mountain Dog', 'Siamese cat', 'common sorrel horse',
200
+ 'impala (antelope)', 'container ship', 'garbage truck', 'sports car', 'military aircraft'],
201
+ },
202
+ 'ImageNet20': {
203
+ 'n_cls': 20,
204
+ 'labels': ['smooth newt', 'eft', 'spotted salamander', 'European green lizard', 'Nile crocodile', 'grey wolf',
205
+ 'Arctic fox', 'brown bear', 'starfish', 'zebra', 'balloon', 'high-speed train', 'canoe', 'missile',
206
+ 'moped', 'schooner', 'snowmobile', 'space shuttle', 'steam locomotive', 'tank'],
207
+ },
208
+ 'car196': {
209
+ 'n_cls': 196,
210
+ 'labels': ['AM General Hummer SUV 2000', 'Acura RL Sedan 2012', 'Acura TL Sedan 2012', 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012', 'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012', 'Aston Martin V8 Vantage Convertible 2012', 'Aston Martin V8 Vantage Coupe 2012', 'Aston Martin Virage Convertible 2012', 'Aston Martin Virage Coupe 2012', 'Audi RS 4 Convertible 2008', 'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012', 'Audi R8 Coupe 2012', 'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994', 'Audi 100 Wagon 1994', 'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011', 'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012', 'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012', 'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012', 'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012', 'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007', 'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012', 'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012', 'BMW Z4 Convertible 2012', 'Bentley Continental Supersports Conv. Convertible 2012', 'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011', 'Bentley Continental GT Coupe 2012', 'Bentley Continental GT Coupe 2007', 'Bentley Continental Flying Spur Sedan 2007', 'Bugatti Veyron 16.4 Convertible 2009', 'Bugatti Veyron 16.4 Coupe 2009', 'Buick Regal GS 2012', 'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012', 'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012', 'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007', 'Chevrolet Silverado 1500 Hybrid Crew Cab 2012', 'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012', 'Chevrolet Corvette Ron Fellows Edition Z06 2007', 'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012', 'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007', 'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012', 'Chevrolet Express Cargo Van 2007', 'Chevrolet Avalanche Crew Cab 2012', 'Chevrolet Cobalt SS 2010', 'Chevrolet Malibu Hybrid Sedan 2010', 'Chevrolet TrailBlazer SS 2009', 'Chevrolet Silverado 2500HD Regular Cab 2012', 'Chevrolet Silverado 1500 Classic Extended Cab 2007', 'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007', 'Chevrolet Malibu Sedan 2007', 'Chevrolet Silverado 1500 Extended Cab 2012', 'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009', 'Chrysler Sebring Convertible 2010', 'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010', 'Chrysler Crossfire Convertible 2008', 'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002', 'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007', 'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010', 'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009', 'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010', 'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008', 'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012', 'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012', 'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998', 'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012', 'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012', 'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012', 'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012', 'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007', 'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012', 'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006', 'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007', 'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012', 'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012', 'GMC Savana Van 2012', 'GMC Yukon Hybrid SUV 2012', 'GMC Acadia SUV 2012', 'GMC Canyon Extended Cab 2012', 'Geo Metro Convertible 1993', 'HUMMER H3T Crew Cab 2010', 'HUMMER H2 SUT Crew Cab 2009', 'Honda Odyssey Minivan 2012', 'Honda Odyssey Minivan 2007', 'Honda Accord Coupe 2012', 'Honda Accord Sedan 2012', 'Hyundai Veloster Hatchback 2012', 'Hyundai Santa Fe SUV 2012', 'Hyundai Tucson SUV 2012', 'Hyundai Veracruz SUV 2012', 'Hyundai Sonata Hybrid Sedan 2012', 'Hyundai Elantra Sedan 2007', 'Hyundai Accent Sedan 2012', 'Hyundai Genesis Sedan 2012', 'Hyundai Sonata Sedan 2012', 'Hyundai Elantra Touring Hatchback 2012', 'Hyundai Azera Sedan 2012', 'Infiniti G Coupe IPL 2012', 'Infiniti QX56 SUV 2011', 'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012', 'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012', 'Jeep Liberty SUV 2012', 'Jeep Grand Cherokee SUV 2012', 'Jeep Compass SUV 2012', 'Lamborghini Reventon Coupe 2008', 'Lamborghini Aventador Coupe 2012', 'Lamborghini Gallardo LP 570-4 Superleggera 2012', 'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012', 'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011', 'MINI Cooper Roadster Convertible 2012', 'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011', 'McLaren MP4-12C Coupe 2012', 'Mercedes-Benz 300-Class Convertible 1993', 'Mercedes-Benz C-Class Sedan 2012', 'Mercedes-Benz SL-Class Coupe 2009', 'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012', 'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012', 'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012', 'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998', 'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012', 'Ram C/V Cargo Van Minivan 2012', 'Rolls-Royce Phantom Drophead Coupe Convertible 2012', 'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012', 'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009', 'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007', 'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012', 'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012', 'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012', 'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012', 'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991', 'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012', 'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007', 'smart fortwo Convertible 2012'],
211
+ },
212
+ 'food101': {
213
+ 'n_cls': 101,
214
+ 'labels': ['Apple pie', 'Baby back ribs', 'Baklava', 'Beef carpaccio', 'Beef tartare', 'Beet salad', 'Beignets', 'Bibimbap', 'Bread pudding', 'Breakfast burrito', 'Bruschetta', 'Caesar salad', 'Cannoli', 'Caprese salad', 'Carrot cake', 'Ceviche', 'Cheesecake', 'Cheese plate', 'Chicken curry', 'Chicken quesadilla', 'Chicken wings', 'Chocolate cake', 'Chocolate mousse', 'Churros', 'Clam chowder', 'Club sandwich', 'Crab cakes', 'Creme brulee', 'Croque madame', 'Cup cakes', 'Deviled eggs', 'Donuts', 'Dumplings', 'Edamame', 'Eggs benedict', 'Escargots', 'Falafel', 'Filet mignon', 'Fish and chips', 'Foie gras', 'French fries', 'French onion soup', 'French toast', 'Fried calamari', 'Fried rice', 'Frozen yogurt', 'Garlic bread', 'Gnocchi', 'Greek salad', 'Grilled cheese sandwich', 'Grilled salmon', 'Guacamole', 'Gyoza', 'Hamburger', 'Hot and sour soup', 'Hot dog', 'Huevos rancheros', 'Hummus', 'Ice cream', 'Lasagna', 'Lobster bisque', 'Lobster roll sandwich', 'Macaroni and cheese', 'Macarons', 'Miso soup', 'Mussels', 'Nachos', 'Omelette', 'Onion rings', 'Oysters', 'Pad thai', 'Paella', 'Pancakes', 'Panna cotta', 'Peking duck', 'Pho', 'Pizza', 'Pork chop', 'Poutine', 'Prime rib', 'Pulled pork sandwich', 'Ramen', 'Ravioli', 'Red velvet cake', 'Risotto', 'Samosa', 'Sashimi', 'Scallops', 'Seaweed salad', 'Shrimp and grits', 'Spaghetti bolognese', 'Spaghetti carbonara', 'Spring rolls', 'Steak', 'Strawberry shortcake', 'Sushi', 'Tacos', 'Takoyaki', 'Tiramisu', 'Tuna tartare', 'Waffles'],
215
+ },
216
+ 'pet37': {
217
+ 'n_cls': 37,
218
+ 'labels': ['Abyssinian', 'American Bulldog', 'American Pit Bull Terrier', 'Basset Hound', 'Beagle', 'Bengal', 'Birman', 'Bombay', 'Boxer', 'British Shorthair', 'Chihuahua', 'Egyptian Mau', 'English Cocker Spaniel', 'English Setter', 'German Shorthaired', 'Great Pyrenees', 'Havanese', 'Japanese Chin', 'Keeshond', 'Leonberger', 'Maine Coon', 'Miniature Pinscher', 'Newfoundland', 'Persian', 'Pomeranian', 'Pug', 'Ragdoll', 'Russian Blue', 'Saint Bernard', 'Samoyed', 'Scottish Terrier', 'Shiba Inu', 'Siamese', 'Sphynx', 'Staffordshire Bull Terrier', 'Wheaten Terrier', 'Yorkshire Terrier'],
219
+ },
220
+ 'bird200': {
221
+ 'n_cls': 200,
222
+ 'labels': ['Black footed Albatross', 'Laysan Albatross', 'Sooty Albatross', 'Groove billed Ani', 'Crested Auklet', 'Least Auklet', 'Parakeet Auklet', 'Rhinoceros Auklet', 'Brewer Blackbird', 'Red winged Blackbird', 'Rusty Blackbird', 'Yellow headed Blackbird', 'Bobolink', 'Indigo Bunting', 'Lazuli Bunting', 'Painted Bunting', 'Cardinal', 'Spotted Catbird', 'Gray Catbird', 'Yellow breasted Chat', 'Eastern Towhee', 'Chuck will Widow', 'Brandt Cormorant', 'Red faced Cormorant', 'Pelagic Cormorant', 'Bronzed Cowbird', 'Shiny Cowbird', 'Brown Creeper', 'American Crow', 'Fish Crow', 'Black billed Cuckoo', 'Mangrove Cuckoo', 'Yellow billed Cuckoo', 'Gray crowned Rosy Finch', 'Purple Finch', 'Northern Flicker', 'Acadian Flycatcher', 'Great Crested Flycatcher', 'Least Flycatcher', 'Olive sided Flycatcher', 'Scissor tailed Flycatcher', 'Vermilion Flycatcher', 'Yellow bellied Flycatcher', 'Frigatebird', 'Northern Fulmar', 'Gadwall', 'American Goldfinch', 'European Goldfinch', 'Boat tailed Grackle', 'Eared Grebe', 'Horned Grebe', 'Pied billed Grebe', 'Western Grebe', 'Blue Grosbeak', 'Evening Grosbeak', 'Pine Grosbeak', 'Rose breasted Grosbeak', 'Pigeon Guillemot', 'California Gull', 'Glaucous winged Gull', 'Heermann Gull', 'Herring Gull', 'Ivory Gull', 'Ring billed Gull', 'Slaty backed Gull', 'Western Gull', 'Anna Hummingbird', 'Ruby throated Hummingbird', 'Rufous Hummingbird', 'Green Violetear', 'Long tailed Jaeger', 'Pomarine Jaeger', 'Blue Jay', 'Florida Jay', 'Green Jay', 'Dark eyed Junco', 'Tropical Kingbird', 'Gray Kingbird', 'Belted Kingfisher', 'Green Kingfisher', 'Pied Kingfisher', 'Ringed Kingfisher', 'White breasted Kingfisher', 'Red legged Kittiwake', 'Horned Lark', 'Pacific Loon', 'Mallard', 'Western Meadowlark', 'Hooded Merganser', 'Red breasted Merganser', 'Mockingbird', 'Nighthawk', 'Clark Nutcracker', 'White breasted Nuthatch', 'Baltimore Oriole', 'Hooded Oriole', 'Orchard Oriole', 'Scott Oriole', 'Ovenbird', 'Brown Pelican', 'White Pelican', 'Western Wood Pewee', 'Sayornis', 'American Pipit', 'Whip poor Will', 'Horned Puffin', 'Common Raven', 'White necked Raven', 'American Redstart', 'Geococcyx', 'Loggerhead Shrike', 'Great Grey Shrike', 'Baird Sparrow', 'Black throated Sparrow', 'Brewer Sparrow', 'Chipping Sparrow', 'Clay colored Sparrow', 'House Sparrow', 'Field Sparrow', 'Fox Sparrow', 'Grasshopper Sparrow', 'Harris Sparrow', 'Henslow Sparrow', 'Le Conte Sparrow', 'Lincoln Sparrow', 'Nelson Sharp tailed Sparrow', 'Savannah Sparrow', 'Seaside Sparrow', 'Song Sparrow', 'Tree Sparrow', 'Vesper Sparrow', 'White crowned Sparrow', 'White throated Sparrow', 'Cape Glossy Starling', 'Bank Swallow', 'Barn Swallow', 'Cliff Swallow', 'Tree Swallow', 'Scarlet Tanager', 'Summer Tanager', 'Artic Tern', 'Black Tern', 'Caspian Tern', 'Common Tern', 'Elegant Tern', 'Forsters Tern', 'Least Tern', 'Green tailed Towhee', 'Brown Thrasher', 'Sage Thrasher', 'Black capped Vireo', 'Blue headed Vireo', 'Philadelphia Vireo', 'Red eyed Vireo', 'Warbling Vireo', 'White eyed Vireo', 'Yellow throated Vireo', 'Bay breasted Warbler', 'Black and white Warbler', 'Black throated Blue Warbler', 'Blue winged Warbler', 'Canada Warbler', 'Cape May Warbler', 'Cerulean Warbler', 'Chestnut sided Warbler', 'Golden winged Warbler', 'Hooded Warbler', 'Kentucky Warbler', 'Magnolia Warbler', 'Mourning Warbler', 'Myrtle Warbler', 'Nashville Warbler', 'Orange crowned Warbler', 'Palm Warbler', 'Pine Warbler', 'Prairie Warbler', 'Prothonotary Warbler', 'Swainson Warbler', 'Tennessee Warbler', 'Wilson Warbler', 'Worm eating Warbler', 'Yellow Warbler', 'Northern Waterthrush', 'Louisiana Waterthrush', 'Bohemian Waxwing', 'Cedar Waxwing', 'American Three toed Woodpecker', 'Pileated Woodpecker', 'Red bellied Woodpecker', 'Red cockaded Woodpecker', 'Red headed Woodpecker', 'Downy Woodpecker', 'Bewick Wren', 'Cactus Wren', 'Carolina Wren', 'House Wren', 'Marsh Wren', 'Rock Wren', 'Winter Wren', 'Common Yellowthroat'],
223
+ },
224
+ 'cifar10': {
225
+ 'n_cls': 10,
226
+ 'labels': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
227
+ },
228
+ 'cifar100': {
229
+ 'n_cls': 100,
230
+ 'labels': ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle',
231
+ 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
232
+ 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
233
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
234
+ 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
235
+ 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree',
236
+ 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
237
+ 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
238
+ 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor',
239
+ 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'],
240
+ },
241
+ 'bronze2NotLine': {
242
+ 'n_cls': 11,
243
+ 'labels': ['bronze ware of the early Shang age', 'bronze ware of the late Shang age',
244
+ 'bronze ware of the early Western Zhou age', 'bronze ware of the mid Western Zhou age', 'bronze ware of the late Western Zhou age',
245
+ 'bronze ware of the early Spring and Autumn age', 'bronze ware of the mid Spring and Autumn age', 'bronze ware of the late Spring and Autumn age',
246
+ 'bronze ware of the early Warring States age', 'bronze ware of the mid Warring States age', 'bronze ware of the late Warring States age'],
247
+ },
248
+ }
ID-like-train-change-bg/dataloaders/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .pet37 import OxfordIIITPet
2
+ from .car196 import StanfordCars
3
+ from .food101 import Food101
4
+ from .bird200 import Cub2011
ID-like-train-change-bg/dataloaders/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (440 Bytes). View file
 
ID-like-train-change-bg/dataloaders/__pycache__/bird200.cpython-311.pyc ADDED
Binary file (4.55 kB). View file
 
ID-like-train-change-bg/dataloaders/__pycache__/car196.cpython-311.pyc ADDED
Binary file (9.21 kB). View file
 
ID-like-train-change-bg/dataloaders/__pycache__/food101.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
ID-like-train-change-bg/dataloaders/__pycache__/pet37.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
ID-like-train-change-bg/dataloaders/bird200.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from torchvision.datasets.folder import default_loader
4
+ from torchvision.datasets.utils import download_url
5
+ from torch.utils.data import Dataset
6
+ import torch
7
+
8
+ class Cub2011(Dataset):
9
+ base_folder = 'CUB_200_2011/images'
10
+
11
+ def __init__(self, root, train=True, transform=None, loader=default_loader):
12
+ self.root = os.path.expanduser(root)
13
+ self.transform = transform
14
+ self.loader = default_loader
15
+ self.train = train
16
+
17
+ self._load_metadata()
18
+
19
+ def _load_metadata(self):
20
+ images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
21
+ names=['img_id', 'filepath'])
22
+ image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
23
+ sep=' ', names=['img_id', 'target'])
24
+ train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
25
+ sep=' ', names=['img_id', 'is_training_img'])
26
+
27
+ data = images.merge(image_class_labels, on='img_id')
28
+ self.data = data.merge(train_test_split, on='img_id')
29
+
30
+ if self.train:
31
+ self.data = self.data[self.data.is_training_img == 1]
32
+ else:
33
+ self.data = self.data[self.data.is_training_img == 0]
34
+
35
+ class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),
36
+ sep=' ', names=['class_id', 'target'])
37
+ self.class_names_str = [name.split(".")[1].replace('_', ' ') for name in class_names.target]
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, idx):
43
+ sample = self.data.iloc[idx]
44
+ path = os.path.join(self.root, self.base_folder, sample.filepath)
45
+ target = sample.target - 1 # Targets start at 1 by default, so shift to 0
46
+ img = self.loader(path)
47
+
48
+ if self.transform is not None:
49
+ img = self.transform(img)
50
+
51
+ return img, target
52
+
53
+
54
+ if __name__ == "__main__":
55
+ train_set = Cub2011(root = "/nobackup/dataset_myf", train = True)
56
+ val_set = Cub2011(root = "/nobackup/dataset_myf", train = False)
57
+ # idx = train_loader.dataset.data.target == 1
58
+
59
+
60
+ kwargs = {'num_workers': 4, 'pin_memory': True}
61
+ train_loader = torch.utils.data.DataLoader(train_set ,
62
+ batch_size=16, shuffle=True, **kwargs)
63
+ val_loader = torch.utils.data.DataLoader(Cub2011(root = "/nobackup/dataset_myf", train = False),
64
+ batch_size=16, shuffle=False, **kwargs)
ID-like-train-change-bg/dataloaders/car196.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Callable, Optional, Any, Tuple
3
+
4
+ from PIL import Image
5
+ import torch
6
+
7
+ from torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
8
+ from torchvision.datasets.vision import VisionDataset
9
+
10
+
11
+ class StanfordCars(VisionDataset):
12
+ """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
13
+
14
+ The Cars dataset contains 16,185 images of 196 classes of cars. The data is
15
+ split into 8,144 training images and 8,041 testing images, where each class
16
+ has been split roughly in a 50-50 split
17
+
18
+ .. note::
19
+
20
+ This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
21
+
22
+ Args:
23
+ root (string): Root directory of dataset
24
+ split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
25
+ transform (callable, optional): A function/transform that takes in an PIL image
26
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
27
+ target_transform (callable, optional): A function/transform that takes in the
28
+ target and transforms it.
29
+ download (bool, optional): If True, downloads the dataset from the internet and
30
+ puts it in root directory. If dataset is already downloaded, it is not
31
+ downloaded again."""
32
+
33
+ def __init__(
34
+ self,
35
+ root: str,
36
+ split: str = "train",
37
+ transform: Optional[Callable] = None,
38
+ target_transform: Optional[Callable] = None,
39
+ download: bool = False,
40
+ ) -> None:
41
+
42
+ try:
43
+ import scipy.io as sio
44
+ except ImportError:
45
+ raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
46
+
47
+ super().__init__(root, transform=transform, target_transform=target_transform)
48
+
49
+ self._split = verify_str_arg(split, "split", ("train", "test"))
50
+ self._base_folder = pathlib.Path(root) / "stanford_cars"
51
+ devkit = self._base_folder / "devkit"
52
+
53
+ if self._split == "train":
54
+ self._annotations_mat_path = devkit / "cars_train_annos.mat"
55
+ self._images_base_path = self._base_folder / "cars_train"
56
+ else:
57
+ self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
58
+ self._images_base_path = self._base_folder / "cars_test"
59
+
60
+ if download:
61
+ self.download()
62
+
63
+ if not self._check_exists():
64
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
65
+
66
+ self._samples = [
67
+ (
68
+ str(self._images_base_path / annotation["fname"]),
69
+ annotation["class"] - 1, # Original target mapping starts from 1, hence -1
70
+ )
71
+ for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
72
+ ]
73
+
74
+ self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
75
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
76
+
77
+ self.class_names_str = self.classes
78
+
79
+ def __len__(self) -> int:
80
+ return len(self._samples)
81
+
82
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
83
+ """Returns pil_image and class_id for given index"""
84
+ image_path, target = self._samples[idx]
85
+ pil_image = Image.open(image_path).convert("RGB")
86
+
87
+ if self.transform is not None:
88
+ pil_image = self.transform(pil_image)
89
+ if self.target_transform is not None:
90
+ target = self.target_transform(target)
91
+ return pil_image, target
92
+
93
+
94
+ def download(self) -> None:
95
+ if self._check_exists():
96
+ return
97
+
98
+ download_and_extract_archive(
99
+ url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
100
+ download_root=str(self._base_folder),
101
+ md5="c3b158d763b6e2245038c8ad08e45376",
102
+ )
103
+ if self._split == "train":
104
+ download_and_extract_archive(
105
+ url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
106
+ download_root=str(self._base_folder),
107
+ md5="065e5b463ae28d29e77c1b4b166cfe61",
108
+ )
109
+ else:
110
+ download_and_extract_archive(
111
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
112
+ download_root=str(self._base_folder),
113
+ md5="4ce7ebf6a94d07f1952d94dd34c4d501",
114
+ )
115
+ download_url(
116
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
117
+ root=str(self._base_folder),
118
+ md5="b0a2b23655a3edd16d84508592a98d10",
119
+ )
120
+
121
+ def _check_exists(self) -> bool:
122
+ if not (self._base_folder / "devkit").is_dir():
123
+ return False
124
+
125
+ return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
126
+
127
+
128
+ def examine_count(counter, name = "train"):
129
+ print(f"in the {name} set")
130
+ for label in counter:
131
+ print(label, counter[label])
132
+
133
+ if __name__ == "__main__":
134
+
135
+ train_set = StanfordCars(root = "/nobackup/dataset_myf", split = "train", download = True)
136
+ test_set = StanfordCars(root = "/nobackup/dataset_myf", split = "test", download = True)
137
+ print(f"train set len {len(train_set)}")
138
+ print(f"test set len {len(test_set)}")
139
+ from collections import Counter
140
+ train_label_count = Counter([label for img, label in train_set._samples])
141
+ test_label_count = Counter([label for img, label in test_set._samples])
142
+ examine_count(train_label_count, name = "train")
143
+ examine_count(test_label_count, name = "test")
144
+
145
+ kwargs = {'num_workers': 4, 'pin_memory': True}
146
+ train_loader = torch.utils.data.DataLoader(train_set ,
147
+ batch_size=16, shuffle=True, **kwargs)
148
+ val_loader = torch.utils.data.DataLoader(test_set,
149
+ batch_size=16, shuffle=False, **kwargs)
ID-like-train-change-bg/dataloaders/food101.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ from typing import Any, Tuple, Callable, Optional
4
+ import torch
5
+ import PIL.Image
6
+
7
+ from torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
8
+ from torchvision.datasets.vision import VisionDataset
9
+
10
+ class Food101(VisionDataset):
11
+ """`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
12
+
13
+ The Food-101 is a challenging data set of 101 food categories, with 101'000 images.
14
+ For each class, 250 manually reviewed test images are provided as well as 750 training images.
15
+ On purpose, the training images were not cleaned, and thus still contain some amount of noise.
16
+ This comes mostly in the form of intense colors and sometimes wrong labels. All images were
17
+ rescaled to have a maximum side length of 512 pixels.
18
+
19
+
20
+ Args:
21
+ root (string): Root directory of the dataset.
22
+ split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
23
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
24
+ version. E.g, ``transforms.RandomCrop``.
25
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
26
+ download (bool, optional): If True, downloads the dataset from the internet and
27
+ puts it in root directory. If dataset is already downloaded, it is not
28
+ downloaded again. Default is False.
29
+ """
30
+
31
+ _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
32
+ _MD5 = "85eeb15f3717b99a5da872d97d918f87"
33
+
34
+ def __init__(
35
+ self,
36
+ root: str,
37
+ split: str = "train",
38
+ transform: Optional[Callable] = None,
39
+ target_transform: Optional[Callable] = None,
40
+ download: bool = False,
41
+ ) -> None:
42
+ super().__init__(root, transform=transform, target_transform=target_transform)
43
+ self._split = verify_str_arg(split, "split", ("train", "test"))
44
+ self._base_folder = Path(self.root) / "food-101"
45
+ self._meta_folder = self._base_folder / "meta"
46
+ self._images_folder = self._base_folder / "images"
47
+ self.class_names_str = ['Apple pie', 'Baby back ribs', 'Baklava', 'Beef carpaccio', 'Beef tartare', 'Beet salad', 'Beignets', 'Bibimbap', 'Bread pudding', 'Breakfast burrito', 'Bruschetta', 'Caesar salad', 'Cannoli', 'Caprese salad', 'Carrot cake', 'Ceviche', 'Cheesecake', 'Cheese plate', 'Chicken curry', 'Chicken quesadilla', 'Chicken wings', 'Chocolate cake', 'Chocolate mousse', 'Churros', 'Clam chowder', 'Club sandwich', 'Crab cakes', 'Creme brulee', 'Croque madame', 'Cup cakes', 'Deviled eggs', 'Donuts', 'Dumplings', 'Edamame', 'Eggs benedict', 'Escargots', 'Falafel', 'Filet mignon', 'Fish and chips', 'Foie gras', 'French fries', 'French onion soup', 'French toast', 'Fried calamari', 'Fried rice', 'Frozen yogurt', 'Garlic bread', 'Gnocchi', 'Greek salad', 'Grilled cheese sandwich', 'Grilled salmon', 'Guacamole', 'Gyoza', 'Hamburger', 'Hot and sour soup', 'Hot dog', 'Huevos rancheros', 'Hummus', 'Ice cream', 'Lasagna', 'Lobster bisque', 'Lobster roll sandwich', 'Macaroni and cheese', 'Macarons', 'Miso soup', 'Mussels', 'Nachos', 'Omelette', 'Onion rings', 'Oysters', 'Pad thai', 'Paella', 'Pancakes', 'Panna cotta', 'Peking duck', 'Pho', 'Pizza', 'Pork chop', 'Poutine', 'Prime rib', 'Pulled pork sandwich', 'Ramen', 'Ravioli', 'Red velvet cake', 'Risotto', 'Samosa', 'Sashimi', 'Scallops', 'Seaweed salad', 'Shrimp and grits', 'Spaghetti bolognese', 'Spaghetti carbonara', 'Spring rolls', 'Steak', 'Strawberry shortcake', 'Sushi', 'Tacos', 'Takoyaki', 'Tiramisu', 'Tuna tartare', 'Waffles']
48
+
49
+ if download:
50
+ self._download()
51
+
52
+ if not self._check_exists():
53
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
54
+
55
+ self._labels = []
56
+ self._image_files = []
57
+ with open(self._meta_folder / f"{split}.json") as f:
58
+ metadata = json.loads(f.read())
59
+
60
+ self.classes = sorted(metadata.keys())
61
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
62
+
63
+ for class_label, im_rel_paths in metadata.items():
64
+ self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
65
+ self._image_files += [
66
+ self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
67
+ ]
68
+
69
+ def __len__(self) -> int:
70
+ return len(self._image_files)
71
+
72
+ def __getitem__(self, idx) -> Tuple[Any, Any]:
73
+ image_file, label = self._image_files[idx], self._labels[idx]
74
+ image = PIL.Image.open(image_file).convert("RGB")
75
+
76
+ if self.transform:
77
+ image = self.transform(image)
78
+
79
+ if self.target_transform:
80
+ label = self.target_transform(label)
81
+
82
+ return image, label
83
+
84
+
85
+ def extra_repr(self) -> str:
86
+ return f"split={self._split}"
87
+
88
+ def _check_exists(self) -> bool:
89
+ return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
90
+
91
+ def _download(self) -> None:
92
+ if self._check_exists():
93
+ return
94
+ download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
95
+
96
+ def examine_count(counter, name = "train"):
97
+ print(f"in the {name} set")
98
+ for label in counter:
99
+ print(label, counter[label])
100
+
101
+ if __name__ == "__main__":
102
+
103
+ label_names = []
104
+ with open('debug/food101_labels.txt') as f:
105
+ for name in f:
106
+ label_names.append(name.strip())
107
+ print(label_names)
108
+
109
+ train_set = Food101(root = "/nobackup/dataset_myf", split = "train", download = True)
110
+ test_set = Food101(root = "/nobackup/dataset_myf", split = "test")
111
+ print(f"train set len {len(train_set)}")
112
+ print(f"test set len {len(test_set)}")
113
+ from collections import Counter
114
+ train_label_count = Counter(train_set._labels)
115
+ test_label_count = Counter(test_set._labels)
116
+ # examine_count(train_label_count, name = "train")
117
+ # examine_count(test_label_count, name = "test")
118
+
119
+ kwargs = {'num_workers': 4, 'pin_memory': True}
120
+ train_loader = torch.utils.data.DataLoader(train_set ,
121
+ batch_size=16, shuffle=True, **kwargs)
122
+ val_loader = torch.utils.data.DataLoader(test_set,
123
+ batch_size=16, shuffle=False, **kwargs)
ID-like-train-change-bg/dataloaders/pet37.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+
4
+ from typing import Any, Tuple, Callable, Optional, Union, Sequence
5
+ import torch
6
+ from PIL import Image
7
+ import pathlib
8
+ from torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
9
+ from torchvision.datasets.vision import VisionDataset
10
+
11
+
12
+ class OxfordIIITPet(VisionDataset):
13
+ """`Oxford-IIIT Pet Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_.
14
+
15
+ Args:
16
+ root (string): Root directory of the dataset.
17
+ split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``.
18
+ target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or
19
+ ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
20
+
21
+ - ``category`` (int): Label for one of the 37 pet categories.
22
+ - ``segmentation`` (PIL image): Segmentation trimap of the image.
23
+
24
+ If empty, ``None`` will be returned as target.
25
+
26
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
27
+ version. E.g, ``transforms.RandomCrop``.
28
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
29
+ download (bool, optional): If True, downloads the dataset from the internet and puts it into
30
+ ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
31
+ """
32
+
33
+ _RESOURCES = (
34
+ ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
35
+ ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
36
+ )
37
+ _VALID_TARGET_TYPES = ("category", "segmentation")
38
+
39
+ def __init__(
40
+ self,
41
+ root: str,
42
+ split: str = "trainval",
43
+ target_types: Union[Sequence[str], str] = "category",
44
+ transforms: Optional[Callable] = None,
45
+ transform: Optional[Callable] = None,
46
+ target_transform: Optional[Callable] = None,
47
+ download: bool = False,
48
+ ):
49
+ self._split = verify_str_arg(split, "split", ("trainval", "test"))
50
+ if isinstance(target_types, str):
51
+ target_types = [target_types]
52
+ self._target_types = [
53
+ verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types
54
+ ]
55
+
56
+ super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
57
+ self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet"
58
+ self._images_folder = self._base_folder / "images"
59
+ self._anns_folder = self._base_folder / "annotations"
60
+ self._segs_folder = self._anns_folder / "trimaps"
61
+
62
+ if download:
63
+ self._download()
64
+
65
+ if not self._check_exists():
66
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
67
+
68
+ image_ids = []
69
+ self._labels = []
70
+ with open(self._anns_folder / f"{self._split}.txt") as file:
71
+ for line in file:
72
+ image_id, label, *_ = line.strip().split()
73
+ image_ids.append(image_id)
74
+ self._labels.append(int(label) - 1)
75
+
76
+ self.classes = [
77
+ " ".join(part.title() for part in raw_cls.split("_"))
78
+ for raw_cls, _ in sorted(
79
+ {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)},
80
+ key=lambda image_id_and_label: image_id_and_label[1],
81
+ )
82
+ ]
83
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
84
+
85
+ self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
86
+ self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
87
+
88
+ self.class_names_str = self.classes
89
+
90
+ def __len__(self) -> int:
91
+ return len(self._images)
92
+
93
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
94
+ image = Image.open(self._images[idx]).convert("RGB")
95
+
96
+ target: Any = []
97
+ for target_type in self._target_types:
98
+ if target_type == "category":
99
+ target.append(self._labels[idx])
100
+ else: # target_type == "segmentation"
101
+ target.append(Image.open(self._segs[idx]))
102
+
103
+ if not target:
104
+ target = None
105
+ elif len(target) == 1:
106
+ target = target[0]
107
+ else:
108
+ target = tuple(target)
109
+
110
+ if self.transforms:
111
+ image, target = self.transforms(image, target)
112
+
113
+ return image, target
114
+
115
+
116
+ def _check_exists(self) -> bool:
117
+ for folder in (self._images_folder, self._anns_folder):
118
+ if not (os.path.exists(folder) and os.path.isdir(folder)):
119
+ return False
120
+ else:
121
+ return True
122
+
123
+ def _download(self) -> None:
124
+ if self._check_exists():
125
+ return
126
+
127
+ for url, md5 in self._RESOURCES:
128
+ download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)
129
+
130
+
131
+ def examine_count(counter, name = "train"):
132
+ print(f"in the {name} set")
133
+ for label in counter:
134
+ print(label, counter[label])
135
+
136
+ if __name__ == "__main__":
137
+
138
+ train_set = OxfordIIITPet(root = "/nobackup/dataset_myf", split = "trainval", download = True)
139
+ test_set = OxfordIIITPet(root = "/nobackup/dataset_myf", split = "test")
140
+ print(f"train set len {len(train_set)}")
141
+ print(f"test set len {len(test_set)}")
142
+ from collections import Counter
143
+ train_label_count = Counter(train_set._labels)
144
+ test_label_count = Counter(test_set._labels)
145
+ examine_count(train_label_count, name = "train")
146
+ examine_count(test_label_count, name = "test")
147
+
148
+ kwargs = {'num_workers': 4, 'pin_memory': True}
149
+ train_loader = torch.utils.data.DataLoader(train_set ,
150
+ batch_size=16, shuffle=True, **kwargs)
151
+ val_loader = torch.utils.data.DataLoader(test_set,
152
+ batch_size=16, shuffle=False, **kwargs)
ID-like-train-change-bg/error1.txt ADDED
The diff for this file is too large to render. See raw diff
 
ID-like-train-change-bg/eval_ood_detection.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ from scipy import stats
6
+ import config
7
+ from utils.common import setup_seed, get_and_print_results, print_measures
8
+ from utils.file_ops import save_as_dataframe, setup_log
9
+ from utils.plot_util import plot_distribution
10
+ from utils.dataloaders_utils import set_few_shot_loader, set_val_loader, set_ood_loader_ImageNet
11
+ from utils.id_like import get_prompts, get_result, load_model
12
+
13
+ import pickle
14
+ import collections
15
+
16
+ def process_args():
17
+ parser = argparse.ArgumentParser(description='Evaluates OOD for CLIP',
18
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19
+ parser.add_argument('--root-dir', default="/home/zhourixin/OOD_Folder/CODE/other_methods/ID-like-train-change-bg/OODdata", type=str,
20
+ help='root dir of datasets')
21
+ parser.add_argument('--in_dataset', default='bronze2NotLine', type=str, help='in-distribution dataset')
22
+ parser.add_argument('--seed', default=1, type=int, help="random seed")
23
+ parser.add_argument('--score', default='id-like', type=str)
24
+ parser.add_argument('--CLIP_ckpt', type=str, default='ViT-B/16',
25
+ choices=['ViT-B/32', 'ViT-B/16', 'ViT-L/14'],
26
+ help='which pretrained img encoder to use')
27
+ parser.add_argument('--n_shot', default=5, type=int,
28
+ help="how many samples are used to estimate classwise mean and precision matrix")
29
+ parser.add_argument('--batch_size', default=10, type=int, help='mini-batch size')
30
+ parser.add_argument('--test_batch_size', default=512, type=int, help='mini-batch size')
31
+ parser.add_argument('--n_crop', default=256, type=int, help='crop num')
32
+ parser.add_argument('--n_selection', default=32, type=int, help='selection num')
33
+ # parser.add_argument('--selection_p', default=0.2, type=float, help='confidence selection percentile')
34
+ parser.add_argument('--n_ex_prompts', default=100, type=int, help='number of extra prompts')
35
+ parser.add_argument('--n_epoch', default=3, type=int, help='number of epoch')
36
+ parser.add_argument('--lr', '--learning-rate', default=5e-3, type=float, metavar='LR', help='initial learning rate',
37
+ dest='lr')
38
+ parser.add_argument('--lam_in', default=1.0, type=float, help='lambda of id loss')
39
+ parser.add_argument('--lam_out', default=0.3, type=float, help='lambda of ood loss')
40
+ parser.add_argument('--lam_diff', default=0.2, type=float, help='lambda of difference')
41
+
42
+ args = parser.parse_args()
43
+
44
+ args.n_cls = config.data_info[args.in_dataset]['n_cls']
45
+
46
+ args.log_directory = f"/home/zhourixin/OOD_Folder/CODE/other_methods/ID-like-train-change-bg/results/{args.in_dataset}/id-like/{args.n_shot}shot/"
47
+
48
+ os.makedirs(args.log_directory, exist_ok=True)
49
+ setup_seed(args.seed)
50
+ return args
51
+
52
+ def update(d, u):
53
+ for k, v in u.items():
54
+ if isinstance(v, collections.abc.Mapping):
55
+ d[k] = update(d.get(k, {}), v)
56
+ else:
57
+ d[k] = v
58
+ return d
59
+
60
+ def train():
61
+ args = process_args()
62
+
63
+ log = setup_log(args)
64
+ # out_datasets = ['ssb_hard', 'ninco']
65
+ # out_datasets = ['imagenet22k_container', 'ssb_hard', 'ninco',
66
+ # 'inaturalist', 'textures', 'openimage_o']
67
+ out_datasets = ['imagenet22k_container_refine', 'bronzeS_containerM',
68
+ 'bronzeM_containerS', 'bronze_Line', 'ssb_hard', 'ninco',
69
+ 'inaturalist', 'textures', 'openimage_o']
70
+ # out_datasets = ['imagenet22k_container_refine', 'bronzeS_containerM',
71
+ # 'bronzeM_containerS', 'bronze_Line', 'ssb_hard', 'ninco',
72
+ # 'inaturalist', 'textures', 'openimage_o']
73
+
74
+
75
+ test_labels = config.data_info[args.in_dataset]['labels']
76
+ ex_labels = ['X'] * args.n_ex_prompts
77
+
78
+ model_checkpoint_save_path = os.path.join(args.log_directory, 'model_checkpoint.pth')
79
+
80
+ if os.path.exists(model_checkpoint_save_path):
81
+ model = load_model(args, test_labels, ex_labels)
82
+ else:
83
+ few_shot_loader = set_few_shot_loader(args)
84
+ model = get_prompts(args, few_shot_loader, test_labels, ex_labels)
85
+
86
+ score_resulu_dic = {}
87
+
88
+ test_loader = set_val_loader(args)
89
+ result_in = get_result(args, model, test_loader, test_labels, ex_labels, if_acc=True)
90
+ score_in = result_in['scores']
91
+ acc = result_in['acc']
92
+ log.debug(f"Acc: {acc}")
93
+
94
+ update(score_resulu_dic,{"score_in":score_in})
95
+
96
+ auroc_list, aupr_list, fpr_list = [], [], []
97
+ for out_dataset in out_datasets:
98
+ log.debug(f"Evaluting OOD dataset {out_dataset}")
99
+ ood_loader = set_ood_loader_ImageNet(args, out_dataset)
100
+ result_out = get_result(args, model, ood_loader, test_labels, ex_labels)
101
+ score_out = result_out['scores']
102
+ log.debug(f"in scores: {stats.describe(score_in)}")
103
+ log.debug(f"out scores: {stats.describe(score_out)}")
104
+ plot_distribution(args, score_in, score_out, out_dataset)
105
+ get_and_print_results(args, log, score_in, score_out,
106
+ auroc_list, aupr_list, fpr_list)
107
+ update(score_resulu_dic, {"out_score":{out_dataset:score_out}})
108
+
109
+ log.debug('\n\nMean Test Results')
110
+ print_measures(log, np.mean(auroc_list), np.mean(aupr_list),
111
+ np.mean(fpr_list), method_name=args.score)
112
+ save_as_dataframe(args, out_datasets, fpr_list, auroc_list, aupr_list, acc)
113
+
114
+ with open(os.path.join(args.log_directory, 'score.pkl'),
115
+ 'wb') as f:
116
+ pickle.dump(score_resulu_dic, f, pickle.HIGHEST_PROTOCOL)
117
+
118
+
119
+
120
+
121
+ if __name__ == "__main__":
122
+ train()
123
+
ID-like-train-change-bg/output1.txt ADDED
The diff for this file is too large to render. See raw diff
 
ID-like-train-change-bg/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from __future__ import absolute_import
2
+ from .common import *
ID-like-train-change-bg/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (310 Bytes). View file
 
ID-like-train-change-bg/utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (236 Bytes). View file
 
ID-like-train-change-bg/utils/__pycache__/common.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/common.cpython-37.pyc ADDED
Binary file (5.43 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/dataloaders_utils.cpython-311.pyc ADDED
Binary file (27.7 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/file_ops.cpython-311.pyc ADDED
Binary file (7.21 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/file_ops.cpython-37.pyc ADDED
Binary file (3.42 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/id_like.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/id_like_loss.cpython-311.pyc ADDED
Binary file (4.02 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/id_like_utils.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/imagenet_templates.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/plot_util.cpython-311.pyc ADDED
Binary file (2.76 kB). View file
 
ID-like-train-change-bg/utils/__pycache__/plot_util.cpython-37.pyc ADDED
Binary file (1.53 kB). View file
 
ID-like-train-change-bg/utils/common.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import os
4
+ import numpy as np
5
+ import json
6
+ import random
7
+ import sklearn.metrics as sk
8
+
9
+
10
+ def setup_seed(seed):
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed(seed)
13
+ np.random.seed(seed)
14
+ random.seed(seed)
15
+
16
+
17
+ def accuracy(output, target, topk=(1,)):
18
+ """Computes the precision@k for the specified values of k"""
19
+ maxk = max(topk)
20
+ batch_size = target.size(0)
21
+ # values, indices = input.topk(k, dim=1, largest=True, sorted=True)
22
+ _, pred = output.topk(maxk, 1, True, True)
23
+ pred = pred.t()
24
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
25
+
26
+ res = []
27
+ for k in topk:
28
+ correct_k = correct[:k].flatten().float().sum(0)
29
+ res.append(correct_k.mul_(100.0 / batch_size))
30
+ return res
31
+
32
+
33
+ def read_file(file_path, root='corpus'):
34
+ corpus = []
35
+ with open(os.path.join(root, file_path)) as f:
36
+ for line in f:
37
+ corpus.append(line[:-1])
38
+ return corpus
39
+
40
+
41
+ def calculate_cosine_similarity(image_features, text_features):
42
+ image_features /= image_features.norm(dim=-1, keepdim=True)
43
+ text_features /= text_features.norm(dim=-1, keepdim=True)
44
+ similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
45
+ return similarity
46
+
47
+
48
+ class AverageMeter(object):
49
+ def __init__(self):
50
+ self.reset()
51
+
52
+ def reset(self):
53
+ self.val = 0
54
+ self.avg = 0
55
+ self.sum = 0
56
+ self.count = 0
57
+
58
+ def update(self, val, n=1):
59
+ self.val = val
60
+ self.sum += val * n
61
+ self.count += n
62
+ self.avg = self.sum / self.count
63
+
64
+
65
+ def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
66
+ """Use high precision for cumsum and check that final value matches sum
67
+ Parameters
68
+ ----------
69
+ arr : array-like
70
+ To be cumulatively summed as flat
71
+ rtol : float
72
+ Relative tolerance, see ``np.allclose``
73
+ atol : float
74
+ Absolute tolerance, see ``np.allclose``
75
+ """
76
+ out = np.cumsum(arr, dtype=np.float64)
77
+ expected = np.sum(arr, dtype=np.float64)
78
+ if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
79
+ raise RuntimeError('cumsum was found to be unstable: '
80
+ 'its last element does not correspond to sum')
81
+ return out
82
+
83
+
84
+ def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=None):
85
+ classes = np.unique(y_true)
86
+ if (pos_label is None and
87
+ not (np.array_equal(classes, [0, 1]) or
88
+ np.array_equal(classes, [-1, 1]) or
89
+ np.array_equal(classes, [0]) or
90
+ np.array_equal(classes, [-1]) or
91
+ np.array_equal(classes, [1]))):
92
+ raise ValueError("Data is not binary and pos_label is not specified")
93
+ elif pos_label is None:
94
+ pos_label = 1.
95
+
96
+ # make y_true a boolean vector
97
+ y_true = (y_true == pos_label)
98
+
99
+ # sort scores and corresponding truth values
100
+ desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
101
+ y_score = y_score[desc_score_indices]
102
+ y_true = y_true[desc_score_indices]
103
+
104
+ # y_score typically has many tied values. Here we extract
105
+ # the indices associated with the distinct values. We also
106
+ # concatenate a value for the end of the curve.
107
+ distinct_value_indices = np.where(np.diff(y_score))[0]
108
+ threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
109
+
110
+ # accumulate the true positives with decreasing threshold
111
+ tps = stable_cumsum(y_true)[threshold_idxs]
112
+ fps = 1 + threshold_idxs - tps # add one because of zero-based indexing
113
+
114
+ thresholds = y_score[threshold_idxs]
115
+
116
+ recall = tps / tps[-1]
117
+
118
+ last_ind = tps.searchsorted(tps[-1])
119
+ sl = slice(last_ind, None, -1) # [last_ind::-1]
120
+ recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]
121
+
122
+ cutoff = np.argmin(np.abs(recall - recall_level))
123
+
124
+ return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff])
125
+
126
+
127
+ def get_measures(_pos, _neg, recall_level=0.95):
128
+ pos = np.array(_pos[:]).reshape((-1, 1))
129
+ neg = np.array(_neg[:]).reshape((-1, 1))
130
+ examples = np.squeeze(np.vstack((pos, neg)))
131
+ labels = np.zeros(len(examples), dtype=np.int32)
132
+ labels[:len(pos)] += 1
133
+
134
+ auroc = sk.roc_auc_score(labels, examples)
135
+ aupr = sk.average_precision_score(labels, examples)
136
+ fpr = fpr_and_fdr_at_recall(labels, examples, recall_level)
137
+
138
+ return auroc, aupr, fpr
139
+
140
+
141
+ def print_measures(log, auroc, aupr, fpr, method_name='Ours', recall_level=0.95):
142
+ if log == None:
143
+ print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr))
144
+ print('AUROC: \t\t\t{:.2f}'.format(100 * auroc))
145
+ print('AUPR: \t\t\t{:.2f}'.format(100 * aupr))
146
+ else:
147
+ log.debug('\t\t\t\t' + method_name)
148
+ log.debug(' FPR{:d} AUROC AUPR'.format(int(100*recall_level)))
149
+ log.debug('& {:.2f} & {:.2f} & {:.2f}'.format(100*fpr, 100*auroc, 100*aupr))
150
+
151
+
152
+ def get_and_print_results(args, log, in_score, out_score, auroc_list, aupr_list, fpr_list):
153
+ '''
154
+ 1) evaluate detection performance for a given OOD test set (loader)
155
+ 2) print results (FPR95, AUROC, AUPR)
156
+ '''
157
+ aurocs, auprs, fprs = [], [], []
158
+ measures = get_measures(-in_score, -out_score)
159
+ aurocs.append(measures[0]); auprs.append(measures[1]); fprs.append(measures[2])
160
+ print(f'in score samples (random sampled): {in_score[:3]}, out score samples: {out_score[:3]}')
161
+ # print(f'in score samples (min): {in_score[-3:]}, out score samples: {out_score[-3:]}')
162
+ auroc = np.mean(aurocs); aupr = np.mean(auprs); fpr = np.mean(fprs)
163
+ auroc_list.append(auroc); aupr_list.append(aupr); fpr_list.append(fpr) # used to calculate the avg over multiple OOD test sets
164
+ print_measures(log, auroc, aupr, fpr, args.score)
ID-like-train-change-bg/utils/dataloaders_utils.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import Dataset, Subset, DataLoader
9
+ # from transformers import CLIPModel
10
+ from torchvision import datasets, transforms
11
+ import torchvision.transforms as transforms
12
+ from dataloaders import StanfordCars, Food101, OxfordIIITPet, Cub2011
13
+ from torchvision.datasets import CIFAR10, CIFAR100, SVHN
14
+ from tqdm import tqdm
15
+ import config
16
+ from clip import load, tokenize
17
+ from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
18
+
19
+ _tokenizer = _Tokenizer()
20
+
21
+
22
+ def update_class_to_idx(dataset, new_class_to_idx):
23
+ """
24
+ 更新 dataset 实例的 class_to_idx 映射。
25
+
26
+ :param dataset: ImageFolder 实例。
27
+ :param new_class_to_idx: 新的类别到索引的映射字典。
28
+ """
29
+ # 更新 class_to_idx 映射
30
+ dataset.class_to_idx = new_class_to_idx
31
+
32
+ # 根据新的 class_to_idx 更新 idx_to_class 映射
33
+ dataset.idx_to_class = {idx: class_name for class_name, idx in new_class_to_idx.items()}
34
+
35
+ # 重新构建样本列表,以确保它们与新的映射相匹配
36
+ dataset.samples = []
37
+ for class_name, idx in new_class_to_idx.items():
38
+ class_dir = os.path.join(dataset.root, class_name)
39
+ for entry in os.listdir(class_dir):
40
+ full_path = os.path.join(class_dir, entry)
41
+ if os.path.isfile(full_path):
42
+ dataset.samples.append((full_path, idx))
43
+
44
+
45
+ def set_train_loader(args, subset=False, max_count=0):
46
+ root = args.root_dir
47
+ normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
48
+ std=(0.26862954, 0.26130258, 0.27577711)) # for CLIP
49
+ preprocess = transforms.Compose([
50
+ transforms.Resize(224),
51
+ transforms.CenterCrop(224),
52
+ transforms.ToTensor(),
53
+ normalize
54
+ ])
55
+ kwargs = {'num_workers': 4, 'pin_memory': True}
56
+ batch_size = args.batch_size
57
+ batch_size = 256
58
+ shuffle = True
59
+ if args.in_dataset == "ImageNet":
60
+ path = os.path.join(root, 'ImageNet', 'train')
61
+ elif args.in_dataset == "ImageNet100":
62
+ path = os.path.join(root, "ImageNet100", 'train')
63
+ elif args.in_dataset == "ImageNet10":
64
+ path = os.path.join(root, "ImageNet10", 'train')
65
+ elif args.in_dataset == "ImageNet20":
66
+ path = os.path.join(root, "ImageNet20", 'train')
67
+ dataset = datasets.ImageFolder(path, transform=preprocess)
68
+ if subset:
69
+ from collections import defaultdict
70
+ classwise_count = defaultdict(int)
71
+ indices = []
72
+ for i, label in enumerate(dataset.targets):
73
+ if classwise_count[label] < max_count:
74
+ indices.append(i)
75
+ classwise_count[label] += 1
76
+ dataset = torch.utils.data.Subset(dataset, indices)
77
+ train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
78
+
79
+ return train_loader
80
+
81
+
82
+ def set_val_loader(args):
83
+ root = args.root_dir
84
+ normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
85
+ std=(0.26862954, 0.26130258, 0.27577711)) # for CLIP
86
+ preprocess = transforms.Compose([
87
+ transforms.Resize(224),
88
+ transforms.CenterCrop(224),
89
+ transforms.ToTensor(),
90
+ normalize
91
+ ])
92
+ kwargs = {'num_workers': 4, 'pin_memory': True}
93
+ if args.in_dataset == "ImageNet":
94
+ path = os.path.join(root, 'ImageNet', 'val')
95
+ dataset = datasets.ImageFolder(path, transform=preprocess)
96
+
97
+ elif args.in_dataset == "bronze2NotLine":
98
+ path = os.path.join(root, "bronze_ID_and_OOD", "composite_split", "test")
99
+ dataset = datasets.ImageFolder(path, transform=preprocess)
100
+ new_class_to_idx = {'age_0':0, 'age_1':1, 'age_2':2, 'age_3':3,
101
+ 'age_4':4, 'age_5':5, 'age_6':6, 'age_7':7, 'age_8':8, 'age_9':9, 'age_10':10}
102
+ update_class_to_idx(dataset, new_class_to_idx)
103
+
104
+
105
+
106
+ elif args.in_dataset == "ImageNet100":
107
+ path = os.path.join(root, "ImageNet100", 'val')
108
+ dataset = datasets.ImageFolder(path, transform=preprocess)
109
+ elif args.in_dataset == "ImageNet10":
110
+ path = os.path.join(root, "ImageNet10", 'train')
111
+ dataset = datasets.ImageFolder(path, transform=preprocess)
112
+ elif args.in_dataset == "ImageNet20":
113
+ path = os.path.join(root, "ImageNet20", 'train')
114
+ dataset = datasets.ImageFolder(path, transform=preprocess)
115
+ elif args.in_dataset == "car196":
116
+ path = root
117
+ dataset = StanfordCars(path, split="test", download=True, transform=preprocess)
118
+ elif args.in_dataset == "food101":
119
+ path = root
120
+ dataset = Food101(path, split="test", download=True, transform=preprocess)
121
+ elif args.in_dataset == "pet37":
122
+ path = root
123
+ dataset = OxfordIIITPet(path, split="test", download=True, transform=preprocess)
124
+ elif args.in_dataset == "bird200":
125
+ path = root
126
+ dataset = Cub2011(path, train=False, transform=preprocess)
127
+ elif args.in_dataset == "cifar10":
128
+ path = root
129
+ dataset = CIFAR10(path, train=False, transform=preprocess)
130
+ elif args.in_dataset == "cifar100":
131
+ path = root
132
+ dataset = CIFAR100(path, train=False, transform=preprocess)
133
+
134
+ val_loader = torch.utils.data.DataLoader(dataset, batch_size=args.test_batch_size, shuffle=False, **kwargs)
135
+
136
+ return val_loader
137
+
138
+
139
+ def set_ood_loader_ImageNet(args, out_dataset):
140
+ '''
141
+ set OOD loader for ImageNet scale datasets
142
+ '''
143
+ # root = os.path.join(args.root_dir, 'ImageNet_OOD_dataset')
144
+ root = args.root_dir
145
+ # normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
146
+ normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
147
+ std=(0.26862954, 0.26130258, 0.27577711)) # for CLIP
148
+ preprocess = transforms.Compose([
149
+ transforms.Resize(224),
150
+ transforms.CenterCrop(224),
151
+ transforms.ToTensor(),
152
+ normalize
153
+ ])
154
+ if out_dataset == 'imagenet22k_container':
155
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'imagenet-21k-container', 'images'), transform=preprocess)
156
+ elif out_dataset == 'imagenet22k_container_refine':
157
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'imagenet-21k-container-refine', 'images'), transform=preprocess)
158
+ elif out_dataset == 'bronzeS_containerM':
159
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'transfer_dataset', 'bronze_structure_container_material', 'test'), transform=preprocess)
160
+ elif out_dataset == 'bronzeM_containerS':
161
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'transfer_dataset', 'container_structure_bronze_material', 'test'), transform=preprocess)
162
+ elif out_dataset == 'bronze_Line':
163
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'bronze_line'), transform=preprocess)
164
+ elif out_dataset == 'ssb_hard':
165
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'ssb_hard'), transform=preprocess)
166
+ elif out_dataset == 'ninco':
167
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'ninco'), transform=preprocess)
168
+ elif out_dataset == 'inaturalist':
169
+ # testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'iNaturalist'), transform=preprocess)
170
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'inaturalist'), transform=preprocess)
171
+ elif out_dataset == 'textures':
172
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_classic', 'texture'), transform=preprocess)
173
+ elif out_dataset == 'openimage_o':
174
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'images_largescale', 'openimage_o'), transform=preprocess)
175
+
176
+
177
+
178
+ elif out_dataset == 'SUN':
179
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'SUN'), transform=preprocess)
180
+ elif out_dataset == 'places365': # filtered places
181
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'Places'), transform=preprocess)
182
+ elif out_dataset == 'placesbg':
183
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'placesbg'), transform=preprocess)
184
+ elif out_dataset == 'dtd':
185
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(root, 'dtd', 'images'), transform=preprocess)
186
+ elif out_dataset == 'svhn':
187
+ testsetout = SVHN(root=os.path.join(args.root_dir, 'svhn'), split='test', transform=preprocess)
188
+ elif out_dataset == "cifar10":
189
+ testsetout = CIFAR10(root=args.root_dir, train=False, transform=preprocess)
190
+ elif out_dataset == "cifar100":
191
+ testsetout = CIFAR100(root=args.root_dir, train=False, transform=preprocess)
192
+ elif out_dataset == 'ssb_hard':
193
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(args.root_dir, 'ssb_hard'), transform=preprocess)
194
+ elif out_dataset == 'ninco':
195
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(args.root_dir, 'ninco'), transform=preprocess)
196
+ elif out_dataset == 'openimage_o':
197
+ testsetout = torchvision.datasets.ImageFolder(root=os.path.join(args.root_dir, 'openimage_o'), transform=preprocess)
198
+ testloaderOut = torch.utils.data.DataLoader(testsetout, batch_size=args.test_batch_size, shuffle=False, num_workers=0)
199
+
200
+ return testloaderOut
201
+
202
+
203
+ class RandomCrop(object):
204
+ def __init__(self, n_crop=2):
205
+ normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
206
+ std=(0.26862954, 0.26130258, 0.27577711)) # for CLIP
207
+ self.n_crop = n_crop
208
+ self.random_crop = transforms.Compose([
209
+ # transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
210
+ transforms.RandomResizedCrop(224),
211
+ transforms.RandomHorizontalFlip(),
212
+ transforms.ToTensor(),
213
+ normalize
214
+ ])
215
+
216
+ def __call__(self, x):
217
+ views = [self.random_crop(x).unsqueeze(dim=0) for _ in range(self.n_crop)]
218
+ views = torch.cat(views, dim=0)
219
+ return views
220
+
221
+
222
+ def set_few_shot_loader(args):
223
+ root = args.root_dir
224
+ data_transform = RandomCrop(args.n_crop)
225
+ # data_transform = RandomCropAndMask(args.n_crop, args.n_crop)
226
+ shuffle = True
227
+ kwargs = {'num_workers': 0, 'pin_memory': True}
228
+
229
+ if args.in_dataset == "ImageNet":
230
+ path = os.path.join(root, 'images_largescale', 'imagenet_1k', 'train')
231
+ dataset = datasets.ImageFolder(path)
232
+ elif args.in_dataset == "bronze2NotLine":
233
+ path = os.path.join(root, "bronze_ID_and_OOD", "composite_split", "train")
234
+ dataset = datasets.ImageFolder(path)
235
+ new_class_to_idx = {'age_0':0, 'age_1':1, 'age_2':2, 'age_3':3,
236
+ 'age_4':4, 'age_5':5, 'age_6':6, 'age_7':7, 'age_8':8, 'age_9':9, 'age_10':10}
237
+ update_class_to_idx(dataset, new_class_to_idx)
238
+ # A = dataset.class_to_idx
239
+
240
+ elif args.in_dataset == "ImageNet100":
241
+ path = os.path.join(root, "ImageNet100", 'train')
242
+ dataset = datasets.ImageFolder(path)
243
+ elif args.in_dataset == "ImageNet10":
244
+ path = os.path.join(root, "ImageNet10", 'train')
245
+ dataset = datasets.ImageFolder(path)
246
+ elif args.in_dataset == "ImageNet20":
247
+ path = os.path.join(root, "ImageNet20", 'train')
248
+ dataset = datasets.ImageFolder(path)
249
+ elif args.in_dataset == "car196":
250
+ path = root
251
+ dataset = StanfordCars(path, split="train", download=True)
252
+ dataset.targets = [target for _, target in dataset]
253
+ elif args.in_dataset == "food101":
254
+ path = root
255
+ dataset = Food101(path, split="train", download=True)
256
+ dataset.targets = [target for _, target in dataset]
257
+ elif args.in_dataset == "pet37":
258
+ path = root
259
+ dataset = OxfordIIITPet(path, split="trainval", download=True)
260
+ dataset.targets = [target for _, target in dataset]
261
+ elif args.in_dataset == "bird200":
262
+ path = root
263
+ dataset = Cub2011(path, train=True)
264
+ dataset.targets = [dataset.data.iloc[idx].target - 1 for idx in range(len(dataset))]
265
+ elif args.in_dataset == "cifar10":
266
+ path = root
267
+ dataset = CIFAR10(path, train=True)
268
+ elif args.in_dataset == "cifar100":
269
+ path = root
270
+ dataset = CIFAR100(path, train=True)
271
+
272
+ indices = []
273
+ from collections import defaultdict
274
+ classwise_idx = defaultdict(list)
275
+ print('get dataset index')
276
+ for i, target in enumerate(tqdm(dataset.targets)):
277
+ classwise_idx[target].append(i)
278
+ print('sample few shot dataset')
279
+
280
+
281
+
282
+
283
+
284
+ from random import sample
285
+ for i in tqdm(range(args.n_cls)):
286
+ sample_length = len(classwise_idx[i])
287
+ if args.n_shot == 100000:
288
+ sl = sample(classwise_idx[i], int(sample_length*1))
289
+ elif args.n_shot == 90000:
290
+ sl = sample(classwise_idx[i], int(sample_length*0.9))
291
+ elif args.n_shot == 80000:
292
+ sl = sample(classwise_idx[i], int(sample_length*0.8))
293
+ elif args.n_shot == 70000:
294
+ sl = sample(classwise_idx[i], int(sample_length*0.7))
295
+ elif args.n_shot == 60000:
296
+ sl = sample(classwise_idx[i], int(sample_length*0.6))
297
+ elif args.n_shot == 50000:
298
+ sl = sample(classwise_idx[i], int(sample_length*0.5))
299
+ elif args.n_shot == 40000:
300
+ sl = sample(classwise_idx[i], int(sample_length*0.4))
301
+ elif args.n_shot == 30000:
302
+ sl = sample(classwise_idx[i], int(sample_length*0.3))
303
+ elif args.n_shot == 20000:
304
+ sl = sample(classwise_idx[i], int(sample_length*0.2))
305
+ elif args.n_shot == 10000:
306
+ sl = sample(classwise_idx[i], int(sample_length*0.1))
307
+ else:
308
+ sl = sample(classwise_idx[i], args.n_shot)
309
+ indices.extend(sl)
310
+
311
+ if args.in_dataset == "ImageNet":
312
+ path = os.path.join(root, 'images_largescale', 'imagenet_1k', 'train')
313
+ dataset = datasets.ImageFolder(path, transform=data_transform)
314
+ # path = os.path.join(root, 'ImageNet', 'train')
315
+ # dataset = datasets.ImageFolder(path, transform=data_transform)
316
+ elif args.in_dataset == "bronze2NotLine":
317
+ path = os.path.join(root, "bronze_ID_and_OOD", "composite_split", "train")
318
+ dataset = datasets.ImageFolder(path, transform=data_transform)
319
+ new_class_to_idx = {'age_0':0, 'age_1':1, 'age_2':2, 'age_3':3,
320
+ 'age_4':4, 'age_5':5, 'age_6':6, 'age_7':7, 'age_8':8, 'age_9':9, 'age_10':10}
321
+ update_class_to_idx(dataset, new_class_to_idx)
322
+ elif args.in_dataset == "ImageNet100":
323
+ path = os.path.join(root, "ImageNet100", 'train')
324
+ dataset = datasets.ImageFolder(path, transform=data_transform)
325
+ elif args.in_dataset == "ImageNet10":
326
+ path = os.path.join(root, "ImageNet10", 'train')
327
+ dataset = datasets.ImageFolder(path, transform=data_transform)
328
+ elif args.in_dataset == "ImageNet20":
329
+ path = os.path.join(root, "ImageNet20", 'train')
330
+ dataset = datasets.ImageFolder(path, transform=data_transform)
331
+ elif args.in_dataset == "car196":
332
+ path = root
333
+ dataset = StanfordCars(path, split="train", download=True, transform=data_transform)
334
+ elif args.in_dataset == "food101":
335
+ path = root
336
+ dataset = Food101(path, split="train", download=True, transform=data_transform)
337
+ elif args.in_dataset == "pet37":
338
+ path = root
339
+ dataset = OxfordIIITPet(path, split="trainval", download=True, transform=data_transform)
340
+ elif args.in_dataset == "bird200":
341
+ path = root
342
+ dataset = Cub2011(path, train=True, transform=data_transform)
343
+ elif args.in_dataset == "cifar10":
344
+ path = root
345
+ dataset = CIFAR10(path, train=True, transform=data_transform)
346
+ elif args.in_dataset == "cifar100":
347
+ path = root
348
+ dataset = CIFAR100(path, train=True, transform=data_transform)
349
+
350
+ dataset = torch.utils.data.Subset(dataset, indices)
351
+ few_shot_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=shuffle, **kwargs)
352
+
353
+ # from torch.utils.data.distributed import DistributedSampler
354
+ # sampler = DistributedSampler(dataset)
355
+ # few_shot_loader = torch.utils.data.DataLoader(dataset, sampler=sampler,
356
+ # batch_size=args.batch_size,
357
+ # shuffle=False, **kwargs)
358
+
359
+ return few_shot_loader
360
+
361
+
362
+ def set_few_shot_loader_normal(args):
363
+ root = args.root_dir
364
+ normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
365
+ std=(0.26862954, 0.26130258, 0.27577711)) # for CLIP
366
+ data_transform = transforms.Compose([
367
+ transforms.Resize(224),
368
+ transforms.CenterCrop(224),
369
+ transforms.ToTensor(),
370
+ normalize
371
+ ])
372
+ # data_transform = RandomCropAndMask(args.n_crop, args.n_crop)
373
+ shuffle = True
374
+ kwargs = {'num_workers': 0, 'pin_memory': True}
375
+
376
+ if args.in_dataset == "ImageNet":
377
+ path = os.path.join(root, 'ImageNet', 'train')
378
+ dataset = datasets.ImageFolder(path)
379
+ elif args.in_dataset == "ImageNet100":
380
+ path = os.path.join(root, "ImageNet100", 'train')
381
+ dataset = datasets.ImageFolder(path)
382
+ elif args.in_dataset == "ImageNet10":
383
+ path = os.path.join(root, "ImageNet10", 'train')
384
+ dataset = datasets.ImageFolder(path)
385
+ elif args.in_dataset == "ImageNet20":
386
+ path = os.path.join(root, "ImageNet20", 'train')
387
+ dataset = datasets.ImageFolder(path)
388
+ elif args.in_dataset == "car196":
389
+ path = root
390
+ dataset = StanfordCars(path, split="train", download=True)
391
+ dataset.targets = [target for _, target in dataset]
392
+ elif args.in_dataset == "food101":
393
+ path = root
394
+ dataset = Food101(path, split="train", download=True)
395
+ dataset.targets = [target for _, target in dataset]
396
+ elif args.in_dataset == "pet37":
397
+ path = root
398
+ dataset = OxfordIIITPet(path, split="trainval", download=True)
399
+ dataset.targets = [target for _, target in dataset]
400
+ elif args.in_dataset == "bird200":
401
+ path = root
402
+ dataset = Cub2011(path, train=True)
403
+ dataset.targets = [dataset.data.iloc[idx].target - 1 for idx in range(len(dataset))]
404
+ elif args.in_dataset == "cifar10":
405
+ path = root
406
+ dataset = CIFAR10(path, train=True)
407
+ elif args.in_dataset == "cifar100":
408
+ path = root
409
+ dataset = CIFAR100(path, train=True)
410
+
411
+ indices = []
412
+ from collections import defaultdict
413
+ classwise_idx = defaultdict(list)
414
+ print('get dataset index')
415
+ for i, target in enumerate(tqdm(dataset.targets)):
416
+ classwise_idx[target].append(i)
417
+ print('sample few shot dataset')
418
+ from random import sample
419
+ for i in tqdm(range(args.n_cls)):
420
+ sl = sample(classwise_idx[i], args.n_shot)
421
+ indices.extend(sl)
422
+
423
+ if args.in_dataset == "ImageNet":
424
+ path = os.path.join(root, 'ImageNet', 'train')
425
+ dataset = datasets.ImageFolder(path, transform=data_transform)
426
+ elif args.in_dataset == "ImageNet100":
427
+ path = os.path.join(root, "ImageNet100", 'train')
428
+ dataset = datasets.ImageFolder(path, transform=data_transform)
429
+ elif args.in_dataset == "ImageNet10":
430
+ path = os.path.join(root, "ImageNet10", 'train')
431
+ dataset = datasets.ImageFolder(path, transform=data_transform)
432
+ elif args.in_dataset == "ImageNet20":
433
+ path = os.path.join(root, "ImageNet20", 'train')
434
+ dataset = datasets.ImageFolder(path, transform=data_transform)
435
+ elif args.in_dataset == "car196":
436
+ path = root
437
+ dataset = StanfordCars(path, split="train", download=True, transform=data_transform)
438
+ elif args.in_dataset == "food101":
439
+ path = root
440
+ dataset = Food101(path, split="train", download=True, transform=data_transform)
441
+ elif args.in_dataset == "pet37":
442
+ path = root
443
+ dataset = OxfordIIITPet(path, split="trainval", download=True, transform=data_transform)
444
+ elif args.in_dataset == "bird200":
445
+ path = root
446
+ dataset = Cub2011(path, train=True, transform=data_transform)
447
+ elif args.in_dataset == "cifar10":
448
+ path = root
449
+ dataset = CIFAR10(path, train=True, transform=data_transform)
450
+ elif args.in_dataset == "cifar100":
451
+ path = root
452
+ dataset = CIFAR100(path, train=True, transform=data_transform)
453
+ dataset = torch.utils.data.Subset(dataset, indices)
454
+ few_shot_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=shuffle, **kwargs)
455
+
456
+ # from torch.utils.data.distributed import DistributedSampler
457
+ # sampler = DistributedSampler(dataset)
458
+ # few_shot_loader = torch.utils.data.DataLoader(dataset, sampler=sampler,
459
+ # batch_size=args.batch_size,
460
+ # shuffle=False, **kwargs)
461
+
462
+ return few_shot_loader
ID-like-train-change-bg/utils/file_ops.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import numpy as np
4
+ import logging
5
+ import pandas as pd
6
+
7
+
8
+ def save_scores(args, scores, dataset_name):
9
+ with open(os.path.join(args.log_directory, f'{dataset_name}_scores.npy'), 'wb') as f:
10
+ np.save(f, scores)
11
+
12
+
13
+ def load_scores(args, dataset_name):
14
+ with open(os.path.join(args.log_directory, f'{dataset_name}_scores.npy'), 'rb') as f:
15
+ scores = np.load(f)
16
+ return scores
17
+
18
+
19
+ def setup_log(args):
20
+ log = logging.getLogger(__name__)
21
+ formatter = logging.Formatter('%(asctime)s : %(message)s')
22
+ fileHandler = logging.FileHandler(os.path.join(args.log_directory, "ood_eval_info.log"), mode='w')
23
+ fileHandler.setFormatter(formatter)
24
+ streamHandler = logging.StreamHandler()
25
+ streamHandler.setFormatter(formatter)
26
+ log.setLevel(logging.DEBUG)
27
+ log.addHandler(fileHandler)
28
+ log.addHandler(streamHandler)
29
+ # log.debug(f"#########{args.name}############")
30
+ return log
31
+
32
+
33
+ def save_as_dataframe(args, out_datasets, fpr_list, auroc_list, aupr_list, acc_in):
34
+ fpr_list = [float('{:.2f}'.format(100 * fpr)) for fpr in fpr_list]
35
+ auroc_list = [float('{:.2f}'.format(100 * auroc)) for auroc in auroc_list]
36
+ aupr_list = [float('{:.2f}'.format(100 * aupr)) for aupr in aupr_list]
37
+ acc_in_list = [float('{:.2f}'.format(acc_in[0]))]*len(aupr_list)
38
+ import pandas as pd
39
+ data = {k: v for k, v in zip(out_datasets, zip(fpr_list, auroc_list, aupr_list, acc_in_list))}
40
+ data['AVG'] = [np.mean(fpr_list), np.mean(auroc_list), np.mean(aupr_list), np.mean(acc_in_list)]
41
+ data['AVG'] = [float('{:.2f}'.format(metric)) for metric in data['AVG']]
42
+ # Specify orient='index' to create the DataFrame using dictionary keys as rows
43
+ df = pd.DataFrame.from_dict(data, orient='index', columns=['FPR95', 'AUROC', 'AUPR', 'ACC_IN'])
44
+ df.to_csv(os.path.join(args.log_directory, f'result.csv'))
45
+
46
+
47
+ def create_ImageNet_subset(src, dst, target_dirs):
48
+ assert (os.path.exists(src))
49
+ if not os.path.exists(dst):
50
+ os.makedirs(dst)
51
+ types = ['train', 'val']
52
+ for type in types:
53
+ for dir_name in os.listdir(os.path.join(src, type)):
54
+ if dir_name in target_dirs:
55
+ shutil.copytree(os.path.join(src, type, dir_name), os.path.join(dst, type, dir_name))
56
+
57
+
58
+ def prepare_dataframe(captions_dir='gen_captions', dataset_name='imagenet_val', multiple=False):
59
+ # load caption file
60
+ captions_path = os.path.join(captions_dir, f'{dataset_name}_captions.tsv')
61
+ df = pd.read_csv(f"{captions_path}", sep='\t')
62
+ df.columns = ["image_id", "caption", "cls"]
63
+ if multiple: # in case a single img has multiple captions
64
+ x = list(set(df['image_id'].values))
65
+ image_ids = np.arange(0, len(x))
66
+ train_images = [x[i] for i in image_ids]
67
+ df = df[df["image_id"].isin(train_images)].reset_index(drop=True)
68
+ return df
ID-like-train-change-bg/utils/id_like.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from tqdm import tqdm
8
+ from utils.id_like_utils import ClipPromptLearner
9
+ from utils.id_like_loss import get_loss
10
+ from utils.common import AverageMeter, accuracy
11
+ from utils import imagenet_templates
12
+ from clip import load, tokenize
13
+ from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+ _tokenizer = _Tokenizer()
15
+ import config
16
+
17
+
18
+ def select_in_out(args, image_features, sim):
19
+
20
+ idx_in = torch.topk(sim, dim=0, k=args.n_selection)[1].squeeze()
21
+ image_features_crop_in_temp = torch.index_select(image_features, index=idx_in, dim=0)
22
+ idx_out = torch.topk(-sim, dim=0, k=args.n_selection)[1].squeeze()
23
+ image_features_crop_out_temp = torch.index_select(image_features, index=idx_out, dim=0)
24
+ image_features_in_temp = image_features_crop_in_temp
25
+ image_features_out_temp = image_features_crop_out_temp
26
+
27
+ return image_features_in_temp, image_features_out_temp
28
+
29
+
30
+ def get_in_out(args, clip, model, labels, images, targets):
31
+
32
+ image_features_in = []
33
+ image_features_out = []
34
+ targets_in = []
35
+ targets_out = []
36
+ with torch.no_grad():
37
+
38
+ for image_idx, (image, target) in enumerate(zip(images, targets)):
39
+ label = labels[target.item()]
40
+ # openai_imagenet_template = imagenet_templates.openai_imagenet_template
41
+ openai_imagenet_template = [lambda c: f'a photo of a {c}.']
42
+ select_prompts_in = [func(label) for func in openai_imagenet_template]
43
+ text_inputs = tokenize(select_prompts_in).cuda()
44
+ select_prompts_in = clip.encode_text(text_inputs)
45
+ select_prompts_in /= select_prompts_in.norm(dim=-1, keepdim=True)
46
+
47
+ image = image.cuda()
48
+ target = target.cuda()
49
+ image_features = model.get_image_features(image)
50
+ image_features /= image_features.norm(dim=-1, keepdim=True)
51
+
52
+ sim = image_features @ select_prompts_in.t()
53
+ sim = torch.max(sim, dim=1, keepdim=True)[0]
54
+
55
+ image_features_in_temp, image_features_out_temp = select_in_out(args, image_features, sim)
56
+ image_features_in.append(image_features_in_temp)
57
+ image_features_out.append(image_features_out_temp)
58
+
59
+ # create in target
60
+ targets_in_temp = torch.tile(target, dims=(image_features_in_temp.size(0),))
61
+ targets_in.append(targets_in_temp)
62
+
63
+ # create out target
64
+ # no use
65
+ prompt_features = model.get_text_features()
66
+ prompt_features = prompt_features / prompt_features.norm(dim=-1, keepdim=True)
67
+ prompt_features_out = prompt_features[args.n_cls:, ...]
68
+ logit_out_temp = image_features_out_temp @ prompt_features_out.t()
69
+ targets_out_temp = torch.max(logit_out_temp, dim=1)[1] + args.n_cls
70
+ targets_out.append(targets_out_temp)
71
+
72
+ image_features_in = torch.cat(image_features_in, dim=0)
73
+ image_features_out = torch.cat(image_features_out, dim=0)
74
+ targets_in = torch.cat(targets_in, dim=0).cuda()
75
+ targets_out = torch.cat(targets_out, dim=0).cuda()
76
+ return image_features_in, image_features_out, targets_in, targets_out
77
+
78
+
79
+ def get_prompts(args, loader, labels, ex_labels):
80
+ model = ClipPromptLearner(args,
81
+ classnames=labels, ex_classnames=ex_labels, arch=args.CLIP_ckpt, device='cuda',
82
+ n_ctx=config.n_ctx, ctx_init=config.ctx_init,
83
+ ctx_position=config.ctx_position, learned_cls=config.learned_cls,
84
+ n_ex_ctx=config.n_ex_ctx, ex_ctx_init=config.ex_ctx_init,
85
+ ex_ctx_position=config.ex_ctx_position, ex_learned_cls=config.ex_learned_cls)
86
+
87
+ loss_meter = AverageMeter()
88
+ optimizer = torch.optim.AdamW([{'params': model.prompt_learner.parameters()},
89
+ {'params': model.ex_prompt_learner.parameters()}], args.lr)
90
+
91
+ clip, _, _ = load(args.CLIP_ckpt, device='cuda', download_root=config.DOWNLOAD_ROOT)
92
+
93
+ for epoch in range(args.n_epoch):
94
+
95
+ tqdm.write(f'Train epoch:{epoch + 1}/{args.n_epoch}')
96
+ for batch_idx, (images, targets) in enumerate(tqdm(loader)):
97
+ image_features_in, image_features_out, targets_in, targets_out = \
98
+ get_in_out(args, clip, model, labels, images, targets)
99
+
100
+ # train
101
+ # get prompts
102
+ logit_scale = model.logit_scale.exp()
103
+ prompt_features = model.get_text_features()
104
+ prompt_features = prompt_features / prompt_features.norm(dim=-1, keepdim=True)
105
+ loss, loss_str = get_loss(args, prompt_features,
106
+ image_features_in, image_features_out,
107
+ targets_in, targets_out, logit_scale)
108
+
109
+ # update
110
+ loss.backward()
111
+ optimizer.step()
112
+ optimizer.zero_grad()
113
+
114
+ loss_meter.update(loss.detach().cpu().item())
115
+ tqdm.write(f'Train epoch:{epoch + 1}/{args.n_epoch}\t'
116
+ f'Loss_avg:{loss_meter.avg:.6f}\t' + loss_str)
117
+
118
+ if epoch+1 == args.n_epoch:
119
+ model_save_dir = args.log_directory
120
+ os.makedirs(model_save_dir, exist_ok=True)
121
+ model_checkpoint_save_path = os.path.join(model_save_dir, 'model_checkpoint.pth')
122
+ model_checkpoint = {
123
+ 'prompt_learner_state_dict': model.prompt_learner.state_dict(),
124
+ 'ex_prompt_learner_state_dict': model.ex_prompt_learner.state_dict(),
125
+ }
126
+ torch.save(model_checkpoint, model_checkpoint_save_path)
127
+
128
+ return model
129
+
130
+
131
+ def get_result(args, model, loader, labels, ex_labels, if_acc=False):
132
+ tqdm_object = tqdm(loader, total=len(loader))
133
+ outputs = []
134
+ all_targets = []
135
+ result = {
136
+ 'scores': None,
137
+ 'acc': None,
138
+ }
139
+
140
+ with torch.no_grad():
141
+ text_features = model.get_text_features()
142
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
143
+
144
+ for batch_idx, (images, targets) in enumerate(tqdm_object):
145
+ with torch.no_grad():
146
+ images = images.cuda()
147
+ targets = targets.long().cuda()
148
+ image_features = model.image_encoder(images)
149
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
150
+ logit_scale = model.logit_scale.exp()
151
+ output = logit_scale * image_features @ text_features.t()
152
+
153
+ output = output.detach().cpu()
154
+ outputs.append(output)
155
+ all_targets.append(targets)
156
+ outputs = torch.cat(outputs, dim=0)
157
+ all_targets = torch.cat(all_targets, dim=0)
158
+
159
+ # scores
160
+
161
+ outputs_softmax = F.softmax(outputs, dim=1)
162
+ scores = torch.sum(outputs_softmax[:, args.n_cls:], dim=1).detach().cpu().squeeze().numpy() - 1
163
+
164
+ result['scores'] = scores
165
+ # acc
166
+ if if_acc:
167
+ res = accuracy(outputs[:, :args.n_cls], all_targets.detach().cpu())
168
+ result['acc'] = [acc.item() for acc in res]
169
+ return result
170
+
171
+
172
+ def load_model(args, labels, ex_labels):
173
+ model = ClipPromptLearner(args,
174
+ classnames=labels, ex_classnames=ex_labels, arch=args.CLIP_ckpt, device='cuda',
175
+ n_ctx=config.n_ctx, ctx_init=config.ctx_init,
176
+ ctx_position=config.ctx_position, learned_cls=config.learned_cls,
177
+ n_ex_ctx=config.n_ex_ctx, ex_ctx_init=config.ex_ctx_init,
178
+ ex_ctx_position=config.ex_ctx_position, ex_learned_cls=config.ex_learned_cls)
179
+ model_checkpoint_save_path = os.path.join(args.log_directory, 'model_checkpoint.pth')
180
+ model_checkpoint = torch.load(model_checkpoint_save_path)
181
+ model.prompt_learner.load_state_dict(model_checkpoint['prompt_learner_state_dict'])
182
+ model.ex_prompt_learner.load_state_dict(model_checkpoint['ex_prompt_learner_state_dict'])
183
+ return model.cuda()
184
+
ID-like-train-change-bg/utils/id_like_loss.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+ from utils.id_like_utils import ClipPromptLearner
8
+ from utils.common import AverageMeter, accuracy
9
+ import config
10
+
11
+
12
+ def get_loss(args, prompt_features, image_features_in, image_features_out, targets_in, targets_out, logit_scale):
13
+ prompt_features_in = prompt_features[:args.n_cls, ...]
14
+ prompt_features_out = prompt_features[args.n_cls:, ...]
15
+ # loss_in
16
+ logit_in = logit_scale * image_features_in @ prompt_features.t()
17
+ # logit_in = logit_scale * image_features_in @ prompt_features_in.t()
18
+ loss_in = F.cross_entropy(logit_in, targets_in)
19
+
20
+ # loss_out
21
+ logit_out = logit_scale * image_features_out @ prompt_features.t()
22
+
23
+ # logit_out_softmax_probs = F.softmax(logit_out, dim=1)
24
+ # flag_out = torch.cat([torch.LongTensor([0] * args.n_cls + [1] * args.n_ex_prompts)], dim=0).cuda()
25
+ # logit_out_softmax_probs_in = torch.sum(logit_out_softmax_probs * (1 - flag_out), dim=1)
26
+ # logit_out_softmax_probs_in_log = -torch.log(1.-logit_out_softmax_probs_in)
27
+ # loss_out = torch.mean(logit_out_softmax_probs_in_log)
28
+
29
+ logit_out_softmax_probs = F.softmax(logit_out, dim=1)
30
+ flag_out = torch.cat([torch.LongTensor([0] * args.n_cls + [1] * args.n_ex_prompts)], dim=0).cuda()
31
+ logit_out_softmax_probs_in = torch.sum(logit_out_softmax_probs * (1 - flag_out), dim=1)
32
+ logit_out_softmax_probs_in_log = torch.log(logit_out_softmax_probs_in + 1e-16)
33
+ loss_out = torch.mean(logit_out_softmax_probs_in_log)
34
+
35
+ # loss_diff
36
+ loss_diff = torch.FloatTensor([0.]).cuda()
37
+ for p in range(prompt_features_out.size(0) - 1):
38
+ for q in range(p + 1, prompt_features_out.size(0)):
39
+ loss_diff += F.cosine_embedding_loss(input1=prompt_features_out[p].unsqueeze(dim=0),
40
+ input2=prompt_features_out[q].unsqueeze(dim=0),
41
+ target=torch.LongTensor([-1]).cuda())
42
+ if prompt_features_out.size(0) > 1:
43
+ loss_diff /= (prompt_features_out.size(0) * (prompt_features_out.size(0) - 1) / 2.)
44
+
45
+ # loss
46
+ loss = loss_in * args.lam_in + loss_out * args.lam_out + loss_diff * args.lam_diff
47
+ loss_str = f'Loss_now:{loss.detach().cpu().item():.6f}\t' \
48
+ f'Loss_in:{loss_in.detach().cpu().item():.6f}\t' \
49
+ f'Loss_out:{loss_out.detach().cpu().item():.6f}\t' \
50
+ f'Loss_diff:{loss_diff.detach().cpu().item():.6f}'
51
+
52
+ return loss, loss_str
ID-like-train-change-bg/utils/id_like_utils.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+ import os
4
+ import json
5
+ import time
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.cuda.amp import autocast, GradScaler
11
+ import numpy as np
12
+
13
+ # from transformers import CLIPTokenizer
14
+ from tqdm import tqdm
15
+
16
+ from clip import load, tokenize
17
+ from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
18
+
19
+ import config
20
+
21
+ _tokenizer = _Tokenizer()
22
+ DOWNLOAD_ROOT = config.DOWNLOAD_ROOT
23
+
24
+
25
+ class TextEncoder(nn.Module):
26
+ def __init__(self, clip_model):
27
+ super().__init__()
28
+ self.transformer = clip_model.transformer
29
+ self.positional_embedding = clip_model.positional_embedding
30
+ self.ln_final = clip_model.ln_final
31
+ self.text_projection = clip_model.text_projection
32
+ self.dtype = clip_model.dtype
33
+
34
+ def forward(self, prompts, tokenized_prompts):
35
+ x = prompts + self.positional_embedding.type(self.dtype)
36
+ x = x.permute(1, 0, 2) # NLD -> LND
37
+ x = self.transformer(x)
38
+ x = x.permute(1, 0, 2) # LND -> NLD
39
+ x = self.ln_final(x).type(self.dtype)
40
+
41
+ # x.shape = [batch_size, n_ctx, transformer.width]
42
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
43
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
44
+
45
+ return x
46
+
47
+
48
+ class PromptLearner(nn.Module):
49
+ def __init__(self, clip_model, classnames, n_ctx=16, ctx_init=None, ctx_position='end',
50
+ learned_cls=False):
51
+ super().__init__()
52
+ n_cls = len(classnames)
53
+ self.learned_cls = learned_cls
54
+ dtype = clip_model.dtype
55
+ self.dtype = dtype
56
+ self.device = clip_model.visual.conv1.weight.device
57
+ ctx_dim = clip_model.ln_final.weight.shape[0]
58
+ self.ctx_dim = ctx_dim
59
+
60
+ if ctx_init:
61
+ print("Initializing the contect with given words: [{}]".format(ctx_init))
62
+ ctx_init = ctx_init.replace("_", " ")
63
+ if '[CLS]' in ctx_init:
64
+ ctx_list = ctx_init.split(" ")
65
+ split_idx = ctx_list.index("[CLS]")
66
+ ctx_init = ctx_init.replace("[CLS] ", "")
67
+ ctx_position = "middle"
68
+ else:
69
+ split_idx = None
70
+ self.split_idx = split_idx
71
+ n_ctx = len(ctx_init.split(" "))
72
+ prompt = tokenize(ctx_init).to(self.device)
73
+ with torch.no_grad():
74
+ embedding = clip_model.token_embedding(prompt).type(dtype)
75
+ ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
76
+ prompt_prefix = ctx_init
77
+ else:
78
+ ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
79
+ nn.init.normal_(ctx_vectors, std=0.02)
80
+ prompt_prefix = " ".join(["X"] * n_ctx)
81
+
82
+ self.prompt_prefix = prompt_prefix
83
+
84
+ print(f'Initial context: "{prompt_prefix}"')
85
+ print(f"Number of context words (tokens): {n_ctx}")
86
+
87
+ self.ctx_init_state = ctx_vectors.detach().clone()
88
+ self.ctx = nn.Parameter(ctx_vectors) # to be optimized
89
+
90
+ if not self.learned_cls:
91
+ classnames = [name.replace("_", " ") for name in classnames]
92
+ name_lens = [len(_tokenizer.encode(name)) for name in classnames]
93
+ prompts = [prompt_prefix + " " + name + "." for name in classnames]
94
+ else:
95
+ cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) # assume each learnable cls_token is only 1 word
96
+ nn.init.normal_(cls_vectors, std=0.02)
97
+ cls_token = "X"
98
+ name_lens = [1 for _ in classnames]
99
+ prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames]
100
+
101
+ self.cls_init_state = cls_vectors.detach().clone()
102
+ self.cls = nn.Parameter(cls_vectors)
103
+
104
+ tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device)
105
+ with torch.no_grad():
106
+ embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
107
+
108
+ self.register_buffer("token_prefix", embedding[:, :1, :])
109
+ if self.learned_cls:
110
+ self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :])
111
+ else:
112
+ self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])
113
+
114
+ self.ctx_init = ctx_init
115
+ self.tokenized_prompts = tokenized_prompts # torch.Tensor
116
+ self.name_lens = name_lens
117
+ self.class_token_position = ctx_position
118
+ self.n_cls = n_cls
119
+ self.n_ctx = n_ctx
120
+ self.classnames = classnames
121
+
122
+ def reset(self):
123
+ ctx_vectors = self.ctx_init_state
124
+ self.ctx.copy_(ctx_vectors)
125
+ if self.learned_cls:
126
+ cls_vectors = self.cls_init_state
127
+ self.cls.copy_(cls_vectors)
128
+
129
+ def reset_classnames(self, classnames, arch):
130
+ self.n_cls = len(classnames)
131
+ if not self.learned_cls:
132
+ classnames = [name.replace("_", " ") for name in classnames]
133
+ name_lens = [len(_tokenizer.encode(name)) for name in classnames]
134
+ prompts = [self.prompt_prefix + " " + name + "." for name in classnames]
135
+ else:
136
+ cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype)
137
+ nn.init.normal_(cls_vectors, std=0.02)
138
+ cls_token = "X"
139
+ name_lens = [1 for _ in classnames]
140
+ prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames]
141
+ # TODO: re-init the cls parameters
142
+ # self.cls = nn.Parameter(cls_vectors) # to be optimized
143
+ self.cls_init_state = cls_vectors.detach().clone()
144
+ tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device)
145
+
146
+ clip, _, _ = load(arch, device=self.device, download_root=DOWNLOAD_ROOT)
147
+
148
+ with torch.no_grad():
149
+ embedding = clip.token_embedding(tokenized_prompts).type(self.dtype)
150
+
151
+ self.token_prefix = embedding[:, :1, :]
152
+ self.token_suffix = embedding[:, 1 + self.n_ctx:, :] # CLS, EOS
153
+
154
+ self.name_lens = name_lens
155
+ self.tokenized_prompts = tokenized_prompts
156
+ self.classnames = classnames
157
+
158
+ def forward(self, init=None):
159
+ if init is not None:
160
+ ctx = init
161
+ else:
162
+ ctx = self.ctx
163
+ if ctx.dim() == 2:
164
+ ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
165
+ elif not ctx.size()[0] == self.n_cls:
166
+ ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1)
167
+
168
+ prefix = self.token_prefix
169
+ suffix = self.token_suffix
170
+
171
+ if self.learned_cls:
172
+ assert self.class_token_position == "end"
173
+ if self.class_token_position == "end":
174
+ if self.learned_cls:
175
+ cls = self.cls
176
+ prompts = torch.cat(
177
+ [
178
+ prefix, # (n_cls, 1, dim)
179
+ ctx.to(self.device), # (n_cls, n_ctx, dim)
180
+ cls.to(self.device), # (n_cls, 1, dim)
181
+ suffix, # (n_cls, *, dim)
182
+ ],
183
+ dim=-2,
184
+ )
185
+ else:
186
+ prompts = torch.cat(
187
+ [
188
+ prefix, # (n_cls, 1, dim)
189
+ ctx.to(self.device), # (n_cls, n_ctx, dim)
190
+ suffix, # (n_cls, *, dim)
191
+ ],
192
+ dim=-2,
193
+ )
194
+ elif self.class_token_position == "middle":
195
+ # TODO: to work with a batch of prompts
196
+ if self.split_idx is not None:
197
+ half_n_ctx = self.split_idx # split the ctx at the position of [CLS] in `ctx_init`
198
+ else:
199
+ half_n_ctx = self.n_ctx // 2
200
+ prompts = []
201
+ for i in range(self.n_cls):
202
+ name_len = self.name_lens[i]
203
+ prefix_i = prefix[i: i + 1, :, :]
204
+ class_i = suffix[i: i + 1, :name_len, :]
205
+ suffix_i = suffix[i: i + 1, name_len:, :]
206
+ ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
207
+ ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
208
+ prompt = torch.cat(
209
+ [
210
+ prefix_i, # (1, 1, dim)
211
+ ctx_i_half1.to(self.device), # (1, n_ctx//2, dim)
212
+ class_i.to(self.device), # (1, name_len, dim)
213
+ ctx_i_half2.to(self.device), # (1, n_ctx//2, dim)
214
+ suffix_i, # (1, *, dim)
215
+ ],
216
+ dim=1,
217
+ )
218
+ prompts.append(prompt)
219
+ prompts = torch.cat(prompts, dim=0)
220
+
221
+ elif self.class_token_position == "front":
222
+ prompts = []
223
+ for i in range(self.n_cls):
224
+ name_len = self.name_lens[i]
225
+ prefix_i = prefix[i: i + 1, :, :]
226
+ class_i = suffix[i: i + 1, :name_len, :]
227
+ suffix_i = suffix[i: i + 1, name_len:, :]
228
+ ctx_i = ctx[i: i + 1, :, :]
229
+ prompt = torch.cat(
230
+ [
231
+ prefix_i, # (1, 1, dim)
232
+ class_i.to(self.device), # (1, name_len, dim)
233
+ ctx_i.to(self.device), # (1, n_ctx, dim)
234
+ suffix_i, # (1, *, dim)
235
+ ],
236
+ dim=1,
237
+ )
238
+ prompts.append(prompt)
239
+ prompts = torch.cat(prompts, dim=0)
240
+
241
+ else:
242
+ raise ValueError
243
+
244
+ return prompts
245
+
246
+
247
+ class ClipPromptLearner(nn.Module):
248
+ def __init__(self, args,
249
+ classnames, ex_classnames,
250
+ criterion='cosine', arch="ViT-B/16", device='cuda',
251
+ n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False,
252
+ n_ex_ctx=16, ex_ctx_init=None, ex_ctx_position='end', ex_learned_cls=True):
253
+ super(ClipPromptLearner, self).__init__()
254
+ clip, _, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT)
255
+
256
+ self.image_encoder = clip.visual
257
+
258
+ self.text_encoder = TextEncoder(clip)
259
+ self.text_encoder = nn.parallel.DataParallel(self.text_encoder).to(torch.device("cuda")) # for mutil GPU
260
+ self.logit_scale = clip.logit_scale.data
261
+ # prompt
262
+ self.prompt_learner = PromptLearner(clip, classnames, n_ctx,
263
+ ctx_init, ctx_position, learned_cls=learned_cls)
264
+ self.ex_prompt_learner = PromptLearner(clip, ex_classnames, n_ex_ctx,
265
+ ex_ctx_init, ex_ctx_position, learned_cls=ex_learned_cls)
266
+
267
+ self.criterion = criterion
268
+
269
+ @property
270
+ def dtype(self):
271
+ return self.image_encoder.conv1.weight.dtype
272
+
273
+ def get_text_features(self):
274
+ prompts = torch.cat((self.prompt_learner(),
275
+ self.ex_prompt_learner()), dim=0)
276
+ tokenized_prompts = torch.cat((self.prompt_learner.tokenized_prompts,
277
+ self.ex_prompt_learner.tokenized_prompts), dim=0)
278
+ prompts = prompts.cuda(non_blocking=True) # for mutil GPU
279
+ tokenized_prompts = tokenized_prompts.cuda(non_blocking=True) # for mutil GPU
280
+ text_features = []
281
+ t_features = self.text_encoder(prompts, tokenized_prompts)
282
+ return t_features
283
+
284
+ def get_image_features(self, image):
285
+ image_features = self.image_encoder(image.type(self.dtype))
286
+ return image_features
287
+
288
+ def forward(self, image):
289
+ with torch.no_grad():
290
+ image_features = self.image_encoder(image.type(self.dtype))
291
+
292
+ text_features = self.get_text_features()
293
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
294
+
295
+ logit_scale = self.logit_scale.exp()
296
+ logits = logit_scale * image_features @ text_features.t()
297
+
298
+ return logits