luh1124 commited on
Commit
57f3596
·
1 Parent(s): 5950889

fix(pipeline): do not treat ckpts/... as Hugging Face org/model id

Browse files

- _is_hub_model_id: exclude ckpts/, weights/, configs/, checkpoints/
- Load models from snapshot root only for relative paths; remove broad try/except fallback
- Fixes 404 on hf.co/ckpts/... when NeAR CPU preload runs on Space

Made-with: Cursor

tests/test_pipeline_hub_id.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression: pipeline model paths like ckpts/foo must not be treated as Hub org/model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ import unittest
7
+ from pathlib import Path
8
+
9
+
10
+ _BASE_PATH = Path(__file__).resolve().parents[1] / "trellis/pipelines/base.py"
11
+
12
+
13
+ class PipelineHubIdTests(unittest.TestCase):
14
+ def test_base_py_documents_ckpts_exclusion(self) -> None:
15
+ src = _BASE_PATH.read_text(encoding="utf-8")
16
+ self.assertIn("_NOT_HUB_PATH_PREFIXES", src)
17
+ self.assertIn('"ckpts/"', src)
18
+ self.assertIn("models.from_pretrained(os.path.join(root, vp))", src)
19
+
20
+ def test_regex_ckpts_two_segment_not_hub_semantics(self) -> None:
21
+ """Mirror _is_hub_model_id: two-segment ckpts/... is not a Hub id."""
22
+ not_hub = ("ckpts/", "weights/", "configs/", "checkpoints/")
23
+
24
+ def is_hub(t: str) -> bool:
25
+ t = t.strip().replace("\\", "/")
26
+ if not re.fullmatch(r"[\w.-]+/[\w.-]+", t):
27
+ return False
28
+ tl = t.lower()
29
+ return not any(tl.startswith(p) for p in not_hub)
30
+
31
+ self.assertFalse(is_hub("ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"))
32
+ self.assertTrue(is_hub("luh0502/NeAR"))
33
+
34
+
35
+ if __name__ == "__main__":
36
+ unittest.main()
trellis/pipelines/base.py CHANGED
@@ -10,10 +10,22 @@ import torch.nn as nn
10
  from .. import models
11
 
12
 
 
 
 
 
 
 
 
 
 
13
  def _is_hub_model_id(s: str) -> bool:
14
- """`org/model` style id, not a filesystem path."""
15
  t = s.strip().replace("\\", "/")
16
- return bool(re.fullmatch(r"[\w.-]+/[\w.-]+", t))
 
 
 
17
 
18
 
19
  class Pipeline:
@@ -85,10 +97,11 @@ class Pipeline:
85
 
86
  _models = {}
87
  for k, v in args["models"].items():
88
- try:
89
- _models[k] = models.from_pretrained(os.path.join(root, v))
90
- except Exception:
91
- _models[k] = models.from_pretrained(v)
 
92
 
93
  new_pipeline = Pipeline(_models)
94
  new_pipeline._pretrained_args = args
 
10
  from .. import models
11
 
12
 
13
+ # Paths under a snapshot (e.g. ckpts/foo) also match org/model regex — exclude them.
14
+ _NOT_HUB_PATH_PREFIXES = (
15
+ "ckpts/",
16
+ "weights/",
17
+ "configs/",
18
+ "checkpoints/",
19
+ )
20
+
21
+
22
  def _is_hub_model_id(s: str) -> bool:
23
+ """True for Hugging Face ``org/model``, not for relative snapshot paths like ``ckpts/...``."""
24
  t = s.strip().replace("\\", "/")
25
+ if not re.fullmatch(r"[\w.-]+/[\w.-]+", t):
26
+ return False
27
+ tl = t.lower()
28
+ return not any(tl.startswith(p) for p in _NOT_HUB_PATH_PREFIXES)
29
 
30
 
31
  class Pipeline:
 
97
 
98
  _models = {}
99
  for k, v in args["models"].items():
100
+ vp = str(v).strip().replace("\\", "/")
101
+ if _is_hub_model_id(vp):
102
+ _models[k] = models.from_pretrained(vp)
103
+ else:
104
+ _models[k] = models.from_pretrained(os.path.join(root, vp))
105
 
106
  new_pipeline = Pipeline(_models)
107
  new_pipeline._pretrained_args = args