lodestones commited on
Commit
99a41dc
·
verified ·
1 Parent(s): 85a4088

Update inference_tagger_standalone.py

Browse files
Files changed (1) hide show
  1. inference_tagger_standalone.py +79 -0
inference_tagger_standalone.py CHANGED
@@ -213,6 +213,20 @@ class DINOv3ViTH(nn.Module):
213
  x = block(x, cos, sin)
214
  return self.norm(x)
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  # =============================================================================
218
  # Head — auto-detected from the checkpoint
@@ -526,6 +540,71 @@ class Tagger:
526
  self.model.eval()
527
  print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  @torch.no_grad()
530
  def predict(self, image, topk: int | None = 30,
531
  threshold: float | None = None) -> list[tuple[str, float]]:
 
213
  x = block(x, cos, sin)
214
  return self.norm(x)
215
 
216
+ def get_image_tokens(self, pixel_values):
217
+ """Return patch tokens only (no CLS/registers) as [B, h_p*w_p, D_MODEL]
218
+ and the spatial grid dimensions (h_p, w_p)."""
219
+ _, _, H, W = pixel_values.shape
220
+ h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
221
+ x = self.embeddings(pixel_values)
222
+ cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
223
+ for block in self.layer:
224
+ x = block(x, cos, sin)
225
+ x = self.norm(x)
226
+ # token layout: [CLS, reg_0..reg_R-1, patch_0..patch_N]
227
+ patch_tokens = x[:, 1 + N_REGISTERS:, :] # [B, h_p*w_p, D_MODEL]
228
+ return patch_tokens, h_p, w_p
229
+
230
 
231
  # =============================================================================
232
  # Head — auto-detected from the checkpoint
 
540
  self.model.eval()
541
  print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
542
 
543
+ @torch.no_grad()
544
+ def embed_pca(
545
+ self,
546
+ image,
547
+ n_components: int = 3,
548
+ max_size: int | None = None,
549
+ ) -> "Image.Image":
550
+ """Run PCA on the patch-token features of *image* and return a
551
+ false-colour RGB PIL image where R/G/B channels correspond to the
552
+ first three principal components, each normalised to [0, 255].
553
+
554
+ Parameters
555
+ ----------
556
+ image :
557
+ Local path, URL, or PIL.Image.Image.
558
+ n_components :
559
+ Number of PCA components (must be 3 for RGB output).
560
+ max_size :
561
+ Long-edge cap in pixels (defaults to ``self.max_size``).
562
+ """
563
+ if n_components != 3:
564
+ raise ValueError("n_components must be 3 for false-colour RGB output")
565
+ if max_size is None:
566
+ max_size = self.max_size
567
+
568
+ if isinstance(image, Image.Image):
569
+ img = image.convert("RGB")
570
+ w, h = img.size
571
+ scale = min(1.0, max_size / max(w, h))
572
+ new_w = _snap(round(w * scale), PATCH_SIZE)
573
+ new_h = _snap(round(h * scale), PATCH_SIZE)
574
+ pv = v2.Compose([
575
+ v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
576
+ v2.ToImage(),
577
+ v2.ToDtype(torch.float32, scale=True),
578
+ v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
579
+ ])(img).unsqueeze(0).to(self.device)
580
+ else:
581
+ pv = preprocess_image(image, max_size=max_size).to(self.device)
582
+
583
+ with torch.autocast(device_type=self.device.type, dtype=self.dtype):
584
+ patch_tokens, h_p, w_p = self.model.backbone.get_image_tokens(pv)
585
+
586
+ # patch_tokens: [1, h_p*w_p, D_MODEL] → [N, D]
587
+ tokens = patch_tokens[0].float() # fp32 for PCA
588
+
589
+ # Centre
590
+ mean = tokens.mean(dim=0, keepdim=True)
591
+ tokens_c = tokens - mean
592
+
593
+ # PCA via SVD (economy)
594
+ _, _, Vt = torch.linalg.svd(tokens_c, full_matrices=False)
595
+ components = Vt[:n_components] # [3, D]
596
+ projected = tokens_c @ components.T # [N, 3]
597
+
598
+ # Normalise each component to [0, 1]
599
+ lo = projected.min(dim=0).values
600
+ hi = projected.max(dim=0).values
601
+ projected = (projected - lo) / (hi - lo + 1e-8)
602
+
603
+ # Reshape to spatial grid and convert to uint8 PIL image
604
+ rgb = projected.reshape(h_p, w_p, 3).cpu().numpy()
605
+ rgb_uint8 = (rgb * 255).clip(0, 255).astype("uint8")
606
+ return Image.fromarray(rgb_uint8, mode="RGB")
607
+
608
  @torch.no_grad()
609
  def predict(self, image, topk: int | None = 30,
610
  threshold: float | None = None) -> list[tuple[str, float]]: