Upload 17 files
Browse files- ComfyUI-TiledDiffusion/.gitignore +167 -0
- ComfyUI-TiledDiffusion/.patches.py +189 -0
- ComfyUI-TiledDiffusion/README.md +112 -0
- ComfyUI-TiledDiffusion/__init__.py +9 -0
- ComfyUI-TiledDiffusion/__pycache__/.patches.cpython-310.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/.patches.cpython-311.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/__init__.cpython-310.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/__init__.cpython-311.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/tiled_diffusion.cpython-310.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/tiled_diffusion.cpython-311.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/tiled_vae.cpython-310.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/tiled_vae.cpython-311.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/utils.cpython-310.pyc +0 -0
- ComfyUI-TiledDiffusion/__pycache__/utils.cpython-311.pyc +0 -0
- ComfyUI-TiledDiffusion/tiled_diffusion.py +650 -0
- ComfyUI-TiledDiffusion/tiled_vae.py +868 -0
- ComfyUI-TiledDiffusion/utils.py +246 -0
ComfyUI-TiledDiffusion/.gitignore
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
| 162 |
+
backup*
|
| 163 |
+
**/.DS_Store
|
| 164 |
+
**/.venv
|
| 165 |
+
**/.vscode
|
| 166 |
+
|
| 167 |
+
.*
|
ComfyUI-TiledDiffusion/.patches.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
| 2 |
+
|
| 3 |
+
if 'tiled_diffusion' not in model_options:
|
| 4 |
+
return calc_cond_batch_original_tiled_diffusion_875b8c8d(model, conds, x_in, timestep, model_options)
|
| 5 |
+
out_conds = []
|
| 6 |
+
out_counts = []
|
| 7 |
+
to_run = []
|
| 8 |
+
|
| 9 |
+
for i in range(len(conds)):
|
| 10 |
+
out_conds.append(torch.zeros_like(x_in))
|
| 11 |
+
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
| 12 |
+
|
| 13 |
+
cond = conds[i]
|
| 14 |
+
if cond is not None:
|
| 15 |
+
for x in cond:
|
| 16 |
+
p = get_area_and_mult(x, x_in, timestep)
|
| 17 |
+
if p is None:
|
| 18 |
+
continue
|
| 19 |
+
|
| 20 |
+
to_run += [(p, i)]
|
| 21 |
+
|
| 22 |
+
while len(to_run) > 0:
|
| 23 |
+
first = to_run[0]
|
| 24 |
+
first_shape = first[0][0].shape
|
| 25 |
+
to_batch_temp = []
|
| 26 |
+
for x in range(len(to_run)):
|
| 27 |
+
if can_concat_cond(to_run[x][0], first[0]):
|
| 28 |
+
to_batch_temp += [x]
|
| 29 |
+
|
| 30 |
+
to_batch_temp.reverse()
|
| 31 |
+
to_batch = to_batch_temp[:1]
|
| 32 |
+
|
| 33 |
+
free_memory = model_management.get_free_memory(x_in.device)
|
| 34 |
+
for i in range(1, len(to_batch_temp) + 1):
|
| 35 |
+
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
| 36 |
+
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
| 37 |
+
if model.memory_required(input_shape) * 1.5 < free_memory:
|
| 38 |
+
to_batch = batch_amount
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
input_x = []
|
| 42 |
+
mult = []
|
| 43 |
+
c = []
|
| 44 |
+
cond_or_uncond = []
|
| 45 |
+
area = []
|
| 46 |
+
control = None
|
| 47 |
+
patches = None
|
| 48 |
+
for x in to_batch:
|
| 49 |
+
o = to_run.pop(x)
|
| 50 |
+
p = o[0]
|
| 51 |
+
input_x.append(p.input_x)
|
| 52 |
+
mult.append(p.mult)
|
| 53 |
+
c.append(p.conditioning)
|
| 54 |
+
area.append(p.area)
|
| 55 |
+
cond_or_uncond.append(o[1])
|
| 56 |
+
control = p.control
|
| 57 |
+
patches = p.patches
|
| 58 |
+
|
| 59 |
+
batch_chunks = len(cond_or_uncond)
|
| 60 |
+
input_x = torch.cat(input_x)
|
| 61 |
+
c = cond_cat(c)
|
| 62 |
+
timestep_ = torch.cat([timestep] * batch_chunks)
|
| 63 |
+
|
| 64 |
+
if control is not None:
|
| 65 |
+
c['control'] = control if 'tiled_diffusion' in model_options else control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
| 66 |
+
|
| 67 |
+
transformer_options = {}
|
| 68 |
+
if 'transformer_options' in model_options:
|
| 69 |
+
transformer_options = model_options['transformer_options'].copy()
|
| 70 |
+
|
| 71 |
+
if patches is not None:
|
| 72 |
+
if "patches" in transformer_options:
|
| 73 |
+
cur_patches = transformer_options["patches"].copy()
|
| 74 |
+
for p in patches:
|
| 75 |
+
if p in cur_patches:
|
| 76 |
+
cur_patches[p] = cur_patches[p] + patches[p]
|
| 77 |
+
else:
|
| 78 |
+
cur_patches[p] = patches[p]
|
| 79 |
+
transformer_options["patches"] = cur_patches
|
| 80 |
+
else:
|
| 81 |
+
transformer_options["patches"] = patches
|
| 82 |
+
|
| 83 |
+
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
| 84 |
+
transformer_options["sigmas"] = timestep
|
| 85 |
+
|
| 86 |
+
c['transformer_options'] = transformer_options
|
| 87 |
+
|
| 88 |
+
if 'model_function_wrapper' in model_options:
|
| 89 |
+
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
| 90 |
+
else:
|
| 91 |
+
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
| 92 |
+
|
| 93 |
+
for o in range(batch_chunks):
|
| 94 |
+
cond_index = cond_or_uncond[o]
|
| 95 |
+
a = area[o]
|
| 96 |
+
if a is None:
|
| 97 |
+
out_conds[cond_index] += output[o] * mult[o]
|
| 98 |
+
out_counts[cond_index] += mult[o]
|
| 99 |
+
else:
|
| 100 |
+
out_c = out_conds[cond_index]
|
| 101 |
+
out_cts = out_counts[cond_index]
|
| 102 |
+
dims = len(a) // 2
|
| 103 |
+
for i in range(dims):
|
| 104 |
+
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
| 105 |
+
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
| 106 |
+
out_c += output[o] * mult[o]
|
| 107 |
+
out_cts += mult[o]
|
| 108 |
+
|
| 109 |
+
for i in range(len(out_conds)):
|
| 110 |
+
out_conds[i] /= out_counts[i]
|
| 111 |
+
|
| 112 |
+
return out_conds
|
| 113 |
+
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
| 114 |
+
# reshape and GAP the attention map
|
| 115 |
+
_, hw1, hw2 = attn.shape
|
| 116 |
+
b, _, lh, lw = x0.shape
|
| 117 |
+
attn = attn.reshape(b, -1, hw1, hw2)
|
| 118 |
+
# Global Average Pool
|
| 119 |
+
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
| 120 |
+
|
| 121 |
+
def calc_closest_factors(a):
|
| 122 |
+
for b in range(int(math.sqrt(a)), 0, -1):
|
| 123 |
+
if a%b == 0:
|
| 124 |
+
c = a // b
|
| 125 |
+
return (b,c)
|
| 126 |
+
m = calc_closest_factors(hw1)
|
| 127 |
+
mh = max(m) if lh > lw else min(m)
|
| 128 |
+
mw = m[1] if mh == m[0] else m[0]
|
| 129 |
+
mid_shape = mh, mw
|
| 130 |
+
|
| 131 |
+
# Reshape
|
| 132 |
+
mask = (
|
| 133 |
+
mask.reshape(b, *mid_shape)
|
| 134 |
+
.unsqueeze(1)
|
| 135 |
+
.type(attn.dtype)
|
| 136 |
+
)
|
| 137 |
+
# Upsample
|
| 138 |
+
mask = F.interpolate(mask, (lh, lw))
|
| 139 |
+
|
| 140 |
+
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
| 141 |
+
blurred = blurred * mask + x0 * (1 - mask)
|
| 142 |
+
return blurred
|
| 143 |
+
|
| 144 |
+
def pre_run_control(model, conds):
|
| 145 |
+
s = model.model_sampling
|
| 146 |
+
|
| 147 |
+
def find_outer_instance(target:str, target_type):
|
| 148 |
+
import inspect
|
| 149 |
+
frame = inspect.currentframe()
|
| 150 |
+
i = 0
|
| 151 |
+
while frame and i < 7:
|
| 152 |
+
if (found:=frame.f_locals.get(target, None)) is not None:
|
| 153 |
+
if isinstance(found, target_type):
|
| 154 |
+
return found
|
| 155 |
+
frame = frame.f_back
|
| 156 |
+
i += 1
|
| 157 |
+
return None
|
| 158 |
+
from comfy.model_patcher import ModelPatcher
|
| 159 |
+
if (_model:=find_outer_instance('model', ModelPatcher)) is not None:
|
| 160 |
+
if (model_function_wrapper:=_model.model_options.get('model_function_wrapper', None)) is not None:
|
| 161 |
+
import sys
|
| 162 |
+
tiled_diffusion = sys.modules.get('ComfyUI-TiledDiffusion.tiled_diffusion', None)
|
| 163 |
+
if tiled_diffusion is None:
|
| 164 |
+
for key in sys.modules:
|
| 165 |
+
if 'tiled_diffusion' in key:
|
| 166 |
+
tiled_diffusion = sys.modules[key]
|
| 167 |
+
break
|
| 168 |
+
if (AbstractDiffusion:=getattr(tiled_diffusion, 'AbstractDiffusion', None)) is not None:
|
| 169 |
+
if isinstance(model_function_wrapper, AbstractDiffusion):
|
| 170 |
+
model_function_wrapper.reset()
|
| 171 |
+
|
| 172 |
+
for t in range(len(conds)):
|
| 173 |
+
x = conds[t]
|
| 174 |
+
|
| 175 |
+
timestep_start = None
|
| 176 |
+
timestep_end = None
|
| 177 |
+
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
| 178 |
+
if 'control' in x:
|
| 179 |
+
try: x['control'].cleanup()
|
| 180 |
+
except Exception: ...
|
| 181 |
+
x['control'].pre_run(model, percent_to_timestep_function)
|
| 182 |
+
def _set_position(self, boxes, masks, positive_embeddings):
|
| 183 |
+
objs = self.position_net(boxes, masks, positive_embeddings)
|
| 184 |
+
def func(x, extra_options):
|
| 185 |
+
key = extra_options["transformer_index"]
|
| 186 |
+
module = self.module_list[key]
|
| 187 |
+
return module(x, objs.to(device=x.device, dtype=x.dtype))
|
| 188 |
+
return func
|
| 189 |
+
|
ComfyUI-TiledDiffusion/README.md
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tiled Diffusion & VAE for ComfyUI
|
| 2 |
+
|
| 3 |
+
Check out the [SD-WebUI extension](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/) for more information.
|
| 4 |
+
|
| 5 |
+
This extension enables **large image drawing & upscaling with limited VRAM** via the following techniques:
|
| 6 |
+
|
| 7 |
+
1. Two SOTA diffusion tiling algorithms: [Mixture of Diffusers](https://github.com/albarji/mixture-of-diffusers) <a href="https://arxiv.org/abs/2302.02412"><img width="32" alt="Mixture of Diffusers Paper" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/b753b7f6-f9c0-405d-bace-792b9bbce5d5"></a> and [MultiDiffusion](https://github.com/omerbt/MultiDiffusion) <a href="https://arxiv.org/abs/2302.08113"><img width="32" alt="MultiDiffusion Paper" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/b753b7f6-f9c0-405d-bace-792b9bbce5d5"></a>
|
| 8 |
+
2. pkuliyi2015 & Kahsolt's Tiled VAE algorithm.
|
| 9 |
+
3. ~~pkuliyi2015 & Kahsolt's TIled Noise Inversion for better upscaling.~~
|
| 10 |
+
|
| 11 |
+
> [!NOTE]
|
| 12 |
+
> Sizes/dimensions are in pixels and then converted to latent-space sizes.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Features
|
| 16 |
+
- [x] SDXL model support
|
| 17 |
+
- [x] ControlNet support
|
| 18 |
+
- [ ] ~~StableSR support~~
|
| 19 |
+
- [ ] ~~Tiled Noise Inversion~~
|
| 20 |
+
- [x] Tiled VAE
|
| 21 |
+
- [ ] Regional Prompt Control
|
| 22 |
+
- [x] Img2img upscale
|
| 23 |
+
- [x] Ultra-Large image generation
|
| 24 |
+
|
| 25 |
+
## Tiled Diffusion
|
| 26 |
+
|
| 27 |
+
<div align="center">
|
| 28 |
+
<img width="500" alt="Tiled_Diffusion" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/7cb897a3-a645-426f-8742-d6ba5cf04b64">
|
| 29 |
+
</div>
|
| 30 |
+
|
| 31 |
+
> [!TIP]
|
| 32 |
+
> Set `tile_overlap` to 0 and `denoise` to 1 to see the tile seams and then adjust the options to your needs. Also, increase `tile_batch_size` to increase speed (if your machine can handle it).
|
| 33 |
+
|
| 34 |
+
| Name | Description |
|
| 35 |
+
|-------------------|--------------------------------------------------------------|
|
| 36 |
+
| `method` | Tiling [strategy](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/blob/fbb24736c9bc374c7f098f82b575fcd14a73936a/scripts/tilediffusion.py#L39-L46). `MultiDiffusion` or `Mixture of Diffusers`. |
|
| 37 |
+
| `tile_width` | Tile's width |
|
| 38 |
+
| `tile_height` | Tile's height |
|
| 39 |
+
| `tile_overlap` | Tile's overlap |
|
| 40 |
+
| `tile_batch_size` | The number of tiles to process in a batch |
|
| 41 |
+
|
| 42 |
+
### How can I specify the tiles' arrangement?
|
| 43 |
+
|
| 44 |
+
If you have the [Math Expression](https://github.com/pythongosssss/ComfyUI-Custom-Scripts#math-expression) node (or something similar), you can use that to pass in the latent that's passed in your KSampler and divide the `tile_height`/`tile_width` by the number of rows/columns you want.
|
| 45 |
+
|
| 46 |
+
`C` = number of columns you want
|
| 47 |
+
`R` = number of rows you want
|
| 48 |
+
|
| 49 |
+
`pixel width of input image or latent // C` = `tile_width`
|
| 50 |
+
`pixel height of input image or latent // R` = `tile_height`
|
| 51 |
+
|
| 52 |
+
<img width="800" alt="Tile_arrangement" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/9952e7d8-909e-436f-a284-c00f0fb71665">
|
| 53 |
+
|
| 54 |
+
## Tiled VAE
|
| 55 |
+
|
| 56 |
+
<div align="center">
|
| 57 |
+
<img width="900" alt="Tiled_VAE" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/b5850e03-2cac-49ce-b1fe-a67906bf4c9d">
|
| 58 |
+
</div>
|
| 59 |
+
|
| 60 |
+
<br>
|
| 61 |
+
|
| 62 |
+
The recommended tile sizes are given upon the creation of the node based on the available VRAM.
|
| 63 |
+
|
| 64 |
+
> [!NOTE]
|
| 65 |
+
> Enabling `fast` for the decoder may produce images with slightly higher contrast and brightness.
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
| Name | Description |
|
| 69 |
+
|-------------|----------------------------------------------------------------------------------------------------------------------------------------------|
|
| 70 |
+
| `tile_size` | <blockquote>The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.</blockquote> |
|
| 71 |
+
| `fast` | <blockquote><p>When Fast Mode is disabled:</p> <ol> <li>The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.</li> <li>When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.</li> <li>After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues. </li> <li>A zigzag execution order is used to reduce unnecessary data transfer.</li> </ol> <p>When Fast Mode is enabled:</p> <ol> <li>The original input is downsampled and passed to a separate task queue.</li> <li>Its group norm parameters are recorded and used by all tiles' task queues.</li> <li>Each tile is separately processed without any RAM-VRAM data transfer.</li> </ol> <p>After all tiles are processed, tiles are written to a result buffer and returned.</p></blockquote> |
|
| 72 |
+
| `color_fix` | <blockquote>Only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.</blockquote><p>Only for the encoder. Can restore colors if tiles are too small.</p> |
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## Workflows
|
| 77 |
+
|
| 78 |
+
The following images can be loaded in ComfyUI.
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
<div align="center">
|
| 82 |
+
<img alt="ComfyUI_07501_" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/c3713cfb-e083-4df4-a310-9467827ee666">
|
| 83 |
+
<p>Simple upscale.</p>
|
| 84 |
+
</div>
|
| 85 |
+
|
| 86 |
+
<br>
|
| 87 |
+
|
| 88 |
+
<div align="center">
|
| 89 |
+
|
| 90 |
+
<img alt="ComfyUI_07503_" src="https://github.com/shiimizu/ComfyUI-TiledDiffusion/assets/54494639/b681b617-4bb1-49e5-b85a-ef5a0f6e4830">
|
| 91 |
+
<p>4x upscale. 3 passes.</p>
|
| 92 |
+
</div>
|
| 93 |
+
|
| 94 |
+
## Citation
|
| 95 |
+
|
| 96 |
+
```bibtex
|
| 97 |
+
@article{jimenez2023mixtureofdiffusers,
|
| 98 |
+
title={Mixture of Diffusers for scene composition and high resolution image generation},
|
| 99 |
+
author={Álvaro Barbero Jiménez},
|
| 100 |
+
journal={arXiv preprint arXiv:2302.02412},
|
| 101 |
+
year={2023}
|
| 102 |
+
}
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
```bibtex
|
| 106 |
+
@article{bar2023multidiffusion,
|
| 107 |
+
title={MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation},
|
| 108 |
+
author={Bar-Tal, Omer and Yariv, Lior and Lipman, Yaron and Dekel, Tali},
|
| 109 |
+
journal={arXiv preprint arXiv:2302.08113},
|
| 110 |
+
year={2023}
|
| 111 |
+
}
|
| 112 |
+
```
|
ComfyUI-TiledDiffusion/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .tiled_diffusion import NODE_CLASS_MAPPINGS as TD_NCM, NODE_DISPLAY_NAME_MAPPINGS as TD_NDCM
|
| 2 |
+
from .tiled_vae import NODE_CLASS_MAPPINGS as TV_NCM, NODE_DISPLAY_NAME_MAPPINGS as TV_NDCM
|
| 3 |
+
NODE_CLASS_MAPPINGS = {}
|
| 4 |
+
NODE_DISPLAY_NAME_MAPPINGS = {}
|
| 5 |
+
NODE_CLASS_MAPPINGS.update(TD_NCM)
|
| 6 |
+
NODE_DISPLAY_NAME_MAPPINGS.update(TD_NDCM)
|
| 7 |
+
NODE_CLASS_MAPPINGS.update(TV_NCM)
|
| 8 |
+
NODE_DISPLAY_NAME_MAPPINGS.update(TV_NDCM)
|
| 9 |
+
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|
ComfyUI-TiledDiffusion/__pycache__/.patches.cpython-310.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/.patches.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (431 Bytes). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (717 Bytes). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/tiled_diffusion.cpython-310.pyc
ADDED
|
Binary file (19 kB). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/tiled_diffusion.cpython-311.pyc
ADDED
|
Binary file (37 kB). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/tiled_vae.cpython-310.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/tiled_vae.cpython-311.pyc
ADDED
|
Binary file (45.7 kB). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (7.83 kB). View file
|
|
|
ComfyUI-TiledDiffusion/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
ComfyUI-TiledDiffusion/tiled_diffusion.py
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import torch
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
import comfy.model_management
|
| 5 |
+
from comfy.model_patcher import ModelPatcher
|
| 6 |
+
import comfy.model_patcher
|
| 7 |
+
from comfy.model_base import BaseModel
|
| 8 |
+
from typing import List, Union, Tuple, Dict
|
| 9 |
+
from nodes import ImageScale
|
| 10 |
+
import comfy.utils
|
| 11 |
+
from comfy.controlnet import ControlNet, T2IAdapter
|
| 12 |
+
|
| 13 |
+
opt_C = 4
|
| 14 |
+
opt_f = 8
|
| 15 |
+
|
| 16 |
+
def ceildiv(big, small):
|
| 17 |
+
# Correct ceiling division that avoids floating-point errors and importing math.ceil.
|
| 18 |
+
return -(big // -small)
|
| 19 |
+
|
| 20 |
+
from enum import Enum
|
| 21 |
+
class BlendMode(Enum): # i.e. LayerType
|
| 22 |
+
FOREGROUND = 'Foreground'
|
| 23 |
+
BACKGROUND = 'Background'
|
| 24 |
+
|
| 25 |
+
class Processing: ...
|
| 26 |
+
class Device: ...
|
| 27 |
+
devices = Device()
|
| 28 |
+
devices.device = comfy.model_management.get_torch_device()
|
| 29 |
+
|
| 30 |
+
def null_decorator(fn):
|
| 31 |
+
def wrapper(*args, **kwargs):
|
| 32 |
+
return fn(*args, **kwargs)
|
| 33 |
+
return wrapper
|
| 34 |
+
|
| 35 |
+
keep_signature = null_decorator
|
| 36 |
+
controlnet = null_decorator
|
| 37 |
+
stablesr = null_decorator
|
| 38 |
+
grid_bbox = null_decorator
|
| 39 |
+
custom_bbox = null_decorator
|
| 40 |
+
noise_inverse = null_decorator
|
| 41 |
+
|
| 42 |
+
class BBox:
|
| 43 |
+
''' grid bbox '''
|
| 44 |
+
|
| 45 |
+
def __init__(self, x:int, y:int, w:int, h:int):
|
| 46 |
+
self.x = x
|
| 47 |
+
self.y = y
|
| 48 |
+
self.w = w
|
| 49 |
+
self.h = h
|
| 50 |
+
self.box = [x, y, x+w, y+h]
|
| 51 |
+
self.slicer = slice(None), slice(None), slice(y, y+h), slice(x, x+w)
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx:int) -> int:
|
| 54 |
+
return self.box[idx]
|
| 55 |
+
|
| 56 |
+
def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]:
|
| 57 |
+
cols = ceildiv((w - overlap) , (tile_w - overlap))
|
| 58 |
+
rows = ceildiv((h - overlap) , (tile_h - overlap))
|
| 59 |
+
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
| 60 |
+
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
| 61 |
+
|
| 62 |
+
bbox_list: List[BBox] = []
|
| 63 |
+
weight = torch.zeros((1, 1, h, w), device=devices.device, dtype=torch.float32)
|
| 64 |
+
for row in range(rows):
|
| 65 |
+
y = min(int(row * dy), h - tile_h)
|
| 66 |
+
for col in range(cols):
|
| 67 |
+
x = min(int(col * dx), w - tile_w)
|
| 68 |
+
|
| 69 |
+
bbox = BBox(x, y, tile_w, tile_h)
|
| 70 |
+
bbox_list.append(bbox)
|
| 71 |
+
weight[bbox.slicer] += init_weight
|
| 72 |
+
|
| 73 |
+
return bbox_list, weight
|
| 74 |
+
|
| 75 |
+
class CustomBBox(BBox):
|
| 76 |
+
''' region control bbox '''
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
class AbstractDiffusion:
|
| 80 |
+
def __init__(self):
|
| 81 |
+
self.method = self.__class__.__name__
|
| 82 |
+
self.pbar = None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
self.w: int = 0
|
| 86 |
+
self.h: int = 0
|
| 87 |
+
self.tile_width: int = None
|
| 88 |
+
self.tile_height: int = None
|
| 89 |
+
self.tile_overlap: int = None
|
| 90 |
+
self.tile_batch_size: int = None
|
| 91 |
+
|
| 92 |
+
# cache. final result of current sampling step, [B, C=4, H//8, W//8]
|
| 93 |
+
# avoiding overhead of creating new tensors and weight summing
|
| 94 |
+
self.x_buffer: Tensor = None
|
| 95 |
+
# self.w: int = int(self.p.width // opt_f) # latent size
|
| 96 |
+
# self.h: int = int(self.p.height // opt_f)
|
| 97 |
+
# weights for background & grid bboxes
|
| 98 |
+
self._weights: Tensor = None
|
| 99 |
+
# self.weights: Tensor = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)
|
| 100 |
+
self._init_grid_bbox = None
|
| 101 |
+
self._init_done = None
|
| 102 |
+
|
| 103 |
+
# count the step correctly
|
| 104 |
+
self.step_count = 0
|
| 105 |
+
self.inner_loop_count = 0
|
| 106 |
+
self.kdiff_step = -1
|
| 107 |
+
|
| 108 |
+
# ext. Grid tiling painting (grid bbox)
|
| 109 |
+
self.enable_grid_bbox: bool = False
|
| 110 |
+
self.tile_w: int = None
|
| 111 |
+
self.tile_h: int = None
|
| 112 |
+
self.tile_bs: int = None
|
| 113 |
+
self.num_tiles: int = None
|
| 114 |
+
self.num_batches: int = None
|
| 115 |
+
self.batched_bboxes: List[List[BBox]] = []
|
| 116 |
+
|
| 117 |
+
# ext. Region Prompt Control (custom bbox)
|
| 118 |
+
self.enable_custom_bbox: bool = False
|
| 119 |
+
self.custom_bboxes: List[CustomBBox] = []
|
| 120 |
+
# self.cond_basis: Cond = None
|
| 121 |
+
# self.uncond_basis: Uncond = None
|
| 122 |
+
# self.draw_background: bool = True # by default we draw major prompts in grid tiles
|
| 123 |
+
# self.causal_layers: bool = None
|
| 124 |
+
|
| 125 |
+
# ext. ControlNet
|
| 126 |
+
self.enable_controlnet: bool = False
|
| 127 |
+
# self.controlnet_script: ModuleType = None
|
| 128 |
+
self.control_tensor_batch_dict = {}
|
| 129 |
+
self.control_tensor_batch: List[List[Tensor]] = [[]]
|
| 130 |
+
# self.control_params: Dict[str, Tensor] = None # {}
|
| 131 |
+
self.control_params: Dict[Tuple, List[List[Tensor]]] = {}
|
| 132 |
+
self.control_tensor_cpu: bool = None
|
| 133 |
+
self.control_tensor_custom: List[List[Tensor]] = []
|
| 134 |
+
|
| 135 |
+
self.draw_background: bool = True # by default we draw major prompts in grid tiles
|
| 136 |
+
self.control_tensor_cpu = False
|
| 137 |
+
self.weights = None
|
| 138 |
+
self.imagescale = ImageScale()
|
| 139 |
+
|
| 140 |
+
def reset(self):
|
| 141 |
+
tile_width = self.tile_width
|
| 142 |
+
tile_height = self.tile_height
|
| 143 |
+
tile_overlap = self.tile_overlap
|
| 144 |
+
tile_batch_size = self.tile_batch_size
|
| 145 |
+
self.__init__()
|
| 146 |
+
self.tile_width = tile_width
|
| 147 |
+
self.tile_height = tile_height
|
| 148 |
+
self.tile_overlap = tile_overlap
|
| 149 |
+
self.tile_batch_size = tile_batch_size
|
| 150 |
+
|
| 151 |
+
def repeat_tensor(self, x:Tensor, n:int, concat=False, concat_to=0) -> Tensor:
|
| 152 |
+
''' repeat the tensor on it's first dim '''
|
| 153 |
+
if n == 1: return x
|
| 154 |
+
B = x.shape[0]
|
| 155 |
+
r_dims = len(x.shape) - 1
|
| 156 |
+
if B == 1: # batch_size = 1 (not `tile_batch_size`)
|
| 157 |
+
shape = [n] + [-1] * r_dims # [N, -1, ...]
|
| 158 |
+
return x.expand(shape) # `expand` is much lighter than `tile`
|
| 159 |
+
else:
|
| 160 |
+
if concat:
|
| 161 |
+
return torch.cat([x for _ in range(n)], dim=0)[:concat_to]
|
| 162 |
+
shape = [n] + [1] * r_dims # [N, 1, ...]
|
| 163 |
+
return x.repeat(shape)
|
| 164 |
+
def update_pbar(self):
|
| 165 |
+
if self.pbar.n >= self.pbar.total:
|
| 166 |
+
self.pbar.close()
|
| 167 |
+
else:
|
| 168 |
+
# self.pbar.update()
|
| 169 |
+
sampling_step = 20
|
| 170 |
+
if self.step_count == sampling_step:
|
| 171 |
+
self.inner_loop_count += 1
|
| 172 |
+
if self.inner_loop_count < self.total_bboxes:
|
| 173 |
+
self.pbar.update()
|
| 174 |
+
else:
|
| 175 |
+
self.step_count = sampling_step
|
| 176 |
+
self.inner_loop_count = 0
|
| 177 |
+
def reset_buffer(self, x_in:Tensor):
|
| 178 |
+
# Judge if the shape of x_in is the same as the shape of x_buffer
|
| 179 |
+
if self.x_buffer is None or self.x_buffer.shape != x_in.shape:
|
| 180 |
+
self.x_buffer = torch.zeros_like(x_in, device=x_in.device, dtype=x_in.dtype)
|
| 181 |
+
else:
|
| 182 |
+
self.x_buffer.zero_()
|
| 183 |
+
|
| 184 |
+
@grid_bbox
|
| 185 |
+
def init_grid_bbox(self, tile_w:int, tile_h:int, overlap:int, tile_bs:int):
|
| 186 |
+
# if self._init_grid_bbox is not None: return
|
| 187 |
+
# self._init_grid_bbox = True
|
| 188 |
+
self.weights = torch.zeros((1, 1, self.h, self.w), device=devices.device, dtype=torch.float32)
|
| 189 |
+
self.enable_grid_bbox = True
|
| 190 |
+
|
| 191 |
+
self.tile_w = min(tile_w, self.w)
|
| 192 |
+
self.tile_h = min(tile_h, self.h)
|
| 193 |
+
overlap = max(0, min(overlap, min(tile_w, tile_h) - 4))
|
| 194 |
+
# split the latent into overlapped tiles, then batching
|
| 195 |
+
# weights basically indicate how many times a pixel is painted
|
| 196 |
+
bboxes, weights = split_bboxes(self.w, self.h, self.tile_w, self.tile_h, overlap, self.get_tile_weights())
|
| 197 |
+
self.weights += weights
|
| 198 |
+
self.num_tiles = len(bboxes)
|
| 199 |
+
self.num_batches = ceildiv(self.num_tiles , tile_bs)
|
| 200 |
+
self.tile_bs = ceildiv(len(bboxes) , self.num_batches) # optimal_batch_size
|
| 201 |
+
self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)]
|
| 202 |
+
|
| 203 |
+
@grid_bbox
|
| 204 |
+
def get_tile_weights(self) -> Union[Tensor, float]:
|
| 205 |
+
return 1.0
|
| 206 |
+
|
| 207 |
+
@noise_inverse
|
| 208 |
+
def init_noise_inverse(self, steps:int, retouch:float, get_cache_callback, set_cache_callback, renoise_strength:float, renoise_kernel:int):
|
| 209 |
+
self.noise_inverse_enabled = True
|
| 210 |
+
self.noise_inverse_steps = steps
|
| 211 |
+
self.noise_inverse_retouch = float(retouch)
|
| 212 |
+
self.noise_inverse_renoise_strength = float(renoise_strength)
|
| 213 |
+
self.noise_inverse_renoise_kernel = int(renoise_kernel)
|
| 214 |
+
self.noise_inverse_set_cache = set_cache_callback
|
| 215 |
+
self.noise_inverse_get_cache = get_cache_callback
|
| 216 |
+
|
| 217 |
+
def init_done(self):
|
| 218 |
+
'''
|
| 219 |
+
Call this after all `init_*`, settings are done, now perform:
|
| 220 |
+
- settings sanity check
|
| 221 |
+
- pre-computations, cache init
|
| 222 |
+
- anything thing needed before denoising starts
|
| 223 |
+
'''
|
| 224 |
+
|
| 225 |
+
# if self._init_done is not None: return
|
| 226 |
+
# self._init_done = True
|
| 227 |
+
self.total_bboxes = 0
|
| 228 |
+
if self.enable_grid_bbox: self.total_bboxes += self.num_batches
|
| 229 |
+
if self.enable_custom_bbox: self.total_bboxes += len(self.custom_bboxes)
|
| 230 |
+
assert self.total_bboxes > 0, "Nothing to paint! No background to draw and no custom bboxes were provided."
|
| 231 |
+
|
| 232 |
+
# sampling_steps = _steps
|
| 233 |
+
# self.pbar = tqdm(total=(self.total_bboxes) * sampling_steps, desc=f"{self.method} Sampling: ")
|
| 234 |
+
|
| 235 |
+
@controlnet
|
| 236 |
+
def prepare_controlnet_tensors(self, refresh:bool=False, tensor=None):
|
| 237 |
+
''' Crop the control tensor into tiles and cache them '''
|
| 238 |
+
if not refresh:
|
| 239 |
+
if self.control_tensor_batch is not None or self.control_params is not None: return
|
| 240 |
+
tensors = [tensor]
|
| 241 |
+
self.org_control_tensor_batch = tensors
|
| 242 |
+
self.control_tensor_batch = []
|
| 243 |
+
for i in range(len(tensors)):
|
| 244 |
+
control_tile_list = []
|
| 245 |
+
control_tensor = tensors[i]
|
| 246 |
+
for bboxes in self.batched_bboxes:
|
| 247 |
+
single_batch_tensors = []
|
| 248 |
+
for bbox in bboxes:
|
| 249 |
+
if len(control_tensor.shape) == 3:
|
| 250 |
+
control_tensor.unsqueeze_(0)
|
| 251 |
+
control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f]
|
| 252 |
+
single_batch_tensors.append(control_tile)
|
| 253 |
+
control_tile = torch.cat(single_batch_tensors, dim=0)
|
| 254 |
+
if self.control_tensor_cpu:
|
| 255 |
+
control_tile = control_tile.cpu()
|
| 256 |
+
control_tile_list.append(control_tile)
|
| 257 |
+
self.control_tensor_batch.append(control_tile_list)
|
| 258 |
+
|
| 259 |
+
if len(self.custom_bboxes) > 0:
|
| 260 |
+
custom_control_tile_list = []
|
| 261 |
+
for bbox in self.custom_bboxes:
|
| 262 |
+
if len(control_tensor.shape) == 3:
|
| 263 |
+
control_tensor.unsqueeze_(0)
|
| 264 |
+
control_tile = control_tensor[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f]
|
| 265 |
+
if self.control_tensor_cpu:
|
| 266 |
+
control_tile = control_tile.cpu()
|
| 267 |
+
custom_control_tile_list.append(control_tile)
|
| 268 |
+
self.control_tensor_custom.append(custom_control_tile_list)
|
| 269 |
+
|
| 270 |
+
@controlnet
|
| 271 |
+
def switch_controlnet_tensors(self, batch_id:int, x_batch_size:int, tile_batch_size:int, is_denoise=False):
|
| 272 |
+
# if not self.enable_controlnet: return
|
| 273 |
+
if self.control_tensor_batch is None: return
|
| 274 |
+
# self.control_params = [0]
|
| 275 |
+
|
| 276 |
+
# for param_id in range(len(self.control_params)):
|
| 277 |
+
for param_id in range(len(self.control_tensor_batch)):
|
| 278 |
+
# tensor that was concatenated in `prepare_controlnet_tensors`
|
| 279 |
+
control_tile = self.control_tensor_batch[param_id][batch_id]
|
| 280 |
+
# broadcast to latent batch size
|
| 281 |
+
if x_batch_size > 1: # self.is_kdiff:
|
| 282 |
+
all_control_tile = []
|
| 283 |
+
for i in range(tile_batch_size):
|
| 284 |
+
this_control_tile = [control_tile[i].unsqueeze(0)] * x_batch_size
|
| 285 |
+
all_control_tile.append(torch.cat(this_control_tile, dim=0))
|
| 286 |
+
control_tile = torch.cat(all_control_tile, dim=0) # [:x_tile.shape[0]]
|
| 287 |
+
self.control_tensor_batch[param_id][batch_id] = control_tile
|
| 288 |
+
# else:
|
| 289 |
+
# control_tile = control_tile.repeat([x_batch_size if is_denoise else x_batch_size * 2, 1, 1, 1])
|
| 290 |
+
# self.control_params[param_id].hint_cond = control_tile.to(devices.device)
|
| 291 |
+
|
| 292 |
+
def process_controlnet(self, x_shape, x_dtype, c_in: dict, cond_or_uncond: List, bboxes, batch_size: int, batch_id: int):
|
| 293 |
+
control: ControlNet = c_in['control']
|
| 294 |
+
param_id = -1 # current controlnet & previous_controlnets
|
| 295 |
+
tuple_key = tuple(cond_or_uncond) + tuple(x_shape)
|
| 296 |
+
while control is not None:
|
| 297 |
+
param_id += 1
|
| 298 |
+
PH, PW = self.h*8, self.w*8
|
| 299 |
+
|
| 300 |
+
if self.control_params.get(tuple_key, None) is None:
|
| 301 |
+
self.control_params[tuple_key] = [[None]]
|
| 302 |
+
val = self.control_params[tuple_key]
|
| 303 |
+
if param_id+1 >= len(val):
|
| 304 |
+
val.extend([[None] for _ in range(param_id+1)])
|
| 305 |
+
if len(self.batched_bboxes) >= len(val[param_id]):
|
| 306 |
+
val[param_id].extend([[None] for _ in range(len(self.batched_bboxes))])
|
| 307 |
+
|
| 308 |
+
while len(self.control_params[tuple_key]) <= param_id:
|
| 309 |
+
self.control_params[tuple_key].extend([None])
|
| 310 |
+
# print('extending param_id')
|
| 311 |
+
|
| 312 |
+
while len(self.control_params[tuple_key][param_id]) <= batch_id:
|
| 313 |
+
self.control_params[tuple_key][param_id].extend([None])
|
| 314 |
+
# print('extending batch_id')
|
| 315 |
+
|
| 316 |
+
# Below is taken from comfy.controlnet.py, but we need to additionally tile the cnets.
|
| 317 |
+
# if statement: eager eval. first time when cond_hint is None.
|
| 318 |
+
if self.refresh or control.cond_hint is None or not isinstance(self.control_params[tuple_key][param_id][batch_id], Tensor):
|
| 319 |
+
dtype = getattr(control, 'manual_cast_dtype', None)
|
| 320 |
+
if dtype is None: dtype = getattr(getattr(control, 'control_model', None), 'dtype', None)
|
| 321 |
+
if dtype is None: dtype = x_dtype
|
| 322 |
+
if isinstance(control, T2IAdapter):
|
| 323 |
+
width, height = control.scale_image_to(PW, PH)
|
| 324 |
+
control.cond_hint = comfy.utils.common_upscale(control.cond_hint_original, width, height, 'nearest-exact', "center").float().to(control.device)
|
| 325 |
+
if control.channels_in == 1 and control.cond_hint.shape[1] > 1:
|
| 326 |
+
control.cond_hint = torch.mean(control.cond_hint, 1, keepdim=True)
|
| 327 |
+
elif control.__class__.__name__ == 'ControlLLLiteAdvanced':
|
| 328 |
+
if control.sub_idxs is not None and control.cond_hint_original.shape[0] >= control.full_latent_length:
|
| 329 |
+
control.cond_hint = comfy.utils.common_upscale(control.cond_hint_original[control.sub_idxs], PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
| 330 |
+
else:
|
| 331 |
+
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
|
| 332 |
+
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
|
| 333 |
+
else:
|
| 334 |
+
control.cond_hint = comfy.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
| 335 |
+
else:
|
| 336 |
+
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
|
| 337 |
+
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
|
| 338 |
+
else:
|
| 339 |
+
control.cond_hint = comfy.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', 'center').to(dtype=dtype, device=control.device)
|
| 340 |
+
|
| 341 |
+
# Broadcast then tile
|
| 342 |
+
#
|
| 343 |
+
# Below can be in the parent's if clause because self.refresh will trigger on resolution change, e.g. cause of ConditioningSetArea
|
| 344 |
+
# so that particular case isn't cached atm.
|
| 345 |
+
cond_hint_pre_tile = control.cond_hint
|
| 346 |
+
if control.cond_hint.shape[0] < batch_size :
|
| 347 |
+
cond_hint_pre_tile = self.repeat_tensor(control.cond_hint, ceildiv(batch_size, control.cond_hint.shape[0]))[:batch_size]
|
| 348 |
+
cns = [cond_hint_pre_tile[:, :, bbox[1]*opt_f:bbox[3]*opt_f, bbox[0]*opt_f:bbox[2]*opt_f] for bbox in bboxes]
|
| 349 |
+
control.cond_hint = torch.cat(cns, dim=0)
|
| 350 |
+
self.control_params[tuple_key][param_id][batch_id]=control.cond_hint
|
| 351 |
+
else:
|
| 352 |
+
control.cond_hint = self.control_params[tuple_key][param_id][batch_id]
|
| 353 |
+
control = control.previous_controlnet
|
| 354 |
+
|
| 355 |
+
import numpy as np
|
| 356 |
+
from numpy import pi, exp, sqrt
|
| 357 |
+
def gaussian_weights(tile_w:int, tile_h:int) -> Tensor:
|
| 358 |
+
'''
|
| 359 |
+
Copy from the original implementation of Mixture of Diffusers
|
| 360 |
+
https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py
|
| 361 |
+
This generates gaussian weights to smooth the noise of each tile.
|
| 362 |
+
This is critical for this method to work.
|
| 363 |
+
'''
|
| 364 |
+
f = lambda x, midpoint, var=0.01: exp(-(x-midpoint)*(x-midpoint) / (tile_w*tile_w) / (2*var)) / sqrt(2*pi*var)
|
| 365 |
+
x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] # -1 because index goes from 0 to latent_width - 1
|
| 366 |
+
y_probs = [f(y, tile_h / 2) for y in range(tile_h)]
|
| 367 |
+
|
| 368 |
+
w = np.outer(y_probs, x_probs)
|
| 369 |
+
return torch.from_numpy(w).to(devices.device, dtype=torch.float32)
|
| 370 |
+
|
| 371 |
+
class CondDict: ...
|
| 372 |
+
|
| 373 |
+
class MultiDiffusion(AbstractDiffusion):
|
| 374 |
+
|
| 375 |
+
@torch.no_grad()
|
| 376 |
+
def __call__(self, model_function: BaseModel.apply_model, args: dict):
|
| 377 |
+
x_in: Tensor = args["input"]
|
| 378 |
+
t_in: Tensor = args["timestep"]
|
| 379 |
+
c_in: dict = args["c"]
|
| 380 |
+
cond_or_uncond: List = args["cond_or_uncond"]
|
| 381 |
+
c_crossattn: Tensor = c_in['c_crossattn']
|
| 382 |
+
|
| 383 |
+
N, C, H, W = x_in.shape
|
| 384 |
+
|
| 385 |
+
# comfyui can feed in a latent that's a different size cause of SetArea, so we'll refresh in that case.
|
| 386 |
+
self.refresh = False
|
| 387 |
+
if self.weights is None or self.h != H or self.w != W:
|
| 388 |
+
self.h, self.w = H, W
|
| 389 |
+
self.refresh = True
|
| 390 |
+
self.init_grid_bbox(self.tile_width, self.tile_height, self.tile_overlap, self.tile_batch_size)
|
| 391 |
+
# init everything done, perform sanity check & pre-computations
|
| 392 |
+
self.init_done()
|
| 393 |
+
self.h, self.w = H, W
|
| 394 |
+
# clear buffer canvas
|
| 395 |
+
self.reset_buffer(x_in)
|
| 396 |
+
|
| 397 |
+
# Background sampling (grid bbox)
|
| 398 |
+
if self.draw_background:
|
| 399 |
+
for batch_id, bboxes in enumerate(self.batched_bboxes):
|
| 400 |
+
if comfy.model_management.processing_interrupted():
|
| 401 |
+
# self.pbar.close()
|
| 402 |
+
return x_in
|
| 403 |
+
|
| 404 |
+
# batching & compute tiles
|
| 405 |
+
x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW]
|
| 406 |
+
n_rep = len(bboxes)
|
| 407 |
+
ts_tile = self.repeat_tensor(t_in, n_rep)
|
| 408 |
+
cond_tile = self.repeat_tensor(c_crossattn, n_rep)
|
| 409 |
+
c_tile = c_in.copy()
|
| 410 |
+
c_tile['c_crossattn'] = cond_tile
|
| 411 |
+
if 'time_context' in c_in:
|
| 412 |
+
c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep)
|
| 413 |
+
for key in c_tile:
|
| 414 |
+
if key in ['y', 'c_concat']:
|
| 415 |
+
icond = c_tile[key]
|
| 416 |
+
if icond.shape[2:] == (self.h, self.w):
|
| 417 |
+
c_tile[key] = torch.cat([icond[bbox.slicer] for bbox in bboxes])
|
| 418 |
+
else:
|
| 419 |
+
c_tile[key] = self.repeat_tensor(icond, n_rep)
|
| 420 |
+
|
| 421 |
+
# controlnet tiling
|
| 422 |
+
# self.switch_controlnet_tensors(batch_id, N, len(bboxes))
|
| 423 |
+
if 'control' in c_in:
|
| 424 |
+
control=c_in['control']
|
| 425 |
+
self.process_controlnet(x_tile.shape, x_tile.dtype, c_in, cond_or_uncond, bboxes, N, batch_id)
|
| 426 |
+
c_tile['control'] = control.get_control(x_tile, ts_tile, c_tile, len(cond_or_uncond))
|
| 427 |
+
|
| 428 |
+
# stablesr tiling
|
| 429 |
+
# self.switch_stablesr_tensors(batch_id)
|
| 430 |
+
|
| 431 |
+
x_tile_out = model_function(x_tile, ts_tile, **c_tile)
|
| 432 |
+
|
| 433 |
+
for i, bbox in enumerate(bboxes):
|
| 434 |
+
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :]
|
| 435 |
+
del x_tile_out, x_tile, ts_tile, c_tile
|
| 436 |
+
|
| 437 |
+
# update progress bar
|
| 438 |
+
# self.update_pbar()
|
| 439 |
+
|
| 440 |
+
# Averaging background buffer
|
| 441 |
+
x_out = torch.where(self.weights > 1, self.x_buffer / self.weights, self.x_buffer)
|
| 442 |
+
|
| 443 |
+
return x_out
|
| 444 |
+
|
| 445 |
+
class MixtureOfDiffusers(AbstractDiffusion):
|
| 446 |
+
"""
|
| 447 |
+
Mixture-of-Diffusers Implementation
|
| 448 |
+
https://github.com/albarji/mixture-of-diffusers
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
def __init__(self, *args, **kwargs):
|
| 452 |
+
super().__init__(*args, **kwargs)
|
| 453 |
+
|
| 454 |
+
# weights for custom bboxes
|
| 455 |
+
self.custom_weights: List[Tensor] = []
|
| 456 |
+
self.get_weight = gaussian_weights
|
| 457 |
+
|
| 458 |
+
def init_done(self):
|
| 459 |
+
super().init_done()
|
| 460 |
+
# The original gaussian weights can be extremely small, so we rescale them for numerical stability
|
| 461 |
+
self.rescale_factor = 1 / self.weights
|
| 462 |
+
# Meanwhile, we rescale the custom weights in advance to save time of slicing
|
| 463 |
+
for bbox_id, bbox in enumerate(self.custom_bboxes):
|
| 464 |
+
if bbox.blend_mode == BlendMode.BACKGROUND:
|
| 465 |
+
self.custom_weights[bbox_id] *= self.rescale_factor[bbox.slicer]
|
| 466 |
+
|
| 467 |
+
@grid_bbox
|
| 468 |
+
def get_tile_weights(self) -> Tensor:
|
| 469 |
+
# weights for grid bboxes
|
| 470 |
+
# if not hasattr(self, 'tile_weights'):
|
| 471 |
+
# x_in can change sizes cause of ConditioningSetArea, so we have to recalcualte each time
|
| 472 |
+
self.tile_weights = self.get_weight(self.tile_w, self.tile_h)
|
| 473 |
+
return self.tile_weights
|
| 474 |
+
|
| 475 |
+
@torch.no_grad()
|
| 476 |
+
def __call__(self, model_function: BaseModel.apply_model, args: dict):
|
| 477 |
+
x_in: Tensor = args["input"]
|
| 478 |
+
t_in: Tensor = args["timestep"]
|
| 479 |
+
c_in: dict = args["c"]
|
| 480 |
+
cond_or_uncond: List= args["cond_or_uncond"]
|
| 481 |
+
c_crossattn: Tensor = c_in['c_crossattn']
|
| 482 |
+
|
| 483 |
+
N, C, H, W = x_in.shape
|
| 484 |
+
|
| 485 |
+
self.refresh = False
|
| 486 |
+
# self.refresh = True
|
| 487 |
+
if self.weights is None or self.h != H or self.w != W:
|
| 488 |
+
self.h, self.w = H, W
|
| 489 |
+
self.refresh = True
|
| 490 |
+
self.init_grid_bbox(self.tile_width, self.tile_height, self.tile_overlap, self.tile_batch_size)
|
| 491 |
+
# init everything done, perform sanity check & pre-computations
|
| 492 |
+
self.init_done()
|
| 493 |
+
self.h, self.w = H, W
|
| 494 |
+
# clear buffer canvas
|
| 495 |
+
self.reset_buffer(x_in)
|
| 496 |
+
|
| 497 |
+
# self.pbar = tqdm(total=(self.total_bboxes) * sampling_steps, desc=f"{self.method} Sampling: ")
|
| 498 |
+
# self.pbar = tqdm(total=len(self.batched_bboxes), desc=f"{self.method} Sampling: ")
|
| 499 |
+
|
| 500 |
+
# Global sampling
|
| 501 |
+
if self.draw_background:
|
| 502 |
+
for batch_id, bboxes in enumerate(self.batched_bboxes): # batch_id is the `Latent tile batch size`
|
| 503 |
+
if comfy.model_management.processing_interrupted():
|
| 504 |
+
# self.pbar.close()
|
| 505 |
+
return x_in
|
| 506 |
+
|
| 507 |
+
# batching
|
| 508 |
+
x_tile_list = []
|
| 509 |
+
t_tile_list = []
|
| 510 |
+
icond_map = {}
|
| 511 |
+
# tcond_tile_list = []
|
| 512 |
+
# icond_tile_list = []
|
| 513 |
+
# vcond_tile_list = []
|
| 514 |
+
# control_list = []
|
| 515 |
+
for bbox in bboxes:
|
| 516 |
+
x_tile_list.append(x_in[bbox.slicer])
|
| 517 |
+
t_tile_list.append(t_in)
|
| 518 |
+
if isinstance(c_in, dict):
|
| 519 |
+
# tcond
|
| 520 |
+
# tcond_tile = c_crossattn #self.get_tcond(c_in) # cond, [1, 77, 768]
|
| 521 |
+
# tcond_tile_list.append(tcond_tile)
|
| 522 |
+
# present in sdxl
|
| 523 |
+
for key in ['y', 'c_concat']:
|
| 524 |
+
if key in c_in:
|
| 525 |
+
icond=c_in[key] # self.get_icond(c_in)
|
| 526 |
+
if icond.shape[2:] == (self.h, self.w):
|
| 527 |
+
icond = icond[bbox.slicer]
|
| 528 |
+
if icond_map.get(key, None) is None:
|
| 529 |
+
icond_map[key] = []
|
| 530 |
+
icond_map[key].append(icond)
|
| 531 |
+
# # vcond:
|
| 532 |
+
# vcond = self.get_vcond(c_in)
|
| 533 |
+
# vcond_tile_list.append(vcond)
|
| 534 |
+
else:
|
| 535 |
+
print('>> [WARN] not supported, make an issue on github!!')
|
| 536 |
+
n_rep = len(bboxes)
|
| 537 |
+
x_tile = torch.cat(x_tile_list, dim=0) # differs each
|
| 538 |
+
t_tile = self.repeat_tensor(t_in, n_rep) # just repeat
|
| 539 |
+
tcond_tile = self.repeat_tensor(c_crossattn, n_rep) # just repeat
|
| 540 |
+
c_tile = c_in.copy()
|
| 541 |
+
c_tile['c_crossattn'] = tcond_tile
|
| 542 |
+
if 'time_context' in c_in:
|
| 543 |
+
c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep) # just repeat
|
| 544 |
+
for key in c_tile:
|
| 545 |
+
if key in ['y', 'c_concat']:
|
| 546 |
+
icond_tile = torch.cat(icond_map[key], dim=0) # differs each
|
| 547 |
+
c_tile[key] = icond_tile
|
| 548 |
+
# vcond_tile = torch.cat(vcond_tile_list, dim=0) if None not in vcond_tile_list else None # just repeat
|
| 549 |
+
|
| 550 |
+
# controlnet
|
| 551 |
+
# self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True)
|
| 552 |
+
if 'control' in c_in:
|
| 553 |
+
control=c_in['control']
|
| 554 |
+
self.process_controlnet(x_tile.shape, x_tile.dtype, c_in, cond_or_uncond, bboxes, N, batch_id)
|
| 555 |
+
c_tile['control'] = control.get_control(x_tile, t_tile, c_tile, len(cond_or_uncond))
|
| 556 |
+
|
| 557 |
+
# stablesr
|
| 558 |
+
# self.switch_stablesr_tensors(batch_id)
|
| 559 |
+
|
| 560 |
+
# denoising: here the x is the noise
|
| 561 |
+
x_tile_out = model_function(x_tile, t_tile, **c_tile)
|
| 562 |
+
|
| 563 |
+
# de-batching
|
| 564 |
+
for i, bbox in enumerate(bboxes):
|
| 565 |
+
# These weights can be calcluated in advance, but will cost a lot of vram
|
| 566 |
+
# when you have many tiles. So we calculate it here.
|
| 567 |
+
w = self.tile_weights * self.rescale_factor[bbox.slicer]
|
| 568 |
+
self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] * w
|
| 569 |
+
del x_tile_out, x_tile, t_tile, c_tile
|
| 570 |
+
|
| 571 |
+
# self.update_pbar()
|
| 572 |
+
# self.pbar.update()
|
| 573 |
+
# self.pbar.close()
|
| 574 |
+
x_out = self.x_buffer
|
| 575 |
+
|
| 576 |
+
return x_out
|
| 577 |
+
|
| 578 |
+
from .utils import hook_all
|
| 579 |
+
hook_all()
|
| 580 |
+
|
| 581 |
+
MAX_RESOLUTION=8192
|
| 582 |
+
class TiledDiffusion():
|
| 583 |
+
@classmethod
|
| 584 |
+
def INPUT_TYPES(s):
|
| 585 |
+
return {"required": {"model": ("MODEL", ),
|
| 586 |
+
"method": (["MultiDiffusion", "Mixture of Diffusers"], {"default": "Mixture of Diffusers"}),
|
| 587 |
+
# "tile_width": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
|
| 588 |
+
"tile_width": ("INT", {"default": 96*opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
| 589 |
+
# "tile_height": ("INT", {"default": 96, "min": 16, "max": 256, "step": 16}),
|
| 590 |
+
"tile_height": ("INT", {"default": 96*opt_f, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
|
| 591 |
+
"tile_overlap": ("INT", {"default": 8*opt_f, "min": 0, "max": 256*opt_f, "step": 4*opt_f}),
|
| 592 |
+
"tile_batch_size": ("INT", {"default": 4, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
| 593 |
+
}}
|
| 594 |
+
RETURN_TYPES = ("MODEL",)
|
| 595 |
+
FUNCTION = "apply"
|
| 596 |
+
CATEGORY = "_for_testing"
|
| 597 |
+
|
| 598 |
+
def apply(self, model: ModelPatcher, method, tile_width, tile_height, tile_overlap, tile_batch_size):
|
| 599 |
+
if method == "Mixture of Diffusers":
|
| 600 |
+
implement = MixtureOfDiffusers()
|
| 601 |
+
else:
|
| 602 |
+
implement = MultiDiffusion()
|
| 603 |
+
|
| 604 |
+
# if noise_inversion:
|
| 605 |
+
# get_cache_callback = self.noise_inverse_get_cache
|
| 606 |
+
# set_cache_callback = None # lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, steps, retouch)
|
| 607 |
+
# implement.init_noise_inverse(steps, retouch, get_cache_callback, set_cache_callback, renoise_strength, renoise_kernel_size)
|
| 608 |
+
|
| 609 |
+
implement.tile_width = tile_width // opt_f
|
| 610 |
+
implement.tile_height = tile_height // opt_f
|
| 611 |
+
implement.tile_overlap = tile_overlap // opt_f
|
| 612 |
+
implement.tile_batch_size = tile_batch_size
|
| 613 |
+
# implement.init_grid_bbox(tile_width, tile_height, tile_overlap, tile_batch_size)
|
| 614 |
+
# # init everything done, perform sanity check & pre-computations
|
| 615 |
+
# implement.init_done()
|
| 616 |
+
# hijack the behaviours
|
| 617 |
+
# implement.hook()
|
| 618 |
+
model = model.clone()
|
| 619 |
+
model.set_model_unet_function_wrapper(implement)
|
| 620 |
+
model.model_options['tiled_diffusion'] = True
|
| 621 |
+
return (model,)
|
| 622 |
+
|
| 623 |
+
class NoiseInversion():
|
| 624 |
+
@classmethod
|
| 625 |
+
def INPUT_TYPES(s):
|
| 626 |
+
return {"required": {"model": ("MODEL", ),
|
| 627 |
+
"positive": ("CONDITIONING", ),
|
| 628 |
+
"negative": ("CONDITIONING", ),
|
| 629 |
+
"latent_image": ("LATENT", ),
|
| 630 |
+
"image": ("IMAGE", ),
|
| 631 |
+
"steps": ("INT", {"default": 10, "min": 1, "max": 208, "step": 1}),
|
| 632 |
+
"retouch": ("FLOAT", {"default": 1, "min": 1, "max": 100, "step": 0.1}),
|
| 633 |
+
"renoise_strength": ("FLOAT", {"default": 1, "min": 1, "max": 2, "step": 0.01}),
|
| 634 |
+
"renoise_kernel_size": ("INT", {"default": 2, "min": 2, "max": 512, "step": 1}),
|
| 635 |
+
}}
|
| 636 |
+
RETURN_TYPES = ("LATENT",)
|
| 637 |
+
FUNCTION = "sample"
|
| 638 |
+
CATEGORY = "sampling"
|
| 639 |
+
def sample(self, model: ModelPatcher, positive, negative,
|
| 640 |
+
latent_image, image, steps, retouch, renoise_strength, renoise_kernel_size):
|
| 641 |
+
return (latent_image,)
|
| 642 |
+
|
| 643 |
+
NODE_CLASS_MAPPINGS = {
|
| 644 |
+
"TiledDiffusion": TiledDiffusion,
|
| 645 |
+
# "NoiseInversion": NoiseInversion,
|
| 646 |
+
}
|
| 647 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 648 |
+
"TiledDiffusion": "Tiled Diffusion",
|
| 649 |
+
# "NoiseInversion": "Noise Inversion",
|
| 650 |
+
}
|
ComfyUI-TiledDiffusion/tiled_vae.py
ADDED
|
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
# ------------------------------------------------------------------------
|
| 3 |
+
#
|
| 4 |
+
# Tiled VAE
|
| 5 |
+
#
|
| 6 |
+
# Introducing a revolutionary new optimization designed to make
|
| 7 |
+
# the VAE work with giant images on limited VRAM!
|
| 8 |
+
# Say goodbye to the frustration of OOM and hello to seamless output!
|
| 9 |
+
#
|
| 10 |
+
# ------------------------------------------------------------------------
|
| 11 |
+
#
|
| 12 |
+
# This script is a wild hack that splits the image into tiles,
|
| 13 |
+
# encodes each tile separately, and merges the result back together.
|
| 14 |
+
#
|
| 15 |
+
# Advantages:
|
| 16 |
+
# - The VAE can now work with giant images on limited VRAM
|
| 17 |
+
# (~10 GB for 8K images!)
|
| 18 |
+
# - The merged output is completely seamless without any post-processing.
|
| 19 |
+
#
|
| 20 |
+
# Drawbacks:
|
| 21 |
+
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
|
| 22 |
+
# You must use --no-half-vae to disable half VAE for that giant image.
|
| 23 |
+
# - The gradient calculation is not compatible with this hack. It
|
| 24 |
+
# will break any backward() or torch.autograd.grad() that passes VAE.
|
| 25 |
+
# (But you can still use the VAE to generate training data.)
|
| 26 |
+
#
|
| 27 |
+
# How it works:
|
| 28 |
+
# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
|
| 29 |
+
# 2. When Fast Mode is disabled:
|
| 30 |
+
# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
|
| 31 |
+
# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
|
| 32 |
+
# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
|
| 33 |
+
# 4. A zigzag execution order is used to reduce unnecessary data transfer.
|
| 34 |
+
# 3. When Fast Mode is enabled:
|
| 35 |
+
# 1. The original input is downsampled and passed to a separate task queue.
|
| 36 |
+
# 2. Its group norm parameters are recorded and used by all tiles' task queues.
|
| 37 |
+
# 3. Each tile is separately processed without any RAM-VRAM data transfer.
|
| 38 |
+
# 4. After all tiles are processed, tiles are written to a result buffer and returned.
|
| 39 |
+
# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
|
| 40 |
+
#
|
| 41 |
+
# Enjoy!
|
| 42 |
+
#
|
| 43 |
+
# @Author: LI YI @ Nanyang Technological University - Singapore
|
| 44 |
+
# @Date: 2023-03-02
|
| 45 |
+
# @License: CC BY-NC-SA 4.0
|
| 46 |
+
#
|
| 47 |
+
# Please give https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111
|
| 48 |
+
# a star if you like the project!
|
| 49 |
+
#
|
| 50 |
+
# -------------------------------------------------------------------------
|
| 51 |
+
'''
|
| 52 |
+
|
| 53 |
+
import gc
|
| 54 |
+
import math
|
| 55 |
+
from time import time
|
| 56 |
+
from tqdm import tqdm
|
| 57 |
+
|
| 58 |
+
import torch
|
| 59 |
+
import torch.version
|
| 60 |
+
import torch.nn.functional as F
|
| 61 |
+
# import gradio as gr
|
| 62 |
+
|
| 63 |
+
# import modules.scripts as scripts
|
| 64 |
+
# from .modules import devices
|
| 65 |
+
# from modules.shared import state
|
| 66 |
+
# from modules.ui import gr_show
|
| 67 |
+
# from modules.processing import opt_f
|
| 68 |
+
# from modules.sd_vae_approx import cheap_approximation
|
| 69 |
+
# from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock
|
| 70 |
+
|
| 71 |
+
# from tile_utils.attn import get_attn_func
|
| 72 |
+
# from tile_utils.typing import Processing
|
| 73 |
+
|
| 74 |
+
import comfy
|
| 75 |
+
import comfy.model_management
|
| 76 |
+
from comfy.model_management import processing_interrupted
|
| 77 |
+
import contextlib
|
| 78 |
+
|
| 79 |
+
opt_C = 4
|
| 80 |
+
opt_f = 8
|
| 81 |
+
is_sdxl = False
|
| 82 |
+
disable_nan_check = True
|
| 83 |
+
|
| 84 |
+
class Device: ...
|
| 85 |
+
devices = Device()
|
| 86 |
+
devices.device = comfy.model_management.get_torch_device()
|
| 87 |
+
devices.cpu = torch.device('cpu')
|
| 88 |
+
devices.torch_gc = lambda: comfy.model_management.soft_empty_cache()
|
| 89 |
+
devices.get_optimal_device = lambda: comfy.model_management.get_torch_device()
|
| 90 |
+
|
| 91 |
+
class NansException(Exception): ...
|
| 92 |
+
def test_for_nans(x, where):
|
| 93 |
+
if disable_nan_check:
|
| 94 |
+
return
|
| 95 |
+
if not torch.all(torch.isnan(x)).item():
|
| 96 |
+
return
|
| 97 |
+
if where == "unet":
|
| 98 |
+
message = "A tensor with all NaNs was produced in Unet."
|
| 99 |
+
if comfy.model_management.unet_dtype(x.device) != torch.float32:
|
| 100 |
+
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
| 101 |
+
elif where == "vae":
|
| 102 |
+
message = "A tensor with all NaNs was produced in VAE."
|
| 103 |
+
if comfy.model_management.unet_dtype(x.device) != torch.float32 and comfy.model_management.vae_dtype() != torch.float32:
|
| 104 |
+
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
|
| 105 |
+
else:
|
| 106 |
+
message = "A tensor with all NaNs was produced."
|
| 107 |
+
message += " Use --disable-nan-check commandline argument to disable this check."
|
| 108 |
+
raise NansException(message)
|
| 109 |
+
|
| 110 |
+
def _autocast(disable=False):
|
| 111 |
+
if disable:
|
| 112 |
+
return contextlib.nullcontext()
|
| 113 |
+
|
| 114 |
+
if comfy.model_management.unet_dtype() == torch.float32 or comfy.model_management.get_torch_device() == torch.device("mps"): # or shared.cmd_opts.precision == "full":
|
| 115 |
+
return contextlib.nullcontext()
|
| 116 |
+
|
| 117 |
+
# only cuda
|
| 118 |
+
autocast_device = comfy.model_management.get_autocast_device(comfy.model_management.get_torch_device())
|
| 119 |
+
return torch.autocast(autocast_device)
|
| 120 |
+
|
| 121 |
+
def without_autocast(disable=False):
|
| 122 |
+
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
| 123 |
+
|
| 124 |
+
devices.test_for_nans = test_for_nans
|
| 125 |
+
devices.autocast = _autocast
|
| 126 |
+
devices.without_autocast = without_autocast
|
| 127 |
+
|
| 128 |
+
def cheap_approximation(sample):
|
| 129 |
+
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
| 130 |
+
|
| 131 |
+
if is_sdxl:
|
| 132 |
+
coeffs = [
|
| 133 |
+
[ 0.3448, 0.4168, 0.4395],
|
| 134 |
+
[-0.1953, -0.0290, 0.0250],
|
| 135 |
+
[ 0.1074, 0.0886, -0.0163],
|
| 136 |
+
[-0.3730, -0.2499, -0.2088],
|
| 137 |
+
]
|
| 138 |
+
else:
|
| 139 |
+
coeffs = [
|
| 140 |
+
[ 0.298, 0.207, 0.208],
|
| 141 |
+
[ 0.187, 0.286, 0.173],
|
| 142 |
+
[-0.158, 0.189, 0.264],
|
| 143 |
+
[-0.184, -0.271, -0.473],
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
coefs = torch.tensor(coeffs).to(sample.device)
|
| 147 |
+
|
| 148 |
+
x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
|
| 149 |
+
|
| 150 |
+
return x_sample
|
| 151 |
+
|
| 152 |
+
def get_rcmd_enc_tsize():
|
| 153 |
+
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
|
| 154 |
+
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
|
| 155 |
+
if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072
|
| 156 |
+
elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048
|
| 157 |
+
elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536
|
| 158 |
+
else: ENCODER_TILE_SIZE = 960
|
| 159 |
+
else: ENCODER_TILE_SIZE = 512
|
| 160 |
+
return ENCODER_TILE_SIZE
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_rcmd_dec_tsize():
|
| 164 |
+
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
|
| 165 |
+
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
|
| 166 |
+
if total_memory > 30*1000: DECODER_TILE_SIZE = 256
|
| 167 |
+
elif total_memory > 16*1000: DECODER_TILE_SIZE = 192
|
| 168 |
+
elif total_memory > 12*1000: DECODER_TILE_SIZE = 128
|
| 169 |
+
elif total_memory > 8*1000: DECODER_TILE_SIZE = 96
|
| 170 |
+
else: DECODER_TILE_SIZE = 64
|
| 171 |
+
else: DECODER_TILE_SIZE = 64
|
| 172 |
+
return DECODER_TILE_SIZE
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def inplace_nonlinearity(x):
|
| 176 |
+
# Test: fix for Nans
|
| 177 |
+
return F.silu(x, inplace=True)
|
| 178 |
+
|
| 179 |
+
def _attn_forward(self, x):
|
| 180 |
+
# From comfy.Idm.modules.diffusionmodules.model.AttnBlock.forward
|
| 181 |
+
# However, the residual & normalization are removed and computed separately.
|
| 182 |
+
h_ = x
|
| 183 |
+
q = self.q(h_)
|
| 184 |
+
k = self.k(h_)
|
| 185 |
+
v = self.v(h_)
|
| 186 |
+
h_ = self.optimized_attention(q, k, v)
|
| 187 |
+
h_ = self.proj_out(h_)
|
| 188 |
+
return h_
|
| 189 |
+
|
| 190 |
+
def get_attn_func():
|
| 191 |
+
return _attn_forward
|
| 192 |
+
|
| 193 |
+
def attn2task(task_queue, net):
|
| 194 |
+
|
| 195 |
+
attn_forward = get_attn_func()
|
| 196 |
+
task_queue.append(('store_res', lambda x: x))
|
| 197 |
+
task_queue.append(('pre_norm', net.norm))
|
| 198 |
+
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
|
| 199 |
+
task_queue.append(['add_res', None])
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def resblock2task(queue, block):
|
| 203 |
+
"""
|
| 204 |
+
Turn a ResNetBlock into a sequence of tasks and append to the task queue
|
| 205 |
+
|
| 206 |
+
@param queue: the target task queue
|
| 207 |
+
@param block: ResNetBlock
|
| 208 |
+
|
| 209 |
+
"""
|
| 210 |
+
if block.in_channels != block.out_channels:
|
| 211 |
+
if block.use_conv_shortcut:
|
| 212 |
+
queue.append(('store_res', block.conv_shortcut))
|
| 213 |
+
else:
|
| 214 |
+
queue.append(('store_res', block.nin_shortcut))
|
| 215 |
+
else:
|
| 216 |
+
queue.append(('store_res', lambda x: x))
|
| 217 |
+
queue.append(('pre_norm', block.norm1))
|
| 218 |
+
queue.append(('silu', inplace_nonlinearity))
|
| 219 |
+
queue.append(('conv1', block.conv1))
|
| 220 |
+
queue.append(('pre_norm', block.norm2))
|
| 221 |
+
queue.append(('silu', inplace_nonlinearity))
|
| 222 |
+
queue.append(('conv2', block.conv2))
|
| 223 |
+
queue.append(['add_res', None])
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def build_sampling(task_queue, net, is_decoder):
|
| 227 |
+
"""
|
| 228 |
+
Build the sampling part of a task queue
|
| 229 |
+
@param task_queue: the target task queue
|
| 230 |
+
@param net: the network
|
| 231 |
+
@param is_decoder: currently building decoder or encoder
|
| 232 |
+
"""
|
| 233 |
+
if is_decoder:
|
| 234 |
+
resblock2task(task_queue, net.mid.block_1)
|
| 235 |
+
attn2task(task_queue, net.mid.attn_1)
|
| 236 |
+
resblock2task(task_queue, net.mid.block_2)
|
| 237 |
+
resolution_iter = reversed(range(net.num_resolutions))
|
| 238 |
+
block_ids = net.num_res_blocks + 1
|
| 239 |
+
condition = 0
|
| 240 |
+
module = net.up
|
| 241 |
+
func_name = 'upsample'
|
| 242 |
+
else:
|
| 243 |
+
resolution_iter = range(net.num_resolutions)
|
| 244 |
+
block_ids = net.num_res_blocks
|
| 245 |
+
condition = net.num_resolutions - 1
|
| 246 |
+
module = net.down
|
| 247 |
+
func_name = 'downsample'
|
| 248 |
+
|
| 249 |
+
for i_level in resolution_iter:
|
| 250 |
+
for i_block in range(block_ids):
|
| 251 |
+
resblock2task(task_queue, module[i_level].block[i_block])
|
| 252 |
+
if i_level != condition:
|
| 253 |
+
task_queue.append((func_name, getattr(module[i_level], func_name)))
|
| 254 |
+
|
| 255 |
+
if not is_decoder:
|
| 256 |
+
resblock2task(task_queue, net.mid.block_1)
|
| 257 |
+
attn2task(task_queue, net.mid.attn_1)
|
| 258 |
+
resblock2task(task_queue, net.mid.block_2)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def build_task_queue(net, is_decoder):
|
| 262 |
+
"""
|
| 263 |
+
Build a single task queue for the encoder or decoder
|
| 264 |
+
@param net: the VAE decoder or encoder network
|
| 265 |
+
@param is_decoder: currently building decoder or encoder
|
| 266 |
+
@return: the task queue
|
| 267 |
+
"""
|
| 268 |
+
task_queue = []
|
| 269 |
+
task_queue.append(('conv_in', net.conv_in))
|
| 270 |
+
|
| 271 |
+
# construct the sampling part of the task queue
|
| 272 |
+
# because encoder and decoder share the same architecture, we extract the sampling part
|
| 273 |
+
build_sampling(task_queue, net, is_decoder)
|
| 274 |
+
|
| 275 |
+
if not is_decoder or not net.give_pre_end:
|
| 276 |
+
task_queue.append(('pre_norm', net.norm_out))
|
| 277 |
+
task_queue.append(('silu', inplace_nonlinearity))
|
| 278 |
+
task_queue.append(('conv_out', net.conv_out))
|
| 279 |
+
if is_decoder and net.tanh_out:
|
| 280 |
+
task_queue.append(('tanh', torch.tanh))
|
| 281 |
+
|
| 282 |
+
return task_queue
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def clone_task_queue(task_queue):
|
| 286 |
+
"""
|
| 287 |
+
Clone a task queue
|
| 288 |
+
@param task_queue: the task queue to be cloned
|
| 289 |
+
@return: the cloned task queue
|
| 290 |
+
"""
|
| 291 |
+
return [[item for item in task] for task in task_queue]
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def get_var_mean(input, num_groups, eps=1e-6):
|
| 295 |
+
"""
|
| 296 |
+
Get mean and var for group norm
|
| 297 |
+
"""
|
| 298 |
+
b, c = input.size(0), input.size(1)
|
| 299 |
+
channel_in_group = int(c/num_groups)
|
| 300 |
+
input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
| 301 |
+
var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
|
| 302 |
+
return var, mean
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
|
| 306 |
+
"""
|
| 307 |
+
Custom group norm with fixed mean and var
|
| 308 |
+
|
| 309 |
+
@param input: input tensor
|
| 310 |
+
@param num_groups: number of groups. by default, num_groups = 32
|
| 311 |
+
@param mean: mean, must be pre-calculated by get_var_mean
|
| 312 |
+
@param var: var, must be pre-calculated by get_var_mean
|
| 313 |
+
@param weight: weight, should be fetched from the original group norm
|
| 314 |
+
@param bias: bias, should be fetched from the original group norm
|
| 315 |
+
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
|
| 316 |
+
|
| 317 |
+
@return: normalized tensor
|
| 318 |
+
"""
|
| 319 |
+
b, c = input.size(0), input.size(1)
|
| 320 |
+
channel_in_group = int(c/num_groups)
|
| 321 |
+
input_reshaped = input.contiguous().view(
|
| 322 |
+
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
| 323 |
+
|
| 324 |
+
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
|
| 325 |
+
out = out.view(b, c, *input.size()[2:])
|
| 326 |
+
|
| 327 |
+
# post affine transform
|
| 328 |
+
if weight is not None:
|
| 329 |
+
out *= weight.view(1, -1, 1, 1)
|
| 330 |
+
if bias is not None:
|
| 331 |
+
out += bias.view(1, -1, 1, 1)
|
| 332 |
+
return out
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
|
| 336 |
+
"""
|
| 337 |
+
Crop the valid region from the tile
|
| 338 |
+
@param x: input tile
|
| 339 |
+
@param input_bbox: original input bounding box
|
| 340 |
+
@param target_bbox: output bounding box
|
| 341 |
+
@param scale: scale factor
|
| 342 |
+
@return: cropped tile
|
| 343 |
+
"""
|
| 344 |
+
padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
|
| 345 |
+
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
|
| 346 |
+
return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
|
| 350 |
+
|
| 351 |
+
def perfcount(fn):
|
| 352 |
+
def wrapper(*args, **kwargs):
|
| 353 |
+
ts = time()
|
| 354 |
+
|
| 355 |
+
if torch.cuda.is_available():
|
| 356 |
+
torch.cuda.reset_peak_memory_stats(devices.device)
|
| 357 |
+
devices.torch_gc()
|
| 358 |
+
gc.collect()
|
| 359 |
+
|
| 360 |
+
ret = fn(*args, **kwargs)
|
| 361 |
+
|
| 362 |
+
devices.torch_gc()
|
| 363 |
+
gc.collect()
|
| 364 |
+
if torch.cuda.is_available():
|
| 365 |
+
vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
|
| 366 |
+
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
|
| 367 |
+
else:
|
| 368 |
+
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
|
| 369 |
+
|
| 370 |
+
return ret
|
| 371 |
+
return wrapper
|
| 372 |
+
|
| 373 |
+
# ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class GroupNormParam:
|
| 377 |
+
|
| 378 |
+
def __init__(self):
|
| 379 |
+
self.var_list = []
|
| 380 |
+
self.mean_list = []
|
| 381 |
+
self.pixel_list = []
|
| 382 |
+
self.weight = None
|
| 383 |
+
self.bias = None
|
| 384 |
+
|
| 385 |
+
def add_tile(self, tile, layer):
|
| 386 |
+
var, mean = get_var_mean(tile, 32)
|
| 387 |
+
# For giant images, the variance can be larger than max float16
|
| 388 |
+
# In this case we create a copy to float32
|
| 389 |
+
if var.dtype == torch.float16 and var.isinf().any():
|
| 390 |
+
fp32_tile = tile.float()
|
| 391 |
+
var, mean = get_var_mean(fp32_tile, 32)
|
| 392 |
+
# ============= DEBUG: test for infinite =============
|
| 393 |
+
# if torch.isinf(var).any():
|
| 394 |
+
# print('[Tiled VAE]: inf test', var)
|
| 395 |
+
# ====================================================
|
| 396 |
+
self.var_list.append(var)
|
| 397 |
+
self.mean_list.append(mean)
|
| 398 |
+
self.pixel_list.append(
|
| 399 |
+
tile.shape[2]*tile.shape[3])
|
| 400 |
+
if hasattr(layer, 'weight'):
|
| 401 |
+
self.weight = layer.weight
|
| 402 |
+
self.bias = layer.bias
|
| 403 |
+
else:
|
| 404 |
+
self.weight = None
|
| 405 |
+
self.bias = None
|
| 406 |
+
|
| 407 |
+
def summary(self):
|
| 408 |
+
"""
|
| 409 |
+
summarize the mean and var and return a function
|
| 410 |
+
that apply group norm on each tile
|
| 411 |
+
"""
|
| 412 |
+
if len(self.var_list) == 0: return None
|
| 413 |
+
|
| 414 |
+
var = torch.vstack(self.var_list)
|
| 415 |
+
mean = torch.vstack(self.mean_list)
|
| 416 |
+
max_value = max(self.pixel_list)
|
| 417 |
+
pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
|
| 418 |
+
sum_pixels = torch.sum(pixels)
|
| 419 |
+
pixels = pixels.unsqueeze(1) / sum_pixels
|
| 420 |
+
# var = torch.sum(var * pixels.to(var.device), dim=0)
|
| 421 |
+
# mean = torch.sum(mean * pixels.to(var.device), dim=0)
|
| 422 |
+
var = torch.sum(var * pixels, dim=0)
|
| 423 |
+
mean = torch.sum(mean * pixels, dim=0)
|
| 424 |
+
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
|
| 425 |
+
|
| 426 |
+
@staticmethod
|
| 427 |
+
def from_tile(tile, norm):
|
| 428 |
+
"""
|
| 429 |
+
create a function from a single tile without summary
|
| 430 |
+
"""
|
| 431 |
+
var, mean = get_var_mean(tile, 32)
|
| 432 |
+
if var.dtype == torch.float16 and var.isinf().any():
|
| 433 |
+
fp32_tile = tile.float()
|
| 434 |
+
var, mean = get_var_mean(fp32_tile, 32)
|
| 435 |
+
# if it is a macbook, we need to convert back to float16
|
| 436 |
+
if var.device.type == 'mps':
|
| 437 |
+
# clamp to avoid overflow
|
| 438 |
+
var = torch.clamp(var, 0, 60000)
|
| 439 |
+
var = var.half()
|
| 440 |
+
mean = mean.half()
|
| 441 |
+
if hasattr(norm, 'weight'):
|
| 442 |
+
weight = norm.weight
|
| 443 |
+
bias = norm.bias
|
| 444 |
+
else:
|
| 445 |
+
weight = None
|
| 446 |
+
bias = None
|
| 447 |
+
|
| 448 |
+
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
|
| 449 |
+
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
|
| 450 |
+
return group_norm_func
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
class VAEHook:
|
| 454 |
+
|
| 455 |
+
def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False):
|
| 456 |
+
self.net = net # encoder | decoder
|
| 457 |
+
self.tile_size = tile_size
|
| 458 |
+
self.is_decoder = is_decoder
|
| 459 |
+
self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder)
|
| 460 |
+
self.color_fix = color_fix and not is_decoder
|
| 461 |
+
self.to_gpu = to_gpu
|
| 462 |
+
self.pad = 11 if is_decoder else 32 # FIXME: magic number
|
| 463 |
+
|
| 464 |
+
def __call__(self, x):
|
| 465 |
+
# original_device = next(self.net.parameters()).device
|
| 466 |
+
try:
|
| 467 |
+
# if self.to_gpu:
|
| 468 |
+
# self.net = self.net.to(devices.get_optimal_device())
|
| 469 |
+
B, C, H, W = x.shape
|
| 470 |
+
if False:#max(H, W) <= self.pad * 2 + self.tile_size:
|
| 471 |
+
print("[Tiled VAE]: the input size is tiny and unnecessary to tile.", x.shape, self.pad * 2 + self.tile_size)
|
| 472 |
+
return self.net.original_forward(x)
|
| 473 |
+
else:
|
| 474 |
+
return self.vae_tile_forward(x)
|
| 475 |
+
finally:
|
| 476 |
+
pass
|
| 477 |
+
# self.net = self.net.to(original_device)
|
| 478 |
+
|
| 479 |
+
def get_best_tile_size(self, lowerbound, upperbound):
|
| 480 |
+
"""
|
| 481 |
+
Get the best tile size for GPU memory
|
| 482 |
+
"""
|
| 483 |
+
divider = 32
|
| 484 |
+
while divider >= 2:
|
| 485 |
+
remainer = lowerbound % divider
|
| 486 |
+
if remainer == 0:
|
| 487 |
+
return lowerbound
|
| 488 |
+
candidate = lowerbound - remainer + divider
|
| 489 |
+
if candidate <= upperbound:
|
| 490 |
+
return candidate
|
| 491 |
+
divider //= 2
|
| 492 |
+
return lowerbound
|
| 493 |
+
|
| 494 |
+
def split_tiles(self, h, w):
|
| 495 |
+
"""
|
| 496 |
+
Tool function to split the image into tiles
|
| 497 |
+
@param h: height of the image
|
| 498 |
+
@param w: width of the image
|
| 499 |
+
@return: tile_input_bboxes, tile_output_bboxes
|
| 500 |
+
"""
|
| 501 |
+
tile_input_bboxes, tile_output_bboxes = [], []
|
| 502 |
+
tile_size = self.tile_size
|
| 503 |
+
pad = self.pad
|
| 504 |
+
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
|
| 505 |
+
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
|
| 506 |
+
# If any of the numbers are 0, we let it be 1
|
| 507 |
+
# This is to deal with long and thin images
|
| 508 |
+
num_height_tiles = max(num_height_tiles, 1)
|
| 509 |
+
num_width_tiles = max(num_width_tiles, 1)
|
| 510 |
+
|
| 511 |
+
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
|
| 512 |
+
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
|
| 513 |
+
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
|
| 514 |
+
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
|
| 515 |
+
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
|
| 516 |
+
|
| 517 |
+
print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
|
| 518 |
+
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
|
| 519 |
+
|
| 520 |
+
for i in range(num_height_tiles):
|
| 521 |
+
for j in range(num_width_tiles):
|
| 522 |
+
# bbox: [x1, x2, y1, y2]
|
| 523 |
+
# the padding is is unnessary for image borders. So we directly start from (32, 32)
|
| 524 |
+
input_bbox = [
|
| 525 |
+
pad + j * real_tile_width,
|
| 526 |
+
min(pad + (j + 1) * real_tile_width, w),
|
| 527 |
+
pad + i * real_tile_height,
|
| 528 |
+
min(pad + (i + 1) * real_tile_height, h),
|
| 529 |
+
]
|
| 530 |
+
|
| 531 |
+
# if the output bbox is close to the image boundary, we extend it to the image boundary
|
| 532 |
+
output_bbox = [
|
| 533 |
+
input_bbox[0] if input_bbox[0] > pad else 0,
|
| 534 |
+
input_bbox[1] if input_bbox[1] < w - pad else w,
|
| 535 |
+
input_bbox[2] if input_bbox[2] > pad else 0,
|
| 536 |
+
input_bbox[3] if input_bbox[3] < h - pad else h,
|
| 537 |
+
]
|
| 538 |
+
|
| 539 |
+
# scale to get the final output bbox
|
| 540 |
+
output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
|
| 541 |
+
tile_output_bboxes.append(output_bbox)
|
| 542 |
+
|
| 543 |
+
# indistinguishable expand the input bbox by pad pixels
|
| 544 |
+
tile_input_bboxes.append([
|
| 545 |
+
max(0, input_bbox[0] - pad),
|
| 546 |
+
min(w, input_bbox[1] + pad),
|
| 547 |
+
max(0, input_bbox[2] - pad),
|
| 548 |
+
min(h, input_bbox[3] + pad),
|
| 549 |
+
])
|
| 550 |
+
|
| 551 |
+
return tile_input_bboxes, tile_output_bboxes
|
| 552 |
+
|
| 553 |
+
@torch.no_grad()
|
| 554 |
+
def estimate_group_norm(self, z, task_queue, color_fix):
|
| 555 |
+
device = z.device
|
| 556 |
+
tile = z
|
| 557 |
+
last_id = len(task_queue) - 1
|
| 558 |
+
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
|
| 559 |
+
last_id -= 1
|
| 560 |
+
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
|
| 561 |
+
raise ValueError('No group norm found in the task queue')
|
| 562 |
+
# estimate until the last group norm
|
| 563 |
+
for i in range(last_id + 1):
|
| 564 |
+
task = task_queue[i]
|
| 565 |
+
if task[0] == 'pre_norm':
|
| 566 |
+
group_norm_func = GroupNormParam.from_tile(tile, task[1])
|
| 567 |
+
task_queue[i] = ('apply_norm', group_norm_func)
|
| 568 |
+
if i == last_id:
|
| 569 |
+
return True
|
| 570 |
+
tile = group_norm_func(tile)
|
| 571 |
+
elif task[0] == 'store_res':
|
| 572 |
+
task_id = i + 1
|
| 573 |
+
while task_id < last_id and task_queue[task_id][0] != 'add_res':
|
| 574 |
+
task_id += 1
|
| 575 |
+
if task_id >= last_id:
|
| 576 |
+
continue
|
| 577 |
+
task_queue[task_id][1] = task[1](tile)
|
| 578 |
+
elif task[0] == 'add_res':
|
| 579 |
+
tile += task[1].to(device)
|
| 580 |
+
task[1] = None
|
| 581 |
+
elif color_fix and task[0] == 'downsample':
|
| 582 |
+
for j in range(i, last_id + 1):
|
| 583 |
+
if task_queue[j][0] == 'store_res':
|
| 584 |
+
task_queue[j] = ('store_res_cpu', task_queue[j][1])
|
| 585 |
+
return True
|
| 586 |
+
else:
|
| 587 |
+
tile = task[1](tile)
|
| 588 |
+
try:
|
| 589 |
+
devices.test_for_nans(tile, "vae")
|
| 590 |
+
except:
|
| 591 |
+
print(f'Nan detected in fast mode estimation. Fast mode disabled.')
|
| 592 |
+
return False
|
| 593 |
+
|
| 594 |
+
raise IndexError('Should not reach here')
|
| 595 |
+
|
| 596 |
+
@perfcount
|
| 597 |
+
@torch.no_grad()
|
| 598 |
+
def vae_tile_forward(self, z):
|
| 599 |
+
"""
|
| 600 |
+
Decode a latent vector z into an image in a tiled manner.
|
| 601 |
+
@param z: latent vector
|
| 602 |
+
@return: image
|
| 603 |
+
"""
|
| 604 |
+
device = next(self.net.parameters()).device
|
| 605 |
+
net = self.net
|
| 606 |
+
tile_size = self.tile_size
|
| 607 |
+
is_decoder = self.is_decoder
|
| 608 |
+
|
| 609 |
+
z = z.detach() # detach the input to avoid backprop
|
| 610 |
+
|
| 611 |
+
N, height, width = z.shape[0], z.shape[2], z.shape[3]
|
| 612 |
+
net.last_z_shape = z.shape
|
| 613 |
+
|
| 614 |
+
# Split the input into tiles and build a task queue for each tile
|
| 615 |
+
print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
|
| 616 |
+
|
| 617 |
+
in_bboxes, out_bboxes = self.split_tiles(height, width)
|
| 618 |
+
|
| 619 |
+
# Prepare tiles by split the input latents
|
| 620 |
+
tiles = []
|
| 621 |
+
for input_bbox in in_bboxes:
|
| 622 |
+
tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
|
| 623 |
+
tiles.append(tile)
|
| 624 |
+
|
| 625 |
+
num_tiles = len(tiles)
|
| 626 |
+
num_completed = 0
|
| 627 |
+
|
| 628 |
+
# Build task queues
|
| 629 |
+
single_task_queue = build_task_queue(net, is_decoder)
|
| 630 |
+
if self.fast_mode:
|
| 631 |
+
# Fast mode: downsample the input image to the tile size,
|
| 632 |
+
# then estimate the group norm parameters on the downsampled image
|
| 633 |
+
scale_factor = tile_size / max(height, width)
|
| 634 |
+
z = z.to(device)
|
| 635 |
+
downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
|
| 636 |
+
# use nearest-exact to keep statictics as close as possible
|
| 637 |
+
print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
|
| 638 |
+
|
| 639 |
+
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
|
| 640 |
+
# The downsampling will heavily distort its mean and std, so we need to recover it.
|
| 641 |
+
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
|
| 642 |
+
std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
|
| 643 |
+
downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
|
| 644 |
+
del std_old, mean_old, std_new, mean_new
|
| 645 |
+
# occasionally the std_new is too small or too large, which exceeds the range of float16
|
| 646 |
+
# so we need to clamp it to max z's range.
|
| 647 |
+
downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
|
| 648 |
+
estimate_task_queue = clone_task_queue(single_task_queue)
|
| 649 |
+
if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
|
| 650 |
+
single_task_queue = estimate_task_queue
|
| 651 |
+
del downsampled_z
|
| 652 |
+
|
| 653 |
+
task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
|
| 654 |
+
|
| 655 |
+
# Dummy result
|
| 656 |
+
result = None
|
| 657 |
+
result_approx = None
|
| 658 |
+
try:
|
| 659 |
+
with devices.autocast():
|
| 660 |
+
result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
|
| 661 |
+
except: pass
|
| 662 |
+
# Free memory of input latent tensor
|
| 663 |
+
del z
|
| 664 |
+
|
| 665 |
+
# Task queue execution
|
| 666 |
+
pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
|
| 667 |
+
pbar_comfy = comfy.utils.ProgressBar(num_tiles * len(task_queues[0]))
|
| 668 |
+
|
| 669 |
+
# execute the task back and forth when switch tiles so that we always
|
| 670 |
+
# keep one tile on the GPU to reduce unnecessary data transfer
|
| 671 |
+
forward = True
|
| 672 |
+
interrupted = False
|
| 673 |
+
state_interrupted = processing_interrupted()
|
| 674 |
+
#state.interrupted = interrupted
|
| 675 |
+
while True:
|
| 676 |
+
if state_interrupted: interrupted = True ; break
|
| 677 |
+
|
| 678 |
+
group_norm_param = GroupNormParam()
|
| 679 |
+
for i in range(num_tiles) if forward else reversed(range(num_tiles)):
|
| 680 |
+
if state_interrupted: interrupted = True ; break
|
| 681 |
+
|
| 682 |
+
tile = tiles[i].to(device)
|
| 683 |
+
input_bbox = in_bboxes[i]
|
| 684 |
+
task_queue = task_queues[i]
|
| 685 |
+
|
| 686 |
+
interrupted = False
|
| 687 |
+
while len(task_queue) > 0:
|
| 688 |
+
if state_interrupted: interrupted = True ; break
|
| 689 |
+
|
| 690 |
+
# DEBUG: current task
|
| 691 |
+
# print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
|
| 692 |
+
task = task_queue.pop(0)
|
| 693 |
+
if task[0] == 'pre_norm':
|
| 694 |
+
group_norm_param.add_tile(tile, task[1])
|
| 695 |
+
break
|
| 696 |
+
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
|
| 697 |
+
task_id = 0
|
| 698 |
+
res = task[1](tile)
|
| 699 |
+
if not self.fast_mode or task[0] == 'store_res_cpu':
|
| 700 |
+
res = res.cpu()
|
| 701 |
+
while task_queue[task_id][0] != 'add_res':
|
| 702 |
+
task_id += 1
|
| 703 |
+
task_queue[task_id][1] = res
|
| 704 |
+
elif task[0] == 'add_res':
|
| 705 |
+
tile += task[1].to(device)
|
| 706 |
+
task[1] = None
|
| 707 |
+
else:
|
| 708 |
+
tile = task[1](tile)
|
| 709 |
+
pbar.update(1)
|
| 710 |
+
pbar_comfy.update(1)
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
if interrupted: break
|
| 714 |
+
|
| 715 |
+
# check for NaNs in the tile.
|
| 716 |
+
# If there are NaNs, we abort the process to save user's time
|
| 717 |
+
devices.test_for_nans(tile, "vae")
|
| 718 |
+
|
| 719 |
+
if len(task_queue) == 0:
|
| 720 |
+
tiles[i] = None
|
| 721 |
+
num_completed += 1
|
| 722 |
+
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
|
| 723 |
+
result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
|
| 724 |
+
result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
|
| 725 |
+
del tile
|
| 726 |
+
elif i == num_tiles - 1 and forward:
|
| 727 |
+
forward = False
|
| 728 |
+
tiles[i] = tile
|
| 729 |
+
elif i == 0 and not forward:
|
| 730 |
+
forward = True
|
| 731 |
+
tiles[i] = tile
|
| 732 |
+
else:
|
| 733 |
+
tiles[i] = tile.cpu()
|
| 734 |
+
del tile
|
| 735 |
+
|
| 736 |
+
if interrupted: break
|
| 737 |
+
if num_completed == num_tiles: break
|
| 738 |
+
|
| 739 |
+
# insert the group norm task to the head of each task queue
|
| 740 |
+
group_norm_func = group_norm_param.summary()
|
| 741 |
+
if group_norm_func is not None:
|
| 742 |
+
for i in range(num_tiles):
|
| 743 |
+
task_queue = task_queues[i]
|
| 744 |
+
task_queue.insert(0, ('apply_norm', group_norm_func))
|
| 745 |
+
|
| 746 |
+
# Done!
|
| 747 |
+
pbar.close()
|
| 748 |
+
if interrupted:
|
| 749 |
+
del result, result_approx
|
| 750 |
+
comfy.model_management.throw_exception_if_processing_interrupted()
|
| 751 |
+
vae_dtype = comfy.model_management.vae_dtype()
|
| 752 |
+
return result.to(dtype=vae_dtype, device=device) if result is not None else result_approx.to(device=device, dtype=vae_dtype)
|
| 753 |
+
|
| 754 |
+
# from .tiled_vae import VAEHook, get_rcmd_enc_tsize, get_rcmd_dec_tsize
|
| 755 |
+
from nodes import VAEEncode, VAEDecode
|
| 756 |
+
class TiledVAE:
|
| 757 |
+
def process(self, *args, **kwargs):
|
| 758 |
+
samples = kwargs['samples'] if 'samples' in kwargs else (kwargs['pixels'] if 'pixels' in kwargs else args[0])
|
| 759 |
+
_vae = kwargs['vae'] if 'vae' in kwargs else args[1]
|
| 760 |
+
tile_size = kwargs['tile_size'] if 'tile_size' in kwargs else args[2]
|
| 761 |
+
fast = kwargs['fast'] if 'fast' in kwargs else args[3]
|
| 762 |
+
color_fix = kwargs['color_fix'] if 'color_fix' in kwargs else False
|
| 763 |
+
is_decoder = self.is_decoder
|
| 764 |
+
|
| 765 |
+
# for shorthand
|
| 766 |
+
vae = _vae.first_stage_model
|
| 767 |
+
encoder = vae.encoder
|
| 768 |
+
decoder = vae.decoder
|
| 769 |
+
|
| 770 |
+
# # undo hijack if disabled (in cases last time crashed)
|
| 771 |
+
# if not enabled:
|
| 772 |
+
# if self.hooked:
|
| 773 |
+
if isinstance(encoder.forward, VAEHook):
|
| 774 |
+
encoder.forward.net = None
|
| 775 |
+
encoder.forward = encoder.original_forward
|
| 776 |
+
if isinstance(decoder.forward, VAEHook):
|
| 777 |
+
decoder.forward.net = None
|
| 778 |
+
decoder.forward = decoder.original_forward
|
| 779 |
+
# self.hooked = False
|
| 780 |
+
# return
|
| 781 |
+
|
| 782 |
+
# if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu:
|
| 783 |
+
# print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.")
|
| 784 |
+
|
| 785 |
+
# do hijack
|
| 786 |
+
# kwargs = {
|
| 787 |
+
# 'fast_decoder': fast_decoder,
|
| 788 |
+
# 'fast_encoder': fast_encoder,
|
| 789 |
+
# 'color_fix': color_fix,
|
| 790 |
+
# 'to_gpu': vae_to_gpu,
|
| 791 |
+
# }
|
| 792 |
+
|
| 793 |
+
# save original forward (only once)
|
| 794 |
+
if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward)
|
| 795 |
+
if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward)
|
| 796 |
+
|
| 797 |
+
# self.hooked = True
|
| 798 |
+
|
| 799 |
+
# encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs)
|
| 800 |
+
# decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs)
|
| 801 |
+
fn = VAEHook(net=decoder if is_decoder else encoder, tile_size=tile_size // 8 if is_decoder else tile_size,
|
| 802 |
+
is_decoder=is_decoder, fast_decoder=fast, fast_encoder=fast,
|
| 803 |
+
color_fix=color_fix, to_gpu=comfy.model_management.vae_device().type != 'cpu')
|
| 804 |
+
if is_decoder:
|
| 805 |
+
decoder.forward = fn
|
| 806 |
+
else:
|
| 807 |
+
encoder.forward = fn
|
| 808 |
+
|
| 809 |
+
ret = (None,)
|
| 810 |
+
try:
|
| 811 |
+
with devices.without_autocast():
|
| 812 |
+
if not is_decoder:
|
| 813 |
+
ret = VAEEncode().encode(_vae, samples)
|
| 814 |
+
else:
|
| 815 |
+
ret = VAEDecode().decode(_vae, samples) if is_decoder else VAEEncode().encode(_vae, samples)
|
| 816 |
+
finally:
|
| 817 |
+
if isinstance(encoder.forward, VAEHook):
|
| 818 |
+
encoder.forward.net = None
|
| 819 |
+
encoder.forward = encoder.original_forward
|
| 820 |
+
if isinstance(decoder.forward, VAEHook):
|
| 821 |
+
decoder.forward.net = None
|
| 822 |
+
decoder.forward = decoder.original_forward
|
| 823 |
+
return ret
|
| 824 |
+
|
| 825 |
+
class VAEEncodeTiled_TiledDiffusion(TiledVAE):
|
| 826 |
+
@classmethod
|
| 827 |
+
def INPUT_TYPES(s):
|
| 828 |
+
fast = True
|
| 829 |
+
tile_size = get_rcmd_enc_tsize()
|
| 830 |
+
return {"required": {"pixels": ("IMAGE", ),
|
| 831 |
+
"vae": ("VAE", ),
|
| 832 |
+
"tile_size": ("INT", {"default": tile_size, "min": 256, "max": 4096, "step": 16}),
|
| 833 |
+
"fast": ("BOOLEAN", {"default": fast}),
|
| 834 |
+
"color_fix": ("BOOLEAN", {"default": fast}),
|
| 835 |
+
}}
|
| 836 |
+
RETURN_TYPES = ("LATENT",)
|
| 837 |
+
FUNCTION = "process"
|
| 838 |
+
CATEGORY = "_for_testing"
|
| 839 |
+
|
| 840 |
+
def __init__(self):
|
| 841 |
+
self.is_decoder = False
|
| 842 |
+
super().__init__()
|
| 843 |
+
|
| 844 |
+
class VAEDecodeTiled_TiledDiffusion(TiledVAE):
|
| 845 |
+
@classmethod
|
| 846 |
+
def INPUT_TYPES(s):
|
| 847 |
+
tile_size = get_rcmd_dec_tsize() * opt_f
|
| 848 |
+
return {"required": {"samples": ("LATENT", ),
|
| 849 |
+
"vae": ("VAE", ),
|
| 850 |
+
"tile_size": ("INT", {"default": tile_size, "min": 48*opt_f, "max": 4096, "step": 16}),
|
| 851 |
+
"fast": ("BOOLEAN", {"default": True}),
|
| 852 |
+
}}
|
| 853 |
+
RETURN_TYPES = ("IMAGE",)
|
| 854 |
+
FUNCTION = "process"
|
| 855 |
+
CATEGORY = "_for_testing"
|
| 856 |
+
|
| 857 |
+
def __init__(self):
|
| 858 |
+
self.is_decoder = True
|
| 859 |
+
super().__init__()
|
| 860 |
+
|
| 861 |
+
NODE_CLASS_MAPPINGS = {
|
| 862 |
+
"VAEEncodeTiled_TiledDiffusion": VAEEncodeTiled_TiledDiffusion,
|
| 863 |
+
"VAEDecodeTiled_TiledDiffusion": VAEDecodeTiled_TiledDiffusion,
|
| 864 |
+
}
|
| 865 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 866 |
+
"VAEEncodeTiled_TiledDiffusion": "Tiled VAE Encode",
|
| 867 |
+
"VAEDecodeTiled_TiledDiffusion": "Tiled VAE Decode",
|
| 868 |
+
}
|
ComfyUI-TiledDiffusion/utils.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import importlib
|
| 3 |
+
from textwrap import dedent, indent
|
| 4 |
+
from copy import copy
|
| 5 |
+
import types
|
| 6 |
+
import functools
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import binascii
|
| 10 |
+
from typing import List, NamedTuple
|
| 11 |
+
|
| 12 |
+
class Hook(NamedTuple):
|
| 13 |
+
fn: object
|
| 14 |
+
module_name: str
|
| 15 |
+
target: str
|
| 16 |
+
orig_key: str
|
| 17 |
+
module_name_path: str
|
| 18 |
+
|
| 19 |
+
def gen_id():
|
| 20 |
+
return binascii.hexlify(os.urandom(1024))[64:72].decode("utf-8")
|
| 21 |
+
|
| 22 |
+
def hook_calc_cond_uncond_batch():
|
| 23 |
+
try:
|
| 24 |
+
from comfy.samplers import calc_cond_batch
|
| 25 |
+
calc_cond_batch_ = calc_cond_batch
|
| 26 |
+
except Exception:
|
| 27 |
+
from comfy.samplers import calc_cond_uncond_batch
|
| 28 |
+
calc_cond_batch_ = calc_cond_uncond_batch
|
| 29 |
+
# this function should only be run by us
|
| 30 |
+
orig_key = f"{calc_cond_batch_.__name__}_original_tiled_diffusion_{gen_id()}"
|
| 31 |
+
payload = [{
|
| 32 |
+
"mode": "replace",
|
| 33 |
+
"target_line": 'control.get_control',
|
| 34 |
+
"code_to_insert": """control if 'tiled_diffusion' in model_options else control.get_control"""
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"dedent": False,
|
| 38 |
+
"target_line": calc_cond_batch_.__name__,
|
| 39 |
+
"code_to_insert": f"""
|
| 40 |
+
if 'tiled_diffusion' not in model_options:
|
| 41 |
+
return {orig_key}{inspect.signature(calc_cond_batch_)}"""
|
| 42 |
+
}]
|
| 43 |
+
fn = inject_code(calc_cond_batch_, payload, 'w')
|
| 44 |
+
return create_hook(fn, 'comfy.samplers', orig_key=orig_key)
|
| 45 |
+
|
| 46 |
+
def hook_sag_create_blur_map():
|
| 47 |
+
imported = False
|
| 48 |
+
try:
|
| 49 |
+
import comfy_extras
|
| 50 |
+
from comfy_extras import nodes_sag
|
| 51 |
+
imported = True
|
| 52 |
+
except Exception: ...
|
| 53 |
+
if not imported: return
|
| 54 |
+
import comfy_extras
|
| 55 |
+
from comfy_extras import nodes_sag
|
| 56 |
+
import re
|
| 57 |
+
source=inspect.getsource(nodes_sag.create_blur_map)
|
| 58 |
+
replace_str="""
|
| 59 |
+
def calc_closest_factors(a):
|
| 60 |
+
for b in range(int(math.sqrt(a)), 0, -1):
|
| 61 |
+
if a%b == 0:
|
| 62 |
+
c = a // b
|
| 63 |
+
return (b,c)
|
| 64 |
+
m = calc_closest_factors(hw1)
|
| 65 |
+
mh = max(m) if lh > lw else min(m)
|
| 66 |
+
mw = m[1] if mh == m[0] else m[0]
|
| 67 |
+
mid_shape = mh, mw"""
|
| 68 |
+
modified_source = re.sub(r"ratio =.*\s+mid_shape =.*", replace_str, source, flags=re.MULTILINE)
|
| 69 |
+
fn = write_to_file_and_return_fn(nodes_sag.create_blur_map, modified_source)
|
| 70 |
+
return create_hook(fn, 'comfy_extras.nodes_sag')
|
| 71 |
+
|
| 72 |
+
def hook_samplers_pre_run_control():
|
| 73 |
+
from comfy.samplers import pre_run_control
|
| 74 |
+
payload = [{
|
| 75 |
+
"dedent": False,
|
| 76 |
+
"target_line": "if 'control' in x:",
|
| 77 |
+
"code_to_insert": """ try: x['control'].cleanup()\n except Exception: ..."""
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"target_line": "s = model.model_sampling",
|
| 81 |
+
"code_to_insert": """
|
| 82 |
+
def find_outer_instance(target:str, target_type):
|
| 83 |
+
import inspect
|
| 84 |
+
frame = inspect.currentframe()
|
| 85 |
+
i = 0
|
| 86 |
+
while frame and i < 7:
|
| 87 |
+
if (found:=frame.f_locals.get(target, None)) is not None:
|
| 88 |
+
if isinstance(found, target_type):
|
| 89 |
+
return found
|
| 90 |
+
frame = frame.f_back
|
| 91 |
+
i += 1
|
| 92 |
+
return None
|
| 93 |
+
from comfy.model_patcher import ModelPatcher
|
| 94 |
+
if (_model:=find_outer_instance('model', ModelPatcher)) is not None:
|
| 95 |
+
if (model_function_wrapper:=_model.model_options.get('model_function_wrapper', None)) is not None:
|
| 96 |
+
import sys
|
| 97 |
+
tiled_diffusion = sys.modules.get('ComfyUI-TiledDiffusion.tiled_diffusion', None)
|
| 98 |
+
if tiled_diffusion is None:
|
| 99 |
+
for key in sys.modules:
|
| 100 |
+
if 'tiled_diffusion' in key:
|
| 101 |
+
tiled_diffusion = sys.modules[key]
|
| 102 |
+
break
|
| 103 |
+
if (AbstractDiffusion:=getattr(tiled_diffusion, 'AbstractDiffusion', None)) is not None:
|
| 104 |
+
if isinstance(model_function_wrapper, AbstractDiffusion):
|
| 105 |
+
model_function_wrapper.reset()
|
| 106 |
+
"""}]
|
| 107 |
+
fn = inject_code(pre_run_control, payload)
|
| 108 |
+
return create_hook(fn, 'comfy.samplers')
|
| 109 |
+
|
| 110 |
+
def hook_gligen__set_position():
|
| 111 |
+
from comfy.gligen import Gligen
|
| 112 |
+
source=inspect.getsource(Gligen._set_position)
|
| 113 |
+
replace_str="""
|
| 114 |
+
nonlocal objs
|
| 115 |
+
if x.shape[0] > objs.shape[0]:
|
| 116 |
+
_objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1)
|
| 117 |
+
else:
|
| 118 |
+
_objs = objs
|
| 119 |
+
return module(x, _objs)"""
|
| 120 |
+
modified_source = dedent(source.replace(" return module(x, objs)", replace_str, 1))
|
| 121 |
+
fn = write_to_file_and_return_fn(Gligen._set_position, modified_source)
|
| 122 |
+
return create_hook(fn, 'comfy.gligen', 'Gligen._set_position')
|
| 123 |
+
|
| 124 |
+
def create_hook(fn, module_name:str, target = None, orig_key = None):
|
| 125 |
+
if target is None: target = fn.__name__
|
| 126 |
+
if orig_key is None: orig_key = f'{target}_original'
|
| 127 |
+
module_name_path = os.path.normpath(module_name.replace('.', '/'))
|
| 128 |
+
return Hook(fn, module_name, target, orig_key, module_name_path)
|
| 129 |
+
|
| 130 |
+
def _getattr(obj, name:str, default=None):
|
| 131 |
+
"""multi-level getattr"""
|
| 132 |
+
for attr in name.split('.'):
|
| 133 |
+
obj = getattr(obj, attr, default)
|
| 134 |
+
return obj
|
| 135 |
+
|
| 136 |
+
def _hasattr(obj, name:str):
|
| 137 |
+
"""multi-level hasattr"""
|
| 138 |
+
return _getattr(obj, name) is not None
|
| 139 |
+
|
| 140 |
+
def _setattr(obj, name:str, value=None):
|
| 141 |
+
"""multi-level setattr"""
|
| 142 |
+
split = name.split('.')
|
| 143 |
+
if not split[:-1]:
|
| 144 |
+
return setattr(obj, name, value)
|
| 145 |
+
else:
|
| 146 |
+
name = split[-1]
|
| 147 |
+
for attr in split[:-1]:
|
| 148 |
+
obj = getattr(obj, attr, None)
|
| 149 |
+
return setattr(obj, name, value)
|
| 150 |
+
|
| 151 |
+
def hook_all(restore=False, hooks=None):
|
| 152 |
+
if hooks is None:
|
| 153 |
+
hooks: List[Hook] = [
|
| 154 |
+
hook_calc_cond_uncond_batch(),
|
| 155 |
+
hook_sag_create_blur_map(),
|
| 156 |
+
hook_samplers_pre_run_control(),
|
| 157 |
+
hook_gligen__set_position(),
|
| 158 |
+
]
|
| 159 |
+
for key, module in sys.modules.items():
|
| 160 |
+
for hook in hooks:
|
| 161 |
+
if key == hook.module_name or key.endswith(hook.module_name_path):
|
| 162 |
+
if _hasattr(module, hook.target):
|
| 163 |
+
if not _hasattr(module, hook.orig_key):
|
| 164 |
+
if (orig_fn:=_getattr(module, hook.target, None)) is not None:
|
| 165 |
+
_setattr(module, hook.orig_key, orig_fn)
|
| 166 |
+
if restore:
|
| 167 |
+
_setattr(module, hook.target, _getattr(module, hook.orig_key, None))
|
| 168 |
+
else:
|
| 169 |
+
_setattr(module, hook.target, hook.fn)
|
| 170 |
+
|
| 171 |
+
def inject_code(original_func, data, mode='a'):
|
| 172 |
+
# Get the source code of the original function
|
| 173 |
+
original_source = inspect.getsource(original_func)
|
| 174 |
+
|
| 175 |
+
# Split the source code into lines
|
| 176 |
+
lines = original_source.split("\n")
|
| 177 |
+
|
| 178 |
+
for item in data:
|
| 179 |
+
# Find the line number of the target line
|
| 180 |
+
target_line_number = None
|
| 181 |
+
for i, line in enumerate(lines):
|
| 182 |
+
if item['target_line'] not in line: continue
|
| 183 |
+
target_line_number = i + 1
|
| 184 |
+
if item.get("mode","insert") == "replace":
|
| 185 |
+
lines[i] = lines[i].replace(item['target_line'], item['code_to_insert'])
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
# Find the indentation of the line where the new code will be inserted
|
| 189 |
+
indentation = ''
|
| 190 |
+
for char in line:
|
| 191 |
+
if char == ' ':
|
| 192 |
+
indentation += char
|
| 193 |
+
else:
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
# Indent the new code to match the original
|
| 197 |
+
code_to_insert = item['code_to_insert']
|
| 198 |
+
if item.get("dedent",True):
|
| 199 |
+
code_to_insert = dedent(item['code_to_insert'])
|
| 200 |
+
code_to_insert = indent(code_to_insert, indentation)
|
| 201 |
+
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
# Insert the code to be injected after the target line
|
| 205 |
+
if item.get("mode","insert") == "insert" and target_line_number is not None:
|
| 206 |
+
lines.insert(target_line_number, code_to_insert)
|
| 207 |
+
|
| 208 |
+
# Recreate the modified source code
|
| 209 |
+
modified_source = "\n".join(lines)
|
| 210 |
+
modified_source = dedent(modified_source.strip("\n"))
|
| 211 |
+
return write_to_file_and_return_fn(original_func, modified_source, mode)
|
| 212 |
+
|
| 213 |
+
def write_to_file_and_return_fn(original_func, source:str, mode='a'):
|
| 214 |
+
# Write the modified source code to a temporary file so the
|
| 215 |
+
# source code and stack traces can still be viewed when debugging.
|
| 216 |
+
custom_name = ".patches.py"
|
| 217 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 218 |
+
temp_file_path = os.path.join(current_dir, custom_name)
|
| 219 |
+
with open(temp_file_path, mode) as temp_file:
|
| 220 |
+
temp_file.write(source)
|
| 221 |
+
temp_file.write("\n")
|
| 222 |
+
temp_file.flush()
|
| 223 |
+
|
| 224 |
+
MODULE_PATH = temp_file.name
|
| 225 |
+
MODULE_NAME = __name__.split('.')[0].replace('-','_') + "_patch_modules"
|
| 226 |
+
spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
|
| 227 |
+
module = importlib.util.module_from_spec(spec)
|
| 228 |
+
sys.modules[spec.name] = module
|
| 229 |
+
spec.loader.exec_module(module)
|
| 230 |
+
|
| 231 |
+
# Retrieve the modified function from the module
|
| 232 |
+
modified_function = getattr(module, original_func.__name__)
|
| 233 |
+
|
| 234 |
+
# Adapted from https://stackoverflow.com/a/49077211
|
| 235 |
+
def copy_func(f, globals=None, module=None, code=None, update_wrapper=True):
|
| 236 |
+
if globals is None: globals = f.__globals__
|
| 237 |
+
if code is None: code = f.__code__
|
| 238 |
+
g = types.FunctionType(code, globals, name=f.__name__,
|
| 239 |
+
argdefs=f.__defaults__, closure=f.__closure__)
|
| 240 |
+
if update_wrapper: g = functools.update_wrapper(g, f)
|
| 241 |
+
if module is not None: g.__module__ = module
|
| 242 |
+
g.__kwdefaults__ = copy(f.__kwdefaults__)
|
| 243 |
+
return g
|
| 244 |
+
|
| 245 |
+
return copy_func(original_func, code=modified_function.__code__, update_wrapper=False)
|
| 246 |
+
|