Update inference_tagger_standalone.py
Browse files
inference_tagger_standalone.py
CHANGED
|
@@ -213,6 +213,20 @@ class DINOv3ViTH(nn.Module):
|
|
| 213 |
x = block(x, cos, sin)
|
| 214 |
return self.norm(x)
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
# =============================================================================
|
| 218 |
# Head — auto-detected from the checkpoint
|
|
@@ -526,6 +540,71 @@ class Tagger:
|
|
| 526 |
self.model.eval()
|
| 527 |
print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
|
| 528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
@torch.no_grad()
|
| 530 |
def predict(self, image, topk: int | None = 30,
|
| 531 |
threshold: float | None = None) -> list[tuple[str, float]]:
|
|
|
|
| 213 |
x = block(x, cos, sin)
|
| 214 |
return self.norm(x)
|
| 215 |
|
| 216 |
+
def get_image_tokens(self, pixel_values):
|
| 217 |
+
"""Return patch tokens only (no CLS/registers) as [B, h_p*w_p, D_MODEL]
|
| 218 |
+
and the spatial grid dimensions (h_p, w_p)."""
|
| 219 |
+
_, _, H, W = pixel_values.shape
|
| 220 |
+
h_p, w_p = H // PATCH_SIZE, W // PATCH_SIZE
|
| 221 |
+
x = self.embeddings(pixel_values)
|
| 222 |
+
cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
|
| 223 |
+
for block in self.layer:
|
| 224 |
+
x = block(x, cos, sin)
|
| 225 |
+
x = self.norm(x)
|
| 226 |
+
# token layout: [CLS, reg_0..reg_R-1, patch_0..patch_N]
|
| 227 |
+
patch_tokens = x[:, 1 + N_REGISTERS:, :] # [B, h_p*w_p, D_MODEL]
|
| 228 |
+
return patch_tokens, h_p, w_p
|
| 229 |
+
|
| 230 |
|
| 231 |
# =============================================================================
|
| 232 |
# Head — auto-detected from the checkpoint
|
|
|
|
| 540 |
self.model.eval()
|
| 541 |
print(f"[Tagger] Ready on {self.device} (backbone={dtype}, head=fp32)")
|
| 542 |
|
| 543 |
+
@torch.no_grad()
|
| 544 |
+
def embed_pca(
|
| 545 |
+
self,
|
| 546 |
+
image,
|
| 547 |
+
n_components: int = 3,
|
| 548 |
+
max_size: int | None = None,
|
| 549 |
+
) -> "Image.Image":
|
| 550 |
+
"""Run PCA on the patch-token features of *image* and return a
|
| 551 |
+
false-colour RGB PIL image where R/G/B channels correspond to the
|
| 552 |
+
first three principal components, each normalised to [0, 255].
|
| 553 |
+
|
| 554 |
+
Parameters
|
| 555 |
+
----------
|
| 556 |
+
image :
|
| 557 |
+
Local path, URL, or PIL.Image.Image.
|
| 558 |
+
n_components :
|
| 559 |
+
Number of PCA components (must be 3 for RGB output).
|
| 560 |
+
max_size :
|
| 561 |
+
Long-edge cap in pixels (defaults to ``self.max_size``).
|
| 562 |
+
"""
|
| 563 |
+
if n_components != 3:
|
| 564 |
+
raise ValueError("n_components must be 3 for false-colour RGB output")
|
| 565 |
+
if max_size is None:
|
| 566 |
+
max_size = self.max_size
|
| 567 |
+
|
| 568 |
+
if isinstance(image, Image.Image):
|
| 569 |
+
img = image.convert("RGB")
|
| 570 |
+
w, h = img.size
|
| 571 |
+
scale = min(1.0, max_size / max(w, h))
|
| 572 |
+
new_w = _snap(round(w * scale), PATCH_SIZE)
|
| 573 |
+
new_h = _snap(round(h * scale), PATCH_SIZE)
|
| 574 |
+
pv = v2.Compose([
|
| 575 |
+
v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS),
|
| 576 |
+
v2.ToImage(),
|
| 577 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 578 |
+
v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
|
| 579 |
+
])(img).unsqueeze(0).to(self.device)
|
| 580 |
+
else:
|
| 581 |
+
pv = preprocess_image(image, max_size=max_size).to(self.device)
|
| 582 |
+
|
| 583 |
+
with torch.autocast(device_type=self.device.type, dtype=self.dtype):
|
| 584 |
+
patch_tokens, h_p, w_p = self.model.backbone.get_image_tokens(pv)
|
| 585 |
+
|
| 586 |
+
# patch_tokens: [1, h_p*w_p, D_MODEL] → [N, D]
|
| 587 |
+
tokens = patch_tokens[0].float() # fp32 for PCA
|
| 588 |
+
|
| 589 |
+
# Centre
|
| 590 |
+
mean = tokens.mean(dim=0, keepdim=True)
|
| 591 |
+
tokens_c = tokens - mean
|
| 592 |
+
|
| 593 |
+
# PCA via SVD (economy)
|
| 594 |
+
_, _, Vt = torch.linalg.svd(tokens_c, full_matrices=False)
|
| 595 |
+
components = Vt[:n_components] # [3, D]
|
| 596 |
+
projected = tokens_c @ components.T # [N, 3]
|
| 597 |
+
|
| 598 |
+
# Normalise each component to [0, 1]
|
| 599 |
+
lo = projected.min(dim=0).values
|
| 600 |
+
hi = projected.max(dim=0).values
|
| 601 |
+
projected = (projected - lo) / (hi - lo + 1e-8)
|
| 602 |
+
|
| 603 |
+
# Reshape to spatial grid and convert to uint8 PIL image
|
| 604 |
+
rgb = projected.reshape(h_p, w_p, 3).cpu().numpy()
|
| 605 |
+
rgb_uint8 = (rgb * 255).clip(0, 255).astype("uint8")
|
| 606 |
+
return Image.fromarray(rgb_uint8, mode="RGB")
|
| 607 |
+
|
| 608 |
@torch.no_grad()
|
| 609 |
def predict(self, image, topk: int | None = 30,
|
| 610 |
threshold: float | None = None) -> list[tuple[str, float]]:
|