asffhgjasdasfdsa commited on
Commit
5e37875
·
1 Parent(s): 099172c

Upload files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -35
  2. .gitignore +23 -0
  3. README.md +5 -6
  4. alembic.ini +84 -0
  5. alembic_db/README.md +4 -0
  6. alembic_db/env.py +64 -0
  7. alembic_db/script.py.mako +28 -0
  8. alembic_db/versions/0001_assets.py +174 -0
  9. api_server/__init__.py +0 -0
  10. api_server/routes/__init__.py +0 -0
  11. api_server/routes/internal/README.md +3 -0
  12. api_server/routes/internal/__init__.py +0 -0
  13. api_server/routes/internal/internal_routes.py +78 -0
  14. api_server/services/__init__.py +0 -0
  15. api_server/services/terminal_service.py +60 -0
  16. api_server/utils/file_operations.py +42 -0
  17. app.py +309 -0
  18. app/__init__.py +0 -0
  19. app/app_settings.py +65 -0
  20. app/assets/api/routes.py +102 -0
  21. app/assets/api/schemas_in.py +94 -0
  22. app/assets/api/schemas_out.py +60 -0
  23. app/assets/database/bulk_ops.py +204 -0
  24. app/assets/database/models.py +233 -0
  25. app/assets/database/queries.py +267 -0
  26. app/assets/database/tags.py +62 -0
  27. app/assets/hashing.py +75 -0
  28. app/assets/helpers.py +217 -0
  29. app/assets/manager.py +123 -0
  30. app/assets/scanner.py +229 -0
  31. app/custom_node_manager.py +145 -0
  32. app/database/db.py +112 -0
  33. app/database/models.py +21 -0
  34. app/frontend_management.py +457 -0
  35. app/logger.py +98 -0
  36. app/model_manager.py +195 -0
  37. app/subgraph_manager.py +132 -0
  38. app/user_manager.py +456 -0
  39. blueprints/put_blueprints_here +0 -0
  40. comfy/audio_encoders/audio_encoders.py +91 -0
  41. comfy/audio_encoders/wav2vec2.py +252 -0
  42. comfy/audio_encoders/whisper.py +186 -0
  43. comfy/checkpoint_pickle.py +13 -0
  44. comfy/cldm/cldm.py +434 -0
  45. comfy/cldm/control_types.py +10 -0
  46. comfy/cldm/dit_embedder.py +120 -0
  47. comfy/cldm/mmdit.py +81 -0
  48. comfy/cli_args.py +259 -0
  49. comfy/clip_config_bigg.json +23 -0
  50. comfy/clip_model.py +331 -0
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ /web/assets/** linguist-generated
2
+ /web/** linguist-vendored
3
+ comfy_api_nodes/apis/__init__.py linguist-generated
4
+ comfy/text_encoders/t5_pile_tokenizer/tokenizer.model filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ /output/
4
+ /input/
5
+ !/input/example.png
6
+ /models/
7
+ /temp/
8
+ /.vs
9
+ .vscode/
10
+ .idea/
11
+ venv/
12
+ .venv/
13
+ /web/extensions/*
14
+ !/web/extensions/logging.js.example
15
+ !/web/extensions/core/
16
+ /tests-ui/data/object_info.json
17
+ /user/
18
+ *.log
19
+ web_custom_versions/
20
+ .DS_Store
21
+ openapi.yaml
22
+ filtered-openapi.yaml
23
+ uv.lock
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: Cfui
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 6.3.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Comfytest
3
+ emoji: 😻
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 6.3.0
 
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
alembic.ini ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A generic, single database configuration.
2
+
3
+ [alembic]
4
+ # path to migration scripts
5
+ # Use forward slashes (/) also on windows to provide an os agnostic path
6
+ script_location = alembic_db
7
+
8
+ # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
9
+ # Uncomment the line below if you want the files to be prepended with date and time
10
+ # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
11
+ # for all available tokens
12
+ # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
13
+
14
+ # sys.path path, will be prepended to sys.path if present.
15
+ # defaults to the current working directory.
16
+ prepend_sys_path = .
17
+
18
+ # timezone to use when rendering the date within the migration file
19
+ # as well as the filename.
20
+ # If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
21
+ # Any required deps can installed by adding `alembic[tz]` to the pip requirements
22
+ # string value is passed to ZoneInfo()
23
+ # leave blank for localtime
24
+ # timezone =
25
+
26
+ # max length of characters to apply to the "slug" field
27
+ # truncate_slug_length = 40
28
+
29
+ # set to 'true' to run the environment during
30
+ # the 'revision' command, regardless of autogenerate
31
+ # revision_environment = false
32
+
33
+ # set to 'true' to allow .pyc and .pyo files without
34
+ # a source .py file to be detected as revisions in the
35
+ # versions/ directory
36
+ # sourceless = false
37
+
38
+ # version location specification; This defaults
39
+ # to alembic_db/versions. When using multiple version
40
+ # directories, initial revisions must be specified with --version-path.
41
+ # The path separator used here should be the separator specified by "version_path_separator" below.
42
+ # version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
43
+
44
+ # version path separator; As mentioned above, this is the character used to split
45
+ # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
46
+ # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
47
+ # Valid values for version_path_separator are:
48
+ #
49
+ # version_path_separator = :
50
+ # version_path_separator = ;
51
+ # version_path_separator = space
52
+ # version_path_separator = newline
53
+ #
54
+ # Use os.pathsep. Default configuration used for new projects.
55
+ version_path_separator = os
56
+
57
+ # set to 'true' to search source files recursively
58
+ # in each "version_locations" directory
59
+ # new in Alembic version 1.10
60
+ # recursive_version_locations = false
61
+
62
+ # the output encoding used when revision files
63
+ # are written from script.py.mako
64
+ # output_encoding = utf-8
65
+
66
+ sqlalchemy.url = sqlite:///user/comfyui.db
67
+
68
+
69
+ [post_write_hooks]
70
+ # post_write_hooks defines scripts or Python functions that are run
71
+ # on newly generated revision scripts. See the documentation for further
72
+ # detail and examples
73
+
74
+ # format using "black" - use the console_scripts runner, against the "black" entrypoint
75
+ # hooks = black
76
+ # black.type = console_scripts
77
+ # black.entrypoint = black
78
+ # black.options = -l 79 REVISION_SCRIPT_FILENAME
79
+
80
+ # lint with attempts to fix using "ruff" - use the exec runner, execute a binary
81
+ # hooks = ruff
82
+ # ruff.type = exec
83
+ # ruff.executable = %(here)s/.venv/bin/ruff
84
+ # ruff.options = check --fix REVISION_SCRIPT_FILENAME
alembic_db/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ## Generate new revision
2
+
3
+ 1. Update models in `/app/database/models.py`
4
+ 2. Run `alembic revision --autogenerate -m "{your message}"`
alembic_db/env.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import engine_from_config
2
+ from sqlalchemy import pool
3
+
4
+ from alembic import context
5
+
6
+ # this is the Alembic Config object, which provides
7
+ # access to the values within the .ini file in use.
8
+ config = context.config
9
+
10
+
11
+ from app.database.models import Base
12
+ target_metadata = Base.metadata
13
+
14
+ # other values from the config, defined by the needs of env.py,
15
+ # can be acquired:
16
+ # my_important_option = config.get_main_option("my_important_option")
17
+ # ... etc.
18
+
19
+
20
+ def run_migrations_offline() -> None:
21
+ """Run migrations in 'offline' mode.
22
+ This configures the context with just a URL
23
+ and not an Engine, though an Engine is acceptable
24
+ here as well. By skipping the Engine creation
25
+ we don't even need a DBAPI to be available.
26
+ Calls to context.execute() here emit the given string to the
27
+ script output.
28
+ """
29
+ url = config.get_main_option("sqlalchemy.url")
30
+ context.configure(
31
+ url=url,
32
+ target_metadata=target_metadata,
33
+ literal_binds=True,
34
+ dialect_opts={"paramstyle": "named"},
35
+ )
36
+
37
+ with context.begin_transaction():
38
+ context.run_migrations()
39
+
40
+
41
+ def run_migrations_online() -> None:
42
+ """Run migrations in 'online' mode.
43
+ In this scenario we need to create an Engine
44
+ and associate a connection with the context.
45
+ """
46
+ connectable = engine_from_config(
47
+ config.get_section(config.config_ini_section, {}),
48
+ prefix="sqlalchemy.",
49
+ poolclass=pool.NullPool,
50
+ )
51
+
52
+ with connectable.connect() as connection:
53
+ context.configure(
54
+ connection=connection, target_metadata=target_metadata
55
+ )
56
+
57
+ with context.begin_transaction():
58
+ context.run_migrations()
59
+
60
+
61
+ if context.is_offline_mode():
62
+ run_migrations_offline()
63
+ else:
64
+ run_migrations_online()
alembic_db/script.py.mako ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """${message}
2
+
3
+ Revision ID: ${up_revision}
4
+ Revises: ${down_revision | comma,n}
5
+ Create Date: ${create_date}
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ ${imports if imports else ""}
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = ${repr(up_revision)}
16
+ down_revision: Union[str, None] = ${repr(down_revision)}
17
+ branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
18
+ depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
19
+
20
+
21
+ def upgrade() -> None:
22
+ """Upgrade schema."""
23
+ ${upgrades if upgrades else "pass"}
24
+
25
+
26
+ def downgrade() -> None:
27
+ """Downgrade schema."""
28
+ ${downgrades if downgrades else "pass"}
alembic_db/versions/0001_assets.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Initial assets schema
3
+ Revision ID: 0001_assets
4
+ Revises: None
5
+ Create Date: 2025-12-10 00:00:00
6
+ """
7
+
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+
11
+ revision = "0001_assets"
12
+ down_revision = None
13
+ branch_labels = None
14
+ depends_on = None
15
+
16
+
17
+ def upgrade() -> None:
18
+ # ASSETS: content identity
19
+ op.create_table(
20
+ "assets",
21
+ sa.Column("id", sa.String(length=36), primary_key=True),
22
+ sa.Column("hash", sa.String(length=256), nullable=True),
23
+ sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
24
+ sa.Column("mime_type", sa.String(length=255), nullable=True),
25
+ sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
26
+ sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
27
+ )
28
+ op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
29
+ op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
30
+
31
+ # ASSETS_INFO: user-visible references
32
+ op.create_table(
33
+ "assets_info",
34
+ sa.Column("id", sa.String(length=36), primary_key=True),
35
+ sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
36
+ sa.Column("name", sa.String(length=512), nullable=False),
37
+ sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
38
+ sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
39
+ sa.Column("user_metadata", sa.JSON(), nullable=True),
40
+ sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
41
+ sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
42
+ sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
43
+ sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
44
+ )
45
+ op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
46
+ op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
47
+ op.create_index("ix_assets_info_name", "assets_info", ["name"])
48
+ op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
49
+ op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
50
+ op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
51
+
52
+ # TAGS: normalized tag vocabulary
53
+ op.create_table(
54
+ "tags",
55
+ sa.Column("name", sa.String(length=512), primary_key=True),
56
+ sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
57
+ sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
58
+ )
59
+ op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
60
+
61
+ # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
62
+ op.create_table(
63
+ "asset_info_tags",
64
+ sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
65
+ sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
66
+ sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
67
+ sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
68
+ sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
69
+ )
70
+ op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
71
+ op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
72
+
73
+ # ASSET_CACHE_STATE: N:1 local cache rows per Asset
74
+ op.create_table(
75
+ "asset_cache_state",
76
+ sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
77
+ sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
78
+ sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
79
+ sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
80
+ sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
81
+ sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
82
+ sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
83
+ )
84
+ op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
85
+ op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
86
+
87
+ # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
88
+ op.create_table(
89
+ "asset_info_meta",
90
+ sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
91
+ sa.Column("key", sa.String(length=256), nullable=False),
92
+ sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
93
+ sa.Column("val_str", sa.String(length=2048), nullable=True),
94
+ sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
95
+ sa.Column("val_bool", sa.Boolean(), nullable=True),
96
+ sa.Column("val_json", sa.JSON(), nullable=True),
97
+ sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
98
+ )
99
+ op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
100
+ op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
101
+ op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
102
+ op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
103
+
104
+ # Tags vocabulary
105
+ tags_table = sa.table(
106
+ "tags",
107
+ sa.column("name", sa.String(length=512)),
108
+ sa.column("tag_type", sa.String()),
109
+ )
110
+ op.bulk_insert(
111
+ tags_table,
112
+ [
113
+ {"name": "models", "tag_type": "system"},
114
+ {"name": "input", "tag_type": "system"},
115
+ {"name": "output", "tag_type": "system"},
116
+
117
+ {"name": "configs", "tag_type": "system"},
118
+ {"name": "checkpoints", "tag_type": "system"},
119
+ {"name": "loras", "tag_type": "system"},
120
+ {"name": "vae", "tag_type": "system"},
121
+ {"name": "text_encoders", "tag_type": "system"},
122
+ {"name": "diffusion_models", "tag_type": "system"},
123
+ {"name": "clip_vision", "tag_type": "system"},
124
+ {"name": "style_models", "tag_type": "system"},
125
+ {"name": "embeddings", "tag_type": "system"},
126
+ {"name": "diffusers", "tag_type": "system"},
127
+ {"name": "vae_approx", "tag_type": "system"},
128
+ {"name": "controlnet", "tag_type": "system"},
129
+ {"name": "gligen", "tag_type": "system"},
130
+ {"name": "upscale_models", "tag_type": "system"},
131
+ {"name": "hypernetworks", "tag_type": "system"},
132
+ {"name": "photomaker", "tag_type": "system"},
133
+ {"name": "classifiers", "tag_type": "system"},
134
+
135
+ {"name": "encoder", "tag_type": "system"},
136
+ {"name": "decoder", "tag_type": "system"},
137
+
138
+ {"name": "missing", "tag_type": "system"},
139
+ {"name": "rescan", "tag_type": "system"},
140
+ ],
141
+ )
142
+
143
+
144
+ def downgrade() -> None:
145
+ op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
146
+ op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
147
+ op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
148
+ op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
149
+ op.drop_table("asset_info_meta")
150
+
151
+ op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
152
+ op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
153
+ op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
154
+ op.drop_table("asset_cache_state")
155
+
156
+ op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
157
+ op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
158
+ op.drop_table("asset_info_tags")
159
+
160
+ op.drop_index("ix_tags_tag_type", table_name="tags")
161
+ op.drop_table("tags")
162
+
163
+ op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
164
+ op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
165
+ op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
166
+ op.drop_index("ix_assets_info_created_at", table_name="assets_info")
167
+ op.drop_index("ix_assets_info_name", table_name="assets_info")
168
+ op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
169
+ op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
170
+ op.drop_table("assets_info")
171
+
172
+ op.drop_index("uq_assets_hash", table_name="assets")
173
+ op.drop_index("ix_assets_mime_type", table_name="assets")
174
+ op.drop_table("assets")
api_server/__init__.py ADDED
File without changes
api_server/routes/__init__.py ADDED
File without changes
api_server/routes/internal/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ComfyUI Internal Routes
2
+
3
+ All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
api_server/routes/internal/__init__.py ADDED
File without changes
api_server/routes/internal/internal_routes.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aiohttp import web
2
+ from typing import Optional
3
+ from folder_paths import folder_names_and_paths, get_directory_by_type
4
+ from api_server.services.terminal_service import TerminalService
5
+ import app.logger
6
+ import os
7
+
8
+ class InternalRoutes:
9
+ '''
10
+ The top level web router for internal routes: /internal/*
11
+ The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
12
+ Check README.md for more information.
13
+ '''
14
+
15
+ def __init__(self, prompt_server):
16
+ self.routes: web.RouteTableDef = web.RouteTableDef()
17
+ self._app: Optional[web.Application] = None
18
+ self.prompt_server = prompt_server
19
+ self.terminal_service = TerminalService(prompt_server)
20
+
21
+ def setup_routes(self):
22
+ @self.routes.get('/logs')
23
+ async def get_logs(request):
24
+ return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
25
+
26
+ @self.routes.get('/logs/raw')
27
+ async def get_raw_logs(request):
28
+ self.terminal_service.update_size()
29
+ return web.json_response({
30
+ "entries": list(app.logger.get_logs()),
31
+ "size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
32
+ })
33
+
34
+ @self.routes.patch('/logs/subscribe')
35
+ async def subscribe_logs(request):
36
+ json_data = await request.json()
37
+ client_id = json_data["clientId"]
38
+ enabled = json_data["enabled"]
39
+ if enabled:
40
+ self.terminal_service.subscribe(client_id)
41
+ else:
42
+ self.terminal_service.unsubscribe(client_id)
43
+
44
+ return web.Response(status=200)
45
+
46
+
47
+ @self.routes.get('/folder_paths')
48
+ async def get_folder_paths(request):
49
+ response = {}
50
+ for key in folder_names_and_paths:
51
+ response[key] = folder_names_and_paths[key][0]
52
+ return web.json_response(response)
53
+
54
+ @self.routes.get('/files/{directory_type}')
55
+ async def get_files(request: web.Request) -> web.Response:
56
+ directory_type = request.match_info['directory_type']
57
+ if directory_type not in ("output", "input", "temp"):
58
+ return web.json_response({"error": "Invalid directory type"}, status=400)
59
+
60
+ directory = get_directory_by_type(directory_type)
61
+
62
+ def is_visible_file(entry: os.DirEntry) -> bool:
63
+ """Filter out hidden files (e.g., .DS_Store on macOS)."""
64
+ return entry.is_file() and not entry.name.startswith('.')
65
+
66
+ sorted_files = sorted(
67
+ (entry for entry in os.scandir(directory) if is_visible_file(entry)),
68
+ key=lambda entry: -entry.stat().st_mtime
69
+ )
70
+ return web.json_response([entry.name for entry in sorted_files], status=200)
71
+
72
+
73
+ def get_app(self):
74
+ if self._app is None:
75
+ self._app = web.Application()
76
+ self.setup_routes()
77
+ self._app.add_routes(self.routes)
78
+ return self._app
api_server/services/__init__.py ADDED
File without changes
api_server/services/terminal_service.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.logger import on_flush
2
+ import os
3
+ import shutil
4
+
5
+
6
+ class TerminalService:
7
+ def __init__(self, server):
8
+ self.server = server
9
+ self.cols = None
10
+ self.rows = None
11
+ self.subscriptions = set()
12
+ on_flush(self.send_messages)
13
+
14
+ def get_terminal_size(self):
15
+ try:
16
+ size = os.get_terminal_size()
17
+ return (size.columns, size.lines)
18
+ except OSError:
19
+ try:
20
+ size = shutil.get_terminal_size()
21
+ return (size.columns, size.lines)
22
+ except OSError:
23
+ return (80, 24) # fallback to 80x24
24
+
25
+ def update_size(self):
26
+ columns, lines = self.get_terminal_size()
27
+ changed = False
28
+
29
+ if columns != self.cols:
30
+ self.cols = columns
31
+ changed = True
32
+
33
+ if lines != self.rows:
34
+ self.rows = lines
35
+ changed = True
36
+
37
+ if changed:
38
+ return {"cols": self.cols, "rows": self.rows}
39
+
40
+ return None
41
+
42
+ def subscribe(self, client_id):
43
+ self.subscriptions.add(client_id)
44
+
45
+ def unsubscribe(self, client_id):
46
+ self.subscriptions.discard(client_id)
47
+
48
+ def send_messages(self, entries):
49
+ if not len(entries) or not len(self.subscriptions):
50
+ return
51
+
52
+ new_size = self.update_size()
53
+
54
+ for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
55
+ if client_id not in self.server.sockets:
56
+ # Automatically unsub if the socket has disconnected
57
+ self.unsubscribe(client_id)
58
+ continue
59
+
60
+ self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
api_server/utils/file_operations.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union, TypedDict, Literal
3
+ from typing_extensions import TypeGuard
4
+ class FileInfo(TypedDict):
5
+ name: str
6
+ path: str
7
+ type: Literal["file"]
8
+ size: int
9
+
10
+ class DirectoryInfo(TypedDict):
11
+ name: str
12
+ path: str
13
+ type: Literal["directory"]
14
+
15
+ FileSystemItem = Union[FileInfo, DirectoryInfo]
16
+
17
+ def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
18
+ return item["type"] == "file"
19
+
20
+ class FileSystemOperations:
21
+ @staticmethod
22
+ def walk_directory(directory: str) -> List[FileSystemItem]:
23
+ file_list: List[FileSystemItem] = []
24
+ for root, dirs, files in os.walk(directory):
25
+ for name in files:
26
+ file_path = os.path.join(root, name)
27
+ relative_path = os.path.relpath(file_path, directory)
28
+ file_list.append({
29
+ "name": name,
30
+ "path": relative_path,
31
+ "type": "file",
32
+ "size": os.path.getsize(file_path)
33
+ })
34
+ for name in dirs:
35
+ dir_path = os.path.join(root, name)
36
+ relative_path = os.path.relpath(dir_path, directory)
37
+ file_list.append({
38
+ "name": name,
39
+ "path": relative_path,
40
+ "type": "directory"
41
+ })
42
+ return file_list
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import random
4
+ import sys
5
+ from typing import Sequence, Mapping, Any, Union
6
+ import torch
7
+ import gradio as gr
8
+ from huggingface_hub import hf_hub_download
9
+ from comfy import model_management
10
+
11
+ hf_hub_download(repo_id="John6666/zuki-cute-ill-v60-sdxl", filename="zukiCuteILL_v60.safetensors", local_dir="models/checkpoints")
12
+ hf_hub_download(repo_id="ximso/RealESRGAN_x4plus_anime_6B", filename="RealESRGAN_x4plus_anime_6B.pth", local_dir="models/upscale_models")
13
+
14
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
15
+ """Returns the value at the given index of a sequence or mapping.
16
+
17
+ If the object is a sequence (like list or string), returns the value at the given index.
18
+ If the object is a mapping (like a dictionary), returns the value at the index-th key.
19
+
20
+ Some return a dictionary, in these cases, we look for the "results" key
21
+
22
+ Args:
23
+ obj (Union[Sequence, Mapping]): The object to retrieve the value from.
24
+ index (int): The index of the value to retrieve.
25
+
26
+ Returns:
27
+ Any: The value at the given index.
28
+
29
+ Raises:
30
+ IndexError: If the index is out of bounds for the object and the object is not a mapping.
31
+ """
32
+ try:
33
+ return obj[index]
34
+ except KeyError:
35
+ return obj["result"][index]
36
+
37
+
38
+ def find_path(name: str, path: str = None) -> str:
39
+ """
40
+ Recursively looks at parent folders starting from the given path until it finds the given name.
41
+ Returns the path as a Path object if found, or None otherwise.
42
+ """
43
+ # If no path is given, use the current working directory
44
+ if path is None:
45
+ path = os.getcwd()
46
+
47
+ # Check if the current directory contains the name
48
+ if name in os.listdir(path):
49
+ path_name = os.path.join(path, name)
50
+ print(f"{name} found: {path_name}")
51
+ return path_name
52
+
53
+ # Get the parent directory
54
+ parent_directory = os.path.dirname(path)
55
+
56
+ # If the parent directory is the same as the current directory, we've reached the root and stop the search
57
+ if parent_directory == path:
58
+ return None
59
+
60
+ # Recursively call the function with the parent directory
61
+ return find_path(name, parent_directory)
62
+
63
+
64
+ def add_comfyui_directory_to_sys_path() -> None:
65
+ """
66
+ Add 'ComfyUI' to the sys.path
67
+ """
68
+ comfyui_path = find_path("ComfyUI")
69
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
70
+ sys.path.append(comfyui_path)
71
+ print(f"'{comfyui_path}' added to sys.path")
72
+
73
+
74
+ def add_extra_model_paths() -> None:
75
+ """
76
+ Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.
77
+ """
78
+ try:
79
+ from main import load_extra_path_config
80
+ except ImportError:
81
+ print(
82
+ "Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
83
+ )
84
+ from utils.extra_config import load_extra_path_config
85
+
86
+ extra_model_paths = find_path("extra_model_paths.yaml")
87
+
88
+ if extra_model_paths is not None:
89
+ load_extra_path_config(extra_model_paths)
90
+ else:
91
+ print("Could not find the extra_model_paths config file.")
92
+
93
+
94
+ add_comfyui_directory_to_sys_path()
95
+ add_extra_model_paths()
96
+
97
+
98
+ def import_custom_nodes() -> None:
99
+ """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS
100
+
101
+ This function sets up a new asyncio event loop, initializes the PromptServer,
102
+ creates a PromptQueue, and initializes the custom nodes.
103
+ """
104
+ import asyncio
105
+ import execution
106
+ from nodes import init_extra_nodes
107
+ import server
108
+
109
+ # Creating a new event loop and setting it as the default loop
110
+ loop = asyncio.new_event_loop()
111
+ asyncio.set_event_loop(loop)
112
+
113
+ # Creating an instance of PromptServer with the loop
114
+ server_instance = server.PromptServer(loop)
115
+ execution.PromptQueue(server_instance)
116
+
117
+ # Initializing custom nodes
118
+ asyncio.run(init_extra_nodes())
119
+
120
+
121
+ from nodes import NODE_CLASS_MAPPINGS
122
+ from comfy_extras.nodes_upscale_model import UpscaleModelLoader
123
+
124
+ import_custom_nodes()
125
+
126
+ checkpointloadersimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
127
+ checkpointloadersimple_4 = checkpointloadersimple.load_checkpoint(
128
+ ckpt_name="zukiCuteILL_v60.safetensors"
129
+ )
130
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
131
+
132
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
133
+ vaeencode = NODE_CLASS_MAPPINGS["VAEEncode"]()
134
+ conditioningconcat = NODE_CLASS_MAPPINGS["ConditioningConcat"]()
135
+ repeatlatentbatch = NODE_CLASS_MAPPINGS["RepeatLatentBatch"]()
136
+ ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
137
+ vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
138
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
139
+ upscalemodelloader_220 = UpscaleModelLoader.execute(
140
+ model_name="RealESRGAN_x4plus_anime_6B.pth"
141
+ )
142
+ pixelksampleupscalerprovider = NODE_CLASS_MAPPINGS["PixelKSampleUpscalerProvider"]()
143
+ iterativelatentupscale = NODE_CLASS_MAPPINGS["IterativeLatentUpscale"]()
144
+ stepsschedulehookprovider = NODE_CLASS_MAPPINGS["StepsScheduleHookProvider"]()
145
+ cfgschedulehookprovider = NODE_CLASS_MAPPINGS["CfgScheduleHookProvider"]()
146
+ pixelksamplehookcombine = NODE_CLASS_MAPPINGS["PixelKSampleHookCombine"]()
147
+
148
+ model_loaders = [checkpointloadersimple_4]
149
+
150
+ valid_models = [
151
+ getattr(loader[0], 'patcher', loader[0])
152
+ for loader in model_loaders
153
+ if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
154
+ ]
155
+ model_management.load_models_gpu(valid_models)
156
+
157
+ cliptextencode_7 = cliptextencode.encode(
158
+ text="lowres, bad quality, worst quality, bad anatomy, sketch, jpeg artifacts, ugly, poorly drawn, (signature, watermark, username, logo, web address, twitter_username, patreon_username, character_name, copyright_name), (censored, mosaic_censoring, convenient_censoring, bar_censor, heart_censor), blurry, simple background, transparent background,",
159
+ clip=get_value_at_index(checkpointloadersimple_4, 1),
160
+ )
161
+ cliptextencode_525 = cliptextencode.encode(
162
+ text="masterpiece, best quality, amazing quality, very aesthetic, absurdres, newest, volumetric lighting, dramatic lighting, ",
163
+ clip=get_value_at_index(checkpointloadersimple_4, 1),
164
+ )
165
+ cfgschedulehookprovider_541 = cfgschedulehookprovider.doit(
166
+ schedule_for_iteration="simple", target_cfg=10
167
+ )
168
+
169
+ @spaces.GPU
170
+ def generate_image(param_image, param_prompt, param_creative, param_style, param_prefix):
171
+
172
+ param_creative = float(param_creative)
173
+
174
+ if param_creative > 0.35:
175
+ param_amount1 = 3
176
+ param_amount2 = 1
177
+ param_step = 7
178
+ param_step2 = 15
179
+ else:
180
+ param_amount1 = 1
181
+ param_amount2 = 3
182
+ param_step = 8
183
+ param_step2 = 17
184
+
185
+ with torch.inference_mode():
186
+
187
+ loadimage_89 = loadimage.load_image(image=param_image)
188
+ vaeencode_229 = vaeencode.encode(
189
+ pixels=get_value_at_index(loadimage_89, 0),
190
+ vae=get_value_at_index(checkpointloadersimple_4, 2),
191
+ )
192
+ cliptextencode_524 = cliptextencode.encode(
193
+ text=param_prompt,
194
+ clip=get_value_at_index(checkpointloadersimple_4, 1),
195
+ )
196
+ cliptextencode_526 = cliptextencode.encode(
197
+ text=param_style,
198
+ clip=get_value_at_index(checkpointloadersimple_4, 1),
199
+ )
200
+ conditioningconcat_521 = conditioningconcat.concat(
201
+ conditioning_to=get_value_at_index(cliptextencode_526, 0),
202
+ conditioning_from=get_value_at_index(cliptextencode_524, 0),
203
+ )
204
+ conditioningconcat_527 = conditioningconcat.concat(
205
+ conditioning_to=get_value_at_index(conditioningconcat_521, 0),
206
+ conditioning_from=get_value_at_index(cliptextencode_525, 0),
207
+ )
208
+ repeatlatentbatch_506 = repeatlatentbatch.repeat(
209
+ amount=param_amount1, samples=get_value_at_index(vaeencode_229, 0)
210
+ )
211
+ ksampler_230 = ksampler.sample(
212
+ seed=random.randint(1, 2**64),
213
+ steps=20,
214
+ cfg=6,
215
+ sampler_name="euler_ancestral",
216
+ scheduler="normal",
217
+ denoise=param_creative,
218
+ model=get_value_at_index(checkpointloadersimple_4, 0),
219
+ positive=get_value_at_index(conditioningconcat_527, 0),
220
+ negative=get_value_at_index(cliptextencode_7, 0),
221
+ latent_image=get_value_at_index(repeatlatentbatch_506, 0),
222
+ )
223
+ repeatlatentbatch_509 = repeatlatentbatch.repeat(
224
+ amount=param_amount2, samples=get_value_at_index(ksampler_230, 0)
225
+ )
226
+ stepsschedulehookprovider_537 = stepsschedulehookprovider.doit(
227
+ schedule_for_iteration="simple", target_steps=param_step2
228
+ )
229
+ pixelksamplehookcombine_540 = pixelksamplehookcombine.doit(
230
+ hook1=get_value_at_index(stepsschedulehookprovider_537, 0),
231
+ hook2=get_value_at_index(cfgschedulehookprovider_541, 0),
232
+ )
233
+ pixelksampleupscalerprovider_462 = pixelksampleupscalerprovider.doit(
234
+ scale_method="lanczos",
235
+ seed=random.randint(1, 2**64),
236
+ steps=param_step,
237
+ cfg=9,
238
+ sampler_name="euler",
239
+ scheduler="normal",
240
+ denoise=0.35,
241
+ use_tiled_vae=False,
242
+ tile_size=512,
243
+ model=get_value_at_index(checkpointloadersimple_4, 0),
244
+ vae=get_value_at_index(checkpointloadersimple_4, 2),
245
+ positive=get_value_at_index(conditioningconcat_527, 0),
246
+ negative=get_value_at_index(cliptextencode_7, 0),
247
+ upscale_model_opt=get_value_at_index(upscalemodelloader_220, 0),
248
+ pk_hook_opt=get_value_at_index(pixelksamplehookcombine_540, 0),
249
+ )
250
+ iterativelatentupscale_461 = iterativelatentupscale.doit(
251
+ upscale_factor=1.5,
252
+ steps=2,
253
+ temp_prefix="",
254
+ step_mode="simple",
255
+ samples=get_value_at_index(repeatlatentbatch_509, 0),
256
+ upscaler=get_value_at_index(pixelksampleupscalerprovider_462, 0),
257
+ unique_id=1445395014345641493,
258
+ )
259
+ vaedecode_233 = vaedecode.decode(
260
+ samples=get_value_at_index(iterativelatentupscale_461, 0),
261
+ vae=get_value_at_index(iterativelatentupscale_461, 1),
262
+ )
263
+ saveimage_410 = saveimage.save_images(
264
+ filename_prefix=param_prefix,
265
+ images=get_value_at_index(vaedecode_233, 0),
266
+ )
267
+ saved_path = [
268
+ f"output/{saveimage_410['ui']['images'][0]['filename']}",
269
+ f"output/{saveimage_410['ui']['images'][1]['filename']}",
270
+ f"output/{saveimage_410['ui']['images'][2]['filename']}",
271
+ ]
272
+ return saved_path
273
+
274
+ with gr.Blocks() as app:
275
+ with gr.Row():
276
+ with gr.Column(scale=1):
277
+ image = gr.Image(label="Image", type="filepath", height=300, show_label=False)
278
+ prompt = gr.Textbox(label="prompt", lines=3, max_lines=3, placeholder="prompt")
279
+ style = gr.Textbox(label="style", lines=2, max_lines=2, placeholder="style")
280
+ creative = gr.Dropdown(
281
+ choices=[
282
+ ("balance", 0.65),
283
+ ("none", 0),
284
+ ("low", 0.25),
285
+ ("normal", 0.5),
286
+ ("high", 0.75),
287
+ ("ultra", 1),
288
+ ],
289
+ allow_custom_value=True,
290
+ value=0.65,
291
+ label="creative"
292
+ )
293
+ run_btn = gr.Button("Generate", variant="primary")
294
+ prefix = gr.Textbox(visible=False, value="comfyui_")
295
+ with gr.Column(scale=2):
296
+ output_image = gr.Gallery(
297
+ label="Result",
298
+ columns=3,
299
+ object_fit="contain",
300
+ height="auto"
301
+ )
302
+ run_btn.click(
303
+ fn=generate_image,
304
+ inputs=[image, prompt, creative, style, prefix],
305
+ outputs=[output_image]
306
+ )
307
+
308
+ if __name__ == "__main__":
309
+ app.launch(share=True)
app/__init__.py ADDED
File without changes
app/app_settings.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from aiohttp import web
4
+ import logging
5
+
6
+
7
+ class AppSettings():
8
+ def __init__(self, user_manager):
9
+ self.user_manager = user_manager
10
+
11
+ def get_settings(self, request):
12
+ try:
13
+ file = self.user_manager.get_request_user_filepath(
14
+ request,
15
+ "comfy.settings.json"
16
+ )
17
+ except KeyError as e:
18
+ logging.error("User settings not found.")
19
+ raise web.HTTPUnauthorized() from e
20
+ if os.path.isfile(file):
21
+ try:
22
+ with open(file) as f:
23
+ return json.load(f)
24
+ except:
25
+ logging.error(f"The user settings file is corrupted: {file}")
26
+ return {}
27
+ else:
28
+ return {}
29
+
30
+ def save_settings(self, request, settings):
31
+ file = self.user_manager.get_request_user_filepath(
32
+ request, "comfy.settings.json")
33
+ with open(file, "w") as f:
34
+ f.write(json.dumps(settings, indent=4))
35
+
36
+ def add_routes(self, routes):
37
+ @routes.get("/settings")
38
+ async def get_settings(request):
39
+ return web.json_response(self.get_settings(request))
40
+
41
+ @routes.get("/settings/{id}")
42
+ async def get_setting(request):
43
+ value = None
44
+ settings = self.get_settings(request)
45
+ setting_id = request.match_info.get("id", None)
46
+ if setting_id and setting_id in settings:
47
+ value = settings[setting_id]
48
+ return web.json_response(value)
49
+
50
+ @routes.post("/settings")
51
+ async def post_settings(request):
52
+ settings = self.get_settings(request)
53
+ new_settings = await request.json()
54
+ self.save_settings(request, {**settings, **new_settings})
55
+ return web.Response(status=200)
56
+
57
+ @routes.post("/settings/{id}")
58
+ async def post_setting(request):
59
+ setting_id = request.match_info.get("id", None)
60
+ if not setting_id:
61
+ return web.Response(status=400)
62
+ settings = self.get_settings(request)
63
+ settings[setting_id] = await request.json()
64
+ self.save_settings(request, settings)
65
+ return web.Response(status=200)
app/assets/api/routes.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import uuid
3
+ from aiohttp import web
4
+
5
+ from pydantic import ValidationError
6
+
7
+ import app.assets.manager as manager
8
+ from app import user_manager
9
+ from app.assets.api import schemas_in
10
+ from app.assets.helpers import get_query_dict
11
+
12
+ ROUTES = web.RouteTableDef()
13
+ USER_MANAGER: user_manager.UserManager | None = None
14
+
15
+ # UUID regex (canonical hyphenated form, case-insensitive)
16
+ UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
17
+
18
+ def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
19
+ global USER_MANAGER
20
+ USER_MANAGER = user_manager_instance
21
+ app.add_routes(ROUTES)
22
+
23
+ def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
24
+ return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
25
+
26
+
27
+ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
28
+ return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
29
+
30
+
31
+ @ROUTES.get("/api/assets")
32
+ async def list_assets(request: web.Request) -> web.Response:
33
+ """
34
+ GET request to list assets.
35
+ """
36
+ query_dict = get_query_dict(request)
37
+ try:
38
+ q = schemas_in.ListAssetsQuery.model_validate(query_dict)
39
+ except ValidationError as ve:
40
+ return _validation_error_response("INVALID_QUERY", ve)
41
+
42
+ payload = manager.list_assets(
43
+ include_tags=q.include_tags,
44
+ exclude_tags=q.exclude_tags,
45
+ name_contains=q.name_contains,
46
+ metadata_filter=q.metadata_filter,
47
+ limit=q.limit,
48
+ offset=q.offset,
49
+ sort=q.sort,
50
+ order=q.order,
51
+ owner_id=USER_MANAGER.get_request_user_id(request),
52
+ )
53
+ return web.json_response(payload.model_dump(mode="json"))
54
+
55
+
56
+ @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
57
+ async def get_asset(request: web.Request) -> web.Response:
58
+ """
59
+ GET request to get an asset's info as JSON.
60
+ """
61
+ asset_info_id = str(uuid.UUID(request.match_info["id"]))
62
+ try:
63
+ result = manager.get_asset(
64
+ asset_info_id=asset_info_id,
65
+ owner_id=USER_MANAGER.get_request_user_id(request),
66
+ )
67
+ except ValueError as e:
68
+ return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
69
+ except Exception:
70
+ logging.exception(
71
+ "get_asset failed for asset_info_id=%s, owner_id=%s",
72
+ asset_info_id,
73
+ USER_MANAGER.get_request_user_id(request),
74
+ )
75
+ return _error_response(500, "INTERNAL", "Unexpected server error.")
76
+ return web.json_response(result.model_dump(mode="json"), status=200)
77
+
78
+
79
+ @ROUTES.get("/api/tags")
80
+ async def get_tags(request: web.Request) -> web.Response:
81
+ """
82
+ GET request to list all tags based on query parameters.
83
+ """
84
+ query_map = dict(request.rel_url.query)
85
+
86
+ try:
87
+ query = schemas_in.TagsListQuery.model_validate(query_map)
88
+ except ValidationError as e:
89
+ return web.json_response(
90
+ {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
91
+ status=400,
92
+ )
93
+
94
+ result = manager.list_tags(
95
+ prefix=query.prefix,
96
+ limit=query.limit,
97
+ offset=query.offset,
98
+ order=query.order,
99
+ include_zero=query.include_zero,
100
+ owner_id=USER_MANAGER.get_request_user_id(request),
101
+ )
102
+ return web.json_response(result.model_dump(mode="json"))
app/assets/api/schemas_in.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ from typing import Any, Literal
4
+
5
+ from pydantic import (
6
+ BaseModel,
7
+ ConfigDict,
8
+ Field,
9
+ conint,
10
+ field_validator,
11
+ )
12
+
13
+
14
+ class ListAssetsQuery(BaseModel):
15
+ include_tags: list[str] = Field(default_factory=list)
16
+ exclude_tags: list[str] = Field(default_factory=list)
17
+ name_contains: str | None = None
18
+
19
+ # Accept either a JSON string (query param) or a dict
20
+ metadata_filter: dict[str, Any] | None = None
21
+
22
+ limit: conint(ge=1, le=500) = 20
23
+ offset: conint(ge=0) = 0
24
+
25
+ sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
26
+ order: Literal["asc", "desc"] = "desc"
27
+
28
+ @field_validator("include_tags", "exclude_tags", mode="before")
29
+ @classmethod
30
+ def _split_csv_tags(cls, v):
31
+ # Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
32
+ if v is None:
33
+ return []
34
+ if isinstance(v, str):
35
+ return [t.strip() for t in v.split(",") if t.strip()]
36
+ if isinstance(v, list):
37
+ out: list[str] = []
38
+ for item in v:
39
+ if isinstance(item, str):
40
+ out.extend([t.strip() for t in item.split(",") if t.strip()])
41
+ return out
42
+ return v
43
+
44
+ @field_validator("metadata_filter", mode="before")
45
+ @classmethod
46
+ def _parse_metadata_json(cls, v):
47
+ if v is None or isinstance(v, dict):
48
+ return v
49
+ if isinstance(v, str) and v.strip():
50
+ try:
51
+ parsed = json.loads(v)
52
+ except Exception as e:
53
+ raise ValueError(f"metadata_filter must be JSON: {e}") from e
54
+ if not isinstance(parsed, dict):
55
+ raise ValueError("metadata_filter must be a JSON object")
56
+ return parsed
57
+ return None
58
+
59
+
60
+ class TagsListQuery(BaseModel):
61
+ model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
62
+
63
+ prefix: str | None = Field(None, min_length=1, max_length=256)
64
+ limit: int = Field(100, ge=1, le=1000)
65
+ offset: int = Field(0, ge=0, le=10_000_000)
66
+ order: Literal["count_desc", "name_asc"] = "count_desc"
67
+ include_zero: bool = True
68
+
69
+ @field_validator("prefix")
70
+ @classmethod
71
+ def normalize_prefix(cls, v: str | None) -> str | None:
72
+ if v is None:
73
+ return v
74
+ v = v.strip()
75
+ return v.lower() or None
76
+
77
+
78
+ class SetPreviewBody(BaseModel):
79
+ """Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
80
+ preview_id: str | None = None
81
+
82
+ @field_validator("preview_id", mode="before")
83
+ @classmethod
84
+ def _norm_uuid(cls, v):
85
+ if v is None:
86
+ return None
87
+ s = str(v).strip()
88
+ if not s:
89
+ return None
90
+ try:
91
+ uuid.UUID(s)
92
+ except Exception:
93
+ raise ValueError("preview_id must be a UUID")
94
+ return s
app/assets/api/schemas_out.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Any
3
+
4
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer
5
+
6
+
7
+ class AssetSummary(BaseModel):
8
+ id: str
9
+ name: str
10
+ asset_hash: str | None = None
11
+ size: int | None = None
12
+ mime_type: str | None = None
13
+ tags: list[str] = Field(default_factory=list)
14
+ preview_url: str | None = None
15
+ created_at: datetime | None = None
16
+ updated_at: datetime | None = None
17
+ last_access_time: datetime | None = None
18
+
19
+ model_config = ConfigDict(from_attributes=True)
20
+
21
+ @field_serializer("created_at", "updated_at", "last_access_time")
22
+ def _ser_dt(self, v: datetime | None, _info):
23
+ return v.isoformat() if v else None
24
+
25
+
26
+ class AssetsList(BaseModel):
27
+ assets: list[AssetSummary]
28
+ total: int
29
+ has_more: bool
30
+
31
+
32
+ class AssetDetail(BaseModel):
33
+ id: str
34
+ name: str
35
+ asset_hash: str | None = None
36
+ size: int | None = None
37
+ mime_type: str | None = None
38
+ tags: list[str] = Field(default_factory=list)
39
+ user_metadata: dict[str, Any] = Field(default_factory=dict)
40
+ preview_id: str | None = None
41
+ created_at: datetime | None = None
42
+ last_access_time: datetime | None = None
43
+
44
+ model_config = ConfigDict(from_attributes=True)
45
+
46
+ @field_serializer("created_at", "last_access_time")
47
+ def _ser_dt(self, v: datetime | None, _info):
48
+ return v.isoformat() if v else None
49
+
50
+
51
+ class TagUsage(BaseModel):
52
+ name: str
53
+ count: int
54
+ type: str
55
+
56
+
57
+ class TagsList(BaseModel):
58
+ tags: list[TagUsage] = Field(default_factory=list)
59
+ total: int
60
+ has_more: bool
app/assets/database/bulk_ops.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import sqlalchemy
4
+ from typing import Iterable
5
+ from sqlalchemy.orm import Session
6
+ from sqlalchemy.dialects import sqlite
7
+
8
+ from app.assets.helpers import utcnow
9
+ from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
10
+
11
+ MAX_BIND_PARAMS = 800
12
+
13
+ def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
14
+ if not rows:
15
+ return []
16
+ rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
17
+ for i in range(0, len(rows), rows_per_stmt):
18
+ yield rows[i:i + rows_per_stmt]
19
+
20
+ def _iter_chunks(seq, n: int):
21
+ for i in range(0, len(seq), n):
22
+ yield seq[i:i + n]
23
+
24
+ def _rows_per_stmt(cols: int) -> int:
25
+ return max(1, MAX_BIND_PARAMS // max(1, cols))
26
+
27
+
28
+ def seed_from_paths_batch(
29
+ session: Session,
30
+ *,
31
+ specs: list[dict],
32
+ owner_id: str = "",
33
+ ) -> dict:
34
+ """Each spec is a dict with keys:
35
+ - abs_path: str
36
+ - size_bytes: int
37
+ - mtime_ns: int
38
+ - info_name: str
39
+ - tags: list[str]
40
+ - fname: Optional[str]
41
+ """
42
+ if not specs:
43
+ return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
44
+
45
+ now = utcnow()
46
+ asset_rows: list[dict] = []
47
+ state_rows: list[dict] = []
48
+ path_to_asset: dict[str, str] = {}
49
+ asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
50
+ path_list: list[str] = []
51
+
52
+ for sp in specs:
53
+ ap = os.path.abspath(sp["abs_path"])
54
+ aid = str(uuid.uuid4())
55
+ iid = str(uuid.uuid4())
56
+ path_list.append(ap)
57
+ path_to_asset[ap] = aid
58
+
59
+ asset_rows.append(
60
+ {
61
+ "id": aid,
62
+ "hash": None,
63
+ "size_bytes": sp["size_bytes"],
64
+ "mime_type": None,
65
+ "created_at": now,
66
+ }
67
+ )
68
+ state_rows.append(
69
+ {
70
+ "asset_id": aid,
71
+ "file_path": ap,
72
+ "mtime_ns": sp["mtime_ns"],
73
+ }
74
+ )
75
+ asset_to_info[aid] = {
76
+ "id": iid,
77
+ "owner_id": owner_id,
78
+ "name": sp["info_name"],
79
+ "asset_id": aid,
80
+ "preview_id": None,
81
+ "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
82
+ "created_at": now,
83
+ "updated_at": now,
84
+ "last_access_time": now,
85
+ "_tags": sp["tags"],
86
+ "_filename": sp["fname"],
87
+ }
88
+
89
+ # insert all seed Assets (hash=NULL)
90
+ ins_asset = sqlite.insert(Asset)
91
+ for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
92
+ session.execute(ins_asset, chunk)
93
+
94
+ # try to claim AssetCacheState (file_path)
95
+ # Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
96
+ ins_state = (
97
+ sqlite.insert(AssetCacheState)
98
+ .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
99
+ )
100
+ for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
101
+ session.execute(ins_state, chunk)
102
+
103
+ # Query to find which of our paths won (were actually inserted)
104
+ winners_by_path: set[str] = set()
105
+ for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
106
+ result = session.execute(
107
+ sqlalchemy.select(AssetCacheState.file_path)
108
+ .where(AssetCacheState.file_path.in_(chunk))
109
+ .where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
110
+ )
111
+ winners_by_path.update(result.scalars().all())
112
+
113
+ all_paths_set = set(path_list)
114
+ losers_by_path = all_paths_set - winners_by_path
115
+ lost_assets = [path_to_asset[p] for p in losers_by_path]
116
+ if lost_assets: # losers get their Asset removed
117
+ for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
118
+ session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
119
+
120
+ if not winners_by_path:
121
+ return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
122
+
123
+ # insert AssetInfo only for winners
124
+ # Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
125
+ winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
126
+ ins_info = (
127
+ sqlite.insert(AssetInfo)
128
+ .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
129
+ )
130
+ for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
131
+ session.execute(ins_info, chunk)
132
+
133
+ # Query to find which info rows were actually inserted (by matching our generated IDs)
134
+ all_info_ids = [row["id"] for row in winner_info_rows]
135
+ inserted_info_ids: set[str] = set()
136
+ for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
137
+ result = session.execute(
138
+ sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
139
+ )
140
+ inserted_info_ids.update(result.scalars().all())
141
+
142
+ # build and insert tag + meta rows for the AssetInfo
143
+ tag_rows: list[dict] = []
144
+ meta_rows: list[dict] = []
145
+ if inserted_info_ids:
146
+ for row in winner_info_rows:
147
+ iid = row["id"]
148
+ if iid not in inserted_info_ids:
149
+ continue
150
+ for t in row["_tags"]:
151
+ tag_rows.append({
152
+ "asset_info_id": iid,
153
+ "tag_name": t,
154
+ "origin": "automatic",
155
+ "added_at": now,
156
+ })
157
+ if row["_filename"]:
158
+ meta_rows.append(
159
+ {
160
+ "asset_info_id": iid,
161
+ "key": "filename",
162
+ "ordinal": 0,
163
+ "val_str": row["_filename"],
164
+ "val_num": None,
165
+ "val_bool": None,
166
+ "val_json": None,
167
+ }
168
+ )
169
+
170
+ bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
171
+ return {
172
+ "inserted_infos": len(inserted_info_ids),
173
+ "won_states": len(winners_by_path),
174
+ "lost_states": len(losers_by_path),
175
+ }
176
+
177
+
178
+ def bulk_insert_tags_and_meta(
179
+ session: Session,
180
+ *,
181
+ tag_rows: list[dict],
182
+ meta_rows: list[dict],
183
+ max_bind_params: int,
184
+ ) -> None:
185
+ """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
186
+ - tag_rows keys: asset_info_id, tag_name, origin, added_at
187
+ - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
188
+ """
189
+ if tag_rows:
190
+ ins_links = (
191
+ sqlite.insert(AssetInfoTag)
192
+ .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
193
+ )
194
+ for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
195
+ session.execute(ins_links, chunk)
196
+ if meta_rows:
197
+ ins_meta = (
198
+ sqlite.insert(AssetInfoMeta)
199
+ .on_conflict_do_nothing(
200
+ index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
201
+ )
202
+ )
203
+ for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
204
+ session.execute(ins_meta, chunk)
app/assets/database/models.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import uuid
4
+ from datetime import datetime
5
+
6
+ from typing import Any
7
+ from sqlalchemy import (
8
+ JSON,
9
+ BigInteger,
10
+ Boolean,
11
+ CheckConstraint,
12
+ DateTime,
13
+ ForeignKey,
14
+ Index,
15
+ Integer,
16
+ Numeric,
17
+ String,
18
+ Text,
19
+ UniqueConstraint,
20
+ )
21
+ from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
22
+
23
+ from app.assets.helpers import utcnow
24
+ from app.database.models import to_dict, Base
25
+
26
+
27
+ class Asset(Base):
28
+ __tablename__ = "assets"
29
+
30
+ id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
31
+ hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
32
+ size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
33
+ mime_type: Mapped[str | None] = mapped_column(String(255))
34
+ created_at: Mapped[datetime] = mapped_column(
35
+ DateTime(timezone=False), nullable=False, default=utcnow
36
+ )
37
+
38
+ infos: Mapped[list[AssetInfo]] = relationship(
39
+ "AssetInfo",
40
+ back_populates="asset",
41
+ primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
42
+ foreign_keys=lambda: [AssetInfo.asset_id],
43
+ cascade="all,delete-orphan",
44
+ passive_deletes=True,
45
+ )
46
+
47
+ preview_of: Mapped[list[AssetInfo]] = relationship(
48
+ "AssetInfo",
49
+ back_populates="preview_asset",
50
+ primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
51
+ foreign_keys=lambda: [AssetInfo.preview_id],
52
+ viewonly=True,
53
+ )
54
+
55
+ cache_states: Mapped[list[AssetCacheState]] = relationship(
56
+ back_populates="asset",
57
+ cascade="all, delete-orphan",
58
+ passive_deletes=True,
59
+ )
60
+
61
+ __table_args__ = (
62
+ Index("uq_assets_hash", "hash", unique=True),
63
+ Index("ix_assets_mime_type", "mime_type"),
64
+ CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
65
+ )
66
+
67
+ def to_dict(self, include_none: bool = False) -> dict[str, Any]:
68
+ return to_dict(self, include_none=include_none)
69
+
70
+ def __repr__(self) -> str:
71
+ return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
72
+
73
+
74
+ class AssetCacheState(Base):
75
+ __tablename__ = "asset_cache_state"
76
+
77
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
78
+ asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
79
+ file_path: Mapped[str] = mapped_column(Text, nullable=False)
80
+ mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
81
+ needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
82
+
83
+ asset: Mapped[Asset] = relationship(back_populates="cache_states")
84
+
85
+ __table_args__ = (
86
+ Index("ix_asset_cache_state_file_path", "file_path"),
87
+ Index("ix_asset_cache_state_asset_id", "asset_id"),
88
+ CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
89
+ UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
90
+ )
91
+
92
+ def to_dict(self, include_none: bool = False) -> dict[str, Any]:
93
+ return to_dict(self, include_none=include_none)
94
+
95
+ def __repr__(self) -> str:
96
+ return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
97
+
98
+
99
+ class AssetInfo(Base):
100
+ __tablename__ = "assets_info"
101
+
102
+ id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
103
+ owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
104
+ name: Mapped[str] = mapped_column(String(512), nullable=False)
105
+ asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
106
+ preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
107
+ user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
108
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
109
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
110
+ last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
111
+
112
+ asset: Mapped[Asset] = relationship(
113
+ "Asset",
114
+ back_populates="infos",
115
+ foreign_keys=[asset_id],
116
+ lazy="selectin",
117
+ )
118
+ preview_asset: Mapped[Asset | None] = relationship(
119
+ "Asset",
120
+ back_populates="preview_of",
121
+ foreign_keys=[preview_id],
122
+ )
123
+
124
+ metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
125
+ back_populates="asset_info",
126
+ cascade="all,delete-orphan",
127
+ passive_deletes=True,
128
+ )
129
+
130
+ tag_links: Mapped[list[AssetInfoTag]] = relationship(
131
+ back_populates="asset_info",
132
+ cascade="all,delete-orphan",
133
+ passive_deletes=True,
134
+ overlaps="tags,asset_infos",
135
+ )
136
+
137
+ tags: Mapped[list[Tag]] = relationship(
138
+ secondary="asset_info_tags",
139
+ back_populates="asset_infos",
140
+ lazy="selectin",
141
+ viewonly=True,
142
+ overlaps="tag_links,asset_info_links,asset_infos,tag",
143
+ )
144
+
145
+ __table_args__ = (
146
+ UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
147
+ Index("ix_assets_info_owner_name", "owner_id", "name"),
148
+ Index("ix_assets_info_owner_id", "owner_id"),
149
+ Index("ix_assets_info_asset_id", "asset_id"),
150
+ Index("ix_assets_info_name", "name"),
151
+ Index("ix_assets_info_created_at", "created_at"),
152
+ Index("ix_assets_info_last_access_time", "last_access_time"),
153
+ )
154
+
155
+ def to_dict(self, include_none: bool = False) -> dict[str, Any]:
156
+ data = to_dict(self, include_none=include_none)
157
+ data["tags"] = [t.name for t in self.tags]
158
+ return data
159
+
160
+ def __repr__(self) -> str:
161
+ return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
162
+
163
+
164
+ class AssetInfoMeta(Base):
165
+ __tablename__ = "asset_info_meta"
166
+
167
+ asset_info_id: Mapped[str] = mapped_column(
168
+ String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
169
+ )
170
+ key: Mapped[str] = mapped_column(String(256), primary_key=True)
171
+ ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
172
+
173
+ val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
174
+ val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
175
+ val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
176
+ val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
177
+
178
+ asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
179
+
180
+ __table_args__ = (
181
+ Index("ix_asset_info_meta_key", "key"),
182
+ Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
183
+ Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
184
+ Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
185
+ )
186
+
187
+
188
+ class AssetInfoTag(Base):
189
+ __tablename__ = "asset_info_tags"
190
+
191
+ asset_info_id: Mapped[str] = mapped_column(
192
+ String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
193
+ )
194
+ tag_name: Mapped[str] = mapped_column(
195
+ String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
196
+ )
197
+ origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
198
+ added_at: Mapped[datetime] = mapped_column(
199
+ DateTime(timezone=False), nullable=False, default=utcnow
200
+ )
201
+
202
+ asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
203
+ tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
204
+
205
+ __table_args__ = (
206
+ Index("ix_asset_info_tags_tag_name", "tag_name"),
207
+ Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
208
+ )
209
+
210
+
211
+ class Tag(Base):
212
+ __tablename__ = "tags"
213
+
214
+ name: Mapped[str] = mapped_column(String(512), primary_key=True)
215
+ tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
216
+
217
+ asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
218
+ back_populates="tag",
219
+ overlaps="asset_infos,tags",
220
+ )
221
+ asset_infos: Mapped[list[AssetInfo]] = relationship(
222
+ secondary="asset_info_tags",
223
+ back_populates="tags",
224
+ viewonly=True,
225
+ overlaps="asset_info_links,tag_links,tags,asset_info",
226
+ )
227
+
228
+ __table_args__ = (
229
+ Index("ix_tags_tag_type", "tag_type"),
230
+ )
231
+
232
+ def __repr__(self) -> str:
233
+ return f"<Tag {self.name}>"
app/assets/database/queries.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlalchemy as sa
2
+ from collections import defaultdict
3
+ from sqlalchemy import select, exists, func
4
+ from sqlalchemy.orm import Session, contains_eager, noload
5
+ from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
6
+ from app.assets.helpers import escape_like_prefix, normalize_tags
7
+ from typing import Sequence
8
+
9
+
10
+ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
11
+ """Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
12
+ owner_id = (owner_id or "").strip()
13
+ if owner_id == "":
14
+ return AssetInfo.owner_id == ""
15
+ return AssetInfo.owner_id.in_(["", owner_id])
16
+
17
+
18
+ def apply_tag_filters(
19
+ stmt: sa.sql.Select,
20
+ include_tags: Sequence[str] | None = None,
21
+ exclude_tags: Sequence[str] | None = None,
22
+ ) -> sa.sql.Select:
23
+ """include_tags: every tag must be present; exclude_tags: none may be present."""
24
+ include_tags = normalize_tags(include_tags)
25
+ exclude_tags = normalize_tags(exclude_tags)
26
+
27
+ if include_tags:
28
+ for tag_name in include_tags:
29
+ stmt = stmt.where(
30
+ exists().where(
31
+ (AssetInfoTag.asset_info_id == AssetInfo.id)
32
+ & (AssetInfoTag.tag_name == tag_name)
33
+ )
34
+ )
35
+
36
+ if exclude_tags:
37
+ stmt = stmt.where(
38
+ ~exists().where(
39
+ (AssetInfoTag.asset_info_id == AssetInfo.id)
40
+ & (AssetInfoTag.tag_name.in_(exclude_tags))
41
+ )
42
+ )
43
+ return stmt
44
+
45
+ def apply_metadata_filter(
46
+ stmt: sa.sql.Select,
47
+ metadata_filter: dict | None = None,
48
+ ) -> sa.sql.Select:
49
+ """Apply filters using asset_info_meta projection table."""
50
+ if not metadata_filter:
51
+ return stmt
52
+
53
+ def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
54
+ return sa.exists().where(
55
+ AssetInfoMeta.asset_info_id == AssetInfo.id,
56
+ AssetInfoMeta.key == key,
57
+ *preds,
58
+ )
59
+
60
+ def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
61
+ if value is None:
62
+ no_row_for_key = sa.not_(
63
+ sa.exists().where(
64
+ AssetInfoMeta.asset_info_id == AssetInfo.id,
65
+ AssetInfoMeta.key == key,
66
+ )
67
+ )
68
+ null_row = _exists_for_pred(
69
+ key,
70
+ AssetInfoMeta.val_json.is_(None),
71
+ AssetInfoMeta.val_str.is_(None),
72
+ AssetInfoMeta.val_num.is_(None),
73
+ AssetInfoMeta.val_bool.is_(None),
74
+ )
75
+ return sa.or_(no_row_for_key, null_row)
76
+
77
+ if isinstance(value, bool):
78
+ return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
79
+ if isinstance(value, (int, float)):
80
+ from decimal import Decimal
81
+ num = value if isinstance(value, Decimal) else Decimal(str(value))
82
+ return _exists_for_pred(key, AssetInfoMeta.val_num == num)
83
+ if isinstance(value, str):
84
+ return _exists_for_pred(key, AssetInfoMeta.val_str == value)
85
+ return _exists_for_pred(key, AssetInfoMeta.val_json == value)
86
+
87
+ for k, v in metadata_filter.items():
88
+ if isinstance(v, list):
89
+ ors = [_exists_clause_for_value(k, elem) for elem in v]
90
+ if ors:
91
+ stmt = stmt.where(sa.or_(*ors))
92
+ else:
93
+ stmt = stmt.where(_exists_clause_for_value(k, v))
94
+ return stmt
95
+
96
+
97
+ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
98
+ """
99
+ Check if an asset with a given hash exists in database.
100
+ """
101
+ row = (
102
+ session.execute(
103
+ select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
104
+ )
105
+ ).first()
106
+ return row is not None
107
+
108
+ def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
109
+ return session.get(AssetInfo, asset_info_id)
110
+
111
+ def list_asset_infos_page(
112
+ session: Session,
113
+ owner_id: str = "",
114
+ include_tags: Sequence[str] | None = None,
115
+ exclude_tags: Sequence[str] | None = None,
116
+ name_contains: str | None = None,
117
+ metadata_filter: dict | None = None,
118
+ limit: int = 20,
119
+ offset: int = 0,
120
+ sort: str = "created_at",
121
+ order: str = "desc",
122
+ ) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
123
+ base = (
124
+ select(AssetInfo)
125
+ .join(Asset, Asset.id == AssetInfo.asset_id)
126
+ .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
127
+ .where(visible_owner_clause(owner_id))
128
+ )
129
+
130
+ if name_contains:
131
+ escaped, esc = escape_like_prefix(name_contains)
132
+ base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
133
+
134
+ base = apply_tag_filters(base, include_tags, exclude_tags)
135
+ base = apply_metadata_filter(base, metadata_filter)
136
+
137
+ sort = (sort or "created_at").lower()
138
+ order = (order or "desc").lower()
139
+ sort_map = {
140
+ "name": AssetInfo.name,
141
+ "created_at": AssetInfo.created_at,
142
+ "updated_at": AssetInfo.updated_at,
143
+ "last_access_time": AssetInfo.last_access_time,
144
+ "size": Asset.size_bytes,
145
+ }
146
+ sort_col = sort_map.get(sort, AssetInfo.created_at)
147
+ sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
148
+
149
+ base = base.order_by(sort_exp).limit(limit).offset(offset)
150
+
151
+ count_stmt = (
152
+ select(sa.func.count())
153
+ .select_from(AssetInfo)
154
+ .join(Asset, Asset.id == AssetInfo.asset_id)
155
+ .where(visible_owner_clause(owner_id))
156
+ )
157
+ if name_contains:
158
+ escaped, esc = escape_like_prefix(name_contains)
159
+ count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
160
+ count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
161
+ count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
162
+
163
+ total = int((session.execute(count_stmt)).scalar_one() or 0)
164
+
165
+ infos = (session.execute(base)).unique().scalars().all()
166
+
167
+ id_list: list[str] = [i.id for i in infos]
168
+ tag_map: dict[str, list[str]] = defaultdict(list)
169
+ if id_list:
170
+ rows = session.execute(
171
+ select(AssetInfoTag.asset_info_id, Tag.name)
172
+ .join(Tag, Tag.name == AssetInfoTag.tag_name)
173
+ .where(AssetInfoTag.asset_info_id.in_(id_list))
174
+ )
175
+ for aid, tag_name in rows.all():
176
+ tag_map[aid].append(tag_name)
177
+
178
+ return infos, tag_map, total
179
+
180
+ def fetch_asset_info_asset_and_tags(
181
+ session: Session,
182
+ asset_info_id: str,
183
+ owner_id: str = "",
184
+ ) -> tuple[AssetInfo, Asset, list[str]] | None:
185
+ stmt = (
186
+ select(AssetInfo, Asset, Tag.name)
187
+ .join(Asset, Asset.id == AssetInfo.asset_id)
188
+ .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
189
+ .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
190
+ .where(
191
+ AssetInfo.id == asset_info_id,
192
+ visible_owner_clause(owner_id),
193
+ )
194
+ .options(noload(AssetInfo.tags))
195
+ .order_by(Tag.name.asc())
196
+ )
197
+
198
+ rows = (session.execute(stmt)).all()
199
+ if not rows:
200
+ return None
201
+
202
+ first_info, first_asset, _ = rows[0]
203
+ tags: list[str] = []
204
+ seen: set[str] = set()
205
+ for _info, _asset, tag_name in rows:
206
+ if tag_name and tag_name not in seen:
207
+ seen.add(tag_name)
208
+ tags.append(tag_name)
209
+ return first_info, first_asset, tags
210
+
211
+ def list_tags_with_usage(
212
+ session: Session,
213
+ prefix: str | None = None,
214
+ limit: int = 100,
215
+ offset: int = 0,
216
+ include_zero: bool = True,
217
+ order: str = "count_desc",
218
+ owner_id: str = "",
219
+ ) -> tuple[list[tuple[str, str, int]], int]:
220
+ counts_sq = (
221
+ select(
222
+ AssetInfoTag.tag_name.label("tag_name"),
223
+ func.count(AssetInfoTag.asset_info_id).label("cnt"),
224
+ )
225
+ .select_from(AssetInfoTag)
226
+ .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
227
+ .where(visible_owner_clause(owner_id))
228
+ .group_by(AssetInfoTag.tag_name)
229
+ .subquery()
230
+ )
231
+
232
+ q = (
233
+ select(
234
+ Tag.name,
235
+ Tag.tag_type,
236
+ func.coalesce(counts_sq.c.cnt, 0).label("count"),
237
+ )
238
+ .select_from(Tag)
239
+ .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
240
+ )
241
+
242
+ if prefix:
243
+ escaped, esc = escape_like_prefix(prefix.strip().lower())
244
+ q = q.where(Tag.name.like(escaped + "%", escape=esc))
245
+
246
+ if not include_zero:
247
+ q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
248
+
249
+ if order == "name_asc":
250
+ q = q.order_by(Tag.name.asc())
251
+ else:
252
+ q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
253
+
254
+ total_q = select(func.count()).select_from(Tag)
255
+ if prefix:
256
+ escaped, esc = escape_like_prefix(prefix.strip().lower())
257
+ total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
258
+ if not include_zero:
259
+ total_q = total_q.where(
260
+ Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
261
+ )
262
+
263
+ rows = (session.execute(q.limit(limit).offset(offset))).all()
264
+ total = (session.execute(total_q)).scalar_one()
265
+
266
+ rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
267
+ return rows_norm, int(total or 0)
app/assets/database/tags.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+ import sqlalchemy
4
+ from sqlalchemy.orm import Session
5
+ from sqlalchemy.dialects import sqlite
6
+
7
+ from app.assets.helpers import normalize_tags, utcnow
8
+ from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
9
+
10
+
11
+ def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
12
+ wanted = normalize_tags(list(names))
13
+ if not wanted:
14
+ return
15
+ rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
16
+ ins = (
17
+ sqlite.insert(Tag)
18
+ .values(rows)
19
+ .on_conflict_do_nothing(index_elements=[Tag.name])
20
+ )
21
+ return session.execute(ins)
22
+
23
+ def add_missing_tag_for_asset_id(
24
+ session: Session,
25
+ *,
26
+ asset_id: str,
27
+ origin: str = "automatic",
28
+ ) -> None:
29
+ select_rows = (
30
+ sqlalchemy.select(
31
+ AssetInfo.id.label("asset_info_id"),
32
+ sqlalchemy.literal("missing").label("tag_name"),
33
+ sqlalchemy.literal(origin).label("origin"),
34
+ sqlalchemy.literal(utcnow()).label("added_at"),
35
+ )
36
+ .where(AssetInfo.asset_id == asset_id)
37
+ .where(
38
+ sqlalchemy.not_(
39
+ sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
40
+ )
41
+ )
42
+ )
43
+ session.execute(
44
+ sqlite.insert(AssetInfoTag)
45
+ .from_select(
46
+ ["asset_info_id", "tag_name", "origin", "added_at"],
47
+ select_rows,
48
+ )
49
+ .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
50
+ )
51
+
52
+ def remove_missing_tag_for_asset_id(
53
+ session: Session,
54
+ *,
55
+ asset_id: str,
56
+ ) -> None:
57
+ session.execute(
58
+ sqlalchemy.delete(AssetInfoTag).where(
59
+ AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
60
+ AssetInfoTag.tag_name == "missing",
61
+ )
62
+ )
app/assets/hashing.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from blake3 import blake3
2
+ from typing import IO
3
+ import os
4
+ import asyncio
5
+
6
+
7
+ DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
8
+
9
+ # NOTE: this allows hashing different representations of a file-like object
10
+ def blake3_hash(
11
+ fp: str | IO[bytes],
12
+ chunk_size: int = DEFAULT_CHUNK,
13
+ ) -> str:
14
+ """
15
+ Returns a BLAKE3 hex digest for ``fp``, which may be:
16
+ - a filename (str/bytes) or PathLike
17
+ - an open binary file object
18
+ If ``fp`` is a file object, it must be opened in **binary** mode and support
19
+ ``read``, ``seek``, and ``tell``. The function will seek to the start before
20
+ reading and will attempt to restore the original position afterward.
21
+ """
22
+ # duck typing to check if input is a file-like object
23
+ if hasattr(fp, "read"):
24
+ return _hash_file_obj(fp, chunk_size)
25
+
26
+ with open(os.fspath(fp), "rb") as f:
27
+ return _hash_file_obj(f, chunk_size)
28
+
29
+
30
+ async def blake3_hash_async(
31
+ fp: str | IO[bytes],
32
+ chunk_size: int = DEFAULT_CHUNK,
33
+ ) -> str:
34
+ """Async wrapper for ``blake3_hash_sync``.
35
+ Uses a worker thread so the event loop remains responsive.
36
+ """
37
+ # If it is a path, open inside the worker thread to keep I/O off the loop.
38
+ if hasattr(fp, "read"):
39
+ return await asyncio.to_thread(blake3_hash, fp, chunk_size)
40
+
41
+ def _worker() -> str:
42
+ with open(os.fspath(fp), "rb") as f:
43
+ return _hash_file_obj(f, chunk_size)
44
+
45
+ return await asyncio.to_thread(_worker)
46
+
47
+
48
+ def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
49
+ """
50
+ Hash an already-open binary file object by streaming in chunks.
51
+ - Seeks to the beginning before reading (if supported).
52
+ - Restores the original position afterward (if tell/seek are supported).
53
+ """
54
+ if chunk_size <= 0:
55
+ chunk_size = DEFAULT_CHUNK
56
+
57
+ # in case file object is already open and not at the beginning, track so can be restored after hashing
58
+ orig_pos = file_obj.tell()
59
+
60
+ try:
61
+ # seek to the beginning before reading
62
+ if orig_pos != 0:
63
+ file_obj.seek(0)
64
+
65
+ h = blake3()
66
+ while True:
67
+ chunk = file_obj.read(chunk_size)
68
+ if not chunk:
69
+ break
70
+ h.update(chunk)
71
+ return h.hexdigest()
72
+ finally:
73
+ # restore original position in file object, if needed
74
+ if orig_pos != 0:
75
+ file_obj.seek(orig_pos)
app/assets/helpers.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+ from aiohttp import web
4
+ from datetime import datetime, timezone
5
+ from pathlib import Path
6
+ from typing import Literal, Any
7
+
8
+ import folder_paths
9
+
10
+
11
+ RootType = Literal["models", "input", "output"]
12
+ ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
13
+
14
+ def get_query_dict(request: web.Request) -> dict[str, Any]:
15
+ """
16
+ Gets a dictionary of query parameters from the request.
17
+
18
+ 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
19
+ """
20
+ query_dict = {
21
+ key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
22
+ for key in request.query.keys()
23
+ }
24
+ return query_dict
25
+
26
+ def list_tree(base_dir: str) -> list[str]:
27
+ out: list[str] = []
28
+ base_abs = os.path.abspath(base_dir)
29
+ if not os.path.isdir(base_abs):
30
+ return out
31
+ for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
32
+ for name in filenames:
33
+ out.append(os.path.abspath(os.path.join(dirpath, name)))
34
+ return out
35
+
36
+ def prefixes_for_root(root: RootType) -> list[str]:
37
+ if root == "models":
38
+ bases: list[str] = []
39
+ for _bucket, paths in get_comfy_models_folders():
40
+ bases.extend(paths)
41
+ return [os.path.abspath(p) for p in bases]
42
+ if root == "input":
43
+ return [os.path.abspath(folder_paths.get_input_directory())]
44
+ if root == "output":
45
+ return [os.path.abspath(folder_paths.get_output_directory())]
46
+ return []
47
+
48
+ def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
49
+ """Escapes %, _ and the escape char itself in a LIKE prefix.
50
+ Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
51
+ """
52
+ s = s.replace(escape, escape + escape) # escape the escape char first
53
+ s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
54
+ return s, escape
55
+
56
+ def fast_asset_file_check(
57
+ *,
58
+ mtime_db: int | None,
59
+ size_db: int | None,
60
+ stat_result: os.stat_result,
61
+ ) -> bool:
62
+ if mtime_db is None:
63
+ return False
64
+ actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
65
+ if int(mtime_db) != int(actual_mtime_ns):
66
+ return False
67
+ sz = int(size_db or 0)
68
+ if sz > 0:
69
+ return int(stat_result.st_size) == sz
70
+ return True
71
+
72
+ def utcnow() -> datetime:
73
+ """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
74
+ return datetime.now(timezone.utc).replace(tzinfo=None)
75
+
76
+ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
77
+ """Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
78
+
79
+ We trust `folder_paths.folder_names_and_paths` and include a category if
80
+ *any* of its base paths lies under the Comfy `models_dir`.
81
+ """
82
+ targets: list[tuple[str, list[str]]] = []
83
+ models_root = os.path.abspath(folder_paths.models_dir)
84
+ for name, values in folder_paths.folder_names_and_paths.items():
85
+ paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
86
+ if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
87
+ targets.append((name, paths))
88
+ return targets
89
+
90
+ def compute_relative_filename(file_path: str) -> str | None:
91
+ """
92
+ Return the model's path relative to the last well-known folder (the model category),
93
+ using forward slashes, eg:
94
+ /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
95
+ /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
96
+
97
+ For non-model paths, returns None.
98
+ NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
99
+ """
100
+ try:
101
+ root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
102
+ except ValueError:
103
+ return None
104
+
105
+ p = Path(rel_path)
106
+ parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
107
+ if not parts:
108
+ return None
109
+
110
+ if root_category == "models":
111
+ # parts[0] is the category ("checkpoints", "vae", etc) – drop it
112
+ inside = parts[1:] if len(parts) > 1 else [parts[0]]
113
+ return "/".join(inside)
114
+ return "/".join(parts) # input/output: keep all parts
115
+
116
+
117
+ def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
118
+ """Given an absolute or relative file path, determine which root category the path belongs to:
119
+ - 'input' if the file resides under `folder_paths.get_input_directory()`
120
+ - 'output' if the file resides under `folder_paths.get_output_directory()`
121
+ - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
122
+
123
+ Returns:
124
+ (root_category, relative_path_inside_that_root)
125
+ For 'models', the relative path is prefixed with the category name:
126
+ e.g. ('models', 'vae/test/sub/ae.safetensors')
127
+
128
+ Raises:
129
+ ValueError: if the path does not belong to input, output, or configured model bases.
130
+ """
131
+ fp_abs = os.path.abspath(file_path)
132
+
133
+ def _is_within(child: str, parent: str) -> bool:
134
+ try:
135
+ return os.path.commonpath([child, parent]) == parent
136
+ except Exception:
137
+ return False
138
+
139
+ def _rel(child: str, parent: str) -> str:
140
+ return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
141
+
142
+ # 1) input
143
+ input_base = os.path.abspath(folder_paths.get_input_directory())
144
+ if _is_within(fp_abs, input_base):
145
+ return "input", _rel(fp_abs, input_base)
146
+
147
+ # 2) output
148
+ output_base = os.path.abspath(folder_paths.get_output_directory())
149
+ if _is_within(fp_abs, output_base):
150
+ return "output", _rel(fp_abs, output_base)
151
+
152
+ # 3) models (check deepest matching base to avoid ambiguity)
153
+ best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
154
+ for bucket, bases in get_comfy_models_folders():
155
+ for b in bases:
156
+ base_abs = os.path.abspath(b)
157
+ if not _is_within(fp_abs, base_abs):
158
+ continue
159
+ cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
160
+ if best is None or cand[0] > best[0]:
161
+ best = cand
162
+
163
+ if best is not None:
164
+ _, bucket, rel_inside = best
165
+ combined = os.path.join(bucket, rel_inside)
166
+ return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
167
+
168
+ raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
169
+
170
+ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
171
+ """Return a tuple (name, tags) derived from a filesystem path.
172
+
173
+ Semantics:
174
+ - Root category is determined by `get_relative_to_root_category_path_of_asset`.
175
+ - The returned `name` is the base filename with extension from the relative path.
176
+ - The returned `tags` are:
177
+ [root_category] + parent folders of the relative path (in order)
178
+ For 'models', this means:
179
+ file '/.../ModelsDir/vae/test_tag/ae.safetensors'
180
+ -> root_category='models', some_path='vae/test_tag/ae.safetensors'
181
+ -> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
182
+
183
+ Raises:
184
+ ValueError: if the path does not belong to input, output, or configured model bases.
185
+ """
186
+ root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
187
+ p = Path(some_path)
188
+ parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
189
+ return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
190
+
191
+ def normalize_tags(tags: list[str] | None) -> list[str]:
192
+ """
193
+ Normalize a list of tags by:
194
+ - Stripping whitespace and converting to lowercase.
195
+ - Removing duplicates.
196
+ """
197
+ return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
198
+
199
+ def collect_models_files() -> list[str]:
200
+ out: list[str] = []
201
+ for folder_name, bases in get_comfy_models_folders():
202
+ rel_files = folder_paths.get_filename_list(folder_name) or []
203
+ for rel_path in rel_files:
204
+ abs_path = folder_paths.get_full_path(folder_name, rel_path)
205
+ if not abs_path:
206
+ continue
207
+ abs_path = os.path.abspath(abs_path)
208
+ allowed = False
209
+ for b in bases:
210
+ base_abs = os.path.abspath(b)
211
+ with contextlib.suppress(Exception):
212
+ if os.path.commonpath([abs_path, base_abs]) == base_abs:
213
+ allowed = True
214
+ break
215
+ if allowed:
216
+ out.append(abs_path)
217
+ return out
app/assets/manager.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from app.database.db import create_session
4
+ from app.assets.api import schemas_out
5
+ from app.assets.database.queries import (
6
+ asset_exists_by_hash,
7
+ fetch_asset_info_asset_and_tags,
8
+ list_asset_infos_page,
9
+ list_tags_with_usage,
10
+ )
11
+
12
+
13
+ def _safe_sort_field(requested: str | None) -> str:
14
+ if not requested:
15
+ return "created_at"
16
+ v = requested.lower()
17
+ if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
18
+ return v
19
+ return "created_at"
20
+
21
+
22
+ def asset_exists(asset_hash: str) -> bool:
23
+ with create_session() as session:
24
+ return asset_exists_by_hash(session, asset_hash=asset_hash)
25
+
26
+ def list_assets(
27
+ include_tags: Sequence[str] | None = None,
28
+ exclude_tags: Sequence[str] | None = None,
29
+ name_contains: str | None = None,
30
+ metadata_filter: dict | None = None,
31
+ limit: int = 20,
32
+ offset: int = 0,
33
+ sort: str = "created_at",
34
+ order: str = "desc",
35
+ owner_id: str = "",
36
+ ) -> schemas_out.AssetsList:
37
+ sort = _safe_sort_field(sort)
38
+ order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
39
+
40
+ with create_session() as session:
41
+ infos, tag_map, total = list_asset_infos_page(
42
+ session,
43
+ owner_id=owner_id,
44
+ include_tags=include_tags,
45
+ exclude_tags=exclude_tags,
46
+ name_contains=name_contains,
47
+ metadata_filter=metadata_filter,
48
+ limit=limit,
49
+ offset=offset,
50
+ sort=sort,
51
+ order=order,
52
+ )
53
+
54
+ summaries: list[schemas_out.AssetSummary] = []
55
+ for info in infos:
56
+ asset = info.asset
57
+ tags = tag_map.get(info.id, [])
58
+ summaries.append(
59
+ schemas_out.AssetSummary(
60
+ id=info.id,
61
+ name=info.name,
62
+ asset_hash=asset.hash if asset else None,
63
+ size=int(asset.size_bytes) if asset else None,
64
+ mime_type=asset.mime_type if asset else None,
65
+ tags=tags,
66
+ preview_url=f"/api/assets/{info.id}/content",
67
+ created_at=info.created_at,
68
+ updated_at=info.updated_at,
69
+ last_access_time=info.last_access_time,
70
+ )
71
+ )
72
+
73
+ return schemas_out.AssetsList(
74
+ assets=summaries,
75
+ total=total,
76
+ has_more=(offset + len(summaries)) < total,
77
+ )
78
+
79
+ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
80
+ with create_session() as session:
81
+ res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
82
+ if not res:
83
+ raise ValueError(f"AssetInfo {asset_info_id} not found")
84
+ info, asset, tag_names = res
85
+ preview_id = info.preview_id
86
+
87
+ return schemas_out.AssetDetail(
88
+ id=info.id,
89
+ name=info.name,
90
+ asset_hash=asset.hash if asset else None,
91
+ size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
92
+ mime_type=asset.mime_type if asset else None,
93
+ tags=tag_names,
94
+ user_metadata=info.user_metadata or {},
95
+ preview_id=preview_id,
96
+ created_at=info.created_at,
97
+ last_access_time=info.last_access_time,
98
+ )
99
+
100
+ def list_tags(
101
+ prefix: str | None = None,
102
+ limit: int = 100,
103
+ offset: int = 0,
104
+ order: str = "count_desc",
105
+ include_zero: bool = True,
106
+ owner_id: str = "",
107
+ ) -> schemas_out.TagsList:
108
+ limit = max(1, min(1000, limit))
109
+ offset = max(0, offset)
110
+
111
+ with create_session() as session:
112
+ rows, total = list_tags_with_usage(
113
+ session,
114
+ prefix=prefix,
115
+ limit=limit,
116
+ offset=offset,
117
+ include_zero=include_zero,
118
+ order=order,
119
+ owner_id=owner_id,
120
+ )
121
+
122
+ tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
123
+ return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
app/assets/scanner.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import time
3
+ import logging
4
+ import os
5
+ import sqlalchemy
6
+
7
+ import folder_paths
8
+ from app.database.db import create_session, dependencies_available
9
+ from app.assets.helpers import (
10
+ collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
11
+ list_tree,prefixes_for_root, escape_like_prefix,
12
+ RootType
13
+ )
14
+ from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
15
+ from app.assets.database.bulk_ops import seed_from_paths_batch
16
+ from app.assets.database.models import Asset, AssetCacheState, AssetInfo
17
+
18
+
19
+ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
20
+ """
21
+ Scan the given roots and seed the assets into the database.
22
+ """
23
+ if not dependencies_available():
24
+ if enable_logging:
25
+ logging.warning("Database dependencies not available, skipping assets scan")
26
+ return
27
+ t_start = time.perf_counter()
28
+ created = 0
29
+ skipped_existing = 0
30
+ paths: list[str] = []
31
+ try:
32
+ existing_paths: set[str] = set()
33
+ for r in roots:
34
+ try:
35
+ survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
36
+ if survivors:
37
+ existing_paths.update(survivors)
38
+ except Exception as e:
39
+ logging.exception("fast DB scan failed for %s: %s", r, e)
40
+
41
+ if "models" in roots:
42
+ paths.extend(collect_models_files())
43
+ if "input" in roots:
44
+ paths.extend(list_tree(folder_paths.get_input_directory()))
45
+ if "output" in roots:
46
+ paths.extend(list_tree(folder_paths.get_output_directory()))
47
+
48
+ specs: list[dict] = []
49
+ tag_pool: set[str] = set()
50
+ for p in paths:
51
+ abs_p = os.path.abspath(p)
52
+ if abs_p in existing_paths:
53
+ skipped_existing += 1
54
+ continue
55
+ try:
56
+ stat_p = os.stat(abs_p, follow_symlinks=False)
57
+ except OSError:
58
+ continue
59
+ # skip empty files
60
+ if not stat_p.st_size:
61
+ continue
62
+ name, tags = get_name_and_tags_from_asset_path(abs_p)
63
+ specs.append(
64
+ {
65
+ "abs_path": abs_p,
66
+ "size_bytes": stat_p.st_size,
67
+ "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
68
+ "info_name": name,
69
+ "tags": tags,
70
+ "fname": compute_relative_filename(abs_p),
71
+ }
72
+ )
73
+ for t in tags:
74
+ tag_pool.add(t)
75
+ # if no file specs, nothing to do
76
+ if not specs:
77
+ return
78
+ with create_session() as sess:
79
+ if tag_pool:
80
+ ensure_tags_exist(sess, tag_pool, tag_type="user")
81
+
82
+ result = seed_from_paths_batch(sess, specs=specs, owner_id="")
83
+ created += result["inserted_infos"]
84
+ sess.commit()
85
+ finally:
86
+ if enable_logging:
87
+ logging.info(
88
+ "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
89
+ roots,
90
+ time.perf_counter() - t_start,
91
+ created,
92
+ skipped_existing,
93
+ len(paths),
94
+ )
95
+
96
+
97
+ def _fast_db_consistency_pass(
98
+ root: RootType,
99
+ *,
100
+ collect_existing_paths: bool = False,
101
+ update_missing_tags: bool = False,
102
+ ) -> set[str] | None:
103
+ """Fast DB+FS pass for a root:
104
+ - Toggle needs_verify per state using fast check
105
+ - For hashed assets with at least one fast-ok state in this root: delete stale missing states
106
+ - For seed assets with all states missing: delete Asset and its AssetInfos
107
+ - Optionally add/remove 'missing' tags based on fast-ok in this root
108
+ - Optionally return surviving absolute paths
109
+ """
110
+ prefixes = prefixes_for_root(root)
111
+ if not prefixes:
112
+ return set() if collect_existing_paths else None
113
+
114
+ conds = []
115
+ for p in prefixes:
116
+ base = os.path.abspath(p)
117
+ if not base.endswith(os.sep):
118
+ base += os.sep
119
+ escaped, esc = escape_like_prefix(base)
120
+ conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
121
+
122
+ with create_session() as sess:
123
+ rows = (
124
+ sess.execute(
125
+ sqlalchemy.select(
126
+ AssetCacheState.id,
127
+ AssetCacheState.file_path,
128
+ AssetCacheState.mtime_ns,
129
+ AssetCacheState.needs_verify,
130
+ AssetCacheState.asset_id,
131
+ Asset.hash,
132
+ Asset.size_bytes,
133
+ )
134
+ .join(Asset, Asset.id == AssetCacheState.asset_id)
135
+ .where(sqlalchemy.or_(*conds))
136
+ .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
137
+ )
138
+ ).all()
139
+
140
+ by_asset: dict[str, dict] = {}
141
+ for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
142
+ acc = by_asset.get(aid)
143
+ if acc is None:
144
+ acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
145
+ by_asset[aid] = acc
146
+
147
+ fast_ok = False
148
+ try:
149
+ exists = True
150
+ fast_ok = fast_asset_file_check(
151
+ mtime_db=mtime_db,
152
+ size_db=acc["size_db"],
153
+ stat_result=os.stat(fp, follow_symlinks=True),
154
+ )
155
+ except FileNotFoundError:
156
+ exists = False
157
+ except OSError:
158
+ exists = False
159
+
160
+ acc["states"].append({
161
+ "sid": sid,
162
+ "fp": fp,
163
+ "exists": exists,
164
+ "fast_ok": fast_ok,
165
+ "needs_verify": bool(needs_verify),
166
+ })
167
+
168
+ to_set_verify: list[int] = []
169
+ to_clear_verify: list[int] = []
170
+ stale_state_ids: list[int] = []
171
+ survivors: set[str] = set()
172
+
173
+ for aid, acc in by_asset.items():
174
+ a_hash = acc["hash"]
175
+ states = acc["states"]
176
+ any_fast_ok = any(s["fast_ok"] for s in states)
177
+ all_missing = all(not s["exists"] for s in states)
178
+
179
+ for s in states:
180
+ if not s["exists"]:
181
+ continue
182
+ if s["fast_ok"] and s["needs_verify"]:
183
+ to_clear_verify.append(s["sid"])
184
+ if not s["fast_ok"] and not s["needs_verify"]:
185
+ to_set_verify.append(s["sid"])
186
+
187
+ if a_hash is None:
188
+ if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
189
+ sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
190
+ asset = sess.get(Asset, aid)
191
+ if asset:
192
+ sess.delete(asset)
193
+ else:
194
+ for s in states:
195
+ if s["exists"]:
196
+ survivors.add(os.path.abspath(s["fp"]))
197
+ continue
198
+
199
+ if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
200
+ for s in states:
201
+ if not s["exists"]:
202
+ stale_state_ids.append(s["sid"])
203
+ if update_missing_tags:
204
+ with contextlib.suppress(Exception):
205
+ remove_missing_tag_for_asset_id(sess, asset_id=aid)
206
+ elif update_missing_tags:
207
+ with contextlib.suppress(Exception):
208
+ add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
209
+
210
+ for s in states:
211
+ if s["exists"]:
212
+ survivors.add(os.path.abspath(s["fp"]))
213
+
214
+ if stale_state_ids:
215
+ sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
216
+ if to_set_verify:
217
+ sess.execute(
218
+ sqlalchemy.update(AssetCacheState)
219
+ .where(AssetCacheState.id.in_(to_set_verify))
220
+ .values(needs_verify=True)
221
+ )
222
+ if to_clear_verify:
223
+ sess.execute(
224
+ sqlalchemy.update(AssetCacheState)
225
+ .where(AssetCacheState.id.in_(to_clear_verify))
226
+ .values(needs_verify=False)
227
+ )
228
+ sess.commit()
229
+ return survivors if collect_existing_paths else None
app/custom_node_manager.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import folder_paths
5
+ import glob
6
+ from aiohttp import web
7
+ import json
8
+ import logging
9
+ from functools import lru_cache
10
+
11
+ from utils.json_util import merge_json_recursive
12
+
13
+
14
+ # Extra locale files to load into main.json
15
+ EXTRA_LOCALE_FILES = [
16
+ "nodeDefs.json",
17
+ "commands.json",
18
+ "settings.json",
19
+ ]
20
+
21
+
22
+ def safe_load_json_file(file_path: str) -> dict:
23
+ if not os.path.exists(file_path):
24
+ return {}
25
+
26
+ try:
27
+ with open(file_path, "r", encoding="utf-8") as f:
28
+ return json.load(f)
29
+ except json.JSONDecodeError:
30
+ logging.error(f"Error loading {file_path}")
31
+ return {}
32
+
33
+
34
+ class CustomNodeManager:
35
+ @lru_cache(maxsize=1)
36
+ def build_translations(self):
37
+ """Load all custom nodes translations during initialization. Translations are
38
+ expected to be loaded from `locales/` folder.
39
+
40
+ The folder structure is expected to be the following:
41
+ - custom_nodes/
42
+ - custom_node_1/
43
+ - locales/
44
+ - en/
45
+ - main.json
46
+ - commands.json
47
+ - settings.json
48
+
49
+ returned translations are expected to be in the following format:
50
+ {
51
+ "en": {
52
+ "nodeDefs": {...},
53
+ "commands": {...},
54
+ "settings": {...},
55
+ ...{other main.json keys}
56
+ }
57
+ }
58
+ """
59
+
60
+ translations = {}
61
+
62
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
63
+ # Sort glob results for deterministic ordering
64
+ for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
65
+ locales_dir = os.path.join(custom_node_dir, "locales")
66
+ if not os.path.exists(locales_dir):
67
+ continue
68
+
69
+ for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
70
+ lang_code = os.path.basename(os.path.dirname(lang_dir))
71
+
72
+ if lang_code not in translations:
73
+ translations[lang_code] = {}
74
+
75
+ # Load main.json
76
+ main_file = os.path.join(lang_dir, "main.json")
77
+ node_translations = safe_load_json_file(main_file)
78
+
79
+ # Load extra locale files
80
+ for extra_file in EXTRA_LOCALE_FILES:
81
+ extra_file_path = os.path.join(lang_dir, extra_file)
82
+ key = extra_file.split(".")[0]
83
+ json_data = safe_load_json_file(extra_file_path)
84
+ if json_data:
85
+ node_translations[key] = json_data
86
+
87
+ if node_translations:
88
+ translations[lang_code] = merge_json_recursive(
89
+ translations[lang_code], node_translations
90
+ )
91
+
92
+ return translations
93
+
94
+ def add_routes(self, routes, webapp, loadedModules):
95
+
96
+ example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
97
+
98
+ @routes.get("/workflow_templates")
99
+ async def get_workflow_templates(request):
100
+ """Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
101
+
102
+ files = []
103
+
104
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
105
+ for folder_name in example_workflow_folder_names:
106
+ pattern = os.path.join(folder, f"*/{folder_name}/*.json")
107
+ matched_files = glob.glob(pattern)
108
+ files.extend(matched_files)
109
+
110
+ workflow_templates_dict = (
111
+ {}
112
+ ) # custom_nodes folder name -> example workflow names
113
+ for file in files:
114
+ custom_nodes_name = os.path.basename(
115
+ os.path.dirname(os.path.dirname(file))
116
+ )
117
+ workflow_name = os.path.splitext(os.path.basename(file))[0]
118
+ workflow_templates_dict.setdefault(custom_nodes_name, []).append(
119
+ workflow_name
120
+ )
121
+ return web.json_response(workflow_templates_dict)
122
+
123
+ # Serve workflow templates from custom nodes.
124
+ for module_name, module_dir in loadedModules:
125
+ for folder_name in example_workflow_folder_names:
126
+ workflows_dir = os.path.join(module_dir, folder_name)
127
+
128
+ if os.path.exists(workflows_dir):
129
+ if folder_name != "example_workflows":
130
+ logging.debug(
131
+ "Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
132
+ folder_name, module_name)
133
+
134
+ webapp.add_routes(
135
+ [
136
+ web.static(
137
+ "/api/workflow_templates/" + module_name, workflows_dir
138
+ )
139
+ ]
140
+ )
141
+
142
+ @routes.get("/i18n")
143
+ async def get_i18n(request):
144
+ """Returns translations from all custom nodes' locales folders."""
145
+ return web.json_response(self.build_translations())
app/database/db.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ from app.logger import log_startup_warning
5
+ from utils.install_util import get_missing_requirements_message
6
+ from comfy.cli_args import args
7
+
8
+ _DB_AVAILABLE = False
9
+ Session = None
10
+
11
+
12
+ try:
13
+ from alembic import command
14
+ from alembic.config import Config
15
+ from alembic.runtime.migration import MigrationContext
16
+ from alembic.script import ScriptDirectory
17
+ from sqlalchemy import create_engine
18
+ from sqlalchemy.orm import sessionmaker
19
+
20
+ _DB_AVAILABLE = True
21
+ except ImportError as e:
22
+ log_startup_warning(
23
+ f"""
24
+ ------------------------------------------------------------------------
25
+ Error importing dependencies: {e}
26
+ {get_missing_requirements_message()}
27
+ This error is happening because ComfyUI now uses a local sqlite database.
28
+ ------------------------------------------------------------------------
29
+ """.strip()
30
+ )
31
+
32
+
33
+ def dependencies_available():
34
+ """
35
+ Temporary function to check if the dependencies are available
36
+ """
37
+ return _DB_AVAILABLE
38
+
39
+
40
+ def can_create_session():
41
+ """
42
+ Temporary function to check if the database is available to create a session
43
+ During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
44
+ """
45
+ return dependencies_available() and Session is not None
46
+
47
+
48
+ def get_alembic_config():
49
+ root_path = os.path.join(os.path.dirname(__file__), "../..")
50
+ config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
51
+ scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
52
+
53
+ config = Config(config_path)
54
+ config.set_main_option("script_location", scripts_path)
55
+ config.set_main_option("sqlalchemy.url", args.database_url)
56
+
57
+ return config
58
+
59
+
60
+ def get_db_path():
61
+ url = args.database_url
62
+ if url.startswith("sqlite:///"):
63
+ return url.split("///")[1]
64
+ else:
65
+ raise ValueError(f"Unsupported database URL '{url}'.")
66
+
67
+
68
+ def init_db():
69
+ db_url = args.database_url
70
+ logging.debug(f"Database URL: {db_url}")
71
+ db_path = get_db_path()
72
+ db_exists = os.path.exists(db_path)
73
+
74
+ config = get_alembic_config()
75
+
76
+ # Check if we need to upgrade
77
+ engine = create_engine(db_url)
78
+ conn = engine.connect()
79
+
80
+ context = MigrationContext.configure(conn)
81
+ current_rev = context.get_current_revision()
82
+
83
+ script = ScriptDirectory.from_config(config)
84
+ target_rev = script.get_current_head()
85
+
86
+ if target_rev is None:
87
+ logging.warning("No target revision found.")
88
+ elif current_rev != target_rev:
89
+ # Backup the database pre upgrade
90
+ backup_path = db_path + ".bkp"
91
+ if db_exists:
92
+ shutil.copy(db_path, backup_path)
93
+ else:
94
+ backup_path = None
95
+
96
+ try:
97
+ command.upgrade(config, target_rev)
98
+ logging.info(f"Database upgraded from {current_rev} to {target_rev}")
99
+ except Exception as e:
100
+ if backup_path:
101
+ # Restore the database from backup if upgrade fails
102
+ shutil.copy(backup_path, db_path)
103
+ os.remove(backup_path)
104
+ logging.exception("Error upgrading database: ")
105
+ raise e
106
+
107
+ global Session
108
+ Session = sessionmaker(bind=engine)
109
+
110
+
111
+ def create_session():
112
+ return Session()
app/database/models.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from datetime import datetime
3
+ from sqlalchemy.orm import DeclarativeBase
4
+
5
+ class Base(DeclarativeBase):
6
+ pass
7
+
8
+ def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
9
+ fields = obj.__table__.columns.keys()
10
+ out: dict[str, Any] = {}
11
+ for field in fields:
12
+ val = getattr(obj, field)
13
+ if val is None and not include_none:
14
+ continue
15
+ if isinstance(val, datetime):
16
+ out[field] = val.isoformat()
17
+ else:
18
+ out[field] = val
19
+ return out
20
+
21
+ # TODO: Define models here
app/frontend_management.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import re
6
+ import sys
7
+ import tempfile
8
+ import zipfile
9
+ import importlib
10
+ from dataclasses import dataclass
11
+ from functools import cached_property
12
+ from pathlib import Path
13
+ from typing import Dict, TypedDict, Optional
14
+ from aiohttp import web
15
+ from importlib.metadata import version
16
+
17
+ import requests
18
+ from typing_extensions import NotRequired
19
+
20
+ from utils.install_util import get_missing_requirements_message, requirements_path
21
+
22
+ from comfy.cli_args import DEFAULT_VERSION_STRING
23
+ import app.logger
24
+
25
+
26
+ def frontend_install_warning_message():
27
+ return f"""
28
+ {get_missing_requirements_message()}
29
+
30
+ This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
31
+ """.strip()
32
+
33
+ def parse_version(version: str) -> tuple[int, int, int]:
34
+ return tuple(map(int, version.split(".")))
35
+
36
+ def is_valid_version(version: str) -> bool:
37
+ """Validate if a string is a valid semantic version (X.Y.Z format)."""
38
+ pattern = r"^(\d+)\.(\d+)\.(\d+)$"
39
+ return bool(re.match(pattern, version))
40
+
41
+ def get_installed_frontend_version():
42
+ """Get the currently installed frontend package version."""
43
+ frontend_version_str = version("comfyui-frontend-package")
44
+ return frontend_version_str
45
+
46
+
47
+ def get_required_frontend_version():
48
+ """Get the required frontend version from requirements.txt."""
49
+ try:
50
+ with open(requirements_path, "r", encoding="utf-8") as f:
51
+ for line in f:
52
+ line = line.strip()
53
+ if line.startswith("comfyui-frontend-package=="):
54
+ version_str = line.split("==")[-1]
55
+ if not is_valid_version(version_str):
56
+ logging.error(f"Invalid version format in requirements.txt: {version_str}")
57
+ return None
58
+ return version_str
59
+ logging.error("comfyui-frontend-package not found in requirements.txt")
60
+ return None
61
+ except FileNotFoundError:
62
+ logging.error("requirements.txt not found. Cannot determine required frontend version.")
63
+ return None
64
+ except Exception as e:
65
+ logging.error(f"Error reading requirements.txt: {e}")
66
+ return None
67
+
68
+
69
+ def check_frontend_version():
70
+ """Check if the frontend version is up to date."""
71
+
72
+ try:
73
+ frontend_version_str = get_installed_frontend_version()
74
+ frontend_version = parse_version(frontend_version_str)
75
+ required_frontend_str = get_required_frontend_version()
76
+ required_frontend = parse_version(required_frontend_str)
77
+ if frontend_version < required_frontend:
78
+ app.logger.log_startup_warning(
79
+ f"""
80
+ ________________________________________________________________________
81
+ WARNING WARNING WARNING WARNING WARNING
82
+
83
+ Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
84
+
85
+ {frontend_install_warning_message()}
86
+ ________________________________________________________________________
87
+ """.strip()
88
+ )
89
+ else:
90
+ logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
91
+ except Exception as e:
92
+ logging.error(f"Failed to check frontend version: {e}")
93
+
94
+
95
+ REQUEST_TIMEOUT = 10 # seconds
96
+
97
+
98
+ class Asset(TypedDict):
99
+ url: str
100
+
101
+
102
+ class Release(TypedDict):
103
+ id: int
104
+ tag_name: str
105
+ name: str
106
+ prerelease: bool
107
+ created_at: str
108
+ published_at: str
109
+ body: str
110
+ assets: NotRequired[list[Asset]]
111
+
112
+
113
+ @dataclass
114
+ class FrontEndProvider:
115
+ owner: str
116
+ repo: str
117
+
118
+ @property
119
+ def folder_name(self) -> str:
120
+ return f"{self.owner}_{self.repo}"
121
+
122
+ @property
123
+ def release_url(self) -> str:
124
+ return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
125
+
126
+ @cached_property
127
+ def all_releases(self) -> list[Release]:
128
+ releases = []
129
+ api_url = self.release_url
130
+ while api_url:
131
+ response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
132
+ response.raise_for_status() # Raises an HTTPError if the response was an error
133
+ releases.extend(response.json())
134
+ # GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
135
+ if "next" in response.links:
136
+ api_url = response.links["next"]["url"]
137
+ else:
138
+ api_url = None
139
+ return releases
140
+
141
+ @cached_property
142
+ def latest_release(self) -> Release:
143
+ latest_release_url = f"{self.release_url}/latest"
144
+ response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
145
+ response.raise_for_status() # Raises an HTTPError if the response was an error
146
+ return response.json()
147
+
148
+ @cached_property
149
+ def latest_prerelease(self) -> Release:
150
+ """Get the latest pre-release version - even if it's older than the latest release"""
151
+ release = [release for release in self.all_releases if release["prerelease"]]
152
+
153
+ if not release:
154
+ raise ValueError("No pre-releases found")
155
+
156
+ # GitHub returns releases in reverse chronological order, so first is latest
157
+ return release[0]
158
+
159
+ def get_release(self, version: str) -> Release:
160
+ if version == "latest":
161
+ return self.latest_release
162
+ elif version == "prerelease":
163
+ return self.latest_prerelease
164
+ else:
165
+ for release in self.all_releases:
166
+ if release["tag_name"] in [version, f"v{version}"]:
167
+ return release
168
+ raise ValueError(f"Version {version} not found in releases")
169
+
170
+
171
+ def download_release_asset_zip(release: Release, destination_path: str) -> None:
172
+ """Download dist.zip from github release."""
173
+ asset_url = None
174
+ for asset in release.get("assets", []):
175
+ if asset["name"] == "dist.zip":
176
+ asset_url = asset["url"]
177
+ break
178
+
179
+ if not asset_url:
180
+ raise ValueError("dist.zip not found in the release assets")
181
+
182
+ # Use a temporary file to download the zip content
183
+ with tempfile.TemporaryFile() as tmp_file:
184
+ headers = {"Accept": "application/octet-stream"}
185
+ response = requests.get(
186
+ asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
187
+ )
188
+ response.raise_for_status() # Ensure we got a successful response
189
+
190
+ # Write the content to the temporary file
191
+ tmp_file.write(response.content)
192
+
193
+ # Go back to the beginning of the temporary file
194
+ tmp_file.seek(0)
195
+
196
+ # Extract the zip file content to the destination path
197
+ with zipfile.ZipFile(tmp_file, "r") as zip_ref:
198
+ zip_ref.extractall(destination_path)
199
+
200
+
201
+ class FrontendManager:
202
+ CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
203
+
204
+ @classmethod
205
+ def get_required_frontend_version(cls) -> str:
206
+ """Get the required frontend package version."""
207
+ return get_required_frontend_version()
208
+
209
+ @classmethod
210
+ def get_installed_templates_version(cls) -> str:
211
+ """Get the currently installed workflow templates package version."""
212
+ try:
213
+ templates_version_str = version("comfyui-workflow-templates")
214
+ return templates_version_str
215
+ except Exception:
216
+ return None
217
+
218
+ @classmethod
219
+ def get_required_templates_version(cls) -> str:
220
+ """Get the required workflow templates version from requirements.txt."""
221
+ try:
222
+ with open(requirements_path, "r", encoding="utf-8") as f:
223
+ for line in f:
224
+ line = line.strip()
225
+ if line.startswith("comfyui-workflow-templates=="):
226
+ version_str = line.split("==")[-1]
227
+ if not is_valid_version(version_str):
228
+ logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
229
+ return None
230
+ return version_str
231
+ logging.error("comfyui-workflow-templates not found in requirements.txt")
232
+ return None
233
+ except FileNotFoundError:
234
+ logging.error("requirements.txt not found. Cannot determine required templates version.")
235
+ return None
236
+ except Exception as e:
237
+ logging.error(f"Error reading requirements.txt: {e}")
238
+ return None
239
+
240
+ @classmethod
241
+ def default_frontend_path(cls) -> str:
242
+ try:
243
+ import comfyui_frontend_package
244
+
245
+ return str(importlib.resources.files(comfyui_frontend_package) / "static")
246
+ except ImportError:
247
+ logging.error(
248
+ f"""
249
+ ********** ERROR ***********
250
+
251
+ comfyui-frontend-package is not installed.
252
+
253
+ {frontend_install_warning_message()}
254
+
255
+ ********** ERROR ***********
256
+ """.strip()
257
+ )
258
+ sys.exit(-1)
259
+
260
+ @classmethod
261
+ def template_asset_map(cls) -> Optional[Dict[str, str]]:
262
+ """Return a mapping of template asset names to their absolute paths."""
263
+ try:
264
+ from comfyui_workflow_templates import (
265
+ get_asset_path,
266
+ iter_templates,
267
+ )
268
+ except ImportError:
269
+ logging.error(
270
+ f"""
271
+ ********** ERROR ***********
272
+
273
+ comfyui-workflow-templates is not installed.
274
+
275
+ {frontend_install_warning_message()}
276
+
277
+ ********** ERROR ***********
278
+ """.strip()
279
+ )
280
+ return None
281
+
282
+ try:
283
+ template_entries = list(iter_templates())
284
+ except Exception as exc:
285
+ logging.error(f"Failed to enumerate workflow templates: {exc}")
286
+ return None
287
+
288
+ asset_map: Dict[str, str] = {}
289
+ try:
290
+ for entry in template_entries:
291
+ for asset in entry.assets:
292
+ asset_map[asset.filename] = get_asset_path(
293
+ entry.template_id, asset.filename
294
+ )
295
+ except Exception as exc:
296
+ logging.error(f"Failed to resolve template asset paths: {exc}")
297
+ return None
298
+
299
+ if not asset_map:
300
+ logging.error("No workflow template assets found. Did the packages install correctly?")
301
+ return None
302
+
303
+ return asset_map
304
+
305
+
306
+ @classmethod
307
+ def legacy_templates_path(cls) -> Optional[str]:
308
+ """Return the legacy templates directory shipped inside the meta package."""
309
+ try:
310
+ import comfyui_workflow_templates
311
+
312
+ return str(
313
+ importlib.resources.files(comfyui_workflow_templates) / "templates"
314
+ )
315
+ except ImportError:
316
+ logging.error(
317
+ f"""
318
+ ********** ERROR ***********
319
+
320
+ comfyui-workflow-templates is not installed.
321
+
322
+ {frontend_install_warning_message()}
323
+
324
+ ********** ERROR ***********
325
+ """.strip()
326
+ )
327
+ return None
328
+
329
+ @classmethod
330
+ def embedded_docs_path(cls) -> str:
331
+ """Get the path to embedded documentation"""
332
+ try:
333
+ import comfyui_embedded_docs
334
+
335
+ return str(
336
+ importlib.resources.files(comfyui_embedded_docs) / "docs"
337
+ )
338
+ except ImportError:
339
+ logging.info("comfyui-embedded-docs package not found")
340
+ return None
341
+
342
+ @classmethod
343
+ def parse_version_string(cls, value: str) -> tuple[str, str, str]:
344
+ """
345
+ Args:
346
+ value (str): The version string to parse.
347
+
348
+ Returns:
349
+ tuple[str, str]: A tuple containing provider name and version.
350
+
351
+ Raises:
352
+ argparse.ArgumentTypeError: If the version string is invalid.
353
+ """
354
+ VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
355
+ match_result = re.match(VERSION_PATTERN, value)
356
+ if match_result is None:
357
+ raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
358
+
359
+ return match_result.group(1), match_result.group(2), match_result.group(3)
360
+
361
+ @classmethod
362
+ def init_frontend_unsafe(
363
+ cls, version_string: str, provider: Optional[FrontEndProvider] = None
364
+ ) -> str:
365
+ """
366
+ Initializes the frontend for the specified version.
367
+
368
+ Args:
369
+ version_string (str): The version string.
370
+ provider (FrontEndProvider, optional): The provider to use. Defaults to None.
371
+
372
+ Returns:
373
+ str: The path to the initialized frontend.
374
+
375
+ Raises:
376
+ Exception: If there is an error during the initialization process.
377
+ main error source might be request timeout or invalid URL.
378
+ """
379
+ if version_string == DEFAULT_VERSION_STRING:
380
+ check_frontend_version()
381
+ return cls.default_frontend_path()
382
+
383
+ repo_owner, repo_name, version = cls.parse_version_string(version_string)
384
+
385
+ if version.startswith("v"):
386
+ expected_path = str(
387
+ Path(cls.CUSTOM_FRONTENDS_ROOT)
388
+ / f"{repo_owner}_{repo_name}"
389
+ / version.lstrip("v")
390
+ )
391
+ if os.path.exists(expected_path):
392
+ logging.info(
393
+ f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
394
+ )
395
+ return expected_path
396
+
397
+ logging.info(
398
+ f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
399
+ )
400
+
401
+ provider = provider or FrontEndProvider(repo_owner, repo_name)
402
+ release = provider.get_release(version)
403
+
404
+ semantic_version = release["tag_name"].lstrip("v")
405
+ web_root = str(
406
+ Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
407
+ )
408
+ if not os.path.exists(web_root):
409
+ try:
410
+ os.makedirs(web_root, exist_ok=True)
411
+ logging.info(
412
+ "Downloading frontend(%s) version(%s) to (%s)",
413
+ provider.folder_name,
414
+ semantic_version,
415
+ web_root,
416
+ )
417
+ logging.debug(release)
418
+ download_release_asset_zip(release, destination_path=web_root)
419
+ finally:
420
+ # Clean up the directory if it is empty, i.e. the download failed
421
+ if not os.listdir(web_root):
422
+ os.rmdir(web_root)
423
+
424
+ return web_root
425
+
426
+ @classmethod
427
+ def init_frontend(cls, version_string: str) -> str:
428
+ """
429
+ Initializes the frontend with the specified version string.
430
+
431
+ Args:
432
+ version_string (str): The version string to initialize the frontend with.
433
+
434
+ Returns:
435
+ str: The path of the initialized frontend.
436
+ """
437
+ try:
438
+ return cls.init_frontend_unsafe(version_string)
439
+ except Exception as e:
440
+ logging.error("Failed to initialize frontend: %s", e)
441
+ logging.info("Falling back to the default frontend.")
442
+ check_frontend_version()
443
+ return cls.default_frontend_path()
444
+ @classmethod
445
+ def template_asset_handler(cls):
446
+ assets = cls.template_asset_map()
447
+ if not assets:
448
+ return None
449
+
450
+ async def serve_template(request: web.Request) -> web.StreamResponse:
451
+ rel_path = request.match_info.get("path", "")
452
+ target = assets.get(rel_path)
453
+ if target is None:
454
+ raise web.HTTPNotFound()
455
+ return web.FileResponse(target)
456
+
457
+ return serve_template
app/logger.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from datetime import datetime
3
+ import io
4
+ import logging
5
+ import sys
6
+ import threading
7
+
8
+ logs = None
9
+ stdout_interceptor = None
10
+ stderr_interceptor = None
11
+
12
+
13
+ class LogInterceptor(io.TextIOWrapper):
14
+ def __init__(self, stream, *args, **kwargs):
15
+ buffer = stream.buffer
16
+ encoding = stream.encoding
17
+ super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
18
+ self._lock = threading.Lock()
19
+ self._flush_callbacks = []
20
+ self._logs_since_flush = []
21
+
22
+ def write(self, data):
23
+ entry = {"t": datetime.now().isoformat(), "m": data}
24
+ with self._lock:
25
+ self._logs_since_flush.append(entry)
26
+
27
+ # Simple handling for cr to overwrite the last output if it isnt a full line
28
+ # else logs just get full of progress messages
29
+ if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
30
+ logs.pop()
31
+ logs.append(entry)
32
+ super().write(data)
33
+
34
+ def flush(self):
35
+ super().flush()
36
+ for cb in self._flush_callbacks:
37
+ cb(self._logs_since_flush)
38
+ self._logs_since_flush = []
39
+
40
+ def on_flush(self, callback):
41
+ self._flush_callbacks.append(callback)
42
+
43
+
44
+ def get_logs():
45
+ return logs
46
+
47
+
48
+ def on_flush(callback):
49
+ if stdout_interceptor is not None:
50
+ stdout_interceptor.on_flush(callback)
51
+ if stderr_interceptor is not None:
52
+ stderr_interceptor.on_flush(callback)
53
+
54
+ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
55
+ global logs
56
+ if logs:
57
+ return
58
+
59
+ # Override output streams and log to buffer
60
+ logs = deque(maxlen=capacity)
61
+
62
+ global stdout_interceptor
63
+ global stderr_interceptor
64
+ stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
65
+ stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
66
+
67
+ # Setup default global logger
68
+ logger = logging.getLogger()
69
+ logger.setLevel(log_level)
70
+
71
+ stream_handler = logging.StreamHandler()
72
+ stream_handler.setFormatter(logging.Formatter("%(message)s"))
73
+
74
+ if use_stdout:
75
+ # Only errors and critical to stderr
76
+ stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
77
+
78
+ # Lesser to stdout
79
+ stdout_handler = logging.StreamHandler(sys.stdout)
80
+ stdout_handler.setFormatter(logging.Formatter("%(message)s"))
81
+ stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
82
+ logger.addHandler(stdout_handler)
83
+
84
+ logger.addHandler(stream_handler)
85
+
86
+
87
+ STARTUP_WARNINGS = []
88
+
89
+
90
+ def log_startup_warning(msg):
91
+ logging.warning(msg)
92
+ STARTUP_WARNINGS.append(msg)
93
+
94
+
95
+ def print_startup_warnings():
96
+ for s in STARTUP_WARNINGS:
97
+ logging.warning(s)
98
+ STARTUP_WARNINGS.clear()
app/model_manager.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import base64
5
+ import json
6
+ import time
7
+ import logging
8
+ import folder_paths
9
+ import glob
10
+ import comfy.utils
11
+ from aiohttp import web
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
15
+
16
+
17
+ class ModelFileManager:
18
+ def __init__(self) -> None:
19
+ self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
20
+
21
+ def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
22
+ return self.cache.get(key, default)
23
+
24
+ def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
25
+ self.cache[key] = value
26
+
27
+ def clear_cache(self):
28
+ self.cache.clear()
29
+
30
+ def add_routes(self, routes):
31
+ # NOTE: This is an experiment to replace `/models`
32
+ @routes.get("/experiment/models")
33
+ async def get_model_folders(request):
34
+ model_types = list(folder_paths.folder_names_and_paths.keys())
35
+ folder_black_list = ["configs", "custom_nodes"]
36
+ output_folders: list[dict] = []
37
+ for folder in model_types:
38
+ if folder in folder_black_list:
39
+ continue
40
+ output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
41
+ return web.json_response(output_folders)
42
+
43
+ # NOTE: This is an experiment to replace `/models/{folder}`
44
+ @routes.get("/experiment/models/{folder}")
45
+ async def get_all_models(request):
46
+ folder = request.match_info.get("folder", None)
47
+ if folder not in folder_paths.folder_names_and_paths:
48
+ return web.Response(status=404)
49
+ files = self.get_model_file_list(folder)
50
+ return web.json_response(files)
51
+
52
+ @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
53
+ async def get_model_preview(request):
54
+ folder_name = request.match_info.get("folder", None)
55
+ path_index = int(request.match_info.get("path_index", None))
56
+ filename = request.match_info.get("filename", None)
57
+
58
+ if folder_name not in folder_paths.folder_names_and_paths:
59
+ return web.Response(status=404)
60
+
61
+ folders = folder_paths.folder_names_and_paths[folder_name]
62
+ folder = folders[0][path_index]
63
+ full_filename = os.path.join(folder, filename)
64
+
65
+ previews = self.get_model_previews(full_filename)
66
+ default_preview = previews[0] if len(previews) > 0 else None
67
+ if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
68
+ return web.Response(status=404)
69
+
70
+ try:
71
+ with Image.open(default_preview) as img:
72
+ img_bytes = BytesIO()
73
+ img.save(img_bytes, format="WEBP")
74
+ img_bytes.seek(0)
75
+ return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
76
+ except:
77
+ return web.Response(status=404)
78
+
79
+ def get_model_file_list(self, folder_name: str):
80
+ folder_name = map_legacy(folder_name)
81
+ folders = folder_paths.folder_names_and_paths[folder_name]
82
+ output_list: list[dict] = []
83
+
84
+ for index, folder in enumerate(folders[0]):
85
+ if not os.path.isdir(folder):
86
+ continue
87
+ out = self.cache_model_file_list_(folder)
88
+ if out is None:
89
+ out = self.recursive_search_models_(folder, index)
90
+ self.set_cache(folder, out)
91
+ output_list.extend(out[0])
92
+
93
+ return output_list
94
+
95
+ def cache_model_file_list_(self, folder: str):
96
+ model_file_list_cache = self.get_cache(folder)
97
+
98
+ if model_file_list_cache is None:
99
+ return None
100
+ if not os.path.isdir(folder):
101
+ return None
102
+ if os.path.getmtime(folder) != model_file_list_cache[1]:
103
+ return None
104
+ for x in model_file_list_cache[1]:
105
+ time_modified = model_file_list_cache[1][x]
106
+ folder = x
107
+ if os.path.getmtime(folder) != time_modified:
108
+ return None
109
+
110
+ return model_file_list_cache
111
+
112
+ def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
113
+ if not os.path.isdir(directory):
114
+ return [], {}, time.perf_counter()
115
+
116
+ excluded_dir_names = [".git"]
117
+ # TODO use settings
118
+ include_hidden_files = False
119
+
120
+ result: list[str] = []
121
+ dirs: dict[str, float] = {}
122
+
123
+ for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
124
+ subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
125
+ if not include_hidden_files:
126
+ subdirs[:] = [d for d in subdirs if not d.startswith(".")]
127
+ filenames = [f for f in filenames if not f.startswith(".")]
128
+
129
+ filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
130
+
131
+ for file_name in filenames:
132
+ try:
133
+ full_path = os.path.join(dirpath, file_name)
134
+ relative_path = os.path.relpath(full_path, directory)
135
+
136
+ # Get file metadata
137
+ file_info = {
138
+ "name": relative_path,
139
+ "pathIndex": pathIndex,
140
+ "modified": os.path.getmtime(full_path), # Add modification time
141
+ "created": os.path.getctime(full_path), # Add creation time
142
+ "size": os.path.getsize(full_path) # Add file size
143
+ }
144
+ result.append(file_info)
145
+
146
+ except Exception as e:
147
+ logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
148
+ continue
149
+
150
+ for d in subdirs:
151
+ path: str = os.path.join(dirpath, d)
152
+ try:
153
+ dirs[path] = os.path.getmtime(path)
154
+ except FileNotFoundError:
155
+ logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
156
+ continue
157
+
158
+ return result, dirs, time.perf_counter()
159
+
160
+ def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
161
+ dirname = os.path.dirname(filepath)
162
+
163
+ if not os.path.exists(dirname):
164
+ return []
165
+
166
+ basename = os.path.splitext(filepath)[0]
167
+ match_files = glob.glob(f"{basename}.*", recursive=False)
168
+ image_files = filter_files_content_types(match_files, "image")
169
+ safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
170
+ safetensors_metadata = {}
171
+
172
+ result: list[str | BytesIO] = []
173
+
174
+ for filename in image_files:
175
+ _basename = os.path.splitext(filename)[0]
176
+ if _basename == basename:
177
+ result.append(filename)
178
+ if _basename == f"{basename}.preview":
179
+ result.append(filename)
180
+
181
+ if safetensors_file:
182
+ safetensors_filepath = os.path.join(dirname, safetensors_file)
183
+ header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
184
+ if header:
185
+ safetensors_metadata = json.loads(header)
186
+ safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
187
+ if safetensors_images:
188
+ safetensors_images = json.loads(safetensors_images)
189
+ for image in safetensors_images:
190
+ result.append(BytesIO(base64.b64decode(image)))
191
+
192
+ return result
193
+
194
+ def __exit__(self, exc_type, exc_value, traceback):
195
+ self.clear_cache()
app/subgraph_manager.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import TypedDict
4
+ import os
5
+ import folder_paths
6
+ import glob
7
+ from aiohttp import web
8
+ import hashlib
9
+
10
+
11
+ class Source:
12
+ custom_node = "custom_node"
13
+ templates = "templates"
14
+
15
+ class SubgraphEntry(TypedDict):
16
+ source: str
17
+ """
18
+ Source of subgraph - custom_nodes vs templates.
19
+ """
20
+ path: str
21
+ """
22
+ Relative path of the subgraph file.
23
+ For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
24
+ """
25
+ name: str
26
+ """
27
+ Name of subgraph file.
28
+ """
29
+ info: CustomNodeSubgraphEntryInfo
30
+ """
31
+ Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
32
+ """
33
+ data: str
34
+
35
+ class CustomNodeSubgraphEntryInfo(TypedDict):
36
+ node_pack: str
37
+ """Node pack name."""
38
+
39
+ class SubgraphManager:
40
+ def __init__(self):
41
+ self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
42
+ self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
43
+
44
+ def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
45
+ """Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
46
+ entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
47
+ entry: SubgraphEntry = {
48
+ "source": source,
49
+ "name": os.path.splitext(os.path.basename(file))[0],
50
+ "path": file,
51
+ "info": {"node_pack": node_pack},
52
+ }
53
+ return entry_id, entry
54
+
55
+ async def load_entry_data(self, entry: SubgraphEntry):
56
+ with open(entry['path'], 'r') as f:
57
+ entry['data'] = f.read()
58
+ return entry
59
+
60
+ async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
61
+ if entry is None:
62
+ return None
63
+ entry = entry.copy()
64
+ entry.pop('path', None)
65
+ if remove_data:
66
+ entry.pop('data', None)
67
+ return entry
68
+
69
+ async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
70
+ entries = entries.copy()
71
+ for key in list(entries.keys()):
72
+ entries[key] = await self.sanitize_entry(entries[key], remove_data)
73
+ return entries
74
+
75
+ async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
76
+ """Load subgraphs from custom nodes."""
77
+ if not force_reload and self.cached_custom_node_subgraphs is not None:
78
+ return self.cached_custom_node_subgraphs
79
+
80
+ subgraphs_dict: dict[SubgraphEntry] = {}
81
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
82
+ pattern = os.path.join(folder, "*/subgraphs/*.json")
83
+ for file in glob.glob(pattern):
84
+ file = file.replace('\\', '/')
85
+ node_pack = "custom_nodes." + file.split('/')[-3]
86
+ entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
87
+ subgraphs_dict[entry_id] = entry
88
+
89
+ self.cached_custom_node_subgraphs = subgraphs_dict
90
+ return subgraphs_dict
91
+
92
+ async def get_blueprint_subgraphs(self, force_reload=False):
93
+ """Load subgraphs from the blueprints directory."""
94
+ if not force_reload and self.cached_blueprint_subgraphs is not None:
95
+ return self.cached_blueprint_subgraphs
96
+
97
+ subgraphs_dict: dict[SubgraphEntry] = {}
98
+ blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
99
+
100
+ if os.path.exists(blueprints_dir):
101
+ for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
102
+ file = file.replace('\\', '/')
103
+ entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
104
+ subgraphs_dict[entry_id] = entry
105
+
106
+ self.cached_blueprint_subgraphs = subgraphs_dict
107
+ return subgraphs_dict
108
+
109
+ async def get_all_subgraphs(self, loadedModules, force_reload=False):
110
+ """Get all subgraphs from all sources (custom nodes and blueprints)."""
111
+ custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
112
+ blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
113
+ return {**custom_node_subgraphs, **blueprint_subgraphs}
114
+
115
+ async def get_subgraph(self, id: str, loadedModules):
116
+ """Get a specific subgraph by ID from any source."""
117
+ entry = (await self.get_all_subgraphs(loadedModules)).get(id)
118
+ if entry is not None and entry.get('data') is None:
119
+ await self.load_entry_data(entry)
120
+ return entry
121
+
122
+ def add_routes(self, routes, loadedModules):
123
+ @routes.get("/global_subgraphs")
124
+ async def get_global_subgraphs(request):
125
+ subgraphs_dict = await self.get_all_subgraphs(loadedModules)
126
+ return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
127
+
128
+ @routes.get("/global_subgraphs/{id}")
129
+ async def get_global_subgraph(request):
130
+ id = request.match_info.get("id", None)
131
+ subgraph = await self.get_subgraph(id, loadedModules)
132
+ return web.json_response(await self.sanitize_entry(subgraph))
app/user_manager.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ import os
4
+ import re
5
+ import uuid
6
+ import glob
7
+ import shutil
8
+ import logging
9
+ from aiohttp import web
10
+ from urllib import parse
11
+ from comfy.cli_args import args
12
+ import folder_paths
13
+ from .app_settings import AppSettings
14
+ from typing import TypedDict
15
+
16
+ default_user = "default"
17
+
18
+
19
+ class FileInfo(TypedDict):
20
+ path: str
21
+ size: int
22
+ modified: int
23
+ created: int
24
+
25
+
26
+ def get_file_info(path: str, relative_to: str) -> FileInfo:
27
+ return {
28
+ "path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
29
+ "size": os.path.getsize(path),
30
+ "modified": os.path.getmtime(path),
31
+ "created": os.path.getctime(path)
32
+ }
33
+
34
+
35
+ class UserManager():
36
+ def __init__(self):
37
+ user_directory = folder_paths.get_user_directory()
38
+
39
+ self.settings = AppSettings(self)
40
+ if not os.path.exists(user_directory):
41
+ os.makedirs(user_directory, exist_ok=True)
42
+ if not args.multi_user:
43
+ logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
44
+ logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
45
+
46
+ if args.multi_user:
47
+ if os.path.isfile(self.get_users_file()):
48
+ with open(self.get_users_file()) as f:
49
+ self.users = json.load(f)
50
+ else:
51
+ self.users = {}
52
+ else:
53
+ self.users = {"default": "default"}
54
+
55
+ def get_users_file(self):
56
+ return os.path.join(folder_paths.get_user_directory(), "users.json")
57
+
58
+ def get_request_user_id(self, request):
59
+ user = "default"
60
+ if args.multi_user and "comfy-user" in request.headers:
61
+ user = request.headers["comfy-user"]
62
+ # Block System Users (use same error message to prevent probing)
63
+ if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
64
+ raise KeyError("Unknown user: " + user)
65
+
66
+ if user not in self.users:
67
+ raise KeyError("Unknown user: " + user)
68
+
69
+ return user
70
+
71
+ def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
72
+ if type == "userdata":
73
+ root_dir = folder_paths.get_user_directory()
74
+ else:
75
+ raise KeyError("Unknown filepath type:" + type)
76
+
77
+ user = self.get_request_user_id(request)
78
+ user_root = folder_paths.get_public_user_directory(user)
79
+ if user_root is None:
80
+ return None
81
+ path = user_root
82
+
83
+ # prevent leaving /{type}
84
+ if os.path.commonpath((root_dir, user_root)) != root_dir:
85
+ return None
86
+
87
+ if file is not None:
88
+ # Check if filename is url encoded
89
+ if "%" in file:
90
+ file = parse.unquote(file)
91
+
92
+ # prevent leaving /{type}/{user}
93
+ path = os.path.abspath(os.path.join(user_root, file))
94
+ if os.path.commonpath((user_root, path)) != user_root:
95
+ return None
96
+
97
+ parent = os.path.split(path)[0]
98
+
99
+ if create_dir and not os.path.exists(parent):
100
+ os.makedirs(parent, exist_ok=True)
101
+
102
+ return path
103
+
104
+ def add_user(self, name):
105
+ name = name.strip()
106
+ if not name:
107
+ raise ValueError("username not provided")
108
+ if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
109
+ raise ValueError("System User prefix not allowed")
110
+ user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
111
+ if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
112
+ raise ValueError("System User prefix not allowed")
113
+ user_id = user_id + "_" + str(uuid.uuid4())
114
+
115
+ self.users[user_id] = name
116
+
117
+ with open(self.get_users_file(), "w") as f:
118
+ json.dump(self.users, f)
119
+
120
+ return user_id
121
+
122
+ def add_routes(self, routes):
123
+ self.settings.add_routes(routes)
124
+
125
+ @routes.get("/users")
126
+ async def get_users(request):
127
+ if args.multi_user:
128
+ return web.json_response({"storage": "server", "users": self.users})
129
+ else:
130
+ user_dir = self.get_request_user_filepath(request, None, create_dir=False)
131
+ return web.json_response({
132
+ "storage": "server",
133
+ "migrated": os.path.exists(user_dir)
134
+ })
135
+
136
+ @routes.post("/users")
137
+ async def post_users(request):
138
+ body = await request.json()
139
+ username = body["username"]
140
+ if username in self.users.values():
141
+ return web.json_response({"error": "Duplicate username."}, status=400)
142
+
143
+ try:
144
+ user_id = self.add_user(username)
145
+ except ValueError as e:
146
+ return web.json_response({"error": str(e)}, status=400)
147
+ return web.json_response(user_id)
148
+
149
+ @routes.get("/userdata")
150
+ async def listuserdata(request):
151
+ """
152
+ List user data files in a specified directory.
153
+
154
+ This endpoint allows listing files in a user's data directory, with options for recursion,
155
+ full file information, and path splitting.
156
+
157
+ Query Parameters:
158
+ - dir (required): The directory to list files from.
159
+ - recurse (optional): If "true", recursively list files in subdirectories.
160
+ - full_info (optional): If "true", return detailed file information (path, size, modified time).
161
+ - split (optional): If "true", split file paths into components (only applies when full_info is false).
162
+
163
+ Returns:
164
+ - 400: If 'dir' parameter is missing.
165
+ - 403: If the requested path is not allowed.
166
+ - 404: If the requested directory does not exist.
167
+ - 200: JSON response with the list of files or file information.
168
+
169
+ The response format depends on the query parameters:
170
+ - Default: List of relative file paths.
171
+ - full_info=true: List of dictionaries with file details.
172
+ - split=true (and full_info=false): List of lists, each containing path components.
173
+ """
174
+ directory = request.rel_url.query.get('dir', '')
175
+ if not directory:
176
+ return web.Response(status=400, text="Directory not provided")
177
+
178
+ path = self.get_request_user_filepath(request, directory)
179
+ if not path:
180
+ return web.Response(status=403, text="Invalid directory")
181
+
182
+ if not os.path.exists(path):
183
+ return web.Response(status=404, text="Directory not found")
184
+
185
+ recurse = request.rel_url.query.get('recurse', '').lower() == "true"
186
+ full_info = request.rel_url.query.get('full_info', '').lower() == "true"
187
+ split_path = request.rel_url.query.get('split', '').lower() == "true"
188
+
189
+ # Use different patterns based on whether we're recursing or not
190
+ if recurse:
191
+ pattern = os.path.join(glob.escape(path), '**', '*')
192
+ else:
193
+ pattern = os.path.join(glob.escape(path), '*')
194
+
195
+ def process_full_path(full_path: str) -> FileInfo | str | list[str]:
196
+ if full_info:
197
+ return get_file_info(full_path, path)
198
+
199
+ rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
200
+ if split_path:
201
+ return [rel_path] + rel_path.split('/')
202
+
203
+ return rel_path
204
+
205
+ results = [
206
+ process_full_path(full_path)
207
+ for full_path in glob.glob(pattern, recursive=recurse)
208
+ if os.path.isfile(full_path)
209
+ ]
210
+
211
+ return web.json_response(results)
212
+
213
+ @routes.get("/v2/userdata")
214
+ async def list_userdata_v2(request):
215
+ """
216
+ List files and directories in a user's data directory.
217
+
218
+ This endpoint provides a structured listing of contents within a specified
219
+ subdirectory of the user's data storage.
220
+
221
+ Query Parameters:
222
+ - path (optional): The relative path within the user's data directory
223
+ to list. Defaults to the root ('').
224
+
225
+ Returns:
226
+ - 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
227
+ - 404: If the requested path does not exist.
228
+ - 403: If the user is invalid.
229
+ - 500: If there is an error reading the directory contents.
230
+ - 200: JSON response containing a list of file and directory objects.
231
+ Each object includes:
232
+ - name: The name of the file or directory.
233
+ - type: 'file' or 'directory'.
234
+ - path: The relative path from the user's data root.
235
+ - size (for files): The size in bytes.
236
+ - modified (for files): The last modified timestamp (Unix epoch).
237
+ """
238
+ requested_rel_path = request.rel_url.query.get('path', '')
239
+
240
+ # URL-decode the path parameter
241
+ try:
242
+ requested_rel_path = parse.unquote(requested_rel_path)
243
+ except Exception as e:
244
+ logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
245
+ return web.Response(status=400, text="Invalid characters in path parameter")
246
+
247
+
248
+ # Check user validity and get the absolute path for the requested directory
249
+ try:
250
+ base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
251
+
252
+ if requested_rel_path:
253
+ target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
254
+ else:
255
+ target_abs_path = base_user_path
256
+
257
+ except KeyError as e:
258
+ # Invalid user detected by get_request_user_id inside get_request_user_filepath
259
+ logging.warning(f"Access denied for user: {e}")
260
+ return web.Response(status=403, text="Invalid user specified in request")
261
+
262
+
263
+ if not target_abs_path:
264
+ # Path traversal or other issue detected by get_request_user_filepath
265
+ return web.Response(status=400, text="Invalid path requested")
266
+
267
+ # Handle cases where the user directory or target path doesn't exist
268
+ if not os.path.exists(target_abs_path):
269
+ # Check if it's the base user directory that's missing (new user case)
270
+ if target_abs_path == base_user_path:
271
+ # It's okay if the base user directory doesn't exist yet, return empty list
272
+ return web.json_response([])
273
+ else:
274
+ # A specific subdirectory was requested but doesn't exist
275
+ return web.Response(status=404, text="Requested path not found")
276
+
277
+ if not os.path.isdir(target_abs_path):
278
+ return web.Response(status=400, text="Requested path is not a directory")
279
+
280
+ results = []
281
+ try:
282
+ for root, dirs, files in os.walk(target_abs_path, topdown=True):
283
+ # Process directories
284
+ for dir_name in dirs:
285
+ dir_path = os.path.join(root, dir_name)
286
+ rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
287
+ results.append({
288
+ "name": dir_name,
289
+ "path": rel_path,
290
+ "type": "directory"
291
+ })
292
+
293
+ # Process files
294
+ for file_name in files:
295
+ file_path = os.path.join(root, file_name)
296
+ rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
297
+ entry_info = {
298
+ "name": file_name,
299
+ "path": rel_path,
300
+ "type": "file"
301
+ }
302
+ try:
303
+ stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
304
+ entry_info["size"] = stats.st_size
305
+ entry_info["modified"] = stats.st_mtime
306
+ except OSError as stat_error:
307
+ logging.warning(f"Could not stat file {file_path}: {stat_error}")
308
+ pass # Include file with available info
309
+ results.append(entry_info)
310
+ except OSError as e:
311
+ logging.error(f"Error listing directory {target_abs_path}: {e}")
312
+ return web.Response(status=500, text="Error reading directory contents")
313
+
314
+ # Sort results alphabetically, directories first then files
315
+ results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
316
+
317
+ return web.json_response(results)
318
+
319
+ def get_user_data_path(request, check_exists = False, param = "file"):
320
+ file = request.match_info.get(param, None)
321
+ if not file:
322
+ return web.Response(status=400)
323
+
324
+ path = self.get_request_user_filepath(request, file)
325
+ if not path:
326
+ return web.Response(status=403)
327
+
328
+ if check_exists and not os.path.exists(path):
329
+ return web.Response(status=404)
330
+
331
+ return path
332
+
333
+ @routes.get("/userdata/{file}")
334
+ async def getuserdata(request):
335
+ path = get_user_data_path(request, check_exists=True)
336
+ if not isinstance(path, str):
337
+ return path
338
+
339
+ return web.FileResponse(path)
340
+
341
+ @routes.post("/userdata/{file}")
342
+ async def post_userdata(request):
343
+ """
344
+ Upload or update a user data file.
345
+
346
+ This endpoint handles file uploads to a user's data directory, with options for
347
+ controlling overwrite behavior and response format.
348
+
349
+ Query Parameters:
350
+ - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
351
+ - full_info (optional): If "true", returns detailed file information (path, size, modified time).
352
+ If "false", returns only the relative file path.
353
+
354
+ Path Parameters:
355
+ - file: The target file path (URL encoded if necessary).
356
+
357
+ Returns:
358
+ - 400: If 'file' parameter is missing.
359
+ - 403: If the requested path is not allowed.
360
+ - 409: If overwrite=false and the file already exists.
361
+ - 200: JSON response with either:
362
+ - Full file information (if full_info=true)
363
+ - Relative file path (if full_info=false)
364
+
365
+ The request body should contain the raw file content to be written.
366
+ """
367
+ path = get_user_data_path(request)
368
+ if not isinstance(path, str):
369
+ return path
370
+
371
+ overwrite = request.query.get("overwrite", 'true') != "false"
372
+ full_info = request.query.get('full_info', 'false').lower() == "true"
373
+
374
+ if not overwrite and os.path.exists(path):
375
+ return web.Response(status=409, text="File already exists")
376
+
377
+ try:
378
+ body = await request.read()
379
+
380
+ with open(path, "wb") as f:
381
+ f.write(body)
382
+ except OSError as e:
383
+ logging.warning(f"Error saving file '{path}': {e}")
384
+ return web.Response(
385
+ status=400,
386
+ reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
387
+ )
388
+
389
+ user_path = self.get_request_user_filepath(request, None)
390
+ if full_info:
391
+ resp = get_file_info(path, user_path)
392
+ else:
393
+ resp = os.path.relpath(path, user_path)
394
+
395
+ return web.json_response(resp)
396
+
397
+ @routes.delete("/userdata/{file}")
398
+ async def delete_userdata(request):
399
+ path = get_user_data_path(request, check_exists=True)
400
+ if not isinstance(path, str):
401
+ return path
402
+
403
+ os.remove(path)
404
+
405
+ return web.Response(status=204)
406
+
407
+ @routes.post("/userdata/{file}/move/{dest}")
408
+ async def move_userdata(request):
409
+ """
410
+ Move or rename a user data file.
411
+
412
+ This endpoint handles moving or renaming files within a user's data directory, with options for
413
+ controlling overwrite behavior and response format.
414
+
415
+ Path Parameters:
416
+ - file: The source file path (URL encoded if necessary)
417
+ - dest: The destination file path (URL encoded if necessary)
418
+
419
+ Query Parameters:
420
+ - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
421
+ - full_info (optional): If "true", returns detailed file information (path, size, modified time).
422
+ If "false", returns only the relative file path.
423
+
424
+ Returns:
425
+ - 400: If either 'file' or 'dest' parameter is missing
426
+ - 403: If either requested path is not allowed
427
+ - 404: If the source file does not exist
428
+ - 409: If overwrite=false and the destination file already exists
429
+ - 200: JSON response with either:
430
+ - Full file information (if full_info=true)
431
+ - Relative file path (if full_info=false)
432
+ """
433
+ source = get_user_data_path(request, check_exists=True)
434
+ if not isinstance(source, str):
435
+ return source
436
+
437
+ dest = get_user_data_path(request, check_exists=False, param="dest")
438
+ if not isinstance(dest, str):
439
+ return dest
440
+
441
+ overwrite = request.query.get("overwrite", 'true') != "false"
442
+ full_info = request.query.get('full_info', 'false').lower() == "true"
443
+
444
+ if not overwrite and os.path.exists(dest):
445
+ return web.Response(status=409, text="File already exists")
446
+
447
+ logging.info(f"moving '{source}' -> '{dest}'")
448
+ shutil.move(source, dest)
449
+
450
+ user_path = self.get_request_user_filepath(request, None)
451
+ if full_info:
452
+ resp = get_file_info(dest, user_path)
453
+ else:
454
+ resp = os.path.relpath(dest, user_path)
455
+
456
+ return web.json_response(resp)
blueprints/put_blueprints_here ADDED
File without changes
comfy/audio_encoders/audio_encoders.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .wav2vec2 import Wav2Vec2Model
2
+ from .whisper import WhisperLargeV3
3
+ import comfy.model_management
4
+ import comfy.ops
5
+ import comfy.utils
6
+ import logging
7
+ import torchaudio
8
+
9
+
10
+ class AudioEncoderModel():
11
+ def __init__(self, config):
12
+ self.load_device = comfy.model_management.text_encoder_device()
13
+ offload_device = comfy.model_management.text_encoder_offload_device()
14
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
15
+ model_type = config.pop("model_type")
16
+ model_config = dict(config)
17
+ model_config.update({
18
+ "dtype": self.dtype,
19
+ "device": offload_device,
20
+ "operations": comfy.ops.manual_cast
21
+ })
22
+
23
+ if model_type == "wav2vec2":
24
+ self.model = Wav2Vec2Model(**model_config)
25
+ elif model_type == "whisper3":
26
+ self.model = WhisperLargeV3(**model_config)
27
+ self.model.eval()
28
+ self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
29
+ self.model_sample_rate = 16000
30
+
31
+ def load_sd(self, sd):
32
+ return self.model.load_state_dict(sd, strict=False)
33
+
34
+ def get_sd(self):
35
+ return self.model.state_dict()
36
+
37
+ def encode_audio(self, audio, sample_rate):
38
+ comfy.model_management.load_model_gpu(self.patcher)
39
+ audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
40
+ out, all_layers = self.model(audio.to(self.load_device))
41
+ outputs = {}
42
+ outputs["encoded_audio"] = out
43
+ outputs["encoded_audio_all_layers"] = all_layers
44
+ outputs["audio_samples"] = audio.shape[2]
45
+ return outputs
46
+
47
+
48
+ def load_audio_encoder_from_sd(sd, prefix=""):
49
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
50
+ if "encoder.layer_norm.bias" in sd: #wav2vec2
51
+ embed_dim = sd["encoder.layer_norm.bias"].shape[0]
52
+ if embed_dim == 1024:# large
53
+ config = {
54
+ "model_type": "wav2vec2",
55
+ "embed_dim": 1024,
56
+ "num_heads": 16,
57
+ "num_layers": 24,
58
+ "conv_norm": True,
59
+ "conv_bias": True,
60
+ "do_normalize": True,
61
+ "do_stable_layer_norm": True
62
+ }
63
+ elif embed_dim == 768: # base
64
+ config = {
65
+ "model_type": "wav2vec2",
66
+ "embed_dim": 768,
67
+ "num_heads": 12,
68
+ "num_layers": 12,
69
+ "conv_norm": False,
70
+ "conv_bias": False,
71
+ "do_normalize": False, # chinese-wav2vec2-base has this False
72
+ "do_stable_layer_norm": False
73
+ }
74
+ else:
75
+ raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
76
+ elif "model.encoder.embed_positions.weight" in sd:
77
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
78
+ config = {
79
+ "model_type": "whisper3",
80
+ }
81
+ else:
82
+ raise RuntimeError("ERROR: audio encoder not supported.")
83
+
84
+ audio_encoder = AudioEncoderModel(config)
85
+ m, u = audio_encoder.load_sd(sd)
86
+ if len(m) > 0:
87
+ logging.warning("missing audio encoder: {}".format(m))
88
+ if len(u) > 0:
89
+ logging.warning("unexpected audio encoder: {}".format(u))
90
+
91
+ return audio_encoder
comfy/audio_encoders/wav2vec2.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from comfy.ldm.modules.attention import optimized_attention_masked
4
+
5
+
6
+ class LayerNormConv(nn.Module):
7
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
8
+ super().__init__()
9
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
10
+ self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
11
+
12
+ def forward(self, x):
13
+ x = self.conv(x)
14
+ return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
15
+
16
+ class LayerGroupNormConv(nn.Module):
17
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
18
+ super().__init__()
19
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
20
+ self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
21
+
22
+ def forward(self, x):
23
+ x = self.conv(x)
24
+ return torch.nn.functional.gelu(self.layer_norm(x))
25
+
26
+ class ConvNoNorm(nn.Module):
27
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
28
+ super().__init__()
29
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
30
+
31
+ def forward(self, x):
32
+ x = self.conv(x)
33
+ return torch.nn.functional.gelu(x)
34
+
35
+
36
+ class ConvFeatureEncoder(nn.Module):
37
+ def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
38
+ super().__init__()
39
+ if conv_norm:
40
+ self.conv_layers = nn.ModuleList([
41
+ LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
42
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
43
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
44
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
45
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
46
+ LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
47
+ LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
48
+ ])
49
+ else:
50
+ self.conv_layers = nn.ModuleList([
51
+ LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
52
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
53
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
54
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
55
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
56
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
57
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
58
+ ])
59
+
60
+ def forward(self, x):
61
+ x = x.unsqueeze(1)
62
+
63
+ for conv in self.conv_layers:
64
+ x = conv(x)
65
+
66
+ return x.transpose(1, 2)
67
+
68
+
69
+ class FeatureProjection(nn.Module):
70
+ def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
71
+ super().__init__()
72
+ self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
73
+ self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
74
+
75
+ def forward(self, x):
76
+ x = self.layer_norm(x)
77
+ x = self.projection(x)
78
+ return x
79
+
80
+
81
+ class PositionalConvEmbedding(nn.Module):
82
+ def __init__(self, embed_dim=768, kernel_size=128, groups=16):
83
+ super().__init__()
84
+ self.conv = nn.Conv1d(
85
+ embed_dim,
86
+ embed_dim,
87
+ kernel_size=kernel_size,
88
+ padding=kernel_size // 2,
89
+ groups=groups,
90
+ )
91
+ self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
92
+ self.activation = nn.GELU()
93
+
94
+ def forward(self, x):
95
+ x = x.transpose(1, 2)
96
+ x = self.conv(x)[:, :, :-1]
97
+ x = self.activation(x)
98
+ x = x.transpose(1, 2)
99
+ return x
100
+
101
+
102
+ class TransformerEncoder(nn.Module):
103
+ def __init__(
104
+ self,
105
+ embed_dim=768,
106
+ num_heads=12,
107
+ num_layers=12,
108
+ mlp_ratio=4.0,
109
+ do_stable_layer_norm=True,
110
+ dtype=None, device=None, operations=None
111
+ ):
112
+ super().__init__()
113
+
114
+ self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
115
+ self.layers = nn.ModuleList([
116
+ TransformerEncoderLayer(
117
+ embed_dim=embed_dim,
118
+ num_heads=num_heads,
119
+ mlp_ratio=mlp_ratio,
120
+ do_stable_layer_norm=do_stable_layer_norm,
121
+ device=device, dtype=dtype, operations=operations
122
+ )
123
+ for _ in range(num_layers)
124
+ ])
125
+
126
+ self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
127
+ self.do_stable_layer_norm = do_stable_layer_norm
128
+
129
+ def forward(self, x, mask=None):
130
+ x = x + self.pos_conv_embed(x)
131
+ all_x = ()
132
+ if not self.do_stable_layer_norm:
133
+ x = self.layer_norm(x)
134
+ for layer in self.layers:
135
+ all_x += (x,)
136
+ x = layer(x, mask)
137
+ if self.do_stable_layer_norm:
138
+ x = self.layer_norm(x)
139
+ all_x += (x,)
140
+ return x, all_x
141
+
142
+
143
+ class Attention(nn.Module):
144
+ def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
145
+ super().__init__()
146
+ self.embed_dim = embed_dim
147
+ self.num_heads = num_heads
148
+ self.head_dim = embed_dim // num_heads
149
+
150
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
151
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
152
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
153
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
154
+
155
+ def forward(self, x, mask=None):
156
+ assert (mask is None) # TODO?
157
+ q = self.q_proj(x)
158
+ k = self.k_proj(x)
159
+ v = self.v_proj(x)
160
+
161
+ out = optimized_attention_masked(q, k, v, self.num_heads)
162
+ return self.out_proj(out)
163
+
164
+
165
+ class FeedForward(nn.Module):
166
+ def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
167
+ super().__init__()
168
+ self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
169
+ self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
170
+
171
+ def forward(self, x):
172
+ x = self.intermediate_dense(x)
173
+ x = torch.nn.functional.gelu(x)
174
+ x = self.output_dense(x)
175
+ return x
176
+
177
+
178
+ class TransformerEncoderLayer(nn.Module):
179
+ def __init__(
180
+ self,
181
+ embed_dim=768,
182
+ num_heads=12,
183
+ mlp_ratio=4.0,
184
+ do_stable_layer_norm=True,
185
+ dtype=None, device=None, operations=None
186
+ ):
187
+ super().__init__()
188
+
189
+ self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
190
+
191
+ self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
192
+ self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
193
+ self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
194
+ self.do_stable_layer_norm = do_stable_layer_norm
195
+
196
+ def forward(self, x, mask=None):
197
+ residual = x
198
+ if self.do_stable_layer_norm:
199
+ x = self.layer_norm(x)
200
+ x = self.attention(x, mask=mask)
201
+ x = residual + x
202
+ if not self.do_stable_layer_norm:
203
+ x = self.layer_norm(x)
204
+ return self.final_layer_norm(x + self.feed_forward(x))
205
+ else:
206
+ return x + self.feed_forward(self.final_layer_norm(x))
207
+
208
+
209
+ class Wav2Vec2Model(nn.Module):
210
+ """Complete Wav2Vec 2.0 model."""
211
+
212
+ def __init__(
213
+ self,
214
+ embed_dim=1024,
215
+ final_dim=256,
216
+ num_heads=16,
217
+ num_layers=24,
218
+ conv_norm=True,
219
+ conv_bias=True,
220
+ do_normalize=True,
221
+ do_stable_layer_norm=True,
222
+ dtype=None, device=None, operations=None
223
+ ):
224
+ super().__init__()
225
+
226
+ conv_dim = 512
227
+ self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
228
+ self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
229
+
230
+ self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
231
+ self.do_normalize = do_normalize
232
+
233
+ self.encoder = TransformerEncoder(
234
+ embed_dim=embed_dim,
235
+ num_heads=num_heads,
236
+ num_layers=num_layers,
237
+ do_stable_layer_norm=do_stable_layer_norm,
238
+ device=device, dtype=dtype, operations=operations
239
+ )
240
+
241
+ def forward(self, x, mask_time_indices=None, return_dict=False):
242
+ x = torch.mean(x, dim=1)
243
+
244
+ if self.do_normalize:
245
+ x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
246
+
247
+ features = self.feature_extractor(x)
248
+ features = self.feature_projection(features)
249
+ batch_size, seq_len, _ = features.shape
250
+
251
+ x, all_x = self.encoder(features)
252
+ return x, all_x
comfy/audio_encoders/whisper.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from typing import Optional
6
+ from comfy.ldm.modules.attention import optimized_attention_masked
7
+ import comfy.ops
8
+
9
+ class WhisperFeatureExtractor(nn.Module):
10
+ def __init__(self, n_mels=128, device=None):
11
+ super().__init__()
12
+ self.sample_rate = 16000
13
+ self.n_fft = 400
14
+ self.hop_length = 160
15
+ self.n_mels = n_mels
16
+ self.chunk_length = 30
17
+ self.n_samples = 480000
18
+
19
+ self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
20
+ sample_rate=self.sample_rate,
21
+ n_fft=self.n_fft,
22
+ hop_length=self.hop_length,
23
+ n_mels=self.n_mels,
24
+ f_min=0,
25
+ f_max=8000,
26
+ norm="slaney",
27
+ mel_scale="slaney",
28
+ ).to(device)
29
+
30
+ def __call__(self, audio):
31
+ audio = torch.mean(audio, dim=1)
32
+ batch_size = audio.shape[0]
33
+ processed_audio = []
34
+
35
+ for i in range(batch_size):
36
+ aud = audio[i]
37
+ if aud.shape[0] > self.n_samples:
38
+ aud = aud[:self.n_samples]
39
+ elif aud.shape[0] < self.n_samples:
40
+ aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
41
+ processed_audio.append(aud)
42
+
43
+ audio = torch.stack(processed_audio)
44
+
45
+ mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
46
+
47
+ log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
48
+ log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
49
+ log_mel_spec = (log_mel_spec + 4.0) / 4.0
50
+
51
+ return log_mel_spec
52
+
53
+
54
+ class MultiHeadAttention(nn.Module):
55
+ def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
56
+ super().__init__()
57
+ assert d_model % n_heads == 0
58
+
59
+ self.d_model = d_model
60
+ self.n_heads = n_heads
61
+ self.d_k = d_model // n_heads
62
+
63
+ self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
64
+ self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
65
+ self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
66
+ self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
67
+
68
+ def forward(
69
+ self,
70
+ query: torch.Tensor,
71
+ key: torch.Tensor,
72
+ value: torch.Tensor,
73
+ mask: Optional[torch.Tensor] = None,
74
+ ) -> torch.Tensor:
75
+ batch_size, seq_len, _ = query.shape
76
+
77
+ q = self.q_proj(query)
78
+ k = self.k_proj(key)
79
+ v = self.v_proj(value)
80
+
81
+ attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
82
+ attn_output = self.out_proj(attn_output)
83
+
84
+ return attn_output
85
+
86
+
87
+ class EncoderLayer(nn.Module):
88
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
89
+ super().__init__()
90
+
91
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
92
+ self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
93
+
94
+ self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
95
+ self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
96
+ self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
97
+
98
+ def forward(
99
+ self,
100
+ x: torch.Tensor,
101
+ attention_mask: Optional[torch.Tensor] = None
102
+ ) -> torch.Tensor:
103
+ residual = x
104
+ x = self.self_attn_layer_norm(x)
105
+ x = self.self_attn(x, x, x, attention_mask)
106
+ x = residual + x
107
+
108
+ residual = x
109
+ x = self.final_layer_norm(x)
110
+ x = self.fc1(x)
111
+ x = F.gelu(x)
112
+ x = self.fc2(x)
113
+ x = residual + x
114
+
115
+ return x
116
+
117
+
118
+ class AudioEncoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ n_mels: int = 128,
122
+ n_ctx: int = 1500,
123
+ n_state: int = 1280,
124
+ n_head: int = 20,
125
+ n_layer: int = 32,
126
+ dtype=None,
127
+ device=None,
128
+ operations=None
129
+ ):
130
+ super().__init__()
131
+
132
+ self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
133
+ self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
134
+
135
+ self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
136
+
137
+ self.layers = nn.ModuleList([
138
+ EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
139
+ for _ in range(n_layer)
140
+ ])
141
+
142
+ self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
143
+
144
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
+ x = F.gelu(self.conv1(x))
146
+ x = F.gelu(self.conv2(x))
147
+
148
+ x = x.transpose(1, 2)
149
+
150
+ x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
151
+
152
+ all_x = ()
153
+ for layer in self.layers:
154
+ all_x += (x,)
155
+ x = layer(x)
156
+
157
+ x = self.layer_norm(x)
158
+ all_x += (x,)
159
+ return x, all_x
160
+
161
+
162
+ class WhisperLargeV3(nn.Module):
163
+ def __init__(
164
+ self,
165
+ n_mels: int = 128,
166
+ n_audio_ctx: int = 1500,
167
+ n_audio_state: int = 1280,
168
+ n_audio_head: int = 20,
169
+ n_audio_layer: int = 32,
170
+ dtype=None,
171
+ device=None,
172
+ operations=None
173
+ ):
174
+ super().__init__()
175
+
176
+ self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
177
+
178
+ self.encoder = AudioEncoder(
179
+ n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
180
+ dtype=dtype, device=device, operations=operations
181
+ )
182
+
183
+ def forward(self, audio):
184
+ mel = self.feature_extractor(audio)
185
+ x, all_x = self.encoder(mel)
186
+ return x, all_x
comfy/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
comfy/cldm/cldm.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..ldm.modules.diffusionmodules.util import (
8
+ timestep_embedding,
9
+ )
10
+
11
+ from ..ldm.modules.attention import SpatialTransformer
12
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
13
+ from ..ldm.util import exists
14
+ from .control_types import UNION_CONTROLNET_TYPES
15
+ from collections import OrderedDict
16
+ import comfy.ops
17
+ from comfy.ldm.modules.attention import optimized_attention
18
+
19
+ class OptimizedAttention(nn.Module):
20
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
21
+ super().__init__()
22
+ self.heads = nhead
23
+ self.c = c
24
+
25
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
26
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
27
+
28
+ def forward(self, x):
29
+ x = self.in_proj(x)
30
+ q, k, v = x.split(self.c, dim=2)
31
+ out = optimized_attention(q, k, v, self.heads)
32
+ return self.out_proj(out)
33
+
34
+ class QuickGELU(nn.Module):
35
+ def forward(self, x: torch.Tensor):
36
+ return x * torch.sigmoid(1.702 * x)
37
+
38
+ class ResBlockUnionControlnet(nn.Module):
39
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
40
+ super().__init__()
41
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
42
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
43
+ self.mlp = nn.Sequential(
44
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
45
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
46
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
47
+
48
+ def attention(self, x: torch.Tensor):
49
+ return self.attn(x)
50
+
51
+ def forward(self, x: torch.Tensor):
52
+ x = x + self.attention(self.ln_1(x))
53
+ x = x + self.mlp(self.ln_2(x))
54
+ return x
55
+
56
+ class ControlledUnetModel(UNetModel):
57
+ #implemented in the ldm unet
58
+ pass
59
+
60
+ class ControlNet(nn.Module):
61
+ def __init__(
62
+ self,
63
+ image_size,
64
+ in_channels,
65
+ model_channels,
66
+ hint_channels,
67
+ num_res_blocks,
68
+ dropout=0,
69
+ channel_mult=(1, 2, 4, 8),
70
+ conv_resample=True,
71
+ dims=2,
72
+ num_classes=None,
73
+ use_checkpoint=False,
74
+ dtype=torch.float32,
75
+ num_heads=-1,
76
+ num_head_channels=-1,
77
+ num_heads_upsample=-1,
78
+ use_scale_shift_norm=False,
79
+ resblock_updown=False,
80
+ use_new_attention_order=False,
81
+ use_spatial_transformer=False, # custom transformer support
82
+ transformer_depth=1, # custom transformer support
83
+ context_dim=None, # custom transformer support
84
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
85
+ legacy=True,
86
+ disable_self_attentions=None,
87
+ num_attention_blocks=None,
88
+ disable_middle_self_attn=False,
89
+ use_linear_in_transformer=False,
90
+ adm_in_channels=None,
91
+ transformer_depth_middle=None,
92
+ transformer_depth_output=None,
93
+ attn_precision=None,
94
+ union_controlnet_num_control_type=None,
95
+ device=None,
96
+ operations=comfy.ops.disable_weight_init,
97
+ **kwargs,
98
+ ):
99
+ super().__init__()
100
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
101
+ if use_spatial_transformer:
102
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
103
+
104
+ if context_dim is not None:
105
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
106
+ # from omegaconf.listconfig import ListConfig
107
+ # if type(context_dim) == ListConfig:
108
+ # context_dim = list(context_dim)
109
+
110
+ if num_heads_upsample == -1:
111
+ num_heads_upsample = num_heads
112
+
113
+ if num_heads == -1:
114
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
115
+
116
+ if num_head_channels == -1:
117
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
118
+
119
+ self.dims = dims
120
+ self.image_size = image_size
121
+ self.in_channels = in_channels
122
+ self.model_channels = model_channels
123
+
124
+ if isinstance(num_res_blocks, int):
125
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
126
+ else:
127
+ if len(num_res_blocks) != len(channel_mult):
128
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
129
+ "as a list/tuple (per-level) with the same length as channel_mult")
130
+ self.num_res_blocks = num_res_blocks
131
+
132
+ if disable_self_attentions is not None:
133
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
134
+ assert len(disable_self_attentions) == len(channel_mult)
135
+ if num_attention_blocks is not None:
136
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
137
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
138
+
139
+ transformer_depth = transformer_depth[:]
140
+
141
+ self.dropout = dropout
142
+ self.channel_mult = channel_mult
143
+ self.conv_resample = conv_resample
144
+ self.num_classes = num_classes
145
+ self.use_checkpoint = use_checkpoint
146
+ self.dtype = dtype
147
+ self.num_heads = num_heads
148
+ self.num_head_channels = num_head_channels
149
+ self.num_heads_upsample = num_heads_upsample
150
+ self.predict_codebook_ids = n_embed is not None
151
+
152
+ time_embed_dim = model_channels * 4
153
+ self.time_embed = nn.Sequential(
154
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
155
+ nn.SiLU(),
156
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
157
+ )
158
+
159
+ if self.num_classes is not None:
160
+ if isinstance(self.num_classes, int):
161
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
162
+ elif self.num_classes == "continuous":
163
+ self.label_emb = nn.Linear(1, time_embed_dim)
164
+ elif self.num_classes == "sequential":
165
+ assert adm_in_channels is not None
166
+ self.label_emb = nn.Sequential(
167
+ nn.Sequential(
168
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
169
+ nn.SiLU(),
170
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
171
+ )
172
+ )
173
+ else:
174
+ raise ValueError()
175
+
176
+ self.input_blocks = nn.ModuleList(
177
+ [
178
+ TimestepEmbedSequential(
179
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
180
+ )
181
+ ]
182
+ )
183
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
184
+
185
+ self.input_hint_block = TimestepEmbedSequential(
186
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
187
+ nn.SiLU(),
188
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
189
+ nn.SiLU(),
190
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
191
+ nn.SiLU(),
192
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
193
+ nn.SiLU(),
194
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
195
+ nn.SiLU(),
196
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
197
+ nn.SiLU(),
198
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
199
+ nn.SiLU(),
200
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
201
+ )
202
+
203
+ self._feature_size = model_channels
204
+ input_block_chans = [model_channels]
205
+ ch = model_channels
206
+ ds = 1
207
+ for level, mult in enumerate(channel_mult):
208
+ for nr in range(self.num_res_blocks[level]):
209
+ layers = [
210
+ ResBlock(
211
+ ch,
212
+ time_embed_dim,
213
+ dropout,
214
+ out_channels=mult * model_channels,
215
+ dims=dims,
216
+ use_checkpoint=use_checkpoint,
217
+ use_scale_shift_norm=use_scale_shift_norm,
218
+ dtype=self.dtype,
219
+ device=device,
220
+ operations=operations,
221
+ )
222
+ ]
223
+ ch = mult * model_channels
224
+ num_transformers = transformer_depth.pop(0)
225
+ if num_transformers > 0:
226
+ if num_head_channels == -1:
227
+ dim_head = ch // num_heads
228
+ else:
229
+ num_heads = ch // num_head_channels
230
+ dim_head = num_head_channels
231
+ if legacy:
232
+ #num_heads = 1
233
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
234
+ if exists(disable_self_attentions):
235
+ disabled_sa = disable_self_attentions[level]
236
+ else:
237
+ disabled_sa = False
238
+
239
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
240
+ layers.append(
241
+ SpatialTransformer(
242
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
243
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
244
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
245
+ )
246
+ )
247
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
248
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
249
+ self._feature_size += ch
250
+ input_block_chans.append(ch)
251
+ if level != len(channel_mult) - 1:
252
+ out_ch = ch
253
+ self.input_blocks.append(
254
+ TimestepEmbedSequential(
255
+ ResBlock(
256
+ ch,
257
+ time_embed_dim,
258
+ dropout,
259
+ out_channels=out_ch,
260
+ dims=dims,
261
+ use_checkpoint=use_checkpoint,
262
+ use_scale_shift_norm=use_scale_shift_norm,
263
+ down=True,
264
+ dtype=self.dtype,
265
+ device=device,
266
+ operations=operations
267
+ )
268
+ if resblock_updown
269
+ else Downsample(
270
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
271
+ )
272
+ )
273
+ )
274
+ ch = out_ch
275
+ input_block_chans.append(ch)
276
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
277
+ ds *= 2
278
+ self._feature_size += ch
279
+
280
+ if num_head_channels == -1:
281
+ dim_head = ch // num_heads
282
+ else:
283
+ num_heads = ch // num_head_channels
284
+ dim_head = num_head_channels
285
+ if legacy:
286
+ #num_heads = 1
287
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
288
+ mid_block = [
289
+ ResBlock(
290
+ ch,
291
+ time_embed_dim,
292
+ dropout,
293
+ dims=dims,
294
+ use_checkpoint=use_checkpoint,
295
+ use_scale_shift_norm=use_scale_shift_norm,
296
+ dtype=self.dtype,
297
+ device=device,
298
+ operations=operations
299
+ )]
300
+ if transformer_depth_middle >= 0:
301
+ mid_block += [SpatialTransformer( # always uses a self-attn
302
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
303
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
304
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
305
+ ),
306
+ ResBlock(
307
+ ch,
308
+ time_embed_dim,
309
+ dropout,
310
+ dims=dims,
311
+ use_checkpoint=use_checkpoint,
312
+ use_scale_shift_norm=use_scale_shift_norm,
313
+ dtype=self.dtype,
314
+ device=device,
315
+ operations=operations
316
+ )]
317
+ self.middle_block = TimestepEmbedSequential(*mid_block)
318
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
319
+ self._feature_size += ch
320
+
321
+ if union_controlnet_num_control_type is not None:
322
+ self.num_control_type = union_controlnet_num_control_type
323
+ num_trans_channel = 320
324
+ num_trans_head = 8
325
+ num_trans_layer = 1
326
+ num_proj_channel = 320
327
+ # task_scale_factor = num_trans_channel ** 0.5
328
+ self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
329
+
330
+ self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
331
+ self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
332
+ #-----------------------------------------------------------------------------------------------------
333
+
334
+ control_add_embed_dim = 256
335
+ class ControlAddEmbedding(nn.Module):
336
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
337
+ super().__init__()
338
+ self.num_control_type = num_control_type
339
+ self.in_dim = in_dim
340
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
341
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
342
+ def forward(self, control_type, dtype, device):
343
+ c_type = torch.zeros((self.num_control_type,), device=device)
344
+ c_type[control_type] = 1.0
345
+ c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
346
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
347
+
348
+ self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
349
+ else:
350
+ self.task_embedding = None
351
+ self.control_add_embedding = None
352
+
353
+ def union_controlnet_merge(self, hint, control_type, emb, context):
354
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
355
+ inputs = []
356
+ condition_list = []
357
+
358
+ for idx in range(min(1, len(control_type))):
359
+ controlnet_cond = self.input_hint_block(hint[idx], emb, context)
360
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
361
+ if idx < len(control_type):
362
+ feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
363
+
364
+ inputs.append(feat_seq.unsqueeze(1))
365
+ condition_list.append(controlnet_cond)
366
+
367
+ x = torch.cat(inputs, dim=1)
368
+ x = self.transformer_layes(x)
369
+ controlnet_cond_fuser = None
370
+ for idx in range(len(control_type)):
371
+ alpha = self.spatial_ch_projs(x[:, idx])
372
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
373
+ o = condition_list[idx] + alpha
374
+ if controlnet_cond_fuser is None:
375
+ controlnet_cond_fuser = o
376
+ else:
377
+ controlnet_cond_fuser += o
378
+ return controlnet_cond_fuser
379
+
380
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
381
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
382
+
383
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
384
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
385
+ emb = self.time_embed(t_emb)
386
+
387
+ guided_hint = None
388
+ if self.control_add_embedding is not None: #Union Controlnet
389
+ control_type = kwargs.get("control_type", [])
390
+
391
+ if any([c >= self.num_control_type for c in control_type]):
392
+ max_type = max(control_type)
393
+ max_type_name = {
394
+ v: k for k, v in UNION_CONTROLNET_TYPES.items()
395
+ }[max_type]
396
+ raise ValueError(
397
+ f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
398
+ f"({self.num_control_type}) supported.\n" +
399
+ "Please consider using the ProMax ControlNet Union model.\n" +
400
+ "https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
401
+ )
402
+
403
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
404
+ if len(control_type) > 0:
405
+ if len(hint.shape) < 5:
406
+ hint = hint.unsqueeze(dim=0)
407
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
408
+
409
+ if guided_hint is None:
410
+ guided_hint = self.input_hint_block(hint, emb, context)
411
+
412
+ out_output = []
413
+ out_middle = []
414
+
415
+ if self.num_classes is not None:
416
+ if y is None:
417
+ raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
418
+ emb = emb + self.label_emb(y)
419
+
420
+ h = x
421
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
422
+ if guided_hint is not None:
423
+ h = module(h, emb, context)
424
+ h += guided_hint
425
+ guided_hint = None
426
+ else:
427
+ h = module(h, emb, context)
428
+ out_output.append(zero_conv(h, emb, context))
429
+
430
+ h = self.middle_block(h, emb, context)
431
+ out_middle.append(self.middle_block_out(h, emb, context))
432
+
433
+ return {"middle": out_middle, "output": out_output}
434
+
comfy/cldm/control_types.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ UNION_CONTROLNET_TYPES = {
2
+ "openpose": 0,
3
+ "depth": 1,
4
+ "hed/pidi/scribble/ted": 2,
5
+ "canny/lineart/anime_lineart/mlsd": 3,
6
+ "normal": 4,
7
+ "segment": 5,
8
+ "tile": 6,
9
+ "repaint": 7,
10
+ }
comfy/cldm/dit_embedder.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
9
+
10
+
11
+ class ControlNetEmbedder(nn.Module):
12
+
13
+ def __init__(
14
+ self,
15
+ img_size: int,
16
+ patch_size: int,
17
+ in_chans: int,
18
+ attention_head_dim: int,
19
+ num_attention_heads: int,
20
+ adm_in_channels: int,
21
+ num_layers: int,
22
+ main_model_double: int,
23
+ double_y_emb: bool,
24
+ device: torch.device,
25
+ dtype: torch.dtype,
26
+ pos_embed_max_size: Optional[int] = None,
27
+ operations = None,
28
+ ):
29
+ super().__init__()
30
+ self.main_model_double = main_model_double
31
+ self.dtype = dtype
32
+ self.hidden_size = num_attention_heads * attention_head_dim
33
+ self.patch_size = patch_size
34
+ self.x_embedder = PatchEmbed(
35
+ img_size=img_size,
36
+ patch_size=patch_size,
37
+ in_chans=in_chans,
38
+ embed_dim=self.hidden_size,
39
+ strict_img_size=pos_embed_max_size is None,
40
+ device=device,
41
+ dtype=dtype,
42
+ operations=operations,
43
+ )
44
+
45
+ self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
46
+
47
+ self.double_y_emb = double_y_emb
48
+ if self.double_y_emb:
49
+ self.orig_y_embedder = VectorEmbedder(
50
+ adm_in_channels, self.hidden_size, dtype, device, operations=operations
51
+ )
52
+ self.y_embedder = VectorEmbedder(
53
+ self.hidden_size, self.hidden_size, dtype, device, operations=operations
54
+ )
55
+ else:
56
+ self.y_embedder = VectorEmbedder(
57
+ adm_in_channels, self.hidden_size, dtype, device, operations=operations
58
+ )
59
+
60
+ self.transformer_blocks = nn.ModuleList(
61
+ DismantledBlock(
62
+ hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
63
+ dtype=dtype, device=device, operations=operations
64
+ )
65
+ for _ in range(num_layers)
66
+ )
67
+
68
+ # self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
69
+ # TODO double check this logic when 8b
70
+ self.use_y_embedder = True
71
+
72
+ self.controlnet_blocks = nn.ModuleList([])
73
+ for _ in range(len(self.transformer_blocks)):
74
+ controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
75
+ self.controlnet_blocks.append(controlnet_block)
76
+
77
+ self.pos_embed_input = PatchEmbed(
78
+ img_size=img_size,
79
+ patch_size=patch_size,
80
+ in_chans=in_chans,
81
+ embed_dim=self.hidden_size,
82
+ strict_img_size=False,
83
+ device=device,
84
+ dtype=dtype,
85
+ operations=operations,
86
+ )
87
+
88
+ def forward(
89
+ self,
90
+ x: torch.Tensor,
91
+ timesteps: torch.Tensor,
92
+ y: Optional[torch.Tensor] = None,
93
+ context: Optional[torch.Tensor] = None,
94
+ hint = None,
95
+ ) -> Tuple[Tensor, List[Tensor]]:
96
+ x_shape = list(x.shape)
97
+ x = self.x_embedder(x)
98
+ if not self.double_y_emb:
99
+ h = (x_shape[-2] + 1) // self.patch_size
100
+ w = (x_shape[-1] + 1) // self.patch_size
101
+ x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
102
+ c = self.t_embedder(timesteps, dtype=x.dtype)
103
+ if y is not None and self.y_embedder is not None:
104
+ if self.double_y_emb:
105
+ y = self.orig_y_embedder(y)
106
+ y = self.y_embedder(y)
107
+ c = c + y
108
+
109
+ x = x + self.pos_embed_input(hint)
110
+
111
+ block_out = ()
112
+
113
+ repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
114
+ for i in range(len(self.transformer_blocks)):
115
+ out = self.transformer_blocks[i](x, c)
116
+ if not self.double_y_emb:
117
+ x = out
118
+ block_out += (self.controlnet_blocks[i](out),) * repeat
119
+
120
+ return {"output": block_out}
comfy/cldm/mmdit.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ import comfy.ldm.modules.diffusionmodules.mmdit
4
+
5
+ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
6
+ def __init__(
7
+ self,
8
+ num_blocks = None,
9
+ control_latent_channels = None,
10
+ dtype = None,
11
+ device = None,
12
+ operations = None,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
16
+ # controlnet_blocks
17
+ self.controlnet_blocks = torch.nn.ModuleList([])
18
+ for _ in range(len(self.joint_blocks)):
19
+ self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
20
+
21
+ if control_latent_channels is None:
22
+ control_latent_channels = self.in_channels
23
+
24
+ self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
25
+ None,
26
+ self.patch_size,
27
+ control_latent_channels,
28
+ self.hidden_size,
29
+ bias=True,
30
+ strict_img_size=False,
31
+ dtype=dtype,
32
+ device=device,
33
+ operations=operations
34
+ )
35
+
36
+ def forward(
37
+ self,
38
+ x: torch.Tensor,
39
+ timesteps: torch.Tensor,
40
+ y: Optional[torch.Tensor] = None,
41
+ context: Optional[torch.Tensor] = None,
42
+ hint = None,
43
+ ) -> torch.Tensor:
44
+
45
+ #weird sd3 controlnet specific stuff
46
+ y = torch.zeros_like(y)
47
+
48
+ if self.context_processor is not None:
49
+ context = self.context_processor(context)
50
+
51
+ hw = x.shape[-2:]
52
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
53
+ x += self.pos_embed_input(hint)
54
+
55
+ c = self.t_embedder(timesteps, dtype=x.dtype)
56
+ if y is not None and self.y_embedder is not None:
57
+ y = self.y_embedder(y)
58
+ c = c + y
59
+
60
+ if context is not None:
61
+ context = self.context_embedder(context)
62
+
63
+ output = []
64
+
65
+ blocks = len(self.joint_blocks)
66
+ for i in range(blocks):
67
+ context, x = self.joint_blocks[i](
68
+ context,
69
+ x,
70
+ c=c,
71
+ use_checkpoint=self.use_checkpoint,
72
+ )
73
+
74
+ out = self.controlnet_blocks[i](x)
75
+ count = self.depth // blocks
76
+ if i == blocks - 1:
77
+ count -= 1
78
+ for j in range(count):
79
+ output.append(out)
80
+
81
+ return {"output": output}
comfy/cli_args.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import os
4
+ import comfy.options
5
+
6
+
7
+ class EnumAction(argparse.Action):
8
+ """
9
+ Argparse action for handling Enums
10
+ """
11
+ def __init__(self, **kwargs):
12
+ # Pop off the type value
13
+ enum_type = kwargs.pop("type", None)
14
+
15
+ # Ensure an Enum subclass is provided
16
+ if enum_type is None:
17
+ raise ValueError("type must be assigned an Enum when using EnumAction")
18
+ if not issubclass(enum_type, enum.Enum):
19
+ raise TypeError("type must be an Enum when using EnumAction")
20
+
21
+ # Generate choices from the Enum
22
+ choices = tuple(e.value for e in enum_type)
23
+ kwargs.setdefault("choices", choices)
24
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
25
+
26
+ super(EnumAction, self).__init__(**kwargs)
27
+
28
+ self._enum = enum_type
29
+
30
+ def __call__(self, parser, namespace, values, option_string=None):
31
+ # Convert value back into an Enum
32
+ value = self._enum(values)
33
+ setattr(namespace, self.dest, value)
34
+
35
+
36
+ parser = argparse.ArgumentParser()
37
+
38
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
39
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
40
+ parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
41
+ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
42
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
43
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
44
+
45
+ parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
46
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
47
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
48
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
49
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
50
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
51
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
53
+ parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
54
+ cm_group = parser.add_mutually_exclusive_group()
55
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
56
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
57
+
58
+
59
+ fp_group = parser.add_mutually_exclusive_group()
60
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
61
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
62
+
63
+ fpunet_group = parser.add_mutually_exclusive_group()
64
+ fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
65
+ fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
66
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
67
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
68
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
69
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
70
+ fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
71
+
72
+ fpvae_group = parser.add_mutually_exclusive_group()
73
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
74
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
75
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
76
+
77
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
78
+
79
+ fpte_group = parser.add_mutually_exclusive_group()
80
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
81
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
82
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
83
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
84
+ fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
85
+
86
+ parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
87
+
88
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
89
+
90
+ parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
91
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
92
+ parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
93
+
94
+ class LatentPreviewMethod(enum.Enum):
95
+ NoPreviews = "none"
96
+ Auto = "auto"
97
+ Latent2RGB = "latent2rgb"
98
+ TAESD = "taesd"
99
+
100
+ @classmethod
101
+ def from_string(cls, value: str):
102
+ for member in cls:
103
+ if member.value == value:
104
+ return member
105
+ return None
106
+
107
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
108
+
109
+ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
110
+
111
+ cache_group = parser.add_mutually_exclusive_group()
112
+ cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
113
+ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
114
+ cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
115
+ cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
116
+
117
+ attn_group = parser.add_mutually_exclusive_group()
118
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
119
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
120
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
121
+ attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
122
+ attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
123
+
124
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
125
+
126
+ upcast = parser.add_mutually_exclusive_group()
127
+ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
128
+ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
129
+
130
+
131
+ parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
132
+ manager_group = parser.add_mutually_exclusive_group()
133
+ manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
134
+ manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
135
+
136
+
137
+ vram_group = parser.add_mutually_exclusive_group()
138
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
139
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
140
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
141
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
142
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
143
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
144
+
145
+ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
146
+
147
+ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
148
+ parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
149
+
150
+ parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
151
+
152
+ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
153
+
154
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
155
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
156
+
157
+ class PerformanceFeature(enum.Enum):
158
+ Fp16Accumulation = "fp16_accumulation"
159
+ Fp8MatrixMultiplication = "fp8_matrix_mult"
160
+ CublasOps = "cublas_ops"
161
+ AutoTune = "autotune"
162
+
163
+ parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
164
+
165
+ parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
166
+
167
+ parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
168
+ parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
169
+
170
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
171
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
172
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
173
+
174
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
175
+ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
176
+ parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
177
+ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
178
+
179
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
180
+
181
+ parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
182
+ parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
183
+
184
+
185
+ # The default built-in provider hosted under web/
186
+ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
187
+
188
+ parser.add_argument(
189
+ "--front-end-version",
190
+ type=str,
191
+ default=DEFAULT_VERSION_STRING,
192
+ help="""
193
+ Specifies the version of the frontend to be used. This command needs internet connectivity to query and
194
+ download available frontend implementations from GitHub releases.
195
+
196
+ The version string should be in the format of:
197
+ [repoOwner]/[repoName]@[version]
198
+ where version is one of: "latest" or a valid version number (e.g. "1.0.0")
199
+ """,
200
+ )
201
+
202
+ def is_valid_directory(path: str) -> str:
203
+ """Validate if the given path is a directory, and check permissions."""
204
+ if not os.path.exists(path):
205
+ raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
206
+ if not os.path.isdir(path):
207
+ raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
208
+ if not os.access(path, os.R_OK):
209
+ raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
210
+ return path
211
+
212
+ parser.add_argument(
213
+ "--front-end-root",
214
+ type=is_valid_directory,
215
+ default=None,
216
+ help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
217
+ )
218
+
219
+ parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
220
+
221
+ parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
222
+
223
+ parser.add_argument(
224
+ "--comfy-api-base",
225
+ type=str,
226
+ default="https://api.comfy.org",
227
+ help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
228
+ )
229
+
230
+ database_default_path = os.path.abspath(
231
+ os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
232
+ )
233
+ parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
234
+ parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
235
+
236
+ if comfy.options.args_parsing:
237
+ args = parser.parse_args()
238
+ else:
239
+ args = parser.parse_args([])
240
+
241
+ if args.windows_standalone_build:
242
+ args.auto_launch = True
243
+
244
+ if args.disable_auto_launch:
245
+ args.auto_launch = False
246
+
247
+ if args.force_fp16:
248
+ args.fp16_unet = True
249
+
250
+
251
+ # '--fast' is not provided, use an empty set
252
+ if args.fast is None:
253
+ args.fast = set()
254
+ # '--fast' is provided with an empty list, enable all optimizations
255
+ elif args.fast == []:
256
+ args.fast = set(PerformanceFeature)
257
+ # '--fast' is provided with a list of performance features, use that list
258
+ else:
259
+ args.fast = set(args.fast)
comfy/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
comfy/clip_model.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.ldm.modules.attention import optimized_attention_for_device
3
+ import comfy.ops
4
+ import math
5
+
6
+ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
7
+ image = image[:, :, :, :3] if image.shape[3] > 3 else image
8
+ mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
9
+ std = torch.tensor(std, device=image.device, dtype=image.dtype)
10
+ image = image.movedim(-1, 1)
11
+ if not (image.shape[2] == size and image.shape[3] == size):
12
+ if crop:
13
+ scale = (size / min(image.shape[2], image.shape[3]))
14
+ scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
15
+ else:
16
+ scale_size = (size, size)
17
+
18
+ image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
19
+ h = (image.shape[2] - size)//2
20
+ w = (image.shape[3] - size)//2
21
+ image = image[:,:,h:h+size,w:w+size]
22
+ image = torch.clip((255. * image), 0, 255).round() / 255.0
23
+ return (image - mean.view([3,1,1])) / std.view([3,1,1])
24
+
25
+ def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
26
+ def scale_dim(size, scale):
27
+ scaled = math.ceil(size * scale / patch_size) * patch_size
28
+ return max(patch_size, int(scaled))
29
+
30
+ # Binary search for optimal scale
31
+ lo, hi = eps / 10, 100.0
32
+ while hi - lo >= eps:
33
+ mid = (lo + hi) / 2
34
+ h, w = scale_dim(oh, mid), scale_dim(ow, mid)
35
+ if (h // patch_size) * (w // patch_size) <= max_num_patches:
36
+ lo = mid
37
+ else:
38
+ hi = mid
39
+
40
+ return scale_dim(oh, lo), scale_dim(ow, lo)
41
+
42
+ def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
43
+ if size > 0:
44
+ return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
45
+
46
+ image = image[:, :, :, :3] if image.shape[3] > 3 else image
47
+ mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
48
+ std = torch.tensor(std, device=image.device, dtype=image.dtype)
49
+ image = image.movedim(-1, 1)
50
+
51
+ b, c, h, w = image.shape
52
+ h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
53
+
54
+ image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
55
+ image = torch.clip((255. * image), 0, 255).round() / 255.0
56
+ return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
57
+
58
+ class CLIPAttention(torch.nn.Module):
59
+ def __init__(self, embed_dim, heads, dtype, device, operations):
60
+ super().__init__()
61
+
62
+ self.heads = heads
63
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
64
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
65
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
66
+
67
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
68
+
69
+ def forward(self, x, mask=None, optimized_attention=None):
70
+ q = self.q_proj(x)
71
+ k = self.k_proj(x)
72
+ v = self.v_proj(x)
73
+
74
+ out = optimized_attention(q, k, v, self.heads, mask)
75
+ return self.out_proj(out)
76
+
77
+ ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
78
+ "gelu": torch.nn.functional.gelu,
79
+ "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
80
+ }
81
+
82
+ class CLIPMLP(torch.nn.Module):
83
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
84
+ super().__init__()
85
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
86
+ self.activation = ACTIVATIONS[activation]
87
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
88
+
89
+ def forward(self, x):
90
+ x = self.fc1(x)
91
+ x = self.activation(x)
92
+ x = self.fc2(x)
93
+ return x
94
+
95
+ class CLIPLayer(torch.nn.Module):
96
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
97
+ super().__init__()
98
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
99
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
100
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
101
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
102
+
103
+ def forward(self, x, mask=None, optimized_attention=None):
104
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
105
+ x += self.mlp(self.layer_norm2(x))
106
+ return x
107
+
108
+
109
+ class CLIPEncoder(torch.nn.Module):
110
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
111
+ super().__init__()
112
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
113
+
114
+ def forward(self, x, mask=None, intermediate_output=None):
115
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
116
+
117
+ all_intermediate = None
118
+ if intermediate_output is not None:
119
+ if intermediate_output == "all":
120
+ all_intermediate = []
121
+ intermediate_output = None
122
+ elif intermediate_output < 0:
123
+ intermediate_output = len(self.layers) + intermediate_output
124
+
125
+ intermediate = None
126
+ for i, l in enumerate(self.layers):
127
+ x = l(x, mask, optimized_attention)
128
+ if i == intermediate_output:
129
+ intermediate = x.clone()
130
+ if all_intermediate is not None:
131
+ all_intermediate.append(x.unsqueeze(1).clone())
132
+
133
+ if all_intermediate is not None:
134
+ intermediate = torch.cat(all_intermediate, dim=1)
135
+
136
+ return x, intermediate
137
+
138
+ class CLIPEmbeddings(torch.nn.Module):
139
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
140
+ super().__init__()
141
+ self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
142
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
143
+
144
+ def forward(self, input_tokens, dtype=torch.float32):
145
+ return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
146
+
147
+
148
+ class CLIPTextModel_(torch.nn.Module):
149
+ def __init__(self, config_dict, dtype, device, operations):
150
+ num_layers = config_dict["num_hidden_layers"]
151
+ embed_dim = config_dict["hidden_size"]
152
+ heads = config_dict["num_attention_heads"]
153
+ intermediate_size = config_dict["intermediate_size"]
154
+ intermediate_activation = config_dict["hidden_act"]
155
+ num_positions = config_dict["max_position_embeddings"]
156
+ self.eos_token_id = config_dict["eos_token_id"]
157
+
158
+ super().__init__()
159
+ self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
160
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
161
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
162
+
163
+ def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
164
+ if embeds is not None:
165
+ x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
166
+ else:
167
+ x = self.embeddings(input_tokens, dtype=dtype)
168
+
169
+ mask = None
170
+ if attention_mask is not None:
171
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
172
+ mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
173
+
174
+ causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
175
+
176
+ if mask is not None:
177
+ mask += causal_mask
178
+ else:
179
+ mask = causal_mask
180
+
181
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
182
+ x = self.final_layer_norm(x)
183
+ if i is not None and final_layer_norm_intermediate:
184
+ i = self.final_layer_norm(i)
185
+
186
+ if num_tokens is not None:
187
+ pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
188
+ else:
189
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
190
+ return x, i, pooled_output
191
+
192
+ class CLIPTextModel(torch.nn.Module):
193
+ def __init__(self, config_dict, dtype, device, operations):
194
+ super().__init__()
195
+ self.num_layers = config_dict["num_hidden_layers"]
196
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
197
+ embed_dim = config_dict["hidden_size"]
198
+ self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
199
+ self.dtype = dtype
200
+
201
+ def get_input_embeddings(self):
202
+ return self.text_model.embeddings.token_embedding
203
+
204
+ def set_input_embeddings(self, embeddings):
205
+ self.text_model.embeddings.token_embedding = embeddings
206
+
207
+ def forward(self, *args, **kwargs):
208
+ x = self.text_model(*args, **kwargs)
209
+ out = self.text_projection(x[2])
210
+ return (x[0], x[1], out, x[2])
211
+
212
+ def siglip2_pos_embed(embed_weight, embeds, orig_shape):
213
+ embed_weight_len = round(embed_weight.shape[0] ** 0.5)
214
+ embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
215
+ embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
216
+ embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
217
+ return embeds + embed_weight
218
+
219
+ class Siglip2Embeddings(torch.nn.Module):
220
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
221
+ super().__init__()
222
+ self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
223
+ self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
224
+ self.patch_size = patch_size
225
+
226
+ def forward(self, pixel_values):
227
+ b, c, h, w = pixel_values.shape
228
+ img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
229
+ img = img.permute(0, 1, 3, 2, 4, 5)
230
+ img = img.reshape(b, img.shape[1] * img.shape[2], -1)
231
+ img = self.patch_embedding(img)
232
+ return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
233
+
234
+ class CLIPVisionEmbeddings(torch.nn.Module):
235
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
236
+ super().__init__()
237
+
238
+ num_patches = (image_size // patch_size) ** 2
239
+ if model_type == "siglip_vision_model":
240
+ self.class_embedding = None
241
+ patch_bias = True
242
+ else:
243
+ num_patches = num_patches + 1
244
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
245
+ patch_bias = False
246
+
247
+ self.patch_embedding = operations.Conv2d(
248
+ in_channels=num_channels,
249
+ out_channels=embed_dim,
250
+ kernel_size=patch_size,
251
+ stride=patch_size,
252
+ bias=patch_bias,
253
+ dtype=dtype,
254
+ device=device
255
+ )
256
+
257
+ self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
258
+
259
+ def forward(self, pixel_values):
260
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
261
+ if self.class_embedding is not None:
262
+ embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
263
+ return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
264
+
265
+
266
+ class CLIPVision(torch.nn.Module):
267
+ def __init__(self, config_dict, dtype, device, operations):
268
+ super().__init__()
269
+ num_layers = config_dict["num_hidden_layers"]
270
+ embed_dim = config_dict["hidden_size"]
271
+ heads = config_dict["num_attention_heads"]
272
+ intermediate_size = config_dict["intermediate_size"]
273
+ intermediate_activation = config_dict["hidden_act"]
274
+ model_type = config_dict["model_type"]
275
+
276
+ if model_type in ["siglip2_vision_model"]:
277
+ self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
278
+ else:
279
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
280
+ if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
281
+ self.pre_layrnorm = lambda a: a
282
+ self.output_layernorm = True
283
+ else:
284
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
285
+ self.output_layernorm = False
286
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
287
+ self.post_layernorm = operations.LayerNorm(embed_dim)
288
+
289
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
290
+ x = self.embeddings(pixel_values)
291
+ x = self.pre_layrnorm(x)
292
+ #TODO: attention_mask?
293
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
294
+ if self.output_layernorm:
295
+ x = self.post_layernorm(x)
296
+ pooled_output = x
297
+ else:
298
+ pooled_output = self.post_layernorm(x[:, 0, :])
299
+ return x, i, pooled_output
300
+
301
+ class LlavaProjector(torch.nn.Module):
302
+ def __init__(self, in_dim, out_dim, dtype, device, operations):
303
+ super().__init__()
304
+ self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
305
+ self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
306
+
307
+ def forward(self, x):
308
+ return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
309
+
310
+ class CLIPVisionModelProjection(torch.nn.Module):
311
+ def __init__(self, config_dict, dtype, device, operations):
312
+ super().__init__()
313
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
314
+ if "projection_dim" in config_dict:
315
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
316
+ else:
317
+ self.visual_projection = lambda a: a
318
+
319
+ if "llava3" == config_dict.get("projector_type", None):
320
+ self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
321
+ else:
322
+ self.multi_modal_projector = None
323
+
324
+ def forward(self, *args, **kwargs):
325
+ x = self.vision_model(*args, **kwargs)
326
+ out = self.visual_projection(x[2])
327
+ projected = None
328
+ if self.multi_modal_projector is not None:
329
+ projected = self.multi_modal_projector(x[1])
330
+
331
+ return (x[0], x[1], out, projected)