| """Regression: pipeline model paths like ckpts/foo must not be treated as Hub org/model.""" |
|
|
| from __future__ import annotations |
|
|
| import re |
| import unittest |
| from pathlib import Path |
|
|
|
|
| _BASE_PATH = Path(__file__).resolve().parents[1] / "trellis/pipelines/base.py" |
|
|
|
|
| class PipelineHubIdTests(unittest.TestCase): |
| def test_base_py_documents_ckpts_exclusion(self) -> None: |
| src = _BASE_PATH.read_text(encoding="utf-8") |
| self.assertIn("_NOT_HUB_PATH_PREFIXES", src) |
| self.assertIn('"ckpts/"', src) |
| self.assertIn("models.from_pretrained(os.path.join(root, vp))", src) |
|
|
| def test_regex_ckpts_two_segment_not_hub_semantics(self) -> None: |
| """Mirror _is_hub_model_id: two-segment ckpts/... is not a Hub id.""" |
| not_hub = ("ckpts/", "weights/", "configs/", "checkpoints/") |
|
|
| def is_hub(t: str) -> bool: |
| t = t.strip().replace("\\", "/") |
| if not re.fullmatch(r"[\w.-]+/[\w.-]+", t): |
| return False |
| tl = t.lower() |
| return not any(tl.startswith(p) for p in not_hub) |
|
|
| self.assertFalse(is_hub("ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16")) |
| self.assertTrue(is_hub("luh0502/NeAR")) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|