luh1124 commited on
Commit
94d9e58
·
1 Parent(s): da36f0f

fix(dino): drop default HF near-assets; use torch.hub facebookresearch/dinov2

Browse files

- No snapshot_download unless NEAR_AUX_REPO or NEAR_DINO_LOCAL_REPO is set
- Optional NEAR_DINO_TORCH_HUB_REPO overrides github repo id
- Tests: github default path; resolve() without mirror raises

Made-with: Cursor

Files changed (2) hide show
  1. tests/test_dino_hub.py +17 -1
  2. trellis/utils/dino_hub.py +57 -14
tests/test_dino_hub.py CHANGED
@@ -33,6 +33,7 @@ class DinoHubTests(unittest.TestCase):
33
  "NEAR_DINO_REPO_SUBDIR",
34
  "NEAR_DINO_MODEL_ID",
35
  "NEAR_DINO_FILENAME",
 
36
  ):
37
  os.environ.pop(key, None)
38
  dino_hub.resolve_dinov2_repo_root.cache_clear()
@@ -47,12 +48,15 @@ class DinoHubTests(unittest.TestCase):
47
  repo_root = self.tmp_path / "dinov2-local"
48
  _write_hub_repo(repo_root)
49
  os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root)
50
- os.environ["NEAR_AUX_REPO"] = "luh0502/near-assets"
51
 
52
  resolved = dino_hub.resolve_dinov2_repo_root()
53
 
54
  self.assertEqual(resolved, str(repo_root.resolve()))
55
 
 
 
 
 
56
  def test_resolve_dinov2_repo_root_uses_aux_repo_subdir(self) -> None:
57
  aux_root = self.tmp_path / "near-assets"
58
  repo_root = aux_root / "dinov2"
@@ -91,6 +95,18 @@ class DinoHubTests(unittest.TestCase):
91
  pretrained=True,
92
  )
93
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  if __name__ == "__main__":
96
  unittest.main()
 
33
  "NEAR_DINO_REPO_SUBDIR",
34
  "NEAR_DINO_MODEL_ID",
35
  "NEAR_DINO_FILENAME",
36
+ "NEAR_DINO_TORCH_HUB_REPO",
37
  ):
38
  os.environ.pop(key, None)
39
  dino_hub.resolve_dinov2_repo_root.cache_clear()
 
48
  repo_root = self.tmp_path / "dinov2-local"
49
  _write_hub_repo(repo_root)
50
  os.environ["NEAR_DINO_LOCAL_REPO"] = str(repo_root)
 
51
 
52
  resolved = dino_hub.resolve_dinov2_repo_root()
53
 
54
  self.assertEqual(resolved, str(repo_root.resolve()))
55
 
56
+ def test_resolve_dinov2_repo_root_requires_mirror_when_no_local(self) -> None:
57
+ with self.assertRaises(FileNotFoundError):
58
+ dino_hub.resolve_dinov2_repo_root()
59
+
60
  def test_resolve_dinov2_repo_root_uses_aux_repo_subdir(self) -> None:
61
  aux_root = self.tmp_path / "near-assets"
62
  repo_root = aux_root / "dinov2"
 
95
  pretrained=True,
96
  )
97
 
98
+ def test_load_dinov2_model_torch_hub_github_when_no_mirror_env(self) -> None:
99
+ fake_torch = types.SimpleNamespace(hub=types.SimpleNamespace(load=mock.Mock(return_value=object())))
100
+ with mock.patch.dict(sys.modules, {"torch": fake_torch}):
101
+ loaded = dino_hub.load_dinov2_model("dinov2_vitl14_reg")
102
+
103
+ self.assertIsNotNone(loaded)
104
+ fake_torch.hub.load.assert_called_once_with(
105
+ "facebookresearch/dinov2",
106
+ "dinov2_vitl14_reg",
107
+ pretrained=True,
108
+ )
109
+
110
 
111
  if __name__ == "__main__":
112
  unittest.main()
trellis/utils/dino_hub.py CHANGED
@@ -31,7 +31,12 @@ def _validate_repo_root(repo_root: Path) -> Path:
31
 
32
  @lru_cache(maxsize=1)
33
  def resolve_dinov2_repo_root() -> str:
34
- """Resolve a local or HF-cached torch.hub-compatible DINOv2 repository."""
 
 
 
 
 
35
  t0 = time.time()
36
  local_repo = (os.environ.get("NEAR_DINO_LOCAL_REPO") or "").strip()
37
  if _is_local_dir(local_repo):
@@ -39,7 +44,14 @@ def resolve_dinov2_repo_root() -> str:
39
  print(f"[NeAR] timing dino.resolve_repo_root: {time.time() - t0:.1f}s (source=local)", flush=True)
40
  return resolved
41
 
42
- aux_repo = (os.environ.get("NEAR_AUX_REPO") or "luh0502/near-assets").strip()
 
 
 
 
 
 
 
43
  subdir = (os.environ.get("NEAR_DINO_REPO_SUBDIR") or "dinov2").strip().strip("/")
44
  if _is_local_dir(aux_repo):
45
  local_aux = Path(aux_repo).expanduser()
@@ -70,30 +82,61 @@ def resolve_dinov2_repo_root() -> str:
70
  return resolved
71
 
72
  raise FileNotFoundError(
73
- "Could not locate a torch.hub-compatible DINOv2 repo. "
74
- "Set NEAR_DINO_LOCAL_REPO to a local mirror, or publish one under "
75
- f"{aux_repo!r} with subdir {subdir!r}."
76
  )
77
 
78
 
79
  def load_dinov2_model(model_name: str):
80
- """Load a DINOv2 backbone from a local or HF-cached mirror instead of GitHub."""
81
  import torch
82
 
83
  t0 = time.time()
84
- repo_root = resolve_dinov2_repo_root()
85
  resolved_model_name = (os.environ.get("NEAR_DINO_MODEL_ID") or model_name).strip() or model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  print(
87
- f"[NeAR] Loading DINOv2 backbone {resolved_model_name!r} from local/HF mirror {repo_root}",
88
  flush=True,
89
  )
90
  t_hub = time.time()
91
- model = torch.hub.load(
92
- repo_or_dir=repo_root,
93
- model=resolved_model_name,
94
- source="local",
95
- pretrained=True,
96
- )
97
  print(f"[NeAR] timing dino.torch_hub_load: {time.time() - t_hub:.1f}s", flush=True)
98
  print(f"[NeAR] timing dino.load_total: {time.time() - t0:.1f}s", flush=True)
99
  return model
 
31
 
32
  @lru_cache(maxsize=1)
33
  def resolve_dinov2_repo_root() -> str:
34
+ """Resolve a local or HF-cached torch.hub-compatible DINOv2 repository (mirror path only).
35
+
36
+ Requires ``NEAR_DINO_LOCAL_REPO`` (existing directory) or non-empty ``NEAR_AUX_REPO``
37
+ (local directory or Hugging Face ``org/model``). There is no default Hub id here —
38
+ default loading in :func:`load_dinov2_model` uses ``torch.hub`` + GitHub instead.
39
+ """
40
  t0 = time.time()
41
  local_repo = (os.environ.get("NEAR_DINO_LOCAL_REPO") or "").strip()
42
  if _is_local_dir(local_repo):
 
44
  print(f"[NeAR] timing dino.resolve_repo_root: {time.time() - t0:.1f}s (source=local)", flush=True)
45
  return resolved
46
 
47
+ aux_repo = (os.environ.get("NEAR_AUX_REPO") or "").strip()
48
+ if not aux_repo:
49
+ raise FileNotFoundError(
50
+ "resolve_dinov2_repo_root() needs NEAR_DINO_LOCAL_REPO or NEAR_AUX_REPO. "
51
+ "For normal runs, use load_dinov2_model() without calling this — it loads from "
52
+ "torch.hub (facebookresearch/dinov2) when no mirror env vars are set."
53
+ )
54
+
55
  subdir = (os.environ.get("NEAR_DINO_REPO_SUBDIR") or "dinov2").strip().strip("/")
56
  if _is_local_dir(aux_repo):
57
  local_aux = Path(aux_repo).expanduser()
 
82
  return resolved
83
 
84
  raise FileNotFoundError(
85
+ "Could not locate a torch.hub-compatible DINOv2 repo under NEAR_AUX_REPO. "
86
+ f"Checked {aux_repo!r} with subdir {subdir!r}, or unset NEAR_AUX_REPO to use "
87
+ "default torch.hub (facebookresearch/dinov2)."
88
  )
89
 
90
 
91
  def load_dinov2_model(model_name: str):
92
+ """Load DINOv2: local mirror, optional HF aux mirror, or official torch.hub (GitHub)."""
93
  import torch
94
 
95
  t0 = time.time()
 
96
  resolved_model_name = (os.environ.get("NEAR_DINO_MODEL_ID") or model_name).strip() or model_name
97
+
98
+ local_repo = (os.environ.get("NEAR_DINO_LOCAL_REPO") or "").strip()
99
+ if _is_local_dir(local_repo):
100
+ repo_root = str(_validate_repo_root(Path(local_repo).expanduser()))
101
+ print(
102
+ f"[NeAR] Loading DINOv2 backbone {resolved_model_name!r} from local mirror {repo_root}",
103
+ flush=True,
104
+ )
105
+ t_hub = time.time()
106
+ model = torch.hub.load(
107
+ repo_or_dir=repo_root,
108
+ model=resolved_model_name,
109
+ source="local",
110
+ pretrained=True,
111
+ )
112
+ print(f"[NeAR] timing dino.torch_hub_load: {time.time() - t_hub:.1f}s", flush=True)
113
+ print(f"[NeAR] timing dino.load_total: {time.time() - t0:.1f}s", flush=True)
114
+ return model
115
+
116
+ if (os.environ.get("NEAR_AUX_REPO") or "").strip():
117
+ repo_root = resolve_dinov2_repo_root()
118
+ print(
119
+ f"[NeAR] Loading DINOv2 backbone {resolved_model_name!r} from aux mirror {repo_root}",
120
+ flush=True,
121
+ )
122
+ t_hub = time.time()
123
+ model = torch.hub.load(
124
+ repo_or_dir=repo_root,
125
+ model=resolved_model_name,
126
+ source="local",
127
+ pretrained=True,
128
+ )
129
+ print(f"[NeAR] timing dino.torch_hub_load: {time.time() - t_hub:.1f}s", flush=True)
130
+ print(f"[NeAR] timing dino.load_total: {time.time() - t0:.1f}s", flush=True)
131
+ return model
132
+
133
+ hub_repo = (os.environ.get("NEAR_DINO_TORCH_HUB_REPO") or "facebookresearch/dinov2").strip()
134
  print(
135
+ f"[NeAR] Loading DINOv2 {resolved_model_name!r} via torch.hub from {hub_repo!r} (github)",
136
  flush=True,
137
  )
138
  t_hub = time.time()
139
+ model = torch.hub.load(hub_repo, resolved_model_name, pretrained=True)
 
 
 
 
 
140
  print(f"[NeAR] timing dino.torch_hub_load: {time.time() - t_hub:.1f}s", flush=True)
141
  print(f"[NeAR] timing dino.load_total: {time.time() - t0:.1f}s", flush=True)
142
  return model