| from __future__ import annotations |
|
|
| import os |
| import tempfile |
| import unittest |
| import importlib.util |
| import sys |
| import types |
| from pathlib import Path |
| from unittest import mock |
|
|
|
|
| MODULE_PATH = Path(__file__).resolve().parents[1] / "trellis" / "utils" / "dino_hub.py" |
| _SPEC = importlib.util.spec_from_file_location("near_test_dino_hub", MODULE_PATH) |
| assert _SPEC is not None and _SPEC.loader is not None |
| dino_hub = importlib.util.module_from_spec(_SPEC) |
| _SPEC.loader.exec_module(dino_hub) |
|
|
|
|
| def _write_hub_repo(repo_root: Path) -> None: |
| repo_root.mkdir(parents=True, exist_ok=True) |
| (repo_root / "hubconf.py").write_text("def dinov2_vitl14_reg(pretrained=True): return None\n", encoding="utf-8") |
|
|
|
|
| class DinoHubTests(unittest.TestCase): |
| def setUp(self) -> None: |
| self._tmpdir = tempfile.TemporaryDirectory() |
| self.tmp_path = Path(self._tmpdir.name) |
| self._old_env = os.environ.copy() |
| for key in ( |
| "NEAR_DINO_LOCAL_REPO", |
| "NEAR_AUX_REPO", |
| "NEAR_DINO_REPO_SUBDIR", |
| "NEAR_DINO_MODEL_ID", |
| "NEAR_DINO_FILENAME", |
| "NEAR_DINO_TORCH_HUB_REPO", |
| ): |
| os.environ.pop(key, None) |
| dino_hub.resolve_dinov2_repo_root.cache_clear() |
|
|
| def tearDown(self) -> None: |
| os.environ.clear() |
| os.environ.update(self._old_env) |
| dino_hub.resolve_dinov2_repo_root.cache_clear() |
| self._tmpdir.cleanup() |
|
|
| def test_resolve_dinov2_repo_root_prefers_local_override(self) -> None: |
| repo_root = self.tmp_path / "dinov2-local" |
| _write_hub_repo(repo_root) |
| os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root) |
|
|
| resolved = dino_hub.resolve_dinov2_repo_root() |
|
|
| self.assertEqual(resolved, str(repo_root.resolve())) |
|
|
| def test_resolve_dinov2_repo_root_requires_mirror_when_no_local(self) -> None: |
| with self.assertRaises(FileNotFoundError): |
| dino_hub.resolve_dinov2_repo_root() |
|
|
| def test_resolve_dinov2_repo_root_uses_aux_repo_subdir(self) -> None: |
| aux_root = self.tmp_path / "near-assets" |
| repo_root = aux_root / "dinov2" |
| _write_hub_repo(repo_root) |
| os.environ["NEAR_AUX_REPO"] = str(aux_root) |
| os.environ["NEAR_DINO_REPO_SUBDIR"] = "dinov2" |
|
|
| resolved = dino_hub.resolve_dinov2_repo_root() |
|
|
| self.assertEqual(resolved, str(repo_root.resolve())) |
|
|
| def test_resolve_dinov2_repo_root_validates_optional_filename(self) -> None: |
| repo_root = self.tmp_path / "dinov2-local" |
| _write_hub_repo(repo_root) |
| os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root) |
| os.environ["NEAR_DINO_FILENAME"] = "weights/model.bin" |
|
|
| with self.assertRaisesRegex(FileNotFoundError, "weights/model.bin"): |
| dino_hub.resolve_dinov2_repo_root() |
|
|
| def test_load_dinov2_model_uses_local_torch_hub_mirror(self) -> None: |
| repo_root = self.tmp_path / "dinov2-local" |
| _write_hub_repo(repo_root) |
| os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root) |
| os.environ["NEAR_DINO_MODEL_ID"] = "custom_dino_entry" |
|
|
| fake_torch = types.SimpleNamespace(hub=types.SimpleNamespace(load=mock.Mock(return_value=object()))) |
| with mock.patch.dict(sys.modules, {"torch": fake_torch}): |
| loaded = dino_hub.load_dinov2_model("dinov2_vitl14_reg") |
|
|
| self.assertIsNotNone(loaded) |
| fake_torch.hub.load.assert_called_once_with( |
| repo_or_dir=str(repo_root.resolve()), |
| model="custom_dino_entry", |
| source="local", |
| pretrained=True, |
| ) |
|
|
| def test_load_dinov2_model_torch_hub_github_when_no_mirror_env(self) -> None: |
| fake_torch = types.SimpleNamespace(hub=types.SimpleNamespace(load=mock.Mock(return_value=object()))) |
| with mock.patch.dict(sys.modules, {"torch": fake_torch}): |
| loaded = dino_hub.load_dinov2_model("dinov2_vitl14_reg") |
|
|
| self.assertIsNotNone(loaded) |
| fake_torch.hub.load.assert_called_once_with( |
| "facebookresearch/dinov2", |
| "dinov2_vitl14_reg", |
| pretrained=True, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|