"""Metric depth <-> RGB encodings, per Vision Banana (Gabeur et al. 2026). Common front-end (all colormaps share this): Barron (2025) power-transform: f(d, lambda=-3, c=10/3) = 1 - (1 + d/10)^(-2), mapping metric depth [0, inf) -> curve parameter u in [0, 1). Primary (canonical) colormap: Hilbert u -> RGB via piecewise-linear interp along a Hamiltonian path across 8 cube corners (black -> blue -> cyan -> green -> yellow -> red -> magenta -> white). Invertible: project RGB onto nearest segment. Augmentation colormaps (forward-only for training variety; not used by the eval decoder): Plasma / Inferno / Viridis: matplotlib perceptually-uniform LUTs applied to u. Grayscale: u replicated to all 3 channels. At eval we always request Hilbert so the RGB->depth inverse is well-defined. """ from __future__ import annotations import torch LAMBDA: float = -3.0 C: float = 10.0 / 3.0 LAMBDA_C: float = LAMBDA * C # -10 # Hamiltonian path on the cube: black -> ... -> white, each step flips one axis. CORNERS = torch.tensor( [ [0.0, 0.0, 0.0], # black [0.0, 0.0, 1.0], # blue [0.0, 1.0, 1.0], # cyan [0.0, 1.0, 0.0], # green [1.0, 1.0, 0.0], # yellow [1.0, 0.0, 0.0], # red [1.0, 0.0, 1.0], # magenta [1.0, 1.0, 1.0], # white ] ) N_SEG = CORNERS.shape[0] - 1 # 7 def depth_to_curve(depth: torch.Tensor) -> torch.Tensor: """Barron power-transform: metric depth in [0, inf) -> curve parameter u in [0, 1). NaN / negative / inf inputs map to 0 (encoded as black), so downstream integer indexing is safe. """ d = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0).clamp_min(0.0) return 1.0 - (1.0 + d / 10.0).pow(-2.0) def curve_to_depth(u: torch.Tensor) -> torch.Tensor: """Inverse Barron transform: u in [0, 1) -> metric depth in [0, inf).""" u_safe = u.clamp(0.0, 0.9999) return 10.0 * ((1.0 - u_safe).rsqrt() - 1.0) def curve_to_rgb(u: torch.Tensor) -> torch.Tensor: """u in [0, 1] -> RGB in [0, 1]^3 along the 7-segment Hamiltonian path.""" u_clamped = u.clamp(0.0, 1.0) scaled = u_clamped * N_SEG # [0, 7] idx = scaled.floor().clamp_(0, N_SEG - 1).long() # segment index [0, 6] t = (scaled - idx.to(scaled.dtype)).unsqueeze(-1) # local parameter [0, 1] corners = CORNERS.to(u.device, dtype=u.dtype) a = corners[idx] b = corners[idx + 1] return a + t * (b - a) def rgb_to_curve(rgb: torch.Tensor) -> torch.Tensor: """RGB in [0, 1]^3 -> curve parameter u in [0, 1] via nearest-segment projection. rgb: (..., 3) tensor. Returns: (...) tensor. """ corners = CORNERS.to(rgb.device, dtype=rgb.dtype) a = corners[:-1] # (7, 3) segment starts b = corners[1:] # (7, 3) segment ends d_vec = b - a # (7, 3) direction (unit length since corner-to-corner) # Broadcast rgb (..., 1, 3) against segments (7, 3). x = rgb.unsqueeze(-2) - a # (..., 7, 3) # Segment length squared is 1 for every corner-to-corner edge. t = (x * d_vec).sum(-1).clamp(0.0, 1.0) # (..., 7) proj = a + t.unsqueeze(-1) * d_vec # (..., 7, 3) dist2 = (rgb.unsqueeze(-2) - proj).pow(2).sum(-1) # (..., 7) seg_idx = dist2.argmin(dim=-1) # (...,) seg_t = t.gather(-1, seg_idx.unsqueeze(-1)).squeeze(-1) # (...,) return (seg_idx.to(rgb.dtype) + seg_t) / N_SEG def depth_to_rgb(depth: torch.Tensor) -> torch.Tensor: """Metric depth (..., H, W) -> RGB (..., H, W, 3) via the canonical Hilbert path.""" return curve_to_rgb(depth_to_curve(depth)) def rgb_to_depth(rgb: torch.Tensor) -> torch.Tensor: """RGB (..., H, W, 3) -> metric depth (..., H, W). Assumes the Hilbert encoding.""" return curve_to_depth(rgb_to_curve(rgb)) # ---- augmentation colormaps (forward-only) --------------------------------- _MPL_LUT_CACHE: dict[str, torch.Tensor] = {} def _mpl_lut(name: str, n: int = 1024) -> torch.Tensor: """Return a (n, 3) RGB LUT for a matplotlib colormap, cached on CPU in float32.""" key = f"{name}:{n}" if key not in _MPL_LUT_CACHE: import numpy as np import matplotlib.cm as mcm cmap = mcm.get_cmap(name) xs = np.linspace(0.0, 1.0, n, dtype=np.float32) rgb = cmap(xs)[:, :3].astype(np.float32) # drop alpha _MPL_LUT_CACHE[key] = torch.from_numpy(rgb) return _MPL_LUT_CACHE[key] def _curve_to_lut(u: torch.Tensor, lut: torch.Tensor) -> torch.Tensor: """Sample u in [0,1] into a (n,3) LUT with linear interpolation.""" n = lut.shape[0] lut = lut.to(u.device, dtype=u.dtype) scaled = u.clamp(0.0, 1.0) * (n - 1) idx_lo = scaled.floor().clamp_(0, n - 2).long() t = (scaled - idx_lo.to(scaled.dtype)).unsqueeze(-1) a = lut[idx_lo] b = lut[idx_lo + 1] return a + t * (b - a) COLORMAPS = ["hilbert", "plasma", "inferno", "viridis", "grayscale"] CM_DESCRIPTIONS = { "hilbert": ( "Color sequence from near to far: pure black (0,0,0), blue (0,0,255), cyan (0,255,255), " "green (0,255,0), yellow (255,255,0), red (255,0,0), magenta (255,0,255), white (255,255,255), " "with smooth gradients along this Hamiltonian cube path." ), "plasma": ( "Color sequence from near to far: dark purple, magenta, orange, yellow-white, using the plasma perceptual colormap." ), "inferno": ( "Color sequence from near to far: pure black, dark purple, red, orange, yellow, near-white, using the inferno perceptual colormap." ), "viridis": ( "Color sequence from near to far: dark purple, blue, teal, green, yellow, using the viridis perceptual colormap." ), "grayscale": ( "Near is pure black; far is pure white; pixels in between are monochrome gray scaled linearly with curved depth." ), } def depth_to_rgb_cm(depth: torch.Tensor, cm_name: str) -> torch.Tensor: """Encode metric depth with the named colormap. Only hilbert is invertible.""" u = depth_to_curve(depth) cm = cm_name.lower() if cm == "hilbert": return curve_to_rgb(u) if cm == "grayscale": return u.unsqueeze(-1).expand(*u.shape, 3).clone() if cm in ("plasma", "inferno", "viridis"): return _curve_to_lut(u, _mpl_lut(cm)) raise ValueError(f"unknown colormap: {cm_name}") if __name__ == "__main__": import math # 1. Round-trip error on a log-spaced depth grid from 1 cm to 100 m. depths = torch.logspace(-2, 2, steps=1000, dtype=torch.float64) recovered = rgb_to_depth(depth_to_rgb(depths)) err = (recovered - depths).abs() rel = err / depths print(f"round-trip: max abs err = {err.max().item()*100:.4f} cm") print(f" max rel err = {rel.max().item()*100:.5f} %") print(f" mean rel err = {rel.mean().item()*100:.5f} %") # 2. Endpoint sanity. print(f"d=0: rgb = {depth_to_rgb(torch.tensor(0.0, dtype=torch.float64)).tolist()}") print(f"d=10: rgb = {depth_to_rgb(torch.tensor(10.0, dtype=torch.float64)).tolist()}") print(f"d=50: rgb = {depth_to_rgb(torch.tensor(50.0, dtype=torch.float64)).tolist()}") print(f"d=1000: rgb = {depth_to_rgb(torch.tensor(1000.0, dtype=torch.float64)).tolist()}") # 3. Noise robustness: add gaussian noise to RGB, measure metric depth error. torch.manual_seed(0) depths = torch.linspace(0.1, 30.0, steps=500, dtype=torch.float64) rgb = depth_to_rgb(depths) for sigma in (0.0, 0.01, 0.02, 0.05): rgb_noisy = (rgb + sigma * torch.randn_like(rgb)).clamp(0, 1) recovered = rgb_to_depth(rgb_noisy) rel = ((recovered - depths).abs() / depths).mean().item() print(f"noise sigma={sigma:.2f}: mean rel err = {rel*100:.3f} %") # 4. GPU / batch shapes. if torch.cuda.is_available(): d = torch.rand(2, 512, 512, device="cuda") * 20.0 rgb = depth_to_rgb(d) d_back = rgb_to_depth(rgb) rel = ((d_back - d).abs() / d.clamp_min(1e-3)).mean().item() print(f"GPU batch (2,512,512): mean rel err = {rel*100:.4f} % rgb shape {tuple(rgb.shape)}")