import os import subprocess import argparse import math import time import shutil import cv2 import torch import numpy as np import base64 import io import json from datetime import datetime from typing import * from PIL import Image import threading try: import nest_asyncio nest_asyncio.apply() except ImportError: pass # Lock for model initialization init_lock = threading.Lock() os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["ATTN_BACKEND"] = "flash_attn_3" os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' import spaces from gradio import Server from gradio.data_classes import FileData from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from trellis2.modules.sparse import SparseTensor from trellis2.pipelines import Pixal3DImageTo3DPipeline from trellis2.renderers import EnvMap from trellis2.utils import render_utils import o_voxel # ============================================================================ # Constants & Defaults # ============================================================================ MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') os.makedirs(TMP_DIR, exist_ok=True) MODES = [ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, ] STEPS = 8 # Cascade parameters CASCADE_LR_RESOLUTION = 512 CASCADE_MAX_NUM_TOKENS = 49152 # MoGe defaults MOGE_MODEL_NAME = "Ruicheng/moge-2-vitl" WILD_MESH_SCALE = 1.0 WILD_EXTEND_PIXEL = 0 WILD_IMAGE_RESOLUTION = 512 # Image Cond Model configs IMAGE_COND_CONFIGS = { "ss": { "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", "image_size": 512, "grid_resolution": 16, }, "shape_512": { "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", "image_size": 512, "grid_resolution": 32, "use_naf_upsample": True, "naf_target_size": 512, }, "shape_1024": { "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", "image_size": 1024, "grid_resolution": 64, "use_naf_upsample": True, "naf_target_size": 512, }, "tex_1024": { "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", "image_size": 1024, "grid_resolution": 64, "use_naf_upsample": True, "naf_target_size": 1024, }, } # ============================================================================ # Model Loading # ============================================================================ def build_image_cond_model(config: dict): from trellis2.trainers.flow_matching.mixins.image_conditioned_proj import DinoV3ProjFeatureExtractor model = DinoV3ProjFeatureExtractor(**config) model.eval() return model def load_moge_model(device="cuda", model_name=MOGE_MODEL_NAME): from moge.model.v2 import MoGeModel moge_model = MoGeModel.from_pretrained(model_name).to(device) moge_model.eval() return moge_model # Global instances (lazy loaded or loaded at start) pipeline = None moge_model = None envmap = None def init_models(): global pipeline, moge_model, envmap with init_lock: if pipeline is not None: return model_path = "TencentARC/Pixal3D-T" print(f"[Pipeline] Loading from {model_path}...") pipeline = Pixal3DImageTo3DPipeline.from_pretrained(model_path) print("[ImageCond] Building DinoV3ProjFeatureExtractor models...") pipeline.image_cond_model_ss = build_image_cond_model(IMAGE_COND_CONFIGS["ss"]) pipeline.image_cond_model_shape_512 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_512"]) pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"]) pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"]) pipeline.cuda() print("[NAF] Pre-loading NAF upsampler model...") for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']: model = getattr(pipeline, attr, None) if model is not None and getattr(model, 'use_naf_upsample', False): model._load_naf() print("[MoGe-2] Loading model for camera estimation...") moge_model = load_moge_model(device="cuda") print("[EnvMap] Loading environment maps...") envmap = { 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), } # ============================================================================ # Utilities # ============================================================================ def compute_f_pixels(camera_angle_x: float, resolution: int) -> float: focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0)) f_pixels = focal_length * resolution / 32.0 return float(f_pixels.item()) def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution): rotation_matrix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) gp = grid_point.to(torch.float32) @ rotation_matrix.T gp = gp / mesh_scale / 2 xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item() xt, yt = float(target_point[0].item()), float(target_point[1].item()) f_pixels = compute_f_pixels(camera_angle_x, image_resolution) x_ndc = xt - image_resolution / 2.0 y_ndc = -(yt - image_resolution / 2.0) distance_x = f_pixels * xw / x_ndc - yw return {"distance_from_x": float(distance_x), "f_pixels": float(f_pixels)} def get_camera_params_wild_moge(image_path, device="cuda", mesh_scale=1.0, extend_pixel=0, image_resolution=512): pil_image = Image.open(image_path).convert("RGB") width, height = pil_image.size image_np = np.array(pil_image).astype(np.float32) / 255.0 image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device) with torch.no_grad(): output = moge_model.infer(image_tensor) intrinsics = output["intrinsics"].squeeze().cpu().numpy() fx_normalized = intrinsics[0, 0] fx = fx_normalized * width camera_angle_x = 2 * math.atan(width / (2 * fx)) grid_point = torch.tensor([-1.0, 0.0, 0.0]) distance = distance_from_fov( camera_angle_x, grid_point, torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]), mesh_scale, image_resolution )["distance_from_x"] return {'camera_angle_x': camera_angle_x, 'distance': distance, 'mesh_scale': mesh_scale} def pack_state(shape_slat, tex_slat, res): state_data = { 'shape_slat_feats': shape_slat.feats.cpu().numpy(), 'tex_slat_feats': tex_slat.feats.cpu().numpy(), 'coords': shape_slat.coords.cpu().numpy(), 'res': res, } state_path = os.path.join(TMP_DIR, f"state_{int(time.time()*1000)}.npz") np.savez_compressed(state_path, **state_data) return state_path def unpack_state(state_path): data = np.load(state_path) shape_slat = SparseTensor( feats=torch.from_numpy(data['shape_slat_feats']).cuda(), coords=torch.from_numpy(data['coords']).cuda(), ) tex_slat = shape_slat.replace(torch.from_numpy(data['tex_slat_feats']).cuda()) return shape_slat, tex_slat, int(data['res']) # ============================================================================ # API Implementation # ============================================================================ app = Server() @app.get("/") async def homepage(): html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") with open(html_path, "r", encoding="utf-8") as f: return HTMLResponse(content=f.read()) @app.api() def preprocess(image: FileData) -> FileData: init_models() img = Image.open(image["path"]) processed = pipeline.preprocess_image(img) out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png") processed.save(out_path) return FileData(path=out_path) @app.api() @spaces.GPU(duration=120) def generate_3d( image: FileData, seed: int, resolution: int, ss_guidance_strength: float = 7.5, ss_guidance_rescale: float = 0.7, ss_sampling_steps: int = 12, ss_rescale_t: float = 5.0, shape_slat_guidance_strength: float = 7.5, shape_slat_guidance_rescale: float = 0.5, shape_slat_sampling_steps: int = 12, shape_slat_rescale_t: float = 3.0, tex_slat_guidance_strength: float = 1.0, tex_slat_guidance_rescale: float = 0.0, tex_slat_sampling_steps: int = 12, tex_slat_rescale_t: float = 3.0, ) -> Dict: init_models() torch.manual_seed(seed) hr_resolution = int(resolution) img = Image.open(image["path"]) image_preprocessed = pipeline.preprocess_image(img) temp_processed_path = os.path.join(TMP_DIR, "temp_proc.png") image_preprocessed.save(temp_processed_path) camera_params = get_camera_params_wild_moge( temp_processed_path, device="cuda", mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL, image_resolution=WILD_IMAGE_RESOLUTION, ) ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t} shape_sampler_override = {"steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t} tex_sampler_override = {"steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t} pipeline_type = f"{hr_resolution}_cascade" mesh_list, (shape_slat, tex_slat, res) = pipeline.run( image_preprocessed, camera_params=camera_params, seed=seed, sparse_structure_sampler_params=ss_sampler_override, shape_slat_sampler_params=shape_sampler_override, tex_slat_sampler_params=tex_sampler_override, preprocess_image=False, return_latent=True, pipeline_type=pipeline_type, max_num_tokens=CASCADE_MAX_NUM_TOKENS, ) mesh = mesh_list[0] state_path = pack_state(shape_slat, tex_slat, res) mesh.simplify(16777216) renders = render_utils.render_proj_aligned_video( mesh, camera_angle_x=camera_params['camera_angle_x'], distance=camera_params['distance'], resolution=1024, num_frames=STEPS, envmap=envmap, ) # Save renders and return paths render_files = {} for mode_key, frames in renders.items(): mode_files = [] for i, frame in enumerate(frames): p = os.path.abspath(os.path.join(TMP_DIR, f"render_{mode_key}_{i}_{int(time.time()*1000)}.jpg")) Image.fromarray(frame).save(p, quality=85) mode_files.append(FileData(path=p)) render_files[mode_key] = mode_files return { "render_paths": render_files, "state_path": os.path.abspath(state_path) } @app.api() @spaces.GPU(duration=120) def extract_glb_api(state_path: str, decimation_target: int, texture_size: int) -> FileData: init_models() shape_slat, tex_slat, res = unpack_state(state_path) mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, ) rot = np.array([ [-1, 0, 0, 0], [ 0, 0, -1, 0], [ 0, -1, 0, 0], [ 0, 0, 0, 1], ], dtype=np.float64) glb.apply_transform(rot) out_glb = os.path.join(TMP_DIR, f"result_{int(time.time()*1000)}.glb") glb.export(out_glb, extension_webp=True) return FileData(path=out_glb) # Mount assets and tmp for direct access app.mount("/assets", StaticFiles(directory="assets"), name="assets") app.mount("/tmp", StaticFiles(directory=TMP_DIR), name="tmp") if __name__ == "__main__": # Re-install utils3d as in original app.py subprocess.run([ "pip", "install", "--force-reinstall", "--no-deps", "https://github.com/LDYang694/Storages/releases/download/20260430/utils3d-0.0.2-py3-none-any.whl" ], check=True) # Pre-initialize models before launching the server init_models() app.launch(show_error=True, share=True)