Yang2001 commited on
Commit
604ad14
·
1 Parent(s): c3cee00

feat: add clear button, SSE progress tracking, multi-session support

Browse files

- Add trash button in header to clear state and return to upload
- Add real-time progress bar via SSE with tqdm interception
- Per-session progress queues to prevent cross-talk between users
- Fix temp file naming to avoid concurrent conflicts
- Support near/far params in render_proj_aligned_video

Files changed (3) hide show
  1. app.py +125 -4
  2. index.html +162 -26
  3. trellis2/utils/render_utils.py +7 -1
app.py CHANGED
@@ -204,7 +204,8 @@ def pack_state(shape_slat, tex_slat, res):
204
  'coords': shape_slat.coords.cpu().numpy(),
205
  'res': res,
206
  }
207
- state_path = os.path.join(TMP_DIR, f"state_{int(time.time()*1000)}.npz")
 
208
  np.savez_compressed(state_path, **state_data)
209
  return state_path
210
 
@@ -217,6 +218,77 @@ def unpack_state(state_path):
217
  tex_slat = shape_slat.replace(torch.from_numpy(data['tex_slat_feats']).cuda())
218
  return shape_slat, tex_slat, int(data['res'])
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  # ============================================================================
221
  # API Implementation
222
  # ============================================================================
@@ -229,6 +301,36 @@ async def homepage():
229
  with open(html_path, "r", encoding="utf-8") as f:
230
  return HTMLResponse(content=f.read())
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  @app.api()
233
  def preprocess(image: FileData) -> FileData:
234
  init_models()
@@ -256,14 +358,18 @@ def generate_3d(
256
  tex_slat_guidance_rescale: float = 0.0,
257
  tex_slat_sampling_steps: int = 12,
258
  tex_slat_rescale_t: float = 3.0,
 
259
  ) -> Dict:
260
  init_models()
 
 
 
261
  torch.manual_seed(seed)
262
  hr_resolution = int(resolution)
263
 
264
  img = Image.open(image["path"])
265
  image_preprocessed = pipeline.preprocess_image(img)
266
- temp_processed_path = os.path.join(TMP_DIR, "temp_proc.png")
267
  image_preprocessed.save(temp_processed_path)
268
 
269
  camera_params = get_camera_params_wild_moge(
@@ -271,6 +377,7 @@ def generate_3d(
271
  mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL,
272
  image_resolution=WILD_IMAGE_RESOLUTION,
273
  )
 
274
 
275
  ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength,
276
  "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t}
@@ -296,12 +403,18 @@ def generate_3d(
296
  mesh = mesh_list[0]
297
  state_path = pack_state(shape_slat, tex_slat, res)
298
 
 
299
  mesh.simplify(16777216)
 
 
 
300
  renders = render_utils.render_proj_aligned_video(
301
  mesh, camera_angle_x=camera_params['camera_angle_x'],
302
- distance=camera_params['distance'], resolution=1024,
303
  num_frames=STEPS, envmap=envmap,
 
304
  )
 
305
 
306
  # Save renders and return paths
307
  render_files = {}
@@ -313,6 +426,7 @@ def generate_3d(
313
  mode_files.append(FileData(path=p))
314
  render_files[mode_key] = mode_files
315
 
 
316
  return {
317
  "render_paths": render_files,
318
  "state_path": os.path.abspath(state_path)
@@ -320,10 +434,16 @@ def generate_3d(
320
 
321
  @app.api()
322
  @spaces.GPU(duration=240)
323
- def extract_glb_api(state_path: str, decimation_target: int, texture_size: int) -> FileData:
324
  init_models()
 
 
 
325
  shape_slat, tex_slat, res = unpack_state(state_path)
326
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
 
 
 
327
  glb = o_voxel.postprocess.to_glb(
328
  vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
329
  coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
@@ -341,6 +461,7 @@ def extract_glb_api(state_path: str, decimation_target: int, texture_size: int)
341
 
342
  out_glb = os.path.join(TMP_DIR, f"result_{int(time.time()*1000)}.glb")
343
  glb.export(out_glb, extension_webp=True)
 
344
  return FileData(path=out_glb)
345
 
346
  # Mount assets and tmp for direct access
 
204
  'coords': shape_slat.coords.cpu().numpy(),
205
  'res': res,
206
  }
207
+ import random
208
+ state_path = os.path.join(TMP_DIR, f"state_{int(time.time()*1000)}_{random.randint(0,9999):04d}.npz")
209
  np.savez_compressed(state_path, **state_data)
210
  return state_path
211
 
 
218
  tex_slat = shape_slat.replace(torch.from_numpy(data['tex_slat_feats']).cuda())
219
  return shape_slat, tex_slat, int(data['res'])
220
 
221
+ # ============================================================================
222
+ # Progress Tracking (SSE-based, tqdm interception, multi-session)
223
+ # ============================================================================
224
+
225
+ import asyncio
226
+ import queue
227
+ from fastapi.responses import StreamingResponse
228
+ from fastapi import Request
229
+
230
+ # Per-session progress queues
231
+ _progress_queues: Dict[str, queue.Queue] = {}
232
+ _active_session: str = "" # Which session is currently running GPU work
233
+
234
+ def _reset_progress(session_id: str):
235
+ global _active_session
236
+ _active_session = session_id
237
+ if session_id not in _progress_queues:
238
+ _progress_queues[session_id] = queue.Queue()
239
+ # Drain old items
240
+ q = _progress_queues[session_id]
241
+ while not q.empty():
242
+ try:
243
+ q.get_nowait()
244
+ except:
245
+ break
246
+
247
+ def _update_progress(stage: str, step: int, total: int):
248
+ data = {"stage": stage, "step": step, "total": total, "done": False}
249
+ session_id = _active_session
250
+ if session_id and session_id in _progress_queues:
251
+ try:
252
+ _progress_queues[session_id].put_nowait(data)
253
+ except:
254
+ pass
255
+
256
+ def _finish_progress():
257
+ session_id = _active_session
258
+ if session_id and session_id in _progress_queues:
259
+ try:
260
+ _progress_queues[session_id].put_nowait({"done": True})
261
+ except:
262
+ pass
263
+ # Schedule cleanup after a short delay (let SSE client receive the done signal)
264
+ def _cleanup():
265
+ time.sleep(5)
266
+ _progress_queues.pop(session_id, None)
267
+ threading.Thread(target=_cleanup, daemon=True).start()
268
+
269
+ # Monkey-patch tqdm to intercept progress
270
+ import tqdm as _tqdm_module
271
+
272
+ _original_tqdm = _tqdm_module.tqdm
273
+
274
+ class _TqdmProgressInterceptor(_original_tqdm):
275
+ """Wraps tqdm to push progress updates to SSE."""
276
+ def __init__(self, *args, **kwargs):
277
+ self._stage_desc = kwargs.get('desc', 'Processing')
278
+ super().__init__(*args, **kwargs)
279
+
280
+ def update(self, n=1):
281
+ super().update(n)
282
+ _update_progress(self._stage_desc, self.n, self.total or 0)
283
+
284
+ # Patch tqdm globally
285
+ _tqdm_module.tqdm = _TqdmProgressInterceptor
286
+ # Also patch the direct import in the sampler module and render_utils
287
+ import trellis2.pipelines.samplers.flow_euler as _fe_module
288
+ _fe_module.tqdm = _TqdmProgressInterceptor
289
+ import trellis2.utils.render_utils as _ru_module
290
+ _ru_module.tqdm = _TqdmProgressInterceptor
291
+
292
  # ============================================================================
293
  # API Implementation
294
  # ============================================================================
 
301
  with open(html_path, "r", encoding="utf-8") as f:
302
  return HTMLResponse(content=f.read())
303
 
304
+ @app.get("/progress")
305
+ async def progress_sse(request: Request):
306
+ """SSE endpoint for real-time progress updates during generation."""
307
+ session_id = request.query_params.get("session_id", "")
308
+ if session_id and session_id not in _progress_queues:
309
+ _progress_queues[session_id] = queue.Queue()
310
+
311
+ async def event_stream():
312
+ q = _progress_queues.get(session_id)
313
+ timeout_count = 0
314
+ while True:
315
+ if q:
316
+ try:
317
+ data = q.get_nowait()
318
+ yield f"data: {json.dumps(data)}\n\n"
319
+ if data.get("done"):
320
+ break
321
+ timeout_count = 0
322
+ except queue.Empty:
323
+ yield f": keepalive\n\n"
324
+ timeout_count += 1
325
+ else:
326
+ yield f": keepalive\n\n"
327
+ timeout_count += 1
328
+ # Timeout after 5 minutes of no data
329
+ if timeout_count > 1000:
330
+ break
331
+ await asyncio.sleep(0.3)
332
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
333
+
334
  @app.api()
335
  def preprocess(image: FileData) -> FileData:
336
  init_models()
 
358
  tex_slat_guidance_rescale: float = 0.0,
359
  tex_slat_sampling_steps: int = 12,
360
  tex_slat_rescale_t: float = 3.0,
361
+ session_id: str = "",
362
  ) -> Dict:
363
  init_models()
364
+ _reset_progress(session_id)
365
+ _update_progress("Preprocessing & Camera Estimation", 0, 1)
366
+
367
  torch.manual_seed(seed)
368
  hr_resolution = int(resolution)
369
 
370
  img = Image.open(image["path"])
371
  image_preprocessed = pipeline.preprocess_image(img)
372
+ temp_processed_path = os.path.join(TMP_DIR, f"temp_proc_{session_id[:8]}_{int(time.time()*1000)}.png")
373
  image_preprocessed.save(temp_processed_path)
374
 
375
  camera_params = get_camera_params_wild_moge(
 
377
  mesh_scale=WILD_MESH_SCALE, extend_pixel=WILD_EXTEND_PIXEL,
378
  image_resolution=WILD_IMAGE_RESOLUTION,
379
  )
380
+ _update_progress("Preprocessing & Camera Estimation", 1, 1)
381
 
382
  ss_sampler_override = {"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength,
383
  "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t}
 
403
  mesh = mesh_list[0]
404
  state_path = pack_state(shape_slat, tex_slat, res)
405
 
406
+ _update_progress("Rendering views", 0, 1)
407
  mesh.simplify(16777216)
408
+ cam_dist = camera_params['distance']
409
+ near = max(0.01, cam_dist - 2.0)
410
+ far = cam_dist + 10.0
411
  renders = render_utils.render_proj_aligned_video(
412
  mesh, camera_angle_x=camera_params['camera_angle_x'],
413
+ distance=cam_dist, resolution=1024,
414
  num_frames=STEPS, envmap=envmap,
415
+ near=near, far=far,
416
  )
417
+ _update_progress("Rendering views", 1, 1)
418
 
419
  # Save renders and return paths
420
  render_files = {}
 
426
  mode_files.append(FileData(path=p))
427
  render_files[mode_key] = mode_files
428
 
429
+ _finish_progress()
430
  return {
431
  "render_paths": render_files,
432
  "state_path": os.path.abspath(state_path)
 
434
 
435
  @app.api()
436
  @spaces.GPU(duration=240)
437
+ def extract_glb_api(state_path: str, decimation_target: int, texture_size: int, session_id: str = "") -> FileData:
438
  init_models()
439
+ _reset_progress(session_id)
440
+ _update_progress("Decoding latent", 0, 1)
441
+
442
  shape_slat, tex_slat, res = unpack_state(state_path)
443
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
444
+ _update_progress("Decoding latent", 1, 1)
445
+
446
+ _update_progress("Extracting GLB mesh", 0, 1)
447
  glb = o_voxel.postprocess.to_glb(
448
  vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
449
  coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout,
 
461
 
462
  out_glb = os.path.join(TMP_DIR, f"result_{int(time.time()*1000)}.glb")
463
  glb.export(out_glb, extension_webp=True)
464
+ _finish_progress()
465
  return FileData(path=out_glb)
466
 
467
  # Mount assets and tmp for direct access
index.html CHANGED
@@ -548,9 +548,9 @@
548
 
549
  <div class="sidebar-section" style="margin-bottom: 1.5rem;">
550
  <p style="font-size: 0.82rem; color: var(--text-dim); line-height: 1.6;">
551
- Upload an image and click Generate.<br>
552
- Click Extract GLB to export.<br>
553
- Download the generated GLB file.
554
  </p>
555
  <p style="font-size: 0.72rem; color: var(--text-dim); line-height: 1.5; margin-top: 0.5rem; opacity: 0.7;">
556
  Note: Camera estimated automatically via MoGe-2.
@@ -651,9 +651,9 @@
651
  <span>3. RESULT</span>
652
  </div>
653
  </div>
654
- <div style="color: var(--text-dim); font-size: 0.8rem; font-weight: 500;">
655
- Pixal3D V1.1
656
- </div>
657
  </header>
658
 
659
  <div class="workspace">
@@ -716,9 +716,21 @@
716
 
717
  <div class="loading-overlay" id="loading-overlay">
718
  <div class="loader-ring"></div>
719
- <div style="text-align: center;">
720
- <h2 id="loading-title" style="font-family: 'Outfit'; margin-bottom: 0.5rem;">Synthesizing Geometry</h2>
721
- <p id="loading-subtitle" style="color: var(--text-dim);">The neural engine is crafting your 3D model...</p>
 
 
 
 
 
 
 
 
 
 
 
 
722
  </div>
723
  </div>
724
 
@@ -736,6 +748,7 @@
736
  let generationResult = null;
737
  let currentMode = "shaded_forest";
738
  let currentFrame = 0;
 
739
  let currentStep = 1;
740
 
741
  const MODES = [
@@ -782,6 +795,52 @@
782
  link.click();
783
  };
784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785
  // Slider
786
  document.getElementById('angle-slider').oninput = (e) => {
787
  currentFrame = parseInt(e.target.value);
@@ -815,19 +874,24 @@
815
  img.style.display = 'block';
816
  hint.style.display = 'none';
817
  document.getElementById('generate-btn').disabled = false;
818
- // Update reference thumbnails
819
- const thumb2 = document.getElementById('ref-thumb-2');
820
- const thumb3 = document.getElementById('ref-thumb-3');
821
- thumb2.src = e.target.result;
822
- thumb2.style.display = 'block';
823
- thumb3.src = e.target.result;
824
- thumb3.style.display = 'block';
825
  setStep(1);
826
  };
827
  reader.readAsDataURL(file);
828
 
829
- // Background pre-warm
830
- client.predict("/preprocess", { image: handle_file(file) }).catch(console.error);
 
 
 
 
 
 
 
 
 
 
 
 
831
  }
832
 
833
  function setStep(num) {
@@ -850,7 +914,8 @@
850
  async function startGeneration() {
851
  if (!currentFile) return;
852
 
853
- showLoading("Neural Synthesis", "Optimizing geometry for " + (document.getElementById('resolution').value) + "px output...");
 
854
  try {
855
  const params = {
856
  image: handle_file(currentFile),
@@ -858,23 +923,89 @@
858
  resolution: parseInt(document.getElementById('resolution').value),
859
  ss_guidance_strength: parseFloat(document.getElementById('ss_gs').value),
860
  ss_sampling_steps: parseInt(document.getElementById('ss_steps').value),
861
- shape_slat_guidance_strength: parseFloat(document.getElementById('shape_gs').value)
 
862
  };
863
 
864
  const result = await client.predict("/generate_3d", params);
865
  generationResult = result.data[0];
866
 
 
867
  populateFrames(generationResult.render_paths);
868
  setStep(2);
869
  hideLoading();
870
  showToast("Generation complete!");
871
  } catch (err) {
872
  console.error(err);
 
873
  hideLoading();
874
  showToast("An error occurred during synthesis.");
875
  }
876
  }
877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
  function populateFrames(renderPaths) {
879
  const container = document.getElementById('frame-container');
880
  container.innerHTML = '';
@@ -912,17 +1043,20 @@
912
  async function startExtraction() {
913
  if (!generationResult) return;
914
 
915
- showLoading("Finalizing Mesh", "Performing PBR texture baking and decimation...");
 
916
  try {
917
  const params = {
918
  state_path: generationResult.state_path,
919
  decimation_target: parseInt(document.getElementById('decimation').value),
920
- texture_size: 4096 // Constant for highest quality
 
921
  };
922
 
923
  const result = await client.predict("/extract_glb_api", params);
924
  const glbUrl = result.data[0].url;
925
 
 
926
  const viewer = document.getElementById('main-3d-viewer');
927
  viewer.src = glbUrl;
928
  setStep(3);
@@ -930,6 +1064,7 @@
930
  showToast("3D Asset ready!");
931
  } catch (err) {
932
  console.error(err);
 
933
  hideLoading();
934
  showToast("Extraction failed.");
935
  }
@@ -951,7 +1086,7 @@
951
  div.className = 'example-item';
952
  div.innerHTML = `<img src="${path}">`;
953
  div.onclick = async () => {
954
- showLoading("Fetching Sample", "Loading high-resolution asset from gallery...");
955
  const res = await fetch(path);
956
  const blob = await res.blob();
957
  const file = new File([blob], "sample.webp", { type: "image/webp" });
@@ -998,14 +1133,15 @@
998
  document.getElementById('seed-display').textContent = s;
999
  };
1000
 
1001
- function showLoading(title, sub) {
1002
- document.getElementById('loading-title').textContent = title;
1003
- document.getElementById('loading-subtitle').textContent = sub;
1004
  document.getElementById('loading-overlay').style.display = 'flex';
1005
  }
1006
 
1007
  function hideLoading() {
1008
  document.getElementById('loading-overlay').style.display = 'none';
 
 
 
1009
  }
1010
 
1011
  function showToast(msg) {
 
548
 
549
  <div class="sidebar-section" style="margin-bottom: 1.5rem;">
550
  <p style="font-size: 0.82rem; color: var(--text-dim); line-height: 1.6;">
551
+ 1. Upload an image and click Generate.<br>
552
+ 2. Click Extract GLB to export.<br>
553
+ 3. Download the generated GLB file.
554
  </p>
555
  <p style="font-size: 0.72rem; color: var(--text-dim); line-height: 1.5; margin-top: 0.5rem; opacity: 0.7;">
556
  Note: Camera estimated automatically via MoGe-2.
 
651
  <span>3. RESULT</span>
652
  </div>
653
  </div>
654
+ <button class="btn btn-outline" id="clear-btn" title="Clear all & restart" style="width: 34px; height: 34px; padding: 0; border-radius: 50%; display: flex; align-items: center; justify-content: center; border-color: rgba(248,113,113,0.3);">
655
+ <i data-lucide="trash-2" style="width: 16px; height: 16px; color: #f87171;"></i>
656
+ </button>
657
  </header>
658
 
659
  <div class="workspace">
 
716
 
717
  <div class="loading-overlay" id="loading-overlay">
718
  <div class="loader-ring"></div>
719
+ <div style="text-align: center; width: 100%; max-width: 500px; padding: 0 2rem;">
720
+ <!-- Progress stages -->
721
+ <div id="progress-stages" style="display: none; text-align: left;">
722
+ <div class="progress-stage" id="progress-stage-item" style="margin-bottom: 1rem;">
723
+ <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 0.4rem;">
724
+ <span id="progress-stage-name" style="font-size: 0.85rem; font-weight: 600; color: var(--primary);">Initializing...</span>
725
+ <span id="progress-step-text" style="font-size: 0.75rem; color: var(--text-dim); font-family: monospace;">0/0</span>
726
+ </div>
727
+ <div style="width: 100%; height: 6px; background: var(--border); border-radius: 3px; overflow: hidden;">
728
+ <div id="progress-bar-fill" style="width: 0%; height: 100%; background: linear-gradient(90deg, var(--primary), var(--accent)); border-radius: 3px; transition: width 0.3s ease;"></div>
729
+ </div>
730
+ </div>
731
+ <!-- Stage history log -->
732
+ <div id="progress-log" style="font-size: 0.75rem; color: var(--text-dim); line-height: 1.8; max-height: 180px; overflow-y: auto; margin-top: 1rem; padding-top: 0.5rem; border-top: 1px solid var(--border);"></div>
733
+ </div>
734
  </div>
735
  </div>
736
 
 
748
  let generationResult = null;
749
  let currentMode = "shaded_forest";
750
  let currentFrame = 0;
751
+ const sessionId = crypto.randomUUID();
752
  let currentStep = 1;
753
 
754
  const MODES = [
 
795
  link.click();
796
  };
797
 
798
+ // Clear button
799
+ document.getElementById('clear-btn').onclick = () => {
800
+ // Reset state
801
+ currentFile = null;
802
+ generationResult = null;
803
+ currentFrame = 0;
804
+ currentMode = "shaded_forest";
805
+
806
+ // Reset source preview
807
+ document.getElementById('source-preview').src = '';
808
+ document.getElementById('source-preview').style.display = 'none';
809
+ document.getElementById('upload-hint').style.display = 'flex';
810
+ document.getElementById('file-input').value = '';
811
+
812
+ // Reset generate button
813
+ document.getElementById('generate-btn').disabled = true;
814
+
815
+ // Reset preview frames
816
+ document.getElementById('frame-container').innerHTML = '';
817
+ document.getElementById('angle-slider').value = 0;
818
+ document.getElementById('angle-display').textContent = '00';
819
+
820
+ // Reset 3D viewer
821
+ document.getElementById('main-3d-viewer').removeAttribute('src');
822
+
823
+ // Reset thumbnails
824
+ document.getElementById('ref-thumb-2').style.display = 'none';
825
+ document.getElementById('ref-thumb-2').src = '';
826
+ document.getElementById('ref-thumb-3').style.display = 'none';
827
+ document.getElementById('ref-thumb-3').src = '';
828
+
829
+ // Reset mode tabs
830
+ document.querySelectorAll('.mode-tab').forEach(t => {
831
+ t.classList.toggle('active', t.textContent === 'Forest');
832
+ });
833
+
834
+ // Go back to step 1
835
+ setStep(1);
836
+ showToast("Cleared. Ready for new upload.");
837
+ };
838
+
839
+ // Step navigation click
840
+ document.getElementById('step-1').onclick = () => setStep(1);
841
+ document.getElementById('step-2').onclick = () => { if (generationResult) setStep(2); };
842
+ document.getElementById('step-3').onclick = () => { if (document.getElementById('main-3d-viewer').src) setStep(3); };
843
+
844
  // Slider
845
  document.getElementById('angle-slider').oninput = (e) => {
846
  currentFrame = parseInt(e.target.value);
 
874
  img.style.display = 'block';
875
  hint.style.display = 'none';
876
  document.getElementById('generate-btn').disabled = false;
 
 
 
 
 
 
 
877
  setStep(1);
878
  };
879
  reader.readAsDataURL(file);
880
 
881
+ // Call preprocess and update with segmented result
882
+ try {
883
+ const result = await client.predict("/preprocess", { image: handle_file(file) });
884
+ const processedUrl = result.data[0].url;
885
+ if (processedUrl) {
886
+ document.getElementById('source-preview').src = processedUrl;
887
+ document.getElementById('ref-thumb-2').src = processedUrl;
888
+ document.getElementById('ref-thumb-2').style.display = 'block';
889
+ document.getElementById('ref-thumb-3').src = processedUrl;
890
+ document.getElementById('ref-thumb-3').style.display = 'block';
891
+ }
892
+ } catch (err) {
893
+ console.error("Preprocess failed:", err);
894
+ }
895
  }
896
 
897
  function setStep(num) {
 
914
  async function startGeneration() {
915
  if (!currentFile) return;
916
 
917
+ showLoading();
918
+ startProgressListener();
919
  try {
920
  const params = {
921
  image: handle_file(currentFile),
 
923
  resolution: parseInt(document.getElementById('resolution').value),
924
  ss_guidance_strength: parseFloat(document.getElementById('ss_gs').value),
925
  ss_sampling_steps: parseInt(document.getElementById('ss_steps').value),
926
+ shape_slat_guidance_strength: parseFloat(document.getElementById('shape_gs').value),
927
+ session_id: sessionId
928
  };
929
 
930
  const result = await client.predict("/generate_3d", params);
931
  generationResult = result.data[0];
932
 
933
+ stopProgressListener();
934
  populateFrames(generationResult.render_paths);
935
  setStep(2);
936
  hideLoading();
937
  showToast("Generation complete!");
938
  } catch (err) {
939
  console.error(err);
940
+ stopProgressListener();
941
  hideLoading();
942
  showToast("An error occurred during synthesis.");
943
  }
944
  }
945
 
946
+ // SSE Progress Listener
947
+ let progressEventSource = null;
948
+ let lastStageName = "";
949
+
950
+ function startProgressListener() {
951
+ // Show progress UI
952
+ document.getElementById('progress-stages').style.display = 'block';
953
+ document.getElementById('progress-log').innerHTML = '';
954
+ document.getElementById('progress-stage-name').textContent = 'Initializing...';
955
+ document.getElementById('progress-step-text').textContent = '';
956
+ document.getElementById('progress-bar-fill').style.width = '0%';
957
+ lastStageName = "";
958
+
959
+ progressEventSource = new EventSource(`/progress?session_id=${sessionId}`);
960
+ progressEventSource.onmessage = (event) => {
961
+ try {
962
+ const data = JSON.parse(event.data);
963
+ if (data.done) {
964
+ stopProgressListener();
965
+ return;
966
+ }
967
+ updateProgressUI(data);
968
+ } catch (e) {}
969
+ };
970
+ progressEventSource.onerror = () => {
971
+ // Silently ignore SSE errors, generation continues
972
+ };
973
+ }
974
+
975
+ function stopProgressListener() {
976
+ if (progressEventSource) {
977
+ progressEventSource.close();
978
+ progressEventSource = null;
979
+ }
980
+ }
981
+
982
+ function updateProgressUI(data) {
983
+ const stageName = data.stage || '';
984
+ const step = data.step || 0;
985
+ const total = data.total || 0;
986
+
987
+ // If stage changed, log the previous one as completed
988
+ if (stageName && stageName !== lastStageName) {
989
+ if (lastStageName) {
990
+ const logEl = document.getElementById('progress-log');
991
+ logEl.innerHTML += `<div style="display:flex;align-items:center;gap:0.4rem;"><span style="color:var(--accent);">✓</span> ${lastStageName}</div>`;
992
+ logEl.scrollTop = logEl.scrollHeight;
993
+ }
994
+ lastStageName = stageName;
995
+ }
996
+
997
+ // Update current stage display
998
+ document.getElementById('progress-stage-name').textContent = stageName;
999
+ if (total > 0) {
1000
+ document.getElementById('progress-step-text').textContent = `${step}/${total}`;
1001
+ const pct = Math.min(100, (step / total) * 100);
1002
+ document.getElementById('progress-bar-fill').style.width = pct + '%';
1003
+ } else {
1004
+ document.getElementById('progress-step-text').textContent = '';
1005
+ document.getElementById('progress-bar-fill').style.width = '0%';
1006
+ }
1007
+ }
1008
+
1009
  function populateFrames(renderPaths) {
1010
  const container = document.getElementById('frame-container');
1011
  container.innerHTML = '';
 
1043
  async function startExtraction() {
1044
  if (!generationResult) return;
1045
 
1046
+ showLoading();
1047
+ startProgressListener();
1048
  try {
1049
  const params = {
1050
  state_path: generationResult.state_path,
1051
  decimation_target: parseInt(document.getElementById('decimation').value),
1052
+ texture_size: 4096,
1053
+ session_id: sessionId
1054
  };
1055
 
1056
  const result = await client.predict("/extract_glb_api", params);
1057
  const glbUrl = result.data[0].url;
1058
 
1059
+ stopProgressListener();
1060
  const viewer = document.getElementById('main-3d-viewer');
1061
  viewer.src = glbUrl;
1062
  setStep(3);
 
1064
  showToast("3D Asset ready!");
1065
  } catch (err) {
1066
  console.error(err);
1067
+ stopProgressListener();
1068
  hideLoading();
1069
  showToast("Extraction failed.");
1070
  }
 
1086
  div.className = 'example-item';
1087
  div.innerHTML = `<img src="${path}">`;
1088
  div.onclick = async () => {
1089
+ showLoading();
1090
  const res = await fetch(path);
1091
  const blob = await res.blob();
1092
  const file = new File([blob], "sample.webp", { type: "image/webp" });
 
1133
  document.getElementById('seed-display').textContent = s;
1134
  };
1135
 
1136
+ function showLoading() {
 
 
1137
  document.getElementById('loading-overlay').style.display = 'flex';
1138
  }
1139
 
1140
  function hideLoading() {
1141
  document.getElementById('loading-overlay').style.display = 'none';
1142
+ document.getElementById('progress-stages').style.display = 'none';
1143
+ document.getElementById('progress-log').innerHTML = '';
1144
+ document.getElementById('progress-bar-fill').style.width = '0%';
1145
  }
1146
 
1147
  function showToast(msg) {
trellis2/utils/render_utils.py CHANGED
@@ -198,8 +198,14 @@ def render_proj_aligned_video(sample, camera_angle_x, distance, resolution=1024,
198
  extrinsics_list.append(extr_rotated)
199
  intrinsics_list.append(intr_first)
200
 
 
 
 
 
 
 
201
  return render_frames(sample, extrinsics_list, intrinsics_list,
202
- {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
203
 
204
 
205
  def make_pbr_vis_frames(result, resolution=1024):
 
198
  extrinsics_list.append(extr_rotated)
199
  intrinsics_list.append(intr_first)
200
 
201
+ render_options = {'resolution': resolution, 'bg_color': bg_color}
202
+ if 'near' in kwargs:
203
+ render_options['near'] = kwargs.pop('near')
204
+ if 'far' in kwargs:
205
+ render_options['far'] = kwargs.pop('far')
206
+
207
  return render_frames(sample, extrinsics_list, intrinsics_list,
208
+ render_options, **kwargs)
209
 
210
 
211
  def make_pbr_vis_frames(result, resolution=1024):