CephVIT / heatmap_utils.py
farrell236's picture
Upload heatmap_utils.py
c67f469 verified
import math
import torch
import torch.nn.functional as F
def _gaussian_blur_heatmaps(heatmaps: torch.Tensor, kernel: int = 11) -> torch.Tensor:
if kernel % 2 == 0:
raise ValueError("kernel must be odd")
sigma = kernel / 6.0
radius = kernel // 2
x = torch.arange(kernel, device=heatmaps.device, dtype=heatmaps.dtype) - radius
g = torch.exp(-(x ** 2) / (2 * sigma * sigma))
g = g / g.sum()
g_x = g.view(1, 1, 1, kernel)
g_y = g.view(1, 1, kernel, 1)
B, N, H, W = heatmaps.shape
# 🔥 FIX HERE
x_in = heatmaps.reshape(B * N, 1, H, W)
x_in = F.pad(x_in, (radius, radius, 0, 0), mode="reflect")
x_in = F.conv2d(x_in, g_x)
x_in = F.pad(x_in, (0, 0, radius, radius), mode="reflect")
x_in = F.conv2d(x_in, g_y)
return x_in.reshape(B, N, H, W)
def heatmaps_to_coords_dark(
heatmaps: torch.Tensor,
blur_kernel: int = 11,
eps: float = 1e-10,
) -> torch.Tensor:
"""
DARK-style decoding with second-order local refinement.
Args:
heatmaps: [B, N, H, W] or [N, H, W]
blur_kernel: Gaussian blur kernel before log
eps: numerical stability for log
Returns:
coords: [B, N, 2] or [N, 2] in heatmap coordinates
"""
squeeze_batch = False
if heatmaps.ndim == 3:
heatmaps = heatmaps.unsqueeze(0)
squeeze_batch = True
if heatmaps.ndim != 4:
raise ValueError(f"Expected [B, N, H, W] or [N, H, W], got {heatmaps.shape}")
B, N, H, W = heatmaps.shape
# Blur then log, as in DARK-style refinement
hm = _gaussian_blur_heatmaps(heatmaps, kernel=blur_kernel)
hm = torch.clamp(hm, min=eps).log()
# Coarse argmax
flat = hm.view(B, N, -1)
idx = flat.argmax(dim=-1)
py = (idx // W).long()
px = (idx % W).long()
coords = torch.stack([px.float(), py.float()], dim=-1)
# Refine using local derivatives of log-heatmap
for b in range(B):
for n in range(N):
x = px[b, n].item()
y = py[b, n].item()
# Need 1-pixel neighborhood for derivatives
if x < 1 or x > W - 2 or y < 1 or y > H - 2:
continue
patch = hm[b, n]
dx = 0.5 * (patch[y, x + 1] - patch[y, x - 1])
dy = 0.5 * (patch[y + 1, x] - patch[y - 1, x])
dxx = patch[y, x + 1] - 2 * patch[y, x] + patch[y, x - 1]
dyy = patch[y + 1, x] - 2 * patch[y, x] + patch[y - 1, x]
dxy = 0.25 * (
patch[y + 1, x + 1]
- patch[y + 1, x - 1]
- patch[y - 1, x + 1]
+ patch[y - 1, x - 1]
)
grad = torch.stack([dx, dy]) # [2]
hessian = torch.stack(
[
torch.stack([dxx, dxy]),
torch.stack([dxy, dyy]),
]
) # [2, 2]
# Solve offset = -H^{-1} g
det = hessian[0, 0] * hessian[1, 1] - hessian[0, 1] * hessian[1, 0]
if torch.abs(det) < 1e-6:
continue
try:
offset = -torch.linalg.solve(hessian, grad)
except RuntimeError:
continue
# Keep refinement bounded; if huge, it's unstable
if torch.all(torch.abs(offset) <= 1.5):
coords[b, n, 0] += offset[0]
coords[b, n, 1] += offset[1]
if squeeze_batch:
coords = coords[0]
return coords
def heatmap_coords_to_image_coords(
coords: torch.Tensor,
image_size: tuple,
heatmap_size: tuple,
) -> torch.Tensor:
"""
Map coordinates from heatmap space back to image space.
Args:
coords: [B, N, 2] or [N, 2]
image_size: (H_img, W_img)
heatmap_size: (H_hm, W_hm)
"""
H_img, W_img = image_size
H_hm, W_hm = heatmap_size
out = coords.clone()
out[..., 0] *= (W_img / W_hm)
out[..., 1] *= (H_img / H_hm)
return out
def gaussian2d(size: int, sigma: float, device=None) -> torch.Tensor:
"""
Create a 2D Gaussian kernel of shape [size, size].
"""
coords = torch.arange(size, device=device, dtype=torch.float32)
center = (size - 1) / 2.0
x = coords - center
y = coords - center
yy, xx = torch.meshgrid(y, x, indexing="ij")
g = torch.exp(-(xx**2 + yy**2) / (2 * sigma * sigma))
return g
def draw_gaussian(
heatmap: torch.Tensor,
center_x: float,
center_y: float,
sigma: float,
) -> torch.Tensor:
"""
Draw a Gaussian on a single heatmap in-place.
Args:
heatmap: [H, W]
center_x, center_y: landmark coordinates in heatmap space
sigma: Gaussian sigma in heatmap pixels
"""
H, W = heatmap.shape
radius = int(3 * sigma)
size = 2 * radius + 1
mu_x = int(round(center_x.item()))
mu_y = int(round(center_y.item()))
left = min(mu_x, radius)
right = min(W - mu_x - 1, radius)
top = min(mu_y, radius)
bottom = min(H - mu_y - 1, radius)
if left < 0 or right < 0 or top < 0 or bottom < 0:
return heatmap
g = gaussian2d(size=size, sigma=sigma, device=heatmap.device)
g_x0 = radius - left
g_x1 = radius + right + 1
g_y0 = radius - top
g_y1 = radius + bottom + 1
h_x0 = mu_x - left
h_x1 = mu_x + right + 1
h_y0 = mu_y - top
h_y1 = mu_y + bottom + 1
heatmap[h_y0:h_y1, h_x0:h_x1] = torch.maximum(
heatmap[h_y0:h_y1, h_x0:h_x1],
g[g_y0:g_y1, g_x0:g_x1],
)
return heatmap
def generate_heatmaps(
landmarks: torch.Tensor,
image_size: tuple,
heatmap_size: tuple,
sigma: float = 2.0,
) -> torch.Tensor:
"""
Generate Gaussian heatmaps for landmark detection.
Args:
landmarks: [N, 2] tensor of (x, y) in original image coordinates
image_size: (H_img, W_img)
heatmap_size: (H_hm, W_hm)
sigma: Gaussian sigma in heatmap pixels
Returns:
heatmaps: [N, H_hm, W_hm]
"""
if landmarks.ndim != 2 or landmarks.shape[1] != 2:
raise ValueError(f"Expected landmarks shape [N, 2], got {landmarks.shape}")
H_img, W_img = image_size
H_hm, W_hm = heatmap_size
scale_x = W_hm / W_img
scale_y = H_hm / H_img
device = landmarks.device
num_landmarks = landmarks.shape[0]
heatmaps = torch.zeros((num_landmarks, H_hm, W_hm), dtype=torch.float32, device=device)
for i in range(num_landmarks):
x, y = landmarks[i]
x_hm = x * scale_x
y_hm = y * scale_y
if 0 <= x_hm < W_hm and 0 <= y_hm < H_hm:
draw_gaussian(heatmaps[i], x_hm, y_hm, sigma=sigma)
return heatmaps
def generate_batch_heatmaps(
landmarks_batch: torch.Tensor,
image_size: tuple,
heatmap_size: tuple,
sigma: float = 2.0,
) -> torch.Tensor:
"""
Batch version.
Args:
landmarks_batch: [B, N, 2]
image_size: (H_img, W_img)
heatmap_size: (H_hm, W_hm)
Returns:
heatmaps: [B, N, H_hm, W_hm]
"""
if landmarks_batch.ndim != 3 or landmarks_batch.shape[-1] != 2:
raise ValueError(f"Expected [B, N, 2], got {landmarks_batch.shape}")
out = []
for b in range(landmarks_batch.shape[0]):
hm = generate_heatmaps(
landmarks=landmarks_batch[b],
image_size=image_size,
heatmap_size=heatmap_size,
sigma=sigma,
)
out.append(hm)
return torch.stack(out, dim=0)
def heatmaps_to_coords_argmax(heatmaps: torch.Tensor) -> torch.Tensor:
"""
Decode coordinates from heatmaps using argmax.
Args:
heatmaps: [B, N, H, W] or [N, H, W]
Returns:
coords: [B, N, 2] or [N, 2] in heatmap coordinates
"""
squeeze_batch = False
if heatmaps.ndim == 3:
heatmaps = heatmaps.unsqueeze(0)
squeeze_batch = True
B, N, H, W = heatmaps.shape
flat = heatmaps.view(B, N, -1)
idx = flat.argmax(dim=-1)
y = idx // W
x = idx % W
coords = torch.stack([x.float(), y.float()], dim=-1)
if squeeze_batch:
coords = coords[0]
return coords
def heatmap_coords_to_image_coords(
coords: torch.Tensor,
image_size: tuple,
heatmap_size: tuple,
) -> torch.Tensor:
"""
Map coordinates from heatmap space back to image space.
"""
H_img, W_img = image_size
H_hm, W_hm = heatmap_size
scale_x = W_img / W_hm
scale_y = H_img / H_hm
out = coords.clone()
out[..., 0] = out[..., 0] * scale_x
out[..., 1] = out[..., 1] * scale_y
return out