deep-plantain / decode_rgb_to_depth.py
phanerozoic's picture
Initial release: weights, README with three OOD demos, RGB-to-depth decoder
f97c3fc verified
"""Metric depth <-> RGB encodings, per Vision Banana (Gabeur et al. 2026).
Common front-end (all colormaps share this):
Barron (2025) power-transform: f(d, lambda=-3, c=10/3) = 1 - (1 + d/10)^(-2),
mapping metric depth [0, inf) -> curve parameter u in [0, 1).
Primary (canonical) colormap: Hilbert
u -> RGB via piecewise-linear interp along a Hamiltonian path across 8 cube corners
(black -> blue -> cyan -> green -> yellow -> red -> magenta -> white).
Invertible: project RGB onto nearest segment.
Augmentation colormaps (forward-only for training variety; not used by the eval decoder):
Plasma / Inferno / Viridis: matplotlib perceptually-uniform LUTs applied to u.
Grayscale: u replicated to all 3 channels.
At eval we always request Hilbert so the RGB->depth inverse is well-defined.
"""
from __future__ import annotations
import torch
LAMBDA: float = -3.0
C: float = 10.0 / 3.0
LAMBDA_C: float = LAMBDA * C # -10
# Hamiltonian path on the cube: black -> ... -> white, each step flips one axis.
CORNERS = torch.tensor(
[
[0.0, 0.0, 0.0], # black
[0.0, 0.0, 1.0], # blue
[0.0, 1.0, 1.0], # cyan
[0.0, 1.0, 0.0], # green
[1.0, 1.0, 0.0], # yellow
[1.0, 0.0, 0.0], # red
[1.0, 0.0, 1.0], # magenta
[1.0, 1.0, 1.0], # white
]
)
N_SEG = CORNERS.shape[0] - 1 # 7
def depth_to_curve(depth: torch.Tensor) -> torch.Tensor:
"""Barron power-transform: metric depth in [0, inf) -> curve parameter u in [0, 1).
NaN / negative / inf inputs map to 0 (encoded as black), so downstream integer indexing is safe.
"""
d = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0).clamp_min(0.0)
return 1.0 - (1.0 + d / 10.0).pow(-2.0)
def curve_to_depth(u: torch.Tensor) -> torch.Tensor:
"""Inverse Barron transform: u in [0, 1) -> metric depth in [0, inf)."""
u_safe = u.clamp(0.0, 0.9999)
return 10.0 * ((1.0 - u_safe).rsqrt() - 1.0)
def curve_to_rgb(u: torch.Tensor) -> torch.Tensor:
"""u in [0, 1] -> RGB in [0, 1]^3 along the 7-segment Hamiltonian path."""
u_clamped = u.clamp(0.0, 1.0)
scaled = u_clamped * N_SEG # [0, 7]
idx = scaled.floor().clamp_(0, N_SEG - 1).long() # segment index [0, 6]
t = (scaled - idx.to(scaled.dtype)).unsqueeze(-1) # local parameter [0, 1]
corners = CORNERS.to(u.device, dtype=u.dtype)
a = corners[idx]
b = corners[idx + 1]
return a + t * (b - a)
def rgb_to_curve(rgb: torch.Tensor) -> torch.Tensor:
"""RGB in [0, 1]^3 -> curve parameter u in [0, 1] via nearest-segment projection.
rgb: (..., 3) tensor. Returns: (...) tensor.
"""
corners = CORNERS.to(rgb.device, dtype=rgb.dtype)
a = corners[:-1] # (7, 3) segment starts
b = corners[1:] # (7, 3) segment ends
d_vec = b - a # (7, 3) direction (unit length since corner-to-corner)
# Broadcast rgb (..., 1, 3) against segments (7, 3).
x = rgb.unsqueeze(-2) - a # (..., 7, 3)
# Segment length squared is 1 for every corner-to-corner edge.
t = (x * d_vec).sum(-1).clamp(0.0, 1.0) # (..., 7)
proj = a + t.unsqueeze(-1) * d_vec # (..., 7, 3)
dist2 = (rgb.unsqueeze(-2) - proj).pow(2).sum(-1) # (..., 7)
seg_idx = dist2.argmin(dim=-1) # (...,)
seg_t = t.gather(-1, seg_idx.unsqueeze(-1)).squeeze(-1) # (...,)
return (seg_idx.to(rgb.dtype) + seg_t) / N_SEG
def depth_to_rgb(depth: torch.Tensor) -> torch.Tensor:
"""Metric depth (..., H, W) -> RGB (..., H, W, 3) via the canonical Hilbert path."""
return curve_to_rgb(depth_to_curve(depth))
def rgb_to_depth(rgb: torch.Tensor) -> torch.Tensor:
"""RGB (..., H, W, 3) -> metric depth (..., H, W). Assumes the Hilbert encoding."""
return curve_to_depth(rgb_to_curve(rgb))
# ---- augmentation colormaps (forward-only) ---------------------------------
_MPL_LUT_CACHE: dict[str, torch.Tensor] = {}
def _mpl_lut(name: str, n: int = 1024) -> torch.Tensor:
"""Return a (n, 3) RGB LUT for a matplotlib colormap, cached on CPU in float32."""
key = f"{name}:{n}"
if key not in _MPL_LUT_CACHE:
import numpy as np
import matplotlib.cm as mcm
cmap = mcm.get_cmap(name)
xs = np.linspace(0.0, 1.0, n, dtype=np.float32)
rgb = cmap(xs)[:, :3].astype(np.float32) # drop alpha
_MPL_LUT_CACHE[key] = torch.from_numpy(rgb)
return _MPL_LUT_CACHE[key]
def _curve_to_lut(u: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
"""Sample u in [0,1] into a (n,3) LUT with linear interpolation."""
n = lut.shape[0]
lut = lut.to(u.device, dtype=u.dtype)
scaled = u.clamp(0.0, 1.0) * (n - 1)
idx_lo = scaled.floor().clamp_(0, n - 2).long()
t = (scaled - idx_lo.to(scaled.dtype)).unsqueeze(-1)
a = lut[idx_lo]
b = lut[idx_lo + 1]
return a + t * (b - a)
COLORMAPS = ["hilbert", "plasma", "inferno", "viridis", "grayscale"]
CM_DESCRIPTIONS = {
"hilbert": (
"Color sequence from near to far: pure black (0,0,0), blue (0,0,255), cyan (0,255,255), "
"green (0,255,0), yellow (255,255,0), red (255,0,0), magenta (255,0,255), white (255,255,255), "
"with smooth gradients along this Hamiltonian cube path."
),
"plasma": (
"Color sequence from near to far: dark purple, magenta, orange, yellow-white, using the plasma perceptual colormap."
),
"inferno": (
"Color sequence from near to far: pure black, dark purple, red, orange, yellow, near-white, using the inferno perceptual colormap."
),
"viridis": (
"Color sequence from near to far: dark purple, blue, teal, green, yellow, using the viridis perceptual colormap."
),
"grayscale": (
"Near is pure black; far is pure white; pixels in between are monochrome gray scaled linearly with curved depth."
),
}
def depth_to_rgb_cm(depth: torch.Tensor, cm_name: str) -> torch.Tensor:
"""Encode metric depth with the named colormap. Only hilbert is invertible."""
u = depth_to_curve(depth)
cm = cm_name.lower()
if cm == "hilbert":
return curve_to_rgb(u)
if cm == "grayscale":
return u.unsqueeze(-1).expand(*u.shape, 3).clone()
if cm in ("plasma", "inferno", "viridis"):
return _curve_to_lut(u, _mpl_lut(cm))
raise ValueError(f"unknown colormap: {cm_name}")
if __name__ == "__main__":
import math
# 1. Round-trip error on a log-spaced depth grid from 1 cm to 100 m.
depths = torch.logspace(-2, 2, steps=1000, dtype=torch.float64)
recovered = rgb_to_depth(depth_to_rgb(depths))
err = (recovered - depths).abs()
rel = err / depths
print(f"round-trip: max abs err = {err.max().item()*100:.4f} cm")
print(f" max rel err = {rel.max().item()*100:.5f} %")
print(f" mean rel err = {rel.mean().item()*100:.5f} %")
# 2. Endpoint sanity.
print(f"d=0: rgb = {depth_to_rgb(torch.tensor(0.0, dtype=torch.float64)).tolist()}")
print(f"d=10: rgb = {depth_to_rgb(torch.tensor(10.0, dtype=torch.float64)).tolist()}")
print(f"d=50: rgb = {depth_to_rgb(torch.tensor(50.0, dtype=torch.float64)).tolist()}")
print(f"d=1000: rgb = {depth_to_rgb(torch.tensor(1000.0, dtype=torch.float64)).tolist()}")
# 3. Noise robustness: add gaussian noise to RGB, measure metric depth error.
torch.manual_seed(0)
depths = torch.linspace(0.1, 30.0, steps=500, dtype=torch.float64)
rgb = depth_to_rgb(depths)
for sigma in (0.0, 0.01, 0.02, 0.05):
rgb_noisy = (rgb + sigma * torch.randn_like(rgb)).clamp(0, 1)
recovered = rgb_to_depth(rgb_noisy)
rel = ((recovered - depths).abs() / depths).mean().item()
print(f"noise sigma={sigma:.2f}: mean rel err = {rel*100:.3f} %")
# 4. GPU / batch shapes.
if torch.cuda.is_available():
d = torch.rand(2, 512, 512, device="cuda") * 20.0
rgb = depth_to_rgb(d)
d_back = rgb_to_depth(rgb)
rel = ((d_back - d).abs() / d.clamp_min(1e-3)).mean().item()
print(f"GPU batch (2,512,512): mean rel err = {rel*100:.4f} % rgb shape {tuple(rgb.shape)}")