| """Metric depth <-> RGB encodings, per Vision Banana (Gabeur et al. 2026). |
| |
| Common front-end (all colormaps share this): |
| Barron (2025) power-transform: f(d, lambda=-3, c=10/3) = 1 - (1 + d/10)^(-2), |
| mapping metric depth [0, inf) -> curve parameter u in [0, 1). |
| |
| Primary (canonical) colormap: Hilbert |
| u -> RGB via piecewise-linear interp along a Hamiltonian path across 8 cube corners |
| (black -> blue -> cyan -> green -> yellow -> red -> magenta -> white). |
| Invertible: project RGB onto nearest segment. |
| |
| Augmentation colormaps (forward-only for training variety; not used by the eval decoder): |
| Plasma / Inferno / Viridis: matplotlib perceptually-uniform LUTs applied to u. |
| Grayscale: u replicated to all 3 channels. |
| |
| At eval we always request Hilbert so the RGB->depth inverse is well-defined. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
|
|
| LAMBDA: float = -3.0 |
| C: float = 10.0 / 3.0 |
| LAMBDA_C: float = LAMBDA * C |
|
|
| |
| CORNERS = torch.tensor( |
| [ |
| [0.0, 0.0, 0.0], |
| [0.0, 0.0, 1.0], |
| [0.0, 1.0, 1.0], |
| [0.0, 1.0, 0.0], |
| [1.0, 1.0, 0.0], |
| [1.0, 0.0, 0.0], |
| [1.0, 0.0, 1.0], |
| [1.0, 1.0, 1.0], |
| ] |
| ) |
| N_SEG = CORNERS.shape[0] - 1 |
|
|
|
|
| def depth_to_curve(depth: torch.Tensor) -> torch.Tensor: |
| """Barron power-transform: metric depth in [0, inf) -> curve parameter u in [0, 1). |
| |
| NaN / negative / inf inputs map to 0 (encoded as black), so downstream integer indexing is safe. |
| """ |
| d = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0).clamp_min(0.0) |
| return 1.0 - (1.0 + d / 10.0).pow(-2.0) |
|
|
|
|
| def curve_to_depth(u: torch.Tensor) -> torch.Tensor: |
| """Inverse Barron transform: u in [0, 1) -> metric depth in [0, inf).""" |
| u_safe = u.clamp(0.0, 0.9999) |
| return 10.0 * ((1.0 - u_safe).rsqrt() - 1.0) |
|
|
|
|
| def curve_to_rgb(u: torch.Tensor) -> torch.Tensor: |
| """u in [0, 1] -> RGB in [0, 1]^3 along the 7-segment Hamiltonian path.""" |
| u_clamped = u.clamp(0.0, 1.0) |
| scaled = u_clamped * N_SEG |
| idx = scaled.floor().clamp_(0, N_SEG - 1).long() |
| t = (scaled - idx.to(scaled.dtype)).unsqueeze(-1) |
|
|
| corners = CORNERS.to(u.device, dtype=u.dtype) |
| a = corners[idx] |
| b = corners[idx + 1] |
| return a + t * (b - a) |
|
|
|
|
| def rgb_to_curve(rgb: torch.Tensor) -> torch.Tensor: |
| """RGB in [0, 1]^3 -> curve parameter u in [0, 1] via nearest-segment projection. |
| |
| rgb: (..., 3) tensor. Returns: (...) tensor. |
| """ |
| corners = CORNERS.to(rgb.device, dtype=rgb.dtype) |
| a = corners[:-1] |
| b = corners[1:] |
| d_vec = b - a |
|
|
| |
| x = rgb.unsqueeze(-2) - a |
| |
| t = (x * d_vec).sum(-1).clamp(0.0, 1.0) |
|
|
| proj = a + t.unsqueeze(-1) * d_vec |
| dist2 = (rgb.unsqueeze(-2) - proj).pow(2).sum(-1) |
|
|
| seg_idx = dist2.argmin(dim=-1) |
| seg_t = t.gather(-1, seg_idx.unsqueeze(-1)).squeeze(-1) |
| return (seg_idx.to(rgb.dtype) + seg_t) / N_SEG |
|
|
|
|
| def depth_to_rgb(depth: torch.Tensor) -> torch.Tensor: |
| """Metric depth (..., H, W) -> RGB (..., H, W, 3) via the canonical Hilbert path.""" |
| return curve_to_rgb(depth_to_curve(depth)) |
|
|
|
|
| def rgb_to_depth(rgb: torch.Tensor) -> torch.Tensor: |
| """RGB (..., H, W, 3) -> metric depth (..., H, W). Assumes the Hilbert encoding.""" |
| return curve_to_depth(rgb_to_curve(rgb)) |
|
|
|
|
| |
|
|
| _MPL_LUT_CACHE: dict[str, torch.Tensor] = {} |
|
|
|
|
| def _mpl_lut(name: str, n: int = 1024) -> torch.Tensor: |
| """Return a (n, 3) RGB LUT for a matplotlib colormap, cached on CPU in float32.""" |
| key = f"{name}:{n}" |
| if key not in _MPL_LUT_CACHE: |
| import numpy as np |
| import matplotlib.cm as mcm |
| cmap = mcm.get_cmap(name) |
| xs = np.linspace(0.0, 1.0, n, dtype=np.float32) |
| rgb = cmap(xs)[:, :3].astype(np.float32) |
| _MPL_LUT_CACHE[key] = torch.from_numpy(rgb) |
| return _MPL_LUT_CACHE[key] |
|
|
|
|
| def _curve_to_lut(u: torch.Tensor, lut: torch.Tensor) -> torch.Tensor: |
| """Sample u in [0,1] into a (n,3) LUT with linear interpolation.""" |
| n = lut.shape[0] |
| lut = lut.to(u.device, dtype=u.dtype) |
| scaled = u.clamp(0.0, 1.0) * (n - 1) |
| idx_lo = scaled.floor().clamp_(0, n - 2).long() |
| t = (scaled - idx_lo.to(scaled.dtype)).unsqueeze(-1) |
| a = lut[idx_lo] |
| b = lut[idx_lo + 1] |
| return a + t * (b - a) |
|
|
|
|
| COLORMAPS = ["hilbert", "plasma", "inferno", "viridis", "grayscale"] |
|
|
| CM_DESCRIPTIONS = { |
| "hilbert": ( |
| "Color sequence from near to far: pure black (0,0,0), blue (0,0,255), cyan (0,255,255), " |
| "green (0,255,0), yellow (255,255,0), red (255,0,0), magenta (255,0,255), white (255,255,255), " |
| "with smooth gradients along this Hamiltonian cube path." |
| ), |
| "plasma": ( |
| "Color sequence from near to far: dark purple, magenta, orange, yellow-white, using the plasma perceptual colormap." |
| ), |
| "inferno": ( |
| "Color sequence from near to far: pure black, dark purple, red, orange, yellow, near-white, using the inferno perceptual colormap." |
| ), |
| "viridis": ( |
| "Color sequence from near to far: dark purple, blue, teal, green, yellow, using the viridis perceptual colormap." |
| ), |
| "grayscale": ( |
| "Near is pure black; far is pure white; pixels in between are monochrome gray scaled linearly with curved depth." |
| ), |
| } |
|
|
|
|
| def depth_to_rgb_cm(depth: torch.Tensor, cm_name: str) -> torch.Tensor: |
| """Encode metric depth with the named colormap. Only hilbert is invertible.""" |
| u = depth_to_curve(depth) |
| cm = cm_name.lower() |
| if cm == "hilbert": |
| return curve_to_rgb(u) |
| if cm == "grayscale": |
| return u.unsqueeze(-1).expand(*u.shape, 3).clone() |
| if cm in ("plasma", "inferno", "viridis"): |
| return _curve_to_lut(u, _mpl_lut(cm)) |
| raise ValueError(f"unknown colormap: {cm_name}") |
|
|
|
|
| if __name__ == "__main__": |
| import math |
|
|
| |
| depths = torch.logspace(-2, 2, steps=1000, dtype=torch.float64) |
| recovered = rgb_to_depth(depth_to_rgb(depths)) |
| err = (recovered - depths).abs() |
| rel = err / depths |
| print(f"round-trip: max abs err = {err.max().item()*100:.4f} cm") |
| print(f" max rel err = {rel.max().item()*100:.5f} %") |
| print(f" mean rel err = {rel.mean().item()*100:.5f} %") |
|
|
| |
| print(f"d=0: rgb = {depth_to_rgb(torch.tensor(0.0, dtype=torch.float64)).tolist()}") |
| print(f"d=10: rgb = {depth_to_rgb(torch.tensor(10.0, dtype=torch.float64)).tolist()}") |
| print(f"d=50: rgb = {depth_to_rgb(torch.tensor(50.0, dtype=torch.float64)).tolist()}") |
| print(f"d=1000: rgb = {depth_to_rgb(torch.tensor(1000.0, dtype=torch.float64)).tolist()}") |
|
|
| |
| torch.manual_seed(0) |
| depths = torch.linspace(0.1, 30.0, steps=500, dtype=torch.float64) |
| rgb = depth_to_rgb(depths) |
| for sigma in (0.0, 0.01, 0.02, 0.05): |
| rgb_noisy = (rgb + sigma * torch.randn_like(rgb)).clamp(0, 1) |
| recovered = rgb_to_depth(rgb_noisy) |
| rel = ((recovered - depths).abs() / depths).mean().item() |
| print(f"noise sigma={sigma:.2f}: mean rel err = {rel*100:.3f} %") |
|
|
| |
| if torch.cuda.is_available(): |
| d = torch.rand(2, 512, 512, device="cuda") * 20.0 |
| rgb = depth_to_rgb(d) |
| d_back = rgb_to_depth(rgb) |
| rel = ((d_back - d).abs() / d.clamp_min(1e-3)).mean().item() |
| print(f"GPU batch (2,512,512): mean rel err = {rel*100:.4f} % rgb shape {tuple(rgb.shape)}") |
|
|