Dervlex commited on
Commit
926efac
·
verified ·
1 Parent(s): b98ab32

Upload 17 files

Browse files
ComfyUI-TiledDiffusion/.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ backup*
163
+ **/.DS_Store
164
+ **/.venv
165
+ **/.vscode
166
+
167
+ .*
ComfyUI-TiledDiffusion/.patches.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def calc_cond_batch(model, conds, x_in, timestep, model_options):
2
+
3
+ if 'tiled_diffusion' not in model_options:
4
+ return calc_cond_batch_original_tiled_diffusion_875b8c8d(model, conds, x_in, timestep, model_options)
5
+ out_conds = []
6
+ out_counts = []
7
+ to_run = []
8
+
9
+ for i in range(len(conds)):
10
+ out_conds.append(torch.zeros_like(x_in))
11
+ out_counts.append(torch.ones_like(x_in) * 1e-37)
12
+
13
+ cond = conds[i]
14
+ if cond is not None:
15
+ for x in cond:
16
+ p = get_area_and_mult(x, x_in, timestep)
17
+ if p is None:
18
+ continue
19
+
20
+ to_run += [(p, i)]
21
+
22
+ while len(to_run) > 0:
23
+ first = to_run[0]
24
+ first_shape = first[0][0].shape
25
+ to_batch_temp = []
26
+ for x in range(len(to_run)):
27
+ if can_concat_cond(to_run[x][0], first[0]):
28
+ to_batch_temp += [x]
29
+
30
+ to_batch_temp.reverse()
31
+ to_batch = to_batch_temp[:1]
32
+
33
+ free_memory = model_management.get_free_memory(x_in.device)
34
+ for i in range(1, len(to_batch_temp) + 1):
35
+ batch_amount = to_batch_temp[:len(to_batch_temp)//i]
36
+ input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
37
+ if model.memory_required(input_shape) * 1.5 < free_memory:
38
+ to_batch = batch_amount
39
+ break
40
+
41
+ input_x = []
42
+ mult = []
43
+ c = []
44
+ cond_or_uncond = []
45
+ area = []
46
+ control = None
47
+ patches = None
48
+ for x in to_batch:
49
+ o = to_run.pop(x)
50
+ p = o[0]
51
+ input_x.append(p.input_x)
52
+ mult.append(p.mult)
53
+ c.append(p.conditioning)
54
+ area.append(p.area)
55
+ cond_or_uncond.append(o[1])
56
+ control = p.control
57
+ patches = p.patches
58
+
59
+ batch_chunks = len(cond_or_uncond)
60
+ input_x = torch.cat(input_x)
61
+ c = cond_cat(c)
62
+ timestep_ = torch.cat([timestep] * batch_chunks)
63
+
64
+ if control is not None:
65
+ c['control'] = control if 'tiled_diffusion' in model_options else control.get_control(input_x, timestep_, c, len(cond_or_uncond))
66
+
67
+ transformer_options = {}
68
+ if 'transformer_options' in model_options:
69
+ transformer_options = model_options['transformer_options'].copy()
70
+
71
+ if patches is not None:
72
+ if "patches" in transformer_options:
73
+ cur_patches = transformer_options["patches"].copy()
74
+ for p in patches:
75
+ if p in cur_patches:
76
+ cur_patches[p] = cur_patches[p] + patches[p]
77
+ else:
78
+ cur_patches[p] = patches[p]
79
+ transformer_options["patches"] = cur_patches
80
+ else:
81
+ transformer_options["patches"] = patches
82
+
83
+ transformer_options["cond_or_uncond"] = cond_or_uncond[:]
84
+ transformer_options["sigmas"] = timestep
85
+
86
+ c['transformer_options'] = transformer_options
87
+
88
+ if 'model_function_wrapper' in model_options:
89
+ output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
90
+ else:
91
+ output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
92
+
93
+ for o in range(batch_chunks):
94
+ cond_index = cond_or_uncond[o]
95
+ a = area[o]
96
+ if a is None:
97
+ out_conds[cond_index] += output[o] * mult[o]
98
+ out_counts[cond_index] += mult[o]
99
+ else:
100
+ out_c = out_conds[cond_index]
101
+ out_cts = out_counts[cond_index]
102
+ dims = len(a) // 2
103
+ for i in range(dims):
104
+ out_c = out_c.narrow(i + 2, a[i + dims], a[i])
105
+ out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
106
+ out_c += output[o] * mult[o]
107
+ out_cts += mult[o]
108
+
109
+ for i in range(len(out_conds)):
110
+ out_conds[i] /= out_counts[i]
111
+
112
+ return out_conds
113
+ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
114
+ # reshape and GAP the attention map
115
+ _, hw1, hw2 = attn.shape
116
+ b, _, lh, lw = x0.shape
117
+ attn = attn.reshape(b, -1, hw1, hw2)
118
+ # Global Average Pool
119
+ mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
120
+
121
+ def calc_closest_factors(a):
122
+ for b in range(int(math.sqrt(a)), 0, -1):
123
+ if a%b == 0:
124
+ c = a // b
125
+ return (b,c)
126
+ m = calc_closest_factors(hw1)
127
+ mh = max(m) if lh > lw else min(m)
128
+ mw = m[1] if mh == m[0] else m[0]
129
+ mid_shape = mh, mw
130
+
131
+ # Reshape
132
+ mask = (
133
+ mask.reshape(b, *mid_shape)
134
+ .unsqueeze(1)
135
+ .type(attn.dtype)
136
+ )
137
+ # Upsample
138
+ mask = F.interpolate(mask, (lh, lw))
139
+
140
+ blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
141
+ blurred = blurred * mask + x0 * (1 - mask)
142
+ return blurred
143
+
144
+ def pre_run_control(model, conds):
145
+ s = model.model_sampling
146
+
147
+ def find_outer_instance(target:str, target_type):
148
+ import inspect
149
+ frame = inspect.currentframe()
150
+ i = 0
151
+ while frame and i < 7:
152
+ if (found:=frame.f_locals.get(target, None)) is not None:
153
+ if isinstance(found, target_type):
154
+ return found
155
+ frame = frame.f_back
156
+ i += 1
157
+ return None
158
+ from comfy.model_patcher import ModelPatcher
159
+ if (_model:=find_outer_instance('model', ModelPatcher)) is not None:
160
+ if (model_function_wrapper:=_model.model_options.get('model_function_wrapper', None)) is not None:
161
+ import sys
162
+ tiled_diffusion = sys.modules.get('ComfyUI-TiledDiffusion.tiled_diffusion', None)
163
+ if tiled_diffusion is None:
164
+ for key in sys.modules:
165
+ if 'tiled_diffusion' in key:
166
+ tiled_diffusion = sys.modules[key]
167
+ break
168
+ if (AbstractDiffusion:=getattr(tiled_diffusion, 'AbstractDiffusion', None)) is not None:
169
+ if isinstance(model_function_wrapper, AbstractDiffusion):
170
+ model_function_wrapper.reset()
171
+
172
+ for t in range(len(conds)):
173
+ x = conds[t]
174
+
175
+ timestep_start = None
176
+ timestep_end = None
177
+ percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
178
+ if 'control' in x:
179
+ try: x['control'].cleanup()
180
+ except Exception: ...
181
+ x['control'].pre_run(model, percent_to_timestep_function)
182
+ def _set_position(self, boxes, masks, positive_embeddings):
183
+ objs = self.position_net(boxes, masks, positive_embeddings)
184
+ def func(x, extra_options):
185
+ key = extra_options["transformer_index"]
186
+ module = self.module_list[key]
187
+ return module(x, objs.to(device=x.device, dtype=x.dtype))
188
+ return func
189
+
ComfyUI-TiledDiffusion/README.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tiled Diffusion & VAE for ComfyUI
2
+
3
+ Check out the [SD-WebUI extension](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/) for more information.
4
+
5
+ This extension enables **large image drawing & upscaling with limited VRAM** via the following techniques:
6
+
7
+ 1. Two SOTA diffusion tiling algorithms: [Mixture of Diffusers](https://github.com/albarji/mixture-of-diffusers) <a href="https://arxiv.org/abs/2302.02412"><img width="32" alt="Mixture of Diffusers Paper" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/b753b7f6-f9c0-405d-bace-792b9bbce5d5"></a> and [MultiDiffusion](https://github.com/omerbt/MultiDiffusion) <a href="https://arxiv.org/abs/2302.08113"><img width="32" alt="MultiDiffusion Paper" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/b753b7f6-f9c0-405d-bace-792b9bbce5d5"></a>
8
+ 2. pkuliyi2015 & Kahsolt's Tiled VAE algorithm.
9
+ 3. ~~pkuliyi2015 & Kahsolt's TIled Noise Inversion for better upscaling.~~
10
+
11
+ > [!NOTE]
12
+ > Sizes/dimensions are in pixels and then converted to latent-space sizes.
13
+
14
+
15
+ ## Features
16
+ - [x] SDXL model support
17
+ - [x] ControlNet support
18
+ - [ ] ~~StableSR support~~
19
+ - [ ] ~~Tiled Noise Inversion~~
20
+ - [x] Tiled VAE
21
+ - [ ] Regional Prompt Control
22
+ - [x] Img2img upscale
23
+ - [x] Ultra-Large image generation
24
+
25
+ ## Tiled Diffusion
26
+
27
+ <div align="center">
28
+ <img width="500" alt="Tiled_Diffusion" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/7cb897a3-a645-426f-8742-d6ba5cf04b64">
29
+ </div>
30
+
31
+ > [!TIP]
32
+ > Set `tile_overlap` to 0 and `denoise` to 1 to see the tile seams and then adjust the options to your needs. Also, increase `tile_batch_size` to increase speed (if your machine can handle it).
33
+
34
+ | Name | Description |
35
+ |-------------------|--------------------------------------------------------------|
36
+ | `method` | Tiling [strategy](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/blob/fbb24736c9bc374c7f098f82b575fcd14a73936a/scripts/tilediffusion.py#L39-L46). `MultiDiffusion` or `Mixture of Diffusers`. |
37
+ | `tile_width` | Tile's width |
38
+ | `tile_height` | Tile's height |
39
+ | `tile_overlap` | Tile's overlap |
40
+ | `tile_batch_size` | The number of tiles to process in a batch |
41
+
42
+ ### How can I specify the tiles' arrangement?
43
+
44
+ If you have the [Math Expression](https://github.com/pythongosssss/ComfyUI-Custom-Scripts#math-expression) node (or something similar), you can use that to pass in the latent that's passed in your KSampler and divide the `tile_height`/`tile_width` by the number of rows/columns you want.
45
+
46
+ `C` = number of columns you want
47
+ `R` = number of rows you want
48
+
49
+ `pixel width of input image or latent // C` = `tile_width`
50
+ `pixel height of input image or latent // R` = `tile_height`
51
+
52
+ <img width="800" alt="Tile_arrangement" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/9952e7d8-909e-436f-a284-c00f0fb71665">
53
+
54
+ ## Tiled VAE
55
+
56
+ <div align="center">
57
+ <img width="900" alt="Tiled_VAE" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/b5850e03-2cac-49ce-b1fe-a67906bf4c9d">
58
+ </div>
59
+
60
+ <br>
61
+
62
+ The recommended tile sizes are given upon the creation of the node based on the available VRAM.
63
+
64
+ > [!NOTE]
65
+ > Enabling `fast` for the decoder may produce images with slightly higher contrast and brightness.
66
+
67
+
68
+ | Name | Description |
69
+ |-------------|----------------------------------------------------------------------------------------------------------------------------------------------|
70
+ | `tile_size` | <blockquote>The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.</blockquote> |
71
+ | `fast` | <blockquote><p>When Fast Mode is disabled:</p> <ol> <li>The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.</li> <li>When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.</li> <li>After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues. </li> <li>A zigzag execution order is used to reduce unnecessary data transfer.</li> </ol> <p>When Fast Mode is enabled:</p> <ol> <li>The original input is downsampled and passed to a separate task queue.</li> <li>Its group norm parameters are recorded and used by all tiles&#39; task queues.</li> <li>Each tile is separately processed without any RAM-VRAM data transfer.</li> </ol> <p>After all tiles are processed, tiles are written to a result buffer and returned.</p></blockquote> |
72
+ | `color_fix` | <blockquote>Only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.</blockquote><p>Only for the encoder. Can restore colors if tiles are too small.</p> |
73
+
74
+
75
+
76
+ ## Workflows
77
+
78
+ The following images can be loaded in ComfyUI.
79
+
80
+
81
+ <div align="center">
82
+ <img alt="ComfyUI_07501_" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/c3713cfb-e083-4df4-a310-9467827ee666">
83
+ <p>Simple upscale.</p>
84
+ </div>
85
+
86
+ <br>
87
+
88
+ <div align="center">
89
+
90
+ <img alt="ComfyUI_07503_" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/b681b617-4bb1-49e5-b85a-ef5a0f6e4830">
91
+ <p>4x upscale. 3 passes.</p>
92
+ </div>
93
+
94
+ ## Citation
95
+
96
+ ```bibtex
97
+ @article{jimenez2023mixtureofdiffusers,
98
+ title={Mixture of Diffusers for scene composition and high resolution image generation},
99
+ author={Álvaro Barbero Jiménez},
100
+ journal={arXiv preprint arXiv:2302.02412},
101
+ year={2023}
102
+ }
103
+ ```
104
+
105
+ ```bibtex
106
+ @article{bar2023multidiffusion,
107
+ title={MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation},
108
+ author={Bar-Tal, Omer and Yariv, Lior and Lipman, Yaron and Dekel, Tali},
109
+ journal={arXiv preprint arXiv:2302.08113},
110
+ year={2023}
111
+ }
112
+ ```
ComfyUI-TiledDiffusion/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .tiled_diffusion import NODE_CLASS_MAPPINGS as TD_NCM, NODE_DISPLAY_NAME_MAPPINGS as TD_NDCM
2
+ from .tiled_vae import NODE_CLASS_MAPPINGS as TV_NCM, NODE_DISPLAY_NAME_MAPPINGS as TV_NDCM
3
+ NODE_CLASS_MAPPINGS = {}
4
+ NODE_DISPLAY_NAME_MAPPINGS = {}
5
+ NODE_CLASS_MAPPINGS.update(TD_NCM)
6
+ NODE_DISPLAY_NAME_MAPPINGS.update(TD_NDCM)
7
+ NODE_CLASS_MAPPINGS.update(TV_NCM)
8
+ NODE_DISPLAY_NAME_MAPPINGS.update(TV_NDCM)
9
+ __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
ComfyUI-TiledDiffusion/__pycache__/.patches.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
ComfyUI-TiledDiffusion/__pycache__/.patches.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
ComfyUI-TiledDiffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (431 Bytes). View file
 
ComfyUI-TiledDiffusion/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (717 Bytes). View file
 
ComfyUI-TiledDiffusion/__pycache__/tiled_diffusion.cpython-310.pyc ADDED
Binary file (19 kB). View file
 
ComfyUI-TiledDiffusion/__pycache__/tiled_diffusion.cpython-311.pyc ADDED
Binary file (37 kB). View file
 
ComfyUI-TiledDiffusion/__pycache__/tiled_vae.cpython-310.pyc ADDED
Binary file (24.9 kB). View file
 
ComfyUI-TiledDiffusion/__pycache__/tiled_vae.cpython-311.pyc ADDED
Binary file (45.7 kB). View file
 
ComfyUI-TiledDiffusion/__pycache__/utils.cpython-310.pyc ADDED
Binary file (7.83 kB). View file
 
ComfyUI-TiledDiffusion/__pycache__/utils.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
ComfyUI-TiledDiffusion/tiled_diffusion.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ from torch import Tensor
4
+ import comfy.model_management
5
+ from comfy.model_patcher import ModelPatcher
6
+ import comfy.model_patcher
7
+ from comfy.model_base import BaseModel
8
+ from typing import List, Union, Tuple, Dict
9
+ from nodes import ImageScale
10
+ import comfy.utils
11
+ from comfy.controlnet import ControlNet, T2IAdapter
12
+
13
+ opt_C = 4
14
+ opt_f = 8
15
+
16
+ def ceildiv(big, small):
17
+ # Correct ceiling division that avoids floating-point errors and importing math.ceil.
18
+ return -(big // -small)
19
+
20
+ from enum import Enum
21
+ class BlendMode(Enum): # i.e. LayerType
22
+ FOREGROUND = 'Foreground'
23
+ BACKGROUND = 'Background'
24
+
25
+ class Processing: ...
26
+ class Device: ...
27
+ devices = Device()
28
+ devices.device = comfy.model_management.get_torch_device()
29
+
30
+ def null_decorator(fn):
31
+ def wrapper(*args, **kwargs):
32
+ return fn(*args, **kwargs)
33
+ return wrapper
34
+
35
+ keep_signature = null_decorator
36
+ controlnet = null_decorator
37
+ stablesr = null_decorator
38
+ grid_bbox = null_decorator
39
+ custom_bbox = null_decorator
40
+ noise_inverse = null_decorator
41
+
42
+ class BBox:
43
+ ''' grid bbox '''
44
+
45
+ def __init__(self, x:int, y:int, w:int, h:int):
46
+ self.x = x
47
+ self.y = y
48
+ self.w = w
49
+ self.h = h
50
+ self.box = [x, y, x+w, y+h]
51
+ self.slicer = slice(None), slice(None), slice(y, y+h), slice(x, x+w)
52
+
53
+ def __getitem__(self, idx:int) -> int:
54
+ return self.box[idx]
55
+
56
+ def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]:
57
+ cols = ceildiv((w - overlap) , (tile_w - overlap))
58
+ rows = ceildiv((h - overlap) , (tile_h - overlap))
59
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
60
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
61
+
62
+ bbox_list: List[BBox] = []
63
+ weight = torch.zeros((1, 1, h, w), device=devices.device, dtype=torch.float32)
64
+ for row in range(rows):
65
+ y = min(int(row * dy), h - tile_h)
66
+ for col in range(cols):
67
+ x = min(int(col * dx), w - tile_w)
68
+
69
+ bbox = BBox(x, y, tile_w, tile_h)
70
+ bbox_list.append(bbox)
71
+ weight[bbox.slicer] += init_weight
72
+
73
+ return bbox_list, weight
74
+
75
+ class CustomBBox(BBox):
76
+ ''' region control bbox '''
77
+ pass
78
+
79
+ class AbstractDiffusion:
80
+ def __init__(self):
81
+ self.method = self.__class__.__name__
82
+ self.pbar = None
83
+
84
+
85
+ self.w: int = 0
86
+ self.h: int = 0
87
+ self.tile_width: int = None
88
+ self.tile_height: int = None
89
+ self.tile_overlap: int = None
90
+ self.tile_batch_size: int = None
91
+
92
+ # cache. final result of current sampling step, [B, C=4, H//8, W//8]
93
+ # avoiding overhead of creating new tensors and weight summing
94
+ self.x_buffer: Tensor = None
95
+ # self.w: int = int(self.p.width // opt_f) # latent size
96
+ # self.h: int = int(self.p.height // opt_f)
97
+ # weights for background & grid bboxes
98
+ self._weights: Tensor = None
99
+ # self.weights: Tensor = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)
100
+ self._init_grid_bbox = None
101
+ self._init_done = None
102
+
103
+ # count the step correctly
104
+ self.step_count = 0
105
+ self.inner_loop_count = 0
106
+ self.kdiff_step = -1
107
+
108
+ # ext. Grid tiling painting (grid bbox)
109
+ self.enable_grid_bbox: bool = False
110
+ self.tile_w: int = None
111
+ self.tile_h: int = None
112
+ self.tile_bs: int = None
113
+ self.num_tiles: int = None
114
+ self.num_batches: int = None
115
+ self.batched_bboxes: List[List[BBox]] = []
116
+
117
+ # ext. Region Prompt Control (custom bbox)
118
+ self.enable_custom_bbox: bool = False
119
+ self.custom_bboxes: List[CustomBBox] = []
120
+ # self.cond_basis: Cond = None
121
+ # self.uncond_basis: Uncond = None
122
+ # self.draw_background: bool = True # by default we draw major prompts in grid tiles
123
+ # self.causal_layers: bool = None
124
+
125
+ # ext. ControlNet
126
+ self.enable_controlnet: bool = False
127
+ # self.controlnet_script: ModuleType = None
128
+ self.control_tensor_batch_dict = {}
129
+ self.control_tensor_batch: List[List[Tensor]] = [[]]
130
+ # self.control_params: Dict[str, Tensor] = None # {}
131
+ self.control_params: Dict[Tuple, List[List[Tensor]]] = {}
132
+ self.control_tensor_cpu: bool = None
133
+ self.control_tensor_custom: List[List[Tensor]] = []
134
+
135
+ self.draw_background: bool = True # by default we draw major prompts in grid tiles
136
+ self.control_tensor_cpu = False
137
+ self.weights = None
138
+ self.imagescale = ImageScale()
139
+
140
+ def reset(self):
141
+ tile_width = self.tile_width
142
+ tile_height = self.tile_height
143
+ tile_overlap = self.tile_overlap
144
+ tile_batch_size = self.tile_batch_size
145
+ self.__init__()
146
+ self.tile_width = tile_width
147
+ self.tile_height = tile_height
148
+ self.tile_overlap = tile_overlap
149
+ self.tile_batch_size = tile_batch_size
150
+
151
+ def repeat_tensor(self, x:Tensor, n:int, concat=False, concat_to=0) -> Tensor:
152
+ ''' repeat the tensor on it's first dim '''
153
+ if n == 1: return x
154
+ B = x.shape[0]
155
+ r_dims = len(x.shape) - 1
156
+ if B == 1: # batch_size = 1 (not `tile_batch_size`)
157
+ shape = [n] + [-1] * r_dims # [N, -1, ...]
158
+ return x.expand(shape) # `expand` is much lighter than `tile`
159
+ else:
160
+ if concat:
161
+ return torch.cat([x for _ in range(n)], dim=0)[:concat_to]
162
+ shape = [n] + [1] * r_dims # [N, 1, ...]
163
+ return x.repeat(shape)
164
+ def update_pbar(self):
165
+ if self.pbar.n >= self.pbar.total:
166
+ self.pbar.close()
167
+ else:
168
+ # self.pbar.update()
169
+ sampling_step = 20
170
+ if self.step_count == sampling_step:
171
+ self.inner_loop_count += 1
172
+ if self.inner_loop_count < self.total_bboxes:
173
+ self.pbar.update()
174
+ else:
175
+ self.step_count = sampling_step
176
+ self.inner_loop_count = 0
177
+ def reset_buffer(self, x_in:Tensor):
178
+ # Judge if the shape of x_in is the same as the shape of x_buffer
179
+ if self.x_buffer is None or self.x_buffer.shape != x_in.shape:
180
+ self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype)
181
+ else:
182
+ self.x_buffer.zero_()
183
+
184
+ @grid_bbox
185
+ def init_grid_bbox(self, tile_w:int, tile_h:int, overlap:int, tile_bs:int):
186
+ # if self._init_grid_bbox is not None: return
187
+ # self._init_grid_bbox = True
188
+ self.weights = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)
189
+ self.enable_grid_bbox = True
190
+
191
+ self.tile_w = min(tile_w, self.w)
192
+ self.tile_h = min(tile_h, self.h)
193
+ overlap = max(0, min(overlap, min(tile_w, tile_h) - 4))
194
+ # split the latent into overlapped tiles, then batching
195
+ # weights basically indicate how many times a pixel is painted
196
+ bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights())
197
+ self.weights += weights
198
+ self.num_tiles = len(bboxes)
199
+ self.num_batches = ceildiv(self.num_tiles , tile_bs)
200
+ self.tile_bs = ceildiv(len(bboxes) , self.num_batches) # optimal_batch_size
201
+ self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)]
202
+
203
+ @grid_bbox
204
+ def get_tile_weights(self) -> Union[Tensor, float]:
205
+ return 1.0
206
+
207
+ @noise_inverse
208
+ def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int):
209
+ self.noise_inverse_enabled = True
210
+ self.noise_inverse_steps = steps
211
+ self.noise_inverse_retouch = float(retouch)
212
+ self.noise_inverse_renoise_strength = float(renoise_strength)
213
+ self.noise_inverse_renoise_kernel = int(renoise_kernel)
214
+ self.noise_inverse_set_cache = set_cache_callback
215
+ self.noise_inverse_get_cache = get_cache_callback
216
+
217
+ def init_done(self):
218
+ '''
219
+ Call this after all `init_*`, settings are done, now perform:
220
+ - settings sanity check
221
+ - pre-computations, cache init
222
+ - anything thing needed before denoising starts
223
+ '''
224
+
225
+ # if self._init_done is not None: return
226
+ # self._init_done = True
227
+ self.total_bboxes = 0
228
+ if self.enable_grid_bbox: self.total_bboxes += self.num_batches
229
+ if self.enable_custom_bbox: self.total_bboxes += len(self.custom_bboxes)
230
+ assert self.total_bboxes > 0, "Nothing to paint! No background to draw and no custom bboxes were provided."
231
+
232
+ # sampling_steps = _steps
233
+ # self.pbar = tqdm(total=(self.total_bboxes) * sampling_steps, desc=f"{self.method} Sampling: ")
234
+
235
+ @controlnet
236
+ def prepare_controlnet_tensors(self, refresh:bool=False, tensor=None):
237
+ ''' Crop the control tensor into tiles and cache them '''
238
+ if not refresh:
239
+ if self.control_tensor_batch is not None or self.control_params is not None: return
240
+ tensors = [tensor]
241
+ self.org_control_tensor_batch = tensors
242
+ self.control_tensor_batch = []
243
+ for i in range(len(tensors)):
244
+ control_tile_list = []
245
+ control_tensor = tensors[i]
246
+ for bboxes in self.batched_bboxes:
247
+ single_batch_tensors = []
248
+ for bbox in bboxes:
249
+ if len(control_tensor.shape) == 3:
250
+ control_tensor.unsqueeze_(0)
251
+ control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f]
252
+ single_batch_tensors.append(control_tile)
253
+ control_tile = torch.cat(single_batch_tensors, dim=0)
254
+ if self.control_tensor_cpu:
255
+ control_tile = control_tile.cpu()
256
+ control_tile_list.append(control_tile)
257
+ self.control_tensor_batch.append(control_tile_list)
258
+
259
+ if len(self.custom_bboxes) > 0:
260
+ custom_control_tile_list = []
261
+ for bbox in self.custom_bboxes:
262
+ if len(control_tensor.shape) == 3:
263
+ control_tensor.unsqueeze_(0)
264
+ control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f]
265
+ if self.control_tensor_cpu:
266
+ control_tile = control_tile.cpu()
267
+ custom_control_tile_list.append(control_tile)
268
+ self.control_tensor_custom.append(custom_control_tile_list)
269
+
270
+ @controlnet
271
+ def switch_controlnet_tensors(self, batch_id:int, x_batch_size:int, tile_batch_size:int, is_denoise=False):
272
+ # if not self.enable_controlnet: return
273
+ if self.control_tensor_batch is None: return
274
+ # self.control_params = [0]
275
+
276
+ # for param_id in range(len(self.control_params)):
277
+ for param_id in range(len(self.control_tensor_batch)):
278
+ # tensor that was concatenated in `prepare_controlnet_tensors`
279
+ control_tile = self.control_tensor_batch[param_id][batch_id]
280
+ # broadcast to latent batch size
281
+ if x_batch_size > 1: # self.is_kdiff:
282
+ all_control_tile = []
283
+ for i in range(tile_batch_size):
284
+ this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size
285
+ all_control_tile.append(torch.cat(this_control_tile, dim=0))
286
+ control_tile = torch.cat(all_control_tile, dim=0) # [:x_tile.shape[0]]
287
+ self.control_tensor_batch[param_id][batch_id] = control_tile
288
+ # else:
289
+ # control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1])
290
+ # self.control_params[param_id].hint_cond = control_tile.to(devices.device)
291
+
292
+ def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List, bboxes, batch_size: int, batch_id: int):
293
+ control: ControlNet = c_in['control']
294
+ param_id = -1 # current controlnet & previous_controlnets
295
+ tuple_key = tuple(cond_or_uncond) + tuple(x_shape)
296
+ while control is not None:
297
+ param_id += 1
298
+ PH, PW = self.h*8, self.w*8
299
+
300
+ if self.control_params.get(tuple_key, None) is None:
301
+ self.control_params[tuple_key] = [[None]]
302
+ val = self.control_params[tuple_key]
303
+ if param_id+1 >= len(val):
304
+ val.extend([[None] for _ in range(param_id+1)])
305
+ if len(self.batched_bboxes) >= len(val[param_id]):
306
+ val[param_id].extend([[None] for _ in range(len(self.batched_bboxes))])
307
+
308
+ while len(self.control_params[tuple_key]) <= param_id:
309
+ self.control_params[tuple_key].extend([None])
310
+ # print('extending param_id')
311
+
312
+ while len(self.control_params[tuple_key][param_id]) <= batch_id:
313
+ self.control_params[tuple_key][param_id].extend([None])
314
+ # print('extending batch_id')
315
+
316
+ # Below is taken from comfy.controlnet.py, but we need to additionally tile the cnets.
317
+ # if statement: eager eval. first time when cond_hint is None.
318
+ if self.refresh or control.cond_hint is None or not isinstance(self.control_params[tuple_key][param_id][batch_id], Tensor):
319
+ dtype = getattr(control, 'manual_cast_dtype', None)
320
+ if dtype is None: dtype = getattr(getattr(control, 'control_model', None), 'dtype', None)
321
+ if dtype is None: dtype = x_dtype
322
+ if isinstance(control, T2IAdapter):
323
+ width, height = control.scale_image_to(PW, PH)
324
+ control.cond_hint = comfy.utils.common_upscale(control.cond_hint_original, width, height, 'nearest-exact', "center").float().to(control.device)
325
+ if control.channels_in == 1 and control.cond_hint.shape[1] > 1:
326
+ control.cond_hint = torch.mean(control.cond_hint, 1, keepdim=True)
327
+ elif control.__class__.__name__ == 'ControlLLLiteAdvanced':
328
+ if control.sub_idxs is not None and control.cond_hint_original.shape[0] >= control.full_latent_length:
329
+ control.cond_hint = comfy.utils.common_upscale(control.cond_hint_original[control.sub_idxs], PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
330
+ else:
331
+ if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
332
+ control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
333
+ else:
334
+ control.cond_hint = comfy.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
335
+ else:
336
+ if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
337
+ control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
338
+ else:
339
+ control.cond_hint = comfy.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', 'center').to(dtype=dtype, device=control.device)
340
+
341
+ # Broadcast then tile
342
+ #
343
+ # Below can be in the parent's if clause because self.refresh will trigger on resolution change, e.g. cause of ConditioningSetArea
344
+ # so that particular case isn't cached atm.
345
+ cond_hint_pre_tile = control.cond_hint
346
+ if control.cond_hint.shape[0] < batch_size :
347
+ cond_hint_pre_tile = self.repeat_tensor(control.cond_hint, ceildiv(batch_size, control.cond_hint.shape[0]))[:batch_size]
348
+ cns = [cond_hint_pre_tile[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] for bbox in bboxes]
349
+ control.cond_hint = torch.cat(cns, dim=0)
350
+ self.control_params[tuple_key][param_id][batch_id]=control.cond_hint
351
+ else:
352
+ control.cond_hint = self.control_params[tuple_key][param_id][batch_id]
353
+ control = control.previous_controlnet
354
+
355
+ import numpy as np
356
+ from numpy import pi, exp, sqrt
357
+ def gaussian_weights(tile_w:int, tile_h:int) -> Tensor:
358
+ '''
359
+ Copy from the original implementation of Mixture of Diffusers
360
+ https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py
361
+ This generates gaussian weights to smooth the noise of each tile.
362
+ This is critical for this method to work.
363
+ '''
364
+ f = lambda x, midpoint, var=0.01: exp(-(x-midpoint)*(x-midpoint) / (tile_w*tile_w) / (2*var)) / sqrt(2*pi*var)
365
+ x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] # -1 because index goes from 0 to latent_width - 1
366
+ y_probs = [f(y, tile_h / 2) for y in range(tile_h)]
367
+
368
+ w = np.outer(y_probs, x_probs)
369
+ return torch.from_numpy(w).to(devices.device, dtype=torch.float32)
370
+
371
+ class CondDict: ...
372
+
373
+ class MultiDiffusion(AbstractDiffusion):
374
+
375
+ @torch.no_grad()
376
+ def __call__(self, model_function: BaseModel.apply_model, args: dict):
377
+ x_in: Tensor = args["input"]
378
+ t_in: Tensor = args["timestep"]
379
+ c_in: dict = args["c"]
380
+ cond_or_uncond: List = args["cond_or_uncond"]
381
+ c_crossattn: Tensor = c_in['c_crossattn']
382
+
383
+ N, C, H, W = x_in.shape
384
+
385
+ # comfyui can feed in a latent that's a different size cause of SetArea, so we'll refresh in that case.
386
+ self.refresh = False
387
+ if self.weights is None or self.h != H or self.w != W:
388
+ self.h, self.w = H, W
389
+ self.refresh = True
390
+ self.init_grid_bbox(self.tile_width, self.tile_height, self.tile_overlap, self.tile_batch_size)
391
+ # init everything done, perform sanity check & pre-computations
392
+ self.init_done()
393
+ self.h, self.w = H, W
394
+ # clear buffer canvas
395
+ self.reset_buffer(x_in)
396
+
397
+ # Background sampling (grid bbox)
398
+ if self.draw_background:
399
+ for batch_id, bboxes in enumerate(self.batched_bboxes):
400
+ if comfy.model_management.processing_interrupted():
401
+ # self.pbar.close()
402
+ return x_in
403
+
404
+ # batching & compute tiles
405
+ x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW]
406
+ n_rep = len(bboxes)
407
+ ts_tile = self.repeat_tensor(t_in, n_rep)
408
+ cond_tile = self.repeat_tensor(c_crossattn, n_rep)
409
+ c_tile = c_in.copy()
410
+ c_tile['c_crossattn'] = cond_tile
411
+ if 'time_context' in c_in:
412
+ c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep)
413
+ for key in c_tile:
414
+ if key in ['y', 'c_concat']:
415
+ icond = c_tile[key]
416
+ if icond.shape[2:] == (self.h, self.w):
417
+ c_tile[key] = torch.cat([icond[bbox.slicer] for bbox in bboxes])
418
+ else:
419
+ c_tile[key] = self.repeat_tensor(icond, n_rep)
420
+
421
+ # controlnet tiling
422
+ # self.switch_controlnet_tensors(batch_id, N, len(bboxes))
423
+ if 'control' in c_in:
424
+ control=c_in['control']
425
+ self.process_controlnet(x_tile.shape, x_tile.dtype, c_in, cond_or_uncond, bboxes, N, batch_id)
426
+ c_tile['control'] = control.get_control(x_tile, ts_tile, c_tile, len(cond_or_uncond))
427
+
428
+ # stablesr tiling
429
+ # self.switch_stablesr_tensors(batch_id)
430
+
431
+ x_tile_out = model_function(x_tile, ts_tile, **c_tile)
432
+
433
+ for i, bbox in enumerate(bboxes):
434
+ self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :]
435
+ del x_tile_out, x_tile, ts_tile, c_tile
436
+
437
+ # update progress bar
438
+ # self.update_pbar()
439
+
440
+ # Averaging background buffer
441
+ x_out = torch.where(self.weights > 1, self.x_buffer / self.weights, self.x_buffer)
442
+
443
+ return x_out
444
+
445
+ class MixtureOfDiffusers(AbstractDiffusion):
446
+ """
447
+ Mixture-of-Diffusers Implementation
448
+ https://github.com/albarji/mixture-of-diffusers
449
+ """
450
+
451
+ def __init__(self, *args, **kwargs):
452
+ super().__init__(*args, **kwargs)
453
+
454
+ # weights for custom bboxes
455
+ self.custom_weights: List[Tensor] = []
456
+ self.get_weight = gaussian_weights
457
+
458
+ def init_done(self):
459
+ super().init_done()
460
+ # The original gaussian weights can be extremely small, so we rescale them for numerical stability
461
+ self.rescale_factor = 1 / self.weights
462
+ # Meanwhile, we rescale the custom weights in advance to save time of slicing
463
+ for bbox_id, bbox in enumerate(self.custom_bboxes):
464
+ if bbox.blend_mode == BlendMode.BACKGROUND:
465
+ self.custom_weights[bbox_id] *= self.rescale_factor[bbox.slicer]
466
+
467
+ @grid_bbox
468
+ def get_tile_weights(self) -> Tensor:
469
+ # weights for grid bboxes
470
+ # if not hasattr(self, 'tile_weights'):
471
+ # x_in can change sizes cause of ConditioningSetArea, so we have to recalcualte each time
472
+ self.tile_weights = self.get_weight(self.tile_w, self.tile_h)
473
+ return self.tile_weights
474
+
475
+ @torch.no_grad()
476
+ def __call__(self, model_function: BaseModel.apply_model, args: dict):
477
+ x_in: Tensor = args["input"]
478
+ t_in: Tensor = args["timestep"]
479
+ c_in: dict = args["c"]
480
+ cond_or_uncond: List= args["cond_or_uncond"]
481
+ c_crossattn: Tensor = c_in['c_crossattn']
482
+
483
+ N, C, H, W = x_in.shape
484
+
485
+ self.refresh = False
486
+ # self.refresh = True
487
+ if self.weights is None or self.h != H or self.w != W:
488
+ self.h, self.w = H, W
489
+ self.refresh = True
490
+ self.init_grid_bbox(self.tile_width, self.tile_height, self.tile_overlap, self.tile_batch_size)
491
+ # init everything done, perform sanity check & pre-computations
492
+ self.init_done()
493
+ self.h, self.w = H, W
494
+ # clear buffer canvas
495
+ self.reset_buffer(x_in)
496
+
497
+ # self.pbar = tqdm(total=(self.total_bboxes) * sampling_steps, desc=f"{self.method} Sampling: ")
498
+ # self.pbar = tqdm(total=len(self.batched_bboxes), desc=f"{self.method} Sampling: ")
499
+
500
+ # Global sampling
501
+ if self.draw_background:
502
+ for batch_id, bboxes in enumerate(self.batched_bboxes): # batch_id is the `Latent tile batch size`
503
+ if comfy.model_management.processing_interrupted():
504
+ # self.pbar.close()
505
+ return x_in
506
+
507
+ # batching
508
+ x_tile_list = []
509
+ t_tile_list = []
510
+ icond_map = {}
511
+ # tcond_tile_list = []
512
+ # icond_tile_list = []
513
+ # vcond_tile_list = []
514
+ # control_list = []
515
+ for bbox in bboxes:
516
+ x_tile_list.append(x_in[bbox.slicer])
517
+ t_tile_list.append(t_in)
518
+ if isinstance(c_in, dict):
519
+ # tcond
520
+ # tcond_tile = c_crossattn #self.get_tcond(c_in) # cond, [1, 77, 768]
521
+ # tcond_tile_list.append(tcond_tile)
522
+ # present in sdxl
523
+ for key in ['y', 'c_concat']:
524
+ if key in c_in:
525
+ icond=c_in[key] # self.get_icond(c_in)
526
+ if icond.shape[2:] == (self.h, self.w):
527
+ icond = icond[bbox.slicer]
528
+ if icond_map.get(key, None) is None:
529
+ icond_map[key] = []
530
+ icond_map[key].append(icond)
531
+ # # vcond:
532
+ # vcond = self.get_vcond(c_in)
533
+ # vcond_tile_list.append(vcond)
534
+ else:
535
+ print('>> [WARN] not supported, make an issue on github!!')
536
+ n_rep = len(bboxes)
537
+ x_tile = torch.cat(x_tile_list, dim=0) # differs each
538
+ t_tile = self.repeat_tensor(t_in, n_rep) # just repeat
539
+ tcond_tile = self.repeat_tensor(c_crossattn, n_rep) # just repeat
540
+ c_tile = c_in.copy()
541
+ c_tile['c_crossattn'] = tcond_tile
542
+ if 'time_context' in c_in:
543
+ c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep) # just repeat
544
+ for key in c_tile:
545
+ if key in ['y', 'c_concat']:
546
+ icond_tile = torch.cat(icond_map[key], dim=0) # differs each
547
+ c_tile[key] = icond_tile
548
+ # vcond_tile = torch.cat(vcond_tile_list, dim=0) if None not in vcond_tile_list else None # just repeat
549
+
550
+ # controlnet
551
+ # self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True)
552
+ if 'control' in c_in:
553
+ control=c_in['control']
554
+ self.process_controlnet(x_tile.shape, x_tile.dtype, c_in, cond_or_uncond, bboxes, N, batch_id)
555
+ c_tile['control'] = control.get_control(x_tile, t_tile, c_tile, len(cond_or_uncond))
556
+
557
+ # stablesr
558
+ # self.switch_stablesr_tensors(batch_id)
559
+
560
+ # denoising: here the x is the noise
561
+ x_tile_out = model_function(x_tile, t_tile, **c_tile)
562
+
563
+ # de-batching
564
+ for i, bbox in enumerate(bboxes):
565
+ # These weights can be calcluated in advance, but will cost a lot of vram
566
+ # when you have many tiles. So we calculate it here.
567
+ w = self.tile_weights * self.rescale_factor[bbox.slicer]
568
+ self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] * w
569
+ del x_tile_out, x_tile, t_tile, c_tile
570
+
571
+ # self.update_pbar()
572
+ # self.pbar.update()
573
+ # self.pbar.close()
574
+ x_out = self.x_buffer
575
+
576
+ return x_out
577
+
578
+ from .utils import hook_all
579
+ hook_all()
580
+
581
+ MAX_RESOLUTION=8192
582
+ class TiledDiffusion():
583
+ @classmethod
584
+ def INPUT_TYPES(s):
585
+ return {"required": {"model": ("MODEL", ),
586
+ "method": (["MultiDiffusion", "Mixture of Diffusers"], {"default": "Mixture of Diffusers"}),
587
+ # "tile_width": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
588
+ "tile_width": ("INT", {"default": 96*opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
589
+ # "tile_height": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
590
+ "tile_height": ("INT", {"default": 96*opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
591
+ "tile_overlap": ("INT", {"default": 8*opt_f, "min": 0, "max": 256*opt_f, "step": 4*opt_f}),
592
+ "tile_batch_size": ("INT", {"default": 4, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
593
+ }}
594
+ RETURN_TYPES = ("MODEL",)
595
+ FUNCTION = "apply"
596
+ CATEGORY = "_for_testing"
597
+
598
+ def apply(self, model: ModelPatcher, method, tile_width, tile_height, tile_overlap, tile_batch_size):
599
+ if method == "Mixture of Diffusers":
600
+ implement = MixtureOfDiffusers()
601
+ else:
602
+ implement = MultiDiffusion()
603
+
604
+ # if noise_inversion:
605
+ # get_cache_callback = self.noise_inverse_get_cache
606
+ # set_cache_callback = None # lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, steps, retouch)
607
+ # implement.init_noise_inverse(steps, retouch, get_cache_callback, set_cache_callback, renoise_strength, renoise_kernel_size)
608
+
609
+ implement.tile_width = tile_width // opt_f
610
+ implement.tile_height = tile_height // opt_f
611
+ implement.tile_overlap = tile_overlap // opt_f
612
+ implement.tile_batch_size = tile_batch_size
613
+ # implement.init_grid_bbox(tile_width, tile_height, tile_overlap, tile_batch_size)
614
+ # # init everything done, perform sanity check & pre-computations
615
+ # implement.init_done()
616
+ # hijack the behaviours
617
+ # implement.hook()
618
+ model = model.clone()
619
+ model.set_model_unet_function_wrapper(implement)
620
+ model.model_options['tiled_diffusion'] = True
621
+ return (model,)
622
+
623
+ class NoiseInversion():
624
+ @classmethod
625
+ def INPUT_TYPES(s):
626
+ return {"required": {"model": ("MODEL", ),
627
+ "positive": ("CONDITIONING", ),
628
+ "negative": ("CONDITIONING", ),
629
+ "latent_image": ("LATENT", ),
630
+ "image": ("IMAGE", ),
631
+ "steps": ("INT", {"default": 10, "min": 1, "max": 208, "step": 1}),
632
+ "retouch": ("FLOAT", {"default": 1, "min": 1, "max": 100, "step": 0.1}),
633
+ "renoise_strength": ("FLOAT", {"default": 1, "min": 1, "max": 2, "step": 0.01}),
634
+ "renoise_kernel_size": ("INT", {"default": 2, "min": 2, "max": 512, "step": 1}),
635
+ }}
636
+ RETURN_TYPES = ("LATENT",)
637
+ FUNCTION = "sample"
638
+ CATEGORY = "sampling"
639
+ def sample(self, model: ModelPatcher, positive, negative,
640
+ latent_image, image, steps, retouch, renoise_strength, renoise_kernel_size):
641
+ return (latent_image,)
642
+
643
+ NODE_CLASS_MAPPINGS = {
644
+ "TiledDiffusion": TiledDiffusion,
645
+ # "NoiseInversion": NoiseInversion,
646
+ }
647
+ NODE_DISPLAY_NAME_MAPPINGS = {
648
+ "TiledDiffusion": "Tiled Diffusion",
649
+ # "NoiseInversion": "Noise Inversion",
650
+ }
ComfyUI-TiledDiffusion/tiled_vae.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # ------------------------------------------------------------------------
3
+ #
4
+ # Tiled VAE
5
+ #
6
+ # Introducing a revolutionary new optimization designed to make
7
+ # the VAE work with giant images on limited VRAM!
8
+ # Say goodbye to the frustration of OOM and hello to seamless output!
9
+ #
10
+ # ------------------------------------------------------------------------
11
+ #
12
+ # This script is a wild hack that splits the image into tiles,
13
+ # encodes each tile separately, and merges the result back together.
14
+ #
15
+ # Advantages:
16
+ # - The VAE can now work with giant images on limited VRAM
17
+ # (~10 GB for 8K images!)
18
+ # - The merged output is completely seamless without any post-processing.
19
+ #
20
+ # Drawbacks:
21
+ # - NaNs always appear in for 8k images when you use fp16 (half) VAE
22
+ # You must use --no-half-vae to disable half VAE for that giant image.
23
+ # - The gradient calculation is not compatible with this hack. It
24
+ # will break any backward() or torch.autograd.grad() that passes VAE.
25
+ # (But you can still use the VAE to generate training data.)
26
+ #
27
+ # How it works:
28
+ # 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
29
+ # 2. When Fast Mode is disabled:
30
+ # 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
31
+ # 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
32
+ # 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
33
+ # 4. A zigzag execution order is used to reduce unnecessary data transfer.
34
+ # 3. When Fast Mode is enabled:
35
+ # 1. The original input is downsampled and passed to a separate task queue.
36
+ # 2. Its group norm parameters are recorded and used by all tiles' task queues.
37
+ # 3. Each tile is separately processed without any RAM-VRAM data transfer.
38
+ # 4. After all tiles are processed, tiles are written to a result buffer and returned.
39
+ # Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
40
+ #
41
+ # Enjoy!
42
+ #
43
+ # @Author: LI YI @ Nanyang Technological University - Singapore
44
+ # @Date: 2023-03-02
45
+ # @License: CC BY-NC-SA 4.0
46
+ #
47
+ # Please give https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111
48
+ # a star if you like the project!
49
+ #
50
+ # -------------------------------------------------------------------------
51
+ '''
52
+
53
+ import gc
54
+ import math
55
+ from time import time
56
+ from tqdm import tqdm
57
+
58
+ import torch
59
+ import torch.version
60
+ import torch.nn.functional as F
61
+ # import gradio as gr
62
+
63
+ # import modules.scripts as scripts
64
+ # from .modules import devices
65
+ # from modules.shared import state
66
+ # from modules.ui import gr_show
67
+ # from modules.processing import opt_f
68
+ # from modules.sd_vae_approx import cheap_approximation
69
+ # from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock
70
+
71
+ # from tile_utils.attn import get_attn_func
72
+ # from tile_utils.typing import Processing
73
+
74
+ import comfy
75
+ import comfy.model_management
76
+ from comfy.model_management import processing_interrupted
77
+ import contextlib
78
+
79
+ opt_C = 4
80
+ opt_f = 8
81
+ is_sdxl = False
82
+ disable_nan_check = True
83
+
84
+ class Device: ...
85
+ devices = Device()
86
+ devices.device = comfy.model_management.get_torch_device()
87
+ devices.cpu = torch.device('cpu')
88
+ devices.torch_gc = lambda: comfy.model_management.soft_empty_cache()
89
+ devices.get_optimal_device = lambda: comfy.model_management.get_torch_device()
90
+
91
+ class NansException(Exception): ...
92
+ def test_for_nans(x, where):
93
+ if disable_nan_check:
94
+ return
95
+ if not torch.all(torch.isnan(x)).item():
96
+ return
97
+ if where == "unet":
98
+ message = "A tensor with all NaNs was produced in Unet."
99
+ if comfy.model_management.unet_dtype(x.device) != torch.float32:
100
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
101
+ elif where == "vae":
102
+ message = "A tensor with all NaNs was produced in VAE."
103
+ if comfy.model_management.unet_dtype(x.device) != torch.float32 and comfy.model_management.vae_dtype() != torch.float32:
104
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
105
+ else:
106
+ message = "A tensor with all NaNs was produced."
107
+ message += " Use --disable-nan-check commandline argument to disable this check."
108
+ raise NansException(message)
109
+
110
+ def _autocast(disable=False):
111
+ if disable:
112
+ return contextlib.nullcontext()
113
+
114
+ if comfy.model_management.unet_dtype() == torch.float32 or comfy.model_management.get_torch_device() == torch.device("mps"): # or shared.cmd_opts.precision == "full":
115
+ return contextlib.nullcontext()
116
+
117
+ # only cuda
118
+ autocast_device = comfy.model_management.get_autocast_device(comfy.model_management.get_torch_device())
119
+ return torch.autocast(autocast_device)
120
+
121
+ def without_autocast(disable=False):
122
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
123
+
124
+ devices.test_for_nans = test_for_nans
125
+ devices.autocast = _autocast
126
+ devices.without_autocast = without_autocast
127
+
128
+ def cheap_approximation(sample):
129
+ # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
130
+
131
+ if is_sdxl:
132
+ coeffs = [
133
+ [ 0.3448, 0.4168, 0.4395],
134
+ [-0.1953, -0.0290, 0.0250],
135
+ [ 0.1074, 0.0886, -0.0163],
136
+ [-0.3730, -0.2499, -0.2088],
137
+ ]
138
+ else:
139
+ coeffs = [
140
+ [ 0.298, 0.207, 0.208],
141
+ [ 0.187, 0.286, 0.173],
142
+ [-0.158, 0.189, 0.264],
143
+ [-0.184, -0.271, -0.473],
144
+ ]
145
+
146
+ coefs = torch.tensor(coeffs).to(sample.device)
147
+
148
+ x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
149
+
150
+ return x_sample
151
+
152
+ def get_rcmd_enc_tsize():
153
+ if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
154
+ total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
155
+ if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072
156
+ elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048
157
+ elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536
158
+ else: ENCODER_TILE_SIZE = 960
159
+ else: ENCODER_TILE_SIZE = 512
160
+ return ENCODER_TILE_SIZE
161
+
162
+
163
+ def get_rcmd_dec_tsize():
164
+ if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
165
+ total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
166
+ if total_memory > 30*1000: DECODER_TILE_SIZE = 256
167
+ elif total_memory > 16*1000: DECODER_TILE_SIZE = 192
168
+ elif total_memory > 12*1000: DECODER_TILE_SIZE = 128
169
+ elif total_memory > 8*1000: DECODER_TILE_SIZE = 96
170
+ else: DECODER_TILE_SIZE = 64
171
+ else: DECODER_TILE_SIZE = 64
172
+ return DECODER_TILE_SIZE
173
+
174
+
175
+ def inplace_nonlinearity(x):
176
+ # Test: fix for Nans
177
+ return F.silu(x, inplace=True)
178
+
179
+ def _attn_forward(self, x):
180
+ # From comfy.Idm.modules.diffusionmodules.model.AttnBlock.forward
181
+ # However, the residual & normalization are removed and computed separately.
182
+ h_ = x
183
+ q = self.q(h_)
184
+ k = self.k(h_)
185
+ v = self.v(h_)
186
+ h_ = self.optimized_attention(q, k, v)
187
+ h_ = self.proj_out(h_)
188
+ return h_
189
+
190
+ def get_attn_func():
191
+ return _attn_forward
192
+
193
+ def attn2task(task_queue, net):
194
+
195
+ attn_forward = get_attn_func()
196
+ task_queue.append(('store_res', lambda x: x))
197
+ task_queue.append(('pre_norm', net.norm))
198
+ task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
199
+ task_queue.append(['add_res', None])
200
+
201
+
202
+ def resblock2task(queue, block):
203
+ """
204
+ Turn a ResNetBlock into a sequence of tasks and append to the task queue
205
+
206
+ @param queue: the target task queue
207
+ @param block: ResNetBlock
208
+
209
+ """
210
+ if block.in_channels != block.out_channels:
211
+ if block.use_conv_shortcut:
212
+ queue.append(('store_res', block.conv_shortcut))
213
+ else:
214
+ queue.append(('store_res', block.nin_shortcut))
215
+ else:
216
+ queue.append(('store_res', lambda x: x))
217
+ queue.append(('pre_norm', block.norm1))
218
+ queue.append(('silu', inplace_nonlinearity))
219
+ queue.append(('conv1', block.conv1))
220
+ queue.append(('pre_norm', block.norm2))
221
+ queue.append(('silu', inplace_nonlinearity))
222
+ queue.append(('conv2', block.conv2))
223
+ queue.append(['add_res', None])
224
+
225
+
226
+ def build_sampling(task_queue, net, is_decoder):
227
+ """
228
+ Build the sampling part of a task queue
229
+ @param task_queue: the target task queue
230
+ @param net: the network
231
+ @param is_decoder: currently building decoder or encoder
232
+ """
233
+ if is_decoder:
234
+ resblock2task(task_queue, net.mid.block_1)
235
+ attn2task(task_queue, net.mid.attn_1)
236
+ resblock2task(task_queue, net.mid.block_2)
237
+ resolution_iter = reversed(range(net.num_resolutions))
238
+ block_ids = net.num_res_blocks + 1
239
+ condition = 0
240
+ module = net.up
241
+ func_name = 'upsample'
242
+ else:
243
+ resolution_iter = range(net.num_resolutions)
244
+ block_ids = net.num_res_blocks
245
+ condition = net.num_resolutions - 1
246
+ module = net.down
247
+ func_name = 'downsample'
248
+
249
+ for i_level in resolution_iter:
250
+ for i_block in range(block_ids):
251
+ resblock2task(task_queue, module[i_level].block[i_block])
252
+ if i_level != condition:
253
+ task_queue.append((func_name, getattr(module[i_level], func_name)))
254
+
255
+ if not is_decoder:
256
+ resblock2task(task_queue, net.mid.block_1)
257
+ attn2task(task_queue, net.mid.attn_1)
258
+ resblock2task(task_queue, net.mid.block_2)
259
+
260
+
261
+ def build_task_queue(net, is_decoder):
262
+ """
263
+ Build a single task queue for the encoder or decoder
264
+ @param net: the VAE decoder or encoder network
265
+ @param is_decoder: currently building decoder or encoder
266
+ @return: the task queue
267
+ """
268
+ task_queue = []
269
+ task_queue.append(('conv_in', net.conv_in))
270
+
271
+ # construct the sampling part of the task queue
272
+ # because encoder and decoder share the same architecture, we extract the sampling part
273
+ build_sampling(task_queue, net, is_decoder)
274
+
275
+ if not is_decoder or not net.give_pre_end:
276
+ task_queue.append(('pre_norm', net.norm_out))
277
+ task_queue.append(('silu', inplace_nonlinearity))
278
+ task_queue.append(('conv_out', net.conv_out))
279
+ if is_decoder and net.tanh_out:
280
+ task_queue.append(('tanh', torch.tanh))
281
+
282
+ return task_queue
283
+
284
+
285
+ def clone_task_queue(task_queue):
286
+ """
287
+ Clone a task queue
288
+ @param task_queue: the task queue to be cloned
289
+ @return: the cloned task queue
290
+ """
291
+ return [[item for item in task] for task in task_queue]
292
+
293
+
294
+ def get_var_mean(input, num_groups, eps=1e-6):
295
+ """
296
+ Get mean and var for group norm
297
+ """
298
+ b, c = input.size(0), input.size(1)
299
+ channel_in_group = int(c/num_groups)
300
+ input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:])
301
+ var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
302
+ return var, mean
303
+
304
+
305
+ def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
306
+ """
307
+ Custom group norm with fixed mean and var
308
+
309
+ @param input: input tensor
310
+ @param num_groups: number of groups. by default, num_groups = 32
311
+ @param mean: mean, must be pre-calculated by get_var_mean
312
+ @param var: var, must be pre-calculated by get_var_mean
313
+ @param weight: weight, should be fetched from the original group norm
314
+ @param bias: bias, should be fetched from the original group norm
315
+ @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
316
+
317
+ @return: normalized tensor
318
+ """
319
+ b, c = input.size(0), input.size(1)
320
+ channel_in_group = int(c/num_groups)
321
+ input_reshaped = input.contiguous().view(
322
+ 1, int(b * num_groups), channel_in_group, *input.size()[2:])
323
+
324
+ out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
325
+ out = out.view(b, c, *input.size()[2:])
326
+
327
+ # post affine transform
328
+ if weight is not None:
329
+ out *= weight.view(1, -1, 1, 1)
330
+ if bias is not None:
331
+ out += bias.view(1, -1, 1, 1)
332
+ return out
333
+
334
+
335
+ def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
336
+ """
337
+ Crop the valid region from the tile
338
+ @param x: input tile
339
+ @param input_bbox: original input bounding box
340
+ @param target_bbox: output bounding box
341
+ @param scale: scale factor
342
+ @return: cropped tile
343
+ """
344
+ padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
345
+ margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
346
+ return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
347
+
348
+
349
+ # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
350
+
351
+ def perfcount(fn):
352
+ def wrapper(*args, **kwargs):
353
+ ts = time()
354
+
355
+ if torch.cuda.is_available():
356
+ torch.cuda.reset_peak_memory_stats(devices.device)
357
+ devices.torch_gc()
358
+ gc.collect()
359
+
360
+ ret = fn(*args, **kwargs)
361
+
362
+ devices.torch_gc()
363
+ gc.collect()
364
+ if torch.cuda.is_available():
365
+ vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
366
+ print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
367
+ else:
368
+ print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
369
+
370
+ return ret
371
+ return wrapper
372
+
373
+ # ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑
374
+
375
+
376
+ class GroupNormParam:
377
+
378
+ def __init__(self):
379
+ self.var_list = []
380
+ self.mean_list = []
381
+ self.pixel_list = []
382
+ self.weight = None
383
+ self.bias = None
384
+
385
+ def add_tile(self, tile, layer):
386
+ var, mean = get_var_mean(tile, 32)
387
+ # For giant images, the variance can be larger than max float16
388
+ # In this case we create a copy to float32
389
+ if var.dtype == torch.float16 and var.isinf().any():
390
+ fp32_tile = tile.float()
391
+ var, mean = get_var_mean(fp32_tile, 32)
392
+ # ============= DEBUG: test for infinite =============
393
+ # if torch.isinf(var).any():
394
+ # print('[Tiled VAE]: inf test', var)
395
+ # ====================================================
396
+ self.var_list.append(var)
397
+ self.mean_list.append(mean)
398
+ self.pixel_list.append(
399
+ tile.shape[2]*tile.shape[3])
400
+ if hasattr(layer, 'weight'):
401
+ self.weight = layer.weight
402
+ self.bias = layer.bias
403
+ else:
404
+ self.weight = None
405
+ self.bias = None
406
+
407
+ def summary(self):
408
+ """
409
+ summarize the mean and var and return a function
410
+ that apply group norm on each tile
411
+ """
412
+ if len(self.var_list) == 0: return None
413
+
414
+ var = torch.vstack(self.var_list)
415
+ mean = torch.vstack(self.mean_list)
416
+ max_value = max(self.pixel_list)
417
+ pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
418
+ sum_pixels = torch.sum(pixels)
419
+ pixels = pixels.unsqueeze(1) / sum_pixels
420
+ # var = torch.sum(var * pixels.to(var.device), dim=0)
421
+ # mean = torch.sum(mean * pixels.to(var.device), dim=0)
422
+ var = torch.sum(var * pixels, dim=0)
423
+ mean = torch.sum(mean * pixels, dim=0)
424
+ return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
425
+
426
+ @staticmethod
427
+ def from_tile(tile, norm):
428
+ """
429
+ create a function from a single tile without summary
430
+ """
431
+ var, mean = get_var_mean(tile, 32)
432
+ if var.dtype == torch.float16 and var.isinf().any():
433
+ fp32_tile = tile.float()
434
+ var, mean = get_var_mean(fp32_tile, 32)
435
+ # if it is a macbook, we need to convert back to float16
436
+ if var.device.type == 'mps':
437
+ # clamp to avoid overflow
438
+ var = torch.clamp(var, 0, 60000)
439
+ var = var.half()
440
+ mean = mean.half()
441
+ if hasattr(norm, 'weight'):
442
+ weight = norm.weight
443
+ bias = norm.bias
444
+ else:
445
+ weight = None
446
+ bias = None
447
+
448
+ def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
449
+ return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
450
+ return group_norm_func
451
+
452
+
453
+ class VAEHook:
454
+
455
+ def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False):
456
+ self.net = net # encoder | decoder
457
+ self.tile_size = tile_size
458
+ self.is_decoder = is_decoder
459
+ self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder)
460
+ self.color_fix = color_fix and not is_decoder
461
+ self.to_gpu = to_gpu
462
+ self.pad = 11 if is_decoder else 32 # FIXME: magic number
463
+
464
+ def __call__(self, x):
465
+ # original_device = next(self.net.parameters()).device
466
+ try:
467
+ # if self.to_gpu:
468
+ # self.net = self.net.to(devices.get_optimal_device())
469
+ B, C, H, W = x.shape
470
+ if False:#max(H, W) <= self.pad * 2 + self.tile_size:
471
+ print("[Tiled VAE]: the input size is tiny and unnecessary to tile.", x.shape, self.pad * 2 + self.tile_size)
472
+ return self.net.original_forward(x)
473
+ else:
474
+ return self.vae_tile_forward(x)
475
+ finally:
476
+ pass
477
+ # self.net = self.net.to(original_device)
478
+
479
+ def get_best_tile_size(self, lowerbound, upperbound):
480
+ """
481
+ Get the best tile size for GPU memory
482
+ """
483
+ divider = 32
484
+ while divider >= 2:
485
+ remainer = lowerbound % divider
486
+ if remainer == 0:
487
+ return lowerbound
488
+ candidate = lowerbound - remainer + divider
489
+ if candidate <= upperbound:
490
+ return candidate
491
+ divider //= 2
492
+ return lowerbound
493
+
494
+ def split_tiles(self, h, w):
495
+ """
496
+ Tool function to split the image into tiles
497
+ @param h: height of the image
498
+ @param w: width of the image
499
+ @return: tile_input_bboxes, tile_output_bboxes
500
+ """
501
+ tile_input_bboxes, tile_output_bboxes = [], []
502
+ tile_size = self.tile_size
503
+ pad = self.pad
504
+ num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
505
+ num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
506
+ # If any of the numbers are 0, we let it be 1
507
+ # This is to deal with long and thin images
508
+ num_height_tiles = max(num_height_tiles, 1)
509
+ num_width_tiles = max(num_width_tiles, 1)
510
+
511
+ # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
512
+ real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
513
+ real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
514
+ real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
515
+ real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
516
+
517
+ print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
518
+ f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
519
+
520
+ for i in range(num_height_tiles):
521
+ for j in range(num_width_tiles):
522
+ # bbox: [x1, x2, y1, y2]
523
+ # the padding is is unnessary for image borders. So we directly start from (32, 32)
524
+ input_bbox = [
525
+ pad + j * real_tile_width,
526
+ min(pad + (j + 1) * real_tile_width, w),
527
+ pad + i * real_tile_height,
528
+ min(pad + (i + 1) * real_tile_height, h),
529
+ ]
530
+
531
+ # if the output bbox is close to the image boundary, we extend it to the image boundary
532
+ output_bbox = [
533
+ input_bbox[0] if input_bbox[0] > pad else 0,
534
+ input_bbox[1] if input_bbox[1] < w - pad else w,
535
+ input_bbox[2] if input_bbox[2] > pad else 0,
536
+ input_bbox[3] if input_bbox[3] < h - pad else h,
537
+ ]
538
+
539
+ # scale to get the final output bbox
540
+ output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
541
+ tile_output_bboxes.append(output_bbox)
542
+
543
+ # indistinguishable expand the input bbox by pad pixels
544
+ tile_input_bboxes.append([
545
+ max(0, input_bbox[0] - pad),
546
+ min(w, input_bbox[1] + pad),
547
+ max(0, input_bbox[2] - pad),
548
+ min(h, input_bbox[3] + pad),
549
+ ])
550
+
551
+ return tile_input_bboxes, tile_output_bboxes
552
+
553
+ @torch.no_grad()
554
+ def estimate_group_norm(self, z, task_queue, color_fix):
555
+ device = z.device
556
+ tile = z
557
+ last_id = len(task_queue) - 1
558
+ while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
559
+ last_id -= 1
560
+ if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
561
+ raise ValueError('No group norm found in the task queue')
562
+ # estimate until the last group norm
563
+ for i in range(last_id + 1):
564
+ task = task_queue[i]
565
+ if task[0] == 'pre_norm':
566
+ group_norm_func = GroupNormParam.from_tile(tile, task[1])
567
+ task_queue[i] = ('apply_norm', group_norm_func)
568
+ if i == last_id:
569
+ return True
570
+ tile = group_norm_func(tile)
571
+ elif task[0] == 'store_res':
572
+ task_id = i + 1
573
+ while task_id < last_id and task_queue[task_id][0] != 'add_res':
574
+ task_id += 1
575
+ if task_id >= last_id:
576
+ continue
577
+ task_queue[task_id][1] = task[1](tile)
578
+ elif task[0] == 'add_res':
579
+ tile += task[1].to(device)
580
+ task[1] = None
581
+ elif color_fix and task[0] == 'downsample':
582
+ for j in range(i, last_id + 1):
583
+ if task_queue[j][0] == 'store_res':
584
+ task_queue[j] = ('store_res_cpu', task_queue[j][1])
585
+ return True
586
+ else:
587
+ tile = task[1](tile)
588
+ try:
589
+ devices.test_for_nans(tile, "vae")
590
+ except:
591
+ print(f'Nan detected in fast mode estimation. Fast mode disabled.')
592
+ return False
593
+
594
+ raise IndexError('Should not reach here')
595
+
596
+ @perfcount
597
+ @torch.no_grad()
598
+ def vae_tile_forward(self, z):
599
+ """
600
+ Decode a latent vector z into an image in a tiled manner.
601
+ @param z: latent vector
602
+ @return: image
603
+ """
604
+ device = next(self.net.parameters()).device
605
+ net = self.net
606
+ tile_size = self.tile_size
607
+ is_decoder = self.is_decoder
608
+
609
+ z = z.detach() # detach the input to avoid backprop
610
+
611
+ N, height, width = z.shape[0], z.shape[2], z.shape[3]
612
+ net.last_z_shape = z.shape
613
+
614
+ # Split the input into tiles and build a task queue for each tile
615
+ print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
616
+
617
+ in_bboxes, out_bboxes = self.split_tiles(height, width)
618
+
619
+ # Prepare tiles by split the input latents
620
+ tiles = []
621
+ for input_bbox in in_bboxes:
622
+ tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
623
+ tiles.append(tile)
624
+
625
+ num_tiles = len(tiles)
626
+ num_completed = 0
627
+
628
+ # Build task queues
629
+ single_task_queue = build_task_queue(net, is_decoder)
630
+ if self.fast_mode:
631
+ # Fast mode: downsample the input image to the tile size,
632
+ # then estimate the group norm parameters on the downsampled image
633
+ scale_factor = tile_size / max(height, width)
634
+ z = z.to(device)
635
+ downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
636
+ # use nearest-exact to keep statictics as close as possible
637
+ print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
638
+
639
+ # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
640
+ # The downsampling will heavily distort its mean and std, so we need to recover it.
641
+ std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
642
+ std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
643
+ downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
644
+ del std_old, mean_old, std_new, mean_new
645
+ # occasionally the std_new is too small or too large, which exceeds the range of float16
646
+ # so we need to clamp it to max z's range.
647
+ downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
648
+ estimate_task_queue = clone_task_queue(single_task_queue)
649
+ if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
650
+ single_task_queue = estimate_task_queue
651
+ del downsampled_z
652
+
653
+ task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
654
+
655
+ # Dummy result
656
+ result = None
657
+ result_approx = None
658
+ try:
659
+ with devices.autocast():
660
+ result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
661
+ except: pass
662
+ # Free memory of input latent tensor
663
+ del z
664
+
665
+ # Task queue execution
666
+ pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
667
+ pbar_comfy = comfy.utils.ProgressBar(num_tiles * len(task_queues[0]))
668
+
669
+ # execute the task back and forth when switch tiles so that we always
670
+ # keep one tile on the GPU to reduce unnecessary data transfer
671
+ forward = True
672
+ interrupted = False
673
+ state_interrupted = processing_interrupted()
674
+ #state.interrupted = interrupted
675
+ while True:
676
+ if state_interrupted: interrupted = True ; break
677
+
678
+ group_norm_param = GroupNormParam()
679
+ for i in range(num_tiles) if forward else reversed(range(num_tiles)):
680
+ if state_interrupted: interrupted = True ; break
681
+
682
+ tile = tiles[i].to(device)
683
+ input_bbox = in_bboxes[i]
684
+ task_queue = task_queues[i]
685
+
686
+ interrupted = False
687
+ while len(task_queue) > 0:
688
+ if state_interrupted: interrupted = True ; break
689
+
690
+ # DEBUG: current task
691
+ # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
692
+ task = task_queue.pop(0)
693
+ if task[0] == 'pre_norm':
694
+ group_norm_param.add_tile(tile, task[1])
695
+ break
696
+ elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
697
+ task_id = 0
698
+ res = task[1](tile)
699
+ if not self.fast_mode or task[0] == 'store_res_cpu':
700
+ res = res.cpu()
701
+ while task_queue[task_id][0] != 'add_res':
702
+ task_id += 1
703
+ task_queue[task_id][1] = res
704
+ elif task[0] == 'add_res':
705
+ tile += task[1].to(device)
706
+ task[1] = None
707
+ else:
708
+ tile = task[1](tile)
709
+ pbar.update(1)
710
+ pbar_comfy.update(1)
711
+
712
+
713
+ if interrupted: break
714
+
715
+ # check for NaNs in the tile.
716
+ # If there are NaNs, we abort the process to save user's time
717
+ devices.test_for_nans(tile, "vae")
718
+
719
+ if len(task_queue) == 0:
720
+ tiles[i] = None
721
+ num_completed += 1
722
+ if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
723
+ result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
724
+ result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
725
+ del tile
726
+ elif i == num_tiles - 1 and forward:
727
+ forward = False
728
+ tiles[i] = tile
729
+ elif i == 0 and not forward:
730
+ forward = True
731
+ tiles[i] = tile
732
+ else:
733
+ tiles[i] = tile.cpu()
734
+ del tile
735
+
736
+ if interrupted: break
737
+ if num_completed == num_tiles: break
738
+
739
+ # insert the group norm task to the head of each task queue
740
+ group_norm_func = group_norm_param.summary()
741
+ if group_norm_func is not None:
742
+ for i in range(num_tiles):
743
+ task_queue = task_queues[i]
744
+ task_queue.insert(0, ('apply_norm', group_norm_func))
745
+
746
+ # Done!
747
+ pbar.close()
748
+ if interrupted:
749
+ del result, result_approx
750
+ comfy.model_management.throw_exception_if_processing_interrupted()
751
+ vae_dtype = comfy.model_management.vae_dtype()
752
+ return result.to(dtype=vae_dtype, device=device) if result is not None else result_approx.to(device=device, dtype=vae_dtype)
753
+
754
+ # from .tiled_vae import VAEHook, get_rcmd_enc_tsize, get_rcmd_dec_tsize
755
+ from nodes import VAEEncode, VAEDecode
756
+ class TiledVAE:
757
+ def process(self, *args, **kwargs):
758
+ samples = kwargs['samples'] if 'samples' in kwargs else (kwargs['pixels'] if 'pixels' in kwargs else args[0])
759
+ _vae = kwargs['vae'] if 'vae' in kwargs else args[1]
760
+ tile_size = kwargs['tile_size'] if 'tile_size' in kwargs else args[2]
761
+ fast = kwargs['fast'] if 'fast' in kwargs else args[3]
762
+ color_fix = kwargs['color_fix'] if 'color_fix' in kwargs else False
763
+ is_decoder = self.is_decoder
764
+
765
+ # for shorthand
766
+ vae = _vae.first_stage_model
767
+ encoder = vae.encoder
768
+ decoder = vae.decoder
769
+
770
+ # # undo hijack if disabled (in cases last time crashed)
771
+ # if not enabled:
772
+ # if self.hooked:
773
+ if isinstance(encoder.forward, VAEHook):
774
+ encoder.forward.net = None
775
+ encoder.forward = encoder.original_forward
776
+ if isinstance(decoder.forward, VAEHook):
777
+ decoder.forward.net = None
778
+ decoder.forward = decoder.original_forward
779
+ # self.hooked = False
780
+ # return
781
+
782
+ # if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu:
783
+ # print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.")
784
+
785
+ # do hijack
786
+ # kwargs = {
787
+ # 'fast_decoder': fast_decoder,
788
+ # 'fast_encoder': fast_encoder,
789
+ # 'color_fix': color_fix,
790
+ # 'to_gpu': vae_to_gpu,
791
+ # }
792
+
793
+ # save original forward (only once)
794
+ if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward)
795
+ if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward)
796
+
797
+ # self.hooked = True
798
+
799
+ # encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs)
800
+ # decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs)
801
+ fn = VAEHook(net=decoder if is_decoder else encoder, tile_size=tile_size // 8 if is_decoder else tile_size,
802
+ is_decoder=is_decoder, fast_decoder=fast, fast_encoder=fast,
803
+ color_fix=color_fix, to_gpu=comfy.model_management.vae_device().type != 'cpu')
804
+ if is_decoder:
805
+ decoder.forward = fn
806
+ else:
807
+ encoder.forward = fn
808
+
809
+ ret = (None,)
810
+ try:
811
+ with devices.without_autocast():
812
+ if not is_decoder:
813
+ ret = VAEEncode().encode(_vae, samples)
814
+ else:
815
+ ret = VAEDecode().decode(_vae, samples) if is_decoder else VAEEncode().encode(_vae, samples)
816
+ finally:
817
+ if isinstance(encoder.forward, VAEHook):
818
+ encoder.forward.net = None
819
+ encoder.forward = encoder.original_forward
820
+ if isinstance(decoder.forward, VAEHook):
821
+ decoder.forward.net = None
822
+ decoder.forward = decoder.original_forward
823
+ return ret
824
+
825
+ class VAEEncodeTiled_TiledDiffusion(TiledVAE):
826
+ @classmethod
827
+ def INPUT_TYPES(s):
828
+ fast = True
829
+ tile_size = get_rcmd_enc_tsize()
830
+ return {"required": {"pixels": ("IMAGE", ),
831
+ "vae": ("VAE", ),
832
+ "tile_size": ("INT", {"default": tile_size, "min": 256, "max": 4096, "step": 16}),
833
+ "fast": ("BOOLEAN", {"default": fast}),
834
+ "color_fix": ("BOOLEAN", {"default": fast}),
835
+ }}
836
+ RETURN_TYPES = ("LATENT",)
837
+ FUNCTION = "process"
838
+ CATEGORY = "_for_testing"
839
+
840
+ def __init__(self):
841
+ self.is_decoder = False
842
+ super().__init__()
843
+
844
+ class VAEDecodeTiled_TiledDiffusion(TiledVAE):
845
+ @classmethod
846
+ def INPUT_TYPES(s):
847
+ tile_size = get_rcmd_dec_tsize() * opt_f
848
+ return {"required": {"samples": ("LATENT", ),
849
+ "vae": ("VAE", ),
850
+ "tile_size": ("INT", {"default": tile_size, "min": 48*opt_f, "max": 4096, "step": 16}),
851
+ "fast": ("BOOLEAN", {"default": True}),
852
+ }}
853
+ RETURN_TYPES = ("IMAGE",)
854
+ FUNCTION = "process"
855
+ CATEGORY = "_for_testing"
856
+
857
+ def __init__(self):
858
+ self.is_decoder = True
859
+ super().__init__()
860
+
861
+ NODE_CLASS_MAPPINGS = {
862
+ "VAEEncodeTiled_TiledDiffusion": VAEEncodeTiled_TiledDiffusion,
863
+ "VAEDecodeTiled_TiledDiffusion": VAEDecodeTiled_TiledDiffusion,
864
+ }
865
+ NODE_DISPLAY_NAME_MAPPINGS = {
866
+ "VAEEncodeTiled_TiledDiffusion": "Tiled VAE Encode",
867
+ "VAEDecodeTiled_TiledDiffusion": "Tiled VAE Decode",
868
+ }
ComfyUI-TiledDiffusion/utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import importlib
3
+ from textwrap import dedent, indent
4
+ from copy import copy
5
+ import types
6
+ import functools
7
+ import os
8
+ import sys
9
+ import binascii
10
+ from typing import List, NamedTuple
11
+
12
+ class Hook(NamedTuple):
13
+ fn: object
14
+ module_name: str
15
+ target: str
16
+ orig_key: str
17
+ module_name_path: str
18
+
19
+ def gen_id():
20
+ return binascii.hexlify(os.urandom(1024))[64:72].decode("utf-8")
21
+
22
+ def hook_calc_cond_uncond_batch():
23
+ try:
24
+ from comfy.samplers import calc_cond_batch
25
+ calc_cond_batch_ = calc_cond_batch
26
+ except Exception:
27
+ from comfy.samplers import calc_cond_uncond_batch
28
+ calc_cond_batch_ = calc_cond_uncond_batch
29
+ # this function should only be run by us
30
+ orig_key = f"{calc_cond_batch_.__name__}_original_tiled_diffusion_{gen_id()}"
31
+ payload = [{
32
+ "mode": "replace",
33
+ "target_line": 'control.get_control',
34
+ "code_to_insert": """control if 'tiled_diffusion' in model_options else control.get_control"""
35
+ },
36
+ {
37
+ "dedent": False,
38
+ "target_line": calc_cond_batch_.__name__,
39
+ "code_to_insert": f"""
40
+ if 'tiled_diffusion' not in model_options:
41
+ return {orig_key}{inspect.signature(calc_cond_batch_)}"""
42
+ }]
43
+ fn = inject_code(calc_cond_batch_, payload, 'w')
44
+ return create_hook(fn, 'comfy.samplers', orig_key=orig_key)
45
+
46
+ def hook_sag_create_blur_map():
47
+ imported = False
48
+ try:
49
+ import comfy_extras
50
+ from comfy_extras import nodes_sag
51
+ imported = True
52
+ except Exception: ...
53
+ if not imported: return
54
+ import comfy_extras
55
+ from comfy_extras import nodes_sag
56
+ import re
57
+ source=inspect.getsource(nodes_sag.create_blur_map)
58
+ replace_str="""
59
+ def calc_closest_factors(a):
60
+ for b in range(int(math.sqrt(a)), 0, -1):
61
+ if a%b == 0:
62
+ c = a // b
63
+ return (b,c)
64
+ m = calc_closest_factors(hw1)
65
+ mh = max(m) if lh > lw else min(m)
66
+ mw = m[1] if mh == m[0] else m[0]
67
+ mid_shape = mh, mw"""
68
+ modified_source = re.sub(r"ratio =.*\s+mid_shape =.*", replace_str, source, flags=re.MULTILINE)
69
+ fn = write_to_file_and_return_fn(nodes_sag.create_blur_map, modified_source)
70
+ return create_hook(fn, 'comfy_extras.nodes_sag')
71
+
72
+ def hook_samplers_pre_run_control():
73
+ from comfy.samplers import pre_run_control
74
+ payload = [{
75
+ "dedent": False,
76
+ "target_line": "if 'control' in x:",
77
+ "code_to_insert": """ try: x['control'].cleanup()\n except Exception: ..."""
78
+ },
79
+ {
80
+ "target_line": "s = model.model_sampling",
81
+ "code_to_insert": """
82
+ def find_outer_instance(target:str, target_type):
83
+ import inspect
84
+ frame = inspect.currentframe()
85
+ i = 0
86
+ while frame and i < 7:
87
+ if (found:=frame.f_locals.get(target, None)) is not None:
88
+ if isinstance(found, target_type):
89
+ return found
90
+ frame = frame.f_back
91
+ i += 1
92
+ return None
93
+ from comfy.model_patcher import ModelPatcher
94
+ if (_model:=find_outer_instance('model', ModelPatcher)) is not None:
95
+ if (model_function_wrapper:=_model.model_options.get('model_function_wrapper', None)) is not None:
96
+ import sys
97
+ tiled_diffusion = sys.modules.get('ComfyUI-TiledDiffusion.tiled_diffusion', None)
98
+ if tiled_diffusion is None:
99
+ for key in sys.modules:
100
+ if 'tiled_diffusion' in key:
101
+ tiled_diffusion = sys.modules[key]
102
+ break
103
+ if (AbstractDiffusion:=getattr(tiled_diffusion, 'AbstractDiffusion', None)) is not None:
104
+ if isinstance(model_function_wrapper, AbstractDiffusion):
105
+ model_function_wrapper.reset()
106
+ """}]
107
+ fn = inject_code(pre_run_control, payload)
108
+ return create_hook(fn, 'comfy.samplers')
109
+
110
+ def hook_gligen__set_position():
111
+ from comfy.gligen import Gligen
112
+ source=inspect.getsource(Gligen._set_position)
113
+ replace_str="""
114
+ nonlocal objs
115
+ if x.shape[0] > objs.shape[0]:
116
+ _objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1)
117
+ else:
118
+ _objs = objs
119
+ return module(x, _objs)"""
120
+ modified_source = dedent(source.replace(" return module(x, objs)", replace_str, 1))
121
+ fn = write_to_file_and_return_fn(Gligen._set_position, modified_source)
122
+ return create_hook(fn, 'comfy.gligen', 'Gligen._set_position')
123
+
124
+ def create_hook(fn, module_name:str, target = None, orig_key = None):
125
+ if target is None: target = fn.__name__
126
+ if orig_key is None: orig_key = f'{target}_original'
127
+ module_name_path = os.path.normpath(module_name.replace('.', '/'))
128
+ return Hook(fn, module_name, target, orig_key, module_name_path)
129
+
130
+ def _getattr(obj, name:str, default=None):
131
+ """multi-level getattr"""
132
+ for attr in name.split('.'):
133
+ obj = getattr(obj, attr, default)
134
+ return obj
135
+
136
+ def _hasattr(obj, name:str):
137
+ """multi-level hasattr"""
138
+ return _getattr(obj, name) is not None
139
+
140
+ def _setattr(obj, name:str, value=None):
141
+ """multi-level setattr"""
142
+ split = name.split('.')
143
+ if not split[:-1]:
144
+ return setattr(obj, name, value)
145
+ else:
146
+ name = split[-1]
147
+ for attr in split[:-1]:
148
+ obj = getattr(obj, attr, None)
149
+ return setattr(obj, name, value)
150
+
151
+ def hook_all(restore=False, hooks=None):
152
+ if hooks is None:
153
+ hooks: List[Hook] = [
154
+ hook_calc_cond_uncond_batch(),
155
+ hook_sag_create_blur_map(),
156
+ hook_samplers_pre_run_control(),
157
+ hook_gligen__set_position(),
158
+ ]
159
+ for key, module in sys.modules.items():
160
+ for hook in hooks:
161
+ if key == hook.module_name or key.endswith(hook.module_name_path):
162
+ if _hasattr(module, hook.target):
163
+ if not _hasattr(module, hook.orig_key):
164
+ if (orig_fn:=_getattr(module, hook.target, None)) is not None:
165
+ _setattr(module, hook.orig_key, orig_fn)
166
+ if restore:
167
+ _setattr(module, hook.target, _getattr(module, hook.orig_key, None))
168
+ else:
169
+ _setattr(module, hook.target, hook.fn)
170
+
171
+ def inject_code(original_func, data, mode='a'):
172
+ # Get the source code of the original function
173
+ original_source = inspect.getsource(original_func)
174
+
175
+ # Split the source code into lines
176
+ lines = original_source.split("\n")
177
+
178
+ for item in data:
179
+ # Find the line number of the target line
180
+ target_line_number = None
181
+ for i, line in enumerate(lines):
182
+ if item['target_line'] not in line: continue
183
+ target_line_number = i + 1
184
+ if item.get("mode","insert") == "replace":
185
+ lines[i] = lines[i].replace(item['target_line'], item['code_to_insert'])
186
+ break
187
+
188
+ # Find the indentation of the line where the new code will be inserted
189
+ indentation = ''
190
+ for char in line:
191
+ if char == ' ':
192
+ indentation += char
193
+ else:
194
+ break
195
+
196
+ # Indent the new code to match the original
197
+ code_to_insert = item['code_to_insert']
198
+ if item.get("dedent",True):
199
+ code_to_insert = dedent(item['code_to_insert'])
200
+ code_to_insert = indent(code_to_insert, indentation)
201
+
202
+ break
203
+
204
+ # Insert the code to be injected after the target line
205
+ if item.get("mode","insert") == "insert" and target_line_number is not None:
206
+ lines.insert(target_line_number, code_to_insert)
207
+
208
+ # Recreate the modified source code
209
+ modified_source = "\n".join(lines)
210
+ modified_source = dedent(modified_source.strip("\n"))
211
+ return write_to_file_and_return_fn(original_func, modified_source, mode)
212
+
213
+ def write_to_file_and_return_fn(original_func, source:str, mode='a'):
214
+ # Write the modified source code to a temporary file so the
215
+ # source code and stack traces can still be viewed when debugging.
216
+ custom_name = ".patches.py"
217
+ current_dir = os.path.dirname(os.path.abspath(__file__))
218
+ temp_file_path = os.path.join(current_dir, custom_name)
219
+ with open(temp_file_path, mode) as temp_file:
220
+ temp_file.write(source)
221
+ temp_file.write("\n")
222
+ temp_file.flush()
223
+
224
+ MODULE_PATH = temp_file.name
225
+ MODULE_NAME = __name__.split('.')[0].replace('-','_') + "_patch_modules"
226
+ spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
227
+ module = importlib.util.module_from_spec(spec)
228
+ sys.modules[spec.name] = module
229
+ spec.loader.exec_module(module)
230
+
231
+ # Retrieve the modified function from the module
232
+ modified_function = getattr(module, original_func.__name__)
233
+
234
+ # Adapted from https://stackoverflow.com/a/49077211
235
+ def copy_func(f, globals=None, module=None, code=None, update_wrapper=True):
236
+ if globals is None: globals = f.__globals__
237
+ if code is None: code = f.__code__
238
+ g = types.FunctionType(code, globals, name=f.__name__,
239
+ argdefs=f.__defaults__, closure=f.__closure__)
240
+ if update_wrapper: g = functools.update_wrapper(g, f)
241
+ if module is not None: g.__module__ = module
242
+ g.__kwdefaults__ = copy(f.__kwdefaults__)
243
+ return g
244
+
245
+ return copy_func(original_func, code=modified_function.__code__, update_wrapper=False)
246
+