Update inference_tagger_standalone.py
Browse filesFix backbone state-dict loading: remap backbone.model.layer.* β backbone.layer.*
The checkpoint stores the 32 transformer blocks under backbone.model.layer.N.* (HF-style, with an intermediate model wrapper), but DINOv3ViTH in this script declares them at backbone.layer.N.*. Combined with strict=False, assign=True in load_state_dict, all 608 block parameters (32 layers Γ 19 tensors per block) were silently failing to load β the backbone ran on default nn.Linear / nn.LayerNorm initializations while only the head loaded correctly. The only hint was a printed [Tagger] Missing keys (608): ['backbone.layer.0.layer_scale1', ...] line that was easy to miss, and the model produced plausible-looking but essentially random tag predictions, making it feel like undertraining.
Confirmed by dumping tagger_proto.safetensors keys β they're all under backbone.model.layer.N.* and the head is a single projection.weight of shape (74625, 6400).
Changes:
Strip the intermediate model. segment from backbone keys during loading so backbone.model.layer.N.* maps to self.layer[N].* correctly.
Load both backbone and head with strict=True, so any future name/shape drift fails loudly at load time instead of silently returning noise.
Auto-detect head layout (currently a single Linear) so this class of silent mis-load can't recur if the head changes later.
Minor: consistent aspect-ratio preservation in preprocessing, torch.zeros instead of torch.empty for embedding parameters, drop the redundant torch.autocast wrapper (backbone is explicitly cast to bf16, head stays fp32 per the training recipe).
Verified by running the loader against a synthesized state dict matching the real key layout (616 keys: 5 embedding + 608 block + 2 final norm + 1 head) β strict load passes and a forward returns the right logit shape. Also confirmed by another user who hit the same bug and fixed it by remapping the keys, reporting that outputs went "from horrifically bad to pretty much perfect."
- inference_tagger_standalone.py +325 -145
|
@@ -64,17 +64,19 @@ from safetensors.torch import load_file
|
|
| 64 |
# All hyperparameters match facebook/dinov3-vith16plus-pretrain-lvd1689m
|
| 65 |
# =============================================================================
|
| 66 |
|
| 67 |
-
D_MODEL
|
| 68 |
-
N_HEADS
|
| 69 |
-
HEAD_DIM
|
| 70 |
-
N_LAYERS
|
| 71 |
-
D_FFN
|
| 72 |
N_REGISTERS = 4
|
| 73 |
-
PATCH_SIZE
|
| 74 |
-
ROPE_THETA
|
| 75 |
-
ROPE_RESCALE = 2.0
|
| 76 |
-
LN_EPS
|
| 77 |
-
LAYERSCALE
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# ---------------------------------------------------------------------------
|
|
@@ -83,25 +85,23 @@ LAYERSCALE = 1.0
|
|
| 83 |
|
| 84 |
@lru_cache(maxsize=32)
|
| 85 |
def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
|
| 86 |
-
"""Normalised [-1,+1] patch-centre coordinates (float32, cached)."""
|
| 87 |
device = torch.device(device_str)
|
| 88 |
cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
|
| 89 |
cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
|
| 90 |
coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
|
| 91 |
-
coords = 2.0 * coords - 1.0
|
| 92 |
coords = coords * ROPE_RESCALE
|
| 93 |
return coords # [h*w, 2]
|
| 94 |
|
| 95 |
|
| 96 |
def _build_rope(h_patches: int, w_patches: int,
|
| 97 |
dtype: torch.dtype, device: torch.device):
|
| 98 |
-
|
| 99 |
-
coords = _patch_coords_cached(h_patches, w_patches, str(device)) # [P, 2]
|
| 100 |
inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
|
| 101 |
-
0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device))
|
| 102 |
-
angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
|
| 103 |
-
angles = angles.flatten(1, 2).tile(2)
|
| 104 |
-
cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
|
| 105 |
sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
|
| 106 |
return cos, sin
|
| 107 |
|
|
@@ -113,7 +113,6 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
| 113 |
|
| 114 |
def _apply_rope(q: torch.Tensor, k: torch.Tensor,
|
| 115 |
cos: torch.Tensor, sin: torch.Tensor):
|
| 116 |
-
"""Apply RoPE only to patch tokens (skip CLS + register prefix)."""
|
| 117 |
n_pre = 1 + N_REGISTERS
|
| 118 |
q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
|
| 119 |
k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
|
|
@@ -123,7 +122,7 @@ def _apply_rope(q: torch.Tensor, k: torch.Tensor,
|
|
| 123 |
|
| 124 |
|
| 125 |
# ---------------------------------------------------------------------------
|
| 126 |
-
#
|
| 127 |
# ---------------------------------------------------------------------------
|
| 128 |
|
| 129 |
class _Attention(nn.Module):
|
|
@@ -134,7 +133,7 @@ class _Attention(nn.Module):
|
|
| 134 |
self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
|
| 135 |
self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
|
| 136 |
|
| 137 |
-
def forward(self, x
|
| 138 |
B, S, _ = x.shape
|
| 139 |
q = self.q_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
|
| 140 |
k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
|
|
@@ -148,125 +147,259 @@ class _GatedMLP(nn.Module):
|
|
| 148 |
def __init__(self):
|
| 149 |
super().__init__()
|
| 150 |
self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
|
| 151 |
-
self.up_proj
|
| 152 |
-
self.down_proj = nn.Linear(D_FFN,
|
| 153 |
|
| 154 |
-
def forward(self, x
|
| 155 |
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 156 |
|
| 157 |
|
| 158 |
class _Block(nn.Module):
|
| 159 |
def __init__(self):
|
| 160 |
super().__init__()
|
| 161 |
-
self.norm1
|
| 162 |
-
self.attention
|
| 163 |
self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
|
| 164 |
-
self.norm2
|
| 165 |
-
self.mlp
|
| 166 |
self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
|
| 167 |
|
| 168 |
-
def forward(self, x
|
| 169 |
x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
|
| 170 |
x = x + self.mlp(self.norm2(x)) * self.layer_scale2
|
| 171 |
return x
|
| 172 |
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
class DINOv3ViTH(nn.Module):
|
| 179 |
"""DINOv3 ViT-H/16+ backbone.
|
| 180 |
|
| 181 |
-
Accepts any H, W that are multiples of 16.
|
| 182 |
-
Returns last_hidden_state [B, 1+R+P, D_MODEL].
|
| 183 |
Token layout: [CLS, reg_0..reg_3, patch_0..patch_N].
|
| 184 |
-
|
| 185 |
-
State-dict keys are intentionally identical to the HuggingFace
|
| 186 |
-
transformers layout so .safetensors checkpoints load without remapping.
|
| 187 |
"""
|
| 188 |
|
| 189 |
def __init__(self):
|
| 190 |
super().__init__()
|
| 191 |
-
# These names must match HF exactly
|
| 192 |
self.embeddings = _Embeddings()
|
| 193 |
self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
|
| 194 |
-
self.norm
|
| 195 |
-
|
| 196 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
| 197 |
-
strict, missing_keys, unexpected_keys, error_msgs):
|
| 198 |
-
# HF stores layer_scale as a sub-module with a "lambda1" parameter;
|
| 199 |
-
# we store it as a plain Parameter directly on _Block.
|
| 200 |
-
# Remap "layer.i.layer_scale{1,2}.lambda1" β "layer.i.layer_scale{1,2}"
|
| 201 |
-
for k in list(state_dict.keys()):
|
| 202 |
-
if k.startswith(prefix) and ".layer_scale" in k and k.endswith(".lambda1"):
|
| 203 |
-
new_k = k[:-len(".lambda1")]
|
| 204 |
-
state_dict[new_k] = state_dict.pop(k)
|
| 205 |
-
# Drop rope_embeddings buffer (computed on-the-fly)
|
| 206 |
-
for k in list(state_dict.keys()):
|
| 207 |
-
if k.startswith(prefix) and "rope_embeddings" in k:
|
| 208 |
-
state_dict.pop(k)
|
| 209 |
-
super()._load_from_state_dict(
|
| 210 |
-
state_dict, prefix, local_metadata, strict,
|
| 211 |
-
missing_keys, unexpected_keys, error_msgs)
|
| 212 |
-
|
| 213 |
-
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 214 |
-
B, _, H, W = pixel_values.shape
|
| 215 |
-
x = self.embeddings(pixel_values) # [B, 1+R+P, D]
|
| 216 |
|
|
|
|
|
|
|
|
|
|
| 217 |
h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
|
| 218 |
cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
|
| 219 |
-
|
| 220 |
for block in self.layer:
|
| 221 |
x = block(x, cos, sin)
|
| 222 |
-
|
| 223 |
return self.norm(x)
|
| 224 |
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
"""
|
| 231 |
|
| 232 |
-
def __init__(self
|
|
|
|
| 233 |
super().__init__()
|
| 234 |
-
self.
|
| 235 |
-
self.
|
| 236 |
-
self.register_tokens = nn.Parameter(torch.empty(1, N_REGISTERS, D_MODEL))
|
| 237 |
-
self.patch_embeddings = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
|
| 238 |
|
| 239 |
-
def forward(self,
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
# =============================================================================
|
| 249 |
-
# Tagger
|
| 250 |
# =============================================================================
|
| 251 |
|
| 252 |
class DINOv3Tagger(nn.Module):
|
| 253 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
|
| 271 |
|
| 272 |
# =============================================================================
|
|
@@ -274,7 +407,7 @@ class DINOv3Tagger(nn.Module):
|
|
| 274 |
# =============================================================================
|
| 275 |
|
| 276 |
_IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 277 |
-
_IMAGENET_STD
|
| 278 |
|
| 279 |
|
| 280 |
def _snap(x: int, m: int) -> int:
|
|
@@ -291,12 +424,22 @@ def _open_image(source) -> Image.Image:
|
|
| 291 |
|
| 292 |
|
| 293 |
def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
|
| 294 |
-
"""Load and preprocess an image β [1, 3, H, W] float32, ImageNet-normalised.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
img = _open_image(source)
|
| 296 |
w, h = img.size
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
return v2.Compose([
|
| 301 |
v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
|
| 302 |
v2.ToImage(),
|
|
@@ -315,13 +458,15 @@ class Tagger:
|
|
| 315 |
Parameters
|
| 316 |
----------
|
| 317 |
checkpoint_path : str
|
| 318 |
-
Path to a .safetensors or .pth checkpoint
|
| 319 |
vocab_path : str
|
| 320 |
-
Path to tagger_vocab.json
|
|
|
|
| 321 |
device : str
|
| 322 |
-
"cuda", "cuda:0", "cpu",
|
| 323 |
dtype : torch.dtype
|
| 324 |
-
bfloat16 recommended on Ampere+
|
|
|
|
| 325 |
max_size : int
|
| 326 |
Long-edge cap in pixels before feeding to the model.
|
| 327 |
"""
|
|
@@ -334,8 +479,13 @@ class Tagger:
|
|
| 334 |
dtype: torch.dtype = torch.bfloat16,
|
| 335 |
max_size: int = 1024,
|
| 336 |
):
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
self.max_size = max_size
|
| 340 |
|
| 341 |
with open(vocab_path) as f:
|
|
@@ -344,36 +494,47 @@ class Tagger:
|
|
| 344 |
self.num_tags = len(self.idx2tag)
|
| 345 |
print(f"[Tagger] Vocabulary: {self.num_tags:,} tags")
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
print(f"[Tagger] Loading checkpoint: {checkpoint_path}")
|
| 350 |
if checkpoint_path.endswith((".safetensors", ".sft")):
|
| 351 |
-
sd = load_file(checkpoint_path, device=
|
| 352 |
else:
|
| 353 |
-
sd = torch.load(checkpoint_path, map_location=
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
self.model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
self.model.eval()
|
| 364 |
-
print(f"[Tagger] Ready on {self.device} ({dtype})")
|
| 365 |
|
| 366 |
@torch.no_grad()
|
| 367 |
def predict(self, image, topk: int | None = 30,
|
| 368 |
threshold: float | None = None) -> list[tuple[str, float]]:
|
| 369 |
-
"""Tag a single image (local path or URL).
|
| 370 |
-
Specify either topk OR threshold. Returns [(tag, score), ...] desc."""
|
| 371 |
if topk is None and threshold is None:
|
| 372 |
topk = 30
|
| 373 |
|
| 374 |
pv = preprocess_image(image, max_size=self.max_size).to(self.device)
|
| 375 |
-
|
| 376 |
-
logits = self.model(pv)[0]
|
| 377 |
scores = torch.sigmoid(logits.float())
|
| 378 |
|
| 379 |
if topk is not None:
|
|
@@ -381,17 +542,18 @@ class Tagger:
|
|
| 381 |
else:
|
| 382 |
assert threshold is not None
|
| 383 |
indices = (scores >= threshold).nonzero(as_tuple=True)[0]
|
| 384 |
-
values
|
| 385 |
-
order
|
| 386 |
indices, values = indices[order], values[order]
|
| 387 |
|
| 388 |
-
return [(self.idx2tag[i], float(v))
|
|
|
|
| 389 |
|
| 390 |
@torch.no_grad()
|
| 391 |
def predict_batch(self, images, topk: int | None = 30,
|
| 392 |
-
threshold: float | None = None)
|
| 393 |
-
|
| 394 |
-
|
| 395 |
|
| 396 |
|
| 397 |
# =============================================================================
|
|
@@ -399,17 +561,20 @@ class Tagger:
|
|
| 399 |
# =============================================================================
|
| 400 |
|
| 401 |
def _fmt_pretty(path: str, results) -> str:
|
| 402 |
-
lines = [f"\n{'β' * 60}", f"
|
| 403 |
for rank, (tag, score) in enumerate(results, 1):
|
| 404 |
bar = "β" * int(score * 20)
|
| 405 |
-
lines.append(f"
|
| 406 |
return "\n".join(lines)
|
| 407 |
|
|
|
|
| 408 |
def _fmt_tags(results) -> str:
|
| 409 |
return ", ".join(tag for tag, _ in results)
|
| 410 |
|
|
|
|
| 411 |
def _fmt_json(path: str, results) -> dict:
|
| 412 |
-
return {"file": path,
|
|
|
|
| 413 |
|
| 414 |
|
| 415 |
# =============================================================================
|
|
@@ -418,28 +583,40 @@ def _fmt_json(path: str, results) -> dict:
|
|
| 418 |
|
| 419 |
def main():
|
| 420 |
parser = argparse.ArgumentParser(
|
| 421 |
-
description="DINOv3 ViT-H/16+ tagger inference (standalone
|
| 422 |
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 423 |
)
|
| 424 |
-
parser.add_argument("--checkpoint", required=True,
|
| 425 |
-
|
| 426 |
-
parser.add_argument("--
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
parser.add_argument("--max-size", type=int, default=1024,
|
| 429 |
-
help="Long-edge cap in pixels
|
| 430 |
|
| 431 |
mode = parser.add_mutually_exclusive_group()
|
| 432 |
-
mode.add_argument("--topk",
|
| 433 |
-
|
|
|
|
|
|
|
| 434 |
|
| 435 |
parser.add_argument("--format", choices=["pretty", "tags", "json"],
|
| 436 |
default="pretty", help="Output format (default: pretty)")
|
| 437 |
args = parser.parse_args()
|
| 438 |
|
| 439 |
-
tagger = Tagger(
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
-
topk, threshold = (
|
|
|
|
|
|
|
| 443 |
json_out = []
|
| 444 |
|
| 445 |
for src in args.images:
|
|
@@ -448,13 +625,16 @@ def main():
|
|
| 448 |
print(f"[warning] File not found: {src}", file=sys.stderr)
|
| 449 |
continue
|
| 450 |
results = tagger.predict(src, topk=topk, threshold=threshold)
|
| 451 |
-
if
|
| 452 |
-
|
| 453 |
-
elif args.format == "
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
if args.format == "json":
|
| 456 |
print(json.dumps(json_out, indent=2, ensure_ascii=False))
|
| 457 |
|
| 458 |
|
| 459 |
if __name__ == "__main__":
|
| 460 |
-
main()
|
|
|
|
| 64 |
# All hyperparameters match facebook/dinov3-vith16plus-pretrain-lvd1689m
|
| 65 |
# =============================================================================
|
| 66 |
|
| 67 |
+
D_MODEL = 1280
|
| 68 |
+
N_HEADS = 20
|
| 69 |
+
HEAD_DIM = D_MODEL // N_HEADS # 64
|
| 70 |
+
N_LAYERS = 32
|
| 71 |
+
D_FFN = 5120
|
| 72 |
N_REGISTERS = 4
|
| 73 |
+
PATCH_SIZE = 16
|
| 74 |
+
ROPE_THETA = 100.0
|
| 75 |
+
ROPE_RESCALE = 2.0
|
| 76 |
+
LN_EPS = 1e-5
|
| 77 |
+
LAYERSCALE = 1.0
|
| 78 |
+
|
| 79 |
+
FEATURE_DIM = (1 + N_REGISTERS) * D_MODEL # 6400
|
| 80 |
|
| 81 |
|
| 82 |
# ---------------------------------------------------------------------------
|
|
|
|
| 85 |
|
| 86 |
@lru_cache(maxsize=32)
|
| 87 |
def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
|
|
|
|
| 88 |
device = torch.device(device_str)
|
| 89 |
cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
|
| 90 |
cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
|
| 91 |
coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
|
| 92 |
+
coords = 2.0 * coords - 1.0
|
| 93 |
coords = coords * ROPE_RESCALE
|
| 94 |
return coords # [h*w, 2]
|
| 95 |
|
| 96 |
|
| 97 |
def _build_rope(h_patches: int, w_patches: int,
|
| 98 |
dtype: torch.dtype, device: torch.device):
|
| 99 |
+
coords = _patch_coords_cached(h_patches, w_patches, str(device))
|
|
|
|
| 100 |
inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
|
| 101 |
+
0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device))
|
| 102 |
+
angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
|
| 103 |
+
angles = angles.flatten(1, 2).tile(2)
|
| 104 |
+
cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
|
| 105 |
sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
|
| 106 |
return cos, sin
|
| 107 |
|
|
|
|
| 113 |
|
| 114 |
def _apply_rope(q: torch.Tensor, k: torch.Tensor,
|
| 115 |
cos: torch.Tensor, sin: torch.Tensor):
|
|
|
|
| 116 |
n_pre = 1 + N_REGISTERS
|
| 117 |
q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
|
| 118 |
k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
# ---------------------------------------------------------------------------
|
| 125 |
+
# Transformer blocks
|
| 126 |
# ---------------------------------------------------------------------------
|
| 127 |
|
| 128 |
class _Attention(nn.Module):
|
|
|
|
| 133 |
self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
|
| 134 |
self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
|
| 135 |
|
| 136 |
+
def forward(self, x, cos, sin):
|
| 137 |
B, S, _ = x.shape
|
| 138 |
q = self.q_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
|
| 139 |
k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
|
|
|
|
| 147 |
def __init__(self):
|
| 148 |
super().__init__()
|
| 149 |
self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
|
| 150 |
+
self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
|
| 151 |
+
self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True)
|
| 152 |
|
| 153 |
+
def forward(self, x):
|
| 154 |
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 155 |
|
| 156 |
|
| 157 |
class _Block(nn.Module):
|
| 158 |
def __init__(self):
|
| 159 |
super().__init__()
|
| 160 |
+
self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
|
| 161 |
+
self.attention = _Attention()
|
| 162 |
self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
|
| 163 |
+
self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
|
| 164 |
+
self.mlp = _GatedMLP()
|
| 165 |
self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
|
| 166 |
|
| 167 |
+
def forward(self, x, cos, sin):
|
| 168 |
x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
|
| 169 |
x = x + self.mlp(self.norm2(x)) * self.layer_scale2
|
| 170 |
return x
|
| 171 |
|
| 172 |
|
| 173 |
+
class _Embeddings(nn.Module):
|
| 174 |
+
def __init__(self):
|
| 175 |
+
super().__init__()
|
| 176 |
+
# zeros() rather than empty() so a forgotten checkpoint key fails
|
| 177 |
+
# predictably instead of producing undefined outputs.
|
| 178 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
|
| 179 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
|
| 180 |
+
self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
|
| 181 |
+
self.patch_embeddings = nn.Conv2d(
|
| 182 |
+
3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
|
| 183 |
+
|
| 184 |
+
def forward(self, pixel_values):
|
| 185 |
+
B = pixel_values.shape[0]
|
| 186 |
+
dtype = self.patch_embeddings.weight.dtype
|
| 187 |
+
patches = self.patch_embeddings(
|
| 188 |
+
pixel_values.to(dtype)).flatten(2).transpose(1, 2)
|
| 189 |
+
cls = self.cls_token.expand(B, -1, -1)
|
| 190 |
+
regs = self.register_tokens.expand(B, -1, -1)
|
| 191 |
+
return torch.cat([cls, regs, patches], dim=1)
|
| 192 |
+
|
| 193 |
|
| 194 |
class DINOv3ViTH(nn.Module):
|
| 195 |
"""DINOv3 ViT-H/16+ backbone.
|
| 196 |
|
|
|
|
|
|
|
| 197 |
Token layout: [CLS, reg_0..reg_3, patch_0..patch_N].
|
| 198 |
+
Returns last_hidden_state [B, 1+R+P, D_MODEL].
|
|
|
|
|
|
|
| 199 |
"""
|
| 200 |
|
| 201 |
def __init__(self):
|
| 202 |
super().__init__()
|
|
|
|
| 203 |
self.embeddings = _Embeddings()
|
| 204 |
self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
|
| 205 |
+
self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
def forward(self, pixel_values):
|
| 208 |
+
_, _, H, W = pixel_values.shape
|
| 209 |
+
x = self.embeddings(pixel_values)
|
| 210 |
h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
|
| 211 |
cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
|
|
|
|
| 212 |
for block in self.layer:
|
| 213 |
x = block(x, cos, sin)
|
|
|
|
| 214 |
return self.norm(x)
|
| 215 |
|
| 216 |
|
| 217 |
+
# =============================================================================
|
| 218 |
+
# Head β auto-detected from the checkpoint
|
| 219 |
+
# =============================================================================
|
| 220 |
+
|
| 221 |
+
class _LowRankHead(nn.Module):
|
| 222 |
+
"""Two-matrix low-rank projection head.
|
| 223 |
+
|
| 224 |
+
features (in_dim)
|
| 225 |
+
β Linear(in_dim, rank, bias=?)
|
| 226 |
+
β Linear(rank, num_tags, bias=?)
|
| 227 |
"""
|
| 228 |
|
| 229 |
+
def __init__(self, in_dim: int, rank: int, num_tags: int,
|
| 230 |
+
down_bias: bool, up_bias: bool):
|
| 231 |
super().__init__()
|
| 232 |
+
self.proj_down = nn.Linear(in_dim, rank, bias=down_bias)
|
| 233 |
+
self.proj_up = nn.Linear(rank, num_tags, bias=up_bias)
|
|
|
|
|
|
|
| 234 |
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
return self.proj_up(self.proj_down(x))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _build_head_from_checkpoint(
|
| 240 |
+
head_sd: dict,
|
| 241 |
+
in_dim: int,
|
| 242 |
+
num_tags: int,
|
| 243 |
+
) -> tuple[nn.Module, dict]:
|
| 244 |
+
"""Inspect head_sd and build a matching Module.
|
| 245 |
+
|
| 246 |
+
Supports two layouts, in order of preference:
|
| 247 |
+
1. Single linear β any ``*.weight`` with shape [num_tags, in_dim]
|
| 248 |
+
2. Low-rank pair (2 mats) β one ``*.weight`` [rank, in_dim] plus
|
| 249 |
+
one ``*.weight`` [num_tags, rank]
|
| 250 |
+
|
| 251 |
+
Returns (module, remapped_state_dict) where the remapped state dict
|
| 252 |
+
matches the module's own key names so strict loading works.
|
| 253 |
+
"""
|
| 254 |
+
weights_2d = [(k, v) for k, v in head_sd.items()
|
| 255 |
+
if k.endswith(".weight") and v.ndim == 2]
|
| 256 |
+
|
| 257 |
+
# --- Case 1: single dense linear ---------------------------------------
|
| 258 |
+
singles = [(k, v) for k, v in weights_2d
|
| 259 |
+
if tuple(v.shape) == (num_tags, in_dim)]
|
| 260 |
+
if len(weights_2d) <= 2 and len(singles) == 1:
|
| 261 |
+
wkey, wval = singles[0]
|
| 262 |
+
base = wkey[:-len(".weight")]
|
| 263 |
+
bias_key = base + ".bias"
|
| 264 |
+
has_bias = bias_key in head_sd
|
| 265 |
+
module = nn.Linear(in_dim, num_tags, bias=has_bias)
|
| 266 |
+
remapped = {"weight": wval}
|
| 267 |
+
if has_bias:
|
| 268 |
+
remapped["bias"] = head_sd[bias_key]
|
| 269 |
+
# Sanity check: no extra keys we don't understand
|
| 270 |
+
expected_src = {wkey} | ({bias_key} if has_bias else set())
|
| 271 |
+
extra = set(head_sd) - expected_src
|
| 272 |
+
if extra:
|
| 273 |
+
raise RuntimeError(
|
| 274 |
+
f"Head has single-linear shape but extra unknown keys: {sorted(extra)}")
|
| 275 |
+
return module, remapped
|
| 276 |
+
|
| 277 |
+
# --- Case 2: low-rank pair ---------------------------------------------
|
| 278 |
+
down = None # (key, tensor) with shape [rank, in_dim]
|
| 279 |
+
up = None # (key, tensor) with shape [num_tags, rank]
|
| 280 |
+
for k, v in weights_2d:
|
| 281 |
+
if v.shape[1] == in_dim and v.shape[0] != num_tags:
|
| 282 |
+
down = (k, v)
|
| 283 |
+
elif v.shape[0] == num_tags and v.shape[1] != in_dim:
|
| 284 |
+
up = (k, v)
|
| 285 |
+
|
| 286 |
+
if down is not None and up is not None:
|
| 287 |
+
rank_down = down[1].shape[0]
|
| 288 |
+
rank_up = up[1].shape[1]
|
| 289 |
+
if rank_down != rank_up:
|
| 290 |
+
raise RuntimeError(
|
| 291 |
+
f"Low-rank head: inner dims disagree "
|
| 292 |
+
f"(down out={rank_down}, up in={rank_up})")
|
| 293 |
+
|
| 294 |
+
down_key, down_w = down
|
| 295 |
+
up_key, up_w = up
|
| 296 |
+
down_base = down_key[:-len(".weight")]
|
| 297 |
+
up_base = up_key[:-len(".weight")]
|
| 298 |
+
down_bias_key = down_base + ".bias"
|
| 299 |
+
up_bias_key = up_base + ".bias"
|
| 300 |
+
has_down_bias = down_bias_key in head_sd
|
| 301 |
+
has_up_bias = up_bias_key in head_sd
|
| 302 |
+
|
| 303 |
+
module = _LowRankHead(
|
| 304 |
+
in_dim=in_dim,
|
| 305 |
+
rank=rank_down,
|
| 306 |
+
num_tags=num_tags,
|
| 307 |
+
down_bias=has_down_bias,
|
| 308 |
+
up_bias=has_up_bias,
|
| 309 |
+
)
|
| 310 |
+
remapped = {
|
| 311 |
+
"proj_down.weight": down_w,
|
| 312 |
+
"proj_up.weight": up_w,
|
| 313 |
+
}
|
| 314 |
+
if has_down_bias:
|
| 315 |
+
remapped["proj_down.bias"] = head_sd[down_bias_key]
|
| 316 |
+
if has_up_bias:
|
| 317 |
+
remapped["proj_up.bias"] = head_sd[up_bias_key]
|
| 318 |
+
|
| 319 |
+
# Sanity check
|
| 320 |
+
expected_src = {down_key, up_key}
|
| 321 |
+
if has_down_bias:
|
| 322 |
+
expected_src.add(down_bias_key)
|
| 323 |
+
if has_up_bias:
|
| 324 |
+
expected_src.add(up_bias_key)
|
| 325 |
+
extra = set(head_sd) - expected_src
|
| 326 |
+
if extra:
|
| 327 |
+
raise RuntimeError(
|
| 328 |
+
f"Low-rank head detected but checkpoint has extra unknown "
|
| 329 |
+
f"head keys: {sorted(extra)}")
|
| 330 |
+
|
| 331 |
+
print(f"[Tagger] Detected low-rank head: "
|
| 332 |
+
f"in_dim={in_dim}, rank={rank_down}, num_tags={num_tags} "
|
| 333 |
+
f"(down_bias={has_down_bias}, up_bias={has_up_bias})")
|
| 334 |
+
return module, remapped
|
| 335 |
+
|
| 336 |
+
raise RuntimeError(
|
| 337 |
+
"Could not infer head architecture from checkpoint. "
|
| 338 |
+
f"Non-backbone keys found: {sorted(head_sd.keys())}"
|
| 339 |
+
)
|
| 340 |
|
| 341 |
|
| 342 |
# =============================================================================
|
| 343 |
+
# Tagger wrapper module
|
| 344 |
# =============================================================================
|
| 345 |
|
| 346 |
class DINOv3Tagger(nn.Module):
|
| 347 |
+
"""Backbone + head. The head is attached after the checkpoint is
|
| 348 |
+
inspected (so we can build the right shape)."""
|
| 349 |
+
|
| 350 |
+
def __init__(self):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.backbone = DINOv3ViTH()
|
| 353 |
+
self.head: nn.Module | None = None # attached by Tagger
|
| 354 |
|
| 355 |
+
def forward(self, pixel_values):
|
| 356 |
+
hidden = self.backbone(pixel_values)
|
| 357 |
+
cls = hidden[:, 0, :]
|
| 358 |
+
regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1)
|
| 359 |
+
features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
|
| 360 |
+
return self.head(features)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# =============================================================================
|
| 364 |
+
# Checkpoint loading helpers
|
| 365 |
+
# =============================================================================
|
| 366 |
+
|
| 367 |
+
def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
|
| 368 |
+
"""Split full state dict into (backbone_sd, head_sd), stripping the
|
| 369 |
+
``backbone.`` prefix and applying the remaps needed to match
|
| 370 |
+
``DINOv3ViTH``'s parameter layout:
|
| 371 |
+
|
| 372 |
+
1. ``backbone.model.layer.N.*`` β ``layer.N.*``
|
| 373 |
+
(the checkpoint has an HF-style intermediate ``model`` wrapper
|
| 374 |
+
that our flat backbone class does not)
|
| 375 |
+
2. ``...layer_scale{1,2}.lambda1`` β ``...layer_scale{1,2}``
|
| 376 |
+
(HF stores layer_scale as a sub-module with a ``lambda1``
|
| 377 |
+
parameter; we use a plain ``nn.Parameter``)
|
| 378 |
+
3. Drop any ``rope_embeddings`` buffers (recomputed on the fly)
|
| 379 |
"""
|
| 380 |
+
backbone_sd: dict = {}
|
| 381 |
+
head_sd: dict = {}
|
| 382 |
+
for k, v in sd.items():
|
| 383 |
+
if k.startswith("backbone."):
|
| 384 |
+
nk = k[len("backbone."):]
|
| 385 |
+
# Remap (1): strip intermediate "model." before "layer."
|
| 386 |
+
if nk.startswith("model.layer."):
|
| 387 |
+
nk = nk[len("model."):]
|
| 388 |
+
backbone_sd[nk] = v
|
| 389 |
+
else:
|
| 390 |
+
head_sd[k] = v
|
| 391 |
|
| 392 |
+
# Remap (2): layer.N.layer_scale{1,2}.lambda1 β layer.N.layer_scale{1,2}
|
| 393 |
+
for k in list(backbone_sd.keys()):
|
| 394 |
+
if ".layer_scale" in k and k.endswith(".lambda1"):
|
| 395 |
+
backbone_sd[k[:-len(".lambda1")]] = backbone_sd.pop(k)
|
| 396 |
|
| 397 |
+
# Remap (3): drop rope buffers (recomputed on the fly)
|
| 398 |
+
for k in list(backbone_sd.keys()):
|
| 399 |
+
if "rope_embeddings" in k:
|
| 400 |
+
backbone_sd.pop(k)
|
| 401 |
+
|
| 402 |
+
return backbone_sd, head_sd
|
| 403 |
|
| 404 |
|
| 405 |
# =============================================================================
|
|
|
|
| 407 |
# =============================================================================
|
| 408 |
|
| 409 |
_IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 410 |
+
_IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 411 |
|
| 412 |
|
| 413 |
def _snap(x: int, m: int) -> int:
|
|
|
|
| 424 |
|
| 425 |
|
| 426 |
def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
|
| 427 |
+
"""Load and preprocess an image β [1, 3, H, W] float32, ImageNet-normalised.
|
| 428 |
+
|
| 429 |
+
Aspect ratio is preserved: a single scale factor is chosen so that the
|
| 430 |
+
long edge fits inside max_size after snapping to a PATCH_SIZE multiple.
|
| 431 |
+
"""
|
| 432 |
img = _open_image(source)
|
| 433 |
w, h = img.size
|
| 434 |
+
|
| 435 |
+
# Target long-edge (snapped to patch multiple).
|
| 436 |
+
long_edge = max(w, h)
|
| 437 |
+
target_long = _snap(min(long_edge, max_size), PATCH_SIZE)
|
| 438 |
+
scale = target_long / long_edge
|
| 439 |
+
|
| 440 |
+
new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
|
| 441 |
+
new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
|
| 442 |
+
|
| 443 |
return v2.Compose([
|
| 444 |
v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
|
| 445 |
v2.ToImage(),
|
|
|
|
| 458 |
Parameters
|
| 459 |
----------
|
| 460 |
checkpoint_path : str
|
| 461 |
+
Path to a .safetensors or .pt/.pth checkpoint.
|
| 462 |
vocab_path : str
|
| 463 |
+
Path to tagger_vocab.json or tagger_vocab_with_categories.json
|
| 464 |
+
(either must contain an ``idx2tag`` list).
|
| 465 |
device : str
|
| 466 |
+
"cuda", "cuda:0", "cpu", ...
|
| 467 |
dtype : torch.dtype
|
| 468 |
+
Backbone precision. bfloat16 recommended on Ampere+, float16 for
|
| 469 |
+
older GPUs, float32 for CPU. The head always runs in fp32.
|
| 470 |
max_size : int
|
| 471 |
Long-edge cap in pixels before feeding to the model.
|
| 472 |
"""
|
|
|
|
| 479 |
dtype: torch.dtype = torch.bfloat16,
|
| 480 |
max_size: int = 1024,
|
| 481 |
):
|
| 482 |
+
want_cuda = device.startswith("cuda")
|
| 483 |
+
if want_cuda and not torch.cuda.is_available():
|
| 484 |
+
print("[Tagger] CUDA not available, falling back to CPU")
|
| 485 |
+
device = "cpu"
|
| 486 |
+
dtype = torch.float32
|
| 487 |
+
self.device = torch.device(device)
|
| 488 |
+
self.dtype = dtype
|
| 489 |
self.max_size = max_size
|
| 490 |
|
| 491 |
with open(vocab_path) as f:
|
|
|
|
| 494 |
self.num_tags = len(self.idx2tag)
|
| 495 |
print(f"[Tagger] Vocabulary: {self.num_tags:,} tags")
|
| 496 |
|
| 497 |
+
# --- Load checkpoint to CPU first so we can inspect shapes ---------
|
|
|
|
| 498 |
print(f"[Tagger] Loading checkpoint: {checkpoint_path}")
|
| 499 |
if checkpoint_path.endswith((".safetensors", ".sft")):
|
| 500 |
+
sd = load_file(checkpoint_path, device="cpu")
|
| 501 |
else:
|
| 502 |
+
sd = torch.load(checkpoint_path, map_location="cpu")
|
| 503 |
+
|
| 504 |
+
backbone_sd, head_sd = _split_and_clean_state_dict(sd)
|
| 505 |
+
|
| 506 |
+
if not head_sd:
|
| 507 |
+
raise RuntimeError(
|
| 508 |
+
"Checkpoint contains no non-backbone keys β cannot build head.")
|
| 509 |
+
|
| 510 |
+
# --- Build model, inferring head shape from the checkpoint --------
|
| 511 |
+
self.model = DINOv3Tagger()
|
| 512 |
+
head_module, head_sd_remapped = _build_head_from_checkpoint(
|
| 513 |
+
head_sd, in_dim=FEATURE_DIM, num_tags=self.num_tags,
|
| 514 |
+
)
|
| 515 |
+
self.model.head = head_module
|
| 516 |
+
|
| 517 |
+
# --- Strict load β mismatches raise instead of silently passing ----
|
| 518 |
+
self.model.backbone.load_state_dict(backbone_sd, strict=True)
|
| 519 |
+
self.model.head.load_state_dict(head_sd_remapped, strict=True)
|
| 520 |
+
|
| 521 |
+
# --- Move to device. Backbone β bf16/fp16; head stays fp32. --------
|
| 522 |
+
self.model.backbone = self.model.backbone.to(
|
| 523 |
+
device=self.device, dtype=dtype)
|
| 524 |
+
self.model.head = self.model.head.to(
|
| 525 |
+
device=self.device, dtype=torch.float32)
|
| 526 |
self.model.eval()
|
| 527 |
+
print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
|
| 528 |
|
| 529 |
@torch.no_grad()
|
| 530 |
def predict(self, image, topk: int | None = 30,
|
| 531 |
threshold: float | None = None) -> list[tuple[str, float]]:
|
| 532 |
+
"""Tag a single image (local path or URL)."""
|
|
|
|
| 533 |
if topk is None and threshold is None:
|
| 534 |
topk = 30
|
| 535 |
|
| 536 |
pv = preprocess_image(image, max_size=self.max_size).to(self.device)
|
| 537 |
+
logits = self.model(pv)[0]
|
|
|
|
| 538 |
scores = torch.sigmoid(logits.float())
|
| 539 |
|
| 540 |
if topk is not None:
|
|
|
|
| 542 |
else:
|
| 543 |
assert threshold is not None
|
| 544 |
indices = (scores >= threshold).nonzero(as_tuple=True)[0]
|
| 545 |
+
values = scores[indices]
|
| 546 |
+
order = values.argsort(descending=True)
|
| 547 |
indices, values = indices[order], values[order]
|
| 548 |
|
| 549 |
+
return [(self.idx2tag[i], float(v))
|
| 550 |
+
for i, v in zip(indices.tolist(), values.tolist())]
|
| 551 |
|
| 552 |
@torch.no_grad()
|
| 553 |
def predict_batch(self, images, topk: int | None = 30,
|
| 554 |
+
threshold: float | None = None):
|
| 555 |
+
return [self.predict(img, topk=topk, threshold=threshold)
|
| 556 |
+
for img in images]
|
| 557 |
|
| 558 |
|
| 559 |
# =============================================================================
|
|
|
|
| 561 |
# =============================================================================
|
| 562 |
|
| 563 |
def _fmt_pretty(path: str, results) -> str:
|
| 564 |
+
lines = [f"\n{'β' * 60}", f" {path}", f"{'β' * 60}"]
|
| 565 |
for rank, (tag, score) in enumerate(results, 1):
|
| 566 |
bar = "β" * int(score * 20)
|
| 567 |
+
lines.append(f" {rank:>3}. {score:.3f} {bar:<20} {tag}")
|
| 568 |
return "\n".join(lines)
|
| 569 |
|
| 570 |
+
|
| 571 |
def _fmt_tags(results) -> str:
|
| 572 |
return ", ".join(tag for tag, _ in results)
|
| 573 |
|
| 574 |
+
|
| 575 |
def _fmt_json(path: str, results) -> dict:
|
| 576 |
+
return {"file": path,
|
| 577 |
+
"tags": [{"tag": t, "score": round(s, 4)} for t, s in results]}
|
| 578 |
|
| 579 |
|
| 580 |
# =============================================================================
|
|
|
|
| 583 |
|
| 584 |
def main():
|
| 585 |
parser = argparse.ArgumentParser(
|
| 586 |
+
description="DINOv3 ViT-H/16+ tagger inference (standalone)",
|
| 587 |
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 588 |
)
|
| 589 |
+
parser.add_argument("--checkpoint", required=True,
|
| 590 |
+
help="Path to .safetensors or .pt checkpoint")
|
| 591 |
+
parser.add_argument("--vocab", required=True,
|
| 592 |
+
help="Path to tagger_vocab*.json")
|
| 593 |
+
parser.add_argument("--images", nargs="+", required=True,
|
| 594 |
+
help="Image paths and/or http(s) URLs")
|
| 595 |
+
parser.add_argument("--device", default="cuda",
|
| 596 |
+
help="Device: cuda, cuda:0, cpu (default: cuda)")
|
| 597 |
parser.add_argument("--max-size", type=int, default=1024,
|
| 598 |
+
help="Long-edge cap in pixels (default: 1024)")
|
| 599 |
|
| 600 |
mode = parser.add_mutually_exclusive_group()
|
| 601 |
+
mode.add_argument("--topk", type=int, default=30,
|
| 602 |
+
help="Return top-k tags (default: 30)")
|
| 603 |
+
mode.add_argument("--threshold", type=float,
|
| 604 |
+
help="Return all tags with score >= threshold")
|
| 605 |
|
| 606 |
parser.add_argument("--format", choices=["pretty", "tags", "json"],
|
| 607 |
default="pretty", help="Output format (default: pretty)")
|
| 608 |
args = parser.parse_args()
|
| 609 |
|
| 610 |
+
tagger = Tagger(
|
| 611 |
+
checkpoint_path=args.checkpoint,
|
| 612 |
+
vocab_path=args.vocab,
|
| 613 |
+
device=args.device,
|
| 614 |
+
max_size=args.max_size,
|
| 615 |
+
)
|
| 616 |
|
| 617 |
+
topk, threshold = (
|
| 618 |
+
(None, args.threshold) if args.threshold else (args.topk, None)
|
| 619 |
+
)
|
| 620 |
json_out = []
|
| 621 |
|
| 622 |
for src in args.images:
|
|
|
|
| 625 |
print(f"[warning] File not found: {src}", file=sys.stderr)
|
| 626 |
continue
|
| 627 |
results = tagger.predict(src, topk=topk, threshold=threshold)
|
| 628 |
+
if args.format == "pretty":
|
| 629 |
+
print(_fmt_pretty(src, results))
|
| 630 |
+
elif args.format == "tags":
|
| 631 |
+
print(_fmt_tags(results))
|
| 632 |
+
elif args.format == "json":
|
| 633 |
+
json_out.append(_fmt_json(src, results))
|
| 634 |
|
| 635 |
if args.format == "json":
|
| 636 |
print(json.dumps(json_out, indent=2, ensure_ascii=False))
|
| 637 |
|
| 638 |
|
| 639 |
if __name__ == "__main__":
|
| 640 |
+
main()
|