NeAR / tests /test_dino_hub.py
luh1124's picture
fix(dino): drop default HF near-assets; use torch.hub facebookresearch/dinov2
94d9e58
from __future__ import annotations
import os
import tempfile
import unittest
import importlib.util
import sys
import types
from pathlib import Path
from unittest import mock
MODULE_PATH = Path(__file__).resolve().parents[1] / "trellis" / "utils" / "dino_hub.py"
_SPEC = importlib.util.spec_from_file_location("near_test_dino_hub", MODULE_PATH)
assert _SPEC is not None and _SPEC.loader is not None
dino_hub = importlib.util.module_from_spec(_SPEC)
_SPEC.loader.exec_module(dino_hub)
def _write_hub_repo(repo_root: Path) -> None:
repo_root.mkdir(parents=True, exist_ok=True)
(repo_root / "hubconf.py").write_text("def dinov2_vitl14_reg(pretrained=True): return None\n", encoding="utf-8")
class DinoHubTests(unittest.TestCase):
def setUp(self) -> None:
self._tmpdir = tempfile.TemporaryDirectory()
self.tmp_path = Path(self._tmpdir.name)
self._old_env = os.environ.copy()
for key in (
"NEAR_DINO_LOCAL_REPO",
"NEAR_AUX_REPO",
"NEAR_DINO_REPO_SUBDIR",
"NEAR_DINO_MODEL_ID",
"NEAR_DINO_FILENAME",
"NEAR_DINO_TORCH_HUB_REPO",
):
os.environ.pop(key, None)
dino_hub.resolve_dinov2_repo_root.cache_clear()
def tearDown(self) -> None:
os.environ.clear()
os.environ.update(self._old_env)
dino_hub.resolve_dinov2_repo_root.cache_clear()
self._tmpdir.cleanup()
def test_resolve_dinov2_repo_root_prefers_local_override(self) -> None:
repo_root = self.tmp_path / "dinov2-local"
_write_hub_repo(repo_root)
os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root)
resolved = dino_hub.resolve_dinov2_repo_root()
self.assertEqual(resolved, str(repo_root.resolve()))
def test_resolve_dinov2_repo_root_requires_mirror_when_no_local(self) -> None:
with self.assertRaises(FileNotFoundError):
dino_hub.resolve_dinov2_repo_root()
def test_resolve_dinov2_repo_root_uses_aux_repo_subdir(self) -> None:
aux_root = self.tmp_path / "near-assets"
repo_root = aux_root / "dinov2"
_write_hub_repo(repo_root)
os.environ["NEAR_AUX_REPO"] = str(aux_root)
os.environ["NEAR_DINO_REPO_SUBDIR"] = "dinov2"
resolved = dino_hub.resolve_dinov2_repo_root()
self.assertEqual(resolved, str(repo_root.resolve()))
def test_resolve_dinov2_repo_root_validates_optional_filename(self) -> None:
repo_root = self.tmp_path / "dinov2-local"
_write_hub_repo(repo_root)
os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root)
os.environ["NEAR_DINO_FILENAME"] = "weights/model.bin"
with self.assertRaisesRegex(FileNotFoundError, "weights/model.bin"):
dino_hub.resolve_dinov2_repo_root()
def test_load_dinov2_model_uses_local_torch_hub_mirror(self) -> None:
repo_root = self.tmp_path / "dinov2-local"
_write_hub_repo(repo_root)
os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root)
os.environ["NEAR_DINO_MODEL_ID"] = "custom_dino_entry"
fake_torch = types.SimpleNamespace(hub=types.SimpleNamespace(load=mock.Mock(return_value=object())))
with mock.patch.dict(sys.modules, {"torch": fake_torch}):
loaded = dino_hub.load_dinov2_model("dinov2_vitl14_reg")
self.assertIsNotNone(loaded)
fake_torch.hub.load.assert_called_once_with(
repo_or_dir=str(repo_root.resolve()),
model="custom_dino_entry",
source="local",
pretrained=True,
)
def test_load_dinov2_model_torch_hub_github_when_no_mirror_env(self) -> None:
fake_torch = types.SimpleNamespace(hub=types.SimpleNamespace(load=mock.Mock(return_value=object())))
with mock.patch.dict(sys.modules, {"torch": fake_torch}):
loaded = dino_hub.load_dinov2_model("dinov2_vitl14_reg")
self.assertIsNotNone(loaded)
fake_torch.hub.load.assert_called_once_with(
"facebookresearch/dinov2",
"dinov2_vitl14_reg",
pretrained=True,
)
if __name__ == "__main__":
unittest.main()