techfreakworm commited on
Commit
3d485ad
·
unverified ·
1 Parent(s): 55a3bb4

feat(models): ensure_models with HF cache symlinks + missing-file fallback

Browse files
Files changed (2) hide show
  1. models.py +105 -0
  2. tests/test_models.py +41 -0
models.py CHANGED
@@ -5,8 +5,16 @@ supported. If that ever happens we'll qualify by ComfyUI loader-type.
5
  """
6
  from __future__ import annotations
7
 
 
 
 
 
8
  from dataclasses import dataclass
9
 
 
 
 
 
10
 
11
  @dataclass(frozen=True)
12
  class ModelEntry:
@@ -144,3 +152,100 @@ def _flatten_widget_values(values):
144
  yield from _flatten_widget_values(list(v.values()))
145
  else:
146
  yield v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
  from __future__ import annotations
7
 
8
+ import logging
9
+ import os
10
+ import pathlib
11
+ from collections.abc import Iterator
12
  from dataclasses import dataclass
13
 
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
 
19
  @dataclass(frozen=True)
20
  class ModelEntry:
 
152
  yield from _flatten_widget_values(list(v.values()))
153
  else:
154
  yield v
155
+
156
+
157
+ @dataclass
158
+ class DownloadEvent:
159
+ filename: str
160
+ mb_done: float
161
+ mb_total: float
162
+
163
+
164
+ def _on_spaces() -> bool:
165
+ return bool(os.environ.get("SPACES_ZERO_GPU"))
166
+
167
+
168
+ def _comfy_models_dir() -> pathlib.Path:
169
+ raw = os.environ.get("COMFY_MODELS_DIR")
170
+ if raw:
171
+ return pathlib.Path(raw)
172
+ if _on_spaces():
173
+ return pathlib.Path("/data/models")
174
+ return pathlib.Path(__file__).parent / "comfyui" / "models"
175
+
176
+
177
+ def ensure_models(filenames: set[str]) -> Iterator[DownloadEvent]:
178
+ """Ensure each requested model is materialized in comfyui/models/<type>/.
179
+
180
+ Local mode: hf_hub_download into the user's HF cache; symlink to comfyui/models/.
181
+ Spaces mode: hf_hub_download with cache_dir=/data; comfyui/models/ symlinks
182
+ point into /data.
183
+
184
+ Files not in MODEL_REGISTRY are skipped (with a warning) — useful when the
185
+ workflow has been manually customized with non-canonical filenames that the
186
+ user supplies via their own ComfyUI install.
187
+
188
+ Yields DownloadEvent on each successfully materialized file (mb_done==mb_total
189
+ when already cached locally).
190
+ """
191
+ comfy_models = _comfy_models_dir()
192
+ cache_dir = pathlib.Path(
193
+ os.environ.get(
194
+ "HF_HUB_CACHE",
195
+ pathlib.Path.home() / ".cache" / "huggingface" / "hub",
196
+ )
197
+ )
198
+
199
+ for filename in filenames:
200
+ if filename not in MODEL_REGISTRY:
201
+ logger.warning(
202
+ "model file %r not in MODEL_REGISTRY; skipping. "
203
+ "Add an entry to MODEL_REGISTRY or override the loader in the workflow.",
204
+ filename,
205
+ )
206
+ continue
207
+ entry = MODEL_REGISTRY[filename]
208
+
209
+ # Resolve source: hf_hub_download returns the cache path (or downloads).
210
+ try:
211
+ source = pathlib.Path(
212
+ hf_hub_download(
213
+ repo_id=entry.repo_id,
214
+ filename=filename,
215
+ cache_dir=str(cache_dir),
216
+ local_dir=None,
217
+ )
218
+ )
219
+ size_mb = source.stat().st_size / 1024 / 1024
220
+ yield DownloadEvent(filename, size_mb, size_mb)
221
+ except Exception as exc:
222
+ # Fall back to scanning the cache for a placeholder file (test mode + offline mode).
223
+ candidates = list(cache_dir.rglob(filename))
224
+ if not candidates:
225
+ logger.warning(
226
+ "could not download or locate %r in HF cache: %s; skipping",
227
+ filename,
228
+ exc,
229
+ )
230
+ continue
231
+ source = candidates[0]
232
+ yield DownloadEvent(filename, 0.0, 0.0)
233
+
234
+ # Build symlink target inside comfy_models
235
+ dest_dir = comfy_models / entry.comfy_type
236
+ if entry.subfolder:
237
+ dest_dir = dest_dir / entry.subfolder
238
+ dest_dir.mkdir(parents=True, exist_ok=True)
239
+ dest = dest_dir / filename
240
+
241
+ if dest.is_symlink() or dest.exists():
242
+ dest.unlink()
243
+ dest.symlink_to(source)
244
+
245
+
246
+ def ensure_models_for_mode(mode: str) -> Iterator[DownloadEvent]:
247
+ """Convenience: walk a mode's workflow and ensure all referenced models exist."""
248
+ import workflow as workflow_module # local import to avoid cycle at import time
249
+ wf = workflow_module.load_template(mode)
250
+ needed = walk_workflow_for_models(wf)
251
+ yield from ensure_models(needed)
tests/test_models.py CHANGED
@@ -1,4 +1,6 @@
1
  """Unit tests for models.py — MODEL_REGISTRY and ensure_models_for_mode."""
 
 
2
  import models
3
  import workflow
4
 
@@ -26,3 +28,42 @@ def test_walk_workflow_for_models_finds_t2v_loaders():
26
  for name in needed
27
  )
28
  assert any("gemma" in name.lower() for name in needed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Unit tests for models.py — MODEL_REGISTRY and ensure_models_for_mode."""
2
+ import pathlib
3
+
4
  import models
5
  import workflow
6
 
 
28
  for name in needed
29
  )
30
  assert any("gemma" in name.lower() for name in needed)
31
+
32
+
33
+ def test_ensure_models_creates_symlinks_local(tmp_path, monkeypatch, fake_hf_cache):
34
+ """In local mode, ensure_models creates symlinks from comfy/models -> HF cache."""
35
+ monkeypatch.setenv("HF_HUB_CACHE", str(fake_hf_cache))
36
+ monkeypatch.setattr(models, "_on_spaces", lambda: False)
37
+
38
+ # Force the HF Hub call to fail so the fallback path (cache_dir.rglob) is exercised.
39
+ def _raise(*_args, **_kwargs):
40
+ raise RuntimeError("offline test: forcing fallback to cache scan")
41
+ monkeypatch.setattr(models, "hf_hub_download", _raise)
42
+
43
+ comfy_models = tmp_path / "comfyui" / "models"
44
+ monkeypatch.setattr(models, "_comfy_models_dir", lambda: comfy_models)
45
+
46
+ needed = {
47
+ "ltx-2.3-22b-distilled.safetensors",
48
+ "model-00001-of-00005.safetensors",
49
+ }
50
+ events = list(models.ensure_models(needed))
51
+
52
+ # Each requested file should now have a symlink in comfyui/models/<type>/
53
+ assert (comfy_models / "checkpoints" / "ltx-2.3-22b-distilled.safetensors").is_symlink()
54
+ assert (comfy_models / "text_encoders" / "gemma-3-12b-it"
55
+ / "model-00001-of-00005.safetensors").is_symlink()
56
+
57
+
58
+ def test_ensure_models_skips_unregistered_files_with_warning(tmp_path, monkeypatch, fake_hf_cache, caplog):
59
+ """Files not in MODEL_REGISTRY are skipped (with warning), not raised."""
60
+ import logging
61
+ monkeypatch.setenv("HF_HUB_CACHE", str(fake_hf_cache))
62
+ monkeypatch.setattr(models, "_on_spaces", lambda: False)
63
+ monkeypatch.setattr(models, "_comfy_models_dir", lambda: tmp_path / "comfyui" / "models")
64
+
65
+ with caplog.at_level(logging.WARNING):
66
+ events = list(models.ensure_models({"nonexistent_phantom_file.safetensors"}))
67
+
68
+ # Should not raise, should log a warning, should yield no events for the missing entry.
69
+ assert any("nonexistent_phantom_file" in record.message for record in caplog.records)