Add ZeroGPU support: import spaces first, add @spaces.GPU decorators, add flex_gemm env vars
Browse files
app.py
CHANGED
|
@@ -5,6 +5,7 @@ Image-to-3D generation using Proj-mode Cascade inference (512->1024/1536).
|
|
| 5 |
|
| 6 |
"""
|
| 7 |
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
|
| 10 |
import os
|
|
@@ -16,6 +17,10 @@ subprocess.run([
|
|
| 16 |
|
| 17 |
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 18 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
import argparse
|
| 20 |
import math
|
| 21 |
import time
|
|
@@ -324,6 +329,7 @@ def get_seed(randomize_seed, seed):
|
|
| 324 |
# Core Inference
|
| 325 |
# ============================================================================
|
| 326 |
|
|
|
|
| 327 |
def image_to_3d(
|
| 328 |
image, seed, resolution,
|
| 329 |
ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
|
|
@@ -422,6 +428,7 @@ def image_to_3d(
|
|
| 422 |
return state, full_html
|
| 423 |
|
| 424 |
|
|
|
|
| 425 |
def extract_glb(state, decimation_target, texture_size, req: gr.Request, progress=gr.Progress(track_tqdm=True)):
|
| 426 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 427 |
shape_slat, tex_slat, res = unpack_state(state)
|
|
|
|
| 5 |
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
import spaces
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
import os
|
|
|
|
| 17 |
|
| 18 |
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 19 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 20 |
+
os.environ["ATTN_BACKEND"] = "flash_attn_3"
|
| 21 |
+
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
|
| 22 |
+
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
|
| 23 |
+
|
| 24 |
import argparse
|
| 25 |
import math
|
| 26 |
import time
|
|
|
|
| 329 |
# Core Inference
|
| 330 |
# ============================================================================
|
| 331 |
|
| 332 |
+
@spaces.GPU(duration=120)
|
| 333 |
def image_to_3d(
|
| 334 |
image, seed, resolution,
|
| 335 |
ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
|
|
|
|
| 428 |
return state, full_html
|
| 429 |
|
| 430 |
|
| 431 |
+
@spaces.GPU(duration=120)
|
| 432 |
def extract_glb(state, decimation_target, texture_size, req: gr.Request, progress=gr.Progress(track_tqdm=True)):
|
| 433 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 434 |
shape_slat, tex_slat, res = unpack_state(state)
|