Nanny7 commited on
Commit
3359103
·
1 Parent(s): 7bc688a

Add ZeroGPU support: import spaces first, add @spaces.GPU decorators, add flex_gemm env vars

Browse files
Files changed (1) hide show
  1. app.py +7 -0
app.py CHANGED
@@ -5,6 +5,7 @@ Image-to-3D generation using Proj-mode Cascade inference (512->1024/1536).
5
 
6
  """
7
 
 
8
  import gradio as gr
9
 
10
  import os
@@ -16,6 +17,10 @@ subprocess.run([
16
 
17
  os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
18
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 
 
 
 
19
  import argparse
20
  import math
21
  import time
@@ -324,6 +329,7 @@ def get_seed(randomize_seed, seed):
324
  # Core Inference
325
  # ============================================================================
326
 
 
327
  def image_to_3d(
328
  image, seed, resolution,
329
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
@@ -422,6 +428,7 @@ def image_to_3d(
422
  return state, full_html
423
 
424
 
 
425
  def extract_glb(state, decimation_target, texture_size, req: gr.Request, progress=gr.Progress(track_tqdm=True)):
426
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
427
  shape_slat, tex_slat, res = unpack_state(state)
 
5
 
6
  """
7
 
8
+ import spaces
9
  import gradio as gr
10
 
11
  import os
 
17
 
18
  os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
19
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
20
+ os.environ["ATTN_BACKEND"] = "flash_attn_3"
21
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
22
+ os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
23
+
24
  import argparse
25
  import math
26
  import time
 
329
  # Core Inference
330
  # ============================================================================
331
 
332
+ @spaces.GPU(duration=120)
333
  def image_to_3d(
334
  image, seed, resolution,
335
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
 
428
  return state, full_html
429
 
430
 
431
+ @spaces.GPU(duration=120)
432
  def extract_glb(state, decimation_target, texture_size, req: gr.Request, progress=gr.Progress(track_tqdm=True)):
433
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
434
  shape_slat, tex_slat, res = unpack_state(state)