Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| def _gaussian_blur_heatmaps(heatmaps: torch.Tensor, kernel: int = 11) -> torch.Tensor: | |
| if kernel % 2 == 0: | |
| raise ValueError("kernel must be odd") | |
| sigma = kernel / 6.0 | |
| radius = kernel // 2 | |
| x = torch.arange(kernel, device=heatmaps.device, dtype=heatmaps.dtype) - radius | |
| g = torch.exp(-(x ** 2) / (2 * sigma * sigma)) | |
| g = g / g.sum() | |
| g_x = g.view(1, 1, 1, kernel) | |
| g_y = g.view(1, 1, kernel, 1) | |
| B, N, H, W = heatmaps.shape | |
| # 🔥 FIX HERE | |
| x_in = heatmaps.reshape(B * N, 1, H, W) | |
| x_in = F.pad(x_in, (radius, radius, 0, 0), mode="reflect") | |
| x_in = F.conv2d(x_in, g_x) | |
| x_in = F.pad(x_in, (0, 0, radius, radius), mode="reflect") | |
| x_in = F.conv2d(x_in, g_y) | |
| return x_in.reshape(B, N, H, W) | |
| def heatmaps_to_coords_dark( | |
| heatmaps: torch.Tensor, | |
| blur_kernel: int = 11, | |
| eps: float = 1e-10, | |
| ) -> torch.Tensor: | |
| """ | |
| DARK-style decoding with second-order local refinement. | |
| Args: | |
| heatmaps: [B, N, H, W] or [N, H, W] | |
| blur_kernel: Gaussian blur kernel before log | |
| eps: numerical stability for log | |
| Returns: | |
| coords: [B, N, 2] or [N, 2] in heatmap coordinates | |
| """ | |
| squeeze_batch = False | |
| if heatmaps.ndim == 3: | |
| heatmaps = heatmaps.unsqueeze(0) | |
| squeeze_batch = True | |
| if heatmaps.ndim != 4: | |
| raise ValueError(f"Expected [B, N, H, W] or [N, H, W], got {heatmaps.shape}") | |
| B, N, H, W = heatmaps.shape | |
| # Blur then log, as in DARK-style refinement | |
| hm = _gaussian_blur_heatmaps(heatmaps, kernel=blur_kernel) | |
| hm = torch.clamp(hm, min=eps).log() | |
| # Coarse argmax | |
| flat = hm.view(B, N, -1) | |
| idx = flat.argmax(dim=-1) | |
| py = (idx // W).long() | |
| px = (idx % W).long() | |
| coords = torch.stack([px.float(), py.float()], dim=-1) | |
| # Refine using local derivatives of log-heatmap | |
| for b in range(B): | |
| for n in range(N): | |
| x = px[b, n].item() | |
| y = py[b, n].item() | |
| # Need 1-pixel neighborhood for derivatives | |
| if x < 1 or x > W - 2 or y < 1 or y > H - 2: | |
| continue | |
| patch = hm[b, n] | |
| dx = 0.5 * (patch[y, x + 1] - patch[y, x - 1]) | |
| dy = 0.5 * (patch[y + 1, x] - patch[y - 1, x]) | |
| dxx = patch[y, x + 1] - 2 * patch[y, x] + patch[y, x - 1] | |
| dyy = patch[y + 1, x] - 2 * patch[y, x] + patch[y - 1, x] | |
| dxy = 0.25 * ( | |
| patch[y + 1, x + 1] | |
| - patch[y + 1, x - 1] | |
| - patch[y - 1, x + 1] | |
| + patch[y - 1, x - 1] | |
| ) | |
| grad = torch.stack([dx, dy]) # [2] | |
| hessian = torch.stack( | |
| [ | |
| torch.stack([dxx, dxy]), | |
| torch.stack([dxy, dyy]), | |
| ] | |
| ) # [2, 2] | |
| # Solve offset = -H^{-1} g | |
| det = hessian[0, 0] * hessian[1, 1] - hessian[0, 1] * hessian[1, 0] | |
| if torch.abs(det) < 1e-6: | |
| continue | |
| try: | |
| offset = -torch.linalg.solve(hessian, grad) | |
| except RuntimeError: | |
| continue | |
| # Keep refinement bounded; if huge, it's unstable | |
| if torch.all(torch.abs(offset) <= 1.5): | |
| coords[b, n, 0] += offset[0] | |
| coords[b, n, 1] += offset[1] | |
| if squeeze_batch: | |
| coords = coords[0] | |
| return coords | |
| def heatmap_coords_to_image_coords( | |
| coords: torch.Tensor, | |
| image_size: tuple, | |
| heatmap_size: tuple, | |
| ) -> torch.Tensor: | |
| """ | |
| Map coordinates from heatmap space back to image space. | |
| Args: | |
| coords: [B, N, 2] or [N, 2] | |
| image_size: (H_img, W_img) | |
| heatmap_size: (H_hm, W_hm) | |
| """ | |
| H_img, W_img = image_size | |
| H_hm, W_hm = heatmap_size | |
| out = coords.clone() | |
| out[..., 0] *= (W_img / W_hm) | |
| out[..., 1] *= (H_img / H_hm) | |
| return out | |
| def gaussian2d(size: int, sigma: float, device=None) -> torch.Tensor: | |
| """ | |
| Create a 2D Gaussian kernel of shape [size, size]. | |
| """ | |
| coords = torch.arange(size, device=device, dtype=torch.float32) | |
| center = (size - 1) / 2.0 | |
| x = coords - center | |
| y = coords - center | |
| yy, xx = torch.meshgrid(y, x, indexing="ij") | |
| g = torch.exp(-(xx**2 + yy**2) / (2 * sigma * sigma)) | |
| return g | |
| def draw_gaussian( | |
| heatmap: torch.Tensor, | |
| center_x: float, | |
| center_y: float, | |
| sigma: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Draw a Gaussian on a single heatmap in-place. | |
| Args: | |
| heatmap: [H, W] | |
| center_x, center_y: landmark coordinates in heatmap space | |
| sigma: Gaussian sigma in heatmap pixels | |
| """ | |
| H, W = heatmap.shape | |
| radius = int(3 * sigma) | |
| size = 2 * radius + 1 | |
| mu_x = int(round(center_x.item())) | |
| mu_y = int(round(center_y.item())) | |
| left = min(mu_x, radius) | |
| right = min(W - mu_x - 1, radius) | |
| top = min(mu_y, radius) | |
| bottom = min(H - mu_y - 1, radius) | |
| if left < 0 or right < 0 or top < 0 or bottom < 0: | |
| return heatmap | |
| g = gaussian2d(size=size, sigma=sigma, device=heatmap.device) | |
| g_x0 = radius - left | |
| g_x1 = radius + right + 1 | |
| g_y0 = radius - top | |
| g_y1 = radius + bottom + 1 | |
| h_x0 = mu_x - left | |
| h_x1 = mu_x + right + 1 | |
| h_y0 = mu_y - top | |
| h_y1 = mu_y + bottom + 1 | |
| heatmap[h_y0:h_y1, h_x0:h_x1] = torch.maximum( | |
| heatmap[h_y0:h_y1, h_x0:h_x1], | |
| g[g_y0:g_y1, g_x0:g_x1], | |
| ) | |
| return heatmap | |
| def generate_heatmaps( | |
| landmarks: torch.Tensor, | |
| image_size: tuple, | |
| heatmap_size: tuple, | |
| sigma: float = 2.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Generate Gaussian heatmaps for landmark detection. | |
| Args: | |
| landmarks: [N, 2] tensor of (x, y) in original image coordinates | |
| image_size: (H_img, W_img) | |
| heatmap_size: (H_hm, W_hm) | |
| sigma: Gaussian sigma in heatmap pixels | |
| Returns: | |
| heatmaps: [N, H_hm, W_hm] | |
| """ | |
| if landmarks.ndim != 2 or landmarks.shape[1] != 2: | |
| raise ValueError(f"Expected landmarks shape [N, 2], got {landmarks.shape}") | |
| H_img, W_img = image_size | |
| H_hm, W_hm = heatmap_size | |
| scale_x = W_hm / W_img | |
| scale_y = H_hm / H_img | |
| device = landmarks.device | |
| num_landmarks = landmarks.shape[0] | |
| heatmaps = torch.zeros((num_landmarks, H_hm, W_hm), dtype=torch.float32, device=device) | |
| for i in range(num_landmarks): | |
| x, y = landmarks[i] | |
| x_hm = x * scale_x | |
| y_hm = y * scale_y | |
| if 0 <= x_hm < W_hm and 0 <= y_hm < H_hm: | |
| draw_gaussian(heatmaps[i], x_hm, y_hm, sigma=sigma) | |
| return heatmaps | |
| def generate_batch_heatmaps( | |
| landmarks_batch: torch.Tensor, | |
| image_size: tuple, | |
| heatmap_size: tuple, | |
| sigma: float = 2.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Batch version. | |
| Args: | |
| landmarks_batch: [B, N, 2] | |
| image_size: (H_img, W_img) | |
| heatmap_size: (H_hm, W_hm) | |
| Returns: | |
| heatmaps: [B, N, H_hm, W_hm] | |
| """ | |
| if landmarks_batch.ndim != 3 or landmarks_batch.shape[-1] != 2: | |
| raise ValueError(f"Expected [B, N, 2], got {landmarks_batch.shape}") | |
| out = [] | |
| for b in range(landmarks_batch.shape[0]): | |
| hm = generate_heatmaps( | |
| landmarks=landmarks_batch[b], | |
| image_size=image_size, | |
| heatmap_size=heatmap_size, | |
| sigma=sigma, | |
| ) | |
| out.append(hm) | |
| return torch.stack(out, dim=0) | |
| def heatmaps_to_coords_argmax(heatmaps: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Decode coordinates from heatmaps using argmax. | |
| Args: | |
| heatmaps: [B, N, H, W] or [N, H, W] | |
| Returns: | |
| coords: [B, N, 2] or [N, 2] in heatmap coordinates | |
| """ | |
| squeeze_batch = False | |
| if heatmaps.ndim == 3: | |
| heatmaps = heatmaps.unsqueeze(0) | |
| squeeze_batch = True | |
| B, N, H, W = heatmaps.shape | |
| flat = heatmaps.view(B, N, -1) | |
| idx = flat.argmax(dim=-1) | |
| y = idx // W | |
| x = idx % W | |
| coords = torch.stack([x.float(), y.float()], dim=-1) | |
| if squeeze_batch: | |
| coords = coords[0] | |
| return coords | |
| def heatmap_coords_to_image_coords( | |
| coords: torch.Tensor, | |
| image_size: tuple, | |
| heatmap_size: tuple, | |
| ) -> torch.Tensor: | |
| """ | |
| Map coordinates from heatmap space back to image space. | |
| """ | |
| H_img, W_img = image_size | |
| H_hm, W_hm = heatmap_size | |
| scale_x = W_img / W_hm | |
| scale_y = H_img / H_hm | |
| out = coords.clone() | |
| out[..., 0] = out[..., 0] * scale_x | |
| out[..., 1] = out[..., 1] * scale_y | |
| return out |