techfreakworm commited on
Commit
261639d
·
unverified ·
1 Parent(s): a7d0a44

feat(models): hf cache mirror (hardlink blobs, preserve snapshot symlinks, copy refs)

Browse files
Files changed (2) hide show
  1. models.py +52 -0
  2. tests/test_models.py +52 -0
models.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
 
4
  import os
5
  from dataclasses import dataclass, field
 
6
  from typing import Any
7
 
8
  # Avoid importing torch at module load — keeps `import models` fast in CI.
@@ -94,3 +95,54 @@ def build_diffsynth_configs(
94
  DSConfig(model_id=c.model_id, origin_file_pattern=c.origin_file_pattern, **(vram_cfg or {}))
95
  for c in configs
96
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import os
5
  from dataclasses import dataclass, field
6
+ from pathlib import Path
7
  from typing import Any
8
 
9
  # Avoid importing torch at module load — keeps `import models` fast in CI.
 
95
  DSConfig(model_id=c.model_id, origin_file_pattern=c.origin_file_pattern, **(vram_cfg or {}))
96
  for c in configs
97
  ]
98
+
99
+
100
+ def mirror_preload_hf_cache(src_root: Path | str, dst_root: Path | str) -> None:
101
+ """Mirror a read-only HF cache tree (preload_from_hub) into a writable tree.
102
+
103
+ - ``blobs/<sha>`` files -> **hardlinked** (zero-copy, shared inode).
104
+ - ``snapshots/<commit>/...`` symlinks -> **preserved** with original relative target.
105
+ - ``refs/<branch>`` files -> **byte-copied** (HF lib overwrites on etag check).
106
+ - Directories -> ``mkdir`` so the runtime user owns them.
107
+
108
+ Falls back to ``symlink`` when ``os.link()`` raises EXDEV (cross-device).
109
+ """
110
+ import errno
111
+ import shutil
112
+
113
+ src_root = Path(src_root)
114
+ dst_root = Path(dst_root)
115
+
116
+ if not (src_root / "hub").exists():
117
+ return # nothing preloaded -- no-op
118
+
119
+ for src_dir, _, files in os.walk(src_root / "hub"):
120
+ rel = Path(src_dir).relative_to(src_root)
121
+ dst_dir = dst_root / rel
122
+ dst_dir.mkdir(parents=True, exist_ok=True)
123
+
124
+ for name in files:
125
+ src_path = Path(src_dir) / name
126
+ dst_path = dst_dir / name
127
+ if dst_path.exists():
128
+ continue
129
+
130
+ # Refs get byte-copied
131
+ if "refs/" in str(rel).replace("\\", "/"):
132
+ shutil.copy2(src_path, dst_path)
133
+ continue
134
+
135
+ # Symlinks (snapshot files) preserve their relative target
136
+ if src_path.is_symlink():
137
+ target = os.readlink(src_path)
138
+ dst_path.symlink_to(target)
139
+ continue
140
+
141
+ # Regular files (blobs) hardlink with EXDEV fallback
142
+ try:
143
+ os.link(src_path, dst_path)
144
+ except OSError as e:
145
+ if e.errno == errno.EXDEV:
146
+ dst_path.symlink_to(src_path)
147
+ else:
148
+ raise
tests/test_models.py CHANGED
@@ -36,3 +36,55 @@ def test_vram_limit_for_mps_is_unified_memory_aware():
36
 
37
  def test_vram_limit_for_cpu_is_zero():
38
  assert models.vram_limit_for("cpu", free_gb=64.0) == 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def test_vram_limit_for_cpu_is_zero():
38
  assert models.vram_limit_for("cpu", free_gb=64.0) == 0.0
39
+
40
+
41
+ def test_mirror_hardlinks_blobs(tmp_path):
42
+ """Blobs (content-addressed files) get hardlinked into the mirror."""
43
+ src = tmp_path / "src" / "hub"
44
+ dst = tmp_path / "rw"
45
+ blob_dir = src / "blobs"
46
+ blob_dir.mkdir(parents=True)
47
+ blob = blob_dir / "abcdef"
48
+ blob.write_bytes(b"hello")
49
+
50
+ models.mirror_preload_hf_cache(src.parent, dst)
51
+
52
+ mirrored = dst / "hub" / "blobs" / "abcdef"
53
+ assert mirrored.exists()
54
+ assert mirrored.stat().st_ino == blob.stat().st_ino, "should be hardlinked"
55
+
56
+
57
+ def test_mirror_preserves_snapshot_symlinks(tmp_path):
58
+ """Snapshot symlinks point at relative blob paths -- preserve as-is."""
59
+ src = tmp_path / "src" / "hub"
60
+ dst = tmp_path / "rw"
61
+ (src / "blobs").mkdir(parents=True)
62
+ blob = src / "blobs" / "abc"
63
+ blob.write_bytes(b"content")
64
+ snap_dir = src / "snapshots" / "v1"
65
+ snap_dir.mkdir(parents=True)
66
+ link = snap_dir / "model.safetensors"
67
+ link.symlink_to("../../blobs/abc")
68
+
69
+ models.mirror_preload_hf_cache(src.parent, dst)
70
+
71
+ mirrored_link = dst / "hub" / "snapshots" / "v1" / "model.safetensors"
72
+ assert mirrored_link.is_symlink()
73
+ target = os.readlink(mirrored_link)
74
+ assert target == "../../blobs/abc"
75
+
76
+
77
+ def test_mirror_byte_copies_refs(tmp_path):
78
+ """Refs are rewritten by HF lib on etag; must be a real copy, not hardlink."""
79
+ src = tmp_path / "src" / "hub"
80
+ dst = tmp_path / "rw"
81
+ refs_dir = src / "refs" / "main"
82
+ refs_dir.mkdir(parents=True)
83
+ ref = refs_dir / "v1"
84
+ ref.write_text("commit-sha\n")
85
+
86
+ models.mirror_preload_hf_cache(src.parent, dst)
87
+
88
+ mirrored_ref = dst / "hub" / "refs" / "main" / "v1"
89
+ assert mirrored_ref.read_text() == "commit-sha\n"
90
+ assert mirrored_ref.stat().st_ino != ref.stat().st_ino, "must be a real copy"