fix(dino): drop default HF near-assets; use torch.hub facebookresearch/dinov2
Browse files- No snapshot_download unless NEAR_AUX_REPO or NEAR_DINO_LOCAL_REPO is set
- Optional NEAR_DINO_TORCH_HUB_REPO overrides github repo id
- Tests: github default path; resolve() without mirror raises
Made-with: Cursor
- tests/test_dino_hub.py +17 -1
- trellis/utils/dino_hub.py +57 -14
tests/test_dino_hub.py
CHANGED
|
@@ -33,6 +33,7 @@ class DinoHubTests(unittest.TestCase):
|
|
| 33 |
"NEAR_DINO_REPO_SUBDIR",
|
| 34 |
"NEAR_DINO_MODEL_ID",
|
| 35 |
"NEAR_DINO_FILENAME",
|
|
|
|
| 36 |
):
|
| 37 |
os.environ.pop(key, None)
|
| 38 |
dino_hub.resolve_dinov2_repo_root.cache_clear()
|
|
@@ -47,12 +48,15 @@ class DinoHubTests(unittest.TestCase):
|
|
| 47 |
repo_root = self.tmp_path / "dinov2-local"
|
| 48 |
_write_hub_repo(repo_root)
|
| 49 |
os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root)
|
| 50 |
-
os.environ["NEAR_AUX_REPO"] = "luh0502/near-assets"
|
| 51 |
|
| 52 |
resolved = dino_hub.resolve_dinov2_repo_root()
|
| 53 |
|
| 54 |
self.assertEqual(resolved, str(repo_root.resolve()))
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def test_resolve_dinov2_repo_root_uses_aux_repo_subdir(self) -> None:
|
| 57 |
aux_root = self.tmp_path / "near-assets"
|
| 58 |
repo_root = aux_root / "dinov2"
|
|
@@ -91,6 +95,18 @@ class DinoHubTests(unittest.TestCase):
|
|
| 91 |
pretrained=True,
|
| 92 |
)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
if __name__ == "__main__":
|
| 96 |
unittest.main()
|
|
|
|
| 33 |
"NEAR_DINO_REPO_SUBDIR",
|
| 34 |
"NEAR_DINO_MODEL_ID",
|
| 35 |
"NEAR_DINO_FILENAME",
|
| 36 |
+
"NEAR_DINO_TORCH_HUB_REPO",
|
| 37 |
):
|
| 38 |
os.environ.pop(key, None)
|
| 39 |
dino_hub.resolve_dinov2_repo_root.cache_clear()
|
|
|
|
| 48 |
repo_root = self.tmp_path / "dinov2-local"
|
| 49 |
_write_hub_repo(repo_root)
|
| 50 |
os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root)
|
|
|
|
| 51 |
|
| 52 |
resolved = dino_hub.resolve_dinov2_repo_root()
|
| 53 |
|
| 54 |
self.assertEqual(resolved, str(repo_root.resolve()))
|
| 55 |
|
| 56 |
+
def test_resolve_dinov2_repo_root_requires_mirror_when_no_local(self) -> None:
|
| 57 |
+
with self.assertRaises(FileNotFoundError):
|
| 58 |
+
dino_hub.resolve_dinov2_repo_root()
|
| 59 |
+
|
| 60 |
def test_resolve_dinov2_repo_root_uses_aux_repo_subdir(self) -> None:
|
| 61 |
aux_root = self.tmp_path / "near-assets"
|
| 62 |
repo_root = aux_root / "dinov2"
|
|
|
|
| 95 |
pretrained=True,
|
| 96 |
)
|
| 97 |
|
| 98 |
+
def test_load_dinov2_model_torch_hub_github_when_no_mirror_env(self) -> None:
|
| 99 |
+
fake_torch = types.SimpleNamespace(hub=types.SimpleNamespace(load=mock.Mock(return_value=object())))
|
| 100 |
+
with mock.patch.dict(sys.modules, {"torch": fake_torch}):
|
| 101 |
+
loaded = dino_hub.load_dinov2_model("dinov2_vitl14_reg")
|
| 102 |
+
|
| 103 |
+
self.assertIsNotNone(loaded)
|
| 104 |
+
fake_torch.hub.load.assert_called_once_with(
|
| 105 |
+
"facebookresearch/dinov2",
|
| 106 |
+
"dinov2_vitl14_reg",
|
| 107 |
+
pretrained=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
|
| 111 |
if __name__ == "__main__":
|
| 112 |
unittest.main()
|
trellis/utils/dino_hub.py
CHANGED
|
@@ -31,7 +31,12 @@ def _validate_repo_root(repo_root: Path) -> Path:
|
|
| 31 |
|
| 32 |
@lru_cache(maxsize=1)
|
| 33 |
def resolve_dinov2_repo_root() -> str:
|
| 34 |
-
"""Resolve a local or HF-cached torch.hub-compatible DINOv2 repository.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
t0 = time.time()
|
| 36 |
local_repo = (os.environ.get("NEAR_DINO_LOCAL_REPO") or "").strip()
|
| 37 |
if _is_local_dir(local_repo):
|
|
@@ -39,7 +44,14 @@ def resolve_dinov2_repo_root() -> str:
|
|
| 39 |
print(f"[NeAR] timing dino.resolve_repo_root: {time.time() - t0:.1f}s (source=local)", flush=True)
|
| 40 |
return resolved
|
| 41 |
|
| 42 |
-
aux_repo = (os.environ.get("NEAR_AUX_REPO") or "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
subdir = (os.environ.get("NEAR_DINO_REPO_SUBDIR") or "dinov2").strip().strip("/")
|
| 44 |
if _is_local_dir(aux_repo):
|
| 45 |
local_aux = Path(aux_repo).expanduser()
|
|
@@ -70,30 +82,61 @@ def resolve_dinov2_repo_root() -> str:
|
|
| 70 |
return resolved
|
| 71 |
|
| 72 |
raise FileNotFoundError(
|
| 73 |
-
"Could not locate a torch.hub-compatible DINOv2 repo. "
|
| 74 |
-
"
|
| 75 |
-
|
| 76 |
)
|
| 77 |
|
| 78 |
|
| 79 |
def load_dinov2_model(model_name: str):
|
| 80 |
-
"""Load
|
| 81 |
import torch
|
| 82 |
|
| 83 |
t0 = time.time()
|
| 84 |
-
repo_root = resolve_dinov2_repo_root()
|
| 85 |
resolved_model_name = (os.environ.get("NEAR_DINO_MODEL_ID") or model_name).strip() or model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
print(
|
| 87 |
-
f"[NeAR] Loading DINOv2
|
| 88 |
flush=True,
|
| 89 |
)
|
| 90 |
t_hub = time.time()
|
| 91 |
-
model = torch.hub.load(
|
| 92 |
-
repo_or_dir=repo_root,
|
| 93 |
-
model=resolved_model_name,
|
| 94 |
-
source="local",
|
| 95 |
-
pretrained=True,
|
| 96 |
-
)
|
| 97 |
print(f"[NeAR] timing dino.torch_hub_load: {time.time() - t_hub:.1f}s", flush=True)
|
| 98 |
print(f"[NeAR] timing dino.load_total: {time.time() - t0:.1f}s", flush=True)
|
| 99 |
return model
|
|
|
|
| 31 |
|
| 32 |
@lru_cache(maxsize=1)
|
| 33 |
def resolve_dinov2_repo_root() -> str:
|
| 34 |
+
"""Resolve a local or HF-cached torch.hub-compatible DINOv2 repository (mirror path only).
|
| 35 |
+
|
| 36 |
+
Requires ``NEAR_DINO_LOCAL_REPO`` (existing directory) or non-empty ``NEAR_AUX_REPO``
|
| 37 |
+
(local directory or Hugging Face ``org/model``). There is no default Hub id here —
|
| 38 |
+
default loading in :func:`load_dinov2_model` uses ``torch.hub`` + GitHub instead.
|
| 39 |
+
"""
|
| 40 |
t0 = time.time()
|
| 41 |
local_repo = (os.environ.get("NEAR_DINO_LOCAL_REPO") or "").strip()
|
| 42 |
if _is_local_dir(local_repo):
|
|
|
|
| 44 |
print(f"[NeAR] timing dino.resolve_repo_root: {time.time() - t0:.1f}s (source=local)", flush=True)
|
| 45 |
return resolved
|
| 46 |
|
| 47 |
+
aux_repo = (os.environ.get("NEAR_AUX_REPO") or "").strip()
|
| 48 |
+
if not aux_repo:
|
| 49 |
+
raise FileNotFoundError(
|
| 50 |
+
"resolve_dinov2_repo_root() needs NEAR_DINO_LOCAL_REPO or NEAR_AUX_REPO. "
|
| 51 |
+
"For normal runs, use load_dinov2_model() without calling this — it loads from "
|
| 52 |
+
"torch.hub (facebookresearch/dinov2) when no mirror env vars are set."
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
subdir = (os.environ.get("NEAR_DINO_REPO_SUBDIR") or "dinov2").strip().strip("/")
|
| 56 |
if _is_local_dir(aux_repo):
|
| 57 |
local_aux = Path(aux_repo).expanduser()
|
|
|
|
| 82 |
return resolved
|
| 83 |
|
| 84 |
raise FileNotFoundError(
|
| 85 |
+
"Could not locate a torch.hub-compatible DINOv2 repo under NEAR_AUX_REPO. "
|
| 86 |
+
f"Checked {aux_repo!r} with subdir {subdir!r}, or unset NEAR_AUX_REPO to use "
|
| 87 |
+
"default torch.hub (facebookresearch/dinov2)."
|
| 88 |
)
|
| 89 |
|
| 90 |
|
| 91 |
def load_dinov2_model(model_name: str):
|
| 92 |
+
"""Load DINOv2: local mirror, optional HF aux mirror, or official torch.hub (GitHub)."""
|
| 93 |
import torch
|
| 94 |
|
| 95 |
t0 = time.time()
|
|
|
|
| 96 |
resolved_model_name = (os.environ.get("NEAR_DINO_MODEL_ID") or model_name).strip() or model_name
|
| 97 |
+
|
| 98 |
+
local_repo = (os.environ.get("NEAR_DINO_LOCAL_REPO") or "").strip()
|
| 99 |
+
if _is_local_dir(local_repo):
|
| 100 |
+
repo_root = str(_validate_repo_root(Path(local_repo).expanduser()))
|
| 101 |
+
print(
|
| 102 |
+
f"[NeAR] Loading DINOv2 backbone {resolved_model_name!r} from local mirror {repo_root}",
|
| 103 |
+
flush=True,
|
| 104 |
+
)
|
| 105 |
+
t_hub = time.time()
|
| 106 |
+
model = torch.hub.load(
|
| 107 |
+
repo_or_dir=repo_root,
|
| 108 |
+
model=resolved_model_name,
|
| 109 |
+
source="local",
|
| 110 |
+
pretrained=True,
|
| 111 |
+
)
|
| 112 |
+
print(f"[NeAR] timing dino.torch_hub_load: {time.time() - t_hub:.1f}s", flush=True)
|
| 113 |
+
print(f"[NeAR] timing dino.load_total: {time.time() - t0:.1f}s", flush=True)
|
| 114 |
+
return model
|
| 115 |
+
|
| 116 |
+
if (os.environ.get("NEAR_AUX_REPO") or "").strip():
|
| 117 |
+
repo_root = resolve_dinov2_repo_root()
|
| 118 |
+
print(
|
| 119 |
+
f"[NeAR] Loading DINOv2 backbone {resolved_model_name!r} from aux mirror {repo_root}",
|
| 120 |
+
flush=True,
|
| 121 |
+
)
|
| 122 |
+
t_hub = time.time()
|
| 123 |
+
model = torch.hub.load(
|
| 124 |
+
repo_or_dir=repo_root,
|
| 125 |
+
model=resolved_model_name,
|
| 126 |
+
source="local",
|
| 127 |
+
pretrained=True,
|
| 128 |
+
)
|
| 129 |
+
print(f"[NeAR] timing dino.torch_hub_load: {time.time() - t_hub:.1f}s", flush=True)
|
| 130 |
+
print(f"[NeAR] timing dino.load_total: {time.time() - t0:.1f}s", flush=True)
|
| 131 |
+
return model
|
| 132 |
+
|
| 133 |
+
hub_repo = (os.environ.get("NEAR_DINO_TORCH_HUB_REPO") or "facebookresearch/dinov2").strip()
|
| 134 |
print(
|
| 135 |
+
f"[NeAR] Loading DINOv2 {resolved_model_name!r} via torch.hub from {hub_repo!r} (github)",
|
| 136 |
flush=True,
|
| 137 |
)
|
| 138 |
t_hub = time.time()
|
| 139 |
+
model = torch.hub.load(hub_repo, resolved_model_name, pretrained=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
print(f"[NeAR] timing dino.torch_hub_load: {time.time() - t_hub:.1f}s", flush=True)
|
| 141 |
print(f"[NeAR] timing dino.load_total: {time.time() - t0:.1f}s", flush=True)
|
| 142 |
return model
|