| import os |
| import subprocess |
| import argparse |
| import math |
| import time |
| import shutil |
| import cv2 |
| import torch |
| import numpy as np |
| import base64 |
| import io |
| import json |
| from datetime import datetime |
| from typing import * |
| from PIL import Image |
|
|
| import threading |
| try: |
| import nest_asyncio |
| nest_asyncio.apply() |
| except ImportError: |
| pass |
|
|
| |
| init_lock = threading.Lock() |
|
|
| os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
| os.environ["ATTN_BACKEND"] = "flash_attn_3" |
| os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') |
| os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' |
|
|
| import spaces |
| from gradio import Server |
| from gradio.data_classes import FileData |
| from fastapi.responses import HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
|
|
| from trellis2.modules.sparse import SparseTensor |
| from trellis2.pipelines import Pixal3DImageTo3DPipeline |
| from trellis2.renderers import EnvMap |
| from trellis2.utils import render_utils |
| import o_voxel |
|
|
| |
| |
| |
|
|
| MAX_SEED = np.iinfo(np.int32).max |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') |
| os.makedirs(TMP_DIR, exist_ok=True) |
|
|
| MODES = [ |
| {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, |
| {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, |
| {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, |
| {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, |
| {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, |
| {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, |
| ] |
| STEPS = 8 |
|
|
| |
| CASCADE_LR_RESOLUTION = 512 |
| CASCADE_MAX_NUM_TOKENS = 49152 |
|
|
| |
| MOGE_MODEL_NAME = "Ruicheng/moge-2-vitl" |
| WILD_MESH_SCALE = 1.0 |
| WILD_EXTEND_PIXEL = 0 |
| WILD_IMAGE_RESOLUTION = 512 |
|
|
| |
| IMAGE_COND_CONFIGS = { |
| "ss": { |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", |
| "image_size": 512, |
| "grid_resolution": 16, |
| }, |
| "shape_512": { |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", |
| "image_size": 512, |
| "grid_resolution": 32, |
| "use_naf_upsample": True, |
| "naf_target_size": 512, |
| }, |
| "shape_1024": { |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", |
| "image_size": 1024, |
| "grid_resolution": 64, |
| "use_naf_upsample": True, |
| "naf_target_size": 512, |
| }, |
| "tex_1024": { |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", |
| "image_size": 1024, |
| "grid_resolution": 64, |
| "use_naf_upsample": True, |
| "naf_target_size": 1024, |
| }, |
| } |
|
|
| |
| |
| |
|
|
| def build_image_cond_model(config: dict): |
| from trellis2.trainers.flow_matching.mixins.image_conditioned_proj import DinoV3ProjFeatureExtractor |
| model = DinoV3ProjFeatureExtractor(**config) |
| model.eval() |
| return model |
|
|
| def load_moge_model(device="cuda", model_name=MOGE_MODEL_NAME): |
| from moge.model.v2 import MoGeModel |
| moge_model = MoGeModel.from_pretrained(model_name).to(device) |
| moge_model.eval() |
| return moge_model |
|
|
| |
| pipeline = None |
| moge_model = None |
| envmap = None |
|
|
| def init_models(): |
| global pipeline, moge_model, envmap |
| with init_lock: |
| if pipeline is not None: |
| return |
|
|
| model_path = "TencentARC/Pixal3D-T" |
| print(f"[Pipeline] Loading from {model_path}...") |
| pipeline = Pixal3DImageTo3DPipeline.from_pretrained(model_path) |
| |
| print("[ImageCond] Building DinoV3ProjFeatureExtractor models...") |
| pipeline.image_cond_model_ss = build_image_cond_model(IMAGE_COND_CONFIGS["ss"]) |
| pipeline.image_cond_model_shape_512 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_512"]) |
| pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"]) |
| pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"]) |
| |
| pipeline.low_vram = False |
| pipeline.cuda() |
| |
| |
| pipeline.image_cond_model_ss.cuda() |
| pipeline.image_cond_model_shape_512.cuda() |
| pipeline.image_cond_model_shape_1024.cuda() |
| pipeline.image_cond_model_tex_1024.cuda() |
| |
| print("[NAF] Pre-loading NAF upsampler model...") |
| for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']: |
| model = getattr(pipeline, attr, None) |
| if model is not None and getattr(model, 'use_naf_upsample', False): |
| model._load_naf() |
| |
| print("[MoGe-2] Loading model for camera estimation...") |
| moge_model = load_moge_model(device="cuda") |
| |
| print("[EnvMap] Loading environment maps...") |
| _base = os.path.dirname(os.path.abspath(__file__)) |
| envmap = { |
| 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/forest.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), |
| 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/sunset.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), |
| 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), |
| } |
|
|
| |
| |
| |
|
|
| def compute_f_pixels(camera_angle_x: float, resolution: int) -> float: |
| focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0)) |
| f_pixels = focal_length * resolution / 32.0 |
| return float(f_pixels.item()) |
|
|
| def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution): |
| rotation_matrix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) |
| gp = grid_point.to(torch.float32) @ rotation_matrix.T |
| gp = gp / mesh_scale / 2 |
| xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item() |
| xt, yt = float(target_point[0].item()), float(target_point[1].item()) |
| f_pixels = compute_f_pixels(camera_angle_x, image_resolution) |
| x_ndc = xt - image_resolution / 2.0 |
| y_ndc = -(yt - image_resolution / 2.0) |
| distance_x = f_pixels * xw / x_ndc - yw |
| return {"distance_from_x": float(distance_x), "f_pixels": float(f_pixels)} |
|
|
| def get_camera_params_wild_moge(image_path, device="cuda", mesh_scale=1.0, extend_pixel=0, image_resolution=512): |
| pil_image = Image.open(image_path).convert("RGB") |
| width, height = pil_image.size |
| image_np = np.array(pil_image).astype(np.float32) / 255.0 |
| image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device) |
| with torch.no_grad(): |
| output = moge_model.infer(image_tensor) |
| intrinsics = output["intrinsics"].squeeze().cpu().numpy() |
| fx_normalized = intrinsics[0, 0] |
| fx = fx_normalized * width |
| camera_angle_x = 2 * math.atan(width / (2 * fx)) |
|
|
| grid_point = torch.tensor([-1.0, 0.0, 0.0]) |
| distance = distance_from_fov( |
| camera_angle_x, grid_point, |
| torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]), |
| mesh_scale, image_resolution |
| )["distance_from_x"] |
| return {'camera_angle_x': camera_angle_x, 'distance': distance, 'mesh_scale': mesh_scale} |
|
|
| def pack_state(shape_slat, tex_slat, res): |
| state_data = { |
| 'shape_slat_feats': shape_slat.feats.cpu().numpy(), |
| 'tex_slat_feats': tex_slat.feats.cpu().numpy(), |
| 'coords': shape_slat.coords.cpu().numpy(), |
| 'res': res, |
| } |
| import random |
| state_path = os.path.join(TMP_DIR, f"state_{int(time.time()*1000)}_{random.randint(0,9999):04d}.npz") |
| np.savez_compressed(state_path, **state_data) |
| return state_path |
|
|
| def unpack_state(state_path): |
| data = np.load(state_path) |
| shape_slat = SparseTensor( |
| feats=torch.from_numpy(data['shape_slat_feats']).cuda(), |
| coords=torch.from_numpy(data['coords']).cuda(), |
| ) |
| tex_slat = shape_slat.replace(torch.from_numpy(data['tex_slat_feats']).cuda()) |
| return shape_slat, tex_slat, int(data['res']) |
|
|
| |
| |
| |
|
|
| import asyncio |
| import queue |
| from fastapi.responses import StreamingResponse |
| from fastapi import Request |
|
|
| |
| _progress_queues: Dict[str, queue.Queue] = {} |
| _thread_local = threading.local() |
|
|
| def _reset_progress(session_id: str): |
| _thread_local.active_session = session_id |
| if session_id not in _progress_queues: |
| _progress_queues[session_id] = queue.Queue() |
| |
| q = _progress_queues[session_id] |
| while not q.empty(): |
| try: |
| q.get_nowait() |
| except: |
| break |
|
|
| def _update_progress(stage: str, step: int, total: int): |
| data = {"stage": stage, "step": step, "total": total, "done": False} |
| session_id = getattr(_thread_local, 'active_session', '') |
| if session_id and session_id in _progress_queues: |
| try: |
| _progress_queues[session_id].put_nowait(data) |
| except: |
| pass |
|
|
| def _finish_progress(): |
| session_id = getattr(_thread_local, 'active_session', '') |
| if session_id and session_id in _progress_queues: |
| try: |
| _progress_queues[session_id].put_nowait({"done": True}) |
| except: |
| pass |
| |
| def _cleanup(): |
| time.sleep(5) |
| _progress_queues.pop(session_id, None) |
| threading.Thread(target=_cleanup, daemon=True).start() |
|
|
| |
| import tqdm as _tqdm_module |
|
|
| _original_tqdm = _tqdm_module.tqdm |
|
|
| class _TqdmProgressInterceptor(_original_tqdm): |
| """Wraps tqdm to push progress updates to SSE.""" |
| def __init__(self, *args, **kwargs): |
| self._stage_desc = kwargs.get('desc', 'Processing') |
| super().__init__(*args, **kwargs) |
| |
| def set_description(self, desc=None, refresh=True): |
| self._stage_desc = desc or 'Processing' |
| super().set_description(desc, refresh) |
| |
| def update(self, n=1): |
| super().update(n) |
| _update_progress(self._stage_desc, self.n, self.total or 0) |
|
|
| |
| _tqdm_module.tqdm = _TqdmProgressInterceptor |
| |
| import trellis2.pipelines.samplers.flow_euler as _fe_module |
| _fe_module.tqdm = _TqdmProgressInterceptor |
| import trellis2.utils.render_utils as _ru_module |
| _ru_module.tqdm = _TqdmProgressInterceptor |
| import o_voxel.postprocess as _ovp_module |
| _ovp_module.tqdm = _TqdmProgressInterceptor |
|
|
| |
| |
| |
|
|
| app = Server() |
|
|
| @app.get("/") |
| async def homepage(): |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") |
| with open(html_path, "r", encoding="utf-8") as f: |
| return HTMLResponse(content=f.read()) |
|
|
| @app.get("/progress") |
| async def progress_sse(request: Request): |
| """SSE endpoint for real-time progress updates during generation.""" |
| session_id = request.query_params.get("session_id", "") |
| if session_id and session_id not in _progress_queues: |
| _progress_queues[session_id] = queue.Queue() |
| |
| async def event_stream(): |
| q = _progress_queues.get(session_id) |
| timeout_count = 0 |
| while True: |
| if q: |
| try: |
| data = q.get_nowait() |
| yield f"data: {json.dumps(data)}\n\n" |
| if data.get("done"): |
| break |
| timeout_count = 0 |
| except queue.Empty: |
| yield f": keepalive\n\n" |
| timeout_count += 1 |
| else: |
| yield f": keepalive\n\n" |
| timeout_count += 1 |
| |
| if timeout_count > 1000: |
| break |
| await asyncio.sleep(0.3) |
| return StreamingResponse(event_stream(), media_type="text/event-stream") |
|
|
| @app.api() |
| @spaces.GPU(duration=30) |
| def preprocess(image: FileData) -> FileData: |
| init_models() |
| img = Image.open(image["path"]) |
| processed = pipeline.preprocess_image(img) |
| out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png") |
| processed.save(out_path) |
| return FileData(path=out_path) |
|
|
| @app.api() |
| @spaces.GPU(duration=120) |
| def generate_3d( |
| image: FileData, |
| seed: int, |
| resolution: int, |
| ss_guidance_strength: float = 7.5, |
| ss_guidance_rescale: float = 0.7, |
| ss_sampling_steps: int = 12, |
| ss_rescale_t: float = 5.0, |
| shape_slat_guidance_strength: float = 7.5, |
| shape_slat_guidance_rescale: float = 0.5, |
| shape_slat_sampling_steps: int = 12, |
| shape_slat_rescale_t: float = 3.0, |
| tex_slat_guidance_strength: float = 1.0, |
| tex_slat_guidance_rescale: float = 0.0, |
| tex_slat_sampling_steps: int = 12, |
| tex_slat_rescale_t: float = 3.0, |
| session_id: str = "", |
| ) -> Dict: |
| init_models() |
| _reset_progress(session_id) |
| _update_progress("Preprocessing & Camera Estimation", 0, 1) |
| |
| torch.manual_seed(seed) |
| hr_resolution = int(resolution) |
| |
| img = Image.open(image["path"]) |
| |
| image_preprocessed = img |
| temp_processed_path = os.path.join(TMP_DIR, f"temp_proc_{session_id[:8]}_{int(time.time()*1000)}.png") |
| image_preprocessed.save(temp_processed_path) |
| |
| camera_params = get_camera_params_wild_moge( |
| temp_processed_path, device="cuda", |
| mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL, |
| image_resolution=WILD_IMAGE_RESOLUTION, |
| ) |
| _update_progress("Preprocessing & Camera Estimation", 1, 1) |
| |
| ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, |
| "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t} |
| shape_sampler_override = {"steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, |
| "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t} |
| tex_sampler_override = {"steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, |
| "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t} |
|
|
| pipeline_type = f"{hr_resolution}_cascade" |
| mesh_list, (shape_slat, tex_slat, res) = pipeline.run( |
| image_preprocessed, |
| camera_params=camera_params, |
| seed=seed, |
| sparse_structure_sampler_params=ss_sampler_override, |
| shape_slat_sampler_params=shape_sampler_override, |
| tex_slat_sampler_params=tex_sampler_override, |
| preprocess_image=False, |
| return_latent=True, |
| pipeline_type=pipeline_type, |
| max_num_tokens=CASCADE_MAX_NUM_TOKENS, |
| ) |
| |
| mesh = mesh_list[0] |
| state_path = pack_state(shape_slat, tex_slat, res) |
| |
| _update_progress("Rendering views", 0, 1) |
| mesh.simplify(16777216) |
| cam_dist = camera_params['distance'] |
| near = max(0.01, cam_dist - 2.0) |
| far = cam_dist + 10.0 |
| renders = render_utils.render_proj_aligned_video( |
| mesh, camera_angle_x=camera_params['camera_angle_x'], |
| distance=cam_dist, resolution=1024, |
| num_frames=STEPS, envmap=envmap, |
| near=near, far=far, |
| ) |
| _update_progress("Rendering views", 1, 1) |
| |
| |
| render_files = {} |
| for mode_key, frames in renders.items(): |
| mode_files = [] |
| for i, frame in enumerate(frames): |
| p = os.path.abspath(os.path.join(TMP_DIR, f"render_{mode_key}_{i}_{int(time.time()*1000)}.jpg")) |
| Image.fromarray(frame).save(p, quality=85) |
| mode_files.append(FileData(path=p)) |
| render_files[mode_key] = mode_files |
|
|
| _finish_progress() |
| return { |
| "render_paths": render_files, |
| "state_path": os.path.abspath(state_path) |
| } |
|
|
| @app.api() |
| @spaces.GPU(duration=240) |
| def extract_glb_api(state_path: str, decimation_target: int, texture_size: int, session_id: str = "") -> FileData: |
| init_models() |
| _reset_progress(session_id) |
| _update_progress("Decoding latent", 0, 1) |
| |
| shape_slat, tex_slat, res = unpack_state(state_path) |
| mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] |
| _update_progress("Decoding latent", 1, 1) |
| |
| glb = o_voxel.postprocess.to_glb( |
| vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, |
| coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, |
| grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], |
| decimation_target=decimation_target, texture_size=texture_size, |
| remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, |
| ) |
| rot = np.array([ |
| [-1, 0, 0, 0], |
| [ 0, 0, -1, 0], |
| [ 0, -1, 0, 0], |
| [ 0, 0, 0, 1], |
| ], dtype=np.float64) |
| glb.apply_transform(rot) |
| |
| out_glb = os.path.join(TMP_DIR, f"result_{int(time.time()*1000)}.glb") |
| glb.export(out_glb, extension_webp=True) |
| _finish_progress() |
| return FileData(path=out_glb) |
|
|
| |
| app.mount("/assets", StaticFiles(directory="assets"), name="assets") |
| app.mount("/tmp", StaticFiles(directory=TMP_DIR), name="tmp") |
|
|
| if __name__ == "__main__": |
| |
| subprocess.run([ |
| "pip", "install", "--force-reinstall", "--no-deps", |
| "https://github.com/LDYang694/Storages/releases/download/20260430/utils3d-0.0.2-py3-none-any.whl" |
| ], check=True) |
| |
| |
| init_models() |
| |
| app.launch(show_error=True, share=True) |