Pixal3D / app_bak.py
Yang2001's picture
chore: update app.py, add app_bak.py, update requirements and autotune cache
f7a2756
import os
import subprocess
import argparse
import math
import time
import shutil
import cv2
import torch
import numpy as np
import base64
import io
import json
from datetime import datetime
from typing import *
from PIL import Image
import threading
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# Lock for model initialization
init_lock = threading.Lock()
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ATTN_BACKEND"] = "flash_attn_3"
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
import spaces
from gradio import Server
from gradio.data_classes import FileData
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from trellis2.modules.sparse import SparseTensor
from trellis2.pipelines import Pixal3DImageTo3DPipeline
from trellis2.renderers import EnvMap
from trellis2.utils import render_utils
import o_voxel
# ============================================================================
# Constants & Defaults
# ============================================================================
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
MODES = [
{"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
{"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
{"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
{"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
{"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
{"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
]
STEPS = 8
# Cascade parameters
CASCADE_LR_RESOLUTION = 512
CASCADE_MAX_NUM_TOKENS = 49152
# MoGe defaults
MOGE_MODEL_NAME = "Ruicheng/moge-2-vitl"
WILD_MESH_SCALE = 1.0
WILD_EXTEND_PIXEL = 0
WILD_IMAGE_RESOLUTION = 512
# Image Cond Model configs
IMAGE_COND_CONFIGS = {
"ss": {
"model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 512,
"grid_resolution": 16,
},
"shape_512": {
"model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 512,
"grid_resolution": 32,
"use_naf_upsample": True,
"naf_target_size": 512,
},
"shape_1024": {
"model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 1024,
"grid_resolution": 64,
"use_naf_upsample": True,
"naf_target_size": 512,
},
"tex_1024": {
"model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 1024,
"grid_resolution": 64,
"use_naf_upsample": True,
"naf_target_size": 1024,
},
}
# ============================================================================
# Model Loading
# ============================================================================
def build_image_cond_model(config: dict):
from trellis2.trainers.flow_matching.mixins.image_conditioned_proj import DinoV3ProjFeatureExtractor
model = DinoV3ProjFeatureExtractor(**config)
model.eval()
return model
def load_moge_model(device="cuda", model_name=MOGE_MODEL_NAME):
from moge.model.v2 import MoGeModel
moge_model = MoGeModel.from_pretrained(model_name).to(device)
moge_model.eval()
return moge_model
# Global instances (lazy loaded or loaded at start)
pipeline = None
moge_model = None
envmap = None
def init_models():
global pipeline, moge_model, envmap
with init_lock:
if pipeline is not None:
return
# GPU / CUDA Diagnostics (runs when GPU is allocated)
import subprocess as _sp
print("=" * 60)
print("[Diagnostics] PyTorch version:", torch.__version__)
print("[Diagnostics] CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
print("[Diagnostics] CUDA version:", torch.version.cuda)
print("[Diagnostics] cuDNN version:", torch.backends.cudnn.version())
for i in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(i)
cap = torch.cuda.get_device_capability(i)
mem = torch.cuda.get_device_properties(i).total_memory / 1024**3
print(f"[Diagnostics] GPU {i}: {name}, sm_{cap[0]}{cap[1]}, {mem:.1f} GB")
try:
res = _sp.run(["nvidia-smi", "--query-gpu=name,compute_cap,memory.total", "--format=csv,noheader"], capture_output=True, text=True, timeout=10)
print("[Diagnostics] nvidia-smi:", res.stdout.strip())
except Exception as e:
print(f"[Diagnostics] nvidia-smi failed: {e}")
print("=" * 60)
model_path = "TencentARC/Pixal3D-T"
print(f"[Pipeline] Loading from {model_path}...")
pipeline = Pixal3DImageTo3DPipeline.from_pretrained(model_path)
print("[ImageCond] Building DinoV3ProjFeatureExtractor models...")
pipeline.image_cond_model_ss = build_image_cond_model(IMAGE_COND_CONFIGS["ss"])
pipeline.image_cond_model_shape_512 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_512"])
pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
pipeline.low_vram = False
pipeline.cuda()
# Ensure image_cond_models are on GPU
pipeline.image_cond_model_ss.cuda()
pipeline.image_cond_model_shape_512.cuda()
pipeline.image_cond_model_shape_1024.cuda()
pipeline.image_cond_model_tex_1024.cuda()
print("[NAF] Pre-loading NAF upsampler model...")
for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
model = getattr(pipeline, attr, None)
if model is not None and getattr(model, 'use_naf_upsample', False):
model._load_naf()
print("[MoGe-2] Loading model for camera estimation...")
moge_model = load_moge_model(device="cuda")
print("[EnvMap] Loading environment maps...")
_base = os.path.dirname(os.path.abspath(__file__))
envmap = {
'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/forest.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/sunset.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
}
# ============================================================================
# Utilities
# ============================================================================
def compute_f_pixels(camera_angle_x: float, resolution: int) -> float:
focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0))
f_pixels = focal_length * resolution / 32.0
return float(f_pixels.item())
def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution):
rotation_matrix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
gp = grid_point.to(torch.float32) @ rotation_matrix.T
gp = gp / mesh_scale / 2
xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item()
xt, yt = float(target_point[0].item()), float(target_point[1].item())
f_pixels = compute_f_pixels(camera_angle_x, image_resolution)
x_ndc = xt - image_resolution / 2.0
y_ndc = -(yt - image_resolution / 2.0)
distance_x = f_pixels * xw / x_ndc - yw
return {"distance_from_x": float(distance_x), "f_pixels": float(f_pixels)}
def get_camera_params_wild_moge(image_path, device="cuda", mesh_scale=1.0, extend_pixel=0, image_resolution=512):
pil_image = Image.open(image_path).convert("RGB")
width, height = pil_image.size
image_np = np.array(pil_image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device)
with torch.no_grad():
output = moge_model.infer(image_tensor)
intrinsics = output["intrinsics"].squeeze().cpu().numpy()
fx_normalized = intrinsics[0, 0]
fx = fx_normalized * width
camera_angle_x = 2 * math.atan(width / (2 * fx))
grid_point = torch.tensor([-1.0, 0.0, 0.0])
distance = distance_from_fov(
camera_angle_x, grid_point,
torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]),
mesh_scale, image_resolution
)["distance_from_x"]
return {'camera_angle_x': camera_angle_x, 'distance': distance, 'mesh_scale': mesh_scale}
def pack_state(shape_slat, tex_slat, res):
state_data = {
'shape_slat_feats': shape_slat.feats.cpu().numpy(),
'tex_slat_feats': tex_slat.feats.cpu().numpy(),
'coords': shape_slat.coords.cpu().numpy(),
'res': res,
}
import random
state_path = os.path.join(TMP_DIR, f"state_{int(time.time()*1000)}_{random.randint(0,9999):04d}.npz")
np.savez_compressed(state_path, **state_data)
return state_path
def unpack_state(state_path):
data = np.load(state_path)
shape_slat = SparseTensor(
feats=torch.from_numpy(data['shape_slat_feats']).cuda(),
coords=torch.from_numpy(data['coords']).cuda(),
)
tex_slat = shape_slat.replace(torch.from_numpy(data['tex_slat_feats']).cuda())
return shape_slat, tex_slat, int(data['res'])
# ============================================================================
# Progress Tracking (file-based, cross-process safe for @spaces.GPU)
# ============================================================================
import asyncio
from fastapi.responses import JSONResponse
from fastapi import Request
PROGRESS_DIR = os.path.join(TMP_DIR, '_progress')
os.makedirs(PROGRESS_DIR, exist_ok=True)
_thread_local = threading.local()
def _progress_file(session_id: str) -> str:
"""Return path to a session's progress JSON file."""
return os.path.join(PROGRESS_DIR, f"{session_id}.json")
def _reset_progress(session_id: str):
_thread_local.active_session = session_id
_write_progress_file(session_id, {"stage": "Initializing...", "step": 0, "total": 0, "done": False})
def _update_progress(stage: str, step: int, total: int):
session_id = getattr(_thread_local, 'active_session', '')
if session_id:
_write_progress_file(session_id, {"stage": stage, "step": step, "total": total, "done": False})
def _finish_progress():
session_id = getattr(_thread_local, 'active_session', '')
if session_id:
_write_progress_file(session_id, {"done": True})
def _write_progress_file(session_id: str, data: dict):
"""Atomically write progress JSON to a file (cross-process safe)."""
path = _progress_file(session_id)
tmp_path = path + ".tmp"
try:
with open(tmp_path, 'w') as f:
json.dump(data, f)
os.replace(tmp_path, path) # atomic on POSIX
except Exception:
pass
# Monkey-patch tqdm to intercept progress
import tqdm as _tqdm_module
_original_tqdm = _tqdm_module.tqdm
class _TqdmProgressInterceptor(_original_tqdm):
"""Wraps tqdm to push progress updates to SSE."""
def __init__(self, *args, **kwargs):
self._stage_desc = kwargs.get('desc', 'Processing')
super().__init__(*args, **kwargs)
def set_description(self, desc=None, refresh=True):
self._stage_desc = desc or 'Processing'
super().set_description(desc, refresh)
def update(self, n=1):
super().update(n)
_update_progress(self._stage_desc, self.n, self.total or 0)
# Patch tqdm globally
_tqdm_module.tqdm = _TqdmProgressInterceptor
# Also patch the direct import in the sampler module and render_utils
import trellis2.pipelines.samplers.flow_euler as _fe_module
_fe_module.tqdm = _TqdmProgressInterceptor
import trellis2.utils.render_utils as _ru_module
_ru_module.tqdm = _TqdmProgressInterceptor
import o_voxel.postprocess as _ovp_module
_ovp_module.tqdm = _TqdmProgressInterceptor
# ============================================================================
# API Implementation
# ============================================================================
app = Server()
@app.get("/")
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
@app.get("/progress")
async def progress_poll(request: Request):
"""Polling endpoint for real-time progress updates during generation."""
session_id = request.query_params.get("session_id", "")
path = _progress_file(session_id)
try:
with open(path, 'r') as f:
data = json.load(f)
return JSONResponse(data)
except (FileNotFoundError, json.JSONDecodeError):
return JSONResponse({"stage": "Waiting...", "step": 0, "total": 0, "done": False})
@app.api()
@spaces.GPU(duration=30)
def preprocess(image: FileData) -> FileData:
init_models()
img = Image.open(image["path"])
processed = pipeline.preprocess_image(img)
out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png")
processed.save(out_path)
return FileData(path=out_path)
@app.api()
@spaces.GPU(duration=120)
def generate_3d(
image: FileData,
seed: int,
resolution: int,
ss_guidance_strength: float = 7.5,
ss_guidance_rescale: float = 0.7,
ss_sampling_steps: int = 12,
ss_rescale_t: float = 5.0,
shape_slat_guidance_strength: float = 7.5,
shape_slat_guidance_rescale: float = 0.5,
shape_slat_sampling_steps: int = 12,
shape_slat_rescale_t: float = 3.0,
tex_slat_guidance_strength: float = 1.0,
tex_slat_guidance_rescale: float = 0.0,
tex_slat_sampling_steps: int = 12,
tex_slat_rescale_t: float = 3.0,
session_id: str = "",
) -> Dict:
init_models()
_reset_progress(session_id)
_update_progress("Preprocessing & Camera Estimation", 0, 1)
torch.manual_seed(seed)
hr_resolution = int(resolution)
img = Image.open(image["path"])
# Image is already preprocessed by /preprocess endpoint, use directly
image_preprocessed = img
temp_processed_path = os.path.join(TMP_DIR, f"temp_proc_{session_id[:8]}_{int(time.time()*1000)}.png")
image_preprocessed.save(temp_processed_path)
camera_params = get_camera_params_wild_moge(
temp_processed_path, device="cuda",
mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL,
image_resolution=WILD_IMAGE_RESOLUTION,
)
_update_progress("Preprocessing & Camera Estimation", 1, 1)
ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength,
"guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t}
shape_sampler_override = {"steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t}
tex_sampler_override = {"steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t}
pipeline_type = f"{hr_resolution}_cascade"
mesh_list, (shape_slat, tex_slat, res) = pipeline.run(
image_preprocessed,
camera_params=camera_params,
seed=seed,
sparse_structure_sampler_params=ss_sampler_override,
shape_slat_sampler_params=shape_sampler_override,
tex_slat_sampler_params=tex_sampler_override,
preprocess_image=False,
return_latent=True,
pipeline_type=pipeline_type,
max_num_tokens=CASCADE_MAX_NUM_TOKENS,
)
mesh = mesh_list[0]
state_path = pack_state(shape_slat, tex_slat, res)
_update_progress("Rendering views", 0, 1)
mesh.simplify(16777216)
cam_dist = camera_params['distance']
near = max(0.01, cam_dist - 2.0)
far = cam_dist + 10.0
renders = render_utils.render_proj_aligned_video(
mesh, camera_angle_x=camera_params['camera_angle_x'],
distance=cam_dist, resolution=1024,
num_frames=STEPS, envmap=envmap,
near=near, far=far,
)
_update_progress("Rendering views", 1, 1)
# Save renders and return paths
render_files = {}
for mode_key, frames in renders.items():
mode_files = []
for i, frame in enumerate(frames):
p = os.path.abspath(os.path.join(TMP_DIR, f"render_{mode_key}_{i}_{int(time.time()*1000)}.jpg"))
Image.fromarray(frame).save(p, quality=85)
mode_files.append(FileData(path=p))
render_files[mode_key] = mode_files
_finish_progress()
return {
"render_paths": render_files,
"state_path": os.path.abspath(state_path),
"camera_angle_x": camera_params['camera_angle_x'],
"distance": camera_params['distance'],
}
@app.api()
@spaces.GPU(duration=240)
def extract_glb_api(state_path: str, decimation_target: int, texture_size: int, session_id: str = "") -> FileData:
init_models()
_reset_progress(session_id)
_update_progress("Decoding latent", 0, 1)
shape_slat, tex_slat, res = unpack_state(state_path)
mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
_update_progress("Decoding latent", 1, 1)
glb = o_voxel.postprocess.to_glb(
vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target=decimation_target, texture_size=texture_size,
remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True,
)
rot = np.array([
[-1, 0, 0, 0],
[ 0, 0, -1, 0],
[ 0, -1, 0, 0],
[ 0, 0, 0, 1],
], dtype=np.float64)
glb.apply_transform(rot)
out_glb = os.path.join(TMP_DIR, f"result_{int(time.time()*1000)}.glb")
glb.export(out_glb, extension_webp=True)
_finish_progress()
return FileData(path=out_glb)
# Mount assets and tmp for direct access
app.mount("/assets", StaticFiles(directory="assets"), name="assets")
app.mount("/tmp", StaticFiles(directory=TMP_DIR), name="tmp")
if __name__ == "__main__":
# Re-install utils3d as in original app.py
subprocess.run([
"pip", "install", "--force-reinstall", "--no-deps",
"https://github.com/LDYang694/Storages/releases/download/20260430/utils3d-0.0.2-py3-none-any.whl"
], check=True)
# Pre-initialize models before launching the server
init_models()
app.launch(show_error=True, share=True)