Spaces:
Running
Running
| import os | |
| import subprocess | |
| import argparse | |
| import math | |
| import time | |
| import shutil | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import base64 | |
| import io | |
| import json | |
| from datetime import datetime | |
| from typing import * | |
| from PIL import Image | |
| import threading | |
| try: | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| except ImportError: | |
| pass | |
| # Lock for model initialization | |
| init_lock = threading.Lock() | |
| os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| os.environ["ATTN_BACKEND"] = "flash_attn_3" | |
| os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') | |
| os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' | |
| import spaces | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from trellis2.modules.sparse import SparseTensor | |
| from trellis2.pipelines import Pixal3DImageTo3DPipeline | |
| from trellis2.renderers import EnvMap | |
| from trellis2.utils import render_utils | |
| import o_voxel | |
| # ============================================================================ | |
| # Constants & Defaults | |
| # ============================================================================ | |
| MAX_SEED = np.iinfo(np.int32).max | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| MODES = [ | |
| {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, | |
| {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, | |
| {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, | |
| {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, | |
| {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, | |
| {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, | |
| ] | |
| STEPS = 8 | |
| # Cascade parameters | |
| CASCADE_LR_RESOLUTION = 512 | |
| CASCADE_MAX_NUM_TOKENS = 49152 | |
| # MoGe defaults | |
| MOGE_MODEL_NAME = "Ruicheng/moge-2-vitl" | |
| WILD_MESH_SCALE = 1.0 | |
| WILD_EXTEND_PIXEL = 0 | |
| WILD_IMAGE_RESOLUTION = 512 | |
| # Image Cond Model configs | |
| IMAGE_COND_CONFIGS = { | |
| "ss": { | |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", | |
| "image_size": 512, | |
| "grid_resolution": 16, | |
| }, | |
| "shape_512": { | |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", | |
| "image_size": 512, | |
| "grid_resolution": 32, | |
| "use_naf_upsample": True, | |
| "naf_target_size": 512, | |
| }, | |
| "shape_1024": { | |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", | |
| "image_size": 1024, | |
| "grid_resolution": 64, | |
| "use_naf_upsample": True, | |
| "naf_target_size": 512, | |
| }, | |
| "tex_1024": { | |
| "model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m", | |
| "image_size": 1024, | |
| "grid_resolution": 64, | |
| "use_naf_upsample": True, | |
| "naf_target_size": 1024, | |
| }, | |
| } | |
| # ============================================================================ | |
| # Model Loading | |
| # ============================================================================ | |
| def build_image_cond_model(config: dict): | |
| from trellis2.trainers.flow_matching.mixins.image_conditioned_proj import DinoV3ProjFeatureExtractor | |
| model = DinoV3ProjFeatureExtractor(**config) | |
| model.eval() | |
| return model | |
| def load_moge_model(device="cuda", model_name=MOGE_MODEL_NAME): | |
| from moge.model.v2 import MoGeModel | |
| moge_model = MoGeModel.from_pretrained(model_name).to(device) | |
| moge_model.eval() | |
| return moge_model | |
| # Global instances (lazy loaded or loaded at start) | |
| pipeline = None | |
| moge_model = None | |
| envmap = None | |
| def init_models(): | |
| global pipeline, moge_model, envmap | |
| with init_lock: | |
| if pipeline is not None: | |
| return | |
| model_path = "TencentARC/Pixal3D-T" | |
| print(f"[Pipeline] Loading from {model_path}...") | |
| pipeline = Pixal3DImageTo3DPipeline.from_pretrained(model_path) | |
| print("[ImageCond] Building DinoV3ProjFeatureExtractor models...") | |
| pipeline.image_cond_model_ss = build_image_cond_model(IMAGE_COND_CONFIGS["ss"]) | |
| pipeline.image_cond_model_shape_512 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_512"]) | |
| pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"]) | |
| pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"]) | |
| pipeline.cuda() | |
| print("[NAF] Pre-loading NAF upsampler model...") | |
| for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']: | |
| model = getattr(pipeline, attr, None) | |
| if model is not None and getattr(model, 'use_naf_upsample', False): | |
| model._load_naf() | |
| print("[MoGe-2] Loading model for camera estimation...") | |
| moge_model = load_moge_model(device="cuda") | |
| print("[EnvMap] Loading environment maps...") | |
| envmap = { | |
| 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), | |
| 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), | |
| 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')), | |
| } | |
| # ============================================================================ | |
| # Utilities | |
| # ============================================================================ | |
| def compute_f_pixels(camera_angle_x: float, resolution: int) -> float: | |
| focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0)) | |
| f_pixels = focal_length * resolution / 32.0 | |
| return float(f_pixels.item()) | |
| def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution): | |
| rotation_matrix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) | |
| gp = grid_point.to(torch.float32) @ rotation_matrix.T | |
| gp = gp / mesh_scale / 2 | |
| xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item() | |
| xt, yt = float(target_point[0].item()), float(target_point[1].item()) | |
| f_pixels = compute_f_pixels(camera_angle_x, image_resolution) | |
| x_ndc = xt - image_resolution / 2.0 | |
| y_ndc = -(yt - image_resolution / 2.0) | |
| distance_x = f_pixels * xw / x_ndc - yw | |
| return {"distance_from_x": float(distance_x), "f_pixels": float(f_pixels)} | |
| def get_camera_params_wild_moge(image_path, device="cuda", mesh_scale=1.0, extend_pixel=0, image_resolution=512): | |
| pil_image = Image.open(image_path).convert("RGB") | |
| width, height = pil_image.size | |
| image_np = np.array(pil_image).astype(np.float32) / 255.0 | |
| image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device) | |
| with torch.no_grad(): | |
| output = moge_model.infer(image_tensor) | |
| intrinsics = output["intrinsics"].squeeze().cpu().numpy() | |
| fx_normalized = intrinsics[0, 0] | |
| fx = fx_normalized * width | |
| camera_angle_x = 2 * math.atan(width / (2 * fx)) | |
| grid_point = torch.tensor([-1.0, 0.0, 0.0]) | |
| distance = distance_from_fov( | |
| camera_angle_x, grid_point, | |
| torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]), | |
| mesh_scale, image_resolution | |
| )["distance_from_x"] | |
| return {'camera_angle_x': camera_angle_x, 'distance': distance, 'mesh_scale': mesh_scale} | |
| def pack_state(shape_slat, tex_slat, res): | |
| state_data = { | |
| 'shape_slat_feats': shape_slat.feats.cpu().numpy(), | |
| 'tex_slat_feats': tex_slat.feats.cpu().numpy(), | |
| 'coords': shape_slat.coords.cpu().numpy(), | |
| 'res': res, | |
| } | |
| state_path = os.path.join(TMP_DIR, f"state_{int(time.time()*1000)}.npz") | |
| np.savez_compressed(state_path, **state_data) | |
| return state_path | |
| def unpack_state(state_path): | |
| data = np.load(state_path) | |
| shape_slat = SparseTensor( | |
| feats=torch.from_numpy(data['shape_slat_feats']).cuda(), | |
| coords=torch.from_numpy(data['coords']).cuda(), | |
| ) | |
| tex_slat = shape_slat.replace(torch.from_numpy(data['tex_slat_feats']).cuda()) | |
| return shape_slat, tex_slat, int(data['res']) | |
| # ============================================================================ | |
| # API Implementation | |
| # ============================================================================ | |
| app = Server() | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| def preprocess(image: FileData) -> FileData: | |
| init_models() | |
| img = Image.open(image["path"]) | |
| processed = pipeline.preprocess_image(img) | |
| out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png") | |
| processed.save(out_path) | |
| return FileData(path=out_path) | |
| def generate_3d( | |
| image: FileData, | |
| seed: int, | |
| resolution: int, | |
| ss_guidance_strength: float = 7.5, | |
| ss_guidance_rescale: float = 0.7, | |
| ss_sampling_steps: int = 12, | |
| ss_rescale_t: float = 5.0, | |
| shape_slat_guidance_strength: float = 7.5, | |
| shape_slat_guidance_rescale: float = 0.5, | |
| shape_slat_sampling_steps: int = 12, | |
| shape_slat_rescale_t: float = 3.0, | |
| tex_slat_guidance_strength: float = 1.0, | |
| tex_slat_guidance_rescale: float = 0.0, | |
| tex_slat_sampling_steps: int = 12, | |
| tex_slat_rescale_t: float = 3.0, | |
| ) -> Dict: | |
| init_models() | |
| torch.manual_seed(seed) | |
| hr_resolution = int(resolution) | |
| img = Image.open(image["path"]) | |
| image_preprocessed = pipeline.preprocess_image(img) | |
| temp_processed_path = os.path.join(TMP_DIR, "temp_proc.png") | |
| image_preprocessed.save(temp_processed_path) | |
| camera_params = get_camera_params_wild_moge( | |
| temp_processed_path, device="cuda", | |
| mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL, | |
| image_resolution=WILD_IMAGE_RESOLUTION, | |
| ) | |
| ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, | |
| "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t} | |
| shape_sampler_override = {"steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, | |
| "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t} | |
| tex_sampler_override = {"steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, | |
| "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t} | |
| pipeline_type = f"{hr_resolution}_cascade" | |
| mesh_list, (shape_slat, tex_slat, res) = pipeline.run( | |
| image_preprocessed, | |
| camera_params=camera_params, | |
| seed=seed, | |
| sparse_structure_sampler_params=ss_sampler_override, | |
| shape_slat_sampler_params=shape_sampler_override, | |
| tex_slat_sampler_params=tex_sampler_override, | |
| preprocess_image=False, | |
| return_latent=True, | |
| pipeline_type=pipeline_type, | |
| max_num_tokens=CASCADE_MAX_NUM_TOKENS, | |
| ) | |
| mesh = mesh_list[0] | |
| state_path = pack_state(shape_slat, tex_slat, res) | |
| mesh.simplify(16777216) | |
| renders = render_utils.render_proj_aligned_video( | |
| mesh, camera_angle_x=camera_params['camera_angle_x'], | |
| distance=camera_params['distance'], resolution=1024, | |
| num_frames=STEPS, envmap=envmap, | |
| ) | |
| # Save renders and return paths | |
| render_files = {} | |
| for mode_key, frames in renders.items(): | |
| mode_files = [] | |
| for i, frame in enumerate(frames): | |
| p = os.path.abspath(os.path.join(TMP_DIR, f"render_{mode_key}_{i}_{int(time.time()*1000)}.jpg")) | |
| Image.fromarray(frame).save(p, quality=85) | |
| mode_files.append(FileData(path=p)) | |
| render_files[mode_key] = mode_files | |
| return { | |
| "render_paths": render_files, | |
| "state_path": os.path.abspath(state_path) | |
| } | |
| def extract_glb_api(state_path: str, decimation_target: int, texture_size: int) -> FileData: | |
| init_models() | |
| shape_slat, tex_slat, res = unpack_state(state_path) | |
| mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] | |
| glb = o_voxel.postprocess.to_glb( | |
| vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, | |
| coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, | |
| grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], | |
| decimation_target=decimation_target, texture_size=texture_size, | |
| remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, | |
| ) | |
| rot = np.array([ | |
| [-1, 0, 0, 0], | |
| [ 0, 0, -1, 0], | |
| [ 0, -1, 0, 0], | |
| [ 0, 0, 0, 1], | |
| ], dtype=np.float64) | |
| glb.apply_transform(rot) | |
| out_glb = os.path.join(TMP_DIR, f"result_{int(time.time()*1000)}.glb") | |
| glb.export(out_glb, extension_webp=True) | |
| return FileData(path=out_glb) | |
| # Mount assets and tmp for direct access | |
| app.mount("/assets", StaticFiles(directory="assets"), name="assets") | |
| app.mount("/tmp", StaticFiles(directory=TMP_DIR), name="tmp") | |
| if __name__ == "__main__": | |
| # Re-install utils3d as in original app.py | |
| subprocess.run([ | |
| "pip", "install", "--force-reinstall", "--no-deps", | |
| "https://github.com/LDYang694/Storages/releases/download/20260430/utils3d-0.0.2-py3-none-any.whl" | |
| ], check=True) | |
| # Pre-initialize models before launching the server | |
| init_models() | |
| app.launch(show_error=True, share=True) | |