from __future__ import annotations import os import tempfile import unittest import importlib.util import sys import types from pathlib import Path from unittest import mock MODULE_PATH = Path(__file__).resolve().parents[1] / "trellis" / "utils" / "dino_hub.py" _SPEC = importlib.util.spec_from_file_location("near_test_dino_hub", MODULE_PATH) assert _SPEC is not None and _SPEC.loader is not None dino_hub = importlib.util.module_from_spec(_SPEC) _SPEC.loader.exec_module(dino_hub) def _write_hub_repo(repo_root: Path) -> None: repo_root.mkdir(parents=True, exist_ok=True) (repo_root / "hubconf.py").write_text("def dinov2_vitl14_reg(pretrained=True): return None\n", encoding="utf-8") class DinoHubTests(unittest.TestCase): def setUp(self) -> None: self._tmpdir = tempfile.TemporaryDirectory() self.tmp_path = Path(self._tmpdir.name) self._old_env = os.environ.copy() for key in ( "NEAR_DINO_LOCAL_REPO", "NEAR_AUX_REPO", "NEAR_DINO_REPO_SUBDIR", "NEAR_DINO_MODEL_ID", "NEAR_DINO_FILENAME", "NEAR_DINO_TORCH_HUB_REPO", ): os.environ.pop(key, None) dino_hub.resolve_dinov2_repo_root.cache_clear() def tearDown(self) -> None: os.environ.clear() os.environ.update(self._old_env) dino_hub.resolve_dinov2_repo_root.cache_clear() self._tmpdir.cleanup() def test_resolve_dinov2_repo_root_prefers_local_override(self) -> None: repo_root = self.tmp_path / "dinov2-local" _write_hub_repo(repo_root) os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root) resolved = dino_hub.resolve_dinov2_repo_root() self.assertEqual(resolved, str(repo_root.resolve())) def test_resolve_dinov2_repo_root_requires_mirror_when_no_local(self) -> None: with self.assertRaises(FileNotFoundError): dino_hub.resolve_dinov2_repo_root() def test_resolve_dinov2_repo_root_uses_aux_repo_subdir(self) -> None: aux_root = self.tmp_path / "near-assets" repo_root = aux_root / "dinov2" _write_hub_repo(repo_root) os.environ["NEAR_AUX_REPO"] = str(aux_root) os.environ["NEAR_DINO_REPO_SUBDIR"] = "dinov2" resolved = dino_hub.resolve_dinov2_repo_root() self.assertEqual(resolved, str(repo_root.resolve())) def test_resolve_dinov2_repo_root_validates_optional_filename(self) -> None: repo_root = self.tmp_path / "dinov2-local" _write_hub_repo(repo_root) os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root) os.environ["NEAR_DINO_FILENAME"] = "weights/model.bin" with self.assertRaisesRegex(FileNotFoundError, "weights/model.bin"): dino_hub.resolve_dinov2_repo_root() def test_load_dinov2_model_uses_local_torch_hub_mirror(self) -> None: repo_root = self.tmp_path / "dinov2-local" _write_hub_repo(repo_root) os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root) os.environ["NEAR_DINO_MODEL_ID"] = "custom_dino_entry" fake_torch = types.SimpleNamespace(hub=types.SimpleNamespace(load=mock.Mock(return_value=object()))) with mock.patch.dict(sys.modules, {"torch": fake_torch}): loaded = dino_hub.load_dinov2_model("dinov2_vitl14_reg") self.assertIsNotNone(loaded) fake_torch.hub.load.assert_called_once_with( repo_or_dir=str(repo_root.resolve()), model="custom_dino_entry", source="local", pretrained=True, ) def test_load_dinov2_model_torch_hub_github_when_no_mirror_env(self) -> None: fake_torch = types.SimpleNamespace(hub=types.SimpleNamespace(load=mock.Mock(return_value=object()))) with mock.patch.dict(sys.modules, {"torch": fake_torch}): loaded = dino_hub.load_dinov2_model("dinov2_vitl14_reg") self.assertIsNotNone(loaded) fake_torch.hub.load.assert_called_once_with( "facebookresearch/dinov2", "dinov2_vitl14_reg", pretrained=True, ) if __name__ == "__main__": unittest.main()