import math import torch import torch.nn.functional as F def _gaussian_blur_heatmaps(heatmaps: torch.Tensor, kernel: int = 11) -> torch.Tensor: if kernel % 2 == 0: raise ValueError("kernel must be odd") sigma = kernel / 6.0 radius = kernel // 2 x = torch.arange(kernel, device=heatmaps.device, dtype=heatmaps.dtype) - radius g = torch.exp(-(x ** 2) / (2 * sigma * sigma)) g = g / g.sum() g_x = g.view(1, 1, 1, kernel) g_y = g.view(1, 1, kernel, 1) B, N, H, W = heatmaps.shape # 🔥 FIX HERE x_in = heatmaps.reshape(B * N, 1, H, W) x_in = F.pad(x_in, (radius, radius, 0, 0), mode="reflect") x_in = F.conv2d(x_in, g_x) x_in = F.pad(x_in, (0, 0, radius, radius), mode="reflect") x_in = F.conv2d(x_in, g_y) return x_in.reshape(B, N, H, W) def heatmaps_to_coords_dark( heatmaps: torch.Tensor, blur_kernel: int = 11, eps: float = 1e-10, ) -> torch.Tensor: """ DARK-style decoding with second-order local refinement. Args: heatmaps: [B, N, H, W] or [N, H, W] blur_kernel: Gaussian blur kernel before log eps: numerical stability for log Returns: coords: [B, N, 2] or [N, 2] in heatmap coordinates """ squeeze_batch = False if heatmaps.ndim == 3: heatmaps = heatmaps.unsqueeze(0) squeeze_batch = True if heatmaps.ndim != 4: raise ValueError(f"Expected [B, N, H, W] or [N, H, W], got {heatmaps.shape}") B, N, H, W = heatmaps.shape # Blur then log, as in DARK-style refinement hm = _gaussian_blur_heatmaps(heatmaps, kernel=blur_kernel) hm = torch.clamp(hm, min=eps).log() # Coarse argmax flat = hm.view(B, N, -1) idx = flat.argmax(dim=-1) py = (idx // W).long() px = (idx % W).long() coords = torch.stack([px.float(), py.float()], dim=-1) # Refine using local derivatives of log-heatmap for b in range(B): for n in range(N): x = px[b, n].item() y = py[b, n].item() # Need 1-pixel neighborhood for derivatives if x < 1 or x > W - 2 or y < 1 or y > H - 2: continue patch = hm[b, n] dx = 0.5 * (patch[y, x + 1] - patch[y, x - 1]) dy = 0.5 * (patch[y + 1, x] - patch[y - 1, x]) dxx = patch[y, x + 1] - 2 * patch[y, x] + patch[y, x - 1] dyy = patch[y + 1, x] - 2 * patch[y, x] + patch[y - 1, x] dxy = 0.25 * ( patch[y + 1, x + 1] - patch[y + 1, x - 1] - patch[y - 1, x + 1] + patch[y - 1, x - 1] ) grad = torch.stack([dx, dy]) # [2] hessian = torch.stack( [ torch.stack([dxx, dxy]), torch.stack([dxy, dyy]), ] ) # [2, 2] # Solve offset = -H^{-1} g det = hessian[0, 0] * hessian[1, 1] - hessian[0, 1] * hessian[1, 0] if torch.abs(det) < 1e-6: continue try: offset = -torch.linalg.solve(hessian, grad) except RuntimeError: continue # Keep refinement bounded; if huge, it's unstable if torch.all(torch.abs(offset) <= 1.5): coords[b, n, 0] += offset[0] coords[b, n, 1] += offset[1] if squeeze_batch: coords = coords[0] return coords def heatmap_coords_to_image_coords( coords: torch.Tensor, image_size: tuple, heatmap_size: tuple, ) -> torch.Tensor: """ Map coordinates from heatmap space back to image space. Args: coords: [B, N, 2] or [N, 2] image_size: (H_img, W_img) heatmap_size: (H_hm, W_hm) """ H_img, W_img = image_size H_hm, W_hm = heatmap_size out = coords.clone() out[..., 0] *= (W_img / W_hm) out[..., 1] *= (H_img / H_hm) return out def gaussian2d(size: int, sigma: float, device=None) -> torch.Tensor: """ Create a 2D Gaussian kernel of shape [size, size]. """ coords = torch.arange(size, device=device, dtype=torch.float32) center = (size - 1) / 2.0 x = coords - center y = coords - center yy, xx = torch.meshgrid(y, x, indexing="ij") g = torch.exp(-(xx**2 + yy**2) / (2 * sigma * sigma)) return g def draw_gaussian( heatmap: torch.Tensor, center_x: float, center_y: float, sigma: float, ) -> torch.Tensor: """ Draw a Gaussian on a single heatmap in-place. Args: heatmap: [H, W] center_x, center_y: landmark coordinates in heatmap space sigma: Gaussian sigma in heatmap pixels """ H, W = heatmap.shape radius = int(3 * sigma) size = 2 * radius + 1 mu_x = int(round(center_x.item())) mu_y = int(round(center_y.item())) left = min(mu_x, radius) right = min(W - mu_x - 1, radius) top = min(mu_y, radius) bottom = min(H - mu_y - 1, radius) if left < 0 or right < 0 or top < 0 or bottom < 0: return heatmap g = gaussian2d(size=size, sigma=sigma, device=heatmap.device) g_x0 = radius - left g_x1 = radius + right + 1 g_y0 = radius - top g_y1 = radius + bottom + 1 h_x0 = mu_x - left h_x1 = mu_x + right + 1 h_y0 = mu_y - top h_y1 = mu_y + bottom + 1 heatmap[h_y0:h_y1, h_x0:h_x1] = torch.maximum( heatmap[h_y0:h_y1, h_x0:h_x1], g[g_y0:g_y1, g_x0:g_x1], ) return heatmap def generate_heatmaps( landmarks: torch.Tensor, image_size: tuple, heatmap_size: tuple, sigma: float = 2.0, ) -> torch.Tensor: """ Generate Gaussian heatmaps for landmark detection. Args: landmarks: [N, 2] tensor of (x, y) in original image coordinates image_size: (H_img, W_img) heatmap_size: (H_hm, W_hm) sigma: Gaussian sigma in heatmap pixels Returns: heatmaps: [N, H_hm, W_hm] """ if landmarks.ndim != 2 or landmarks.shape[1] != 2: raise ValueError(f"Expected landmarks shape [N, 2], got {landmarks.shape}") H_img, W_img = image_size H_hm, W_hm = heatmap_size scale_x = W_hm / W_img scale_y = H_hm / H_img device = landmarks.device num_landmarks = landmarks.shape[0] heatmaps = torch.zeros((num_landmarks, H_hm, W_hm), dtype=torch.float32, device=device) for i in range(num_landmarks): x, y = landmarks[i] x_hm = x * scale_x y_hm = y * scale_y if 0 <= x_hm < W_hm and 0 <= y_hm < H_hm: draw_gaussian(heatmaps[i], x_hm, y_hm, sigma=sigma) return heatmaps def generate_batch_heatmaps( landmarks_batch: torch.Tensor, image_size: tuple, heatmap_size: tuple, sigma: float = 2.0, ) -> torch.Tensor: """ Batch version. Args: landmarks_batch: [B, N, 2] image_size: (H_img, W_img) heatmap_size: (H_hm, W_hm) Returns: heatmaps: [B, N, H_hm, W_hm] """ if landmarks_batch.ndim != 3 or landmarks_batch.shape[-1] != 2: raise ValueError(f"Expected [B, N, 2], got {landmarks_batch.shape}") out = [] for b in range(landmarks_batch.shape[0]): hm = generate_heatmaps( landmarks=landmarks_batch[b], image_size=image_size, heatmap_size=heatmap_size, sigma=sigma, ) out.append(hm) return torch.stack(out, dim=0) def heatmaps_to_coords_argmax(heatmaps: torch.Tensor) -> torch.Tensor: """ Decode coordinates from heatmaps using argmax. Args: heatmaps: [B, N, H, W] or [N, H, W] Returns: coords: [B, N, 2] or [N, 2] in heatmap coordinates """ squeeze_batch = False if heatmaps.ndim == 3: heatmaps = heatmaps.unsqueeze(0) squeeze_batch = True B, N, H, W = heatmaps.shape flat = heatmaps.view(B, N, -1) idx = flat.argmax(dim=-1) y = idx // W x = idx % W coords = torch.stack([x.float(), y.float()], dim=-1) if squeeze_batch: coords = coords[0] return coords def heatmap_coords_to_image_coords( coords: torch.Tensor, image_size: tuple, heatmap_size: tuple, ) -> torch.Tensor: """ Map coordinates from heatmap space back to image space. """ H_img, W_img = image_size H_hm, W_hm = heatmap_size scale_x = W_img / W_hm scale_y = H_img / H_hm out = coords.clone() out[..., 0] = out[..., 0] * scale_x out[..., 1] = out[..., 1] * scale_y return out