Pixal3D-T / app.py
acecalisto3's picture
Update app.py
be06730 verified
import os
import subprocess
import argparse
import math
import time
import shutil
import cv2
import torch
import numpy as np
import base64
import io
import json
from datetime import datetime
from typing import *
from PIL import Image
import os
from huggingface_hub import InferenceClient
from PIL import Image
client = InferenceClient(api_key=os.getenv("HF_TOKEN"))
image_path = "cats.jpg"
try:
output = client.image_segmentation(image_path, model="briaai/RMBG-2.0")
output.save("output_no_bg.png")
print("Success! Background removed and saved as output_no_bg.png")
except Exception as e:
print(f"Error encountered: {e}")
import threading
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
init_lock = threading.Lock()
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ATTN_BACKEND"] = "flash_attn_3"
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
import spaces
from gradio import Server
from gradio.data_classes import FileData
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from trellis2.modules.sparse import SparseTensor
from trellis2.pipelines import Pixal3DImageTo3DPipeline
from trellis2.renderers import EnvMap
from trellis2.utils import render_utils
import o_voxel
# ============================================================================
# Constants & Defaults
# ============================================================================
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
MODES = [
{"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
{"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
{"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
{"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
{"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
{"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
]
STEPS = 8
# Cascade parameters
CASCADE_LR_RESOLUTION = 512
CASCADE_MAX_NUM_TOKENS = 49152
# MoGe defaults
MOGE_MODEL_NAME = "Ruicheng/moge-2-vitl"
WILD_MESH_SCALE = 1.0
WILD_EXTEND_PIXEL = 0
WILD_IMAGE_RESOLUTION = 512
# Image Cond Model configs
IMAGE_COND_CONFIGS = {
"ss": {
"model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 512,
"grid_resolution": 16,
},
"shape_512": {
"model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 512,
"grid_resolution": 32,
"use_naf_upsample": True,
"naf_target_size": 512,
},
"shape_1024": {
"model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 1024,
"grid_resolution": 64,
"use_naf_upsample": True,
"naf_target_size": 512,
},
"tex_1024": {
"model_name": "camenduru/dinov3-vitl16-pretrain-lvd1689m",
"image_size": 1024,
"grid_resolution": 64,
"use_naf_upsample": True,
"naf_target_size": 1024,
},
}
# ============================================================================
# Model Loading
# ============================================================================
def build_image_cond_model(config: dict):
from trellis2.trainers.flow_matching.mixins.image_conditioned_proj import DinoV3ProjFeatureExtractor
model = DinoV3ProjFeatureExtractor(**config)
model.eval()
return model
def load_moge_model(device="cuda", model_name=MOGE_MODEL_NAME):
from moge.model.v2 import MoGeModel
moge_model = MoGeModel.from_pretrained(model_name).to(device)
moge_model.eval()
return moge_model
# Global instances (lazy loaded or loaded at start)
pipeline = None
moge_model = None
envmap = None
def init_models():
global pipeline, moge_model, envmap
with init_lock:
if pipeline is not None:
return
model_path = "TencentARC/Pixal3D-T"
print(f"[Pipeline] Loading from {model_path}...")
pipeline = Pixal3DImageTo3DPipeline.from_pretrained(model_path)
print("[ImageCond] Building DinoV3ProjFeatureExtractor models...")
pipeline.image_cond_model_ss = build_image_cond_model(IMAGE_COND_CONFIGS["ss"])
pipeline.image_cond_model_shape_512 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_512"])
pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
pipeline.cuda()
print("[NAF] Pre-loading NAF upsampler model...")
for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
model = getattr(pipeline, attr, None)
if model is not None and getattr(model, 'use_naf_upsample', False):
model._load_naf()
print("[MoGe-2] Loading model for camera estimation...")
moge_model = load_moge_model(device="cuda")
print("[EnvMap] Loading environment maps...")
_base = os.path.dirname(os.path.abspath(__file__))
envmap = {
'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/forest.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/sunset.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
}
# ============================================================================
# Utilities
# ============================================================================
def compute_f_pixels(camera_angle_x: float, resolution: int) -> float:
focal_length = 16.0 / torch.tan(torch.tensor(camera_angle_x / 2.0))
f_pixels = focal_length * resolution / 32.0
return float(f_pixels.item())
def distance_from_fov(camera_angle_x, grid_point, target_point, mesh_scale, image_resolution):
rotation_matrix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
gp = grid_point.to(torch.float32) @ rotation_matrix.T
gp = gp / mesh_scale / 2
xw, yw, zw = gp[0].item(), gp[1].item(), gp[2].item()
xt, yt = float(target_point[0].item()), float(target_point[1].item())
f_pixels = compute_f_pixels(camera_angle_x, image_resolution)
x_ndc = xt - image_resolution / 2.0
y_ndc = -(yt - image_resolution / 2.0)
distance_x = f_pixels * xw / x_ndc - yw
return {"distance_from_x": float(distance_x), "f_pixels": float(f_pixels)}
def get_camera_params_wild_moge(image_path, device="cuda", mesh_scale=1.0, extend_pixel=0, image_resolution=512):
pil_image = Image.open(image_path).convert("RGB")
width, height = pil_image.size
image_np = np.array(pil_image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).to(device)
with torch.no_grad():
output = moge_model.infer(image_tensor)
intrinsics = output["intrinsics"].squeeze().cpu().numpy()
fx_normalized = intrinsics[0, 0]
fx = fx_normalized * width
camera_angle_x = 2 * math.atan(width / (2 * fx))
grid_point = torch.tensor([-1.0, 0.0, 0.0])
distance = distance_from_fov(
camera_angle_x, grid_point,
torch.tensor([0 - extend_pixel, image_resolution - 1 + extend_pixel]),
mesh_scale, image_resolution
)["distance_from_x"]
return {'camera_angle_x': camera_angle_x, 'distance': distance, 'mesh_scale': mesh_scale}
def pack_state(shape_slat, tex_slat, res):
state_data = {
'shape_slat_feats': shape_slat.feats.cpu().numpy(),
'tex_slat_feats': tex_slat.feats.cpu().numpy(),
'coords': shape_slat.coords.cpu().numpy(),
'res': res,
}
import random
state_path = os.path.join(TMP_DIR, f"state_{int(time.time()*1000)}_{random.randint(0,9999):04d}.npz")
np.savez_compressed(state_path, **state_data)
return state_path
def unpack_state(state_path):
data = np.load(state_path)
shape_slat = SparseTensor(
feats=torch.from_numpy(data['shape_slat_feats']).cuda(),
coords=torch.from_numpy(data['coords']).cuda(),
)
tex_slat = shape_slat.replace(torch.from_numpy(data['tex_slat_feats']).cuda())
return shape_slat, tex_slat, int(data['res'])
# ============================================================================
# Progress Tracking (SSE-based, tqdm interception, multi-session)
# ============================================================================
import asyncio
import queue
from fastapi.responses import StreamingResponse
from fastapi import Request
_progress_queues: Dict[str, queue.Queue] = {}
_active_session: str = ""
def _reset_progress(session_id: str):
global _active_session
_active_session = session_id
if session_id not in _progress_queues:
_progress_queues[session_id] = queue.Queue()
# Drain old items
q = _progress_queues[session_id]
while not q.empty():
try:
q.get_nowait()
except:
break
def _update_progress(stage: str, step: int, total: int):
data = {"stage": stage, "step": step, "total": total, "done": False}
session_id = _active_session
if session_id and session_id in _progress_queues:
try:
_progress_queues[session_id].put_nowait(data)
except:
pass
def _finish_progress():
session_id = _active_session
if session_id and session_id in _progress_queues:
try:
_progress_queues[session_id].put_nowait({"done": True})
except:
pass
# Schedule cleanup after a short delay (let SSE client receive the done signal)
def _cleanup():
time.sleep(5)
_progress_queues.pop(session_id, None)
threading.Thread(target=_cleanup, daemon=True).start()
# Monkey-patch tqdm to intercept progress
import tqdm as _tqdm_module
_original_tqdm = _tqdm_module.tqdm
class _TqdmProgressInterceptor(_original_tqdm):
"""Wraps tqdm to push progress updates to SSE."""
def __init__(self, *args, **kwargs):
self._stage_desc = kwargs.get('desc', 'Processing')
super().__init__(*args, **kwargs)
def update(self, n=1):
super().update(n)
_update_progress(self._stage_desc, self.n, self.total or 0)
# Patch tqdm globally
_tqdm_module.tqdm = _TqdmProgressInterceptor
import trellis2.pipelines.samplers.flow_euler as _fe_module
_fe_module.tqdm = _TqdmProgressInterceptor
import trellis2.utils.render_utils as _ru_module
_ru_module.tqdm = _TqdmProgressInterceptor
# ============================================================================
# API Implementation
# ============================================================================
app = Server()
@app.get("/")
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
@app.get("/progress")
async def progress_sse(request: Request):
"""SSE endpoint for real-time progress updates during generation."""
session_id = request.query_params.get("session_id", "")
if session_id and session_id not in _progress_queues:
_progress_queues[session_id] = queue.Queue()
async def event_stream():
q = _progress_queues.get(session_id)
timeout_count = 0
while True:
if q:
try:
data = q.get_nowait()
yield f"data: {json.dumps(data)}\n\n"
if data.get("done"):
break
timeout_count = 0
except queue.Empty:
yield f": keepalive\n\n"
timeout_count += 1
else:
yield f": keepalive\n\n"
timeout_count += 1
# Timeout after 5 minutes of no data
if timeout_count > 1000:
break
await asyncio.sleep(0.3)
return StreamingResponse(event_stream(), media_type="text/event-stream")
@app.api()
@spaces.GPU(duration=60)
def preprocess(image: FileData) -> FileData:
init_models()
img = Image.open(image["path"])
processed = pipeline.preprocess_image(img)
out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png")
processed.save(out_path)
return FileData(path=out_path)
@app.api()
@spaces.GPU(duration=120)
def generate_3d(
image: FileData,
seed: int,
resolution: int,
ss_guidance_strength: float = 7.5,
ss_guidance_rescale: float = 0.7,
ss_sampling_steps: int = 12,
ss_rescale_t: float = 5.0,
shape_slat_guidance_strength: float = 7.5,
shape_slat_guidance_rescale: float = 0.5,
shape_slat_sampling_steps: int = 12,
shape_slat_rescale_t: float = 3.0,
tex_slat_guidance_strength: float = 1.0,
tex_slat_guidance_rescale: float = 0.0,
tex_slat_sampling_steps: int = 12,
tex_slat_rescale_t: float = 3.0,
session_id: str = "",
) -> Dict:
init_models()
_reset_progress(session_id)
_update_progress("Preprocessing & Camera Estimation", 0, 1)
torch.manual_seed(seed)
hr_resolution = int(resolution)
img = Image.open(image["path"])
image_preprocessed = pipeline.preprocess_image(img)
temp_processed_path = os.path.join(TMP_DIR, f"temp_proc_{session_id[:8]}_{int(time.time()*1000)}.png")
image_preprocessed.save(temp_processed_path)
camera_params = get_camera_params_wild_moge(
temp_processed_path, device="cuda",
mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL,
image_resolution=WILD_IMAGE_RESOLUTION,
)
_update_progress("Preprocessing & Camera Estimation", 1, 1)
ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength,
"guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t}
shape_sampler_override = {"steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t}
tex_sampler_override = {"steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t}
pipeline_type = f"{hr_resolution}_cascade"
mesh_list, (shape_slat, tex_slat, res) = pipeline.run(
image_preprocessed,
camera_params=camera_params,
seed=seed,
sparse_structure_sampler_params=ss_sampler_override,
shape_slat_sampler_params=shape_sampler_override,
tex_slat_sampler_params=tex_sampler_override,
preprocess_image=False,
return_latent=True,
pipeline_type=pipeline_type,
max_num_tokens=CASCADE_MAX_NUM_TOKENS,
)
mesh = mesh_list[0]
state_path = pack_state(shape_slat, tex_slat, res)
_update_progress("Rendering views", 0, 1)
mesh.simplify(16777216)
cam_dist = camera_params['distance']
near = max(0.01, cam_dist - 2.0)
far = cam_dist + 10.0
renders = render_utils.render_proj_aligned_video(
mesh, camera_angle_x=camera_params['camera_angle_x'],
distance=cam_dist, resolution=1024,
num_frames=STEPS, envmap=envmap,
near=near, far=far,
)
_update_progress("Rendering views", 1, 1)
# Save renders and return paths
render_files = {}
for mode_key, frames in renders.items():
mode_files = []
for i, frame in enumerate(frames):
p = os.path.abspath(os.path.join(TMP_DIR, f"render_{mode_key}_{i}_{int(time.time()*1000)}.jpg"))
Image.fromarray(frame).save(p, quality=85)
mode_files.append(FileData(path=p))
render_files[mode_key] = mode_files
_finish_progress()
return {
"render_paths": render_files,
"state_path": os.path.abspath(state_path)
}
@app.api()
@spaces.GPU(duration=240)
def extract_glb_api(state_path: str, decimation_target: int, texture_size: int, session_id: str = "") -> FileData:
init_models()
_reset_progress(session_id)
_update_progress("Decoding latent", 0, 1)
shape_slat, tex_slat, res = unpack_state(state_path)
mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
_update_progress("Decoding latent", 1, 1)
_update_progress("Extracting GLB mesh", 0, 1)
glb = o_voxel.postprocess.to_glb(
vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target=decimation_target, texture_size=texture_size,
remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True,
)
rot = np.array([
[-1, 0, 0, 0],
[ 0, 0, -1, 0],
[ 0, -1, 0, 0],
[ 0, 0, 0, 1],
], dtype=np.float64)
glb.apply_transform(rot)
out_glb = os.path.join(TMP_DIR, f"result_{int(time.time()*1000)}.glb")
glb.export(out_glb, extension_webp=True)
_finish_progress()
return FileData(path=out_glb)
app.mount("/assets", StaticFiles(directory="assets"), name="assets")
app.mount("/tmp", StaticFiles(directory=TMP_DIR), name="tmp")
if __name__ == "__main__":
# Reinstall utils3d if needed
subprocess.run([
"pip", "install", "--force-reinstall", "--no-deps",
"https://github.com/LDYang694/Storages/releases/download/20260430/utils3d-0.0.2-py3-none-any.whl"
], check=True)
init_models()
app.launch(show_error=True, share=True)