Update inference_tagger_standalone.py

#2
by ClintHardwood - opened
Files changed (1) hide show
  1. inference_tagger_standalone.py +325 -145
inference_tagger_standalone.py CHANGED
@@ -64,17 +64,19 @@ from safetensors.torch import load_file
64
  # All hyperparameters match facebook/dinov3-vith16plus-pretrain-lvd1689m
65
  # =============================================================================
66
 
67
- D_MODEL = 1280
68
- N_HEADS = 20
69
- HEAD_DIM = D_MODEL // N_HEADS # 64
70
- N_LAYERS = 32
71
- D_FFN = 5120
72
  N_REGISTERS = 4
73
- PATCH_SIZE = 16
74
- ROPE_THETA = 100.0
75
- ROPE_RESCALE = 2.0 # pos_embed_rescale applied at inference
76
- LN_EPS = 1e-5
77
- LAYERSCALE = 1.0
 
 
78
 
79
 
80
  # ---------------------------------------------------------------------------
@@ -83,25 +85,23 @@ LAYERSCALE = 1.0
83
 
84
  @lru_cache(maxsize=32)
85
  def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
86
- """Normalised [-1,+1] patch-centre coordinates (float32, cached)."""
87
  device = torch.device(device_str)
88
  cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
89
  cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
90
  coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
91
- coords = 2.0 * coords - 1.0 # [0,1] β†’ [-1,+1]
92
  coords = coords * ROPE_RESCALE
93
  return coords # [h*w, 2]
94
 
95
 
96
  def _build_rope(h_patches: int, w_patches: int,
97
  dtype: torch.dtype, device: torch.device):
98
- """Return (cos, sin) of shape [1, 1, h*w, HEAD_DIM] for broadcasting."""
99
- coords = _patch_coords_cached(h_patches, w_patches, str(device)) # [P, 2]
100
  inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
101
- 0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device)) # [D/4]
102
- angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :] # [P, 2, D/4]
103
- angles = angles.flatten(1, 2).tile(2) # [P, D]
104
- cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0) # [1,1,P,D]
105
  sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
106
  return cos, sin
107
 
@@ -113,7 +113,6 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
113
 
114
  def _apply_rope(q: torch.Tensor, k: torch.Tensor,
115
  cos: torch.Tensor, sin: torch.Tensor):
116
- """Apply RoPE only to patch tokens (skip CLS + register prefix)."""
117
  n_pre = 1 + N_REGISTERS
118
  q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
119
  k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
@@ -123,7 +122,7 @@ def _apply_rope(q: torch.Tensor, k: torch.Tensor,
123
 
124
 
125
  # ---------------------------------------------------------------------------
126
- # Building blocks
127
  # ---------------------------------------------------------------------------
128
 
129
  class _Attention(nn.Module):
@@ -134,7 +133,7 @@ class _Attention(nn.Module):
134
  self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
135
  self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
136
 
137
- def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
138
  B, S, _ = x.shape
139
  q = self.q_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
140
  k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
@@ -148,125 +147,259 @@ class _GatedMLP(nn.Module):
148
  def __init__(self):
149
  super().__init__()
150
  self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
151
- self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
152
- self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True)
153
 
154
- def forward(self, x: torch.Tensor) -> torch.Tensor:
155
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
156
 
157
 
158
  class _Block(nn.Module):
159
  def __init__(self):
160
  super().__init__()
161
- self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
162
- self.attention = _Attention()
163
  self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
164
- self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
165
- self.mlp = _GatedMLP()
166
  self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
167
 
168
- def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
169
  x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
170
  x = x + self.mlp(self.norm2(x)) * self.layer_scale2
171
  return x
172
 
173
 
174
- # ---------------------------------------------------------------------------
175
- # Full backbone
176
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  class DINOv3ViTH(nn.Module):
179
  """DINOv3 ViT-H/16+ backbone.
180
 
181
- Accepts any H, W that are multiples of 16.
182
- Returns last_hidden_state [B, 1+R+P, D_MODEL].
183
  Token layout: [CLS, reg_0..reg_3, patch_0..patch_N].
184
-
185
- State-dict keys are intentionally identical to the HuggingFace
186
- transformers layout so .safetensors checkpoints load without remapping.
187
  """
188
 
189
  def __init__(self):
190
  super().__init__()
191
- # These names must match HF exactly
192
  self.embeddings = _Embeddings()
193
  self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
194
- self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS)
195
-
196
- def _load_from_state_dict(self, state_dict, prefix, local_metadata,
197
- strict, missing_keys, unexpected_keys, error_msgs):
198
- # HF stores layer_scale as a sub-module with a "lambda1" parameter;
199
- # we store it as a plain Parameter directly on _Block.
200
- # Remap "layer.i.layer_scale{1,2}.lambda1" β†’ "layer.i.layer_scale{1,2}"
201
- for k in list(state_dict.keys()):
202
- if k.startswith(prefix) and ".layer_scale" in k and k.endswith(".lambda1"):
203
- new_k = k[:-len(".lambda1")]
204
- state_dict[new_k] = state_dict.pop(k)
205
- # Drop rope_embeddings buffer (computed on-the-fly)
206
- for k in list(state_dict.keys()):
207
- if k.startswith(prefix) and "rope_embeddings" in k:
208
- state_dict.pop(k)
209
- super()._load_from_state_dict(
210
- state_dict, prefix, local_metadata, strict,
211
- missing_keys, unexpected_keys, error_msgs)
212
-
213
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
214
- B, _, H, W = pixel_values.shape
215
- x = self.embeddings(pixel_values) # [B, 1+R+P, D]
216
 
 
 
 
217
  h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
218
  cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
219
-
220
  for block in self.layer:
221
  x = block(x, cos, sin)
222
-
223
  return self.norm(x)
224
 
225
 
226
- class _Embeddings(nn.Module):
227
- """Patch + CLS + register token embeddings.
228
- Key names match HF: embeddings.cls_token, embeddings.register_tokens,
229
- embeddings.patch_embeddings.{weight,bias}.
 
 
 
 
 
 
230
  """
231
 
232
- def __init__(self):
 
233
  super().__init__()
234
- self.cls_token = nn.Parameter(torch.empty(1, 1, D_MODEL))
235
- self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) # unused at inference
236
- self.register_tokens = nn.Parameter(torch.empty(1, N_REGISTERS, D_MODEL))
237
- self.patch_embeddings = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
238
 
239
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
240
- B = pixel_values.shape[0]
241
- dtype = self.patch_embeddings.weight.dtype
242
- patches = self.patch_embeddings(pixel_values.to(dtype)).flatten(2).transpose(1, 2)
243
- cls = self.cls_token.expand(B, -1, -1)
244
- regs = self.register_tokens.expand(B, -1, -1)
245
- return torch.cat([cls, regs, patches], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
  # =============================================================================
249
- # Tagger head
250
  # =============================================================================
251
 
252
  class DINOv3Tagger(nn.Module):
253
- """DINOv3 ViT-H/16+ backbone + linear projection head.
 
 
 
 
 
 
254
 
255
- features = concat(CLS, reg_0..reg_3) β†’ [B, (1+R)*D]
256
- projection: Linear β†’ [B, num_tags]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  """
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- def __init__(self, num_tags: int, projection_bias: bool = False):
260
- super().__init__()
261
- self.backbone = DINOv3ViTH()
262
- self.projection = nn.Linear((1 + N_REGISTERS) * D_MODEL, num_tags, bias=projection_bias)
263
 
264
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
265
- hidden = self.backbone(pixel_values) # [B, S, D]
266
- cls = hidden[:, 0, :] # [B, D]
267
- regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1) # [B, R*D]
268
- features = torch.cat([cls, regs], dim=-1) # [B, (1+R)*D]
269
- return self.projection(features.float()) # fp32 for stability
270
 
271
 
272
  # =============================================================================
@@ -274,7 +407,7 @@ class DINOv3Tagger(nn.Module):
274
  # =============================================================================
275
 
276
  _IMAGENET_MEAN = [0.485, 0.456, 0.406]
277
- _IMAGENET_STD = [0.229, 0.224, 0.225]
278
 
279
 
280
  def _snap(x: int, m: int) -> int:
@@ -291,12 +424,22 @@ def _open_image(source) -> Image.Image:
291
 
292
 
293
  def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
294
- """Load and preprocess an image β†’ [1, 3, H, W] float32, ImageNet-normalised."""
 
 
 
 
295
  img = _open_image(source)
296
  w, h = img.size
297
- scale = min(1.0, max_size / max(w, h))
298
- new_w = _snap(round(w * scale), PATCH_SIZE)
299
- new_h = _snap(round(h * scale), PATCH_SIZE)
 
 
 
 
 
 
300
  return v2.Compose([
301
  v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
302
  v2.ToImage(),
@@ -315,13 +458,15 @@ class Tagger:
315
  Parameters
316
  ----------
317
  checkpoint_path : str
318
- Path to a .safetensors or .pth checkpoint saved by TaggerTrainer.
319
  vocab_path : str
320
- Path to tagger_vocab.json ({"idx2tag": [...]}).
 
321
  device : str
322
- "cuda", "cuda:0", "cpu", etc.
323
  dtype : torch.dtype
324
- bfloat16 recommended on Ampere+; float16 for older GPUs; float32 for CPU.
 
325
  max_size : int
326
  Long-edge cap in pixels before feeding to the model.
327
  """
@@ -334,8 +479,13 @@ class Tagger:
334
  dtype: torch.dtype = torch.bfloat16,
335
  max_size: int = 1024,
336
  ):
337
- self.device = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
338
- self.dtype = dtype
 
 
 
 
 
339
  self.max_size = max_size
340
 
341
  with open(vocab_path) as f:
@@ -344,36 +494,47 @@ class Tagger:
344
  self.num_tags = len(self.idx2tag)
345
  print(f"[Tagger] Vocabulary: {self.num_tags:,} tags")
346
 
347
- self.model = DINOv3Tagger(num_tags=self.num_tags)
348
-
349
  print(f"[Tagger] Loading checkpoint: {checkpoint_path}")
350
  if checkpoint_path.endswith((".safetensors", ".sft")):
351
- sd = load_file(checkpoint_path, device=str(self.device))
352
  else:
353
- sd = torch.load(checkpoint_path, map_location=str(self.device))
354
-
355
- missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
356
- if missing:
357
- print(f"[Tagger] Missing keys ({len(missing)}): {missing[:5]}{'...' if len(missing) > 5 else ''}")
358
- if unexpected:
359
- print(f"[Tagger] Unexpected keys ({len(unexpected)}): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
360
-
361
- self.model.backbone = self.model.backbone.to(dtype=dtype)
362
- self.model = self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  self.model.eval()
364
- print(f"[Tagger] Ready on {self.device} ({dtype})")
365
 
366
  @torch.no_grad()
367
  def predict(self, image, topk: int | None = 30,
368
  threshold: float | None = None) -> list[tuple[str, float]]:
369
- """Tag a single image (local path or URL).
370
- Specify either topk OR threshold. Returns [(tag, score), ...] desc."""
371
  if topk is None and threshold is None:
372
  topk = 30
373
 
374
  pv = preprocess_image(image, max_size=self.max_size).to(self.device)
375
- with torch.autocast(device_type=self.device.type, dtype=self.dtype):
376
- logits = self.model(pv)[0]
377
  scores = torch.sigmoid(logits.float())
378
 
379
  if topk is not None:
@@ -381,17 +542,18 @@ class Tagger:
381
  else:
382
  assert threshold is not None
383
  indices = (scores >= threshold).nonzero(as_tuple=True)[0]
384
- values = scores[indices]
385
- order = values.argsort(descending=True)
386
  indices, values = indices[order], values[order]
387
 
388
- return [(self.idx2tag[i], float(v)) for i, v in zip(indices.tolist(), values.tolist())]
 
389
 
390
  @torch.no_grad()
391
  def predict_batch(self, images, topk: int | None = 30,
392
- threshold: float | None = None) -> list[list[tuple[str, float]]]:
393
- """Tag multiple images (processed individually for mixed resolutions)."""
394
- return [self.predict(img, topk=topk, threshold=threshold) for img in images]
395
 
396
 
397
  # =============================================================================
@@ -399,17 +561,20 @@ class Tagger:
399
  # =============================================================================
400
 
401
  def _fmt_pretty(path: str, results) -> str:
402
- lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
403
  for rank, (tag, score) in enumerate(results, 1):
404
  bar = "β–ˆ" * int(score * 20)
405
- lines.append(f" {rank:>3}. {score:.3f} {bar:<20} {tag}")
406
  return "\n".join(lines)
407
 
 
408
  def _fmt_tags(results) -> str:
409
  return ", ".join(tag for tag, _ in results)
410
 
 
411
  def _fmt_json(path: str, results) -> dict:
412
- return {"file": path, "tags": [{"tag": t, "score": round(s, 4)} for t, s in results]}
 
413
 
414
 
415
  # =============================================================================
@@ -418,28 +583,40 @@ def _fmt_json(path: str, results) -> dict:
418
 
419
  def main():
420
  parser = argparse.ArgumentParser(
421
- description="DINOv3 ViT-H/16+ tagger inference (standalone, no transformers dep)",
422
  formatter_class=argparse.RawDescriptionHelpFormatter,
423
  )
424
- parser.add_argument("--checkpoint", required=True, help="Path to .safetensors or .pth checkpoint")
425
- parser.add_argument("--vocab", required=True, help="Path to tagger_vocab.json")
426
- parser.add_argument("--images", nargs="+", required=True, help="Image paths and/or http(s) URLs")
427
- parser.add_argument("--device", default="cuda", help="Device: cuda, cuda:0, cpu, … (default: cuda)")
 
 
 
 
428
  parser.add_argument("--max-size", type=int, default=1024,
429
- help="Long-edge cap in pixels, multiple of 16 (default: 1024)")
430
 
431
  mode = parser.add_mutually_exclusive_group()
432
- mode.add_argument("--topk", type=int, default=30, help="Return top-k tags (default: 30)")
433
- mode.add_argument("--threshold", type=float, help="Return all tags with score >= threshold")
 
 
434
 
435
  parser.add_argument("--format", choices=["pretty", "tags", "json"],
436
  default="pretty", help="Output format (default: pretty)")
437
  args = parser.parse_args()
438
 
439
- tagger = Tagger(checkpoint_path=args.checkpoint, vocab_path=args.vocab,
440
- device=args.device, max_size=args.max_size)
 
 
 
 
441
 
442
- topk, threshold = (None, args.threshold) if args.threshold else (args.topk, None)
 
 
443
  json_out = []
444
 
445
  for src in args.images:
@@ -448,13 +625,16 @@ def main():
448
  print(f"[warning] File not found: {src}", file=sys.stderr)
449
  continue
450
  results = tagger.predict(src, topk=topk, threshold=threshold)
451
- if args.format == "pretty": print(_fmt_pretty(src, results))
452
- elif args.format == "tags": print(_fmt_tags(results))
453
- elif args.format == "json": json_out.append(_fmt_json(src, results))
 
 
 
454
 
455
  if args.format == "json":
456
  print(json.dumps(json_out, indent=2, ensure_ascii=False))
457
 
458
 
459
  if __name__ == "__main__":
460
- main()
 
64
  # All hyperparameters match facebook/dinov3-vith16plus-pretrain-lvd1689m
65
  # =============================================================================
66
 
67
+ D_MODEL = 1280
68
+ N_HEADS = 20
69
+ HEAD_DIM = D_MODEL // N_HEADS # 64
70
+ N_LAYERS = 32
71
+ D_FFN = 5120
72
  N_REGISTERS = 4
73
+ PATCH_SIZE = 16
74
+ ROPE_THETA = 100.0
75
+ ROPE_RESCALE = 2.0
76
+ LN_EPS = 1e-5
77
+ LAYERSCALE = 1.0
78
+
79
+ FEATURE_DIM = (1 + N_REGISTERS) * D_MODEL # 6400
80
 
81
 
82
  # ---------------------------------------------------------------------------
 
85
 
86
  @lru_cache(maxsize=32)
87
  def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
 
88
  device = torch.device(device_str)
89
  cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
90
  cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
91
  coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
92
+ coords = 2.0 * coords - 1.0
93
  coords = coords * ROPE_RESCALE
94
  return coords # [h*w, 2]
95
 
96
 
97
  def _build_rope(h_patches: int, w_patches: int,
98
  dtype: torch.dtype, device: torch.device):
99
+ coords = _patch_coords_cached(h_patches, w_patches, str(device))
 
100
  inv_freq = 1.0 / (ROPE_THETA ** torch.arange(
101
+ 0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device))
102
+ angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
103
+ angles = angles.flatten(1, 2).tile(2)
104
+ cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
105
  sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
106
  return cos, sin
107
 
 
113
 
114
  def _apply_rope(q: torch.Tensor, k: torch.Tensor,
115
  cos: torch.Tensor, sin: torch.Tensor):
 
116
  n_pre = 1 + N_REGISTERS
117
  q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
118
  k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
 
122
 
123
 
124
  # ---------------------------------------------------------------------------
125
+ # Transformer blocks
126
  # ---------------------------------------------------------------------------
127
 
128
  class _Attention(nn.Module):
 
133
  self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
134
  self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
135
 
136
+ def forward(self, x, cos, sin):
137
  B, S, _ = x.shape
138
  q = self.q_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
139
  k = self.k_proj(x).view(B, S, N_HEADS, HEAD_DIM).transpose(1, 2)
 
147
  def __init__(self):
148
  super().__init__()
149
  self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
150
+ self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
151
+ self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True)
152
 
153
+ def forward(self, x):
154
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
155
 
156
 
157
  class _Block(nn.Module):
158
  def __init__(self):
159
  super().__init__()
160
+ self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
161
+ self.attention = _Attention()
162
  self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
163
+ self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
164
+ self.mlp = _GatedMLP()
165
  self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
166
 
167
+ def forward(self, x, cos, sin):
168
  x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
169
  x = x + self.mlp(self.norm2(x)) * self.layer_scale2
170
  return x
171
 
172
 
173
+ class _Embeddings(nn.Module):
174
+ def __init__(self):
175
+ super().__init__()
176
+ # zeros() rather than empty() so a forgotten checkpoint key fails
177
+ # predictably instead of producing undefined outputs.
178
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
179
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
180
+ self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
181
+ self.patch_embeddings = nn.Conv2d(
182
+ 3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
183
+
184
+ def forward(self, pixel_values):
185
+ B = pixel_values.shape[0]
186
+ dtype = self.patch_embeddings.weight.dtype
187
+ patches = self.patch_embeddings(
188
+ pixel_values.to(dtype)).flatten(2).transpose(1, 2)
189
+ cls = self.cls_token.expand(B, -1, -1)
190
+ regs = self.register_tokens.expand(B, -1, -1)
191
+ return torch.cat([cls, regs, patches], dim=1)
192
+
193
 
194
  class DINOv3ViTH(nn.Module):
195
  """DINOv3 ViT-H/16+ backbone.
196
 
 
 
197
  Token layout: [CLS, reg_0..reg_3, patch_0..patch_N].
198
+ Returns last_hidden_state [B, 1+R+P, D_MODEL].
 
 
199
  """
200
 
201
  def __init__(self):
202
  super().__init__()
 
203
  self.embeddings = _Embeddings()
204
  self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
205
+ self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ def forward(self, pixel_values):
208
+ _, _, H, W = pixel_values.shape
209
+ x = self.embeddings(pixel_values)
210
  h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
211
  cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
 
212
  for block in self.layer:
213
  x = block(x, cos, sin)
 
214
  return self.norm(x)
215
 
216
 
217
+ # =============================================================================
218
+ # Head β€” auto-detected from the checkpoint
219
+ # =============================================================================
220
+
221
+ class _LowRankHead(nn.Module):
222
+ """Two-matrix low-rank projection head.
223
+
224
+ features (in_dim)
225
+ β†’ Linear(in_dim, rank, bias=?)
226
+ β†’ Linear(rank, num_tags, bias=?)
227
  """
228
 
229
+ def __init__(self, in_dim: int, rank: int, num_tags: int,
230
+ down_bias: bool, up_bias: bool):
231
  super().__init__()
232
+ self.proj_down = nn.Linear(in_dim, rank, bias=down_bias)
233
+ self.proj_up = nn.Linear(rank, num_tags, bias=up_bias)
 
 
234
 
235
+ def forward(self, x):
236
+ return self.proj_up(self.proj_down(x))
237
+
238
+
239
+ def _build_head_from_checkpoint(
240
+ head_sd: dict,
241
+ in_dim: int,
242
+ num_tags: int,
243
+ ) -> tuple[nn.Module, dict]:
244
+ """Inspect head_sd and build a matching Module.
245
+
246
+ Supports two layouts, in order of preference:
247
+ 1. Single linear β€” any ``*.weight`` with shape [num_tags, in_dim]
248
+ 2. Low-rank pair (2 mats) β€” one ``*.weight`` [rank, in_dim] plus
249
+ one ``*.weight`` [num_tags, rank]
250
+
251
+ Returns (module, remapped_state_dict) where the remapped state dict
252
+ matches the module's own key names so strict loading works.
253
+ """
254
+ weights_2d = [(k, v) for k, v in head_sd.items()
255
+ if k.endswith(".weight") and v.ndim == 2]
256
+
257
+ # --- Case 1: single dense linear ---------------------------------------
258
+ singles = [(k, v) for k, v in weights_2d
259
+ if tuple(v.shape) == (num_tags, in_dim)]
260
+ if len(weights_2d) <= 2 and len(singles) == 1:
261
+ wkey, wval = singles[0]
262
+ base = wkey[:-len(".weight")]
263
+ bias_key = base + ".bias"
264
+ has_bias = bias_key in head_sd
265
+ module = nn.Linear(in_dim, num_tags, bias=has_bias)
266
+ remapped = {"weight": wval}
267
+ if has_bias:
268
+ remapped["bias"] = head_sd[bias_key]
269
+ # Sanity check: no extra keys we don't understand
270
+ expected_src = {wkey} | ({bias_key} if has_bias else set())
271
+ extra = set(head_sd) - expected_src
272
+ if extra:
273
+ raise RuntimeError(
274
+ f"Head has single-linear shape but extra unknown keys: {sorted(extra)}")
275
+ return module, remapped
276
+
277
+ # --- Case 2: low-rank pair ---------------------------------------------
278
+ down = None # (key, tensor) with shape [rank, in_dim]
279
+ up = None # (key, tensor) with shape [num_tags, rank]
280
+ for k, v in weights_2d:
281
+ if v.shape[1] == in_dim and v.shape[0] != num_tags:
282
+ down = (k, v)
283
+ elif v.shape[0] == num_tags and v.shape[1] != in_dim:
284
+ up = (k, v)
285
+
286
+ if down is not None and up is not None:
287
+ rank_down = down[1].shape[0]
288
+ rank_up = up[1].shape[1]
289
+ if rank_down != rank_up:
290
+ raise RuntimeError(
291
+ f"Low-rank head: inner dims disagree "
292
+ f"(down out={rank_down}, up in={rank_up})")
293
+
294
+ down_key, down_w = down
295
+ up_key, up_w = up
296
+ down_base = down_key[:-len(".weight")]
297
+ up_base = up_key[:-len(".weight")]
298
+ down_bias_key = down_base + ".bias"
299
+ up_bias_key = up_base + ".bias"
300
+ has_down_bias = down_bias_key in head_sd
301
+ has_up_bias = up_bias_key in head_sd
302
+
303
+ module = _LowRankHead(
304
+ in_dim=in_dim,
305
+ rank=rank_down,
306
+ num_tags=num_tags,
307
+ down_bias=has_down_bias,
308
+ up_bias=has_up_bias,
309
+ )
310
+ remapped = {
311
+ "proj_down.weight": down_w,
312
+ "proj_up.weight": up_w,
313
+ }
314
+ if has_down_bias:
315
+ remapped["proj_down.bias"] = head_sd[down_bias_key]
316
+ if has_up_bias:
317
+ remapped["proj_up.bias"] = head_sd[up_bias_key]
318
+
319
+ # Sanity check
320
+ expected_src = {down_key, up_key}
321
+ if has_down_bias:
322
+ expected_src.add(down_bias_key)
323
+ if has_up_bias:
324
+ expected_src.add(up_bias_key)
325
+ extra = set(head_sd) - expected_src
326
+ if extra:
327
+ raise RuntimeError(
328
+ f"Low-rank head detected but checkpoint has extra unknown "
329
+ f"head keys: {sorted(extra)}")
330
+
331
+ print(f"[Tagger] Detected low-rank head: "
332
+ f"in_dim={in_dim}, rank={rank_down}, num_tags={num_tags} "
333
+ f"(down_bias={has_down_bias}, up_bias={has_up_bias})")
334
+ return module, remapped
335
+
336
+ raise RuntimeError(
337
+ "Could not infer head architecture from checkpoint. "
338
+ f"Non-backbone keys found: {sorted(head_sd.keys())}"
339
+ )
340
 
341
 
342
  # =============================================================================
343
+ # Tagger wrapper module
344
  # =============================================================================
345
 
346
  class DINOv3Tagger(nn.Module):
347
+ """Backbone + head. The head is attached after the checkpoint is
348
+ inspected (so we can build the right shape)."""
349
+
350
+ def __init__(self):
351
+ super().__init__()
352
+ self.backbone = DINOv3ViTH()
353
+ self.head: nn.Module | None = None # attached by Tagger
354
 
355
+ def forward(self, pixel_values):
356
+ hidden = self.backbone(pixel_values)
357
+ cls = hidden[:, 0, :]
358
+ regs = hidden[:, 1: 1 + N_REGISTERS, :].flatten(1)
359
+ features = torch.cat([cls, regs], dim=-1).float() # fp32 for head
360
+ return self.head(features)
361
+
362
+
363
+ # =============================================================================
364
+ # Checkpoint loading helpers
365
+ # =============================================================================
366
+
367
+ def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
368
+ """Split full state dict into (backbone_sd, head_sd), stripping the
369
+ ``backbone.`` prefix and applying the remaps needed to match
370
+ ``DINOv3ViTH``'s parameter layout:
371
+
372
+ 1. ``backbone.model.layer.N.*`` β†’ ``layer.N.*``
373
+ (the checkpoint has an HF-style intermediate ``model`` wrapper
374
+ that our flat backbone class does not)
375
+ 2. ``...layer_scale{1,2}.lambda1`` β†’ ``...layer_scale{1,2}``
376
+ (HF stores layer_scale as a sub-module with a ``lambda1``
377
+ parameter; we use a plain ``nn.Parameter``)
378
+ 3. Drop any ``rope_embeddings`` buffers (recomputed on the fly)
379
  """
380
+ backbone_sd: dict = {}
381
+ head_sd: dict = {}
382
+ for k, v in sd.items():
383
+ if k.startswith("backbone."):
384
+ nk = k[len("backbone."):]
385
+ # Remap (1): strip intermediate "model." before "layer."
386
+ if nk.startswith("model.layer."):
387
+ nk = nk[len("model."):]
388
+ backbone_sd[nk] = v
389
+ else:
390
+ head_sd[k] = v
391
 
392
+ # Remap (2): layer.N.layer_scale{1,2}.lambda1 β†’ layer.N.layer_scale{1,2}
393
+ for k in list(backbone_sd.keys()):
394
+ if ".layer_scale" in k and k.endswith(".lambda1"):
395
+ backbone_sd[k[:-len(".lambda1")]] = backbone_sd.pop(k)
396
 
397
+ # Remap (3): drop rope buffers (recomputed on the fly)
398
+ for k in list(backbone_sd.keys()):
399
+ if "rope_embeddings" in k:
400
+ backbone_sd.pop(k)
401
+
402
+ return backbone_sd, head_sd
403
 
404
 
405
  # =============================================================================
 
407
  # =============================================================================
408
 
409
  _IMAGENET_MEAN = [0.485, 0.456, 0.406]
410
+ _IMAGENET_STD = [0.229, 0.224, 0.225]
411
 
412
 
413
  def _snap(x: int, m: int) -> int:
 
424
 
425
 
426
  def preprocess_image(source, max_size: int = 1024) -> torch.Tensor:
427
+ """Load and preprocess an image β†’ [1, 3, H, W] float32, ImageNet-normalised.
428
+
429
+ Aspect ratio is preserved: a single scale factor is chosen so that the
430
+ long edge fits inside max_size after snapping to a PATCH_SIZE multiple.
431
+ """
432
  img = _open_image(source)
433
  w, h = img.size
434
+
435
+ # Target long-edge (snapped to patch multiple).
436
+ long_edge = max(w, h)
437
+ target_long = _snap(min(long_edge, max_size), PATCH_SIZE)
438
+ scale = target_long / long_edge
439
+
440
+ new_w = _snap(max(PATCH_SIZE, round(w * scale)), PATCH_SIZE)
441
+ new_h = _snap(max(PATCH_SIZE, round(h * scale)), PATCH_SIZE)
442
+
443
  return v2.Compose([
444
  v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
445
  v2.ToImage(),
 
458
  Parameters
459
  ----------
460
  checkpoint_path : str
461
+ Path to a .safetensors or .pt/.pth checkpoint.
462
  vocab_path : str
463
+ Path to tagger_vocab.json or tagger_vocab_with_categories.json
464
+ (either must contain an ``idx2tag`` list).
465
  device : str
466
+ "cuda", "cuda:0", "cpu", ...
467
  dtype : torch.dtype
468
+ Backbone precision. bfloat16 recommended on Ampere+, float16 for
469
+ older GPUs, float32 for CPU. The head always runs in fp32.
470
  max_size : int
471
  Long-edge cap in pixels before feeding to the model.
472
  """
 
479
  dtype: torch.dtype = torch.bfloat16,
480
  max_size: int = 1024,
481
  ):
482
+ want_cuda = device.startswith("cuda")
483
+ if want_cuda and not torch.cuda.is_available():
484
+ print("[Tagger] CUDA not available, falling back to CPU")
485
+ device = "cpu"
486
+ dtype = torch.float32
487
+ self.device = torch.device(device)
488
+ self.dtype = dtype
489
  self.max_size = max_size
490
 
491
  with open(vocab_path) as f:
 
494
  self.num_tags = len(self.idx2tag)
495
  print(f"[Tagger] Vocabulary: {self.num_tags:,} tags")
496
 
497
+ # --- Load checkpoint to CPU first so we can inspect shapes ---------
 
498
  print(f"[Tagger] Loading checkpoint: {checkpoint_path}")
499
  if checkpoint_path.endswith((".safetensors", ".sft")):
500
+ sd = load_file(checkpoint_path, device="cpu")
501
  else:
502
+ sd = torch.load(checkpoint_path, map_location="cpu")
503
+
504
+ backbone_sd, head_sd = _split_and_clean_state_dict(sd)
505
+
506
+ if not head_sd:
507
+ raise RuntimeError(
508
+ "Checkpoint contains no non-backbone keys β€” cannot build head.")
509
+
510
+ # --- Build model, inferring head shape from the checkpoint --------
511
+ self.model = DINOv3Tagger()
512
+ head_module, head_sd_remapped = _build_head_from_checkpoint(
513
+ head_sd, in_dim=FEATURE_DIM, num_tags=self.num_tags,
514
+ )
515
+ self.model.head = head_module
516
+
517
+ # --- Strict load β€” mismatches raise instead of silently passing ----
518
+ self.model.backbone.load_state_dict(backbone_sd, strict=True)
519
+ self.model.head.load_state_dict(head_sd_remapped, strict=True)
520
+
521
+ # --- Move to device. Backbone β†’ bf16/fp16; head stays fp32. --------
522
+ self.model.backbone = self.model.backbone.to(
523
+ device=self.device, dtype=dtype)
524
+ self.model.head = self.model.head.to(
525
+ device=self.device, dtype=torch.float32)
526
  self.model.eval()
527
+ print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
528
 
529
  @torch.no_grad()
530
  def predict(self, image, topk: int | None = 30,
531
  threshold: float | None = None) -> list[tuple[str, float]]:
532
+ """Tag a single image (local path or URL)."""
 
533
  if topk is None and threshold is None:
534
  topk = 30
535
 
536
  pv = preprocess_image(image, max_size=self.max_size).to(self.device)
537
+ logits = self.model(pv)[0]
 
538
  scores = torch.sigmoid(logits.float())
539
 
540
  if topk is not None:
 
542
  else:
543
  assert threshold is not None
544
  indices = (scores >= threshold).nonzero(as_tuple=True)[0]
545
+ values = scores[indices]
546
+ order = values.argsort(descending=True)
547
  indices, values = indices[order], values[order]
548
 
549
+ return [(self.idx2tag[i], float(v))
550
+ for i, v in zip(indices.tolist(), values.tolist())]
551
 
552
  @torch.no_grad()
553
  def predict_batch(self, images, topk: int | None = 30,
554
+ threshold: float | None = None):
555
+ return [self.predict(img, topk=topk, threshold=threshold)
556
+ for img in images]
557
 
558
 
559
  # =============================================================================
 
561
  # =============================================================================
562
 
563
  def _fmt_pretty(path: str, results) -> str:
564
+ lines = [f"\n{'─' * 60}", f" {path}", f"{'─' * 60}"]
565
  for rank, (tag, score) in enumerate(results, 1):
566
  bar = "β–ˆ" * int(score * 20)
567
+ lines.append(f" {rank:>3}. {score:.3f} {bar:<20} {tag}")
568
  return "\n".join(lines)
569
 
570
+
571
  def _fmt_tags(results) -> str:
572
  return ", ".join(tag for tag, _ in results)
573
 
574
+
575
  def _fmt_json(path: str, results) -> dict:
576
+ return {"file": path,
577
+ "tags": [{"tag": t, "score": round(s, 4)} for t, s in results]}
578
 
579
 
580
  # =============================================================================
 
583
 
584
  def main():
585
  parser = argparse.ArgumentParser(
586
+ description="DINOv3 ViT-H/16+ tagger inference (standalone)",
587
  formatter_class=argparse.RawDescriptionHelpFormatter,
588
  )
589
+ parser.add_argument("--checkpoint", required=True,
590
+ help="Path to .safetensors or .pt checkpoint")
591
+ parser.add_argument("--vocab", required=True,
592
+ help="Path to tagger_vocab*.json")
593
+ parser.add_argument("--images", nargs="+", required=True,
594
+ help="Image paths and/or http(s) URLs")
595
+ parser.add_argument("--device", default="cuda",
596
+ help="Device: cuda, cuda:0, cpu (default: cuda)")
597
  parser.add_argument("--max-size", type=int, default=1024,
598
+ help="Long-edge cap in pixels (default: 1024)")
599
 
600
  mode = parser.add_mutually_exclusive_group()
601
+ mode.add_argument("--topk", type=int, default=30,
602
+ help="Return top-k tags (default: 30)")
603
+ mode.add_argument("--threshold", type=float,
604
+ help="Return all tags with score >= threshold")
605
 
606
  parser.add_argument("--format", choices=["pretty", "tags", "json"],
607
  default="pretty", help="Output format (default: pretty)")
608
  args = parser.parse_args()
609
 
610
+ tagger = Tagger(
611
+ checkpoint_path=args.checkpoint,
612
+ vocab_path=args.vocab,
613
+ device=args.device,
614
+ max_size=args.max_size,
615
+ )
616
 
617
+ topk, threshold = (
618
+ (None, args.threshold) if args.threshold else (args.topk, None)
619
+ )
620
  json_out = []
621
 
622
  for src in args.images:
 
625
  print(f"[warning] File not found: {src}", file=sys.stderr)
626
  continue
627
  results = tagger.predict(src, topk=topk, threshold=threshold)
628
+ if args.format == "pretty":
629
+ print(_fmt_pretty(src, results))
630
+ elif args.format == "tags":
631
+ print(_fmt_tags(results))
632
+ elif args.format == "json":
633
+ json_out.append(_fmt_json(src, results))
634
 
635
  if args.format == "json":
636
  print(json.dumps(json_out, indent=2, ensure_ascii=False))
637
 
638
 
639
  if __name__ == "__main__":
640
+ main()