NeAR / app_gsplat.py
luh1124's picture
fix(gsplat): 2DGS image fit — expand colors to [1,N,3] for packed=False
2b829a1
# SPDX-FileCopyrightText: Copyright 2023-2026 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Gradio app: 3DGS / 2DGS image fitting with gsplat (Adam + rasterization).
HF ZeroGPU: set ``app_file: app_gsplat.py`` in Space README front matter.
Heavy work runs inside ``@spaces.GPU`` (training requires gradients — no inference_mode).
"""
from __future__ import annotations
import math
import os
import time
import traceback
from typing import Literal
import gradio as gr
import numpy as np
import torch
from PIL import Image
from torch import Tensor, optim
if not os.environ.get("HF_TOKEN") and not os.environ.get("HUGGING_FACE_HUB_TOKEN"):
_t = (os.environ.get("near") or os.environ.get("NEAR") or "").strip()
if _t:
os.environ["HF_TOKEN"] = _t
try:
import spaces # pyright: ignore[reportMissingImports]
except ImportError:
spaces = None
GPU = spaces.GPU if spaces is not None else (lambda f: f)
DEFAULT_PORT = 7863
def _render_colors_to_hw3(render_colors: Tensor) -> Tensor:
"""Flatten batch / camera dims to [H, W, C] and keep RGB."""
x = render_colors
while x.dim() > 3:
x = x[0]
return x[..., :3]
def _gray_preview() -> np.ndarray:
return np.full((64, 128, 3), 48, dtype=np.uint8)
def _pil_to_gt_tensor(img: Image.Image, max_side: int) -> Tensor:
rgb = img.convert("RGB")
w, h = rgb.size
m = max(w, h)
if m > max_side and m > 0:
s = max_side / float(m)
rgb = rgb.resize((max(1, int(w * s)), max(1, int(h * s))), Image.Resampling.LANCZOS)
arr = np.asarray(rgb, dtype=np.float32) / 255.0
return torch.from_numpy(arr)
def _default_gt_tensor(height: int, width: int) -> Tensor:
gt = torch.ones((height, width, 3))
gt[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0])
gt[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0])
return gt
class SimpleImageFitter:
"""Train random Gaussians to match a target image (MSE)."""
def __init__(self, gt_image: Tensor, num_points: int, device: torch.device) -> None:
self.device = device
self.gt_image = gt_image.to(device=device, dtype=torch.float32)
self.num_points = num_points
self.H, self.W = int(gt_image.shape[0]), int(gt_image.shape[1])
fov_x = math.pi / 2.0
self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x)
self.K = torch.tensor(
[
[self.focal, 0.0, self.W / 2.0],
[0.0, self.focal, self.H / 2.0],
[0.0, 0.0, 1.0],
],
device=self.device,
dtype=torch.float32,
)
self._init_gaussians()
def _init_gaussians(self) -> None:
bd = 2.0
n = self.num_points
dev = self.device
self.means = bd * (torch.rand(n, 3, device=dev) - 0.5)
self.scales = torch.rand(n, 3, device=dev)
d = 3
self.rgbs = torch.rand(n, d, device=dev)
u = torch.rand(n, 1, device=dev)
v = torch.rand(n, 1, device=dev)
w = torch.rand(n, 1, device=dev)
self.quats = torch.cat(
[
torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v),
torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v),
torch.sqrt(u) * torch.sin(2.0 * math.pi * w),
torch.sqrt(u) * torch.cos(2.0 * math.pi * w),
],
dim=-1,
)
self.opacities = torch.ones(n, device=dev)
self.viewmat = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 8.0],
[0.0, 0.0, 0.0, 1.0],
],
device=dev,
dtype=torch.float32,
)
self.means.requires_grad_(True)
self.scales.requires_grad_(True)
self.quats.requires_grad_(True)
self.rgbs.requires_grad_(True)
self.opacities.requires_grad_(True)
def _rasterize(
self,
rasterize_fnc,
means: Tensor,
quats: Tensor,
scales: Tensor,
opacities: Tensor,
rgbs: Tensor,
) -> Tensor:
vm = self.viewmat[None]
k = self.K[None]
quats_n = quats / quats.norm(dim=-1, keepdim=True)
op = torch.sigmoid(opacities)
colors = torch.sigmoid(rgbs)
# 2DGS packed=False: colors must be [..., C, N, D] to match projected means2d
# (same leading dims as viewmats batch). Plain [N, D] hits assert in rasterize_to_pixels_2dgs.
if rasterize_fnc.__name__ == "rasterization_2dgs" and colors.dim() == 2:
colors = colors.unsqueeze(0)
out = rasterize_fnc(
means,
quats_n,
scales,
op,
colors,
vm,
k,
self.W,
self.H,
packed=False,
)
rc = out[0]
return _render_colors_to_hw3(rc)
def train(
self,
iterations: int,
lr: float,
model_type: Literal["3dgs", "2dgs"],
progress: gr.Progress | None = None,
) -> tuple[np.ndarray, str]:
from gsplat import rasterization, rasterization_2dgs # pyright: ignore[reportMissingImports]
rasterize_fnc = rasterization if model_type == "3dgs" else rasterization_2dgs
optimizer = optim.Adam(
[self.rgbs, self.means, self.scales, self.opacities, self.quats],
lr=lr,
)
mse = torch.nn.MSELoss()
t_rast = 0.0
t_bwd = 0.0
iterator = range(iterations)
if progress is not None:
iterator = progress.tqdm(iterator, desc="gsplat fit")
last_loss = 0.0
for it in iterator:
t0 = time.perf_counter()
out_img = self._rasterize(
rasterize_fnc,
self.means,
self.quats,
self.scales,
self.opacities,
self.rgbs,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
t_rast += time.perf_counter() - t0
loss = mse(out_img, self.gt_image)
last_loss = float(loss.item())
optimizer.zero_grad(set_to_none=True)
t1 = time.perf_counter()
loss.backward()
if torch.cuda.is_available():
torch.cuda.synchronize()
t_bwd += time.perf_counter() - t1
optimizer.step()
with torch.no_grad():
final = self._rasterize(
rasterize_fnc,
self.means,
self.quats,
self.scales,
self.opacities,
self.rgbs,
)
pred_u8 = (final.clamp(0.0, 1.0).cpu().numpy() * 255.0).astype(np.uint8)
gt_u8 = (self.gt_image.clamp(0.0, 1.0).cpu().numpy() * 255.0).astype(np.uint8)
h = min(pred_u8.shape[0], gt_u8.shape[0])
w = min(pred_u8.shape[1], gt_u8.shape[1])
combo = np.concatenate([gt_u8[:h, :w], pred_u8[:h, :w]], axis=1)
msg = (
f"**{model_type}** | **{self.W}x{self.H}** | **{self.num_points}** splats | "
f"**{iterations}** iters | final MSE loss: **{last_loss:.6f}**\n\n"
f"Time (total): raster **{t_rast:.2f}s**, backward **{t_bwd:.2f}s**\n\n"
f"Left: target, right: fit."
)
return combo, msg
@GPU
def run_fit(
image: Image.Image | None,
max_side: int,
num_points: int,
iterations: int,
lr: float,
model_type: Literal["3dgs", "2dgs"],
progress: gr.Progress = gr.Progress(track_tqdm=True),
) -> tuple[np.ndarray, str]:
if not torch.cuda.is_available():
return _gray_preview(), "CUDA is not available (needs a GPU backend, e.g. ZeroGPU)."
max_side = int(max(32, min(512, max_side)))
num_points = int(max(256, min(200_000, num_points)))
iterations = int(max(1, min(5000, iterations)))
lr = float(max(1e-5, min(0.5, lr)))
try:
if image is None:
s = min(max_side, 256)
gt = _default_gt_tensor(s, s)
else:
gt = _pil_to_gt_tensor(image, max_side)
dev = torch.device("cuda")
fitter = SimpleImageFitter(gt, num_points=num_points, device=dev)
return fitter.train(
iterations=iterations,
lr=lr,
model_type=model_type,
progress=progress,
)
except Exception:
return _gray_preview(), f"Training failed:\n```\n{traceback.format_exc()}\n```"
def build_app() -> gr.Blocks:
with gr.Blocks(title="gsplat image fitting") as demo:
gr.Markdown(
"Fit **3DGS** or **2DGS** splats to an image with **Adam + gsplat rasterization** "
"(same idea as `gsplat` `examples/image_fitting.py`). "
"For **Hugging Face ZeroGPU**, use `app_file: app_gsplat.py` and keep "
"`spaces` in `requirements.txt`."
)
with gr.Row():
inp = gr.Image(label="Target image (optional)", type="pil", height=320)
out = gr.Image(label="Target | prediction", interactive=False, height=320)
with gr.Row():
max_side = gr.Slider(64, 512, value=256, step=16, label="Max image side (px)")
num_points = gr.Slider(1000, 100000, value=8000, step=500, label="Num Gaussians")
iterations = gr.Slider(50, 2000, value=300, step=50, label="Iterations")
lr = gr.Slider(0.001, 0.05, value=0.01, step=0.001, label="Learning rate")
model_type = gr.Radio(["3dgs", "2dgs"], value="3dgs", label="Rasterizer")
go = gr.Button("Train", variant="primary")
md = gr.Markdown("—")
go.click(
run_fit,
[inp, max_side, num_points, iterations, lr, model_type],
[out, md],
)
return demo
demo = build_app()
demo.queue(max_size=4)
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser()
p.add_argument("--host", default=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"))
p.add_argument(
"--port",
type=int,
default=int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", str(DEFAULT_PORT)))),
)
p.add_argument("--share", action="store_true")
a = p.parse_args()
demo.launch(server_name=a.host, server_port=a.port, share=a.share)