ClintHardwood commited on
Commit
3136b21
Β·
1 Parent(s): 93a0200

Update inference_tagger_standalone.py

Browse files

Fix backbone state-dict loading: remap backbone.model.layer.* β†’ backbone.layer.*
The checkpoint stores the 32 transformer blocks under backbone.model.layer.N.* (HF-style, with an intermediate model wrapper), but DINOv3ViTH in this script declares them at backbone.layer.N.*. Combined with strict=False, assign=True in load_state_dict, all 608 block parameters (32 layers Γ— 19 tensors per block) were silently failing to load β€” the backbone ran on default nn.Linear / nn.LayerNorm initializations while only the head loaded correctly. The only hint was a printed [Tagger] Missing keys (608): ['backbone.layer.0.layer_scale1', ...] line that was easy to miss, and the model produced plausible-looking but essentially random tag predictions, making it feel like undertraining.
Confirmed by dumping tagger_proto.safetensors keys β€” they're all under backbone.model.layer.N.* and the head is a single projection.weight of shape (74625, 6400).
Changes:

Strip the intermediate model. segment from backbone keys during loading so backbone.model.layer.N.* maps to self.layer[N].* correctly.
Load both backbone and head with strict=True, so any future name/shape drift fails loudly at load time instead of silently returning noise.
Auto-detect head layout (currently a single Linear) so this class of silent mis-load can't recur if the head changes later.
Minor: consistent aspect-ratio preservation in preprocessing, torch.zeros instead of torch.empty for embedding parameters, drop the redundant torch.autocast wrapper (backbone is explicitly cast to bf16, head stays fp32 per the training recipe).

Verified by running the loader against a synthesized state dict matching the real key layout (616 keys: 5 embedding + 608 block + 2 final norm + 1 head) β€” strict load passes and a forward returns the right logit shape. Also confirmed by another user who hit the same bug and fixed it by remapping the keys, reporting that outputs went "from horrifically bad to pretty much perfect."

Files changed (1) hide show
  1. inference_tagger_standalone.py +325 -145
inference_tagger_standalone.py CHANGED
@@ -64,17 +64,19 @@ from safetensors.torch import load_file
64
  # All hyperparameters match facebook/dinov3-vith16plus-pretrain-lvd1689m
65
  # =============================================================================
66
 
67
- D_MODEL = 1280
68
- N_HEADS = 20
69
- HEAD_DIM = D_MODEL // N_HEADS # 64
70
- N_LAYERS = 32
71
- D_FFN = 5120
72
  N_REGISTERS = 4
73
- PATCH_SIZE = 16
74
- ROPE_THETA = 100.0
75
- ROPE_RESCALE = 2.0 # pos_embed_rescale applied at inference
76
- LN_EPS = 1e-5
77
- LAYERSCALE = 1.0
 
 
78
 
79
 
80
  # ---------------------------------------------------------------------------
@@ -83,25 +85,23 @@ LAYERSCALE = 1.0
83
 
84
  @lru_cache(maxsize=32)
85
  def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
86
- """Normalised [-1,+1] patch-centre coordinates (float32, cached)."""
87
  device = torch.device(device_str)
88
  cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
89
  cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
90
  coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
91
- coords = 2.0 * coords - 1.0 # [0,1] β†’ [-1,+1]
92
  coords = coords * ROPE_RESCALE
93
  return coords # [h*w, 2]
94
 
95
 
96
  def _build_rope(h_patches: int, w_patches: int,
97
  dtype: torch.dtype, device: torch.device):
98
- """Return (cos, sin) of shape [1, 1, h*w, HEAD_DIM] for broadcasting."""
99
- coords = _patch_coords_cached(h_patches, w_patches, str(device)) # [P, 2]
100
  inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
101
- 0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device)) # [D/4]
102
- angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :] # [P, 2, D/4]
103
- angles = angles.flatten(1, 2).tile(2) # [P, D]
104
- cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,P,D]
105
  sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
106
  return cos, sin
107
 
@@ -113,7 +113,6 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
113
 
114
  def _apply_rope(q: torch.Tensor, k: torch.Tensor,
115
  cos: torch.Tensor, sin: torch.Tensor):
116
- """Apply RoPE only to patch tokens (skip CLS + register prefix)."""
117
  n_pre = 1 + N_REGISTERS
118
  q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
119
  k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
@@ -123,7 +122,7 @@ def _apply_rope(q: torch.Tensor, k: torch.Tensor,
123
 
124
 
125
  # ---------------------------------------------------------------------------
126
- # Building blocks
127
  # ---------------------------------------------------------------------------
128
 
129
  class _Attention(nn.Module):
@@ -134,7 +133,7 @@ class _Attention(nn.Module):
134
  self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
135
  self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
136
 
137
- def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
138
  B, S, _ = x.shape
139
  q = self.q_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
140
  k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
@@ -148,125 +147,259 @@ class _GatedMLP(nn.Module):
148
  def __init__(self):
149
  super().__init__()
150
  self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
151
- self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
152
- self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True)
153
 
154
- def forward(self, x: torch.Tensor) -> torch.Tensor:
155
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
156
 
157
 
158
  class _Block(nn.Module):
159
  def __init__(self):
160
  super().__init__()
161
- self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
162
- self.attention = _Attention()
163
  self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
164
- self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
165
- self.mlp = _GatedMLP()
166
  self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
167
 
168
- def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
169
  x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
170
  x = x + self.mlp(self.norm2(x)) * self.layer_scale2
171
  return x
172
 
173
 
174
- # ---------------------------------------------------------------------------
175
- # Full backbone
176
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  class DINOv3ViTH(nn.Module):
179
  """DINOv3 ViT-H/16+ backbone.
180
 
181
- Accepts any H, W that are multiples of 16.
182
- Returns last_hidden_state [B, 1+R+P, D_MODEL].
183
  Token layout: [CLS, reg_0..reg_3, patch_0..patch_N].
184
-
185
- State-dict keys are intentionally identical to the HuggingFace
186
- transformers layout so .safetensors checkpoints load without remapping.
187
  """
188
 
189
  def __init__(self):
190
  super().__init__()
191
- # These names must match HF exactly
192
  self.embeddings = _Embeddings()
193
  self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
194
- self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS)
195
-
196
- def _load_from_state_dict(self, state_dict, prefix, local_metadata,
197
- strict, missing_keys, unexpected_keys, error_msgs):
198
- # HF stores layer_scale as a sub-module with a "lambda1" parameter;
199
- # we store it as a plain Parameter directly on _Block.
200
- # Remap "layer.i.layer_scale{1,2}.lambda1" β†’ "layer.i.layer_scale{1,2}"
201
- for k in list(state_dict.keys()):
202
- if k.startswith(prefix) and ".layer_scale" in k and k.endswith(".lambda1"):
203
- new_k = k[:-len(".lambda1")]
204
- state_dict[new_k] = state_dict.pop(k)
205
- # Drop rope_embeddings buffer (computed on-the-fly)
206
- for k in list(state_dict.keys()):
207
- if k.startswith(prefix) and "rope_embeddings" in k:
208
- state_dict.pop(k)
209
- super()._load_from_state_dict(
210
- state_dict, prefix, local_metadata, strict,
211
- missing_keys, unexpected_keys, error_msgs)
212
-
213
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
214
- B, _, H, W = pixel_values.shape
215
- x = self.embeddings(pixel_values) # [B, 1+R+P, D]
216
 
 
 
 
217
  h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
218
  cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
219
-
220
  for block in self.layer:
221
  x = block(x, cos, sin)
222
-
223
  return self.norm(x)
224
 
225
 
226
- class _Embeddings(nn.Module):
227
- """Patch + CLS + register token embeddings.
228
- Key names match HF: embeddings.cls_token, embeddings.register_tokens,
229
- embeddings.patch_embeddings.{weight,bias}.
 
 
 
 
 
 
230
  """
231
 
232
- def __init__(self):
 
233
  super().__init__()
234
- self.cls_token = nn.Parameter(torch.empty(1, 1, D_MODEL))
235
- self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) # unused at inference
236
- self.register_tokens = nn.Parameter(torch.empty(1, N_REGISTERS, D_MODEL))
237
- self.patch_embeddings = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
238
 
239
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
240
- B = pixel_values.shape[0]
241
- dtype = self.patch_embeddings.weight.dtype
242
- patches = self.patch_embeddings(pixel_values.to(dtype)).flatten(2).transpose(1, 2)
243
- cls = self.cls_token.expand(B, -1, -1)
244
- regs = self.register_tokens.expand(B, -1, -1)
245
- return torch.cat([cls, regs, patches], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
  # =============================================================================
249
- # Tagger head
250
  # =============================================================================
251
 
252
  class DINOv3Tagger(nn.Module):
253
- """DINOv3 ViT-H/16+ backbone + linear projection head.
 
 
 
 
 
 
254
 
255
- features = concat(CLS, reg_0..reg_3) β†’ [B, (1+R)*D]
256
- projection: Linear β†’ [B, num_tags]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  """
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- def __init__(self, num_tags: int, projection_bias: bool = False):
260
- super().__init__()
261
- self.backbone = DINOv3ViTH()
262
- self.projection = nn.Linear((1 + N_REGISTERS) * D_MODEL, num_tags, bias=projection_bias)
263
 
264
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
265
- hidden = self.backbone(pixel_values) # [B, S, D]
266
- cls = hidden[:, 0, :] # [B, D]
267
- regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1) # [B, R*D]
268
- features = torch.cat([cls, regs], dim=-1) # [B, (1+R)*D]
269
- return self.projection(features.float()) # fp32 for stability
270
 
271
 
272
  # =============================================================================
@@ -274,7 +407,7 @@ class DINOv3Tagger(nn.Module):
274
  # =============================================================================
275
 
276
  _IMAGENET_MEAN = [0.485, 0.456, 0.406]
277
- _IMAGENET_STD = [0.229, 0.224, 0.225]
278
 
279
 
280
  def _snap(x: int, m: int) -> int:
@@ -291,12 +424,22 @@ def _open_image(source) -> Image.Image:
291
 
292
 
293
  def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
294
- """Load and preprocess an image β†’ [1, 3, H, W] float32, ImageNet-normalised."""
 
 
 
 
295
  img = _open_image(source)
296
  w, h = img.size
297
- scale = min(1.0, max_size / max(w, h))
298
- new_w = _snap(round(w * scale), PATCH_SIZE)
299
- new_h = _snap(round(h * scale), PATCH_SIZE)
 
 
 
 
 
 
300
  return v2.Compose([
301
  v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
302
  v2.ToImage(),
@@ -315,13 +458,15 @@ class Tagger:
315
  Parameters
316
  ----------
317
  checkpoint_path : str
318
- Path to a .safetensors or .pth checkpoint saved by TaggerTrainer.
319
  vocab_path : str
320
- Path to tagger_vocab.json ({"idx2tag": [...]}).
 
321
  device : str
322
- "cuda", "cuda:0", "cpu", etc.
323
  dtype : torch.dtype
324
- bfloat16 recommended on Ampere+; float16 for older GPUs; float32 for CPU.
 
325
  max_size : int
326
  Long-edge cap in pixels before feeding to the model.
327
  """
@@ -334,8 +479,13 @@ class Tagger:
334
  dtype: torch.dtype = torch.bfloat16,
335
  max_size: int = 1024,
336
  ):
337
- self.device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
338
- self.dtype = dtype
 
 
 
 
 
339
  self.max_size = max_size
340
 
341
  with open(vocab_path) as f:
@@ -344,36 +494,47 @@ class Tagger:
344
  self.num_tags = len(self.idx2tag)
345
  print(f"[Tagger] Vocabulary: {self.num_tags:,} tags")
346
 
347
- self.model = DINOv3Tagger(num_tags=self.num_tags)
348
-
349
  print(f"[Tagger] Loading checkpoint: {checkpoint_path}")
350
  if checkpoint_path.endswith((".safetensors", ".sft")):
351
- sd = load_file(checkpoint_path, device=str(self.device))
352
  else:
353
- sd = torch.load(checkpoint_path, map_location=str(self.device))
354
-
355
- missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
356
- if missing:
357
- print(f"[Tagger] Missing keys ({len(missing)}): {missing[:5]}{'...' if len(missing) > 5 else ''}")
358
- if unexpected:
359
- print(f"[Tagger] Unexpected keys ({len(unexpected)}): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
360
-
361
- self.model.backbone = self.model.backbone.to(dtype=dtype)
362
- self.model = self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  self.model.eval()
364
- print(f"[Tagger] Ready on {self.device} ({dtype})")
365
 
366
  @torch.no_grad()
367
  def predict(self, image, topk: int | None = 30,
368
  threshold: float | None = None) -> list[tuple[str, float]]:
369
- """Tag a single image (local path or URL).
370
- Specify either topk OR threshold. Returns [(tag, score), ...] desc."""
371
  if topk is None and threshold is None:
372
  topk = 30
373
 
374
  pv = preprocess_image(image, max_size=self.max_size).to(self.device)
375
- with torch.autocast(device_type=self.device.type, dtype=self.dtype):
376
- logits = self.model(pv)[0]
377
  scores = torch.sigmoid(logits.float())
378
 
379
  if topk is not None:
@@ -381,17 +542,18 @@ class Tagger:
381
  else:
382
  assert threshold is not None
383
  indices = (scores >= threshold).nonzero(as_tuple=True)[0]
384
- values = scores[indices]
385
- order = values.argsort(descending=True)
386
  indices, values = indices[order], values[order]
387
 
388
- return [(self.idx2tag[i], float(v)) for i, v in zip(indices.tolist(), values.tolist())]
 
389
 
390
  @torch.no_grad()
391
  def predict_batch(self, images, topk: int | None = 30,
392
- threshold: float | None = None) -> list[list[tuple[str, float]]]:
393
- """Tag multiple images (processed individually for mixed resolutions)."""
394
- return [self.predict(img, topk=topk, threshold=threshold) for img in images]
395
 
396
 
397
  # =============================================================================
@@ -399,17 +561,20 @@ class Tagger:
399
  # =============================================================================
400
 
401
  def _fmt_pretty(path: str, results) -> str:
402
- lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
403
  for rank, (tag, score) in enumerate(results, 1):
404
  bar = "β–ˆ" * int(score * 20)
405
- lines.append(f" {rank:>3}. {score:.3f} {bar:<20} {tag}")
406
  return "\n".join(lines)
407
 
 
408
  def _fmt_tags(results) -> str:
409
  return ", ".join(tag for tag, _ in results)
410
 
 
411
  def _fmt_json(path: str, results) -> dict:
412
- return {"file": path, "tags": [{"tag": t, "score": round(s, 4)} for t, s in results]}
 
413
 
414
 
415
  # =============================================================================
@@ -418,28 +583,40 @@ def _fmt_json(path: str, results) -> dict:
418
 
419
  def main():
420
  parser = argparse.ArgumentParser(
421
- description="DINOv3 ViT-H/16+ tagger inference (standalone, no transformers dep)",
422
  formatter_class=argparse.RawDescriptionHelpFormatter,
423
  )
424
- parser.add_argument("--checkpoint", required=True, help="Path to .safetensors or .pth checkpoint")
425
- parser.add_argument("--vocab", required=True, help="Path to tagger_vocab.json")
426
- parser.add_argument("--images", nargs="+", required=True, help="Image paths and/or http(s) URLs")
427
- parser.add_argument("--device", default="cuda", help="Device: cuda, cuda:0, cpu, … (default: cuda)")
 
 
 
 
428
  parser.add_argument("--max-size", type=int, default=1024,
429
- help="Long-edge cap in pixels, multiple of 16 (default: 1024)")
430
 
431
  mode = parser.add_mutually_exclusive_group()
432
- mode.add_argument("--topk", type=int, default=30, help="Return top-k tags (default: 30)")
433
- mode.add_argument("--threshold", type=float, help="Return all tags with score >= threshold")
 
 
434
 
435
  parser.add_argument("--format", choices=["pretty", "tags", "json"],
436
  default="pretty", help="Output format (default: pretty)")
437
  args = parser.parse_args()
438
 
439
- tagger = Tagger(checkpoint_path=args.checkpoint, vocab_path=args.vocab,
440
- device=args.device, max_size=args.max_size)
 
 
 
 
441
 
442
- topk, threshold = (None, args.threshold) if args.threshold else (args.topk, None)
 
 
443
  json_out = []
444
 
445
  for src in args.images:
@@ -448,13 +625,16 @@ def main():
448
  print(f"[warning] File not found: {src}", file=sys.stderr)
449
  continue
450
  results = tagger.predict(src, topk=topk, threshold=threshold)
451
- if args.format == "pretty": print(_fmt_pretty(src, results))
452
- elif args.format == "tags": print(_fmt_tags(results))
453
- elif args.format == "json": json_out.append(_fmt_json(src, results))
 
 
 
454
 
455
  if args.format == "json":
456
  print(json.dumps(json_out, indent=2, ensure_ascii=False))
457
 
458
 
459
  if __name__ == "__main__":
460
- main()
 
64
  # All hyperparameters match facebook/dinov3-vith16plus-pretrain-lvd1689m
65
  # =============================================================================
66
 
67
+ D_MODEL = 1280
68
+ N_HEADS = 20
69
+ HEAD_DIM = D_MODEL // N_HEADS # 64
70
+ N_LAYERS = 32
71
+ D_FFN = 5120
72
  N_REGISTERS = 4
73
+ PATCH_SIZE = 16
74
+ ROPE_THETA = 100.0
75
+ ROPE_RESCALE = 2.0
76
+ LN_EPS = 1e-5
77
+ LAYERSCALE = 1.0
78
+
79
+ FEATURE_DIM = (1 + N_REGISTERS) * D_MODEL # 6400
80
 
81
 
82
  # ---------------------------------------------------------------------------
 
85
 
86
  @lru_cache(maxsize=32)
87
  def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
 
88
  device = torch.device(device_str)
89
  cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
90
  cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
91
  coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
92
+ coords = 2.0 * coords - 1.0
93
  coords = coords * ROPE_RESCALE
94
  return coords # [h*w, 2]
95
 
96
 
97
  def _build_rope(h_patches: int, w_patches: int,
98
  dtype: torch.dtype, device: torch.device):
99
+ coords = _patch_coords_cached(h_patches, w_patches, str(device))
 
100
  inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
101
+ 0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device))
102
+ angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
103
+ angles = angles.flatten(1, 2).tile(2)
104
+ cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
105
  sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
106
  return cos, sin
107
 
 
113
 
114
  def _apply_rope(q: torch.Tensor, k: torch.Tensor,
115
  cos: torch.Tensor, sin: torch.Tensor):
 
116
  n_pre = 1 + N_REGISTERS
117
  q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
118
  k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
 
122
 
123
 
124
  # ---------------------------------------------------------------------------
125
+ # Transformer blocks
126
  # ---------------------------------------------------------------------------
127
 
128
  class _Attention(nn.Module):
 
133
  self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
134
  self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
135
 
136
+ def forward(self, x, cos, sin):
137
  B, S, _ = x.shape
138
  q = self.q_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
139
  k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
 
147
  def __init__(self):
148
  super().__init__()
149
  self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
150
+ self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
151
+ self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True)
152
 
153
+ def forward(self, x):
154
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
155
 
156
 
157
  class _Block(nn.Module):
158
  def __init__(self):
159
  super().__init__()
160
+ self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
161
+ self.attention = _Attention()
162
  self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
163
+ self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
164
+ self.mlp = _GatedMLP()
165
  self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
166
 
167
+ def forward(self, x, cos, sin):
168
  x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
169
  x = x + self.mlp(self.norm2(x)) * self.layer_scale2
170
  return x
171
 
172
 
173
+ class _Embeddings(nn.Module):
174
+ def __init__(self):
175
+ super().__init__()
176
+ # zeros() rather than empty() so a forgotten checkpoint key fails
177
+ # predictably instead of producing undefined outputs.
178
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
179
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
180
+ self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
181
+ self.patch_embeddings = nn.Conv2d(
182
+ 3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
183
+
184
+ def forward(self, pixel_values):
185
+ B = pixel_values.shape[0]
186
+ dtype = self.patch_embeddings.weight.dtype
187
+ patches = self.patch_embeddings(
188
+ pixel_values.to(dtype)).flatten(2).transpose(1, 2)
189
+ cls = self.cls_token.expand(B, -1, -1)
190
+ regs = self.register_tokens.expand(B, -1, -1)
191
+ return torch.cat([cls, regs, patches], dim=1)
192
+
193
 
194
  class DINOv3ViTH(nn.Module):
195
  """DINOv3 ViT-H/16+ backbone.
196
 
 
 
197
  Token layout: [CLS, reg_0..reg_3, patch_0..patch_N].
198
+ Returns last_hidden_state [B, 1+R+P, D_MODEL].
 
 
199
  """
200
 
201
  def __init__(self):
202
  super().__init__()
 
203
  self.embeddings = _Embeddings()
204
  self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
205
+ self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ def forward(self, pixel_values):
208
+ _, _, H, W = pixel_values.shape
209
+ x = self.embeddings(pixel_values)
210
  h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
211
  cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
 
212
  for block in self.layer:
213
  x = block(x, cos, sin)
 
214
  return self.norm(x)
215
 
216
 
217
+ # =============================================================================
218
+ # Head β€” auto-detected from the checkpoint
219
+ # =============================================================================
220
+
221
+ class _LowRankHead(nn.Module):
222
+ """Two-matrix low-rank projection head.
223
+
224
+ features (in_dim)
225
+ β†’ Linear(in_dim, rank, bias=?)
226
+ β†’ Linear(rank, num_tags, bias=?)
227
  """
228
 
229
+ def __init__(self, in_dim: int, rank: int, num_tags: int,
230
+ down_bias: bool, up_bias: bool):
231
  super().__init__()
232
+ self.proj_down = nn.Linear(in_dim, rank, bias=down_bias)
233
+ self.proj_up = nn.Linear(rank, num_tags, bias=up_bias)
 
 
234
 
235
+ def forward(self, x):
236
+ return self.proj_up(self.proj_down(x))
237
+
238
+
239
+ def _build_head_from_checkpoint(
240
+ head_sd: dict,
241
+ in_dim: int,
242
+ num_tags: int,
243
+ ) -> tuple[nn.Module, dict]:
244
+ """Inspect head_sd and build a matching Module.
245
+
246
+ Supports two layouts, in order of preference:
247
+ 1. Single linear β€” any ``*.weight`` with shape [num_tags, in_dim]
248
+ 2. Low-rank pair (2 mats) β€” one ``*.weight`` [rank, in_dim] plus
249
+ one ``*.weight`` [num_tags, rank]
250
+
251
+ Returns (module, remapped_state_dict) where the remapped state dict
252
+ matches the module's own key names so strict loading works.
253
+ """
254
+ weights_2d = [(k, v) for k, v in head_sd.items()
255
+ if k.endswith(".weight") and v.ndim == 2]
256
+
257
+ # --- Case 1: single dense linear ---------------------------------------
258
+ singles = [(k, v) for k, v in weights_2d
259
+ if tuple(v.shape) == (num_tags, in_dim)]
260
+ if len(weights_2d) <= 2 and len(singles) == 1:
261
+ wkey, wval = singles[0]
262
+ base = wkey[:-len(".weight")]
263
+ bias_key = base + ".bias"
264
+ has_bias = bias_key in head_sd
265
+ module = nn.Linear(in_dim, num_tags, bias=has_bias)
266
+ remapped = {"weight": wval}
267
+ if has_bias:
268
+ remapped["bias"] = head_sd[bias_key]
269
+ # Sanity check: no extra keys we don't understand
270
+ expected_src = {wkey} | ({bias_key} if has_bias else set())
271
+ extra = set(head_sd) - expected_src
272
+ if extra:
273
+ raise RuntimeError(
274
+ f"Head has single-linear shape but extra unknown keys: {sorted(extra)}")
275
+ return module, remapped
276
+
277
+ # --- Case 2: low-rank pair ---------------------------------------------
278
+ down = None # (key, tensor) with shape [rank, in_dim]
279
+ up = None # (key, tensor) with shape [num_tags, rank]
280
+ for k, v in weights_2d:
281
+ if v.shape[1] == in_dim and v.shape[0] != num_tags:
282
+ down = (k, v)
283
+ elif v.shape[0] == num_tags and v.shape[1] != in_dim:
284
+ up = (k, v)
285
+
286
+ if down is not None and up is not None:
287
+ rank_down = down[1].shape[0]
288
+ rank_up = up[1].shape[1]
289
+ if rank_down != rank_up:
290
+ raise RuntimeError(
291
+ f"Low-rank head: inner dims disagree "
292
+ f"(down out={rank_down}, up in={rank_up})")
293
+
294
+ down_key, down_w = down
295
+ up_key, up_w = up
296
+ down_base = down_key[:-len(".weight")]
297
+ up_base = up_key[:-len(".weight")]
298
+ down_bias_key = down_base + ".bias"
299
+ up_bias_key = up_base + ".bias"
300
+ has_down_bias = down_bias_key in head_sd
301
+ has_up_bias = up_bias_key in head_sd
302
+
303
+ module = _LowRankHead(
304
+ in_dim=in_dim,
305
+ rank=rank_down,
306
+ num_tags=num_tags,
307
+ down_bias=has_down_bias,
308
+ up_bias=has_up_bias,
309
+ )
310
+ remapped = {
311
+ "proj_down.weight": down_w,
312
+ "proj_up.weight": up_w,
313
+ }
314
+ if has_down_bias:
315
+ remapped["proj_down.bias"] = head_sd[down_bias_key]
316
+ if has_up_bias:
317
+ remapped["proj_up.bias"] = head_sd[up_bias_key]
318
+
319
+ # Sanity check
320
+ expected_src = {down_key, up_key}
321
+ if has_down_bias:
322
+ expected_src.add(down_bias_key)
323
+ if has_up_bias:
324
+ expected_src.add(up_bias_key)
325
+ extra = set(head_sd) - expected_src
326
+ if extra:
327
+ raise RuntimeError(
328
+ f"Low-rank head detected but checkpoint has extra unknown "
329
+ f"head keys: {sorted(extra)}")
330
+
331
+ print(f"[Tagger] Detected low-rank head: "
332
+ f"in_dim={in_dim}, rank={rank_down}, num_tags={num_tags} "
333
+ f"(down_bias={has_down_bias}, up_bias={has_up_bias})")
334
+ return module, remapped
335
+
336
+ raise RuntimeError(
337
+ "Could not infer head architecture from checkpoint. "
338
+ f"Non-backbone keys found: {sorted(head_sd.keys())}"
339
+ )
340
 
341
 
342
  # =============================================================================
343
+ # Tagger wrapper module
344
  # =============================================================================
345
 
346
  class DINOv3Tagger(nn.Module):
347
+ """Backbone + head. The head is attached after the checkpoint is
348
+ inspected (so we can build the right shape)."""
349
+
350
+ def __init__(self):
351
+ super().__init__()
352
+ self.backbone = DINOv3ViTH()
353
+ self.head: nn.Module | None = None # attached by Tagger
354
 
355
+ def forward(self, pixel_values):
356
+ hidden = self.backbone(pixel_values)
357
+ cls = hidden[:, 0, :]
358
+ regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1)
359
+ features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
360
+ return self.head(features)
361
+
362
+
363
+ # =============================================================================
364
+ # Checkpoint loading helpers
365
+ # =============================================================================
366
+
367
+ def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
368
+ """Split full state dict into (backbone_sd, head_sd), stripping the
369
+ ``backbone.`` prefix and applying the remaps needed to match
370
+ ``DINOv3ViTH``'s parameter layout:
371
+
372
+ 1. ``backbone.model.layer.N.*`` β†’ ``layer.N.*``
373
+ (the checkpoint has an HF-style intermediate ``model`` wrapper
374
+ that our flat backbone class does not)
375
+ 2. ``...layer_scale{1,2}.lambda1`` β†’ ``...layer_scale{1,2}``
376
+ (HF stores layer_scale as a sub-module with a ``lambda1``
377
+ parameter; we use a plain ``nn.Parameter``)
378
+ 3. Drop any ``rope_embeddings`` buffers (recomputed on the fly)
379
  """
380
+ backbone_sd: dict = {}
381
+ head_sd: dict = {}
382
+ for k, v in sd.items():
383
+ if k.startswith("backbone."):
384
+ nk = k[len("backbone."):]
385
+ # Remap (1): strip intermediate "model." before "layer."
386
+ if nk.startswith("model.layer."):
387
+ nk = nk[len("model."):]
388
+ backbone_sd[nk] = v
389
+ else:
390
+ head_sd[k] = v
391
 
392
+ # Remap (2): layer.N.layer_scale{1,2}.lambda1 β†’ layer.N.layer_scale{1,2}
393
+ for k in list(backbone_sd.keys()):
394
+ if ".layer_scale" in k and k.endswith(".lambda1"):
395
+ backbone_sd[k[:-len(".lambda1")]] = backbone_sd.pop(k)
396
 
397
+ # Remap (3): drop rope buffers (recomputed on the fly)
398
+ for k in list(backbone_sd.keys()):
399
+ if "rope_embeddings" in k:
400
+ backbone_sd.pop(k)
401
+
402
+ return backbone_sd, head_sd
403
 
404
 
405
  # =============================================================================
 
407
  # =============================================================================
408
 
409
  _IMAGENET_MEAN = [0.485, 0.456, 0.406]
410
+ _IMAGENET_STD = [0.229, 0.224, 0.225]
411
 
412
 
413
  def _snap(x: int, m: int) -> int:
 
424
 
425
 
426
  def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
427
+ """Load and preprocess an image β†’ [1, 3, H, W] float32, ImageNet-normalised.
428
+
429
+ Aspect ratio is preserved: a single scale factor is chosen so that the
430
+ long edge fits inside max_size after snapping to a PATCH_SIZE multiple.
431
+ """
432
  img = _open_image(source)
433
  w, h = img.size
434
+
435
+ # Target long-edge (snapped to patch multiple).
436
+ long_edge = max(w, h)
437
+ target_long = _snap(min(long_edge, max_size), PATCH_SIZE)
438
+ scale = target_long / long_edge
439
+
440
+ new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
441
+ new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
442
+
443
  return v2.Compose([
444
  v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
445
  v2.ToImage(),
 
458
  Parameters
459
  ----------
460
  checkpoint_path : str
461
+ Path to a .safetensors or .pt/.pth checkpoint.
462
  vocab_path : str
463
+ Path to tagger_vocab.json or tagger_vocab_with_categories.json
464
+ (either must contain an ``idx2tag`` list).
465
  device : str
466
+ "cuda", "cuda:0", "cpu", ...
467
  dtype : torch.dtype
468
+ Backbone precision. bfloat16 recommended on Ampere+, float16 for
469
+ older GPUs, float32 for CPU. The head always runs in fp32.
470
  max_size : int
471
  Long-edge cap in pixels before feeding to the model.
472
  """
 
479
  dtype: torch.dtype = torch.bfloat16,
480
  max_size: int = 1024,
481
  ):
482
+ want_cuda = device.startswith("cuda")
483
+ if want_cuda and not torch.cuda.is_available():
484
+ print("[Tagger] CUDA not available, falling back to CPU")
485
+ device = "cpu"
486
+ dtype = torch.float32
487
+ self.device = torch.device(device)
488
+ self.dtype = dtype
489
  self.max_size = max_size
490
 
491
  with open(vocab_path) as f:
 
494
  self.num_tags = len(self.idx2tag)
495
  print(f"[Tagger] Vocabulary: {self.num_tags:,} tags")
496
 
497
+ # --- Load checkpoint to CPU first so we can inspect shapes ---------
 
498
  print(f"[Tagger] Loading checkpoint: {checkpoint_path}")
499
  if checkpoint_path.endswith((".safetensors", ".sft")):
500
+ sd = load_file(checkpoint_path, device="cpu")
501
  else:
502
+ sd = torch.load(checkpoint_path, map_location="cpu")
503
+
504
+ backbone_sd, head_sd = _split_and_clean_state_dict(sd)
505
+
506
+ if not head_sd:
507
+ raise RuntimeError(
508
+ "Checkpoint contains no non-backbone keys β€” cannot build head.")
509
+
510
+ # --- Build model, inferring head shape from the checkpoint --------
511
+ self.model = DINOv3Tagger()
512
+ head_module, head_sd_remapped = _build_head_from_checkpoint(
513
+ head_sd, in_dim=FEATURE_DIM, num_tags=self.num_tags,
514
+ )
515
+ self.model.head = head_module
516
+
517
+ # --- Strict load β€” mismatches raise instead of silently passing ----
518
+ self.model.backbone.load_state_dict(backbone_sd, strict=True)
519
+ self.model.head.load_state_dict(head_sd_remapped, strict=True)
520
+
521
+ # --- Move to device. Backbone β†’ bf16/fp16; head stays fp32. --------
522
+ self.model.backbone = self.model.backbone.to(
523
+ device=self.device, dtype=dtype)
524
+ self.model.head = self.model.head.to(
525
+ device=self.device, dtype=torch.float32)
526
  self.model.eval()
527
+ print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
528
 
529
  @torch.no_grad()
530
  def predict(self, image, topk: int | None = 30,
531
  threshold: float | None = None) -> list[tuple[str, float]]:
532
+ """Tag a single image (local path or URL)."""
 
533
  if topk is None and threshold is None:
534
  topk = 30
535
 
536
  pv = preprocess_image(image, max_size=self.max_size).to(self.device)
537
+ logits = self.model(pv)[0]
 
538
  scores = torch.sigmoid(logits.float())
539
 
540
  if topk is not None:
 
542
  else:
543
  assert threshold is not None
544
  indices = (scores >= threshold).nonzero(as_tuple=True)[0]
545
+ values = scores[indices]
546
+ order = values.argsort(descending=True)
547
  indices, values = indices[order], values[order]
548
 
549
+ return [(self.idx2tag[i], float(v))
550
+ for i, v in zip(indices.tolist(), values.tolist())]
551
 
552
  @torch.no_grad()
553
  def predict_batch(self, images, topk: int | None = 30,
554
+ threshold: float | None = None):
555
+ return [self.predict(img, topk=topk, threshold=threshold)
556
+ for img in images]
557
 
558
 
559
  # =============================================================================
 
561
  # =============================================================================
562
 
563
  def _fmt_pretty(path: str, results) -> str:
564
+ lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
565
  for rank, (tag, score) in enumerate(results, 1):
566
  bar = "β–ˆ" * int(score * 20)
567
+ lines.append(f" {rank:>3}. {score:.3f} {bar:<20} {tag}")
568
  return "\n".join(lines)
569
 
570
+
571
  def _fmt_tags(results) -> str:
572
  return ", ".join(tag for tag, _ in results)
573
 
574
+
575
  def _fmt_json(path: str, results) -> dict:
576
+ return {"file": path,
577
+ "tags": [{"tag": t, "score": round(s, 4)} for t, s in results]}
578
 
579
 
580
  # =============================================================================
 
583
 
584
  def main():
585
  parser = argparse.ArgumentParser(
586
+ description="DINOv3 ViT-H/16+ tagger inference (standalone)",
587
  formatter_class=argparse.RawDescriptionHelpFormatter,
588
  )
589
+ parser.add_argument("--checkpoint", required=True,
590
+ help="Path to .safetensors or .pt checkpoint")
591
+ parser.add_argument("--vocab", required=True,
592
+ help="Path to tagger_vocab*.json")
593
+ parser.add_argument("--images", nargs="+", required=True,
594
+ help="Image paths and/or http(s) URLs")
595
+ parser.add_argument("--device", default="cuda",
596
+ help="Device: cuda, cuda:0, cpu (default: cuda)")
597
  parser.add_argument("--max-size", type=int, default=1024,
598
+ help="Long-edge cap in pixels (default: 1024)")
599
 
600
  mode = parser.add_mutually_exclusive_group()
601
+ mode.add_argument("--topk", type=int, default=30,
602
+ help="Return top-k tags (default: 30)")
603
+ mode.add_argument("--threshold", type=float,
604
+ help="Return all tags with score >= threshold")
605
 
606
  parser.add_argument("--format", choices=["pretty", "tags", "json"],
607
  default="pretty", help="Output format (default: pretty)")
608
  args = parser.parse_args()
609
 
610
+ tagger = Tagger(
611
+ checkpoint_path=args.checkpoint,
612
+ vocab_path=args.vocab,
613
+ device=args.device,
614
+ max_size=args.max_size,
615
+ )
616
 
617
+ topk, threshold = (
618
+ (None, args.threshold) if args.threshold else (args.topk, None)
619
+ )
620
  json_out = []
621
 
622
  for src in args.images:
 
625
  print(f"[warning] File not found: {src}", file=sys.stderr)
626
  continue
627
  results = tagger.predict(src, topk=topk, threshold=threshold)
628
+ if args.format == "pretty":
629
+ print(_fmt_pretty(src, results))
630
+ elif args.format == "tags":
631
+ print(_fmt_tags(results))
632
+ elif args.format == "json":
633
+ json_out.append(_fmt_json(src, results))
634
 
635
  if args.format == "json":
636
  print(json.dumps(json_out, indent=2, ensure_ascii=False))
637
 
638
 
639
  if __name__ == "__main__":
640
+ main()