# SPDX-FileCopyrightText: Copyright 2023-2026 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Gradio app: 3DGS / 2DGS image fitting with gsplat (Adam + rasterization). HF ZeroGPU: set ``app_file: app_gsplat.py`` in Space README front matter. Heavy work runs inside ``@spaces.GPU`` (training requires gradients — no inference_mode). """ from __future__ import annotations import math import os import time import traceback from typing import Literal import gradio as gr import numpy as np import torch from PIL import Image from torch import Tensor, optim if not os.environ.get("HF_TOKEN") and not os.environ.get("HUGGING_FACE_HUB_TOKEN"): _t = (os.environ.get("near") or os.environ.get("NEAR") or "").strip() if _t: os.environ["HF_TOKEN"] = _t try: import spaces # pyright: ignore[reportMissingImports] except ImportError: spaces = None GPU = spaces.GPU if spaces is not None else (lambda f: f) DEFAULT_PORT = 7863 def _render_colors_to_hw3(render_colors: Tensor) -> Tensor: """Flatten batch / camera dims to [H, W, C] and keep RGB.""" x = render_colors while x.dim() > 3: x = x[0] return x[..., :3] def _gray_preview() -> np.ndarray: return np.full((64, 128, 3), 48, dtype=np.uint8) def _pil_to_gt_tensor(img: Image.Image, max_side: int) -> Tensor: rgb = img.convert("RGB") w, h = rgb.size m = max(w, h) if m > max_side and m > 0: s = max_side / float(m) rgb = rgb.resize((max(1, int(w * s)), max(1, int(h * s))), Image.Resampling.LANCZOS) arr = np.asarray(rgb, dtype=np.float32) / 255.0 return torch.from_numpy(arr) def _default_gt_tensor(height: int, width: int) -> Tensor: gt = torch.ones((height, width, 3)) gt[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0]) gt[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0]) return gt class SimpleImageFitter: """Train random Gaussians to match a target image (MSE).""" def __init__(self, gt_image: Tensor, num_points: int, device: torch.device) -> None: self.device = device self.gt_image = gt_image.to(device=device, dtype=torch.float32) self.num_points = num_points self.H, self.W = int(gt_image.shape[0]), int(gt_image.shape[1]) fov_x = math.pi / 2.0 self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) self.K = torch.tensor( [ [self.focal, 0.0, self.W / 2.0], [0.0, self.focal, self.H / 2.0], [0.0, 0.0, 1.0], ], device=self.device, dtype=torch.float32, ) self._init_gaussians() def _init_gaussians(self) -> None: bd = 2.0 n = self.num_points dev = self.device self.means = bd * (torch.rand(n, 3, device=dev) - 0.5) self.scales = torch.rand(n, 3, device=dev) d = 3 self.rgbs = torch.rand(n, d, device=dev) u = torch.rand(n, 1, device=dev) v = torch.rand(n, 1, device=dev) w = torch.rand(n, 1, device=dev) self.quats = torch.cat( [ torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v), torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v), torch.sqrt(u) * torch.sin(2.0 * math.pi * w), torch.sqrt(u) * torch.cos(2.0 * math.pi * w), ], dim=-1, ) self.opacities = torch.ones(n, device=dev) self.viewmat = torch.tensor( [ [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 8.0], [0.0, 0.0, 0.0, 1.0], ], device=dev, dtype=torch.float32, ) self.means.requires_grad_(True) self.scales.requires_grad_(True) self.quats.requires_grad_(True) self.rgbs.requires_grad_(True) self.opacities.requires_grad_(True) def _rasterize( self, rasterize_fnc, means: Tensor, quats: Tensor, scales: Tensor, opacities: Tensor, rgbs: Tensor, ) -> Tensor: vm = self.viewmat[None] k = self.K[None] quats_n = quats / quats.norm(dim=-1, keepdim=True) op = torch.sigmoid(opacities) colors = torch.sigmoid(rgbs) # 2DGS packed=False: colors must be [..., C, N, D] to match projected means2d # (same leading dims as viewmats batch). Plain [N, D] hits assert in rasterize_to_pixels_2dgs. if rasterize_fnc.__name__ == "rasterization_2dgs" and colors.dim() == 2: colors = colors.unsqueeze(0) out = rasterize_fnc( means, quats_n, scales, op, colors, vm, k, self.W, self.H, packed=False, ) rc = out[0] return _render_colors_to_hw3(rc) def train( self, iterations: int, lr: float, model_type: Literal["3dgs", "2dgs"], progress: gr.Progress | None = None, ) -> tuple[np.ndarray, str]: from gsplat import rasterization, rasterization_2dgs # pyright: ignore[reportMissingImports] rasterize_fnc = rasterization if model_type == "3dgs" else rasterization_2dgs optimizer = optim.Adam( [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr=lr, ) mse = torch.nn.MSELoss() t_rast = 0.0 t_bwd = 0.0 iterator = range(iterations) if progress is not None: iterator = progress.tqdm(iterator, desc="gsplat fit") last_loss = 0.0 for it in iterator: t0 = time.perf_counter() out_img = self._rasterize( rasterize_fnc, self.means, self.quats, self.scales, self.opacities, self.rgbs, ) if torch.cuda.is_available(): torch.cuda.synchronize() t_rast += time.perf_counter() - t0 loss = mse(out_img, self.gt_image) last_loss = float(loss.item()) optimizer.zero_grad(set_to_none=True) t1 = time.perf_counter() loss.backward() if torch.cuda.is_available(): torch.cuda.synchronize() t_bwd += time.perf_counter() - t1 optimizer.step() with torch.no_grad(): final = self._rasterize( rasterize_fnc, self.means, self.quats, self.scales, self.opacities, self.rgbs, ) pred_u8 = (final.clamp(0.0, 1.0).cpu().numpy() * 255.0).astype(np.uint8) gt_u8 = (self.gt_image.clamp(0.0, 1.0).cpu().numpy() * 255.0).astype(np.uint8) h = min(pred_u8.shape[0], gt_u8.shape[0]) w = min(pred_u8.shape[1], gt_u8.shape[1]) combo = np.concatenate([gt_u8[:h, :w], pred_u8[:h, :w]], axis=1) msg = ( f"**{model_type}** | **{self.W}x{self.H}** | **{self.num_points}** splats | " f"**{iterations}** iters | final MSE loss: **{last_loss:.6f}**\n\n" f"Time (total): raster **{t_rast:.2f}s**, backward **{t_bwd:.2f}s**\n\n" f"Left: target, right: fit." ) return combo, msg @GPU def run_fit( image: Image.Image | None, max_side: int, num_points: int, iterations: int, lr: float, model_type: Literal["3dgs", "2dgs"], progress: gr.Progress = gr.Progress(track_tqdm=True), ) -> tuple[np.ndarray, str]: if not torch.cuda.is_available(): return _gray_preview(), "CUDA is not available (needs a GPU backend, e.g. ZeroGPU)." max_side = int(max(32, min(512, max_side))) num_points = int(max(256, min(200_000, num_points))) iterations = int(max(1, min(5000, iterations))) lr = float(max(1e-5, min(0.5, lr))) try: if image is None: s = min(max_side, 256) gt = _default_gt_tensor(s, s) else: gt = _pil_to_gt_tensor(image, max_side) dev = torch.device("cuda") fitter = SimpleImageFitter(gt, num_points=num_points, device=dev) return fitter.train( iterations=iterations, lr=lr, model_type=model_type, progress=progress, ) except Exception: return _gray_preview(), f"Training failed:\n```\n{traceback.format_exc()}\n```" def build_app() -> gr.Blocks: with gr.Blocks(title="gsplat image fitting") as demo: gr.Markdown( "Fit **3DGS** or **2DGS** splats to an image with **Adam + gsplat rasterization** " "(same idea as `gsplat` `examples/image_fitting.py`). " "For **Hugging Face ZeroGPU**, use `app_file: app_gsplat.py` and keep " "`spaces` in `requirements.txt`." ) with gr.Row(): inp = gr.Image(label="Target image (optional)", type="pil", height=320) out = gr.Image(label="Target | prediction", interactive=False, height=320) with gr.Row(): max_side = gr.Slider(64, 512, value=256, step=16, label="Max image side (px)") num_points = gr.Slider(1000, 100000, value=8000, step=500, label="Num Gaussians") iterations = gr.Slider(50, 2000, value=300, step=50, label="Iterations") lr = gr.Slider(0.001, 0.05, value=0.01, step=0.001, label="Learning rate") model_type = gr.Radio(["3dgs", "2dgs"], value="3dgs", label="Rasterizer") go = gr.Button("Train", variant="primary") md = gr.Markdown("—") go.click( run_fit, [inp, max_side, num_points, iterations, lr, model_type], [out, md], ) return demo demo = build_app() demo.queue(max_size=4) if __name__ == "__main__": import argparse p = argparse.ArgumentParser() p.add_argument("--host", default=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")) p.add_argument( "--port", type=int, default=int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", str(DEFAULT_PORT)))), ) p.add_argument("--share", action="store_true") a = p.parse_args() demo.launch(server_name=a.host, server_port=a.port, share=a.share)