phanerozoic commited on
Commit
f97c3fc
·
verified ·
1 Parent(s): e9fd714

Initial release: weights, README with three OOD demos, RGB-to-depth decoder

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ readme/beach.png filter=lfs diff=lfs merge=lfs -text
37
+ readme/cat.jpg filter=lfs diff=lfs merge=lfs -text
38
+ readme/skier.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ base_model: black-forest-labs/FLUX.2-klein-base-4B
5
+ library_name: diffusers
6
+ tags:
7
+ - depth-estimation
8
+ - lora
9
+ - flux2
10
+ - vision-banana
11
+ - arxiv:2604.20329
12
+ pipeline_tag: depth-estimation
13
+ ---
14
+
15
+ # deep-plantain
16
+
17
+ A LoRA adapter on FLUX.2 Klein (4B) for monocular depth estimation. Tests one claim from *Image Generators are Generalist Vision Learners* (Gabeur et al., 2026; [arXiv:2604.20329](https://arxiv.org/abs/2604.20329)) using parameter-efficient tuning.
18
+
19
+ ## The paper's claim
20
+
21
+ Vision Banana argues that image generation training plays the same foundational role for vision that next-token pretraining plays for language. The latent capability for visual understanding is already inside any sufficiently strong image generator; lightweight instruction-tuning aligns it to produce decodable RGB outputs (segmentation masks, depth maps, surface normals, etc.). The paper demonstrates this on Nano Banana Pro across five tasks — referring, semantic, and instance segmentation; metric depth; surface normals — and matches or beats domain specialists (SAM 3, Depth Anything 3, Lotus-2) without sacrificing the base model's generation quality. The thesis is paradigm-level: **image generation as a universal interface for vision**, analogous to text generation in language.
22
+
23
+ ## What this LoRA tests
24
+
25
+ One axis of the paper's claim:
26
+
27
+ - One task of the five (monocular depth)
28
+ - Open base (FLUX.2 Klein 4B)
29
+ - LoRA, not full instruction-tuning of the original training mixture
30
+
31
+ Question: does the thesis — depth understanding latent in image generators, surfaced by instruction-tuning — survive parameter-efficient adaptation on an open backbone?
32
+
33
+ ## Method
34
+
35
+ Both pieces preserved exactly from the paper:
36
+
37
+ 1. **Reframe depth as image-to-image generation.** Input RGB → output RGB depth visualization.
38
+ 2. **Bijective RGB↔depth encoding.** Barron (2025) power transform compresses metric depth to a curve parameter `u ∈ [0, 1)`; piecewise-linear interpolation along a 7-segment Hamiltonian path through the corners of the RGB cube produces the visualization (black → blue → cyan → green → yellow → red → magenta → white). Decoded by projecting predicted RGB onto the nearest cube edge.
39
+
40
+ Training data: Hypersim (synthetic indoor) + NYU Depth V2 train split (real indoor). Maximum encoded depth 15 m by bijection cap.
41
+
42
+ ## Demos
43
+
44
+ Three pictures the model has never seen.
45
+
46
+ ![cat](readme/cat.jpg)
47
+
48
+ *Indoor portrait, close to training distribution. The cat is read as foreground (cyan, ~1–2 m), the wall as background (green, ~3 m), the blanket as nearer foreground (deep blue). Internal depth ordering and subject/background separation correct.*
49
+
50
+ ![beach](readme/beach.png)
51
+
52
+ *Outdoor scene, outside the indoor training distribution. Sky and ground are misencoded — the model has no learned representation for "sky" and pins it to ~5 m yellow rather than infinity. But the salient subjects survive: each distant figure, the kite, and the foreground bucket are individually segmented from the global gradient.*
53
+
54
+ ![skier](readme/skier.jpg)
55
+
56
+ *Outdoor mountain scene, also out-of-distribution. The subject is crisply isolated from snow, mountain, sky. Relative depth ordering of background layers is roughly correct (sky > mountain > snow > subject), compressed into the bijection's 15 m range.*
57
+
58
+ Across all three, a recurring pattern: the visually prominent subject reads more prominently than its actual metric depth would predict (most clearly the cat's tie). When the depth signal is ambiguous or out-of-distribution, the model falls back on saliency-shaped outputs rather than predicting noise. The behavior is consistent with the paper's argument that the base model carries latent representations of image structure — subjects, prominence, attention — which a depth-only LoRA inherits but does not overwrite.
59
+
60
+ ## Status
61
+
62
+ This is an early checkpoint. Improved weights from broader training data and longer schedules will replace it as they land.
63
+
64
+ ## Usage
65
+
66
+ ```python
67
+ from diffusers import Flux2KleinPipeline
68
+ import torch
69
+
70
+ pipe = Flux2KleinPipeline.from_pretrained(
71
+ "black-forest-labs/FLUX.2-klein-base-4B",
72
+ torch_dtype=torch.bfloat16,
73
+ ).to("cuda")
74
+ pipe.load_lora_weights("phanerozoic/deep-plantain")
75
+
76
+ prompt = (
77
+ "Generate a metric depth visualization of this image. Color scheme: "
78
+ "0 m black, ~0.8 m blue, ~1.8 m cyan, ~3.2 m green, ~5.3 m yellow, "
79
+ "~8.7 m red, ~16.5 m magenta, far approaching white. Smooth gradients "
80
+ "along this path; every pixel follows this depth-to-color scheme."
81
+ )
82
+
83
+ depth_pil = pipe(image=src, prompt=prompt, num_inference_steps=20).images[0]
84
+ ```
85
+
86
+ The decoder for predicted RGB → metric depth (nearest-segment projection + inverse Barron transform) is in `decode_rgb_to_depth.py`.
87
+
88
+ ## License
89
+
90
+ Apache 2.0 — matches base FLUX.2 Klein 4B.
91
+
92
+ ## References
93
+
94
+ - Gabeur, Long, Peng, et al. *Image Generators are Generalist Vision Learners.* [arXiv:2604.20329](https://arxiv.org/abs/2604.20329) (2026).
95
+ - Barron, J. T. *A Power Transform.* [arXiv:2502.10647](https://arxiv.org/abs/2502.10647) (2025).
96
+ - Black Forest Labs. *FLUX.2 Klein.* https://bfl.ai/models/flux-2-klein (2025).
decode_rgb_to_depth.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metric depth <-> RGB encodings, per Vision Banana (Gabeur et al. 2026).
2
+
3
+ Common front-end (all colormaps share this):
4
+ Barron (2025) power-transform: f(d, lambda=-3, c=10/3) = 1 - (1 + d/10)^(-2),
5
+ mapping metric depth [0, inf) -> curve parameter u in [0, 1).
6
+
7
+ Primary (canonical) colormap: Hilbert
8
+ u -> RGB via piecewise-linear interp along a Hamiltonian path across 8 cube corners
9
+ (black -> blue -> cyan -> green -> yellow -> red -> magenta -> white).
10
+ Invertible: project RGB onto nearest segment.
11
+
12
+ Augmentation colormaps (forward-only for training variety; not used by the eval decoder):
13
+ Plasma / Inferno / Viridis: matplotlib perceptually-uniform LUTs applied to u.
14
+ Grayscale: u replicated to all 3 channels.
15
+
16
+ At eval we always request Hilbert so the RGB->depth inverse is well-defined.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import torch
21
+
22
+ LAMBDA: float = -3.0
23
+ C: float = 10.0 / 3.0
24
+ LAMBDA_C: float = LAMBDA * C # -10
25
+
26
+ # Hamiltonian path on the cube: black -> ... -> white, each step flips one axis.
27
+ CORNERS = torch.tensor(
28
+ [
29
+ [0.0, 0.0, 0.0], # black
30
+ [0.0, 0.0, 1.0], # blue
31
+ [0.0, 1.0, 1.0], # cyan
32
+ [0.0, 1.0, 0.0], # green
33
+ [1.0, 1.0, 0.0], # yellow
34
+ [1.0, 0.0, 0.0], # red
35
+ [1.0, 0.0, 1.0], # magenta
36
+ [1.0, 1.0, 1.0], # white
37
+ ]
38
+ )
39
+ N_SEG = CORNERS.shape[0] - 1 # 7
40
+
41
+
42
+ def depth_to_curve(depth: torch.Tensor) -> torch.Tensor:
43
+ """Barron power-transform: metric depth in [0, inf) -> curve parameter u in [0, 1).
44
+
45
+ NaN / negative / inf inputs map to 0 (encoded as black), so downstream integer indexing is safe.
46
+ """
47
+ d = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0).clamp_min(0.0)
48
+ return 1.0 - (1.0 + d / 10.0).pow(-2.0)
49
+
50
+
51
+ def curve_to_depth(u: torch.Tensor) -> torch.Tensor:
52
+ """Inverse Barron transform: u in [0, 1) -> metric depth in [0, inf)."""
53
+ u_safe = u.clamp(0.0, 0.9999)
54
+ return 10.0 * ((1.0 - u_safe).rsqrt() - 1.0)
55
+
56
+
57
+ def curve_to_rgb(u: torch.Tensor) -> torch.Tensor:
58
+ """u in [0, 1] -> RGB in [0, 1]^3 along the 7-segment Hamiltonian path."""
59
+ u_clamped = u.clamp(0.0, 1.0)
60
+ scaled = u_clamped * N_SEG # [0, 7]
61
+ idx = scaled.floor().clamp_(0, N_SEG - 1).long() # segment index [0, 6]
62
+ t = (scaled - idx.to(scaled.dtype)).unsqueeze(-1) # local parameter [0, 1]
63
+
64
+ corners = CORNERS.to(u.device, dtype=u.dtype)
65
+ a = corners[idx]
66
+ b = corners[idx + 1]
67
+ return a + t * (b - a)
68
+
69
+
70
+ def rgb_to_curve(rgb: torch.Tensor) -> torch.Tensor:
71
+ """RGB in [0, 1]^3 -> curve parameter u in [0, 1] via nearest-segment projection.
72
+
73
+ rgb: (..., 3) tensor. Returns: (...) tensor.
74
+ """
75
+ corners = CORNERS.to(rgb.device, dtype=rgb.dtype)
76
+ a = corners[:-1] # (7, 3) segment starts
77
+ b = corners[1:] # (7, 3) segment ends
78
+ d_vec = b - a # (7, 3) direction (unit length since corner-to-corner)
79
+
80
+ # Broadcast rgb (..., 1, 3) against segments (7, 3).
81
+ x = rgb.unsqueeze(-2) - a # (..., 7, 3)
82
+ # Segment length squared is 1 for every corner-to-corner edge.
83
+ t = (x * d_vec).sum(-1).clamp(0.0, 1.0) # (..., 7)
84
+
85
+ proj = a + t.unsqueeze(-1) * d_vec # (..., 7, 3)
86
+ dist2 = (rgb.unsqueeze(-2) - proj).pow(2).sum(-1) # (..., 7)
87
+
88
+ seg_idx = dist2.argmin(dim=-1) # (...,)
89
+ seg_t = t.gather(-1, seg_idx.unsqueeze(-1)).squeeze(-1) # (...,)
90
+ return (seg_idx.to(rgb.dtype) + seg_t) / N_SEG
91
+
92
+
93
+ def depth_to_rgb(depth: torch.Tensor) -> torch.Tensor:
94
+ """Metric depth (..., H, W) -> RGB (..., H, W, 3) via the canonical Hilbert path."""
95
+ return curve_to_rgb(depth_to_curve(depth))
96
+
97
+
98
+ def rgb_to_depth(rgb: torch.Tensor) -> torch.Tensor:
99
+ """RGB (..., H, W, 3) -> metric depth (..., H, W). Assumes the Hilbert encoding."""
100
+ return curve_to_depth(rgb_to_curve(rgb))
101
+
102
+
103
+ # ---- augmentation colormaps (forward-only) ---------------------------------
104
+
105
+ _MPL_LUT_CACHE: dict[str, torch.Tensor] = {}
106
+
107
+
108
+ def _mpl_lut(name: str, n: int = 1024) -> torch.Tensor:
109
+ """Return a (n, 3) RGB LUT for a matplotlib colormap, cached on CPU in float32."""
110
+ key = f"{name}:{n}"
111
+ if key not in _MPL_LUT_CACHE:
112
+ import numpy as np
113
+ import matplotlib.cm as mcm
114
+ cmap = mcm.get_cmap(name)
115
+ xs = np.linspace(0.0, 1.0, n, dtype=np.float32)
116
+ rgb = cmap(xs)[:, :3].astype(np.float32) # drop alpha
117
+ _MPL_LUT_CACHE[key] = torch.from_numpy(rgb)
118
+ return _MPL_LUT_CACHE[key]
119
+
120
+
121
+ def _curve_to_lut(u: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
122
+ """Sample u in [0,1] into a (n,3) LUT with linear interpolation."""
123
+ n = lut.shape[0]
124
+ lut = lut.to(u.device, dtype=u.dtype)
125
+ scaled = u.clamp(0.0, 1.0) * (n - 1)
126
+ idx_lo = scaled.floor().clamp_(0, n - 2).long()
127
+ t = (scaled - idx_lo.to(scaled.dtype)).unsqueeze(-1)
128
+ a = lut[idx_lo]
129
+ b = lut[idx_lo + 1]
130
+ return a + t * (b - a)
131
+
132
+
133
+ COLORMAPS = ["hilbert", "plasma", "inferno", "viridis", "grayscale"]
134
+
135
+ CM_DESCRIPTIONS = {
136
+ "hilbert": (
137
+ "Color sequence from near to far: pure black (0,0,0), blue (0,0,255), cyan (0,255,255), "
138
+ "green (0,255,0), yellow (255,255,0), red (255,0,0), magenta (255,0,255), white (255,255,255), "
139
+ "with smooth gradients along this Hamiltonian cube path."
140
+ ),
141
+ "plasma": (
142
+ "Color sequence from near to far: dark purple, magenta, orange, yellow-white, using the plasma perceptual colormap."
143
+ ),
144
+ "inferno": (
145
+ "Color sequence from near to far: pure black, dark purple, red, orange, yellow, near-white, using the inferno perceptual colormap."
146
+ ),
147
+ "viridis": (
148
+ "Color sequence from near to far: dark purple, blue, teal, green, yellow, using the viridis perceptual colormap."
149
+ ),
150
+ "grayscale": (
151
+ "Near is pure black; far is pure white; pixels in between are monochrome gray scaled linearly with curved depth."
152
+ ),
153
+ }
154
+
155
+
156
+ def depth_to_rgb_cm(depth: torch.Tensor, cm_name: str) -> torch.Tensor:
157
+ """Encode metric depth with the named colormap. Only hilbert is invertible."""
158
+ u = depth_to_curve(depth)
159
+ cm = cm_name.lower()
160
+ if cm == "hilbert":
161
+ return curve_to_rgb(u)
162
+ if cm == "grayscale":
163
+ return u.unsqueeze(-1).expand(*u.shape, 3).clone()
164
+ if cm in ("plasma", "inferno", "viridis"):
165
+ return _curve_to_lut(u, _mpl_lut(cm))
166
+ raise ValueError(f"unknown colormap: {cm_name}")
167
+
168
+
169
+ if __name__ == "__main__":
170
+ import math
171
+
172
+ # 1. Round-trip error on a log-spaced depth grid from 1 cm to 100 m.
173
+ depths = torch.logspace(-2, 2, steps=1000, dtype=torch.float64)
174
+ recovered = rgb_to_depth(depth_to_rgb(depths))
175
+ err = (recovered - depths).abs()
176
+ rel = err / depths
177
+ print(f"round-trip: max abs err = {err.max().item()*100:.4f} cm")
178
+ print(f" max rel err = {rel.max().item()*100:.5f} %")
179
+ print(f" mean rel err = {rel.mean().item()*100:.5f} %")
180
+
181
+ # 2. Endpoint sanity.
182
+ print(f"d=0: rgb = {depth_to_rgb(torch.tensor(0.0, dtype=torch.float64)).tolist()}")
183
+ print(f"d=10: rgb = {depth_to_rgb(torch.tensor(10.0, dtype=torch.float64)).tolist()}")
184
+ print(f"d=50: rgb = {depth_to_rgb(torch.tensor(50.0, dtype=torch.float64)).tolist()}")
185
+ print(f"d=1000: rgb = {depth_to_rgb(torch.tensor(1000.0, dtype=torch.float64)).tolist()}")
186
+
187
+ # 3. Noise robustness: add gaussian noise to RGB, measure metric depth error.
188
+ torch.manual_seed(0)
189
+ depths = torch.linspace(0.1, 30.0, steps=500, dtype=torch.float64)
190
+ rgb = depth_to_rgb(depths)
191
+ for sigma in (0.0, 0.01, 0.02, 0.05):
192
+ rgb_noisy = (rgb + sigma * torch.randn_like(rgb)).clamp(0, 1)
193
+ recovered = rgb_to_depth(rgb_noisy)
194
+ rel = ((recovered - depths).abs() / depths).mean().item()
195
+ print(f"noise sigma={sigma:.2f}: mean rel err = {rel*100:.3f} %")
196
+
197
+ # 4. GPU / batch shapes.
198
+ if torch.cuda.is_available():
199
+ d = torch.rand(2, 512, 512, device="cuda") * 20.0
200
+ rgb = depth_to_rgb(d)
201
+ d_back = rgb_to_depth(rgb)
202
+ rel = ((d_back - d).abs() / d.clamp_min(1e-3)).mean().item()
203
+ print(f"GPU batch (2,512,512): mean rel err = {rel*100:.4f} % rgb shape {tuple(rgb.shape)}")
pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:772c103ebdafa461696a15eb151b1fb5387b74915182d2cff0c93b2b022708d5
3
+ size 66866216
readme/beach.png ADDED

Git LFS Details

  • SHA256: 16e75fdc6bdc38f5c12adbf273adc0d6dbd2a98978189fa990ae510641cf7aa5
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
readme/cat.jpg ADDED

Git LFS Details

  • SHA256: a0691b7940b8d6f4f9d0d54ac0be724e1033cdcbaaa470774c1e1ecf460da0c6
  • Pointer size: 131 Bytes
  • Size of remote file: 172 kB
readme/skier.jpg ADDED

Git LFS Details

  • SHA256: 11d04dbd6e9612a778523a420dc30a46bb218bbf9875762b71f3e40b5543ce07
  • Pointer size: 131 Bytes
  • Size of remote file: 170 kB