qic999 commited on
Commit
948ad69
·
verified ·
1 Parent(s): 9999797

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -0
  2. var/D3HR/DiT-XL-2-256x256.pt +3 -0
  3. var/D3HR/DiT-XL/.gitattributes +34 -0
  4. var/D3HR/DiT-XL/README.md +10 -0
  5. var/D3HR/DiT-XL/model_index.json +1018 -0
  6. var/D3HR/DiT-XL/scheduler/scheduler_config.json +13 -0
  7. var/D3HR/DiT-XL/transformer/config.json +23 -0
  8. var/D3HR/DiT-XL/transformer/diffusion_pytorch_model.bin +3 -0
  9. var/D3HR/DiT-XL/vae/config.json +30 -0
  10. var/D3HR/DiT-XL/vae/diffusion_pytorch_model.bin +3 -0
  11. var/D3HR/README.md +98 -0
  12. var/D3HR/ds_inf/imagenet1k_train.txt +3 -0
  13. var/D3HR/ds_inf/imagenet_1k_mapping.json +0 -0
  14. var/D3HR/ds_inf/tiny-imagenet-mapping.txt +200 -0
  15. var/D3HR/generation/__init__.py +0 -0
  16. var/D3HR/generation/__pycache__/__init__.cpython-310.pyc +0 -0
  17. var/D3HR/generation/__pycache__/dit_inversion_save_statistic.cpython-310.pyc +0 -0
  18. var/D3HR/generation/dit_inversion_save_statistic.py +437 -0
  19. var/D3HR/generation/dit_inversion_save_statistic.sh +24 -0
  20. var/D3HR/generation/group_sampling.py +368 -0
  21. var/D3HR/generation/group_sampling.sh +20 -0
  22. var/D3HR/imgs/framework.jpg +3 -0
  23. var/D3HR/imgs/framework.pdf +3 -0
  24. var/D3HR/requirements.txt +9 -0
  25. var/D3HR/validation/__pycache__/argument.cpython-310.pyc +0 -0
  26. var/D3HR/validation/argument.py +310 -0
  27. var/D3HR/validation/get_train_list.py +26 -0
  28. var/D3HR/validation/models/__init__.py +190 -0
  29. var/D3HR/validation/models/__pycache__/__init__.cpython-310.pyc +0 -0
  30. var/D3HR/validation/models/__pycache__/__init__.cpython-37.pyc +0 -0
  31. var/D3HR/validation/models/__pycache__/convnet.cpython-310.pyc +0 -0
  32. var/D3HR/validation/models/__pycache__/convnet.cpython-37.pyc +0 -0
  33. var/D3HR/validation/models/__pycache__/mobilenet_v2.cpython-310.pyc +0 -0
  34. var/D3HR/validation/models/__pycache__/mobilenet_v2.cpython-37.pyc +0 -0
  35. var/D3HR/validation/models/__pycache__/resnet.cpython-310.pyc +0 -0
  36. var/D3HR/validation/models/__pycache__/resnet.cpython-37.pyc +0 -0
  37. var/D3HR/validation/models/convnet.py +147 -0
  38. var/D3HR/validation/models/dit_models.py +438 -0
  39. var/D3HR/validation/models/mobilenet_v2.py +151 -0
  40. var/D3HR/validation/models/pipeline_stable_unclip_img2img.py +854 -0
  41. var/D3HR/validation/models/resnet.py +80 -0
  42. var/D3HR/validation/models/scheduling_ddim.py +522 -0
  43. var/D3HR/validation/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  44. var/D3HR/validation/utils/__pycache__/data_utils.cpython-37.pyc +0 -0
  45. var/D3HR/validation/utils/__pycache__/validate_utils.cpython-310.pyc +0 -0
  46. var/D3HR/validation/utils/__pycache__/validate_utils.cpython-37.pyc +0 -0
  47. var/D3HR/validation/utils/data_utils.py +431 -0
  48. var/D3HR/validation/utils/download.py +50 -0
  49. var/D3HR/validation/utils/syn_utils_dit.py +172 -0
  50. var/D3HR/validation/utils/syn_utils_img2img.py +134 -0
.gitattributes CHANGED
The diff for this file is too large to render. See raw diff
 
var/D3HR/DiT-XL-2-256x256.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ec1876e4c03471bca126663a30e2d1b20610b6d2f87850a39a36f25cc685521
3
+ size 2700611775
var/D3HR/DiT-XL/.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
var/D3HR/DiT-XL/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ ---
4
+
5
+ # Scalable Diffusion Models with Transformers (DiT)
6
+
7
+ ## Abstract
8
+
9
+ We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.
10
+
var/D3HR/DiT-XL/model_index.json ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DiTPipeline",
3
+ "_diffusers_version": "0.12.0.dev0",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "DDIMScheduler"
7
+ ],
8
+ "transformer": [
9
+ "diffusers",
10
+ "Transformer2DModel"
11
+ ],
12
+ "vae": [
13
+ "diffusers",
14
+ "AutoencoderKL"
15
+ ],
16
+ "id2label": {
17
+ "0": "tench, Tinca tinca",
18
+ "1": "goldfish, Carassius auratus",
19
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
20
+ "3": "tiger shark, Galeocerdo cuvieri",
21
+ "4": "hammerhead, hammerhead shark",
22
+ "5": "electric ray, crampfish, numbfish, torpedo",
23
+ "6": "stingray",
24
+ "7": "cock",
25
+ "8": "hen",
26
+ "9": "ostrich, Struthio camelus",
27
+ "10": "brambling, Fringilla montifringilla",
28
+ "11": "goldfinch, Carduelis carduelis",
29
+ "12": "house finch, linnet, Carpodacus mexicanus",
30
+ "13": "junco, snowbird",
31
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
32
+ "15": "robin, American robin, Turdus migratorius",
33
+ "16": "bulbul",
34
+ "17": "jay",
35
+ "18": "magpie",
36
+ "19": "chickadee",
37
+ "20": "water ouzel, dipper",
38
+ "21": "kite",
39
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
40
+ "23": "vulture",
41
+ "24": "great grey owl, great gray owl, Strix nebulosa",
42
+ "25": "European fire salamander, Salamandra salamandra",
43
+ "26": "common newt, Triturus vulgaris",
44
+ "27": "eft",
45
+ "28": "spotted salamander, Ambystoma maculatum",
46
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
47
+ "30": "bullfrog, Rana catesbeiana",
48
+ "31": "tree frog, tree-frog",
49
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
50
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
51
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
52
+ "35": "mud turtle",
53
+ "36": "terrapin",
54
+ "37": "box turtle, box tortoise",
55
+ "38": "banded gecko",
56
+ "39": "common iguana, iguana, Iguana iguana",
57
+ "40": "American chameleon, anole, Anolis carolinensis",
58
+ "41": "whiptail, whiptail lizard",
59
+ "42": "agama",
60
+ "43": "frilled lizard, Chlamydosaurus kingi",
61
+ "44": "alligator lizard",
62
+ "45": "Gila monster, Heloderma suspectum",
63
+ "46": "green lizard, Lacerta viridis",
64
+ "47": "African chameleon, Chamaeleo chamaeleon",
65
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
66
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
67
+ "50": "American alligator, Alligator mississipiensis",
68
+ "51": "triceratops",
69
+ "52": "thunder snake, worm snake, Carphophis amoenus",
70
+ "53": "ringneck snake, ring-necked snake, ring snake",
71
+ "54": "hognose snake, puff adder, sand viper",
72
+ "55": "green snake, grass snake",
73
+ "56": "king snake, kingsnake",
74
+ "57": "garter snake, grass snake",
75
+ "58": "water snake",
76
+ "59": "vine snake",
77
+ "60": "night snake, Hypsiglena torquata",
78
+ "61": "boa constrictor, Constrictor constrictor",
79
+ "62": "rock python, rock snake, Python sebae",
80
+ "63": "Indian cobra, Naja naja",
81
+ "64": "green mamba",
82
+ "65": "sea snake",
83
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
84
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
85
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
86
+ "69": "trilobite",
87
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
88
+ "71": "scorpion",
89
+ "72": "black and gold garden spider, Argiope aurantia",
90
+ "73": "barn spider, Araneus cavaticus",
91
+ "74": "garden spider, Aranea diademata",
92
+ "75": "black widow, Latrodectus mactans",
93
+ "76": "tarantula",
94
+ "77": "wolf spider, hunting spider",
95
+ "78": "tick",
96
+ "79": "centipede",
97
+ "80": "black grouse",
98
+ "81": "ptarmigan",
99
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
100
+ "83": "prairie chicken, prairie grouse, prairie fowl",
101
+ "84": "peacock",
102
+ "85": "quail",
103
+ "86": "partridge",
104
+ "87": "African grey, African gray, Psittacus erithacus",
105
+ "88": "macaw",
106
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
107
+ "90": "lorikeet",
108
+ "91": "coucal",
109
+ "92": "bee eater",
110
+ "93": "hornbill",
111
+ "94": "hummingbird",
112
+ "95": "jacamar",
113
+ "96": "toucan",
114
+ "97": "drake",
115
+ "98": "red-breasted merganser, Mergus serrator",
116
+ "99": "goose",
117
+ "100": "black swan, Cygnus atratus",
118
+ "101": "tusker",
119
+ "102": "echidna, spiny anteater, anteater",
120
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
121
+ "104": "wallaby, brush kangaroo",
122
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
123
+ "106": "wombat",
124
+ "107": "jellyfish",
125
+ "108": "sea anemone, anemone",
126
+ "109": "brain coral",
127
+ "110": "flatworm, platyhelminth",
128
+ "111": "nematode, nematode worm, roundworm",
129
+ "112": "conch",
130
+ "113": "snail",
131
+ "114": "slug",
132
+ "115": "sea slug, nudibranch",
133
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
134
+ "117": "chambered nautilus, pearly nautilus, nautilus",
135
+ "118": "Dungeness crab, Cancer magister",
136
+ "119": "rock crab, Cancer irroratus",
137
+ "120": "fiddler crab",
138
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
139
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
140
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
141
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
142
+ "125": "hermit crab",
143
+ "126": "isopod",
144
+ "127": "white stork, Ciconia ciconia",
145
+ "128": "black stork, Ciconia nigra",
146
+ "129": "spoonbill",
147
+ "130": "flamingo",
148
+ "131": "little blue heron, Egretta caerulea",
149
+ "132": "American egret, great white heron, Egretta albus",
150
+ "133": "bittern",
151
+ "134": "crane",
152
+ "135": "limpkin, Aramus pictus",
153
+ "136": "European gallinule, Porphyrio porphyrio",
154
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
155
+ "138": "bustard",
156
+ "139": "ruddy turnstone, Arenaria interpres",
157
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
158
+ "141": "redshank, Tringa totanus",
159
+ "142": "dowitcher",
160
+ "143": "oystercatcher, oyster catcher",
161
+ "144": "pelican",
162
+ "145": "king penguin, Aptenodytes patagonica",
163
+ "146": "albatross, mollymawk",
164
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
165
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
166
+ "149": "dugong, Dugong dugon",
167
+ "150": "sea lion",
168
+ "151": "Chihuahua",
169
+ "152": "Japanese spaniel",
170
+ "153": "Maltese dog, Maltese terrier, Maltese",
171
+ "154": "Pekinese, Pekingese, Peke",
172
+ "155": "Shih-Tzu",
173
+ "156": "Blenheim spaniel",
174
+ "157": "papillon",
175
+ "158": "toy terrier",
176
+ "159": "Rhodesian ridgeback",
177
+ "160": "Afghan hound, Afghan",
178
+ "161": "basset, basset hound",
179
+ "162": "beagle",
180
+ "163": "bloodhound, sleuthhound",
181
+ "164": "bluetick",
182
+ "165": "black-and-tan coonhound",
183
+ "166": "Walker hound, Walker foxhound",
184
+ "167": "English foxhound",
185
+ "168": "redbone",
186
+ "169": "borzoi, Russian wolfhound",
187
+ "170": "Irish wolfhound",
188
+ "171": "Italian greyhound",
189
+ "172": "whippet",
190
+ "173": "Ibizan hound, Ibizan Podenco",
191
+ "174": "Norwegian elkhound, elkhound",
192
+ "175": "otterhound, otter hound",
193
+ "176": "Saluki, gazelle hound",
194
+ "177": "Scottish deerhound, deerhound",
195
+ "178": "Weimaraner",
196
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
197
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
198
+ "181": "Bedlington terrier",
199
+ "182": "Border terrier",
200
+ "183": "Kerry blue terrier",
201
+ "184": "Irish terrier",
202
+ "185": "Norfolk terrier",
203
+ "186": "Norwich terrier",
204
+ "187": "Yorkshire terrier",
205
+ "188": "wire-haired fox terrier",
206
+ "189": "Lakeland terrier",
207
+ "190": "Sealyham terrier, Sealyham",
208
+ "191": "Airedale, Airedale terrier",
209
+ "192": "cairn, cairn terrier",
210
+ "193": "Australian terrier",
211
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
212
+ "195": "Boston bull, Boston terrier",
213
+ "196": "miniature schnauzer",
214
+ "197": "giant schnauzer",
215
+ "198": "standard schnauzer",
216
+ "199": "Scotch terrier, Scottish terrier, Scottie",
217
+ "200": "Tibetan terrier, chrysanthemum dog",
218
+ "201": "silky terrier, Sydney silky",
219
+ "202": "soft-coated wheaten terrier",
220
+ "203": "West Highland white terrier",
221
+ "204": "Lhasa, Lhasa apso",
222
+ "205": "flat-coated retriever",
223
+ "206": "curly-coated retriever",
224
+ "207": "golden retriever",
225
+ "208": "Labrador retriever",
226
+ "209": "Chesapeake Bay retriever",
227
+ "210": "German short-haired pointer",
228
+ "211": "vizsla, Hungarian pointer",
229
+ "212": "English setter",
230
+ "213": "Irish setter, red setter",
231
+ "214": "Gordon setter",
232
+ "215": "Brittany spaniel",
233
+ "216": "clumber, clumber spaniel",
234
+ "217": "English springer, English springer spaniel",
235
+ "218": "Welsh springer spaniel",
236
+ "219": "cocker spaniel, English cocker spaniel, cocker",
237
+ "220": "Sussex spaniel",
238
+ "221": "Irish water spaniel",
239
+ "222": "kuvasz",
240
+ "223": "schipperke",
241
+ "224": "groenendael",
242
+ "225": "malinois",
243
+ "226": "briard",
244
+ "227": "kelpie",
245
+ "228": "komondor",
246
+ "229": "Old English sheepdog, bobtail",
247
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
248
+ "231": "collie",
249
+ "232": "Border collie",
250
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
251
+ "234": "Rottweiler",
252
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
253
+ "236": "Doberman, Doberman pinscher",
254
+ "237": "miniature pinscher",
255
+ "238": "Greater Swiss Mountain dog",
256
+ "239": "Bernese mountain dog",
257
+ "240": "Appenzeller",
258
+ "241": "EntleBucher",
259
+ "242": "boxer",
260
+ "243": "bull mastiff",
261
+ "244": "Tibetan mastiff",
262
+ "245": "French bulldog",
263
+ "246": "Great Dane",
264
+ "247": "Saint Bernard, St Bernard",
265
+ "248": "Eskimo dog, husky",
266
+ "249": "malamute, malemute, Alaskan malamute",
267
+ "250": "Siberian husky",
268
+ "251": "dalmatian, coach dog, carriage dog",
269
+ "252": "affenpinscher, monkey pinscher, monkey dog",
270
+ "253": "basenji",
271
+ "254": "pug, pug-dog",
272
+ "255": "Leonberg",
273
+ "256": "Newfoundland, Newfoundland dog",
274
+ "257": "Great Pyrenees",
275
+ "258": "Samoyed, Samoyede",
276
+ "259": "Pomeranian",
277
+ "260": "chow, chow chow",
278
+ "261": "keeshond",
279
+ "262": "Brabancon griffon",
280
+ "263": "Pembroke, Pembroke Welsh corgi",
281
+ "264": "Cardigan, Cardigan Welsh corgi",
282
+ "265": "toy poodle",
283
+ "266": "miniature poodle",
284
+ "267": "standard poodle",
285
+ "268": "Mexican hairless",
286
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
287
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
288
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
289
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
290
+ "273": "dingo, warrigal, warragal, Canis dingo",
291
+ "274": "dhole, Cuon alpinus",
292
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
293
+ "276": "hyena, hyaena",
294
+ "277": "red fox, Vulpes vulpes",
295
+ "278": "kit fox, Vulpes macrotis",
296
+ "279": "Arctic fox, white fox, Alopex lagopus",
297
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
298
+ "281": "tabby, tabby cat",
299
+ "282": "tiger cat",
300
+ "283": "Persian cat",
301
+ "284": "Siamese cat, Siamese",
302
+ "285": "Egyptian cat",
303
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
304
+ "287": "lynx, catamount",
305
+ "288": "leopard, Panthera pardus",
306
+ "289": "snow leopard, ounce, Panthera uncia",
307
+ "290": "jaguar, panther, Panthera onca, Felis onca",
308
+ "291": "lion, king of beasts, Panthera leo",
309
+ "292": "tiger, Panthera tigris",
310
+ "293": "cheetah, chetah, Acinonyx jubatus",
311
+ "294": "brown bear, bruin, Ursus arctos",
312
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
313
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
314
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
315
+ "298": "mongoose",
316
+ "299": "meerkat, mierkat",
317
+ "300": "tiger beetle",
318
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
319
+ "302": "ground beetle, carabid beetle",
320
+ "303": "long-horned beetle, longicorn, longicorn beetle",
321
+ "304": "leaf beetle, chrysomelid",
322
+ "305": "dung beetle",
323
+ "306": "rhinoceros beetle",
324
+ "307": "weevil",
325
+ "308": "fly",
326
+ "309": "bee",
327
+ "310": "ant, emmet, pismire",
328
+ "311": "grasshopper, hopper",
329
+ "312": "cricket",
330
+ "313": "walking stick, walkingstick, stick insect",
331
+ "314": "cockroach, roach",
332
+ "315": "mantis, mantid",
333
+ "316": "cicada, cicala",
334
+ "317": "leafhopper",
335
+ "318": "lacewing, lacewing fly",
336
+ "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
337
+ "320": "damselfly",
338
+ "321": "admiral",
339
+ "322": "ringlet, ringlet butterfly",
340
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
341
+ "324": "cabbage butterfly",
342
+ "325": "sulphur butterfly, sulfur butterfly",
343
+ "326": "lycaenid, lycaenid butterfly",
344
+ "327": "starfish, sea star",
345
+ "328": "sea urchin",
346
+ "329": "sea cucumber, holothurian",
347
+ "330": "wood rabbit, cottontail, cottontail rabbit",
348
+ "331": "hare",
349
+ "332": "Angora, Angora rabbit",
350
+ "333": "hamster",
351
+ "334": "porcupine, hedgehog",
352
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
353
+ "336": "marmot",
354
+ "337": "beaver",
355
+ "338": "guinea pig, Cavia cobaya",
356
+ "339": "sorrel",
357
+ "340": "zebra",
358
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
359
+ "342": "wild boar, boar, Sus scrofa",
360
+ "343": "warthog",
361
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
362
+ "345": "ox",
363
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
364
+ "347": "bison",
365
+ "348": "ram, tup",
366
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
367
+ "350": "ibex, Capra ibex",
368
+ "351": "hartebeest",
369
+ "352": "impala, Aepyceros melampus",
370
+ "353": "gazelle",
371
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
372
+ "355": "llama",
373
+ "356": "weasel",
374
+ "357": "mink",
375
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
376
+ "359": "black-footed ferret, ferret, Mustela nigripes",
377
+ "360": "otter",
378
+ "361": "skunk, polecat, wood pussy",
379
+ "362": "badger",
380
+ "363": "armadillo",
381
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
382
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
383
+ "366": "gorilla, Gorilla gorilla",
384
+ "367": "chimpanzee, chimp, Pan troglodytes",
385
+ "368": "gibbon, Hylobates lar",
386
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
387
+ "370": "guenon, guenon monkey",
388
+ "371": "patas, hussar monkey, Erythrocebus patas",
389
+ "372": "baboon",
390
+ "373": "macaque",
391
+ "374": "langur",
392
+ "375": "colobus, colobus monkey",
393
+ "376": "proboscis monkey, Nasalis larvatus",
394
+ "377": "marmoset",
395
+ "378": "capuchin, ringtail, Cebus capucinus",
396
+ "379": "howler monkey, howler",
397
+ "380": "titi, titi monkey",
398
+ "381": "spider monkey, Ateles geoffroyi",
399
+ "382": "squirrel monkey, Saimiri sciureus",
400
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
401
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
402
+ "385": "Indian elephant, Elephas maximus",
403
+ "386": "African elephant, Loxodonta africana",
404
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
405
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
406
+ "389": "barracouta, snoek",
407
+ "390": "eel",
408
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
409
+ "392": "rock beauty, Holocanthus tricolor",
410
+ "393": "anemone fish",
411
+ "394": "sturgeon",
412
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
413
+ "396": "lionfish",
414
+ "397": "puffer, pufferfish, blowfish, globefish",
415
+ "398": "abacus",
416
+ "399": "abaya",
417
+ "400": "academic gown, academic robe, judge's robe",
418
+ "401": "accordion, piano accordion, squeeze box",
419
+ "402": "acoustic guitar",
420
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
421
+ "404": "airliner",
422
+ "405": "airship, dirigible",
423
+ "406": "altar",
424
+ "407": "ambulance",
425
+ "408": "amphibian, amphibious vehicle",
426
+ "409": "analog clock",
427
+ "410": "apiary, bee house",
428
+ "411": "apron",
429
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
430
+ "413": "assault rifle, assault gun",
431
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
432
+ "415": "bakery, bakeshop, bakehouse",
433
+ "416": "balance beam, beam",
434
+ "417": "balloon",
435
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
436
+ "419": "Band Aid",
437
+ "420": "banjo",
438
+ "421": "bannister, banister, balustrade, balusters, handrail",
439
+ "422": "barbell",
440
+ "423": "barber chair",
441
+ "424": "barbershop",
442
+ "425": "barn",
443
+ "426": "barometer",
444
+ "427": "barrel, cask",
445
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
446
+ "429": "baseball",
447
+ "430": "basketball",
448
+ "431": "bassinet",
449
+ "432": "bassoon",
450
+ "433": "bathing cap, swimming cap",
451
+ "434": "bath towel",
452
+ "435": "bathtub, bathing tub, bath, tub",
453
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
454
+ "437": "beacon, lighthouse, beacon light, pharos",
455
+ "438": "beaker",
456
+ "439": "bearskin, busby, shako",
457
+ "440": "beer bottle",
458
+ "441": "beer glass",
459
+ "442": "bell cote, bell cot",
460
+ "443": "bib",
461
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
462
+ "445": "bikini, two-piece",
463
+ "446": "binder, ring-binder",
464
+ "447": "binoculars, field glasses, opera glasses",
465
+ "448": "birdhouse",
466
+ "449": "boathouse",
467
+ "450": "bobsled, bobsleigh, bob",
468
+ "451": "bolo tie, bolo, bola tie, bola",
469
+ "452": "bonnet, poke bonnet",
470
+ "453": "bookcase",
471
+ "454": "bookshop, bookstore, bookstall",
472
+ "455": "bottlecap",
473
+ "456": "bow",
474
+ "457": "bow tie, bow-tie, bowtie",
475
+ "458": "brass, memorial tablet, plaque",
476
+ "459": "brassiere, bra, bandeau",
477
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
478
+ "461": "breastplate, aegis, egis",
479
+ "462": "broom",
480
+ "463": "bucket, pail",
481
+ "464": "buckle",
482
+ "465": "bulletproof vest",
483
+ "466": "bullet train, bullet",
484
+ "467": "butcher shop, meat market",
485
+ "468": "cab, hack, taxi, taxicab",
486
+ "469": "caldron, cauldron",
487
+ "470": "candle, taper, wax light",
488
+ "471": "cannon",
489
+ "472": "canoe",
490
+ "473": "can opener, tin opener",
491
+ "474": "cardigan",
492
+ "475": "car mirror",
493
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
494
+ "477": "carpenter's kit, tool kit",
495
+ "478": "carton",
496
+ "479": "car wheel",
497
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
498
+ "481": "cassette",
499
+ "482": "cassette player",
500
+ "483": "castle",
501
+ "484": "catamaran",
502
+ "485": "CD player",
503
+ "486": "cello, violoncello",
504
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
505
+ "488": "chain",
506
+ "489": "chainlink fence",
507
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
508
+ "491": "chain saw, chainsaw",
509
+ "492": "chest",
510
+ "493": "chiffonier, commode",
511
+ "494": "chime, bell, gong",
512
+ "495": "china cabinet, china closet",
513
+ "496": "Christmas stocking",
514
+ "497": "church, church building",
515
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
516
+ "499": "cleaver, meat cleaver, chopper",
517
+ "500": "cliff dwelling",
518
+ "501": "cloak",
519
+ "502": "clog, geta, patten, sabot",
520
+ "503": "cocktail shaker",
521
+ "504": "coffee mug",
522
+ "505": "coffeepot",
523
+ "506": "coil, spiral, volute, whorl, helix",
524
+ "507": "combination lock",
525
+ "508": "computer keyboard, keypad",
526
+ "509": "confectionery, confectionary, candy store",
527
+ "510": "container ship, containership, container vessel",
528
+ "511": "convertible",
529
+ "512": "corkscrew, bottle screw",
530
+ "513": "cornet, horn, trumpet, trump",
531
+ "514": "cowboy boot",
532
+ "515": "cowboy hat, ten-gallon hat",
533
+ "516": "cradle",
534
+ "517": "crane",
535
+ "518": "crash helmet",
536
+ "519": "crate",
537
+ "520": "crib, cot",
538
+ "521": "Crock Pot",
539
+ "522": "croquet ball",
540
+ "523": "crutch",
541
+ "524": "cuirass",
542
+ "525": "dam, dike, dyke",
543
+ "526": "desk",
544
+ "527": "desktop computer",
545
+ "528": "dial telephone, dial phone",
546
+ "529": "diaper, nappy, napkin",
547
+ "530": "digital clock",
548
+ "531": "digital watch",
549
+ "532": "dining table, board",
550
+ "533": "dishrag, dishcloth",
551
+ "534": "dishwasher, dish washer, dishwashing machine",
552
+ "535": "disk brake, disc brake",
553
+ "536": "dock, dockage, docking facility",
554
+ "537": "dogsled, dog sled, dog sleigh",
555
+ "538": "dome",
556
+ "539": "doormat, welcome mat",
557
+ "540": "drilling platform, offshore rig",
558
+ "541": "drum, membranophone, tympan",
559
+ "542": "drumstick",
560
+ "543": "dumbbell",
561
+ "544": "Dutch oven",
562
+ "545": "electric fan, blower",
563
+ "546": "electric guitar",
564
+ "547": "electric locomotive",
565
+ "548": "entertainment center",
566
+ "549": "envelope",
567
+ "550": "espresso maker",
568
+ "551": "face powder",
569
+ "552": "feather boa, boa",
570
+ "553": "file, file cabinet, filing cabinet",
571
+ "554": "fireboat",
572
+ "555": "fire engine, fire truck",
573
+ "556": "fire screen, fireguard",
574
+ "557": "flagpole, flagstaff",
575
+ "558": "flute, transverse flute",
576
+ "559": "folding chair",
577
+ "560": "football helmet",
578
+ "561": "forklift",
579
+ "562": "fountain",
580
+ "563": "fountain pen",
581
+ "564": "four-poster",
582
+ "565": "freight car",
583
+ "566": "French horn, horn",
584
+ "567": "frying pan, frypan, skillet",
585
+ "568": "fur coat",
586
+ "569": "garbage truck, dustcart",
587
+ "570": "gasmask, respirator, gas helmet",
588
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
589
+ "572": "goblet",
590
+ "573": "go-kart",
591
+ "574": "golf ball",
592
+ "575": "golfcart, golf cart",
593
+ "576": "gondola",
594
+ "577": "gong, tam-tam",
595
+ "578": "gown",
596
+ "579": "grand piano, grand",
597
+ "580": "greenhouse, nursery, glasshouse",
598
+ "581": "grille, radiator grille",
599
+ "582": "grocery store, grocery, food market, market",
600
+ "583": "guillotine",
601
+ "584": "hair slide",
602
+ "585": "hair spray",
603
+ "586": "half track",
604
+ "587": "hammer",
605
+ "588": "hamper",
606
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
607
+ "590": "hand-held computer, hand-held microcomputer",
608
+ "591": "handkerchief, hankie, hanky, hankey",
609
+ "592": "hard disc, hard disk, fixed disk",
610
+ "593": "harmonica, mouth organ, harp, mouth harp",
611
+ "594": "harp",
612
+ "595": "harvester, reaper",
613
+ "596": "hatchet",
614
+ "597": "holster",
615
+ "598": "home theater, home theatre",
616
+ "599": "honeycomb",
617
+ "600": "hook, claw",
618
+ "601": "hoopskirt, crinoline",
619
+ "602": "horizontal bar, high bar",
620
+ "603": "horse cart, horse-cart",
621
+ "604": "hourglass",
622
+ "605": "iPod",
623
+ "606": "iron, smoothing iron",
624
+ "607": "jack-o'-lantern",
625
+ "608": "jean, blue jean, denim",
626
+ "609": "jeep, landrover",
627
+ "610": "jersey, T-shirt, tee shirt",
628
+ "611": "jigsaw puzzle",
629
+ "612": "jinrikisha, ricksha, rickshaw",
630
+ "613": "joystick",
631
+ "614": "kimono",
632
+ "615": "knee pad",
633
+ "616": "knot",
634
+ "617": "lab coat, laboratory coat",
635
+ "618": "ladle",
636
+ "619": "lampshade, lamp shade",
637
+ "620": "laptop, laptop computer",
638
+ "621": "lawn mower, mower",
639
+ "622": "lens cap, lens cover",
640
+ "623": "letter opener, paper knife, paperknife",
641
+ "624": "library",
642
+ "625": "lifeboat",
643
+ "626": "lighter, light, igniter, ignitor",
644
+ "627": "limousine, limo",
645
+ "628": "liner, ocean liner",
646
+ "629": "lipstick, lip rouge",
647
+ "630": "Loafer",
648
+ "631": "lotion",
649
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
650
+ "633": "loupe, jeweler's loupe",
651
+ "634": "lumbermill, sawmill",
652
+ "635": "magnetic compass",
653
+ "636": "mailbag, postbag",
654
+ "637": "mailbox, letter box",
655
+ "638": "maillot",
656
+ "639": "maillot, tank suit",
657
+ "640": "manhole cover",
658
+ "641": "maraca",
659
+ "642": "marimba, xylophone",
660
+ "643": "mask",
661
+ "644": "matchstick",
662
+ "645": "maypole",
663
+ "646": "maze, labyrinth",
664
+ "647": "measuring cup",
665
+ "648": "medicine chest, medicine cabinet",
666
+ "649": "megalith, megalithic structure",
667
+ "650": "microphone, mike",
668
+ "651": "microwave, microwave oven",
669
+ "652": "military uniform",
670
+ "653": "milk can",
671
+ "654": "minibus",
672
+ "655": "miniskirt, mini",
673
+ "656": "minivan",
674
+ "657": "missile",
675
+ "658": "mitten",
676
+ "659": "mixing bowl",
677
+ "660": "mobile home, manufactured home",
678
+ "661": "Model T",
679
+ "662": "modem",
680
+ "663": "monastery",
681
+ "664": "monitor",
682
+ "665": "moped",
683
+ "666": "mortar",
684
+ "667": "mortarboard",
685
+ "668": "mosque",
686
+ "669": "mosquito net",
687
+ "670": "motor scooter, scooter",
688
+ "671": "mountain bike, all-terrain bike, off-roader",
689
+ "672": "mountain tent",
690
+ "673": "mouse, computer mouse",
691
+ "674": "mousetrap",
692
+ "675": "moving van",
693
+ "676": "muzzle",
694
+ "677": "nail",
695
+ "678": "neck brace",
696
+ "679": "necklace",
697
+ "680": "nipple",
698
+ "681": "notebook, notebook computer",
699
+ "682": "obelisk",
700
+ "683": "oboe, hautboy, hautbois",
701
+ "684": "ocarina, sweet potato",
702
+ "685": "odometer, hodometer, mileometer, milometer",
703
+ "686": "oil filter",
704
+ "687": "organ, pipe organ",
705
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
706
+ "689": "overskirt",
707
+ "690": "oxcart",
708
+ "691": "oxygen mask",
709
+ "692": "packet",
710
+ "693": "paddle, boat paddle",
711
+ "694": "paddlewheel, paddle wheel",
712
+ "695": "padlock",
713
+ "696": "paintbrush",
714
+ "697": "pajama, pyjama, pj's, jammies",
715
+ "698": "palace",
716
+ "699": "panpipe, pandean pipe, syrinx",
717
+ "700": "paper towel",
718
+ "701": "parachute, chute",
719
+ "702": "parallel bars, bars",
720
+ "703": "park bench",
721
+ "704": "parking meter",
722
+ "705": "passenger car, coach, carriage",
723
+ "706": "patio, terrace",
724
+ "707": "pay-phone, pay-station",
725
+ "708": "pedestal, plinth, footstall",
726
+ "709": "pencil box, pencil case",
727
+ "710": "pencil sharpener",
728
+ "711": "perfume, essence",
729
+ "712": "Petri dish",
730
+ "713": "photocopier",
731
+ "714": "pick, plectrum, plectron",
732
+ "715": "pickelhaube",
733
+ "716": "picket fence, paling",
734
+ "717": "pickup, pickup truck",
735
+ "718": "pier",
736
+ "719": "piggy bank, penny bank",
737
+ "720": "pill bottle",
738
+ "721": "pillow",
739
+ "722": "ping-pong ball",
740
+ "723": "pinwheel",
741
+ "724": "pirate, pirate ship",
742
+ "725": "pitcher, ewer",
743
+ "726": "plane, carpenter's plane, woodworking plane",
744
+ "727": "planetarium",
745
+ "728": "plastic bag",
746
+ "729": "plate rack",
747
+ "730": "plow, plough",
748
+ "731": "plunger, plumber's helper",
749
+ "732": "Polaroid camera, Polaroid Land camera",
750
+ "733": "pole",
751
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
752
+ "735": "poncho",
753
+ "736": "pool table, billiard table, snooker table",
754
+ "737": "pop bottle, soda bottle",
755
+ "738": "pot, flowerpot",
756
+ "739": "potter's wheel",
757
+ "740": "power drill",
758
+ "741": "prayer rug, prayer mat",
759
+ "742": "printer",
760
+ "743": "prison, prison house",
761
+ "744": "projectile, missile",
762
+ "745": "projector",
763
+ "746": "puck, hockey puck",
764
+ "747": "punching bag, punch bag, punching ball, punchball",
765
+ "748": "purse",
766
+ "749": "quill, quill pen",
767
+ "750": "quilt, comforter, comfort, puff",
768
+ "751": "racer, race car, racing car",
769
+ "752": "racket, racquet",
770
+ "753": "radiator",
771
+ "754": "radio, wireless",
772
+ "755": "radio telescope, radio reflector",
773
+ "756": "rain barrel",
774
+ "757": "recreational vehicle, RV, R.V.",
775
+ "758": "reel",
776
+ "759": "reflex camera",
777
+ "760": "refrigerator, icebox",
778
+ "761": "remote control, remote",
779
+ "762": "restaurant, eating house, eating place, eatery",
780
+ "763": "revolver, six-gun, six-shooter",
781
+ "764": "rifle",
782
+ "765": "rocking chair, rocker",
783
+ "766": "rotisserie",
784
+ "767": "rubber eraser, rubber, pencil eraser",
785
+ "768": "rugby ball",
786
+ "769": "rule, ruler",
787
+ "770": "running shoe",
788
+ "771": "safe",
789
+ "772": "safety pin",
790
+ "773": "saltshaker, salt shaker",
791
+ "774": "sandal",
792
+ "775": "sarong",
793
+ "776": "sax, saxophone",
794
+ "777": "scabbard",
795
+ "778": "scale, weighing machine",
796
+ "779": "school bus",
797
+ "780": "schooner",
798
+ "781": "scoreboard",
799
+ "782": "screen, CRT screen",
800
+ "783": "screw",
801
+ "784": "screwdriver",
802
+ "785": "seat belt, seatbelt",
803
+ "786": "sewing machine",
804
+ "787": "shield, buckler",
805
+ "788": "shoe shop, shoe-shop, shoe store",
806
+ "789": "shoji",
807
+ "790": "shopping basket",
808
+ "791": "shopping cart",
809
+ "792": "shovel",
810
+ "793": "shower cap",
811
+ "794": "shower curtain",
812
+ "795": "ski",
813
+ "796": "ski mask",
814
+ "797": "sleeping bag",
815
+ "798": "slide rule, slipstick",
816
+ "799": "sliding door",
817
+ "800": "slot, one-armed bandit",
818
+ "801": "snorkel",
819
+ "802": "snowmobile",
820
+ "803": "snowplow, snowplough",
821
+ "804": "soap dispenser",
822
+ "805": "soccer ball",
823
+ "806": "sock",
824
+ "807": "solar dish, solar collector, solar furnace",
825
+ "808": "sombrero",
826
+ "809": "soup bowl",
827
+ "810": "space bar",
828
+ "811": "space heater",
829
+ "812": "space shuttle",
830
+ "813": "spatula",
831
+ "814": "speedboat",
832
+ "815": "spider web, spider's web",
833
+ "816": "spindle",
834
+ "817": "sports car, sport car",
835
+ "818": "spotlight, spot",
836
+ "819": "stage",
837
+ "820": "steam locomotive",
838
+ "821": "steel arch bridge",
839
+ "822": "steel drum",
840
+ "823": "stethoscope",
841
+ "824": "stole",
842
+ "825": "stone wall",
843
+ "826": "stopwatch, stop watch",
844
+ "827": "stove",
845
+ "828": "strainer",
846
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
847
+ "830": "stretcher",
848
+ "831": "studio couch, day bed",
849
+ "832": "stupa, tope",
850
+ "833": "submarine, pigboat, sub, U-boat",
851
+ "834": "suit, suit of clothes",
852
+ "835": "sundial",
853
+ "836": "sunglass",
854
+ "837": "sunglasses, dark glasses, shades",
855
+ "838": "sunscreen, sunblock, sun blocker",
856
+ "839": "suspension bridge",
857
+ "840": "swab, swob, mop",
858
+ "841": "sweatshirt",
859
+ "842": "swimming trunks, bathing trunks",
860
+ "843": "swing",
861
+ "844": "switch, electric switch, electrical switch",
862
+ "845": "syringe",
863
+ "846": "table lamp",
864
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
865
+ "848": "tape player",
866
+ "849": "teapot",
867
+ "850": "teddy, teddy bear",
868
+ "851": "television, television system",
869
+ "852": "tennis ball",
870
+ "853": "thatch, thatched roof",
871
+ "854": "theater curtain, theatre curtain",
872
+ "855": "thimble",
873
+ "856": "thresher, thrasher, threshing machine",
874
+ "857": "throne",
875
+ "858": "tile roof",
876
+ "859": "toaster",
877
+ "860": "tobacco shop, tobacconist shop, tobacconist",
878
+ "861": "toilet seat",
879
+ "862": "torch",
880
+ "863": "totem pole",
881
+ "864": "tow truck, tow car, wrecker",
882
+ "865": "toyshop",
883
+ "866": "tractor",
884
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
885
+ "868": "tray",
886
+ "869": "trench coat",
887
+ "870": "tricycle, trike, velocipede",
888
+ "871": "trimaran",
889
+ "872": "tripod",
890
+ "873": "triumphal arch",
891
+ "874": "trolleybus, trolley coach, trackless trolley",
892
+ "875": "trombone",
893
+ "876": "tub, vat",
894
+ "877": "turnstile",
895
+ "878": "typewriter keyboard",
896
+ "879": "umbrella",
897
+ "880": "unicycle, monocycle",
898
+ "881": "upright, upright piano",
899
+ "882": "vacuum, vacuum cleaner",
900
+ "883": "vase",
901
+ "884": "vault",
902
+ "885": "velvet",
903
+ "886": "vending machine",
904
+ "887": "vestment",
905
+ "888": "viaduct",
906
+ "889": "violin, fiddle",
907
+ "890": "volleyball",
908
+ "891": "waffle iron",
909
+ "892": "wall clock",
910
+ "893": "wallet, billfold, notecase, pocketbook",
911
+ "894": "wardrobe, closet, press",
912
+ "895": "warplane, military plane",
913
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
914
+ "897": "washer, automatic washer, washing machine",
915
+ "898": "water bottle",
916
+ "899": "water jug",
917
+ "900": "water tower",
918
+ "901": "whiskey jug",
919
+ "902": "whistle",
920
+ "903": "wig",
921
+ "904": "window screen",
922
+ "905": "window shade",
923
+ "906": "Windsor tie",
924
+ "907": "wine bottle",
925
+ "908": "wing",
926
+ "909": "wok",
927
+ "910": "wooden spoon",
928
+ "911": "wool, woolen, woollen",
929
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
930
+ "913": "wreck",
931
+ "914": "yawl",
932
+ "915": "yurt",
933
+ "916": "web site, website, internet site, site",
934
+ "917": "comic book",
935
+ "918": "crossword puzzle, crossword",
936
+ "919": "street sign",
937
+ "920": "traffic light, traffic signal, stoplight",
938
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
939
+ "922": "menu",
940
+ "923": "plate",
941
+ "924": "guacamole",
942
+ "925": "consomme",
943
+ "926": "hot pot, hotpot",
944
+ "927": "trifle",
945
+ "928": "ice cream, icecream",
946
+ "929": "ice lolly, lolly, lollipop, popsicle",
947
+ "930": "French loaf",
948
+ "931": "bagel, beigel",
949
+ "932": "pretzel",
950
+ "933": "cheeseburger",
951
+ "934": "hotdog, hot dog, red hot",
952
+ "935": "mashed potato",
953
+ "936": "head cabbage",
954
+ "937": "broccoli",
955
+ "938": "cauliflower",
956
+ "939": "zucchini, courgette",
957
+ "940": "spaghetti squash",
958
+ "941": "acorn squash",
959
+ "942": "butternut squash",
960
+ "943": "cucumber, cuke",
961
+ "944": "artichoke, globe artichoke",
962
+ "945": "bell pepper",
963
+ "946": "cardoon",
964
+ "947": "mushroom",
965
+ "948": "Granny Smith",
966
+ "949": "strawberry",
967
+ "950": "orange",
968
+ "951": "lemon",
969
+ "952": "fig",
970
+ "953": "pineapple, ananas",
971
+ "954": "banana",
972
+ "955": "jackfruit, jak, jack",
973
+ "956": "custard apple",
974
+ "957": "pomegranate",
975
+ "958": "hay",
976
+ "959": "carbonara",
977
+ "960": "chocolate sauce, chocolate syrup",
978
+ "961": "dough",
979
+ "962": "meat loaf, meatloaf",
980
+ "963": "pizza, pizza pie",
981
+ "964": "potpie",
982
+ "965": "burrito",
983
+ "966": "red wine",
984
+ "967": "espresso",
985
+ "968": "cup",
986
+ "969": "eggnog",
987
+ "970": "alp",
988
+ "971": "bubble",
989
+ "972": "cliff, drop, drop-off",
990
+ "973": "coral reef",
991
+ "974": "geyser",
992
+ "975": "lakeside, lakeshore",
993
+ "976": "promontory, headland, head, foreland",
994
+ "977": "sandbar, sand bar",
995
+ "978": "seashore, coast, seacoast, sea-coast",
996
+ "979": "valley, vale",
997
+ "980": "volcano",
998
+ "981": "ballplayer, baseball player",
999
+ "982": "groom, bridegroom",
1000
+ "983": "scuba diver",
1001
+ "984": "rapeseed",
1002
+ "985": "daisy",
1003
+ "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1004
+ "987": "corn",
1005
+ "988": "acorn",
1006
+ "989": "hip, rose hip, rosehip",
1007
+ "990": "buckeye, horse chestnut, conker",
1008
+ "991": "coral fungus",
1009
+ "992": "agaric",
1010
+ "993": "gyromitra",
1011
+ "994": "stinkhorn, carrion fungus",
1012
+ "995": "earthstar",
1013
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1014
+ "997": "bolete",
1015
+ "998": "ear, spike, capitulum",
1016
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1017
+ }
1018
+ }
var/D3HR/DiT-XL/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.12.0.dev0",
4
+ "beta_end": 0.02,
5
+ "beta_schedule": "linear",
6
+ "beta_start": 0.0001,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "set_alpha_to_one": true,
11
+ "steps_offset": 0,
12
+ "trained_betas": null
13
+ }
var/D3HR/DiT-XL/transformer/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Transformer2DModel",
3
+ "_diffusers_version": "0.12.0.dev0",
4
+ "activation_fn": "gelu-approximate",
5
+ "attention_bias": true,
6
+ "attention_head_dim": 72,
7
+ "cross_attention_dim": null,
8
+ "dropout": 0.0,
9
+ "in_channels": 4,
10
+ "norm_elementwise_affine": false,
11
+ "norm_num_groups": 32,
12
+ "norm_type": "ada_norm_zero",
13
+ "num_attention_heads": 16,
14
+ "num_embeds_ada_norm": 1000,
15
+ "num_layers": 28,
16
+ "num_vector_embeds": null,
17
+ "only_cross_attention": false,
18
+ "out_channels": 8,
19
+ "patch_size": 2,
20
+ "sample_size": 32,
21
+ "upcast_attention": false,
22
+ "use_linear_projection": false
23
+ }
var/D3HR/DiT-XL/transformer/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e592d64df5a579691e65d2b245641a00bb070b652e2c5ca775cce20a729ce9d9
3
+ size 2999533581
var/D3HR/DiT-XL/vae/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.12.0.dev0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-ema",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "up_block_types": [
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D"
29
+ ]
30
+ }
var/D3HR/DiT-XL/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ec99ed96663e2418dba665762930f9eae8884e6b0a223fd53507931e8446eba
3
+ size 334711857
var/D3HR/README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taming Diffusion for Dataset Distillation with High Representativeness (ICML 2025)
2
+
3
+ This repository is the official implementation of the paper:
4
+
5
+ [**Taming Diffusion for Dataset Distillation with High Representativeness**](https://www.arxiv.org/pdf/2505.18399)
6
+ [*Lin Zhao*](https://lin-zhao-resolve.github.io/),
7
+ [*Yushu Wu*](https://wuyushuwys.github.io/),
8
+ [*Xinru Jiang*](https://oshikaka.github.io/),
9
+ [*Jianyang Gu*](https://vimar-gu.github.io/),
10
+ [*Yanzhi Wang*](https://coe.northeastern.edu/people/wang-yanzhi/),
11
+ [*Xiaolin Xu*](https://www.xiaolinxu.com/),
12
+ [*Pu Zhao*](https://puzhao.info/),
13
+ [*Xue Lin*](https://coe.northeastern.edu/people/lin-xue/),
14
+ ICML, 2025.
15
+
16
+ <div align=center>
17
+ <img width=85% src="./imgs/framework.jpg"/>
18
+ </div>
19
+
20
+ ## Usage
21
+
22
+ 1. [Distilled Datasets](#distilled-datasets)
23
+ 2. [Setup](#setup)
24
+ 3. [Step1: DDIM inversion and distribution matching](#step1-ddim-inversion-and-distribution-matching)
25
+ 4. [Step2: Group sampling](#step2-group-sampling)
26
+ 6. [Evaluation](#evaluation)
27
+
28
+
29
+
30
+ ## Distilled Datasets
31
+ We provide distilled datasets with different IPCs generated by our method on Huggingface🤗! [*Imagenet-1K*](https://www.image-net.org/), [*Tiny-Imagenet*](https://www.kaggle.com/c/tiny-imagenet), [*CIFAR10*](https://www.cs.toronto.edu/~kriz/cifar.html), [*CIFAR100*](https://www.cs.toronto.edu/~kriz/cifar.html) datasets for users to use directly.
32
+
33
+ 🔥Distilled datasets for Imagenet-1K: [10IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/imagenet1k_10ipc), [50IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/imagenet1k_50ipc)
34
+
35
+ 🔥Distilled datasets for Tiny-Imagenet: [10IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/tinyimagenet_10ipc), [50IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/tinyimagenet_50ipc)
36
+
37
+ 🔥Distilled datasets for CIFAR10: [10IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/cifar10_10ipc), [50IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/cifar10_50ipc)
38
+
39
+ 🔥Distilled datasets for CIFAR100: [10IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/cifar100_10ipc), [50IPC](https://huggingface.co/datasets/lin-zhao-resoLve/D3HR/tree/main/cifar100_50ipc)
40
+
41
+ Besides, if you want to use the D3HR to generate distilled datasets by yourself, run the following steps:
42
+
43
+ ## Setup
44
+
45
+ To install the required dependencies, use the following commands:
46
+
47
+ ```bash
48
+ conda create -n D3HR python=3.10
49
+ conda activate D3HR
50
+ cd D3HR
51
+ pip install -e .
52
+ ```
53
+
54
+ ## Step1: DDIM inversion and distribution matching
55
+
56
+ ### (1) load pretrained model
57
+ For Imagenet-1K dataset, you can just use the pretrained [DiT](https://github.com/facebookresearch/DiT) model in huggingface:
58
+ ```bash
59
+ huggingface-cli download facebook/DiT-XL-2-256 --local-dir <your_local_path>
60
+ ```
61
+ For other datasets, you must first fine-tune the pretrained DiT model on the dataset ([github repo](https://github.com/facebookresearch/DiT)), then continue.
62
+
63
+ ### (2) perform DDIM inversion and distribution matching to obtain the statistic information
64
+ ```bash
65
+ sh generation/dit_inversion_save_statistic.sh
66
+ ```
67
+ Note: By default, we store the results at 15 timesteps (23 < t < 39) to support the experiments in Section 6.2.
68
+
69
+ ## Step2: Group sampling
70
+ ```bash
71
+ sh generation/group_sampling.sh
72
+ ```
73
+
74
+
75
+ ## Evaluation
76
+ ```bash
77
+ sh validation/validate.sh
78
+ ```
79
+ Note: The .sh script includes several configuration options—select the one that best fits your needs.
80
+
81
+ ## Acknowledgement
82
+ This project is mainly developed based on:
83
+ [DiT](https://github.com/facebookresearch/DiT)
84
+
85
+
86
+ ## Contact
87
+ If you have any questions, please contact zhao.lin1@northeastern.edu.
88
+
89
+ ## Citation
90
+ If you find our work useful, please cite:
91
+
92
+ ```BiBTeX
93
+ @inproceedings{zhaotaming,
94
+ title={Taming Diffusion for Dataset Distillation with High Representativeness},
95
+ author={Zhao, Lin and Wu, Yushu and Jiang, Xinru and Gu, Jianyang and Wang, Yanzhi and Xu, Xiaolin and Zhao, Pu and Lin, Xue},
96
+ booktitle={Forty-second International Conference on Machine Learning}
97
+ }
98
+ ```
var/D3HR/ds_inf/imagenet1k_train.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fe7183ea2495a3d15bf50036db1047ea387bf397c8c6f1a0bcc30f42df957ce
3
+ size 64470500
var/D3HR/ds_inf/imagenet_1k_mapping.json ADDED
The diff for this file is too large to render. See raw diff
 
var/D3HR/ds_inf/tiny-imagenet-mapping.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ n02124075 285
2
+ n04067472 758
3
+ n04540053 890
4
+ n04099969 765
5
+ n07749582 951
6
+ n01641577 30
7
+ n02802426 430
8
+ n09246464 972
9
+ n07920052 967
10
+ n03970156 731
11
+ n03891332 704
12
+ n02106662 235
13
+ n03201208 532
14
+ n02279972 323
15
+ n02132136 294
16
+ n04146614 779
17
+ n07873807 963
18
+ n02364673 338
19
+ n04507155 879
20
+ n03854065 687
21
+ n03838899 683
22
+ n03733131 645
23
+ n01443537 1
24
+ n07875152 964
25
+ n03544143 604
26
+ n09428293 978
27
+ n03085013 508
28
+ n02437312 354
29
+ n07614500 928
30
+ n03804744 677
31
+ n04265275 811
32
+ n02963159 474
33
+ n02486410 372
34
+ n01944390 113
35
+ n09256479 973
36
+ n02058221 146
37
+ n04275548 815
38
+ n02321529 329
39
+ n02769748 414
40
+ n02099712 208
41
+ n07695742 932
42
+ n02056570 145
43
+ n02281406 325
44
+ n01774750 76
45
+ n02509815 387
46
+ n03983396 737
47
+ n07753592 954
48
+ n04254777 806
49
+ n02233338 314
50
+ n04008634 744
51
+ n02823428 440
52
+ n02236044 315
53
+ n03393912 565
54
+ n07583066 924
55
+ n04074963 761
56
+ n01629819 25
57
+ n09332890 975
58
+ n02481823 367
59
+ n03902125 707
60
+ n03404251 568
61
+ n09193705 970
62
+ n03637318 619
63
+ n04456115 862
64
+ n02666196 398
65
+ n03796401 675
66
+ n02795169 427
67
+ n02123045 281
68
+ n01855672 99
69
+ n01882714 105
70
+ n02917067 466
71
+ n02988304 485
72
+ n04398044 849
73
+ n02843684 448
74
+ n02423022 353
75
+ n02669723 400
76
+ n04465501 866
77
+ n02165456 301
78
+ n03770439 655
79
+ n02099601 207
80
+ n04486054 873
81
+ n02950826 471
82
+ n03814639 678
83
+ n04259630 808
84
+ n03424325 570
85
+ n02948072 470
86
+ n03179701 526
87
+ n03400231 567
88
+ n02206856 309
89
+ n03160309 525
90
+ n01984695 123
91
+ n03977966 734
92
+ n03584254 605
93
+ n04023962 747
94
+ n02814860 437
95
+ n01910747 107
96
+ n04596742 909
97
+ n03992509 739
98
+ n04133789 774
99
+ n03937543 720
100
+ n02927161 467
101
+ n01945685 114
102
+ n02395406 341
103
+ n02125311 286
104
+ n03126707 517
105
+ n04532106 887
106
+ n02268443 319
107
+ n02977058 480
108
+ n07734744 947
109
+ n03599486 612
110
+ n04562935 900
111
+ n03014705 492
112
+ n04251144 801
113
+ n04356056 837
114
+ n02190166 308
115
+ n03670208 627
116
+ n02002724 128
117
+ n02074367 149
118
+ n04285008 817
119
+ n04560804 899
120
+ n04366367 839
121
+ n02403003 345
122
+ n07615774 929
123
+ n04501370 877
124
+ n03026506 496
125
+ n02906734 462
126
+ n01770393 71
127
+ n04597913 910
128
+ n03930313 716
129
+ n04118538 768
130
+ n04179913 786
131
+ n04311004 821
132
+ n02123394 283
133
+ n04070727 760
134
+ n02793495 425
135
+ n02730930 411
136
+ n02094433 187
137
+ n04371430 842
138
+ n04328186 826
139
+ n03649909 621
140
+ n04417672 853
141
+ n03388043 562
142
+ n01774384 75
143
+ n02837789 445
144
+ n07579787 923
145
+ n04399382 850
146
+ n02791270 424
147
+ n03089624 509
148
+ n02814533 436
149
+ n04149813 781
150
+ n07747607 950
151
+ n03355925 557
152
+ n01983481 122
153
+ n04487081 874
154
+ n03250847 542
155
+ n03255030 543
156
+ n02892201 458
157
+ n02883205 457
158
+ n03100240 511
159
+ n02415577 349
160
+ n02480495 365
161
+ n01698640 50
162
+ n01784675 79
163
+ n04376876 845
164
+ n03444034 573
165
+ n01917289 109
166
+ n01950731 115
167
+ n03042490 500
168
+ n07711569 935
169
+ n04532670 888
170
+ n03763968 652
171
+ n07768694 957
172
+ n02999410 488
173
+ n03617480 614
174
+ n06596364 917
175
+ n01768244 69
176
+ n02410509 347
177
+ n03976657 733
178
+ n01742172 61
179
+ n03980874 735
180
+ n02808440 435
181
+ n02226429 311
182
+ n02231487 313
183
+ n02085620 151
184
+ n01644900 32
185
+ n02129165 291
186
+ n02699494 406
187
+ n03837869 682
188
+ n02815834 438
189
+ n07720875 945
190
+ n02788148 421
191
+ n02909870 463
192
+ n03706229 635
193
+ n07871810 962
194
+ n03447447 576
195
+ n02113799 267
196
+ n12267677 988
197
+ n03662601 625
198
+ n02841315 447
199
+ n07715103 938
200
+ n02504458 386
var/D3HR/generation/__init__.py ADDED
File without changes
var/D3HR/generation/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (136 Bytes). View file
 
var/D3HR/generation/__pycache__/dit_inversion_save_statistic.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
var/D3HR/generation/dit_inversion_save_statistic.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from tqdm.auto import tqdm
6
+ from matplotlib import pyplot as plt
7
+ from torchvision import transforms as tfms
8
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, DiTPipeline
9
+ import argparse
10
+ import os
11
+ from scipy import io
12
+ from diffusers import DiTPipeline
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+ import ipdb
16
+ from torch.utils.data import Dataset
17
+ import torchvision.transforms as transforms
18
+ from itertools import islice
19
+ import json
20
+
21
+ # sample
22
+ @torch.no_grad()
23
+ def sample(
24
+ pipe,
25
+ class_labels,
26
+ start_step=0,
27
+ start_latents=None,
28
+ guidance_scale=4.0,
29
+ num_inference_steps=30,
30
+ do_classifier_free_guidance=True,
31
+ device=None,
32
+ ):
33
+
34
+ batch_size = len(class_labels)
35
+ latent_size = pipe.transformer.config.sample_size
36
+ latent_channels = pipe.transformer.config.in_channels
37
+ if start_latents == None:
38
+ latents = randn_tensor(
39
+ shape=(batch_size, latent_channels, latent_size, latent_size),
40
+ generator=generator,
41
+ device=pipe._execution_device,
42
+ dtype=pipe.transformer.dtype,
43
+ )
44
+ else:
45
+ latents = start_latents.clone()
46
+
47
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
48
+
49
+ class_labels = torch.tensor(class_labels, device=device).reshape(-1)
50
+ class_null = torch.tensor([1000] * batch_size, device=device)
51
+ class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
52
+ class_labels_input = class_labels_input.to(device)
53
+
54
+ # set step values
55
+ pipe.scheduler.set_timesteps(num_inference_steps)
56
+
57
+
58
+ for i in tqdm(range(start_step, num_inference_steps)):
59
+
60
+ t = pipe.scheduler.timesteps[i]
61
+
62
+ if do_classifier_free_guidance:
63
+ half = latent_model_input[: len(latent_model_input) // 2]
64
+ latent_model_input = torch.cat([half, half], dim=0)
65
+
66
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
67
+
68
+ timesteps = t
69
+ if not torch.is_tensor(timesteps):
70
+ is_mps = latent_model_input.device.type == "mps"
71
+ if isinstance(timesteps, float):
72
+ dtype = torch.float32 if is_mps else torch.float64
73
+ else:
74
+ dtype = torch.int32 if is_mps else torch.int64
75
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
76
+ elif len(timesteps.shape) == 0:
77
+ timesteps = timesteps[None].to(latent_model_input.device)
78
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
79
+ timesteps = timesteps.expand(latent_model_input.shape[0])
80
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
81
+ # predict noise model_output
82
+ noise_pred = pipe.transformer(
83
+ latent_model_input, timestep=timesteps, class_labels=class_labels_input
84
+ ).sample
85
+
86
+ # Perform guidance
87
+ if do_classifier_free_guidance:
88
+ # perform guidance
89
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
90
+ # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
91
+
92
+ eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
93
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
94
+
95
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
96
+ eps = torch.cat([half_eps, half_eps], dim=0)
97
+
98
+ noise_pred = torch.cat([eps, rest], dim=1)
99
+
100
+ # learned sigma
101
+ if pipe.transformer.config.out_channels // 2 == latent_channels:
102
+
103
+ model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
104
+ else:
105
+ model_output = noise_pred
106
+
107
+ # calculate ddim:
108
+ prev_t = max(1, t.item() - (1000 // num_inference_steps)) # t-1
109
+ alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
110
+ alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
111
+ predicted_x0 = (latent_model_input - (1 - alpha_t).sqrt() * model_output) / alpha_t.sqrt()
112
+ direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * model_output
113
+ latent_model_input = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt
114
+ # latent_model_input = pipe.scheduler.step(model_output, t, latent_model_input).prev_sample
115
+
116
+ if guidance_scale > 1:
117
+ latents, _ = latent_model_input.chunk(2, dim=0)
118
+ else:
119
+ latents = latent_model_input
120
+
121
+ latents = 1 / pipe.vae.config.scaling_factor * latents
122
+ samples = pipe.vae.decode(latents).sample
123
+
124
+ samples = (samples / 2 + 0.5).clamp(0, 1)
125
+
126
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
127
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
128
+
129
+ return samples
130
+
131
+
132
+
133
+
134
+
135
+
136
+ ## Inversion
137
+ @torch.no_grad()
138
+ def invert(
139
+ pipe,
140
+ start_latents,
141
+ class_labels,
142
+ guidance_scale=4.0,
143
+ num_inference_steps=80,
144
+ do_classifier_free_guidance=True,
145
+ device=None,
146
+ ):
147
+
148
+ batch_size = len(class_labels)
149
+ latent_size = pipe.transformer.config.sample_size
150
+ latent_channels = pipe.transformer.config.in_channels
151
+
152
+ latents = start_latents.clone()
153
+
154
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
155
+
156
+ class_labels = torch.tensor(class_labels, device=device).reshape(-1)
157
+ class_null = torch.tensor([1000] * batch_size, device=device)
158
+ class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
159
+ class_labels_input = class_labels_input.to(device)
160
+
161
+ # set step values
162
+ pipe.scheduler.set_timesteps(num_inference_steps)
163
+ intermediate_latents = []
164
+
165
+ # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
166
+ timesteps_all = reversed(pipe.scheduler.timesteps)
167
+
168
+ for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):
169
+
170
+ # We'll skip the final iteration
171
+ if i >= num_inference_steps - 1 -10:
172
+ continue
173
+
174
+ t = timesteps_all[i]
175
+
176
+ if do_classifier_free_guidance:
177
+ half = latent_model_input[: len(latent_model_input) // 2]
178
+ latent_model_input = torch.cat([half, half], dim=0)
179
+
180
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
181
+
182
+ timesteps = t
183
+ if not torch.is_tensor(timesteps):
184
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
185
+ # This would be a good case for the `match` statement (Python 3.10+)
186
+ is_mps = latent_model_input.device.type == "mps"
187
+ if isinstance(timesteps, float):
188
+ dtype = torch.float32 if is_mps else torch.float64
189
+ else:
190
+ dtype = torch.int32 if is_mps else torch.int64
191
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
192
+ elif len(timesteps.shape) == 0:
193
+ timesteps = timesteps[None].to(latent_model_input.device)
194
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
195
+ timesteps = timesteps.expand(latent_model_input.shape[0])
196
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
197
+ # predict noise model_output
198
+ noise_pred = pipe.transformer(
199
+ latent_model_input, timestep=timesteps, class_labels=class_labels_input
200
+ ).sample
201
+
202
+ # Perform guidance
203
+ if do_classifier_free_guidance:
204
+ # perform guidance
205
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
206
+ # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
207
+
208
+ eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
209
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
210
+
211
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
212
+ eps = torch.cat([half_eps, half_eps], dim=0)
213
+
214
+ noise_pred = torch.cat([eps, rest], dim=1)
215
+
216
+ # learned sigma
217
+ if pipe.transformer.config.out_channels // 2 == latent_channels:
218
+
219
+ model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
220
+ else:
221
+ model_output = noise_pred
222
+
223
+ current_t = max(0, t.item() - (1000 // num_inference_steps)) # t
224
+ next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1
225
+ alpha_t = pipe.scheduler.alphas_cumprod[current_t]
226
+ alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]
227
+
228
+ # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
229
+ latent_model_input = (latent_model_input - (1 - alpha_t).sqrt() * model_output) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
230
+ 1 - alpha_t_next
231
+ ).sqrt() * model_output
232
+
233
+ if guidance_scale > 1:
234
+ latents_out, _ = latent_model_input.chunk(2, dim=0)
235
+ else:
236
+ latents_out = latent_model_input
237
+
238
+ # Store i=[3, 8, 13, 18, 23, 28, 33, 38, 43, 48]
239
+ # if (i+2)%5 == 0:
240
+ if i>23 and i<39:
241
+ intermediate_latents.append(latents_out)
242
+ return torch.stack(intermediate_latents, dim=0)
243
+
244
+ # return torch.cat(intermediate_latents)
245
+
246
+
247
+ def parse_args():
248
+ parser = argparse.ArgumentParser(
249
+ description="Script to train Stable Diffusion XL for InstructPix2Pix."
250
+ )
251
+ parser.add_argument(
252
+ "--save_dir",
253
+ type=str,
254
+ default="/scratch/zhao.lin1/ddim_inversion_statistic",
255
+ help="statistic save path",
256
+ )
257
+ parser.add_argument(
258
+ "--mapping_file",
259
+ type=str,
260
+ default="ds_inf/imagenet_1k_mapping.json",
261
+ )
262
+ parser.add_argument("--txt_file", default='ds_inf/imagenet1k_train.txt', type=str)
263
+ parser.add_argument("--pretrained_path", default='/scratch/zhao.lin1/DiT-XL-2-256', type=str)
264
+ parser.add_argument(
265
+ "--batch_size",
266
+ type=int,
267
+ default=200,
268
+ )
269
+ parser.add_argument(
270
+ "--num_workers",
271
+ type=int,
272
+ default=24,
273
+ )
274
+ parser.add_argument(
275
+ "--start",
276
+ type=int,
277
+ default=0,
278
+ )
279
+ parser.add_argument(
280
+ "--end",
281
+ type=int,
282
+ default=25,
283
+ )
284
+ parser.add_argument(
285
+ "--gpu",
286
+ type=int,
287
+ default=1,
288
+ )
289
+ args = parser.parse_args()
290
+
291
+ return args
292
+
293
+ def view_latents(pipe = None, inverted_latents = None):
294
+ with torch.no_grad():
295
+ im = pipe.decode_latents(inverted_latents[-1].unsqueeze(0))
296
+ pipe.numpy_to_pil(im)[0]
297
+
298
+
299
+ def collate_fn(batch):
300
+ batch = [item for item in batch if item is not None]
301
+
302
+ if len(batch) == 0:
303
+ return None
304
+
305
+ images = torch.stack([item['images'] for item in batch])
306
+ labels = torch.tensor([item['labels'] for item in batch])
307
+ idx = torch.tensor([item['idx'] for item in batch])
308
+ paths = [item['paths'] for item in batch]
309
+
310
+ return {
311
+ 'images': images,
312
+ 'labels': labels,
313
+ 'idx':idx,
314
+ 'paths': paths
315
+ }
316
+
317
+
318
+ def save_latent(latent, save_path):
319
+ torch.save(latent, save_path)
320
+
321
+
322
+
323
+ class ImageNetDataset(Dataset):
324
+ def __init__(self, txt_file='', mapping_file=None, class_dir=None):
325
+ self.images = []
326
+ self.img_labels = []
327
+ self.class_dir = class_dir
328
+ self.transform = self.get_transforms()
329
+
330
+ # Load class mapping and json file
331
+ self.wnid_to_index = load_mapping(mapping_file)
332
+ self._load_from_txt(txt_file)
333
+
334
+
335
+ def _load_from_txt(self, txt_file):
336
+ with open(txt_file, "r") as file:
337
+ image_paths = file.readlines()
338
+ image_paths = [path.strip() for path in image_paths if path.split('/')[-2]==self.class_dir]
339
+ for path in image_paths:
340
+ self.images.append(path)
341
+ class_index = self.wnid_to_index[path.split('/')[-2]]
342
+ self.img_labels.append(class_index)
343
+
344
+
345
+ def get_transforms(self):
346
+ # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
347
+ # std=[0.229, 0.224, 0.225])
348
+
349
+ return transforms.Compose([
350
+ transforms.Resize(256),
351
+ transforms.CenterCrop(256),
352
+ transforms.ToTensor(),
353
+ # normalize
354
+ ])
355
+
356
+ def __len__(self):
357
+ return len(self.images)
358
+
359
+ def __getitem__(self, idx):
360
+ img_path = self.images[idx]
361
+ try:
362
+ image = Image.open(img_path).convert('RGB')
363
+ except Exception as e:
364
+ print(f"Error loading image {img_path}: {e}")
365
+ # Return a black image in case of error
366
+ image = Image.new('RGB', (256, 256))
367
+
368
+ img_label = self.img_labels[idx]
369
+
370
+ if self.transform:
371
+ image = self.transform(image)
372
+
373
+ sample = {
374
+ 'images': image,
375
+ 'paths': img_path,
376
+ 'labels': img_label,
377
+ 'idx': idx
378
+ }
379
+
380
+ return sample
381
+
382
+
383
+ def load_mapping(mapping_file):
384
+ new_mapping = {}
385
+ with open(mapping_file, 'r') as file:
386
+ data = json.load(file)
387
+ if "tiny" in mapping_file:
388
+ for index, line in enumerate(file):
389
+ # Extract wnid (eg. n01443537) for each line and -1
390
+ key = line.split()[0]
391
+ new_mapping[key] = index
392
+ else:
393
+ new_mapping = {item["wnid"]: item["index"] for item in data.values()}
394
+ return new_mapping
395
+
396
+ def main():
397
+ args = parse_args()
398
+ torch.cuda.set_device(args.gpu)
399
+ device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
400
+
401
+ torch.cuda.set_device(args.gpu)
402
+
403
+ wnid_to_index = load_mapping(args.mapping_file)
404
+ class_dirs = sorted(list(wnid_to_index.keys()))[args.start:args.end]
405
+
406
+ pipe = DiTPipeline.from_pretrained(args.pretrained_path, torch_dtype=torch.float16)
407
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
408
+
409
+ os.makedirs(args.save_dir, exist_ok=True)
410
+
411
+
412
+ pipe = pipe.to(device)
413
+
414
+ for class_dir in tqdm(class_dirs):
415
+ imgnet1k_dataset = ImageNetDataset(args.txt_file, args.mapping_file, class_dir)
416
+ trainloader = torch.utils.data.DataLoader(imgnet1k_dataset, batch_size=args.batch_size, shuffle=False,num_workers=args.num_workers, drop_last=False, collate_fn=collate_fn)
417
+ latents = []
418
+ for sample in tqdm(trainloader):
419
+ with torch.no_grad():
420
+ images = sample['images'].to(device)
421
+ latent = pipe.vae.encode(images.to(device, dtype=torch.float16) * 2 - 1)
422
+ ls = 0.18215 * latent.latent_dist.sample()
423
+
424
+ inverted_latents = invert(pipe,start_latents = ls, class_labels=sample['labels'], num_inference_steps = 50, device=device).cpu()
425
+ latents.append(torch.flatten(inverted_latents.permute(1,0,2,3,4), start_dim=2))
426
+
427
+ latents = torch.cat(latents, dim=0).cpu()
428
+ mean = latents.mean(dim=0)
429
+ variance = latents.var(dim=0)
430
+ torch.save({"mean": mean, "variance": variance}, os.path.join(args.save_dir,class_dir+'.pt'))
431
+
432
+
433
+
434
+
435
+
436
+ if __name__ == "__main__":
437
+ main()
var/D3HR/generation/dit_inversion_save_statistic.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ save_dir='/home/v-qichen3/debug/var/D3HR/ddim_inversion_statistic'
4
+ pretrained_path='/home/v-qichen3/debug/var/D3HR/DiT-XL'
5
+
6
+ # the range of class ids
7
+ # To improve efficiency, we distribute the generation of different classes across separate GPUs. You can change it to your own setting.
8
+ n=0
9
+ declare -a gpus=(0 1)
10
+ declare -a starts=($n $(($n+100)))
11
+ declare -a ends=($(($n+100)) $(($n+200)))
12
+
13
+ for i in ${!gpus[@]}; do
14
+ gpu=${gpus[$i]}
15
+ start=${starts[$i]}
16
+ end=${ends[$i]}
17
+
18
+ echo "Running on GPU $gpu with start=$start and end=$end"
19
+ python generation/dit_inversion_save_statistic.py --start $start --end $end --gpu $gpu --save_dir $save_dir --pretrained_path $pretrained_path &
20
+ done
21
+
22
+ # waiting for all tasks
23
+ wait
24
+ echo "All tasks completed."
var/D3HR/generation/group_sampling.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+ sys.path.append('/home/zhao.lin1/D3HR')
5
+ from generation.dit_inversion_save_statistic import sample
6
+ from diffusers import DiTPipeline, DDIMScheduler
7
+ import json
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ import argparse
10
+ from tqdm import tqdm
11
+ import ipdb
12
+ from PIL import Image
13
+
14
+ def save_average_latents(input_dir, save_dir):
15
+ latents_by_class = {}
16
+
17
+ # Iterate over each class folder in input_dir
18
+ for class_name in sorted(os.listdir(input_dir)):
19
+ class_dir = os.path.join(input_dir, class_name)
20
+
21
+ if os.path.isdir(class_dir):
22
+ latents = []
23
+ for file_name in sorted(os.listdir(class_dir)):
24
+ if file_name.endswith('.pt'):
25
+ file_path = os.path.join(class_dir, file_name)
26
+ latent = torch.load(file_path) # Load the latent .pth file
27
+ latents.append(latent)
28
+
29
+ # Store the latent vectors for this class
30
+ latents_tensor = torch.stack(latents)
31
+ average_latent = torch.mean(latents_tensor, dim=0)
32
+
33
+ # latents_by_class[class_name] = latents
34
+
35
+ save_file_path = os.path.join(save_dir, f'{class_name}_average_latent.pth')
36
+ torch.save(average_latent, save_file_path)
37
+ print(f"Saved average latent for class '{class_name}' to {save_file_path}")
38
+ return
39
+
40
+ def load_mapping(mapping_file):
41
+ new_mapping = {}
42
+ with open(mapping_file, 'r') as file:
43
+ data = json.load(file)
44
+ if "tiny" in mapping_file:
45
+ for index, line in enumerate(file):
46
+ #Extract the wnid starting with 'n' from each line and subtract 1 from the line number.
47
+ key = line.split()[0]
48
+ new_mapping[key] = index
49
+ else:
50
+ new_mapping = {item["wnid"]: item["index"] for item in data.values()}
51
+ return new_mapping
52
+
53
+
54
+ def process_class(class_dir, folder_path):
55
+ data_list = []
56
+ class_path = os.path.join(folder_path, class_dir)
57
+ for file_name in os.listdir(class_path):
58
+ if file_name.endswith(".pt"):
59
+ file_path = os.path.join(class_path, file_name)
60
+ tensor = torch.load(file_path, map_location=torch.device('cpu'))
61
+ data_list.append(torch.flatten(tensor, start_dim=1))
62
+
63
+ if data_list:
64
+ data = torch.stack(data_list, dim=0)
65
+ mean = data.mean(dim=0)
66
+ variance = data.var(dim=0)
67
+ return class_dir, {"mean": mean, "variance": variance}
68
+ return class_dir, None
69
+
70
+ def process_p_sample(class_dir, folder_path):
71
+ data_list = []
72
+ class_path = os.path.join(folder_path, class_dir)
73
+ for file_name in os.listdir(class_path):
74
+ if file_name.endswith(".pt"):
75
+ file_path = os.path.join(class_path, file_name)
76
+ tensor = torch.load(file_path, map_location=torch.device('cpu'))
77
+ data_list.append(tensor.flatten())
78
+
79
+ data = torch.vstack(data_list)
80
+ return data
81
+
82
+
83
+ def kl_divergence(selected_points, mean, std, device):
84
+ """
85
+ Compute the KL divergence between the candidate distribution and the target Gaussian distribution.
86
+ KL(P || Q) = 0.5 * [tr(Sigma_Q^-1 Sigma_P) + (mu_Q - mu_P)^T Sigma_Q^-1 (mu_Q - mu_P) - k + log(det(Sigma_Q) / det(Sigma_P))]
87
+ Here, the target distribution Q is N(mean,cov),
88
+ and the sampling distribution P is estimated from the selected_points.
89
+ """
90
+ # k = mean.size(0) # Feature dimension: 4090
91
+ selected_mean = selected_points.mean(dim=0)
92
+ selected_var = selected_points.var(dim=0)
93
+ selected_std = torch.sqrt(selected_var)
94
+
95
+ # Compute KL divergnece
96
+ diff = mean - selected_mean
97
+ log_sigma_ratio = torch.log(selected_std / std)
98
+ variance_ratio = (std**2 + diff**2) / (2 * selected_std**2)
99
+ kl = torch.sum(log_sigma_ratio + variance_ratio - 0.5)
100
+
101
+ return kl.item() # Return a scalar value
102
+
103
+ def kl_divergence_independent_batch(mean, std, samples, device):
104
+ mean = mean.to(device)
105
+ std = std.to(device)
106
+ samples = samples.to(device)
107
+
108
+ # Compute KL divergence
109
+ diff = samples - mean
110
+ term1 = torch.sum(diff**2, dim=1) / (2 * std**2)
111
+ kl_divs = term1 + torch.log(std) - 0.5 # log(std/std)=0
112
+ return kl_divs
113
+
114
+ def sinkhorn(A, B, epsilon=0.1, max_iter=1000, tol=1e-9):
115
+ """
116
+ Estimate the Wasserstein distance using the Sinkhorn algorithm, which supports distributions with different numbers of samples.
117
+ A, B: The two input distributions
118
+ epsilon: Sinkhorn regularization parameter
119
+ max_iter: Maximum number of iterations
120
+ tol: Convergence tolerance
121
+ """
122
+
123
+ # The amount of samples
124
+ n_a, n_b = A.size(0), B.size(0)
125
+
126
+ # Define weights and ensure normalization.
127
+ weight_a = torch.ones(n_a, device=A.device) / n_a
128
+ weight_b = torch.ones(n_b, device=B.device) / n_b
129
+
130
+ # Compute the distance matrix
131
+ C = torch.cdist(A, B, p=2) ** 2 # Squared Euclidean distance
132
+
133
+ # Initialize dual variables
134
+ u = torch.zeros(n_a, device=A.device)
135
+ v = torch.zeros(n_b, device=B.device)
136
+
137
+ K = torch.exp(-C / epsilon) # Regularized distance matrix
138
+
139
+ for _ in range(max_iter):
140
+ # Update u and c, consider weights simultaneously
141
+ u_new = epsilon * torch.log(weight_a) - epsilon * torch.logsumexp(-K / epsilon + v.view(1, -1), dim=1)
142
+ v_new = epsilon * torch.log(weight_b) - epsilon * torch.logsumexp(-K / epsilon + u_new.view(-1, 1), dim=0)
143
+
144
+ # Check convergence
145
+ if torch.max(torch.abs(u_new - u)) < tol and torch.max(torch.abs(v_new - v)) < tol:
146
+ break
147
+
148
+ u, v = u_new, v_new
149
+
150
+ transport_cost = torch.sum(K * C)
151
+ wasserstein_distance = transport_cost + epsilon * (torch.sum(u * weight_a) + torch.sum(v * weight_b))
152
+
153
+ return wasserstein_distance
154
+
155
+ def skewness(tensor):
156
+ mean = torch.mean(tensor, dim=0)
157
+ std = torch.std(tensor, dim=0)
158
+ n = tensor.size(0)
159
+ skew = torch.sum(((tensor - mean) / std) ** 3, dim=0) * (n / ((n - 1) * (n - 2)))
160
+ return skew
161
+
162
+ def kurtosis(tensor):
163
+ mean = torch.mean(tensor, dim=0)
164
+ std = torch.std(tensor, dim=0)
165
+ n = tensor.size(0)
166
+ kurt = torch.sum(((tensor - mean) / std) ** 4, dim=0) * (n * (n + 1)) / ((n - 1) * (n - 2) * (n - 3)) - (3 * (n - 1) ** 2) / ((n - 2) * (n - 3))
167
+ return kurt
168
+
169
+
170
+ def skewness_batch(tensor):
171
+ mean = torch.mean(tensor, dim=1, keepdim=True) # shape: [20000, 1, 4096]
172
+ std = torch.std(tensor, dim=1, keepdim=True) # shape: [20000, 1, 4096]
173
+
174
+ n = tensor.size(2) # feature dimension: 4096
175
+ skew = torch.sum(((tensor - mean) / std) ** 3, dim=2) * (n / ((n - 1) * (n - 2))) # shape: [20000, 1]
176
+
177
+ return skew
178
+
179
+
180
+ def evaluate_distribution(samples, mean, std):
181
+ sample_mean = torch.mean(samples, dim=0)
182
+ mean_diff = torch.norm(sample_mean - mean)
183
+
184
+ sample_std = torch.std(samples, dim=0)
185
+ std_diff = torch.norm(sample_std - std)
186
+
187
+ sample_skew = skewness_batch(samples)
188
+
189
+ skew_diff = torch.norm(torch.tensor(sample_skew) - 0) # Sample Skewness close to 0
190
+ # kurt_diff = torch.norm(torch.tensor(sample_kurt) - 3) # Sample Kurtosis close to 3
191
+
192
+ # Comprehensive evaluation: each component can be weighted as needed
193
+ score = mean_diff + std_diff + 10*skew_diff
194
+ return score
195
+
196
+ def select_algorithm(n_trials, n_samples, mean, std, device):
197
+ best_score = float('inf')
198
+ best_sample = None
199
+ for _ in range(n_trials):
200
+ samples = torch.normal(mean.expand(n_samples, -1), std.expand(n_samples, -1)).to(device)
201
+
202
+ score = evaluate_distribution(samples, mean, std)
203
+ print(score)
204
+
205
+ # Choose sample with best score
206
+ if score < best_score:
207
+ best_score = score
208
+ best_sample = samples
209
+
210
+ return best_sample
211
+
212
+
213
+ def evaluate_distribution_batch(samples, mean, std):
214
+ sample_mean = torch.mean(samples, dim=1)
215
+ mean_diff = torch.norm(sample_mean - mean, dim=1)
216
+
217
+ sample_std = torch.std(samples, dim=1)
218
+ std_diff = torch.norm(sample_std - std, dim=1)
219
+
220
+ # Compute samples Skewness
221
+ sample_skew = skewness_batch(samples)
222
+ # Batch computation of skewness differences, default is 0 -> sample_skew-0
223
+ skew_diff = torch.norm(sample_skew, dim=1)
224
+
225
+ # Comprehensive evaluation: each component can be weighted as needed
226
+ score = mean_diff + std_diff + 0.1 * skew_diff
227
+ return score
228
+
229
+ def select_algorithm_batch(n_trials, n_samples, mean, std, device, seed):
230
+ if seed is not None:
231
+ torch.manual_seed(seed)
232
+ # Batch computation, where n_trials indicates the batch size.
233
+ samples = torch.normal(mean.expand(n_trials, n_samples, -1), std.expand(n_trials, n_samples, -1)).to(device) # Batch sampling
234
+ scores = evaluate_distribution_batch(samples, mean, std) # Batch evaluating
235
+
236
+ best_score, best_idx = torch.min(scores, dim=0) # Find smaples with best (small) scores
237
+ worst_score, worst_idx = torch.max(scores, dim=0)
238
+ best_sample = samples[best_idx] # Get best samples
239
+ worst_sample = samples[worst_idx]
240
+
241
+ return best_sample, worst_sample, best_score, worst_score
242
+
243
+
244
+
245
+ def parse_args():
246
+ parser = argparse.ArgumentParser(
247
+ description="Script to train Stable Diffusion XL for InstructPix2Pix."
248
+ )
249
+ parser.add_argument(
250
+ "--mapping_file",
251
+ type=str,
252
+ default="ds_inf/imagenet_1k_mapping.json",
253
+ )
254
+ parser.add_argument("--pretrained_path", default='/scratch/zhao.lin1/DiT-XL-2-256', type=str)
255
+ parser.add_argument("--save_dir", default='/scratch/zhao.lin1/distilled_images/', type=str)
256
+ parser.add_argument("--statistic_path", default='/scratch/zhao.lin1/ddim_inversion_statistic', type=str)
257
+ parser.add_argument(
258
+ "--start",
259
+ type=int,
260
+ default=0,
261
+ )
262
+ parser.add_argument(
263
+ "--end",
264
+ type=int,
265
+ default=1000,
266
+ )
267
+ parser.add_argument(
268
+ "--gpu",
269
+ type=int,
270
+ default=0,
271
+ )
272
+ parser.add_argument(
273
+ "--ipc",
274
+ type=int,
275
+ default=20,
276
+ )
277
+ parser.add_argument(
278
+ "--start_step",
279
+ type=int,
280
+ default=18,
281
+ )
282
+ parser.add_argument(
283
+ "--i_step",
284
+ type=int,
285
+ default=6,
286
+ )
287
+ parser.add_argument(
288
+ "--m",
289
+ type=int,
290
+ default=100000,
291
+ )
292
+ args = parser.parse_args()
293
+
294
+ return args
295
+
296
+
297
+ def main():
298
+ args = parse_args()
299
+
300
+ device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
301
+
302
+ torch.cuda.set_device(device)
303
+
304
+ # x5_step i=[3, 8, 13, 18, 23, 28, 33, 38, 43, 48] start_step = [45, 40, 35, 30, 25, 20, 15, 10, 5, 0]
305
+ # 10-20_step i=[24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39] start_step = [24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9]
306
+
307
+ wnid_to_index = load_mapping(args.mapping_file)
308
+ class_dirs = sorted(list(wnid_to_index.keys()))[args.start:args.end]
309
+
310
+
311
+
312
+ pipe = DiTPipeline.from_pretrained(args.pretrained_path, torch_dtype=torch.float16)
313
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
314
+ pipe = pipe.to(device)
315
+
316
+ for class_dir in tqdm(class_dirs):
317
+ # Compute the best and worst samples
318
+ statics = torch.load(os.path.join(args.statistic_path, class_dir+'.pt'))
319
+ mean = statics['mean'][args.i_step].to(device)
320
+ variance = statics['variance'][args.i_step].to(device)
321
+ std = torch.sqrt(variance)
322
+
323
+ latents_best = None
324
+ latents_worst = None
325
+ best_overall_score = float('inf') # initialize as inf
326
+ worst_overall_score = float('-inf') # initialize as -ing
327
+
328
+ # group sampling
329
+ for i in range(args.m//10):
330
+ seed = i * 12345
331
+ best_sample, worst_sample, best_score, worst_score = select_algorithm_batch(10000, args.ipc, mean, std, device, seed)
332
+
333
+ # Update best and worst samples
334
+ if best_score < best_overall_score:
335
+ best_overall_score = best_score
336
+ latents_best = best_sample
337
+
338
+ if worst_score > worst_overall_score:
339
+ worst_overall_score = worst_score
340
+ latents_worst = worst_sample
341
+
342
+ # Output results
343
+ print("Best overall score:", best_overall_score)
344
+ print("Worst overall score:", worst_overall_score)
345
+
346
+
347
+ latents_best = latents_best.view(-1,4,32,32)
348
+ # latents_worst = latents_worst.view(-1,4,32,32)
349
+
350
+
351
+ # Generate images
352
+ for k, latent in enumerate(latents_best):
353
+ image = sample(
354
+ pipe,
355
+ class_labels=torch.tensor(wnid_to_index[class_dir]).unsqueeze(0),
356
+ start_latents=latent.unsqueeze(0).to(torch.float16),
357
+ start_step=args.start_step,
358
+ num_inference_steps=50,
359
+ device=device
360
+ )
361
+ os.makedirs(os.path.join(args.save_dir,class_dir.split('/')[-1]), exist_ok=True)
362
+ pipe.numpy_to_pil(image)[0].resize((224, 224), Image.LANCZOS).save(os.path.join(args.save_dir,class_dir.split('/')[-1], str(k)+'.png'))
363
+
364
+
365
+
366
+ if __name__ == "__main__":
367
+ main()
368
+
var/D3HR/generation/group_sampling.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Define the range of classes
4
+ # To improve efficiency, we distribute the generation of different classes across separate GPUs. You can change it to your own setting.
5
+ n=0
6
+ declare -a gpus=(0 1)
7
+ declare -a starts=($n $(($n+100)))
8
+ declare -a ends=($(($n+100)) $(($n+200)))
9
+
10
+ for i in ${!gpus[@]}; do
11
+ gpu=${gpus[$i]}
12
+ start=${starts[$i]}
13
+ end=${ends[$i]}
14
+
15
+ echo "Running on GPU $gpu with start=$start and end=$end"
16
+ python generation/group_sampling.py --start $start --end $end --gpu $gpu &
17
+ done
18
+
19
+ wait
20
+ echo "All tasks completed."
var/D3HR/imgs/framework.jpg ADDED

Git LFS Details

  • SHA256: 7b909c637b11a2aa1796378844de2b270ef68f0cf15c00580dda150444a5dac5
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
var/D3HR/imgs/framework.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adefcf20eb3eb575bc93846882e857b3cc38bb57e415c0317b0a9f2c9114d27d
3
+ size 301019
var/D3HR/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ transformers==4.46.2
4
+ diffusers==0.28.0
5
+ matplotlib==3.10.1
6
+ ipdb==0.13.13
7
+ scipy==1.15.2
8
+ huggingface-hub==0.30.2
9
+ accelerate==1.3.0
var/D3HR/validation/__pycache__/argument.cpython-310.pyc ADDED
Binary file (5.86 kB). View file
 
var/D3HR/validation/argument.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+
5
+ def str2bool(v):
6
+ """Cast string to boolean
7
+ """
8
+ if isinstance(v, bool):
9
+ return v
10
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
11
+ return True
12
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
13
+ return False
14
+ else:
15
+ raise argparse.ArgumentTypeError('Boolean value expected.')
16
+
17
+
18
+ parser = argparse.ArgumentParser("EEF")
19
+ parser.add_argument(
20
+ "--arch-name",
21
+ type=str,
22
+ default="resnet18",
23
+ help="arch name from pretrained torchvision models",
24
+ )
25
+ parser.add_argument(
26
+ "--subset",
27
+ type=str,
28
+ default="imagenet-1k",
29
+ )
30
+ parser.add_argument(
31
+ "--spec",
32
+ type=str,
33
+ default="none",
34
+ )
35
+ parser.add_argument(
36
+ "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
37
+ )
38
+ parser.add_argument(
39
+ "--data-dir",
40
+ nargs='+',
41
+ default=["../data/imagenet"],
42
+ help="path to imagenet dataset",
43
+ )
44
+ parser.add_argument(
45
+ "--nclass",
46
+ type=int,
47
+ default=10,
48
+ help="number of classes for synthesis or validation",
49
+ )
50
+ parser.add_argument(
51
+ "--ipc",
52
+ type=int,
53
+ default=10,
54
+ help="number of images per class for synthesis or validation",
55
+ )
56
+ parser.add_argument(
57
+ "--target-ipc",
58
+ type=int,
59
+ default=50,
60
+ help="number of images per class for synthesis or validation",
61
+ )
62
+ parser.add_argument(
63
+ "--phase",
64
+ type=int,
65
+ default=0,
66
+ )
67
+ parser.add_argument(
68
+ "--input-size",
69
+ default=224,
70
+ type=int,
71
+ metavar="S",
72
+ )
73
+ parser.add_argument(
74
+ "--save-size",
75
+ default=224,
76
+ type=int,
77
+ metavar="S",
78
+ )
79
+ parser.add_argument(
80
+ "--repeat",
81
+ default=1,
82
+ type=int,
83
+ help="Repeat times for the validation"
84
+ )
85
+ parser.add_argument(
86
+ "--factor",
87
+ default=2,
88
+ type=int,
89
+ )
90
+ parser.add_argument(
91
+ "--batch-size", default=0, type=int, metavar="N"
92
+ )
93
+ parser.add_argument(
94
+ "--accum-steps",
95
+ type=int,
96
+ default=1,
97
+ help="gradient accumulation steps for small gpu memory",
98
+ )
99
+ parser.add_argument(
100
+ "--mix-type",
101
+ default="cutmix",
102
+ type=str,
103
+ choices=["mixup", "cutmix", None],
104
+ help="mixup or cutmix or None",
105
+ )
106
+ parser.add_argument(
107
+ "--stud-name",
108
+ type=str,
109
+ default="resnet18",
110
+ help="arch name from torchvision models",
111
+ )
112
+ parser.add_argument(
113
+ "--workers",
114
+ default=24,
115
+ type=int,
116
+ metavar="N",
117
+ help="number of data loading workers (default: 4)",
118
+ )
119
+ parser.add_argument(
120
+ "--temperature",
121
+ type=float,
122
+ help="temperature for distillation loss",
123
+ )
124
+ parser.add_argument(
125
+ "--min-scale-crops", type=float, default=0.08, help="argument in RandomResizedCrop"
126
+ )
127
+ parser.add_argument(
128
+ "--max-scale-crops", type=float, default=1, help="argument in RandomResizedCrop"
129
+ )
130
+ parser.add_argument("--epochs", default=300, type=int)
131
+ parser.add_argument(
132
+ "--results-dir",
133
+ type=str,
134
+ default="results",
135
+ help="where to store synthetic data",
136
+ )
137
+ parser.add_argument(
138
+ "--seed", default=42, type=int, help="seed for initializing training. "
139
+ )
140
+ parser.add_argument(
141
+ "--mixup",
142
+ type=float,
143
+ default=0.8,
144
+ help="mixup alpha, mixup enabled if > 0. (default: 0.8)",
145
+ )
146
+ parser.add_argument(
147
+ "--cutmix",
148
+ type=float,
149
+ default=1.0,
150
+ help="cutmix alpha, cutmix enabled if > 0. (default: 1.0)",
151
+ )
152
+ parser.add_argument("--cos", default=True, help="cosine lr scheduler")
153
+ parser.add_argument("--verbose", type=str2bool, default=False)
154
+ parser.add_argument("--mapping_file", default="ds_inf/imagenet_1k_mapping.json", type=str)
155
+ parser.add_argument("--txt_file", default='/home/zhao.lin1/DD-DDIM-inversion/ds_inf/imagenet-1k/biggest_20%_ipc_for_all_1k.txt', type=str)
156
+ parser.add_argument("--val_txt_file", default='/home/zhao.lin1/CONCORD/val.txt', type=str)
157
+ # diffusion
158
+ parser.add_argument("--dit-model", default='DiT-XL/2')
159
+ parser.add_argument("--ckpt", type=str, default='pretrained_models/DiT-XL-2-256x256.pt',
160
+ help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).")
161
+ parser.add_argument("--dit-image-size", default=256, type=int)
162
+ parser.add_argument("--num-dit-classes", default=1000, type=int)
163
+ parser.add_argument("--diffusion-steps", default=1000, type=int)
164
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
165
+
166
+ parser.add_argument("--vae-path", default='stabilityai/sd-vae-ft-ema')
167
+
168
+ # distillation
169
+ parser.add_argument("--save-path", default='./results/test')
170
+ parser.add_argument("--description-path", default='./misc/class_description.json')
171
+ parser.add_argument("--clip-alpha", type=float, default=10.0)
172
+ parser.add_argument("--cls-alpha", type=float, default=10.0)
173
+ parser.add_argument("--num-neg-samples", type=int, default=5)
174
+ parser.add_argument("--neg-policy", type=str, default="weighted")
175
+
176
+ # sgd
177
+ parser.add_argument("--sgd", default=False, action="store_true", help="sgd optimizer")
178
+ parser.add_argument(
179
+ "-lr",
180
+ "--learning-rate",
181
+ type=float,
182
+ default=0.1,
183
+ help="sgd init learning rate",
184
+ )
185
+ parser.add_argument("--momentum", type=float, default=0.9, help="sgd momentum")
186
+ parser.add_argument("--weight-decay", type=float, default=1e-4, help="sgd weight decay")
187
+
188
+ # adamw
189
+ parser.add_argument("--adamw-lr", type=float, default=0, help="adamw learning rate")
190
+ parser.add_argument(
191
+ "--adamw-weight-decay", type=float, default=0.01, help="adamw weight decay"
192
+ )
193
+ parser.add_argument(
194
+ "--exp-name",
195
+ type=str,
196
+ help="name of the experiment, subfolder under syn_data_path",
197
+ )
198
+ args = parser.parse_args()
199
+
200
+ # temperature
201
+ if args.mix_type == "mixup":
202
+ args.temperature = 4
203
+ elif args.mix_type == "cutmix":
204
+ args.temperature = 20
205
+
206
+ if args.subset == "imagenet_1k":
207
+ args.nclass = 1000
208
+ args.classes = range(args.nclass)
209
+ args.val_ipc = 50
210
+ args.input_size = 224
211
+
212
+ elif args.subset == "imagewoof":
213
+ args.nclass = 10
214
+ args.classes = range(args.nclass)
215
+ args.val_ipc = 50
216
+ args.input_size = 224
217
+ if args.ipc == 10:
218
+ args.epochs = 2000
219
+ elif args.ipc == 50:
220
+ args.epochs = 1500
221
+ else:
222
+ args.epochs = 1000
223
+
224
+
225
+ elif args.subset == "cifar10":
226
+ args.nclass = 10
227
+ args.classes = range(args.nclass)
228
+ args.val_ipc = 1000
229
+ args.input_size = 32
230
+ args.epochs = 1000
231
+
232
+ elif args.subset == "cifar100":
233
+ args.nclass = 100
234
+ args.classes = range(args.nclass)
235
+ args.val_ipc = 100
236
+ args.input_size = 32
237
+ args.epochs = 400
238
+
239
+ elif args.subset == "tinyimagenet":
240
+ args.nclass = 200
241
+ args.classes = range(args.nclass)
242
+ args.val_ipc = 50
243
+ args.input_size = 64
244
+ args.epochs = 300
245
+
246
+ # set up batch size
247
+ if args.batch_size == 0:
248
+ if args.ipc >= 50:
249
+ args.batch_size = 100
250
+ elif args.ipc >= 10:
251
+ args.batch_size = 50
252
+ elif args.ipc > 0:
253
+ args.batch_size = 15
254
+ elif args.ipc == -1:
255
+ args.batch_size = 100
256
+
257
+ if args.nclass == 10:
258
+ args.batch_size *= 1
259
+ if args.nclass == 100:
260
+ args.batch_size *= 2
261
+ # if args.nclass == 1000:
262
+ # args.batch_size *= 2
263
+
264
+ # reset batch size below ipc * nclass
265
+ if args.ipc != -1 and args.batch_size > args.ipc * args.nclass:
266
+ args.batch_size = int(args.ipc * args.nclass)
267
+
268
+ # reset batch size with accum_steps
269
+ if args.accum_steps != 1:
270
+ args.batch_size = int(args.batch_size / args.accum_steps)
271
+
272
+ # result dir for saving
273
+ args.exp_name = f"{args.spec}_{args.arch_name}_f{args.factor}_ipc{args.ipc}"
274
+ if not os.path.exists(f"./exp/{args.exp_name}"):
275
+ os.makedirs(f"./exp/{args.exp_name}")
276
+
277
+
278
+ # adamw learning rate
279
+ if args.stud_name == "vgg11":
280
+ args.adamw_lr = 0.0005
281
+ elif args.stud_name == "conv3":
282
+ args.adamw_lr = 0.001
283
+ elif args.stud_name == "conv4":
284
+ args.adamw_lr = 0.001
285
+ elif args.stud_name == "conv5":
286
+ args.adamw_lr = 0.001
287
+ elif args.stud_name == "conv6":
288
+ args.adamw_lr = 0.001
289
+ elif args.stud_name == "resnet18":
290
+ args.adamw_lr = 0.001
291
+ elif args.stud_name == "resnet18_modified":
292
+ args.adamw_lr = 0.001
293
+ elif args.stud_name == "efficientnet_b0":
294
+ args.adamw_lr = 0.002
295
+ elif args.stud_name == "mobilenet_v2":
296
+ args.adamw_lr = 0.0025
297
+ elif args.stud_name == "alexnet":
298
+ args.adamw_lr = 0.0001
299
+ elif args.stud_name == "resnet50":
300
+ args.adamw_lr = 0.001
301
+ elif args.stud_name == "resnet50_modified":
302
+ args.adamw_lr = 0.001
303
+ elif args.stud_name == "resnet101":
304
+ args.adamw_lr = 0.001
305
+ elif args.stud_name == "resnet101_modified":
306
+ args.adamw_lr = 0.001
307
+ elif args.stud_name == "vit_b_16":
308
+ args.adamw_lr = 0.0001
309
+ elif args.stud_name == "swin_v2_t":
310
+ args.adamw_lr = 0.0001
var/D3HR/validation/get_train_list.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def load_mapping(mapping_file):
5
+ new_mapping = {}
6
+ with open(mapping_file, 'r') as file:
7
+ data = json.load(file)
8
+ if "tiny" in mapping_file:
9
+ for index, line in enumerate(file):
10
+ key = line.split()[0]
11
+ new_mapping[key] = index
12
+ else:
13
+ new_mapping = {item["wnid"]: item["index"] for item in data.values()}
14
+ return new_mapping
15
+
16
+
17
+ wnid_to_index = load_mapping("ds_inf/imagenet_1k_mapping.json")
18
+ class_dirs = sorted(list(wnid_to_index.keys()))
19
+ path_list = []
20
+ for class_dir in class_dirs:
21
+ for i in range(20):
22
+ path_list.append(os.path.join('/scratch/zhao.lin1/imagenet1k_256_4.0classfree_start_step_18_ddim_inversion_20_min_images_2/', class_dir, str(i)+'.png'))
23
+ output_file = "/scratch/zhao.lin1/imagenet1k_256_4.0classfree_start_step_18_ddim_inversion_20_min_images_2/train.txt"
24
+ with open(output_file, "w") as file:
25
+ for path in path_list:
26
+ file.write(path + "\n")
var/D3HR/validation/models/__init__.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as thmodels
4
+
5
+ from .convnet import ConvNet
6
+ from .resnet import resnet18, resnet50, resnet101, resnet152
7
+ from .mobilenet_v2 import mobilenetv2
8
+ # import timm
9
+
10
+
11
+
12
+ def load_model(model_name="resnet18", dataset="cifar10", spec='full', pretrained=True, input_size=224, classes=[]):
13
+ def get_model(model_name="resnet18"):
14
+ if "conv" in model_name:
15
+ size = input_size
16
+ nclass = 1000
17
+
18
+ model = ConvNet(
19
+ num_classes=nclass,
20
+ net_norm="batch",
21
+ net_act="relu",
22
+ net_pooling="avgpooling",
23
+ net_depth=int(model_name[-1]),
24
+ net_width=128,
25
+ channel=3,
26
+ im_size=(size, size),
27
+ )
28
+ elif model_name == 'resnet18':
29
+ model = resnet18(weights=None)
30
+ elif model_name == 'resnet50':
31
+ model = resnet50(weights=None)
32
+ elif model_name == 'resnet101':
33
+ model = resnet101(weights=None)
34
+ elif model_name == 'resnet152':
35
+ model = resnet152(weights=None)
36
+ elif model_name == 'mobilenet_v2':
37
+ model = mobilenetv2()
38
+ elif model_name == 'efficientnet_b0':
39
+ model = timm.create_model('efficientnet_b0.ra_in1k', pretrained=False)
40
+ elif model_name == "resnet18_modified":
41
+ model = thmodels.__dict__["resnet18"](pretrained=False)
42
+ model.conv1 = nn.Conv2d(
43
+ 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
44
+ )
45
+ model.maxpool = nn.Identity()
46
+ elif model_name == "resnet50_modified":
47
+ model = thmodels.__dict__["resnet50"](pretrained=False)
48
+ model.conv1 = nn.Conv2d(
49
+ 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
50
+ )
51
+ model.maxpool = nn.Identity()
52
+ elif model_name == "resnet101_modified":
53
+ model = thmodels.__dict__["resnet101"](pretrained=False)
54
+ model.conv1 = nn.Conv2d(
55
+ 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
56
+ )
57
+ model.maxpool = nn.Identity()
58
+ else:
59
+ model = thmodels.__dict__[model_name](weights=None)
60
+
61
+ return model
62
+
63
+ def pruning_classifier(model=None, classes=[]):
64
+ try:
65
+ model_named_parameters = [name for name, x in model.named_parameters()]
66
+ for name, x in model.named_parameters():
67
+ if (
68
+ name == model_named_parameters[-1]
69
+ or name == model_named_parameters[-2]
70
+ ):
71
+ x.data = x[classes]
72
+ except:
73
+ print("ERROR in changing the number of classes.")
74
+
75
+ return model
76
+
77
+ model = get_model(model_name)
78
+ model = pruning_classifier(model, classes)
79
+
80
+ if pretrained:
81
+ if dataset == 'imagenet_1k':
82
+ if model_name == "efficientnet_b0":
83
+ checkpoint = timm.create_model('efficientnet_b0.ra_in1k', pretrained=True).state_dict()
84
+ model.load_state_dict(checkpoint)
85
+ elif model_name == 'conv4':
86
+ state_dict = torch.load('/home/linz/CONCORD/pretrained_models/imagenet-1k_conv4.pth')
87
+ model.load_state_dict(state_dict['model'])
88
+ elif model_name == 'resnet18':
89
+ model = resnet18(weights='DEFAULT')
90
+ elif model_name == 'mobilenet_v2':
91
+ model.load_state_dict(torch.load('/home/zhao.lin1/CONCORD/pretrained_models/mobilenetv2_1.0-0c6065bc.pth'))
92
+ else:
93
+ raise AttributeError(f'{model_name} is not supported in the pre-trained pool')
94
+ else:
95
+ checkpoint = torch.load(
96
+ f"pretrain_models/{dataset}_{model_name}.pth", map_location="cpu"
97
+ )
98
+ model.load_state_dict(checkpoint["model"])
99
+
100
+
101
+ return model
102
+
103
+
104
+ # def load_model(model_name="resnet18", dataset="cifar10", pretrained=True, classes=[]):
105
+ # def get_model(model_name="resnet18"):
106
+ # if "conv" in model_name:
107
+ # if dataset in ["cifar10", "cifar100"]:
108
+ # size = 32
109
+ # elif dataset == "tinyimagenet":
110
+ # size = 64
111
+ # elif dataset in ["imagenet-nette", "imagenet-woof", "imagenet-100"]:
112
+ # size = 128
113
+ # else:
114
+ # size = 224
115
+
116
+ # nclass = len(classes)
117
+
118
+ # model = ConvNet(
119
+ # num_classes=nclass,
120
+ # net_norm="batch",
121
+ # net_act="relu",
122
+ # net_pooling="avgpooling",
123
+ # net_depth=int(model_name[-1]),
124
+ # net_width=128,
125
+ # channel=3,
126
+ # im_size=(size, size),
127
+ # )
128
+ # elif model_name == "resnet18_modified":
129
+ # model = thmodels.__dict__["resnet18"](pretrained=False)
130
+ # model.conv1 = nn.Conv2d(
131
+ # 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
132
+ # )
133
+ # model.maxpool = nn.Identity()
134
+ # elif model_name == "resnet101_modified":
135
+ # model = thmodels.__dict__["resnet101"](pretrained=False)
136
+ # model.conv1 = nn.Conv2d(
137
+ # 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
138
+ # )
139
+ # model.maxpool = nn.Identity()
140
+ # else:
141
+ # model = thmodels.__dict__[model_name](pretrained=False)
142
+
143
+ # return model
144
+
145
+ # def pruning_classifier(model=None, classes=[]):
146
+ # try:
147
+ # model_named_parameters = [name for name, x in model.named_parameters()]
148
+ # for name, x in model.named_parameters():
149
+ # if (
150
+ # name == model_named_parameters[-1]
151
+ # or name == model_named_parameters[-2]
152
+ # ):
153
+ # x.data = x[classes]
154
+ # except:
155
+ # print("ERROR in changing the number of classes.")
156
+
157
+ # return model
158
+
159
+ # # "imagenet-100" "imagenet-10" "imagenet-first" "imagenet-nette" "imagenet-woof"
160
+ # model = get_model(model_name)
161
+ # model = pruning_classifier(model, classes)
162
+ # if pretrained:
163
+ # if dataset in [
164
+ # "imagenet-100",
165
+ # "imagenet-10",
166
+ # "imagenet-nette",
167
+ # "imagenet-woof",
168
+ # "tinyimagenet",
169
+ # "cifar10",
170
+ # "cifar100",
171
+ # ]:
172
+ # checkpoint = torch.load(
173
+ # f"./data/pretrain_models/{dataset}_{model_name}.pth", map_location="cpu"
174
+ # )
175
+ # model.load_state_dict(checkpoint["model"])
176
+ # elif dataset in ["imagenet-1k"]:
177
+ # if model_name == "efficientNet-b0":
178
+ # # Specifically, for loading the pre-trained EfficientNet model, the following modifications are made
179
+ # from torchvision.models._api import WeightsEnum
180
+ # from torch.hub import load_state_dict_from_url
181
+
182
+ # def get_state_dict(self, *args, **kwargs):
183
+ # kwargs.pop("check_hash")
184
+ # return load_state_dict_from_url(self.url, *args, **kwargs)
185
+
186
+ # WeightsEnum.get_state_dict = get_state_dict
187
+
188
+ # model = thmodels.__dict__[model_name](pretrained=True)
189
+
190
+ # return model
var/D3HR/validation/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.97 kB). View file
 
var/D3HR/validation/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (2.9 kB). View file
 
var/D3HR/validation/models/__pycache__/convnet.cpython-310.pyc ADDED
Binary file (3.56 kB). View file
 
var/D3HR/validation/models/__pycache__/convnet.cpython-37.pyc ADDED
Binary file (3.5 kB). View file
 
var/D3HR/validation/models/__pycache__/mobilenet_v2.cpython-310.pyc ADDED
Binary file (4.43 kB). View file
 
var/D3HR/validation/models/__pycache__/mobilenet_v2.cpython-37.pyc ADDED
Binary file (4.34 kB). View file
 
var/D3HR/validation/models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (2 kB). View file
 
var/D3HR/validation/models/__pycache__/resnet.cpython-37.pyc ADDED
Binary file (2.33 kB). View file
 
var/D3HR/validation/models/convnet.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ # Conv-3 model
6
+ class ConvNet(nn.Module):
7
+ def __init__(
8
+ self,
9
+ num_classes,
10
+ net_norm="batch",
11
+ net_depth=3,
12
+ net_width=128,
13
+ channel=3,
14
+ net_act="relu",
15
+ net_pooling="avgpooling",
16
+ im_size=(32, 32),
17
+ ):
18
+ # print(f"Define Convnet (depth {net_depth}, width {net_width}, norm {net_norm})")
19
+ super(ConvNet, self).__init__()
20
+ if net_act == "sigmoid":
21
+ self.net_act = nn.Sigmoid()
22
+ elif net_act == "relu":
23
+ self.net_act = nn.ReLU()
24
+ elif net_act == "leakyrelu":
25
+ self.net_act = nn.LeakyReLU(negative_slope=0.01)
26
+ else:
27
+ exit("unknown activation function: %s" % net_act)
28
+
29
+ if net_pooling == "maxpooling":
30
+ self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
31
+ elif net_pooling == "avgpooling":
32
+ self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2)
33
+ elif net_pooling == "none":
34
+ self.net_pooling = None
35
+ else:
36
+ exit("unknown net_pooling: %s" % net_pooling)
37
+
38
+ self.depth = net_depth
39
+ self.net_norm = net_norm
40
+
41
+ self.layers, shape_feat = self._make_layers(
42
+ channel, net_width, net_depth, net_norm, net_pooling, im_size
43
+ )
44
+ num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
45
+ self.classifier = nn.Linear(num_feat, num_classes)
46
+
47
+ def forward(self, x, return_features=False):
48
+ for d in range(self.depth):
49
+ x = self.layers["conv"][d](x)
50
+ if len(self.layers["norm"]) > 0:
51
+ x = self.layers["norm"][d](x)
52
+ x = self.layers["act"][d](x)
53
+ if len(self.layers["pool"]) > 0:
54
+ x = self.layers["pool"][d](x)
55
+
56
+ # x = nn.functional.avg_pool2d(x, x.shape[-1])
57
+ out = x.view(x.shape[0], -1)
58
+ logit = self.classifier(out)
59
+
60
+ if return_features:
61
+ return logit, out
62
+ else:
63
+ return logit
64
+
65
+ def get_feature(
66
+ self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False
67
+ ):
68
+ if idx_to == -1:
69
+ idx_to = idx_from
70
+ features = []
71
+
72
+ for d in range(self.depth):
73
+ x = self.layers["conv"][d](x)
74
+ if self.net_norm:
75
+ x = self.layers["norm"][d](x)
76
+ x = self.layers["act"][d](x)
77
+ if self.net_pooling:
78
+ x = self.layers["pool"][d](x)
79
+ features.append(x)
80
+ if idx_to < len(features):
81
+ return features[idx_from : idx_to + 1]
82
+
83
+ if return_prob:
84
+ out = x.view(x.size(0), -1)
85
+ logit = self.classifier(out)
86
+ prob = torch.softmax(logit, dim=-1)
87
+ return features, prob
88
+ elif return_logit:
89
+ out = x.view(x.size(0), -1)
90
+ logit = self.classifier(out)
91
+ return features, logit
92
+ else:
93
+ return features[idx_from : idx_to + 1]
94
+
95
+ def _get_normlayer(self, net_norm, shape_feat):
96
+ # shape_feat = (c * h * w)
97
+ if net_norm == "batch":
98
+ norm = nn.BatchNorm2d(shape_feat[0], affine=True)
99
+ elif net_norm == "layer":
100
+ norm = nn.LayerNorm(shape_feat, elementwise_affine=True)
101
+ elif net_norm == "instance":
102
+ norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
103
+ elif net_norm == "group":
104
+ norm = nn.GroupNorm(4, shape_feat[0], affine=True)
105
+ elif net_norm == "none":
106
+ norm = None
107
+ else:
108
+ norm = None
109
+ exit("unknown net_norm: %s" % net_norm)
110
+ return norm
111
+
112
+ def _make_layers(
113
+ self, channel, net_width, net_depth, net_norm, net_pooling, im_size
114
+ ):
115
+ layers = {"conv": [], "norm": [], "act": [], "pool": []}
116
+
117
+ in_channels = channel
118
+ if im_size[0] == 28:
119
+ im_size = (32, 32)
120
+ shape_feat = [in_channels, im_size[0], im_size[1]]
121
+
122
+ for d in range(net_depth):
123
+ layers["conv"] += [
124
+ nn.Conv2d(
125
+ in_channels,
126
+ net_width,
127
+ kernel_size=3,
128
+ padding=3 if channel == 1 and d == 0 else 1,
129
+ )
130
+ ]
131
+ shape_feat[0] = net_width
132
+ if net_norm != "none":
133
+ layers["norm"] += [self._get_normlayer(net_norm, shape_feat)]
134
+ layers["act"] += [self.net_act]
135
+ in_channels = net_width
136
+ if net_pooling != "none":
137
+ layers["pool"] += [self.net_pooling]
138
+ shape_feat[1] //= 2
139
+ shape_feat[2] //= 2
140
+
141
+ layers["conv"] = nn.ModuleList(layers["conv"])
142
+ layers["norm"] = nn.ModuleList(layers["norm"])
143
+ layers["act"] = nn.ModuleList(layers["act"])
144
+ layers["pool"] = nn.ModuleList(layers["pool"])
145
+ layers = nn.ModuleDict(layers)
146
+
147
+ return layers, shape_feat
var/D3HR/validation/models/dit_models.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.jit import Final
16
+ import numpy as np
17
+ import math
18
+ from timm.models.vision_transformer import PatchEmbed, Mlp
19
+ from timm.layers import use_fused_attn
20
+
21
+
22
+ def modulate(x, shift, scale):
23
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
24
+
25
+
26
+ class Attention(nn.Module):
27
+ fused_attn: Final[bool]
28
+
29
+ def __init__(
30
+ self,
31
+ dim,
32
+ num_heads=8,
33
+ qkv_bias=False,
34
+ qk_norm=False,
35
+ attn_drop=0.,
36
+ proj_drop=0.,
37
+ norm_layer=nn.LayerNorm,
38
+ use_gamma=False
39
+ ):
40
+ super().__init__()
41
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
42
+ self.num_heads = num_heads
43
+ self.head_dim = dim // num_heads
44
+ self.scale = self.head_dim ** -0.5
45
+ self.fused_attn = use_fused_attn()
46
+
47
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
49
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
50
+ self.attn_drop = nn.Dropout(attn_drop)
51
+ self.proj = nn.Linear(dim, dim)
52
+ self.proj_drop = nn.Dropout(proj_drop)
53
+ if use_gamma:
54
+ self.gamma1 = nn.Parameter(torch.ones(dim * 3))
55
+ self.gamma2 = nn.Parameter(torch.ones(dim))
56
+ else:
57
+ self.gamma1 = 1
58
+ self.gamma2 = 1
59
+
60
+ def forward(self, x):
61
+ B, N, C = x.shape
62
+ qkv = (self.gamma1 * self.qkv(x)).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
63
+ q, k, v = qkv.unbind(0)
64
+ q, k = self.q_norm(q), self.k_norm(k)
65
+
66
+ if self.fused_attn:
67
+ x = F.scaled_dot_product_attention(
68
+ q, k, v,
69
+ dropout_p=self.attn_drop.p if self.training else 0.,
70
+ )
71
+ else:
72
+ q = q * self.scale
73
+ attn = q @ k.transpose(-2, -1)
74
+ attn = attn.softmax(dim=-1)
75
+ attn = self.attn_drop(attn)
76
+ x = attn @ v
77
+
78
+ x = x.transpose(1, 2).reshape(B, N, C)
79
+ x = self.gamma2 * self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
84
+ #################################################################################
85
+ # Embedding Layers for Timesteps and Class Labels #
86
+ #################################################################################
87
+
88
+ class TimestepEmbedder(nn.Module):
89
+ """
90
+ Embeds scalar timesteps into vector representations.
91
+ """
92
+ def __init__(self, hidden_size, frequency_embedding_size=256):
93
+ super().__init__()
94
+ self.mlp = nn.Sequential(
95
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
96
+ nn.SiLU(),
97
+ nn.Linear(hidden_size, hidden_size, bias=True),
98
+ )
99
+ self.frequency_embedding_size = frequency_embedding_size
100
+
101
+ @staticmethod
102
+ def timestep_embedding(t, dim, max_period=10000):
103
+ """
104
+ Create sinusoidal timestep embeddings.
105
+ :param t: a 1-D Tensor of N indices, one per batch element.
106
+ These may be fractional.
107
+ :param dim: the dimension of the output.
108
+ :param max_period: controls the minimum frequency of the embeddings.
109
+ :return: an (N, D) Tensor of positional embeddings.
110
+ """
111
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
112
+ half = dim // 2
113
+ freqs = torch.exp(
114
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
115
+ ).to(device=t.device)
116
+ args = t[:, None].float() * freqs[None]
117
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
118
+ if dim % 2:
119
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
120
+ return embedding
121
+
122
+ def forward(self, t):
123
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
124
+ t_emb = self.mlp(t_freq)
125
+ return t_emb
126
+
127
+
128
+ class LabelEmbedder(nn.Module):
129
+ """
130
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
131
+ """
132
+ def __init__(self, num_classes, hidden_size, dropout_prob):
133
+ super().__init__()
134
+ use_cfg_embedding = dropout_prob > 0
135
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
136
+ self.num_classes = num_classes
137
+ self.dropout_prob = dropout_prob
138
+
139
+ def token_drop(self, labels, force_drop_ids=None):
140
+ """
141
+ Drops labels to enable classifier-free guidance.
142
+ """
143
+ if force_drop_ids is None:
144
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
145
+ else:
146
+ drop_ids = force_drop_ids == 1
147
+ labels = torch.where(drop_ids, self.num_classes, labels)
148
+ return labels
149
+
150
+ def forward(self, labels, train, force_drop_ids=None):
151
+ use_dropout = self.dropout_prob > 0
152
+ if (train and use_dropout) or (force_drop_ids is not None):
153
+ labels = self.token_drop(labels, force_drop_ids)
154
+ embeddings = self.embedding_table(labels)
155
+ return embeddings
156
+
157
+
158
+ #################################################################################
159
+ # Core DiT Model #
160
+ #################################################################################
161
+
162
+ class DiTBlock(nn.Module):
163
+ """
164
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
165
+ """
166
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, use_gamma=False, **block_kwargs):
167
+ super().__init__()
168
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, use_gamma=use_gamma, **block_kwargs)
170
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
171
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
172
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
173
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
174
+ self.adaLN_modulation = nn.Sequential(
175
+ nn.SiLU(),
176
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
177
+ )
178
+ if use_gamma:
179
+ self.gamma1 = nn.Parameter(torch.ones(hidden_size))
180
+ self.gamma2 = nn.Parameter(torch.ones(hidden_size))
181
+ else:
182
+ self.gamma1 = 1
183
+ self.gamma2 = 1
184
+
185
+ def forward(self, x, c):
186
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
187
+ x = x + self.gamma1 * gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
188
+ x = x + self.gamma2 * gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
189
+ return x
190
+
191
+
192
+ class FinalLayer(nn.Module):
193
+ """
194
+ The final layer of DiT.
195
+ """
196
+ def __init__(self, hidden_size, patch_size, out_channels):
197
+ super().__init__()
198
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
199
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
200
+ self.adaLN_modulation = nn.Sequential(
201
+ nn.SiLU(),
202
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
203
+ )
204
+
205
+ def forward(self, x, c):
206
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
207
+ x = modulate(self.norm_final(x), shift, scale)
208
+ x = self.linear(x)
209
+ return x
210
+
211
+
212
+ class DiT(nn.Module):
213
+ """
214
+ Diffusion model with a Transformer backbone.
215
+ """
216
+ def __init__(
217
+ self,
218
+ input_size=32,
219
+ patch_size=2,
220
+ in_channels=4,
221
+ hidden_size=1152,
222
+ depth=28,
223
+ num_heads=16,
224
+ mlp_ratio=4.0,
225
+ class_dropout_prob=0.1,
226
+ num_classes=1000,
227
+ learn_sigma=True,
228
+ ):
229
+ super().__init__()
230
+ self.learn_sigma = learn_sigma
231
+ self.in_channels = in_channels
232
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
233
+ self.patch_size = patch_size
234
+ self.num_heads = num_heads
235
+
236
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
237
+ self.t_embedder = TimestepEmbedder(hidden_size)
238
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
239
+ num_patches = self.x_embedder.num_patches
240
+ # Will use fixed sin-cos embedding:
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
242
+
243
+ use_gamma = [True] * 14 + [False] * 14
244
+ self.blocks = nn.ModuleList([
245
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, use_gamma=use_gamma[depth_index]) for depth_index in range(depth)
246
+ ])
247
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
248
+ self.initialize_weights()
249
+
250
+ def initialize_weights(self):
251
+ # Initialize transformer layers:
252
+ def _basic_init(module):
253
+ if isinstance(module, nn.Linear):
254
+ torch.nn.init.xavier_uniform_(module.weight)
255
+ if module.bias is not None:
256
+ nn.init.constant_(module.bias, 0)
257
+ self.apply(_basic_init)
258
+
259
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
260
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
261
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
262
+
263
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
264
+ w = self.x_embedder.proj.weight.data
265
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
266
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
267
+
268
+ # Initialize label embedding table:
269
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
270
+
271
+ # Initialize timestep embedding MLP:
272
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
273
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
274
+
275
+ # Zero-out adaLN modulation layers in DiT blocks:
276
+ for block in self.blocks:
277
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
278
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
279
+
280
+ # Zero-out output layers:
281
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
282
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
283
+ nn.init.constant_(self.final_layer.linear.weight, 0)
284
+ nn.init.constant_(self.final_layer.linear.bias, 0)
285
+
286
+ def unpatchify(self, x):
287
+ """
288
+ x: (N, T, patch_size**2 * C)
289
+ imgs: (N, H, W, C)
290
+ """
291
+ c = self.out_channels
292
+ p = self.x_embedder.patch_size[0]
293
+ h = w = int(x.shape[1] ** 0.5)
294
+ assert h * w == x.shape[1]
295
+
296
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
297
+ x = torch.einsum('nhwpqc->nchpwq', x)
298
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
299
+ return imgs
300
+
301
+ def forward(self, x, t, y):
302
+ """
303
+ Forward pass of DiT.
304
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
305
+ t: (N,) tensor of diffusion timesteps
306
+ y: (N,) tensor of class labels
307
+ """
308
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
309
+ t = self.t_embedder(t) # (N, D)
310
+ y = self.y_embedder(y, self.training) # (N, D)
311
+ c = t + y # (N, D)
312
+ for block in self.blocks:
313
+ x = block(x, c) # (N, T, D)
314
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
315
+ x = self.unpatchify(x) # (N, out_channels, H, W)
316
+ return x
317
+
318
+ def forward_with_cfg(self, x, t, y, cfg_scale, **kwargs):
319
+ """
320
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
321
+ """
322
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
323
+ half = x[: len(x) // 2]
324
+ combined = torch.cat([half, half], dim=0)
325
+ model_out = self.forward(combined, t, y)
326
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
327
+ # three channels by default. The standard approach to cfg applies it to all channels.
328
+ # This can be done by uncommenting the following line and commenting-out the line following that.
329
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
330
+ eps, rest = model_out[:, :3], model_out[:, 3:]
331
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
332
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
333
+ eps = torch.cat([half_eps, half_eps], dim=0)
334
+ return torch.cat([eps, rest], dim=1)
335
+
336
+
337
+ #################################################################################
338
+ # Sine/Cosine Positional Embedding Functions #
339
+ #################################################################################
340
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
341
+
342
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
343
+ """
344
+ grid_size: int of the grid height and width
345
+ return:
346
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
347
+ """
348
+ grid_h = np.arange(grid_size, dtype=np.float32)
349
+ grid_w = np.arange(grid_size, dtype=np.float32)
350
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
351
+ grid = np.stack(grid, axis=0)
352
+
353
+ grid = grid.reshape([2, 1, grid_size, grid_size])
354
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
355
+ if cls_token and extra_tokens > 0:
356
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
357
+ return pos_embed
358
+
359
+
360
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
361
+ assert embed_dim % 2 == 0
362
+
363
+ # use half of dimensions to encode grid_h
364
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
365
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
366
+
367
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
368
+ return emb
369
+
370
+
371
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
372
+ """
373
+ embed_dim: output dimension for each position
374
+ pos: a list of positions to be encoded: size (M,)
375
+ out: (M, D)
376
+ """
377
+ assert embed_dim % 2 == 0
378
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
379
+ omega /= embed_dim / 2.
380
+ omega = 1. / 10000**omega # (D/2,)
381
+
382
+ pos = pos.reshape(-1) # (M,)
383
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
384
+
385
+ emb_sin = np.sin(out) # (M, D/2)
386
+ emb_cos = np.cos(out) # (M, D/2)
387
+
388
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
389
+ return emb
390
+
391
+
392
+ #################################################################################
393
+ # DiT Configs #
394
+ #################################################################################
395
+
396
+ def DiT_XL_2(**kwargs):
397
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
398
+
399
+ def DiT_XL_4(**kwargs):
400
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
401
+
402
+ def DiT_XL_8(**kwargs):
403
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
404
+
405
+ def DiT_L_2(**kwargs):
406
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
407
+
408
+ def DiT_L_4(**kwargs):
409
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
410
+
411
+ def DiT_L_8(**kwargs):
412
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
413
+
414
+ def DiT_B_2(**kwargs):
415
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
416
+
417
+ def DiT_B_4(**kwargs):
418
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
419
+
420
+ def DiT_B_8(**kwargs):
421
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
422
+
423
+ def DiT_S_2(**kwargs):
424
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
425
+
426
+ def DiT_S_4(**kwargs):
427
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
428
+
429
+ def DiT_S_8(**kwargs):
430
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
431
+
432
+
433
+ DiT_models = {
434
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
435
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
436
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
437
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
438
+ }
var/D3HR/validation/models/mobilenet_v2.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Creates a MobileNetV2 Model as defined in:
3
+ Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. (2018).
4
+ MobileNetV2: Inverted Residuals and Linear Bottlenecks
5
+ arXiv preprint arXiv:1801.04381.
6
+ import from https://github.com/tonylins/pytorch-mobilenet-v2
7
+ """
8
+
9
+ import torch.nn as nn
10
+ import math
11
+
12
+ __all__ = ['mobilenetv2']
13
+
14
+
15
+ def _make_divisible(v, divisor, min_value=None):
16
+ """
17
+ This function is taken from the original tf repo.
18
+ It ensures that all layers have a channel number that is divisible by 8
19
+ It can be seen here:
20
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
21
+ :param v:
22
+ :param divisor:
23
+ :param min_value:
24
+ :return:
25
+ """
26
+ if min_value is None:
27
+ min_value = divisor
28
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
29
+ # Make sure that round down does not go down by more than 10%.
30
+ if new_v < 0.9 * v:
31
+ new_v += divisor
32
+ return new_v
33
+
34
+
35
+ def conv_3x3_bn(inp, oup, stride):
36
+ return nn.Sequential(
37
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
38
+ nn.BatchNorm2d(oup),
39
+ nn.ReLU6(inplace=True)
40
+ )
41
+
42
+
43
+ def conv_1x1_bn(inp, oup):
44
+ return nn.Sequential(
45
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
46
+ nn.BatchNorm2d(oup),
47
+ nn.ReLU6(inplace=True)
48
+ )
49
+
50
+
51
+ class InvertedResidual(nn.Module):
52
+ def __init__(self, inp, oup, stride, expand_ratio):
53
+ super(InvertedResidual, self).__init__()
54
+ assert stride in [1, 2]
55
+
56
+ hidden_dim = round(inp * expand_ratio)
57
+ self.identity = stride == 1 and inp == oup
58
+
59
+ if expand_ratio == 1:
60
+ self.conv = nn.Sequential(
61
+ # dw
62
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
63
+ nn.BatchNorm2d(hidden_dim),
64
+ nn.ReLU6(inplace=True),
65
+ # pw-linear
66
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
67
+ nn.BatchNorm2d(oup),
68
+ )
69
+ else:
70
+ self.conv = nn.Sequential(
71
+ # pw
72
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
73
+ nn.BatchNorm2d(hidden_dim),
74
+ nn.ReLU6(inplace=True),
75
+ # dw
76
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
77
+ nn.BatchNorm2d(hidden_dim),
78
+ nn.ReLU6(inplace=True),
79
+ # pw-linear
80
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
81
+ nn.BatchNorm2d(oup),
82
+ )
83
+
84
+ def forward(self, x):
85
+ if self.identity:
86
+ return x + self.conv(x)
87
+ else:
88
+ return self.conv(x)
89
+
90
+
91
+ class MobileNetV2(nn.Module):
92
+ def __init__(self, num_classes=1000, width_mult=1.):
93
+ super(MobileNetV2, self).__init__()
94
+ # setting of inverted residual blocks
95
+ self.cfgs = [
96
+ # t, c, n, s
97
+ [1, 16, 1, 1],
98
+ [6, 24, 2, 2],
99
+ [6, 32, 3, 2],
100
+ [6, 64, 4, 2],
101
+ [6, 96, 3, 1],
102
+ [6, 160, 3, 2],
103
+ [6, 320, 1, 1],
104
+ ]
105
+
106
+ # building first layer
107
+ input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8)
108
+ layers = [conv_3x3_bn(3, input_channel, 2)]
109
+ # building inverted residual blocks
110
+ block = InvertedResidual
111
+ for t, c, n, s in self.cfgs:
112
+ output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8)
113
+ for i in range(n):
114
+ layers.append(block(input_channel, output_channel, s if i == 0 else 1, t))
115
+ input_channel = output_channel
116
+ self.features = nn.Sequential(*layers)
117
+ # building last several layers
118
+ output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280
119
+ self.conv = conv_1x1_bn(input_channel, output_channel)
120
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
121
+ self.classifier = nn.Linear(output_channel, num_classes)
122
+
123
+ self._initialize_weights()
124
+
125
+ def forward(self, x):
126
+ x = self.features(x)
127
+ x = self.conv(x)
128
+ x = self.avgpool(x)
129
+ x = x.view(x.size(0), -1)
130
+ x = self.classifier(x)
131
+ return x
132
+
133
+ def _initialize_weights(self):
134
+ for m in self.modules():
135
+ if isinstance(m, nn.Conv2d):
136
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137
+ m.weight.data.normal_(0, math.sqrt(2. / n))
138
+ if m.bias is not None:
139
+ m.bias.data.zero_()
140
+ elif isinstance(m, nn.BatchNorm2d):
141
+ m.weight.data.fill_(1)
142
+ m.bias.data.zero_()
143
+ elif isinstance(m, nn.Linear):
144
+ m.weight.data.normal_(0, 0.01)
145
+ m.bias.data.zero_()
146
+
147
+ def mobilenetv2(**kwargs):
148
+ """
149
+ Constructs a MobileNet V2 model
150
+ """
151
+ return MobileNetV2(**kwargs)
var/D3HR/validation/models/pipeline_stable_unclip_img2img.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import PIL.Image
19
+ import torch
20
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
21
+
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
24
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
+ from diffusers.models.embeddings import get_timestep_embedding
26
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
27
+ from diffusers.schedulers import KarrasDiffusionSchedulers
28
+ from diffusers.utils import (
29
+ USE_PEFT_BACKEND,
30
+ deprecate,
31
+ logging,
32
+ replace_example_docstring,
33
+ scale_lora_layers,
34
+ unscale_lora_layers,
35
+ )
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
38
+ from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
39
+
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```py
46
+ >>> import requests
47
+ >>> import torch
48
+ >>> from PIL import Image
49
+ >>> from io import BytesIO
50
+
51
+ >>> from diffusers import StableUnCLIPImg2ImgPipeline
52
+
53
+ >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
54
+ ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16
55
+ ... ) # TODO update model path
56
+ >>> pipe = pipe.to("cuda")
57
+
58
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
59
+
60
+ >>> response = requests.get(url)
61
+ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
62
+ >>> init_image = init_image.resize((768, 512))
63
+
64
+ >>> prompt = "A fantasy landscape, trending on artstation"
65
+
66
+ >>> images = pipe(prompt, init_image).images
67
+ >>> images[0].save("fantasy_landscape.png")
68
+ ```
69
+ """
70
+
71
+
72
+ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
73
+ """
74
+ Pipeline for text-guided image-to-image generation using stable unCLIP.
75
+
76
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
77
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
78
+
79
+ The pipeline also inherits the following loading methods:
80
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
81
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
82
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
83
+
84
+ Args:
85
+ feature_extractor ([`CLIPImageProcessor`]):
86
+ Feature extractor for image pre-processing before being encoded.
87
+ image_encoder ([`CLIPVisionModelWithProjection`]):
88
+ CLIP vision model for encoding images.
89
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
90
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
91
+ embeddings after the noise has been applied.
92
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
93
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
94
+ by the `noise_level`.
95
+ tokenizer (`~transformers.CLIPTokenizer`):
96
+ A [`~transformers.CLIPTokenizer`)].
97
+ text_encoder ([`~transformers.CLIPTextModel`]):
98
+ Frozen [`~transformers.CLIPTextModel`] text-encoder.
99
+ unet ([`UNet2DConditionModel`]):
100
+ A [`UNet2DConditionModel`] to denoise the encoded image latents.
101
+ scheduler ([`KarrasDiffusionSchedulers`]):
102
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
103
+ vae ([`AutoencoderKL`]):
104
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
105
+ """
106
+
107
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
108
+ _exclude_from_cpu_offload = ["image_normalizer"]
109
+
110
+ # image encoding components
111
+ feature_extractor: CLIPImageProcessor
112
+ image_encoder: CLIPVisionModelWithProjection
113
+
114
+ # image noising components
115
+ image_normalizer: StableUnCLIPImageNormalizer
116
+ image_noising_scheduler: KarrasDiffusionSchedulers
117
+
118
+ # regular denoising components
119
+ tokenizer: CLIPTokenizer
120
+ text_encoder: CLIPTextModel
121
+ unet: UNet2DConditionModel
122
+ scheduler: KarrasDiffusionSchedulers
123
+
124
+ vae: AutoencoderKL
125
+
126
+ def __init__(
127
+ self,
128
+ # image encoding components
129
+ feature_extractor: CLIPImageProcessor,
130
+ image_encoder: CLIPVisionModelWithProjection,
131
+ # image noising components
132
+ image_normalizer: StableUnCLIPImageNormalizer,
133
+ image_noising_scheduler: KarrasDiffusionSchedulers,
134
+ # regular denoising components
135
+ tokenizer: CLIPTokenizer,
136
+ text_encoder: CLIPTextModel,
137
+ unet: UNet2DConditionModel,
138
+ scheduler: KarrasDiffusionSchedulers,
139
+ # vae
140
+ vae: AutoencoderKL,
141
+ ):
142
+ super().__init__()
143
+
144
+ self.register_modules(
145
+ feature_extractor=feature_extractor,
146
+ image_encoder=image_encoder,
147
+ image_normalizer=image_normalizer,
148
+ image_noising_scheduler=image_noising_scheduler,
149
+ tokenizer=tokenizer,
150
+ text_encoder=text_encoder,
151
+ unet=unet,
152
+ scheduler=scheduler,
153
+ vae=vae,
154
+ )
155
+
156
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
157
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
158
+
159
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
160
+ def enable_vae_slicing(self):
161
+ r"""
162
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
163
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
164
+ """
165
+ self.vae.enable_slicing()
166
+
167
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
168
+ def disable_vae_slicing(self):
169
+ r"""
170
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
171
+ computing decoding in one step.
172
+ """
173
+ self.vae.disable_slicing()
174
+
175
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
176
+ def _encode_prompt(
177
+ self,
178
+ prompt,
179
+ device,
180
+ num_images_per_prompt,
181
+ do_classifier_free_guidance,
182
+ negative_prompt=None,
183
+ prompt_embeds: Optional[torch.FloatTensor] = None,
184
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
185
+ lora_scale: Optional[float] = None,
186
+ **kwargs,
187
+ ):
188
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
189
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
190
+
191
+ prompt_embeds_tuple = self.encode_prompt(
192
+ prompt=prompt,
193
+ device=device,
194
+ num_images_per_prompt=num_images_per_prompt,
195
+ do_classifier_free_guidance=do_classifier_free_guidance,
196
+ negative_prompt=negative_prompt,
197
+ prompt_embeds=prompt_embeds,
198
+ negative_prompt_embeds=negative_prompt_embeds,
199
+ lora_scale=lora_scale,
200
+ **kwargs,
201
+ )
202
+
203
+ # concatenate for backwards comp
204
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
205
+
206
+ return prompt_embeds
207
+
208
+ def _encode_image(
209
+ self,
210
+ image,
211
+ device,
212
+ batch_size,
213
+ num_images_per_prompt,
214
+ do_classifier_free_guidance,
215
+ noise_level,
216
+ generator,
217
+ image_embeds,
218
+ ):
219
+ dtype = next(self.image_encoder.parameters()).dtype
220
+
221
+ if isinstance(image, PIL.Image.Image):
222
+ # the image embedding should repeated so it matches the total batch size of the prompt
223
+ repeat_by = batch_size
224
+ else:
225
+ # assume the image input is already properly batched and just needs to be repeated so
226
+ # it matches the num_images_per_prompt.
227
+ #
228
+ # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
229
+ # `image_embeds`. If those happen to be common use cases, let's think harder about
230
+ # what the expected dimensions of inputs should be and how we handle the encoding.
231
+ repeat_by = num_images_per_prompt
232
+
233
+ if image_embeds is None:
234
+ if not isinstance(image, torch.Tensor):
235
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
236
+
237
+ image = image.to(device=device, dtype=dtype)
238
+ image_embeds = self.image_encoder(image).image_embeds
239
+
240
+ image_embeds = self.noise_image_embeddings(
241
+ image_embeds=image_embeds,
242
+ noise_level=noise_level,
243
+ generator=generator,
244
+ )
245
+
246
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
247
+ image_embeds = image_embeds.unsqueeze(1)
248
+ bs_embed, seq_len, _ = image_embeds.shape
249
+ image_embeds = image_embeds.repeat(1, repeat_by, 1)
250
+ image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
251
+ image_embeds = image_embeds.squeeze(1)
252
+
253
+ if do_classifier_free_guidance:
254
+ negative_prompt_embeds = torch.zeros_like(image_embeds)
255
+
256
+ # For classifier free guidance, we need to do two forward passes.
257
+ # Here we concatenate the unconditional and text embeddings into a single batch
258
+ # to avoid doing two forward passes
259
+ image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
260
+
261
+ return image_embeds
262
+
263
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
264
+ def encode_prompt(
265
+ self,
266
+ prompt,
267
+ device,
268
+ num_images_per_prompt,
269
+ do_classifier_free_guidance,
270
+ negative_prompt=None,
271
+ prompt_embeds: Optional[torch.FloatTensor] = None,
272
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
273
+ lora_scale: Optional[float] = None,
274
+ clip_skip: Optional[int] = None,
275
+ ):
276
+ r"""
277
+ Encodes the prompt into text encoder hidden states.
278
+
279
+ Args:
280
+ prompt (`str` or `List[str]`, *optional*):
281
+ prompt to be encoded
282
+ device: (`torch.device`):
283
+ torch device
284
+ num_images_per_prompt (`int`):
285
+ number of images that should be generated per prompt
286
+ do_classifier_free_guidance (`bool`):
287
+ whether to use classifier free guidance or not
288
+ negative_prompt (`str` or `List[str]`, *optional*):
289
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
290
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
291
+ less than `1`).
292
+ prompt_embeds (`torch.FloatTensor`, *optional*):
293
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
294
+ provided, text embeddings will be generated from `prompt` input argument.
295
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
296
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
297
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
298
+ argument.
299
+ lora_scale (`float`, *optional*):
300
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
301
+ clip_skip (`int`, *optional*):
302
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
303
+ the output of the pre-final layer will be used for computing the prompt embeddings.
304
+ """
305
+ # set lora scale so that monkey patched LoRA
306
+ # function of text encoder can correctly access it
307
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
308
+ self._lora_scale = lora_scale
309
+
310
+ # dynamically adjust the LoRA scale
311
+ if not USE_PEFT_BACKEND:
312
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
313
+ else:
314
+ scale_lora_layers(self.text_encoder, lora_scale)
315
+
316
+ if prompt is not None and isinstance(prompt, str):
317
+ batch_size = 1
318
+ elif prompt is not None and isinstance(prompt, list):
319
+ batch_size = len(prompt)
320
+ else:
321
+ batch_size = prompt_embeds.shape[0]
322
+
323
+ if prompt_embeds is None:
324
+ # textual inversion: procecss multi-vector tokens if necessary
325
+ if isinstance(self, TextualInversionLoaderMixin):
326
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
327
+
328
+ text_inputs = self.tokenizer(
329
+ prompt,
330
+ padding="max_length",
331
+ max_length=self.tokenizer.model_max_length,
332
+ truncation=True,
333
+ return_tensors="pt",
334
+ )
335
+ text_input_ids = text_inputs.input_ids
336
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
337
+
338
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
339
+ text_input_ids, untruncated_ids
340
+ ):
341
+ removed_text = self.tokenizer.batch_decode(
342
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
343
+ )
344
+ logger.warning(
345
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
346
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
347
+ )
348
+
349
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
350
+ attention_mask = text_inputs.attention_mask.to(device)
351
+ else:
352
+ attention_mask = None
353
+
354
+ if clip_skip is None:
355
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
356
+ prompt_embeds = prompt_embeds[0]
357
+ else:
358
+ prompt_embeds = self.text_encoder(
359
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
360
+ )
361
+ # Access the `hidden_states` first, that contains a tuple of
362
+ # all the hidden states from the encoder layers. Then index into
363
+ # the tuple to access the hidden states from the desired layer.
364
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
365
+ # We also need to apply the final LayerNorm here to not mess with the
366
+ # representations. The `last_hidden_states` that we typically use for
367
+ # obtaining the final prompt representations passes through the LayerNorm
368
+ # layer.
369
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
370
+
371
+ if self.text_encoder is not None:
372
+ prompt_embeds_dtype = self.text_encoder.dtype
373
+ elif self.unet is not None:
374
+ prompt_embeds_dtype = self.unet.dtype
375
+ else:
376
+ prompt_embeds_dtype = prompt_embeds.dtype
377
+
378
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
379
+
380
+ bs_embed, seq_len, _ = prompt_embeds.shape
381
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
382
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
383
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
384
+
385
+ # get unconditional embeddings for classifier free guidance
386
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
387
+ uncond_tokens: List[str]
388
+ if negative_prompt is None:
389
+ uncond_tokens = [""] * batch_size
390
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
391
+ raise TypeError(
392
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
393
+ f" {type(prompt)}."
394
+ )
395
+ elif isinstance(negative_prompt, str):
396
+ uncond_tokens = [negative_prompt]
397
+ elif batch_size != len(negative_prompt):
398
+ raise ValueError(
399
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
400
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
401
+ " the batch size of `prompt`."
402
+ )
403
+ else:
404
+ uncond_tokens = negative_prompt
405
+
406
+ # textual inversion: procecss multi-vector tokens if necessary
407
+ if isinstance(self, TextualInversionLoaderMixin):
408
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
409
+
410
+ max_length = prompt_embeds.shape[1]
411
+ uncond_input = self.tokenizer(
412
+ uncond_tokens,
413
+ padding="max_length",
414
+ max_length=max_length,
415
+ truncation=True,
416
+ return_tensors="pt",
417
+ )
418
+
419
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
420
+ attention_mask = uncond_input.attention_mask.to(device)
421
+ else:
422
+ attention_mask = None
423
+
424
+ negative_prompt_embeds = self.text_encoder(
425
+ uncond_input.input_ids.to(device),
426
+ attention_mask=attention_mask,
427
+ )
428
+ negative_prompt_embeds = negative_prompt_embeds[0]
429
+
430
+ if do_classifier_free_guidance:
431
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
432
+ seq_len = negative_prompt_embeds.shape[1]
433
+
434
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
435
+
436
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
437
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
438
+
439
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
440
+ # Retrieve the original scale by scaling back the LoRA layers
441
+ unscale_lora_layers(self.text_encoder, lora_scale)
442
+
443
+ return prompt_embeds, negative_prompt_embeds
444
+
445
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
446
+ def decode_latents(self, latents):
447
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
448
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
449
+
450
+ latents = 1 / self.vae.config.scaling_factor * latents
451
+ image = self.vae.decode(latents, return_dict=False)[0]
452
+ image = (image / 2 + 0.5).clamp(0, 1)
453
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
454
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
455
+ return image
456
+
457
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
458
+ def prepare_extra_step_kwargs(self, generator, eta):
459
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
460
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
461
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
462
+ # and should be between [0, 1]
463
+
464
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
465
+ extra_step_kwargs = {}
466
+ if accepts_eta:
467
+ extra_step_kwargs["eta"] = eta
468
+
469
+ # check if the scheduler accepts generator
470
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
471
+ if accepts_generator:
472
+ extra_step_kwargs["generator"] = generator
473
+ return extra_step_kwargs
474
+
475
+ def check_inputs(
476
+ self,
477
+ prompt,
478
+ image,
479
+ height,
480
+ width,
481
+ callback_steps,
482
+ noise_level,
483
+ negative_prompt=None,
484
+ prompt_embeds=None,
485
+ negative_prompt_embeds=None,
486
+ image_embeds=None,
487
+ ):
488
+ if height % 8 != 0 or width % 8 != 0:
489
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
490
+
491
+ if (callback_steps is None) or (
492
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
493
+ ):
494
+ raise ValueError(
495
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
496
+ f" {type(callback_steps)}."
497
+ )
498
+
499
+ if prompt is not None and prompt_embeds is not None:
500
+ raise ValueError(
501
+ "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two."
502
+ )
503
+
504
+ if prompt is None and prompt_embeds is None:
505
+ raise ValueError(
506
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
507
+ )
508
+
509
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
510
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
511
+
512
+ if negative_prompt is not None and negative_prompt_embeds is not None:
513
+ raise ValueError(
514
+ "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined."
515
+ )
516
+
517
+ if prompt is not None and negative_prompt is not None:
518
+ if type(prompt) is not type(negative_prompt):
519
+ raise TypeError(
520
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
521
+ f" {type(prompt)}."
522
+ )
523
+
524
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
525
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
526
+ raise ValueError(
527
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
528
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
529
+ f" {negative_prompt_embeds.shape}."
530
+ )
531
+
532
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
533
+ raise ValueError(
534
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
535
+ )
536
+
537
+ if image is not None and image_embeds is not None:
538
+ raise ValueError(
539
+ "Provide either `image` or `image_embeds`. Please make sure to define only one of the two."
540
+ )
541
+
542
+ if image is None and image_embeds is None:
543
+ raise ValueError(
544
+ "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined."
545
+ )
546
+
547
+ if image is not None:
548
+ if (
549
+ not isinstance(image, torch.Tensor)
550
+ and not isinstance(image, PIL.Image.Image)
551
+ and not isinstance(image, list)
552
+ ):
553
+ raise ValueError(
554
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
555
+ f" {type(image)}"
556
+ )
557
+
558
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
559
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
560
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
561
+ if isinstance(generator, list) and len(generator) != batch_size:
562
+ raise ValueError(
563
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
564
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
565
+ )
566
+
567
+ if latents is None:
568
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
569
+ else:
570
+ latents = latents.to(device)
571
+
572
+ # scale the initial noise by the standard deviation required by the scheduler
573
+ latents = latents * self.scheduler.init_noise_sigma
574
+ return latents
575
+
576
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
577
+ def noise_image_embeddings(
578
+ self,
579
+ image_embeds: torch.Tensor,
580
+ noise_level: int,
581
+ noise: Optional[torch.FloatTensor] = None,
582
+ generator: Optional[torch.Generator] = None,
583
+ ):
584
+ """
585
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
586
+ `noise_level` increases the variance in the final un-noised images.
587
+
588
+ The noise is applied in two ways:
589
+ 1. A noise schedule is applied directly to the embeddings.
590
+ 2. A vector of sinusoidal time embeddings are appended to the output.
591
+
592
+ In both cases, the amount of noise is controlled by the same `noise_level`.
593
+
594
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
595
+ """
596
+ if noise is None:
597
+ noise = randn_tensor(
598
+ image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
599
+ )
600
+
601
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
602
+
603
+ self.image_normalizer.to(image_embeds.device)
604
+ image_embeds = self.image_normalizer.scale(image_embeds)
605
+
606
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
607
+
608
+ image_embeds = self.image_normalizer.unscale(image_embeds)
609
+
610
+ noise_level = get_timestep_embedding(
611
+ timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
612
+ )
613
+
614
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
615
+ # but we might actually be running in fp16. so we need to cast here.
616
+ # there might be better ways to encapsulate this.
617
+ noise_level = noise_level.to(image_embeds.dtype)
618
+
619
+ image_embeds = torch.cat((image_embeds, noise_level), 1)
620
+
621
+ return image_embeds
622
+
623
+ @torch.no_grad()
624
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
625
+ def __call__(
626
+ self,
627
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
628
+ prompt: Union[str, List[str]] = None,
629
+ height: Optional[int] = None,
630
+ width: Optional[int] = None,
631
+ num_inference_steps: int = 20,
632
+ guidance_scale: float = 10,
633
+ negative_prompt: Optional[Union[str, List[str]]] = None,
634
+ num_images_per_prompt: Optional[int] = 1,
635
+ eta: float = 0.0,
636
+ generator: Optional[torch.Generator] = None,
637
+ latents: Optional[torch.FloatTensor] = None,
638
+ prompt_embeds: Optional[torch.FloatTensor] = None,
639
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
640
+ output_type: Optional[str] = "pil",
641
+ return_dict: bool = True,
642
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
643
+ callback_steps: int = 1,
644
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
645
+ noise_level: int = 0,
646
+ image_embeds: Optional[torch.FloatTensor] = None,
647
+ clip_skip: Optional[int] = None,
648
+ cond_fn = None,
649
+ ):
650
+ r"""
651
+ The call function to the pipeline for generation.
652
+
653
+ Args:
654
+ prompt (`str` or `List[str]`, *optional*):
655
+ The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
656
+ used or prompt is initialized to `""`.
657
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
658
+ `Image` or tensor representing an image batch. The image is encoded to its CLIP embedding which the
659
+ `unet` is conditioned on. The image is _not_ encoded by the `vae` and then used as the latents in the
660
+ denoising process like it is in the standard Stable Diffusion text-guided image variation process.
661
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
662
+ The height in pixels of the generated image.
663
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
664
+ The width in pixels of the generated image.
665
+ num_inference_steps (`int`, *optional*, defaults to 20):
666
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
667
+ expense of slower inference.
668
+ guidance_scale (`float`, *optional*, defaults to 10.0):
669
+ A higher guidance scale value encourages the model to generate images closely linked to the text
670
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
671
+ negative_prompt (`str` or `List[str]`, *optional*):
672
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
673
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
674
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
675
+ The number of images to generate per prompt.
676
+ eta (`float`, *optional*, defaults to 0.0):
677
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
678
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
679
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
680
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
681
+ generation deterministic.
682
+ latents (`torch.FloatTensor`, *optional*):
683
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
684
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
685
+ tensor is generated by sampling using the supplied random `generator`.
686
+ prompt_embeds (`torch.FloatTensor`, *optional*):
687
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
688
+ provided, text embeddings are generated from the `prompt` input argument.
689
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
690
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
691
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
692
+ output_type (`str`, *optional*, defaults to `"pil"`):
693
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
694
+ return_dict (`bool`, *optional*, defaults to `True`):
695
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
696
+ callback (`Callable`, *optional*):
697
+ A function that calls every `callback_steps` steps during inference. The function is called with the
698
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
699
+ callback_steps (`int`, *optional*, defaults to 1):
700
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
701
+ every step.
702
+ cross_attention_kwargs (`dict`, *optional*):
703
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
704
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
705
+ noise_level (`int`, *optional*, defaults to `0`):
706
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
707
+ the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details.
708
+ image_embeds (`torch.FloatTensor`, *optional*):
709
+ Pre-generated CLIP embeddings to condition the `unet` on. These latents are not used in the denoising
710
+ process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`.
711
+ clip_skip (`int`, *optional*):
712
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
713
+ the output of the pre-final layer will be used for computing the prompt embeddings.
714
+
715
+ Examples:
716
+
717
+ Returns:
718
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
719
+ [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning
720
+ a tuple, the first element is a list with the generated images.
721
+ """
722
+ # 0. Default height and width to unet
723
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
724
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
725
+
726
+ if prompt is None and prompt_embeds is None:
727
+ prompt = len(image) * [""] if isinstance(image, list) else ""
728
+
729
+ # 1. Check inputs. Raise error if not correct
730
+ self.check_inputs(
731
+ prompt=prompt,
732
+ image=image,
733
+ height=height,
734
+ width=width,
735
+ callback_steps=callback_steps,
736
+ noise_level=noise_level,
737
+ negative_prompt=negative_prompt,
738
+ prompt_embeds=prompt_embeds,
739
+ negative_prompt_embeds=negative_prompt_embeds,
740
+ image_embeds=image_embeds,
741
+ )
742
+
743
+ # 2. Define call parameters
744
+ if prompt is not None and isinstance(prompt, str):
745
+ batch_size = 1
746
+ elif prompt is not None and isinstance(prompt, list):
747
+ batch_size = len(prompt)
748
+ else:
749
+ batch_size = prompt_embeds.shape[0]
750
+
751
+ batch_size = batch_size * num_images_per_prompt
752
+
753
+ device = self._execution_device
754
+
755
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
756
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
757
+ # corresponds to doing no classifier free guidance.
758
+ do_classifier_free_guidance = guidance_scale > 1.0
759
+
760
+ # 3. Encode input prompt
761
+ text_encoder_lora_scale = (
762
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
763
+ )
764
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
765
+ prompt=prompt,
766
+ device=device,
767
+ num_images_per_prompt=num_images_per_prompt,
768
+ do_classifier_free_guidance=do_classifier_free_guidance,
769
+ negative_prompt=negative_prompt,
770
+ prompt_embeds=prompt_embeds,
771
+ negative_prompt_embeds=negative_prompt_embeds,
772
+ lora_scale=text_encoder_lora_scale,
773
+ )
774
+ # For classifier free guidance, we need to do two forward passes.
775
+ # Here we concatenate the unconditional and text embeddings into a single batch
776
+ # to avoid doing two forward passes
777
+ if do_classifier_free_guidance:
778
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
779
+
780
+ # 4. Encoder input image
781
+ noise_level = torch.tensor([noise_level], device=device)
782
+ image_embeds = self._encode_image(
783
+ image=image,
784
+ device=device,
785
+ batch_size=batch_size,
786
+ num_images_per_prompt=num_images_per_prompt,
787
+ do_classifier_free_guidance=do_classifier_free_guidance,
788
+ noise_level=noise_level,
789
+ generator=generator,
790
+ image_embeds=image_embeds,
791
+ )
792
+
793
+ # 5. Prepare timesteps
794
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
795
+ timesteps = self.scheduler.timesteps
796
+
797
+ # 6. Prepare latent variables
798
+ num_channels_latents = self.unet.config.in_channels
799
+ latents = self.prepare_latents(
800
+ batch_size=batch_size,
801
+ num_channels_latents=num_channels_latents,
802
+ height=height,
803
+ width=width,
804
+ dtype=prompt_embeds.dtype,
805
+ device=device,
806
+ generator=generator,
807
+ latents=latents,
808
+ )
809
+
810
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
811
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
812
+
813
+ # 8. Denoising loop
814
+ for i, t in enumerate(timesteps):
815
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
816
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
817
+
818
+ # predict the noise residual
819
+ noise_pred = self.unet(
820
+ latent_model_input,
821
+ t,
822
+ encoder_hidden_states=prompt_embeds,
823
+ class_labels=image_embeds,
824
+ cross_attention_kwargs=cross_attention_kwargs,
825
+ return_dict=False,
826
+ )[0]
827
+
828
+ # perform guidance
829
+ if do_classifier_free_guidance:
830
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
831
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
832
+
833
+ # compute the previous noisy sample x_t -> x_t-1
834
+ latents = self.scheduler.step(noise_pred, t, latents, cond_fn, **extra_step_kwargs, return_dict=False)[0]
835
+
836
+ if callback is not None and i % callback_steps == 0:
837
+ step_idx = i // getattr(self.scheduler, "order", 1)
838
+ callback(step_idx, t, latents)
839
+
840
+ # 9. Post-processing
841
+ if not output_type == "latent":
842
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
843
+ else:
844
+ image = latents
845
+
846
+ image = self.image_processor.postprocess(image, output_type=output_type)
847
+
848
+ # Offload all models
849
+ self.maybe_free_model_hooks()
850
+
851
+ if not return_dict:
852
+ return (image,)
853
+
854
+ return ImagePipelineOutput(images=image)
var/D3HR/validation/models/resnet.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models.resnet import (
4
+ ResNet, ResNet18_Weights, ResNet50_Weights, ResNet101_Weights,
5
+ BasicBlock, Bottleneck,
6
+ _ovewrite_named_param
7
+ )
8
+
9
+
10
+ class FeatResNet(ResNet):
11
+ def __init__(self, block, layers, **kwargs):
12
+ super(FeatResNet, self).__init__(block, layers, **kwargs)
13
+
14
+ def get_features(self, x):
15
+ x = self.conv1(x)
16
+ x = self.bn1(x)
17
+ x = self.relu(x)
18
+ x = self.maxpool(x)
19
+
20
+ x = self.layer1(x)
21
+ x = self.layer2(x)
22
+ x = self.layer3(x)
23
+ x = self.layer4(x)
24
+
25
+ x = self.avgpool(x)
26
+ x = torch.flatten(x, 1)
27
+
28
+ return x
29
+
30
+
31
+ def resnet18(*, weights=None, progress=True, **kwargs):
32
+ weights = ResNet18_Weights.verify(weights)
33
+ if weights is not None:
34
+ _ovewrite_named_param(kwargs, 'num_classes', len(weights.meta['categories']))
35
+
36
+ model = FeatResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
37
+
38
+ if weights is not None:
39
+ model.load_state_dict(weights.get_state_dict(progress=progress))
40
+
41
+ return model
42
+
43
+
44
+ def resnet50(*, weights=None, progress=True, **kwargs):
45
+ weights = ResNet50_Weights.verify(weights)
46
+ if weights is not None:
47
+ _ovewrite_named_param(kwargs, 'num_classes', len(weights.meta['categories']))
48
+
49
+ model = FeatResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
50
+
51
+ if weights is not None:
52
+ model.load_state_dict(weights.get_state_dict(progress=progress))
53
+
54
+ return model
55
+
56
+
57
+ def resnet101(*, weights=None, progress=True, **kwargs):
58
+ weights = ResNet101_Weights.verify(weights)
59
+ if weights is not None:
60
+ _ovewrite_named_param(kwargs, 'num_classes', len(weights.meta['categories']))
61
+
62
+ model = FeatResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
63
+
64
+ if weights is not None:
65
+ model.load_state_dict(weights.get_state_dict(progress=progress))
66
+
67
+ return model
68
+
69
+
70
+ def resnet152(*, weights=None, progress=True, **kwargs):
71
+ weights = ResNet101_Weights.verify(weights)
72
+ if weights is not None:
73
+ _ovewrite_named_param(kwargs, 'num_classes', len(weights.meta['categories']))
74
+
75
+ model = FeatResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
76
+
77
+ if weights is not None:
78
+ model.load_state_dict(weights.get_state_dict(progress=progress))
79
+
80
+ return model
var/D3HR/validation/models/scheduling_ddim.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
29
+
30
+
31
+ @dataclass
32
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
33
+ class DDIMSchedulerOutput(BaseOutput):
34
+ """
35
+ Output class for the scheduler's `step` function output.
36
+
37
+ Args:
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
+ denoising loop.
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
+ `pred_original_sample` can be used to preview progress or for guidance.
44
+ """
45
+
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
48
+
49
+
50
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
51
+ def betas_for_alpha_bar(
52
+ num_diffusion_timesteps,
53
+ max_beta=0.999,
54
+ alpha_transform_type="cosine",
55
+ ):
56
+ """
57
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
58
+ (1-beta) over time from t = [0,1].
59
+
60
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
61
+ to that part of the diffusion process.
62
+
63
+
64
+ Args:
65
+ num_diffusion_timesteps (`int`): the number of betas to produce.
66
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
67
+ prevent singularities.
68
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
69
+ Choose from `cosine` or `exp`
70
+
71
+ Returns:
72
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
73
+ """
74
+ if alpha_transform_type == "cosine":
75
+
76
+ def alpha_bar_fn(t):
77
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
78
+
79
+ elif alpha_transform_type == "exp":
80
+
81
+ def alpha_bar_fn(t):
82
+ return math.exp(t * -12.0)
83
+
84
+ else:
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
+
87
+ betas = []
88
+ for i in range(num_diffusion_timesteps):
89
+ t1 = i / num_diffusion_timesteps
90
+ t2 = (i + 1) / num_diffusion_timesteps
91
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
92
+ return torch.tensor(betas, dtype=torch.float32)
93
+
94
+
95
+ def rescale_zero_terminal_snr(betas):
96
+ """
97
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
98
+
99
+
100
+ Args:
101
+ betas (`torch.Tensor`):
102
+ the betas that the scheduler is being initialized with.
103
+
104
+ Returns:
105
+ `torch.Tensor`: rescaled betas with zero terminal SNR
106
+ """
107
+ # Convert betas to alphas_bar_sqrt
108
+ alphas = 1.0 - betas
109
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
110
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
111
+
112
+ # Store old values.
113
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
114
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
115
+
116
+ # Shift so the last timestep is zero.
117
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
118
+
119
+ # Scale so the first timestep is back to the old value.
120
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
121
+
122
+ # Convert alphas_bar_sqrt to betas
123
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
124
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
125
+ alphas = torch.cat([alphas_bar[0:1], alphas])
126
+ betas = 1 - alphas
127
+
128
+ return betas
129
+
130
+
131
+ class DDIMScheduler(SchedulerMixin, ConfigMixin):
132
+ """
133
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
134
+ non-Markovian guidance.
135
+
136
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
137
+ methods the library implements for all schedulers such as loading and saving.
138
+
139
+ Args:
140
+ num_train_timesteps (`int`, defaults to 1000):
141
+ The number of diffusion steps to train the model.
142
+ beta_start (`float`, defaults to 0.0001):
143
+ The starting `beta` value of inference.
144
+ beta_end (`float`, defaults to 0.02):
145
+ The final `beta` value.
146
+ beta_schedule (`str`, defaults to `"linear"`):
147
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
148
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
149
+ trained_betas (`np.ndarray`, *optional*):
150
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
151
+ clip_sample (`bool`, defaults to `True`):
152
+ Clip the predicted sample for numerical stability.
153
+ clip_sample_range (`float`, defaults to 1.0):
154
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
155
+ set_alpha_to_one (`bool`, defaults to `True`):
156
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
157
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
158
+ otherwise it uses the alpha value at step 0.
159
+ steps_offset (`int`, defaults to 0):
160
+ An offset added to the inference steps, as required by some model families.
161
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
162
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
163
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
164
+ Video](https://imagen.research.google/video/paper.pdf) paper).
165
+ thresholding (`bool`, defaults to `False`):
166
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
167
+ as Stable Diffusion.
168
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
169
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
170
+ sample_max_value (`float`, defaults to 1.0):
171
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
172
+ timestep_spacing (`str`, defaults to `"leading"`):
173
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
174
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
175
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
176
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
177
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
178
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
179
+ """
180
+
181
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
182
+ order = 1
183
+
184
+ @register_to_config
185
+ def __init__(
186
+ self,
187
+ num_train_timesteps: int = 1000,
188
+ beta_start: float = 0.0001,
189
+ beta_end: float = 0.02,
190
+ beta_schedule: str = "linear",
191
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
192
+ clip_sample: bool = True,
193
+ set_alpha_to_one: bool = True,
194
+ steps_offset: int = 0,
195
+ prediction_type: str = "epsilon",
196
+ thresholding: bool = False,
197
+ dynamic_thresholding_ratio: float = 0.995,
198
+ clip_sample_range: float = 1.0,
199
+ sample_max_value: float = 1.0,
200
+ timestep_spacing: str = "leading",
201
+ rescale_betas_zero_snr: bool = False,
202
+ ):
203
+ if trained_betas is not None:
204
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
205
+ elif beta_schedule == "linear":
206
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
207
+ elif beta_schedule == "scaled_linear":
208
+ # this schedule is very specific to the latent diffusion model.
209
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
210
+ elif beta_schedule == "squaredcos_cap_v2":
211
+ # Glide cosine schedule
212
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
213
+ else:
214
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
215
+
216
+ # Rescale for zero SNR
217
+ if rescale_betas_zero_snr:
218
+ self.betas = rescale_zero_terminal_snr(self.betas)
219
+
220
+ self.alphas = 1.0 - self.betas
221
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
222
+
223
+ # At every step in ddim, we are looking into the previous alphas_cumprod
224
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
225
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
226
+ # whether we use the final alpha of the "non-previous" one.
227
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
228
+
229
+ # standard deviation of the initial noise distribution
230
+ self.init_noise_sigma = 1.0
231
+
232
+ # setable values
233
+ self.num_inference_steps = None
234
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
235
+
236
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
237
+ """
238
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
239
+ current timestep.
240
+
241
+ Args:
242
+ sample (`torch.Tensor`):
243
+ The input sample.
244
+ timestep (`int`, *optional*):
245
+ The current timestep in the diffusion chain.
246
+
247
+ Returns:
248
+ `torch.Tensor`:
249
+ A scaled input sample.
250
+ """
251
+ return sample
252
+
253
+ def _get_variance(self, timestep, prev_timestep):
254
+ alpha_prod_t = self.alphas_cumprod[timestep]
255
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
256
+ beta_prod_t = 1 - alpha_prod_t
257
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
258
+
259
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
260
+
261
+ return variance
262
+
263
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
264
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
265
+ """
266
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
267
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
268
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
269
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
270
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
271
+
272
+ https://arxiv.org/abs/2205.11487
273
+ """
274
+ dtype = sample.dtype
275
+ batch_size, channels, *remaining_dims = sample.shape
276
+
277
+ if dtype not in (torch.float32, torch.float64):
278
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
279
+
280
+ # Flatten sample for doing quantile calculation along each image
281
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
282
+
283
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
284
+
285
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
286
+ s = torch.clamp(
287
+ s, min=1, max=self.config.sample_max_value
288
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
289
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
290
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
291
+
292
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
293
+ sample = sample.to(dtype)
294
+
295
+ return sample
296
+
297
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
298
+ """
299
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
300
+
301
+ Args:
302
+ num_inference_steps (`int`):
303
+ The number of diffusion steps used when generating samples with a pre-trained model.
304
+ """
305
+
306
+ if num_inference_steps > self.config.num_train_timesteps:
307
+ raise ValueError(
308
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
309
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
310
+ f" maximal {self.config.num_train_timesteps} timesteps."
311
+ )
312
+
313
+ self.num_inference_steps = num_inference_steps
314
+
315
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
316
+ if self.config.timestep_spacing == "linspace":
317
+ timesteps = (
318
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
319
+ .round()[::-1]
320
+ .copy()
321
+ .astype(np.int64)
322
+ )
323
+ elif self.config.timestep_spacing == "leading":
324
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
325
+ # creates integer timesteps by multiplying by ratio
326
+ # casting to int to avoid issues when num_inference_step is power of 3
327
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
328
+ timesteps += self.config.steps_offset
329
+ elif self.config.timestep_spacing == "trailing":
330
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
331
+ # creates integer timesteps by multiplying by ratio
332
+ # casting to int to avoid issues when num_inference_step is power of 3
333
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
334
+ timesteps -= 1
335
+ else:
336
+ raise ValueError(
337
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
338
+ )
339
+
340
+ self.timesteps = torch.from_numpy(timesteps).to(device)
341
+
342
+ def step(
343
+ self,
344
+ model_output: torch.Tensor,
345
+ timestep: int,
346
+ sample: torch.Tensor,
347
+ cond_fn = None,
348
+ eta: float = 0.0,
349
+ use_clipped_model_output: bool = False,
350
+ generator=None,
351
+ variance_noise: Optional[torch.Tensor] = None,
352
+ return_dict: bool = True,
353
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
354
+ """
355
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
356
+ process from the learned model outputs (most often the predicted noise).
357
+
358
+ Args:
359
+ model_output (`torch.Tensor`):
360
+ The direct output from learned diffusion model.
361
+ timestep (`float`):
362
+ The current discrete timestep in the diffusion chain.
363
+ sample (`torch.Tensor`):
364
+ A current instance of a sample created by the diffusion process.
365
+ eta (`float`):
366
+ The weight of noise for added noise in diffusion step.
367
+ use_clipped_model_output (`bool`, defaults to `False`):
368
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
369
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
370
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
371
+ `use_clipped_model_output` has no effect.
372
+ generator (`torch.Generator`, *optional*):
373
+ A random number generator.
374
+ variance_noise (`torch.Tensor`):
375
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
376
+ itself. Useful for methods such as [`CycleDiffusion`].
377
+ return_dict (`bool`, *optional*, defaults to `True`):
378
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
379
+
380
+ Returns:
381
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
382
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
383
+ tuple is returned where the first element is the sample tensor.
384
+
385
+ """
386
+ if self.num_inference_steps is None:
387
+ raise ValueError(
388
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
389
+ )
390
+
391
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
392
+ # Ideally, read DDIM paper in-detail understanding
393
+
394
+ # Notation (<variable name> -> <name in paper>
395
+ # - pred_noise_t -> e_theta(x_t, t)
396
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
397
+ # - std_dev_t -> sigma_t
398
+ # - eta -> η
399
+ # - pred_sample_direction -> "direction pointing to x_t"
400
+ # - pred_prev_sample -> "x_t-1"
401
+
402
+ # 1. get previous step value (=t-1)
403
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
404
+
405
+ # 2. compute alphas, betas
406
+ alpha_prod_t = self.alphas_cumprod[timestep]
407
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
408
+
409
+ beta_prod_t = 1 - alpha_prod_t
410
+
411
+ if cond_fn is not None:
412
+ model_output = model_output - (1 - alpha_prod_t) ** (0.5) * cond_fn(sample)
413
+
414
+ # 3. compute predicted original sample from predicted noise also called
415
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
416
+ if self.config.prediction_type == "epsilon":
417
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
418
+ pred_epsilon = model_output
419
+ elif self.config.prediction_type == "sample":
420
+ pred_original_sample = model_output
421
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
422
+ elif self.config.prediction_type == "v_prediction":
423
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
424
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
425
+ else:
426
+ raise ValueError(
427
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
428
+ " `v_prediction`"
429
+ )
430
+
431
+ # 4. Clip or threshold "predicted x_0"
432
+ if self.config.thresholding:
433
+ pred_original_sample = self._threshold_sample(pred_original_sample)
434
+ elif self.config.clip_sample:
435
+ pred_original_sample = pred_original_sample.clamp(
436
+ -self.config.clip_sample_range, self.config.clip_sample_range
437
+ )
438
+
439
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
440
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
441
+ variance = self._get_variance(timestep, prev_timestep)
442
+ std_dev_t = eta * variance ** (0.5)
443
+
444
+ if use_clipped_model_output:
445
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
446
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
447
+
448
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
449
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
450
+
451
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
452
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
453
+
454
+ if eta > 0:
455
+ if variance_noise is not None and generator is not None:
456
+ raise ValueError(
457
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
458
+ " `variance_noise` stays `None`."
459
+ )
460
+
461
+ if variance_noise is None:
462
+ variance_noise = randn_tensor(
463
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
464
+ )
465
+ variance = std_dev_t * variance_noise
466
+
467
+ prev_sample = prev_sample + variance
468
+
469
+ if not return_dict:
470
+ return (prev_sample,)
471
+
472
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
473
+
474
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
475
+ def add_noise(
476
+ self,
477
+ original_samples: torch.Tensor,
478
+ noise: torch.Tensor,
479
+ timesteps: torch.IntTensor,
480
+ ) -> torch.Tensor:
481
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
482
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
483
+ # for the subsequent add_noise calls
484
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
485
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
486
+ timesteps = timesteps.to(original_samples.device)
487
+
488
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
489
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
490
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
491
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
492
+
493
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
494
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
495
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
496
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
497
+
498
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
499
+ return noisy_samples
500
+
501
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
502
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
503
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
504
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
505
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
506
+ timesteps = timesteps.to(sample.device)
507
+
508
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
509
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
510
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
511
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
512
+
513
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
514
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
515
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
516
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
517
+
518
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
519
+ return velocity
520
+
521
+ def __len__(self):
522
+ return self.config.num_train_timesteps
var/D3HR/validation/utils/__pycache__/data_utils.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
var/D3HR/validation/utils/__pycache__/data_utils.cpython-37.pyc ADDED
Binary file (7.34 kB). View file
 
var/D3HR/validation/utils/__pycache__/validate_utils.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
var/D3HR/validation/utils/__pycache__/validate_utils.cpython-37.pyc ADDED
Binary file (3.64 kB). View file
 
var/D3HR/validation/utils/data_utils.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision
7
+ from torchvision import transforms
8
+ import json
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ def find_subclasses(spec, nclass, phase=0):
13
+ classes = []
14
+ cls_from = nclass * phase
15
+ cls_to = nclass * (phase + 1)
16
+ if spec == 'woof':
17
+ file_list = './misc/class_woof.txt'
18
+ elif spec == 'im100':
19
+ file_list = './misc/class_100.txt'
20
+ else:
21
+ file_list = './misc/class_indices.txt'
22
+ with open(file_list, 'r') as f:
23
+ class_name = f.readlines()
24
+ for c in class_name:
25
+ c = c.split('\n')[0]
26
+ classes.append(c)
27
+ classes = classes[cls_from:cls_to]
28
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
29
+
30
+ return classes, class_to_idx
31
+
32
+
33
+ def find_original_classes(spec, classes):
34
+ file_list = './misc/class_indices.txt'
35
+ with open(file_list, 'r') as f:
36
+ all_classes = f.readlines()
37
+ all_classes = [class_name.split('\n')[0] for class_name in all_classes]
38
+ original_classes = []
39
+ for class_name in classes:
40
+ original_classes.append(all_classes.index(class_name))
41
+ return original_classes
42
+
43
+
44
+ def load_mapping_imgwoof(mapping_file, names):
45
+ new_mapping = {}
46
+ with open(mapping_file, 'r') as file:
47
+ data = json.load(file)
48
+ if "tiny" in mapping_file:
49
+ for index, line in enumerate(file):
50
+ # 提取每一行的编号(n开头部分)并将行号-1
51
+ key = line.split()[0]
52
+ new_mapping[key] = index
53
+ else:
54
+ new_mapping = {item["wnid"]: names.index(item["name"]) for item in data.values() if item['name'] in names}
55
+ return new_mapping
56
+
57
+
58
+ def load_mapping(mapping_file):
59
+ new_mapping = {}
60
+ with open(mapping_file, 'r') as file:
61
+ data = json.load(file)
62
+ if "tiny" in mapping_file:
63
+ for index, line in enumerate(file):
64
+ # 提取每一行的编号(n开头部分)并将行号-1
65
+ key = line.split()[0]
66
+ new_mapping[key] = index
67
+ else:
68
+ new_mapping = {item["wnid"]: item["index"] for item in data.values()}
69
+ return new_mapping
70
+
71
+
72
+
73
+ def load_mapping_txt(mapping_file):
74
+ wnid_to_index = {}
75
+ with open(mapping_file, 'r') as f:
76
+ for line in f:
77
+ wnid, index = line.strip().split('\t')
78
+ wnid_to_index[wnid] = int(index)
79
+ return wnid_to_index
80
+
81
+ def find_classes(class_file):
82
+ with open(class_file) as r:
83
+ classes = list(map(lambda s: s.strip(), r.readlines()))
84
+
85
+ classes.sort()
86
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
87
+
88
+ return class_to_idx
89
+
90
+ class ImageFolder(Dataset):
91
+ def __init__(self, split=None, txt_file=None, subset=None, mapping_file=None, transform=None):
92
+ super(ImageFolder, self).__init__()
93
+ self.split = split
94
+ self.image_paths = []
95
+ self.targets = []
96
+ self.samples = []
97
+ self.subset = subset
98
+ if self.subset == 'imagenet_1k':
99
+ self.wnid_to_index = load_mapping(mapping_file)
100
+ elif self.subset == 'tinyimagenet':
101
+ self.wnid_to_index = find_classes(mapping_file)
102
+ if split == 'train':
103
+ self._load_from_txt(txt_file)
104
+ else:
105
+ self._load_from_txt(txt_file)
106
+ self.transform = transform
107
+
108
+
109
+ def _load_from_txt(self, txt_file):
110
+ with open(txt_file, "r") as file:
111
+ image_paths = file.readlines()
112
+
113
+ # 去掉每行的换行符
114
+ self.image_paths = [path.strip() for path in image_paths]
115
+ for path in self.image_paths:
116
+ self.samples.append(path)
117
+ if self.subset == 'cifar10' or self.subset == 'cifar100':
118
+ class_index = int(path.split('/')[-2][-3:])
119
+ else:
120
+ # if self.split == 'test':
121
+ # class_index = self.wnid_to_index[path.split('/')[-2]]
122
+ # elif self.split == 'train':
123
+ class_index = self.wnid_to_index[path.split('/')[-2]]
124
+ self.targets.append(class_index)
125
+
126
+
127
+ #combine ten txt
128
+ def _load_from_txt_1(self, txt_file):
129
+
130
+ image_paths_10 = []
131
+ for kk in range(10):
132
+ txt_file=f'/scratch/zhao.lin1/tinyimagenet_finetune_start_step_18_ddim_inversion_10_min_images_{kk}/train.txt'
133
+ with open(txt_file, "r") as file:
134
+ image_paths = file.readlines()
135
+
136
+ image_paths_10.append([path.strip() for path in image_paths])
137
+
138
+
139
+ for kk in range(len(image_paths)):
140
+ number = random.randint(0, 9)
141
+ self.image_paths.append(image_paths_10[number][kk])
142
+ if self.subset == 'cifar10' or self.subset == 'cifar100':
143
+ class_index = int(path.split('/')[-2][-3:])
144
+ else:
145
+ # if self.split == 'test':
146
+ # class_index = self.wnid_to_index[path.split('/')[-2]]
147
+ # elif self.split == 'train':
148
+ class_index = self.wnid_to_index[image_paths_10[number][kk].split('/')[-2]]
149
+ self.targets.append(class_index)
150
+
151
+ def __getitem__(self, index):
152
+ img_path = self.image_paths[index]
153
+ try:
154
+ sample = Image.open(img_path).convert('RGB')
155
+ except Exception as e:
156
+ print(f"Error loading image {img_path}: {e}")
157
+ # Return a black image in case of error
158
+ sample = Image.new('RGB', (256, 256))
159
+ sample = self.transform(sample)
160
+ # class_dir = img_path.split('/')[-2]
161
+ return sample, self.targets[index]
162
+
163
+ def __len__(self):
164
+ return len(self.targets)
165
+
166
+
167
+ class Imagewoof(Dataset):
168
+ def __init__(self, split=None, txt_file=None, subset=None, mapping_file=None, transform=None):
169
+ super(Imagewoof, self).__init__()
170
+ self.split = split
171
+ self.image_paths = []
172
+ self.targets = []
173
+ self.samples = []
174
+ self.subset = subset
175
+ self.names = ["Australian_terrier", "Border_terrier", "Samoyed", "beagle", "Shih-Tzu", "English_foxhound", "Rhodesian_ridgeback", "dingo", "golden_retriever", "Old_English_sheepdog"]
176
+ self.wnid_to_index = load_mapping_imgwoof(mapping_file, self.names)
177
+ self._load_from_txt(txt_file)
178
+ self.transform = transform
179
+
180
+
181
+ def _load_from_txt(self, txt_file):
182
+ with open(txt_file, "r") as file:
183
+ image_paths = file.readlines()
184
+
185
+ # 去掉每行的换行符
186
+ image_paths = [path.strip() for path in image_paths]
187
+ for path in image_paths:
188
+ self.samples.append(path)
189
+ if self.subset == 'cifar10' or self.subset == 'cifar100':
190
+ class_index = int(path.split('/')[-2][-3:])
191
+ else:
192
+ # if self.split == 'test':
193
+ # class_index = self.wnid_to_index[path.split('/')[-2]]
194
+ # elif self.split == 'train':
195
+ if path.split('/')[-2] in list(self.wnid_to_index.keys()):
196
+ class_index = self.wnid_to_index[path.split('/')[-2]]
197
+ self.image_paths.append(path)
198
+ self.targets.append(class_index)
199
+
200
+ def __getitem__(self, index):
201
+ img_path = self.image_paths[index]
202
+ try:
203
+ sample = Image.open(img_path).convert('RGB')
204
+ except Exception as e:
205
+ print(f"Error loading image {img_path}: {e}")
206
+ # Return a black image in case of error
207
+ sample = Image.new('RGB', (256, 256))
208
+ sample = self.transform(sample)
209
+ # class_dir = img_path.split('/')[-2]
210
+ return sample, self.targets[index]
211
+
212
+ def __len__(self):
213
+ return len(self.targets)
214
+
215
+
216
+
217
+ # class ImageFolder(torchvision.datasets.ImageFolder):
218
+ # def __init__(self, nclass, ipc, mem=False, spec='none', phase=0, **kwargs):
219
+ # super(ImageFolder, self).__init__(**kwargs)
220
+ # self.mem = mem
221
+ # self.spec = spec
222
+ # self.classes, self.class_to_idx = find_subclasses(
223
+ # spec=spec, nclass=nclass, phase=phase
224
+ # )
225
+ # self.original_classes = find_original_classes(spec=self.spec, classes=self.classes)
226
+ # self.samples, self.targets = self.load_subset(ipc=ipc)
227
+ # if self.mem:
228
+ # self.samples = [self.loader(path) for path in self.samples]
229
+
230
+ # def load_subset(self, ipc=-1):
231
+ # all_samples = torchvision.datasets.folder.make_dataset(
232
+ # self.root, self.class_to_idx, self.extensions
233
+ # )
234
+ # samples = np.array([item[0] for item in all_samples])
235
+ # targets = np.array([item[1] for item in all_samples])
236
+
237
+ # if ipc == -1:
238
+ # return samples, targets
239
+ # else:
240
+ # sub_samples = []
241
+ # sub_targets = []
242
+ # for c in range(len(self.classes)):
243
+ # c_indices = np.where(targets == c)[0]
244
+ # #random.shuffle(c_indices)
245
+ # sub_samples.extend(samples[c_indices[:ipc]])
246
+ # sub_targets.extend(targets[c_indices[:ipc]])
247
+ # return sub_samples, sub_targets
248
+
249
+ # def __getitem__(self, index):
250
+ # if self.mem:
251
+ # sample = self.samples[index]
252
+ # else:
253
+ # sample = self.loader(self.samples[index])
254
+ # sample = self.transform(sample)
255
+ # return sample, self.targets[index]
256
+
257
+ # def __len__(self):
258
+ # return len(self.targets)
259
+ def random_stitch_crop_4(image):
260
+ """随机从 stitch 的四个子区域中裁剪一个"""
261
+ w, h = image.size # 获取图像的宽和高
262
+ w_half, h_half = w // 2, h // 2
263
+
264
+ # 定义四个区域的坐标
265
+ regions = [
266
+ (0, 0, w_half, h_half), # 左上
267
+ (w_half, 0, w, h_half), # 右上
268
+ (0, h_half, w_half, h), # 左下
269
+ (w_half, h_half, w, h), # 右下
270
+ ]
271
+
272
+ # 随机选择一个区域
273
+ x1, y1, x2, y2 = random.choice(regions)
274
+ return image.crop((x1, y1, x2, y2)) # 裁��并返回
275
+
276
+ def transform_imagenet(args):
277
+ resize_test = [transforms.Resize(args.input_size // 7 * 8), transforms.CenterCrop(args.input_size)]
278
+ # resize_test = [transforms.Resize(args.input_size), transforms.CenterCrop(args.input_size)]
279
+
280
+ cast = [transforms.ToTensor()]
281
+
282
+ aug = [
283
+ # transforms.Resize(224),
284
+ # transforms.Lambda(random_stitch_crop_4),
285
+ # ShufflePatches(args.factor),
286
+ transforms.RandomResizedCrop(
287
+ size=args.input_size,
288
+ # scale=(0.5, 1.0),
289
+ # scale=(1 / args.factor, args.max_scale_crops),
290
+ scale=(0.08, args.max_scale_crops),
291
+ antialias=True,
292
+ ),
293
+ transforms.RandomHorizontalFlip()
294
+ ]
295
+
296
+ normalize = [transforms.Normalize(
297
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
298
+ )]
299
+
300
+ train_transform = transforms.Compose(aug +cast+ normalize)
301
+ test_transform = transforms.Compose(resize_test + cast + normalize)
302
+
303
+ return train_transform, test_transform
304
+
305
+
306
+ sharing_strategy = "file_system"
307
+ torch.multiprocessing.set_sharing_strategy(sharing_strategy)
308
+
309
+
310
+ def set_worker_sharing_strategy(worker_id: int) -> None:
311
+ torch.multiprocessing.set_sharing_strategy(sharing_strategy)
312
+
313
+
314
+ def load_data(args, coreset=False, resize_only=False, mem_flag=True, trainset_only=False):
315
+ train_transform, test_transform = transform_imagenet(args)
316
+ # if len(args.data_dir) == 1:
317
+ # train_dir = os.path.join(args.data_dir[0], 'train')
318
+ # val_dir = os.path.join(args.data_dir[0], 'val')
319
+ # else:
320
+ # train_dir = args.data_dir[0]
321
+ # val_dir = os.path.join(args.data_dir[1], 'val')
322
+
323
+ if resize_only:
324
+ train_transform = transforms.Compose([
325
+ transforms.Resize((512, 512)),
326
+ ])
327
+ elif coreset:
328
+ train_transform = test_transform
329
+
330
+ # train_dataset = ImageFolder(
331
+ # nclass=args.nclass,
332
+ # ipc=args.ipc,
333
+ # mem=mem_flag,
334
+ # spec=args.spec,
335
+ # phase=args.phase,
336
+ # root=train_dir,
337
+ # transform=train_transform,
338
+ # )
339
+
340
+ if args.subset == 'imagewoof':
341
+ # Imagewoor
342
+ train_dataset = Imagewoof(
343
+ split = 'train',
344
+ txt_file=args.txt_file,
345
+ mapping_file=args.mapping_file,
346
+ subset = args.subset,
347
+ transform=train_transform,
348
+ )
349
+ else:
350
+ train_dataset = ImageFolder(
351
+ split = 'train',
352
+ txt_file=args.txt_file,
353
+ mapping_file=args.mapping_file,
354
+ subset = args.subset,
355
+ transform=train_transform,
356
+ )
357
+
358
+
359
+
360
+ if trainset_only:
361
+ return train_dataset
362
+
363
+ train_loader = torch.utils.data.DataLoader(
364
+ train_dataset,
365
+ batch_size=args.batch_size,
366
+ shuffle=True,
367
+ num_workers=24,
368
+ pin_memory=True,
369
+ worker_init_fn=set_worker_sharing_strategy,
370
+ )
371
+ if args.subset == 'cifar10':
372
+ val_dataset = torchvision.datasets.CIFAR10(root='/scratch/zhao.lin1/', train=False, download=True, transform=test_transform)
373
+ elif args.subset == 'cifar100':
374
+ val_dataset = torchvision.datasets.CIFAR100(root='/scratch/zhao.lin1/', train=False, download=True, transform=test_transform)
375
+ elif args.subset == 'imagewoof':
376
+ val_dataset = Imagewoof(
377
+ split = 'test',
378
+ txt_file=args.val_txt_file,
379
+ mapping_file=args.mapping_file,
380
+ subset = args.subset,
381
+ transform=test_transform,
382
+ )
383
+ else:
384
+ val_dataset = ImageFolder(
385
+ split = 'test',
386
+ txt_file=args.val_txt_file,
387
+ mapping_file=args.mapping_file,
388
+ subset = args.subset,
389
+ transform=test_transform,
390
+ )
391
+
392
+
393
+
394
+ val_loader = torch.utils.data.DataLoader(
395
+ val_dataset,
396
+ batch_size=256,
397
+ shuffle=False,
398
+ num_workers=24,
399
+ pin_memory=True,
400
+ worker_init_fn=set_worker_sharing_strategy,
401
+ )
402
+ print("load data successfully")
403
+
404
+ return train_dataset, train_loader, val_loader
405
+
406
+
407
+ class ShufflePatches(torch.nn.Module):
408
+ def __init__(self, factor):
409
+ super().__init__()
410
+ self.factor = factor
411
+
412
+ def shuffle_weight(self, img, factor):
413
+ h, w = img.shape[1:]
414
+ tw = w // factor
415
+ patches = []
416
+ for i in range(factor):
417
+ i = i * tw
418
+ if i != factor - 1:
419
+ patches.append(img[..., i : i + tw])
420
+ else:
421
+ patches.append(img[..., i:])
422
+ random.shuffle(patches)
423
+ img = torch.cat(patches, -1)
424
+ return img
425
+
426
+ def forward(self, img):
427
+ img = self.shuffle_weight(img, self.factor)
428
+ img = img.permute(0, 2, 1)
429
+ img = self.shuffle_weight(img, self.factor)
430
+ img = img.permute(0, 2, 1)
431
+ return img
var/D3HR/validation/utils/download.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Functions for downloading pre-trained DiT models
9
+ """
10
+ from torchvision.datasets.utils import download_url
11
+ import torch
12
+ import os
13
+
14
+
15
+ pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'}
16
+
17
+
18
+ def find_model(model_name):
19
+ """
20
+ Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
21
+ """
22
+ if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
23
+ return download_model(model_name)
24
+ else: # Load a custom DiT checkpoint:
25
+ assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}'
26
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
27
+ if "ema" in checkpoint: # supports checkpoints from train.py
28
+ checkpoint = checkpoint["ema"]
29
+ return checkpoint
30
+
31
+
32
+ def download_model(model_name):
33
+ """
34
+ Downloads a pre-trained DiT model from the web.
35
+ """
36
+ assert model_name in pretrained_models
37
+ local_path = f'pretrained_models/{model_name}'
38
+ if not os.path.isfile(local_path):
39
+ os.makedirs('pretrained_models', exist_ok=True)
40
+ web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
41
+ download_url(web_path, 'pretrained_models')
42
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
43
+ return model
44
+
45
+
46
+ if __name__ == "__main__":
47
+ # Download all DiT checkpoints
48
+ for model in pretrained_models:
49
+ download_model(model)
50
+ print('Done.')
var/D3HR/validation/utils/syn_utils_dit.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import transforms as T
8
+ from torchvision.models import resnet18
9
+ from transformers import CLIPModel, AutoTokenizer
10
+
11
+ from .download import find_model
12
+ from diffusion import create_diffusion
13
+ from models.dit_models import DiT_models
14
+ from diffusers.models import AutoencoderKL
15
+
16
+
17
+ class SupConLoss(nn.Module):
18
+ def __init__(self, temperature=0.05, base_temperatue=0.05):
19
+ super(SupConLoss, self).__init__()
20
+ self.temperature = temperature
21
+ self.base_temperature = base_temperatue
22
+
23
+ def forward(self, image_features, text_features, text_labels):
24
+ logits = (image_features @ text_features.T) / self.temperature
25
+ logits_max, _ = torch.max(logits, dim=1, keepdim=True)
26
+ logits = logits - logits_max.detach()
27
+
28
+ exp_logits = torch.exp(logits) * text_labels
29
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
30
+ mean_log_prob_pos = ((1 - text_labels) * log_prob).sum(1) / (1 - text_labels).sum(1)
31
+ loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
32
+ loss = loss.mean()
33
+
34
+ return loss
35
+
36
+
37
+ class ImageSynthesizer(object):
38
+ def __init__(self, args):
39
+ self.vae = AutoencoderKL.from_pretrained(args.vae_path).to('cuda')
40
+ self.clip_model = CLIPModel.from_pretrained('laion/CLIP-ViT-L-14-laion2B-s32B-b82K').to('cuda')
41
+ self.clip_tokenizer = AutoTokenizer.from_pretrained('laion/CLIP-ViT-L-14-laion2B-s32B-b82K')
42
+
43
+ # DiT model
44
+ assert args.dit_image_size % 8 == 0, 'Image size must be divisible by 8'
45
+ latent_size = args.dit_image_size // 8
46
+ self.latent_size = latent_size
47
+ self.dit = DiT_models[args.dit_model](
48
+ input_size=latent_size,
49
+ num_classes=args.num_dit_classes
50
+ ).to('cuda')
51
+ ckpt_path = args.ckpt
52
+ state_dict = find_model(ckpt_path)
53
+ self.dit.load_state_dict(state_dict, strict=False)
54
+
55
+ # Diffusion
56
+ self.diffusion = create_diffusion(str(args.diffusion_steps))
57
+
58
+ # Class description
59
+ self.description_file = args.description_path
60
+ self.load_class_description()
61
+
62
+ self.cfg_scale = args.cfg_scale
63
+ self.clip_alpha = args.clip_alpha
64
+ self.cls_alpha = args.cls_alpha
65
+ self.num_pos_samples = 5
66
+ self.num_neg_samples = args.num_neg_samples
67
+ self.clip_normalize = T.Normalize(
68
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
69
+ )
70
+ self.contrast_criterion = SupConLoss()
71
+ self.neg_policy = args.neg_policy
72
+
73
+ def load_class_description(self):
74
+ with open(self.description_file, 'r') as fp:
75
+ descriptions = json.load(fp)
76
+ self.class_names = {}
77
+ self.descriptions = {}
78
+
79
+ for class_index, (class_name, description) in descriptions.items():
80
+ self.class_names[class_index] = class_name
81
+ self.descriptions[class_index] = description
82
+
83
+ self.class_indices = list(self.class_names.keys())
84
+ self.class_name_list = list(self.class_names.values())
85
+ self.neighbors = {}
86
+ with torch.no_grad():
87
+ class_name_feat = self.extract_text_feature(self.class_name_list)
88
+ name_sims = (class_name_feat @ class_name_feat.T).cpu()
89
+ name_sims -= torch.eye(len(name_sims))
90
+ name_sims = name_sims.numpy()
91
+ for class_index, sim_indices in zip(self.class_indices, name_sims):
92
+ self.neighbors[class_index] = list(sim_indices)
93
+
94
+ def extract_text_feature(self, descriptions):
95
+ input_text = self.clip_tokenizer(descriptions, padding=True, return_tensors='pt').to('cuda')
96
+ text_feature = self.clip_model.get_text_features(**input_text)
97
+ text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
98
+ return text_feature
99
+
100
+ def cond_fn(self, x, t, y=None, text_features=None, contrastive=True, **kwargs):
101
+ with torch.enable_grad():
102
+ x = nn.Parameter(x).requires_grad_()
103
+ real_x, _ = x.chunk(2, dim=0)
104
+ pseudo_image = self.vae.decode(real_x / 0.18215, return_dict=False)[0]
105
+ pseudo_image = T.Resize((224, 224))(pseudo_image) * 0.5 + 0.5
106
+ pseudo_image = self.clip_normalize(pseudo_image)
107
+
108
+ # Extract image embedding
109
+ clip_feat_image = self.clip_model.get_image_features(pseudo_image)
110
+ clip_feat_image = clip_feat_image / clip_feat_image.norm(dim=-1, keepdim=True)
111
+
112
+ # Extract text embedding
113
+ clip_feat_text_pos, clip_feat_text_neg = torch.split(
114
+ text_features, [self.num_pos_samples, self.num_neg_samples]
115
+ )
116
+
117
+ if contrastive:
118
+ clip_loss = self.contrast_criterion(
119
+ clip_feat_image, torch.cat((clip_feat_text_pos, clip_feat_text_neg), dim=0),
120
+ torch.cat((torch.zeros(self.num_pos_samples), torch.ones(self.num_neg_samples))).unsqueeze(0).cuda()
121
+ )
122
+ else:
123
+ clip_loss = 1. - (clip_feat_image @ clip_feat_text_pos.T).mean()
124
+
125
+ loss = self.clip_alpha * clip_loss
126
+
127
+ return -torch.autograd.grad(loss, x, allow_unused=True)[0]
128
+
129
+ def sample(self, original_label, class_index, batch_size=1, device=None):
130
+ z = torch.randn(batch_size, 4, self.latent_size, self.latent_size, device=device)
131
+ y = torch.tensor([original_label] * batch_size, device=device)
132
+
133
+ # classifier-free guidance
134
+ z = torch.cat([z, z], 0)
135
+ y_null = torch.tensor([1000] * batch_size, device=device)
136
+ y = torch.cat([y, y_null], 0)
137
+
138
+ pos_descriptions = self.descriptions[class_index]
139
+ pos_descriptions = [self.class_names[class_index]+' with '+description for description in pos_descriptions]
140
+ neg_descriptions = []
141
+ if self.neg_policy == 'random':
142
+ neg_classes = random.choices(self.class_indices, k=self.num_neg_samples)
143
+ elif self.neg_policy == 'similar':
144
+ max_indices = np.argsort(self.neighbors[class_index])[-self.num_neg_samples:]
145
+ neg_classes = [self.class_indices[max_index] for max_index in max_indices]
146
+ else:
147
+ neg_classes = random.choices(self.class_indices, self.neighbors[class_index], k=self.num_neg_samples)
148
+ for rand_index in neg_classes:
149
+ neg_descriptions.append(self.class_names[rand_index] + ' with ' + self.descriptions[rand_index][np.random.randint(0, 4)])
150
+ all_descriptions = pos_descriptions + neg_descriptions
151
+ text_features = self.extract_text_feature(all_descriptions)
152
+
153
+ model_kwargs = dict(
154
+ y=y, cfg_scale=self.cfg_scale,
155
+ text_features=text_features, contrastive=True
156
+ )
157
+
158
+ def get_samples(z):
159
+ samples = self.diffusion.ddim_sample_loop(
160
+ self.dit.forward_with_cfg, z.shape, z, clip_denoised=False,
161
+ model_kwargs=model_kwargs, progress=False, device=device,
162
+ cond_fn=self.cond_fn
163
+ )
164
+ samples, _ = samples.chunk(2, dim=0)
165
+ samples = self.vae.decode(samples / 0.18215).sample
166
+ samples = T.Resize((224, 224))(samples)
167
+
168
+ return samples
169
+
170
+ samples = get_samples(z)
171
+
172
+ return samples
var/D3HR/validation/utils/syn_utils_img2img.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from torchvision import transforms as T
6
+ from transformers import CLIPModel, AutoTokenizer
7
+
8
+ from misc import prompts
9
+ from models.scheduling_ddim import DDIMScheduler
10
+
11
+
12
+ class SupConLoss(torch.nn.Module):
13
+ def __init__(self, temperature=0.05, base_temperatue=0.05):
14
+ super(SupConLoss, self).__init__()
15
+ self.temperature = temperature
16
+ self.base_temperature = base_temperatue
17
+
18
+ def forward(self, image_features, text_features, text_labels):
19
+ logits = (image_features @ text_features.T) / self.temperature
20
+ logits_max, _ = torch.max(logits, dim=1, keepdim=True)
21
+ logits = logits - logits_max.detach()
22
+
23
+ exp_logits = torch.exp(logits) * text_labels
24
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
25
+ mean_log_prob_pos = ((1 - text_labels) * log_prob).sum(1) / (1 - text_labels).sum(1)
26
+ loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
27
+ loss = loss.mean()
28
+
29
+ return loss
30
+
31
+
32
+ class ImageSynthesizer(object):
33
+ def __init__(self, args):
34
+ self.init_clip()
35
+ self.description_file = args.description_path
36
+ self.load_class_description()
37
+ self.contrast_criterion = SupConLoss()
38
+
39
+ self.prompts = prompts.prompt_templates
40
+ self.diffusion_steps = args.diffusion_steps
41
+ self.clip_alpha = args.clip_alpha
42
+ self.num_neg_samples = args.num_neg_samples
43
+ self.neg_policy = args.neg_policy
44
+
45
+ def load_class_description(self):
46
+ with open(self.description_file, 'r') as fp:
47
+ descriptions = json.load(fp)
48
+ self.class_names = {}
49
+ self.descriptions = {}
50
+
51
+ for class_index, (class_name, description) in descriptions.items():
52
+ self.class_names[class_index] = class_name
53
+ self.descriptions[class_index] = description
54
+
55
+ self.class_indices = list(self.class_names.keys())
56
+ self.class_name_list = list(self.class_names.values())
57
+ self.neighbors = {}
58
+ with torch.no_grad():
59
+ class_name_feat = self.extract_clip_text_embed(self.class_name_list)
60
+ name_sims = (class_name_feat @ class_name_feat.T).cpu()
61
+ name_sims -= torch.eye(len(name_sims))
62
+ name_sims = name_sims.numpy()
63
+ for class_index, sim_indices in zip(self.class_indices, name_sims):
64
+ self.neighbors[class_index] = list(sim_indices)
65
+
66
+ def init_clip(self):
67
+ self.clip_model = CLIPModel.from_pretrained('laion/CLIP-ViT-L-14-laion2B-s32B-b82K').to('cuda')
68
+ self.clip_tokenizer = AutoTokenizer.from_pretrained('laion/CLIP-ViT-L-14-laion2B-s32B-b82K')
69
+ self.clip_normalize = T.Normalize(
70
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
71
+ )
72
+
73
+ def extract_clip_image_embed(self, image):
74
+ image = self.clip_transform(image).unsqueeze(0)
75
+ clip_feat = self.clip_model.encode_image(image)
76
+ return clip_feat
77
+
78
+ def extract_clip_text_embed(self, descriptions):
79
+ input_text = self.clip_tokenizer(descriptions, padding=True, return_tensors='pt').to('cuda')
80
+ text_features = self.clip_model.get_text_features(**input_text)
81
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
82
+ return text_features
83
+
84
+ def cond_fn(self, sample, **kwargs):
85
+ with torch.enable_grad():
86
+ sample = torch.nn.Parameter(sample).requires_grad_()
87
+ pseudo_image = self.pipe.vae.decode(sample / 0.18215, return_dict=False)[0]
88
+ pseudo_image = T.Resize((224, 224))(pseudo_image) * 0.5 + 0.5
89
+ pseudo_image = self.clip_normalize(pseudo_image)
90
+
91
+ # Extract image embedding
92
+ clip_feat_image = self.clip_model.get_image_features(pseudo_image)
93
+ clip_feat_image = clip_feat_image / clip_feat_image.norm(dim=-1, keepdim=True)
94
+
95
+ clip_loss = self.contrast_criterion(
96
+ clip_feat_image, self.current_desc_embeddings,
97
+ torch.cat((torch.zeros(5), torch.ones(len(self.current_desc_embeddings) - 5))).unsqueeze(0).cuda()
98
+ )
99
+
100
+ loss = self.clip_alpha * clip_loss
101
+
102
+ return -torch.autograd.grad(loss, sample, allow_unused=True)[0]
103
+
104
+ def init_img2img(self):
105
+ from models.pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
106
+ self.pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
107
+ 'radames/stable-diffusion-2-1-unclip-img2img'
108
+ )
109
+ self.pipe.scheduler = DDIMScheduler.from_pretrained('radames/stable-diffusion-2-1-unclip-img2img', subfolder='scheduler')
110
+ self.pipe = self.pipe.to('cuda')
111
+
112
+ def sample_img2img(self, image, class_index):
113
+ class_name = self.class_names[class_index]
114
+ class_name = class_name.split(',')[0]
115
+ pos_descriptions = self.descriptions[class_index]
116
+ prompt = random.choice(self.prompts).format(class_name, '')
117
+
118
+ pos_descriptions = [self.class_names[class_index]+' with '+description for description in pos_descriptions]
119
+ neg_descriptions = []
120
+ if self.neg_policy == 'random':
121
+ neg_classes = random.choices(self.class_indices, k=self.num_neg_samples)
122
+ elif self.neg_policy == 'similar':
123
+ max_indices = np.argsort(self.neighbors[class_index])[-self.num_neg_samples:]
124
+ neg_classes = [self.class_indices[max_index] for max_index in max_indices]
125
+ else:
126
+ neg_classes = random.choices(self.class_indices, self.neighbors[class_index], k=self.num_neg_samples)
127
+ for rand_index in neg_classes:
128
+ neg_descriptions.append(self.class_names[rand_index] + 'with ' + self.descriptions[rand_index][np.random.randint(0, 4)])
129
+ self.current_desc_embeddings = self.extract_clip_text_embed(pos_descriptions + neg_descriptions)
130
+ new_image = self.pipe(image=image, prompt=prompt, cond_fn=self.cond_fn, num_inference_steps=self.diffusion_steps).images[0]
131
+
132
+ new_image = new_image.resize((224, 224))
133
+
134
+ return new_image