Spaces:
Running
Running
| import os | |
| import subprocess | |
| import argparse | |
| import math | |
| import time | |
| import shutil | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import base64 | |
| import io | |
| import json | |
| from datetime import datetime | |
| from typing import * | |
| from PIL import Image | |
| import threading | |
| try: | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| except ImportError: | |
| pass | |
| # Lock for model initialization | |
| init_lock = threading.Lock() | |
| os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| os.environ["ATTN_BACKEND"] = "flash_attn_3" | |
| os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') | |
| os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' | |
| import spaces | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from trellis2.modules.sparse import SparseTensor | |
| from trellis2.pipelines import Pixal3DImageTo3DPipeline | |
| from trellis2.renderers import EnvMap | |
| from trellis2.utils import render_utils | |
| import o_voxel | |
| # ============================================================================ | |
| # Constants & Defaults | |
| # ============================================================================ | |
| MAX_SEED = np.iinfo(np.int32).max | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| MODES = [ | |
| {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, | |
| {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, | |
| {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, | |
| {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, | |
| {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, | |
| {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, | |
| ] | |
| STEPS = 8 | |
| # Cascade parameters | |
| CASCADE_LR_RESOLUTION = 512 | |
| CASCADE_MAX_NUM_TOKENS = 49152 | |
| # MoGe defaults | |
| MOGE_MODEL_NAME = "Ruicheng/moge-2-vitl" | |
| WILD_MESH_SCALE = 1.0 | |
| WILD_EXTEND_PIXEL = 0 | |
| WILD_IMAGE_RESOLUTION = 512 | |
| # Image Cond Model configs | |
| IMAGE_COND_CONFIGS = { | |
| "ss": { | |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", | |
| "image_size": 512, | |
| "grid_resolution": 16, | |
| }, | |
| "shape_512": { | |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", | |
| "image_size": 512, | |
| "grid_resolution": 32, | |
| "use_naf_upsample": True, | |
| "naf_target_size": 512, | |
| }, | |
| "shape_1024": { | |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", | |
| "image_size": 1024, | |
| "grid_resolution": 64, | |
| "use_naf_upsample": True, | |
| "naf_target_size": 512, | |
| }, | |
| "tex_1024": { | |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", | |
| "image_size": 1024, | |
| "grid_resolution": 64, | |
| "use_naf_upsample": True, | |
| "naf_target_size": 1024, | |
| }, | |
| } | |
| # ============================================================================ | |
| # Model Loading | |
| # ============================================================================ | |
| def build_image_cond_model(config: dict): | |
| from trellis2.trainers.flow_matching.mixins.image_conditioned_proj import DinoV3ProjFeatureExtractor | |
| model = DinoV3ProjFeatureExtractor(**config) | |
| model.eval() | |
| return model | |
| def load_moge_model(device="cuda", model_name=MOGE_MODEL_NAME): | |
| from moge.model.v2 import MoGeModel | |
| moge_model = MoGeModel.from_pretrained(model_name).to(device) | |
| moge_model.eval() | |
| return moge_model | |
| # Global instances (lazy loaded or loaded at start) | |
| pipeline = None | |
| moge_model = None | |
| envmap = None | |
| def init_models(): | |
| global pipeline, moge_model, envmap | |
| with init_lock: | |
| if pipeline is not None: | |
| return | |
| # GPU / CUDA Diagnostics (runs when GPU is allocated) | |
| import subprocess as _sp | |
| print("=" * 60) | |
| print("[Diagnostics] PyTorch version:", torch.__version__) | |
| print("[Diagnostics] CUDA available:", torch.cuda.is_available()) | |
| if torch.cuda.is_available(): | |
| print("[Diagnostics] CUDA version:", torch.version.cuda) | |
| print("[Diagnostics] cuDNN version:", torch.backends.cudnn.version()) | |
| for i in range(torch.cuda.device_count()): | |
| name = torch.cuda.get_device_name(i) | |
| cap = torch.cuda.get_device_capability(i) | |
| mem = torch.cuda.get_device_properties(i).total_memory / 1024**3 | |
| print(f"[Diagnostics] GPU {i}: {name}, sm_{cap[0]}{cap[1]}, {mem:.1f} GB") | |
| try: | |
| res = _sp.run(["nvidia-smi", "--query-gpu=name,compute_cap,memory.total", "--format=csv,noheader"], capture_output=True, text=True, timeout=10) | |
| print("[Diagnostics] nvidia-smi:", res.stdout.strip()) | |
| except Exception as e: | |
| print(f"[Diagnostics] nvidia-smi failed: {e}") | |
| print("=" * 60) | |
| model_path = "TencentARC/Pixal3D-T" | |
| print(f"[Pipeline] Loading from {model_path}...") | |
| pipeline = Pixal3DImageTo3DPipeline.from_pretrained(model_path) | |
| print("[ImageCond] Building DinoV3ProjFeatureExtractor models...") | |
| pipeline.image_cond_model_ss = build_image_cond_model(IMAGE_COND_CONFIGS["ss"]) | |
| pipeline.image_cond_model_shape_512 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_512"]) | |
| pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"]) | |
| pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"]) | |
| pipeline.low_vram = False | |
| pipeline.cuda() | |
| # Ensure image_cond_models are on GPU | |
| pipeline.image_cond_model_ss.cuda() | |
| pipeline.image_cond_model_shape_512.cuda() | |
| pipeline.image_cond_model_shape_1024.cuda() | |
| pipeline.image_cond_model_tex_1024.cuda() | |
| print("[NAF] Pre-loading NAF upsampler model...") | |
| for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']: | |
| model = getattr(pipeline, attr, None) | |
| if model is not None and getattr(model, 'use_naf_upsample', False): | |
| model._load_naf() | |
| print("[MoGe-2] Loading model for camera estimation...") | |
| moge_model = load_moge_model(device="cuda") | |
| print("[EnvMap] Loading environment maps...") | |
| _base = os.path.dirname(os.path.abspath(__file__)) | |
| envmap = { | |
| 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/forest.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), | |
| 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/sunset.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), | |
| 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), | |
| } | |
| # ============================================================================ | |
| # Utilities | |
| # ============================================================================ | |
| def compute_f_pixels(camera_angle_x: float, resolution: int) -> float: | |
| focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0)) | |
| f_pixels = focal_length * resolution / 32.0 | |
| return float(f_pixels.item()) | |
| def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution): | |
| rotation_matrix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) | |
| gp = grid_point.to(torch.float32) @ rotation_matrix.T | |
| gp = gp / mesh_scale / 2 | |
| xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item() | |
| xt, yt = float(target_point[0].item()), float(target_point[1].item()) | |
| f_pixels = compute_f_pixels(camera_angle_x, image_resolution) | |
| x_ndc = xt - image_resolution / 2.0 | |
| y_ndc = -(yt - image_resolution / 2.0) | |
| distance_x = f_pixels * xw / x_ndc - yw | |
| return {"distance_from_x": float(distance_x), "f_pixels": float(f_pixels)} | |
| def get_camera_params_wild_moge(image_path, device="cuda", mesh_scale=1.0, extend_pixel=0, image_resolution=512): | |
| pil_image = Image.open(image_path).convert("RGB") | |
| width, height = pil_image.size | |
| image_np = np.array(pil_image).astype(np.float32) / 255.0 | |
| image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device) | |
| with torch.no_grad(): | |
| output = moge_model.infer(image_tensor) | |
| intrinsics = output["intrinsics"].squeeze().cpu().numpy() | |
| fx_normalized = intrinsics[0, 0] | |
| fx = fx_normalized * width | |
| camera_angle_x = 2 * math.atan(width / (2 * fx)) | |
| grid_point = torch.tensor([-1.0, 0.0, 0.0]) | |
| distance = distance_from_fov( | |
| camera_angle_x, grid_point, | |
| torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]), | |
| mesh_scale, image_resolution | |
| )["distance_from_x"] | |
| return {'camera_angle_x': camera_angle_x, 'distance': distance, 'mesh_scale': mesh_scale} | |
| def pack_state(shape_slat, tex_slat, res): | |
| state_data = { | |
| 'shape_slat_feats': shape_slat.feats.cpu().numpy(), | |
| 'tex_slat_feats': tex_slat.feats.cpu().numpy(), | |
| 'coords': shape_slat.coords.cpu().numpy(), | |
| 'res': res, | |
| } | |
| import random | |
| state_path = os.path.join(TMP_DIR, f"state_{int(time.time()*1000)}_{random.randint(0,9999):04d}.npz") | |
| np.savez_compressed(state_path, **state_data) | |
| return state_path | |
| def unpack_state(state_path): | |
| data = np.load(state_path) | |
| shape_slat = SparseTensor( | |
| feats=torch.from_numpy(data['shape_slat_feats']).cuda(), | |
| coords=torch.from_numpy(data['coords']).cuda(), | |
| ) | |
| tex_slat = shape_slat.replace(torch.from_numpy(data['tex_slat_feats']).cuda()) | |
| return shape_slat, tex_slat, int(data['res']) | |
| # ============================================================================ | |
| # Progress Tracking (file-based, cross-process safe for @spaces.GPU) | |
| # ============================================================================ | |
| import asyncio | |
| from fastapi.responses import JSONResponse | |
| from fastapi import Request | |
| PROGRESS_DIR = os.path.join(TMP_DIR, '_progress') | |
| os.makedirs(PROGRESS_DIR, exist_ok=True) | |
| _thread_local = threading.local() | |
| def _progress_file(session_id: str) -> str: | |
| """Return path to a session's progress JSON file.""" | |
| return os.path.join(PROGRESS_DIR, f"{session_id}.json") | |
| def _reset_progress(session_id: str): | |
| _thread_local.active_session = session_id | |
| _write_progress_file(session_id, {"stage": "Initializing...", "step": 0, "total": 0, "done": False}) | |
| def _update_progress(stage: str, step: int, total: int): | |
| session_id = getattr(_thread_local, 'active_session', '') | |
| if session_id: | |
| _write_progress_file(session_id, {"stage": stage, "step": step, "total": total, "done": False}) | |
| def _finish_progress(): | |
| session_id = getattr(_thread_local, 'active_session', '') | |
| if session_id: | |
| _write_progress_file(session_id, {"done": True}) | |
| def _write_progress_file(session_id: str, data: dict): | |
| """Atomically write progress JSON to a file (cross-process safe).""" | |
| path = _progress_file(session_id) | |
| tmp_path = path + ".tmp" | |
| try: | |
| with open(tmp_path, 'w') as f: | |
| json.dump(data, f) | |
| os.replace(tmp_path, path) # atomic on POSIX | |
| except Exception: | |
| pass | |
| # Monkey-patch tqdm to intercept progress | |
| import tqdm as _tqdm_module | |
| _original_tqdm = _tqdm_module.tqdm | |
| class _TqdmProgressInterceptor(_original_tqdm): | |
| """Wraps tqdm to push progress updates to SSE.""" | |
| def __init__(self, *args, **kwargs): | |
| self._stage_desc = kwargs.get('desc', 'Processing') | |
| super().__init__(*args, **kwargs) | |
| def set_description(self, desc=None, refresh=True): | |
| self._stage_desc = desc or 'Processing' | |
| super().set_description(desc, refresh) | |
| def update(self, n=1): | |
| super().update(n) | |
| _update_progress(self._stage_desc, self.n, self.total or 0) | |
| # Patch tqdm globally | |
| _tqdm_module.tqdm = _TqdmProgressInterceptor | |
| # Also patch the direct import in the sampler module and render_utils | |
| import trellis2.pipelines.samplers.flow_euler as _fe_module | |
| _fe_module.tqdm = _TqdmProgressInterceptor | |
| import trellis2.utils.render_utils as _ru_module | |
| _ru_module.tqdm = _TqdmProgressInterceptor | |
| import o_voxel.postprocess as _ovp_module | |
| _ovp_module.tqdm = _TqdmProgressInterceptor | |
| # ============================================================================ | |
| # API Implementation | |
| # ============================================================================ | |
| app = Server() | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| async def progress_poll(request: Request): | |
| """Polling endpoint for real-time progress updates during generation.""" | |
| session_id = request.query_params.get("session_id", "") | |
| path = _progress_file(session_id) | |
| try: | |
| with open(path, 'r') as f: | |
| data = json.load(f) | |
| return JSONResponse(data) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| return JSONResponse({"stage": "Waiting...", "step": 0, "total": 0, "done": False}) | |
| def preprocess(image: FileData) -> FileData: | |
| init_models() | |
| img = Image.open(image["path"]) | |
| processed = pipeline.preprocess_image(img) | |
| out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png") | |
| processed.save(out_path) | |
| return FileData(path=out_path) | |
| def generate_3d( | |
| image: FileData, | |
| seed: int, | |
| resolution: int, | |
| ss_guidance_strength: float = 7.5, | |
| ss_guidance_rescale: float = 0.7, | |
| ss_sampling_steps: int = 12, | |
| ss_rescale_t: float = 5.0, | |
| shape_slat_guidance_strength: float = 7.5, | |
| shape_slat_guidance_rescale: float = 0.5, | |
| shape_slat_sampling_steps: int = 12, | |
| shape_slat_rescale_t: float = 3.0, | |
| tex_slat_guidance_strength: float = 1.0, | |
| tex_slat_guidance_rescale: float = 0.0, | |
| tex_slat_sampling_steps: int = 12, | |
| tex_slat_rescale_t: float = 3.0, | |
| session_id: str = "", | |
| ) -> Dict: | |
| init_models() | |
| _reset_progress(session_id) | |
| _update_progress("Preprocessing & Camera Estimation", 0, 1) | |
| torch.manual_seed(seed) | |
| hr_resolution = int(resolution) | |
| img = Image.open(image["path"]) | |
| # Image is already preprocessed by /preprocess endpoint, use directly | |
| image_preprocessed = img | |
| temp_processed_path = os.path.join(TMP_DIR, f"temp_proc_{session_id[:8]}_{int(time.time()*1000)}.png") | |
| image_preprocessed.save(temp_processed_path) | |
| camera_params = get_camera_params_wild_moge( | |
| temp_processed_path, device="cuda", | |
| mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL, | |
| image_resolution=WILD_IMAGE_RESOLUTION, | |
| ) | |
| _update_progress("Preprocessing & Camera Estimation", 1, 1) | |
| ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, | |
| "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t} | |
| shape_sampler_override = {"steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, | |
| "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t} | |
| tex_sampler_override = {"steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, | |
| "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t} | |
| pipeline_type = f"{hr_resolution}_cascade" | |
| mesh_list, (shape_slat, tex_slat, res) = pipeline.run( | |
| image_preprocessed, | |
| camera_params=camera_params, | |
| seed=seed, | |
| sparse_structure_sampler_params=ss_sampler_override, | |
| shape_slat_sampler_params=shape_sampler_override, | |
| tex_slat_sampler_params=tex_sampler_override, | |
| preprocess_image=False, | |
| return_latent=True, | |
| pipeline_type=pipeline_type, | |
| max_num_tokens=CASCADE_MAX_NUM_TOKENS, | |
| ) | |
| mesh = mesh_list[0] | |
| state_path = pack_state(shape_slat, tex_slat, res) | |
| _update_progress("Rendering views", 0, 1) | |
| mesh.simplify(16777216) | |
| cam_dist = camera_params['distance'] | |
| near = max(0.01, cam_dist - 2.0) | |
| far = cam_dist + 10.0 | |
| renders = render_utils.render_proj_aligned_video( | |
| mesh, camera_angle_x=camera_params['camera_angle_x'], | |
| distance=cam_dist, resolution=1024, | |
| num_frames=STEPS, envmap=envmap, | |
| near=near, far=far, | |
| ) | |
| _update_progress("Rendering views", 1, 1) | |
| # Save renders and return paths | |
| render_files = {} | |
| for mode_key, frames in renders.items(): | |
| mode_files = [] | |
| for i, frame in enumerate(frames): | |
| p = os.path.abspath(os.path.join(TMP_DIR, f"render_{mode_key}_{i}_{int(time.time()*1000)}.jpg")) | |
| Image.fromarray(frame).save(p, quality=85) | |
| mode_files.append(FileData(path=p)) | |
| render_files[mode_key] = mode_files | |
| _finish_progress() | |
| return { | |
| "render_paths": render_files, | |
| "state_path": os.path.abspath(state_path), | |
| "camera_angle_x": camera_params['camera_angle_x'], | |
| "distance": camera_params['distance'], | |
| } | |
| def extract_glb_api(state_path: str, decimation_target: int, texture_size: int, session_id: str = "") -> FileData: | |
| init_models() | |
| _reset_progress(session_id) | |
| _update_progress("Decoding latent", 0, 1) | |
| shape_slat, tex_slat, res = unpack_state(state_path) | |
| mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] | |
| _update_progress("Decoding latent", 1, 1) | |
| glb = o_voxel.postprocess.to_glb( | |
| vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, | |
| coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, | |
| grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], | |
| decimation_target=decimation_target, texture_size=texture_size, | |
| remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, | |
| ) | |
| rot = np.array([ | |
| [-1, 0, 0, 0], | |
| [ 0, 0, -1, 0], | |
| [ 0, -1, 0, 0], | |
| [ 0, 0, 0, 1], | |
| ], dtype=np.float64) | |
| glb.apply_transform(rot) | |
| out_glb = os.path.join(TMP_DIR, f"result_{int(time.time()*1000)}.glb") | |
| glb.export(out_glb, extension_webp=True) | |
| _finish_progress() | |
| return FileData(path=out_glb) | |
| # Mount assets and tmp for direct access | |
| app.mount("/assets", StaticFiles(directory="assets"), name="assets") | |
| app.mount("/tmp", StaticFiles(directory=TMP_DIR), name="tmp") | |
| if __name__ == "__main__": | |
| # Re-install utils3d as in original app.py | |
| subprocess.run([ | |
| "pip", "install", "--force-reinstall", "--no-deps", | |
| "https://github.com/LDYang694/Storages/releases/download/20260430/utils3d-0.0.2-py3-none-any.whl" | |
| ], check=True) | |
| # Pre-initialize models before launching the server | |
| init_models() | |
| app.launch(show_error=True, share=True) |