| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Gradio app: 3DGS / 2DGS image fitting with gsplat (Adam + rasterization). |
| |
| HF ZeroGPU: set ``app_file: app_gsplat.py`` in Space README front matter. |
| Heavy work runs inside ``@spaces.GPU`` (training requires gradients — no inference_mode). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| import os |
| import time |
| import traceback |
| from typing import Literal |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| from PIL import Image |
| from torch import Tensor, optim |
|
|
| if not os.environ.get("HF_TOKEN") and not os.environ.get("HUGGING_FACE_HUB_TOKEN"): |
| _t = (os.environ.get("near") or os.environ.get("NEAR") or "").strip() |
| if _t: |
| os.environ["HF_TOKEN"] = _t |
|
|
| try: |
| import spaces |
| except ImportError: |
| spaces = None |
|
|
| GPU = spaces.GPU if spaces is not None else (lambda f: f) |
|
|
| DEFAULT_PORT = 7863 |
|
|
|
|
| def _render_colors_to_hw3(render_colors: Tensor) -> Tensor: |
| """Flatten batch / camera dims to [H, W, C] and keep RGB.""" |
| x = render_colors |
| while x.dim() > 3: |
| x = x[0] |
| return x[..., :3] |
|
|
|
|
| def _gray_preview() -> np.ndarray: |
| return np.full((64, 128, 3), 48, dtype=np.uint8) |
|
|
|
|
| def _pil_to_gt_tensor(img: Image.Image, max_side: int) -> Tensor: |
| rgb = img.convert("RGB") |
| w, h = rgb.size |
| m = max(w, h) |
| if m > max_side and m > 0: |
| s = max_side / float(m) |
| rgb = rgb.resize((max(1, int(w * s)), max(1, int(h * s))), Image.Resampling.LANCZOS) |
| arr = np.asarray(rgb, dtype=np.float32) / 255.0 |
| return torch.from_numpy(arr) |
|
|
|
|
| def _default_gt_tensor(height: int, width: int) -> Tensor: |
| gt = torch.ones((height, width, 3)) |
| gt[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0]) |
| gt[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0]) |
| return gt |
|
|
|
|
| class SimpleImageFitter: |
| """Train random Gaussians to match a target image (MSE).""" |
|
|
| def __init__(self, gt_image: Tensor, num_points: int, device: torch.device) -> None: |
| self.device = device |
| self.gt_image = gt_image.to(device=device, dtype=torch.float32) |
| self.num_points = num_points |
| self.H, self.W = int(gt_image.shape[0]), int(gt_image.shape[1]) |
| fov_x = math.pi / 2.0 |
| self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) |
| self.K = torch.tensor( |
| [ |
| [self.focal, 0.0, self.W / 2.0], |
| [0.0, self.focal, self.H / 2.0], |
| [0.0, 0.0, 1.0], |
| ], |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self._init_gaussians() |
|
|
| def _init_gaussians(self) -> None: |
| bd = 2.0 |
| n = self.num_points |
| dev = self.device |
|
|
| self.means = bd * (torch.rand(n, 3, device=dev) - 0.5) |
| self.scales = torch.rand(n, 3, device=dev) |
| d = 3 |
| self.rgbs = torch.rand(n, d, device=dev) |
|
|
| u = torch.rand(n, 1, device=dev) |
| v = torch.rand(n, 1, device=dev) |
| w = torch.rand(n, 1, device=dev) |
| self.quats = torch.cat( |
| [ |
| torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v), |
| torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v), |
| torch.sqrt(u) * torch.sin(2.0 * math.pi * w), |
| torch.sqrt(u) * torch.cos(2.0 * math.pi * w), |
| ], |
| dim=-1, |
| ) |
| self.opacities = torch.ones(n, device=dev) |
|
|
| self.viewmat = torch.tensor( |
| [ |
| [1.0, 0.0, 0.0, 0.0], |
| [0.0, 1.0, 0.0, 0.0], |
| [0.0, 0.0, 1.0, 8.0], |
| [0.0, 0.0, 0.0, 1.0], |
| ], |
| device=dev, |
| dtype=torch.float32, |
| ) |
|
|
| self.means.requires_grad_(True) |
| self.scales.requires_grad_(True) |
| self.quats.requires_grad_(True) |
| self.rgbs.requires_grad_(True) |
| self.opacities.requires_grad_(True) |
|
|
| def _rasterize( |
| self, |
| rasterize_fnc, |
| means: Tensor, |
| quats: Tensor, |
| scales: Tensor, |
| opacities: Tensor, |
| rgbs: Tensor, |
| ) -> Tensor: |
| vm = self.viewmat[None] |
| k = self.K[None] |
| quats_n = quats / quats.norm(dim=-1, keepdim=True) |
| op = torch.sigmoid(opacities) |
| colors = torch.sigmoid(rgbs) |
| |
| |
| if rasterize_fnc.__name__ == "rasterization_2dgs" and colors.dim() == 2: |
| colors = colors.unsqueeze(0) |
| out = rasterize_fnc( |
| means, |
| quats_n, |
| scales, |
| op, |
| colors, |
| vm, |
| k, |
| self.W, |
| self.H, |
| packed=False, |
| ) |
| rc = out[0] |
| return _render_colors_to_hw3(rc) |
|
|
| def train( |
| self, |
| iterations: int, |
| lr: float, |
| model_type: Literal["3dgs", "2dgs"], |
| progress: gr.Progress | None = None, |
| ) -> tuple[np.ndarray, str]: |
| from gsplat import rasterization, rasterization_2dgs |
|
|
| rasterize_fnc = rasterization if model_type == "3dgs" else rasterization_2dgs |
|
|
| optimizer = optim.Adam( |
| [self.rgbs, self.means, self.scales, self.opacities, self.quats], |
| lr=lr, |
| ) |
| mse = torch.nn.MSELoss() |
| t_rast = 0.0 |
| t_bwd = 0.0 |
|
|
| iterator = range(iterations) |
| if progress is not None: |
| iterator = progress.tqdm(iterator, desc="gsplat fit") |
|
|
| last_loss = 0.0 |
| for it in iterator: |
| t0 = time.perf_counter() |
| out_img = self._rasterize( |
| rasterize_fnc, |
| self.means, |
| self.quats, |
| self.scales, |
| self.opacities, |
| self.rgbs, |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| t_rast += time.perf_counter() - t0 |
|
|
| loss = mse(out_img, self.gt_image) |
| last_loss = float(loss.item()) |
| optimizer.zero_grad(set_to_none=True) |
| t1 = time.perf_counter() |
| loss.backward() |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| t_bwd += time.perf_counter() - t1 |
| optimizer.step() |
|
|
| with torch.no_grad(): |
| final = self._rasterize( |
| rasterize_fnc, |
| self.means, |
| self.quats, |
| self.scales, |
| self.opacities, |
| self.rgbs, |
| ) |
| pred_u8 = (final.clamp(0.0, 1.0).cpu().numpy() * 255.0).astype(np.uint8) |
| gt_u8 = (self.gt_image.clamp(0.0, 1.0).cpu().numpy() * 255.0).astype(np.uint8) |
| h = min(pred_u8.shape[0], gt_u8.shape[0]) |
| w = min(pred_u8.shape[1], gt_u8.shape[1]) |
| combo = np.concatenate([gt_u8[:h, :w], pred_u8[:h, :w]], axis=1) |
|
|
| msg = ( |
| f"**{model_type}** | **{self.W}x{self.H}** | **{self.num_points}** splats | " |
| f"**{iterations}** iters | final MSE loss: **{last_loss:.6f}**\n\n" |
| f"Time (total): raster **{t_rast:.2f}s**, backward **{t_bwd:.2f}s**\n\n" |
| f"Left: target, right: fit." |
| ) |
| return combo, msg |
|
|
|
|
| @GPU |
| def run_fit( |
| image: Image.Image | None, |
| max_side: int, |
| num_points: int, |
| iterations: int, |
| lr: float, |
| model_type: Literal["3dgs", "2dgs"], |
| progress: gr.Progress = gr.Progress(track_tqdm=True), |
| ) -> tuple[np.ndarray, str]: |
| if not torch.cuda.is_available(): |
| return _gray_preview(), "CUDA is not available (needs a GPU backend, e.g. ZeroGPU)." |
|
|
| max_side = int(max(32, min(512, max_side))) |
| num_points = int(max(256, min(200_000, num_points))) |
| iterations = int(max(1, min(5000, iterations))) |
| lr = float(max(1e-5, min(0.5, lr))) |
|
|
| try: |
| if image is None: |
| s = min(max_side, 256) |
| gt = _default_gt_tensor(s, s) |
| else: |
| gt = _pil_to_gt_tensor(image, max_side) |
|
|
| dev = torch.device("cuda") |
| fitter = SimpleImageFitter(gt, num_points=num_points, device=dev) |
| return fitter.train( |
| iterations=iterations, |
| lr=lr, |
| model_type=model_type, |
| progress=progress, |
| ) |
| except Exception: |
| return _gray_preview(), f"Training failed:\n```\n{traceback.format_exc()}\n```" |
|
|
|
|
| def build_app() -> gr.Blocks: |
| with gr.Blocks(title="gsplat image fitting") as demo: |
| gr.Markdown( |
| "Fit **3DGS** or **2DGS** splats to an image with **Adam + gsplat rasterization** " |
| "(same idea as `gsplat` `examples/image_fitting.py`). " |
| "For **Hugging Face ZeroGPU**, use `app_file: app_gsplat.py` and keep " |
| "`spaces` in `requirements.txt`." |
| ) |
| with gr.Row(): |
| inp = gr.Image(label="Target image (optional)", type="pil", height=320) |
| out = gr.Image(label="Target | prediction", interactive=False, height=320) |
| with gr.Row(): |
| max_side = gr.Slider(64, 512, value=256, step=16, label="Max image side (px)") |
| num_points = gr.Slider(1000, 100000, value=8000, step=500, label="Num Gaussians") |
| iterations = gr.Slider(50, 2000, value=300, step=50, label="Iterations") |
| lr = gr.Slider(0.001, 0.05, value=0.01, step=0.001, label="Learning rate") |
| model_type = gr.Radio(["3dgs", "2dgs"], value="3dgs", label="Rasterizer") |
| go = gr.Button("Train", variant="primary") |
| md = gr.Markdown("—") |
|
|
| go.click( |
| run_fit, |
| [inp, max_side, num_points, iterations, lr, model_type], |
| [out, md], |
| ) |
| return demo |
|
|
|
|
| demo = build_app() |
| demo.queue(max_size=4) |
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| p = argparse.ArgumentParser() |
| p.add_argument("--host", default=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")) |
| p.add_argument( |
| "--port", |
| type=int, |
| default=int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", str(DEFAULT_PORT)))), |
| ) |
| p.add_argument("--share", action="store_true") |
| a = p.parse_args() |
| demo.launch(server_name=a.host, server_port=a.port, share=a.share) |
|
|