NeAR / tests /test_pipeline_hub_id.py
luh1124's picture
fix(pipeline): do not treat ckpts/... as Hugging Face org/model id
57f3596
"""Regression: pipeline model paths like ckpts/foo must not be treated as Hub org/model."""
from __future__ import annotations
import re
import unittest
from pathlib import Path
_BASE_PATH = Path(__file__).resolve().parents[1] / "trellis/pipelines/base.py"
class PipelineHubIdTests(unittest.TestCase):
def test_base_py_documents_ckpts_exclusion(self) -> None:
src = _BASE_PATH.read_text(encoding="utf-8")
self.assertIn("_NOT_HUB_PATH_PREFIXES", src)
self.assertIn('"ckpts/"', src)
self.assertIn("models.from_pretrained(os.path.join(root, vp))", src)
def test_regex_ckpts_two_segment_not_hub_semantics(self) -> None:
"""Mirror _is_hub_model_id: two-segment ckpts/... is not a Hub id."""
not_hub = ("ckpts/", "weights/", "configs/", "checkpoints/")
def is_hub(t: str) -> bool:
t = t.strip().replace("\\", "/")
if not re.fullmatch(r"[\w.-]+/[\w.-]+", t):
return False
tl = t.lower()
return not any(tl.startswith(p) for p in not_hub)
self.assertFalse(is_hub("ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"))
self.assertTrue(is_hub("luh0502/NeAR"))
if __name__ == "__main__":
unittest.main()