| import os |
| import sys |
| import cv2 |
| import math |
| import json |
| import torch |
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| from PIL import ImageOps |
| from pathlib import Path |
| import multiprocessing as mp |
| from vitra.utils.data_utils import resize_short_side_to_target, load_normalizer, recon_traj |
| from vitra.utils.config_utils import load_config |
| from scipy.spatial.transform import Rotation as R |
| import spaces |
|
|
| repo_root = Path(__file__).parent |
| sys.path.insert(0, str(repo_root)) |
|
|
| from visualization.visualize_core import HandVisualizer, normalize_camera_intrinsics, save_to_video, Renderer, process_single_hand_labels |
| from visualization.visualize_core import Config as HandConfig |
|
|
| |
| from inference_human_prediction import ( |
| get_state, |
| euler_traj_to_rotmat_traj, |
| ) |
|
|
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| vla_model = None |
| vla_normalizer = None |
| hand_reconstructor = None |
| visualizer = None |
| hand_config = None |
| app_config = None |
|
|
| def vla_predict(model, normalizer, image, instruction, state, state_mask, |
| action_mask, fov, num_ddim_steps, cfg_scale, sample_times): |
| """ |
| VLA prediction function that runs on GPU. |
| Model is already loaded and moved to CUDA in main process. |
| """ |
| from vitra.datasets.human_dataset import pad_state_human, pad_action |
| from vitra.datasets.dataset_utils import ActionFeature, StateFeature |
| |
| |
| norm_state = normalizer.normalize_state(state.copy()) |
| |
| |
| unified_action_dim = ActionFeature.ALL_FEATURES[1] |
| unified_state_dim = StateFeature.ALL_FEATURES[1] |
| |
| unified_state, unified_state_mask = pad_state_human( |
| state=norm_state, |
| state_mask=state_mask, |
| action_dim=normalizer.action_mean.shape[0], |
| state_dim=normalizer.state_mean.shape[0], |
| unified_state_dim=unified_state_dim, |
| ) |
| _, unified_action_mask = pad_action( |
| actions=None, |
| action_mask=action_mask.copy(), |
| action_dim=normalizer.action_mean.shape[0], |
| unified_action_dim=unified_action_dim |
| ) |
| |
| |
| device = torch.device('cuda') |
| fov = torch.from_numpy(fov).unsqueeze(0).to(device) |
| unified_state = unified_state.unsqueeze(0).to(device) |
| unified_state_mask = unified_state_mask.unsqueeze(0).to(device) |
| unified_action_mask = unified_action_mask.unsqueeze(0).to(device) |
| |
| |
| model = model.to(device) |
| |
| |
| norm_action = model.predict_action( |
| image=image, |
| instruction=instruction, |
| current_state=unified_state, |
| current_state_mask=unified_state_mask, |
| action_mask_torch=unified_action_mask, |
| num_ddim_steps=num_ddim_steps, |
| cfg_scale=cfg_scale, |
| fov=fov, |
| sample_times=sample_times, |
| ) |
| |
| |
| norm_action = norm_action[:, :, :102] |
| unnorm_action = normalizer.unnormalize_action(norm_action) |
| |
| |
| if isinstance(unnorm_action, torch.Tensor): |
| unnorm_action_np = unnorm_action.cpu().numpy() |
| else: |
| unnorm_action_np = np.array(unnorm_action) |
|
|
| return unnorm_action_np |
|
|
| class GradioConfig: |
| """Configuration for Gradio app""" |
| def __init__(self): |
| |
| self.config_path = 'microsoft/VITRA-VLA-3B' |
| self.model_path = None |
| self.statistics_path = None |
| |
| |
| self.hawor_model_path = 'arnoldland/HAWOR' |
| self.detector_path = './weights/hawor/external/detector.pt' |
| self.moge_model_name = 'Ruicheng/moge-2-vitl' |
| self.mano_path = './weights/mano' |
| |
| |
| self.fps = 8 |
|
|
|
|
| def initialize_services(): |
| """Initialize all models once at startup""" |
| global vla_model, vla_normalizer, hand_reconstructor, visualizer, hand_config, app_config |
| |
| if vla_model is not None: |
| return "Services already initialized" |
| |
| try: |
| app_config = GradioConfig() |
| |
| |
| hf_token = os.environ.get('HF_TOKEN', None) |
| if hf_token: |
| from huggingface_hub import login |
| login(token=hf_token) |
| print("Logged in to HuggingFace Hub") |
| |
| |
| print("Loading VLA model...") |
| from vitra.models import load_model |
| from vitra.utils.data_utils import load_normalizer |
| |
| configs = load_config(app_config.config_path) |
| if app_config.model_path is not None: |
| configs['model_load_path'] = app_config.model_path |
| if app_config.statistics_path is not None: |
| configs['statistics_path'] = app_config.statistics_path |
| |
| |
| globals()['vla_model'] = load_model(configs).cuda() |
| globals()['vla_model'].eval() |
| globals()['vla_normalizer'] = load_normalizer(configs) |
| print("VLA model loaded") |
| |
| |
| print("Loading Hand Reconstructor...") |
| from data.tools.hand_recon_core import Config, HandReconstructor |
| |
| class ArgsObj: |
| pass |
| args_obj = ArgsObj() |
| args_obj.hawor_model_path = app_config.hawor_model_path |
| args_obj.detector_path = app_config.detector_path |
| args_obj.moge_model_name = app_config.moge_model_name |
| args_obj.mano_path = app_config.mano_path |
| |
| recon_config = Config(args_obj) |
| globals()['hand_reconstructor'] = HandReconstructor(config=recon_config, device='cuda') |
| print("Hand Reconstructor loaded") |
| |
| |
| print("Loading Visualizer...") |
| globals()['hand_config'] = HandConfig(app_config) |
| globals()['hand_config'].FPS = app_config.fps |
| globals()['visualizer'] = HandVisualizer(globals()['hand_config'], render_gradual_traj=False) |
| globals()['visualizer'].mano = globals()['visualizer'].mano.cuda() |
| print("Visualizer loaded") |
| |
| return "✅ All services initialized successfully!" |
| |
| except Exception as e: |
| import traceback |
| return f"❌ Failed to initialize services: {str(e)}\n{traceback.format_exc()}" |
|
|
|
|
| def validate_image_dimensions(image): |
| """Validate image dimensions before GPU allocation. |
| Returns (is_valid, message) |
| """ |
| if image is None: |
| return True, "" |
| |
| |
| if isinstance(image, np.ndarray): |
| img_pil = Image.fromarray(image) |
| else: |
| img_pil = image |
| |
| |
| width, height = img_pil.size |
| if width < height: |
| error_msg = f"❌ Please upload a landscape image (width ≥ height).\nCurrent image: {width}x{height} (portrait orientation)" |
| return False, error_msg |
| |
| return True, "" |
|
|
|
|
| def validate_and_process_wrapper(image, session_state, progress=gr.Progress()): |
| """Wrapper function to validate image before GPU allocation""" |
| |
| if image is None: |
| return ("Waiting for image upload...", |
| gr.update(interactive=False), |
| None, |
| False, |
| False, |
| session_state) |
| |
| |
| is_valid, error_msg = validate_image_dimensions(image) |
| if not is_valid: |
| return (error_msg, |
| gr.update(interactive=False), |
| None, |
| False, |
| False, |
| session_state) |
| |
| |
| return process_image_upload(image, session_state, progress) |
|
|
|
|
| @spaces.GPU(duration=120) |
| def process_image_upload(image, session_state, progress=gr.Progress()): |
| """Process uploaded image and run hand reconstruction""" |
| global hand_reconstructor |
| if torch.cuda.is_available(): |
| print("CUDA is available for image processing") |
| else: |
| print("CUDA is NOT available for image processing") |
| |
| import time |
| start_time = time.time() |
| while time.time() - start_time < 60: |
| try: |
| if torch.cuda.is_available(): |
| torch.zeros(1).cuda() |
| break |
| except: |
| time.sleep(2) |
| |
| if hand_reconstructor is None: |
| return ("Services not initialized. Please wait for initialization to complete.", |
| gr.update(interactive=False), |
| None, |
| False, |
| False, |
| session_state) |
| |
| try: |
| progress(0, desc="Preparing image...") |
| |
| |
| if isinstance(image, np.ndarray): |
| img_pil = Image.fromarray(image) |
| else: |
| img_pil = image |
| |
| |
| session_state['current_image'] = img_pil |
| |
| progress(0.2, desc="Running hand reconstruction...") |
| |
| |
| image_np = np.array(img_pil) |
| image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
| |
| |
| image_list = [image_bgr] |
| hand_data = hand_reconstructor.recon(image_list) |
|
|
| session_state['current_hand_data'] = hand_data |
| |
| progress(1.0, desc="Hand reconstruction complete!") |
| |
| |
| has_left = 'left' in hand_data and len(hand_data['left']) > 0 |
| has_right = 'right' in hand_data and len(hand_data['right']) > 0 |
| |
| info_msg = "✅ Hand reconstruction complete!\n" |
| info_msg += f"Detected hands: " |
| if has_left and has_right: |
| info_msg += "Left ✓, Right ✓" |
| elif has_left: |
| info_msg += "Left ✓, Right ✗" |
| elif has_right: |
| info_msg += "Left ✗, Right ✓" |
| else: |
| info_msg += "None detected" |
| |
| |
| session_state['detected_left'] = has_left |
| session_state['detected_right'] = has_right |
| |
| |
| |
| return (info_msg, |
| gr.update(interactive=True), |
| hand_data, |
| has_left, |
| has_right, |
| session_state) |
| |
| except Exception as e: |
| import traceback |
| error_msg = f"❌ Hand reconstruction failed: {str(e)}\n{traceback.format_exc()}" |
| |
| session_state['detected_left'] = False |
| session_state['detected_right'] = False |
| |
| return (error_msg, |
| gr.update(interactive=True), |
| None, |
| False, |
| False, |
| session_state) |
|
|
| def update_checkboxes(has_left, has_right): |
| """Update checkbox states based on detected hands (no progress bar)""" |
|
|
| |
| left_checkbox_update = gr.update( |
| value=has_left, |
| interactive=True if has_left else False, |
| elem_classes="disabled-checkbox" if not has_left else "" |
| ) |
| right_checkbox_update = gr.update( |
| value=has_right, |
| interactive=True if has_right else False, |
| elem_classes="disabled-checkbox" if not has_right else "" |
| ) |
| |
| |
| left_instruction_update = gr.update( |
| interactive=has_left, |
| elem_classes="disabled-textbox" if not has_left else "" |
| ) |
| right_instruction_update = gr.update( |
| interactive=has_right, |
| elem_classes="disabled-textbox" if not has_right else "" |
| ) |
| |
| return left_checkbox_update, right_checkbox_update, left_instruction_update, right_instruction_update |
|
|
|
|
| def update_instruction_interactivity(use_left, use_right): |
| """Update instruction textbox interactivity based on checkbox states""" |
| left_update = gr.update( |
| interactive=use_left, |
| elem_classes="disabled-textbox" if not use_left else "" |
| ) |
| right_update = gr.update( |
| interactive=use_right, |
| elem_classes="disabled-textbox" if not use_right else "" |
| ) |
| return left_update, right_update |
|
|
| def update_final_instruction(left_instruction, right_instruction, use_left, use_right): |
| """Update final instruction based on left/right inputs and checkbox states""" |
| |
| left_text = left_instruction if use_left else "None." |
| right_text = right_instruction if use_right else "None." |
| |
| final = f"Left hand: {left_text} Right hand: {right_text}" |
| |
| |
| styled_output = f"""<div style='padding: 12px; background-color: #f0f7ff; border-left: 4px solid #4A90E2; border-radius: 4px; margin-top: 10px;'> |
| <strong style='color: #2c5282;'>📝 Final Instruction:</strong><br> |
| <span style='color: #1a365d; font-size: 14px;'>{final}</span> |
| </div>""" |
| |
| |
| return gr.update(value=styled_output), final |
|
|
| def parse_instruction(instruction_text): |
| """Parse combined instruction into left and right parts""" |
| import re |
| |
| |
| left_match = re.search(r'Left(?:\s+hand)?:\s*([^.]*(?:\.[^LR]*)*)(?=Right|$)', instruction_text, re.IGNORECASE) |
| right_match = re.search(r'Right(?:\s+hand)?:\s*(.+?)$', instruction_text, re.IGNORECASE) |
| |
| left_text = left_match.group(1).strip() if left_match else "None." |
| right_text = right_match.group(1).strip() if right_match else "None." |
| |
| return left_text, right_text |
|
|
| @spaces.GPU(duration=120) |
| def generate_prediction(instruction, use_left, use_right, sample_times, num_ddim_steps, cfg_scale, hand_data, image, progress=gr.Progress()): |
| """Generate hand motion prediction and visualization""" |
| global vla_model, vla_normalizer, visualizer, hand_config, app_config |
| |
| |
| import time |
| start_time = time.time() |
| while time.time() - start_time < 60: |
| try: |
| if torch.cuda.is_available(): |
| torch.zeros(1).cuda() |
| break |
| except: |
| time.sleep(2) |
| |
| if hand_data is None: |
| return None, "Please upload an image and wait for hand reconstruction first" |
| |
| if not use_left and not use_right: |
| return None, "Please select at least one hand (left or right)" |
| |
| try: |
| progress(0, desc="Preparing data...") |
| |
| |
| if image is None: |
| return None, "Image not found. Please upload an image first." |
| |
| ori_w, ori_h = image.size |
| |
| try: |
| image = ImageOps.exif_transpose(image) |
| except Exception: |
| pass |
| |
| image_resized = resize_short_side_to_target(image, target=224) |
| w, h = image_resized.size |
| |
| |
| current_state_left = None |
| current_state_right = None |
| beta_left = None |
| beta_right = None |
| |
| progress(0.1, desc="Extracting hand states...") |
| |
| if use_right: |
| current_state_right, beta_right, fov_x, _ = get_state(hand_data, hand_side='right') |
| if use_left: |
| current_state_left, beta_left, fov_x, _ = get_state(hand_data, hand_side='left') |
| |
| fov_x = fov_x * np.pi / 180 |
| f_ori = ori_w / np.tan(fov_x / 2) / 2 |
| fov_y = 2 * np.arctan(ori_h / (2 * f_ori)) |
| |
| f = w / np.tan(fov_x / 2) / 2 |
| intrinsics = np.array([ |
| [f, 0, w/2], |
| [0, f, h/2], |
| [0, 0, 1] |
| ]) |
| |
| |
| if current_state_left is None and current_state_right is None: |
| return None, "No valid hand states found" |
| |
| state_left = current_state_left if use_left else np.zeros_like(current_state_right) |
| beta_left = beta_left if use_left else np.zeros_like(beta_right) |
| state_right = current_state_right if use_right else np.zeros_like(current_state_left) |
| beta_right = beta_right if use_right else np.zeros_like(beta_left) |
| |
| state = np.concatenate([state_left, beta_left, state_right, beta_right], axis=0) |
| state_mask = np.array([use_left, use_right], dtype=bool) |
| |
| |
| configs = load_config(app_config.config_path) |
| chunk_size = configs.get('fwd_pred_next_n', 16) |
| action_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (chunk_size, 1)) |
| |
| fov = np.array([fov_x, fov_y], dtype=np.float32) |
| image_resized_np = np.array(image_resized) |
| |
| progress(0.3, desc="Running VLA inference...") |
| |
| |
| unnorm_action = vla_predict( |
| model=vla_model, |
| normalizer=vla_normalizer, |
| image=image_resized_np, |
| instruction=instruction, |
| state=state, |
| state_mask=state_mask, |
| action_mask=action_mask, |
| fov=fov, |
| num_ddim_steps=num_ddim_steps, |
| cfg_scale=cfg_scale, |
| sample_times=sample_times, |
| ) |
| |
| progress(0.6, desc="Visualizing predictions...") |
| |
| |
| fx_exo = intrinsics[0, 0] |
| fy_exo = intrinsics[1, 1] |
| renderer = Renderer(w, h, (fx_exo, fy_exo), 'cuda') |
| |
| T = chunk_size + 1 |
| traj_right_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
| traj_left_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
| |
| traj_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (T, 1)) |
| left_hand_mask = traj_mask[:, 0] |
| right_hand_mask = traj_mask[:, 1] |
| hand_mask = (left_hand_mask, right_hand_mask) |
| |
| all_rendered_frames = [] |
| |
| |
| for i in range(sample_times): |
| progress(0.6 + 0.3 * (i / sample_times), desc=f"Rendering sample {i+1}/{sample_times}...") |
| |
| traj_right = traj_right_list[i] |
| traj_left = traj_left_list[i] |
| |
| if use_left: |
| traj_left = recon_traj( |
| state=state_left, |
| rel_action=unnorm_action[i, :, 0:51], |
| ) |
| if use_right: |
| traj_right = recon_traj( |
| state=state_right, |
| rel_action=unnorm_action[i, :, 51:102], |
| ) |
| |
| left_hand_labels = { |
| 'transl_worldspace': traj_left[:, 0:3], |
| 'global_orient_worldspace': R.from_euler('xyz', traj_left[:, 3:6]).as_matrix(), |
| 'hand_pose': euler_traj_to_rotmat_traj(traj_left[:, 6:51], T), |
| 'beta': beta_left, |
| } |
| right_hand_labels = { |
| 'transl_worldspace': traj_right[:, 0:3], |
| 'global_orient_worldspace': R.from_euler('xyz', traj_right[:, 3:6]).as_matrix(), |
| 'hand_pose': euler_traj_to_rotmat_traj(traj_right[:, 6:51], T), |
| 'beta': beta_right, |
| } |
| |
| verts_left_worldspace, _ = process_single_hand_labels(left_hand_labels, left_hand_mask, visualizer.mano, is_left=True) |
| verts_right_worldspace, _ = process_single_hand_labels(right_hand_labels, right_hand_mask, visualizer.mano, is_left=False) |
| |
| hand_traj_wordspace = (verts_left_worldspace, verts_right_worldspace) |
| |
| R_w2c = np.broadcast_to(np.eye(3), (T, 3, 3)).copy() |
| t_w2c = np.zeros((T, 3, 1), dtype=np.float32) |
| extrinsics = (R_w2c, t_w2c) |
| |
| image_bgr = image_resized_np[..., ::-1] |
| resize_video_frames = [image_bgr] * T |
| save_frames = visualizer._render_hand_trajectory( |
| resize_video_frames, |
| hand_traj_wordspace, |
| hand_mask, |
| extrinsics, |
| renderer, |
| mode='first' |
| ) |
| |
| all_rendered_frames.append(save_frames) |
| |
| progress(0.95, desc="Creating output video...") |
| |
| |
| num_frames = len(all_rendered_frames[0]) |
| grid_cols = math.ceil(math.sqrt(sample_times)) |
| grid_rows = math.ceil(sample_times / grid_cols) |
| |
| combined_frames = [] |
| for frame_idx in range(num_frames): |
| sample_frames = [all_rendered_frames[i][frame_idx] for i in range(sample_times)] |
| |
| while len(sample_frames) < grid_rows * grid_cols: |
| black_frame = np.zeros_like(sample_frames[0]) |
| sample_frames.append(black_frame) |
| |
| rows = [] |
| for row_idx in range(grid_rows): |
| row_frames = sample_frames[row_idx * grid_cols:(row_idx + 1) * grid_cols] |
| row_concat = np.concatenate(row_frames, axis=1) |
| rows.append(row_concat) |
| |
| combined_frame = np.concatenate(rows, axis=0) |
| combined_frames.append(combined_frame) |
| |
| |
| output_dir = Path("./temp_gradio/outputs") |
| output_dir.mkdir(parents=True, exist_ok=True) |
| output_path = output_dir / "prediction.mp4" |
| save_to_video(combined_frames, str(output_path), fps=hand_config.FPS) |
| |
| progress(1.0, desc="Complete!") |
| |
| return str(output_path), f"✅ Generated {sample_times} prediction samples successfully!" |
| |
| except Exception as e: |
| import traceback |
| error_msg = f"❌ Prediction failed: {str(e)}\n{traceback.format_exc()}" |
| return None, error_msg |
|
|
|
|
| def load_examples(): |
| """Automatically load all image examples from the examples folder""" |
| examples_dir = Path(__file__).parent / "examples" |
| |
| |
| default_instructions = { |
| "0001.jpg": "Left hand: Put the trash into the garbage. Right hand: None.", |
| "0002.jpg": "Left hand: None. Right hand: Pick up the picture of Michael Jackson.", |
| "0003.png": "Left hand: None. Right hand: Pick up the metal water cup.", |
| "0004.jpg": "Left hand: Squeeze the dish sponge. Right hand: None.", |
| "0005.jpg": "Left hand: None. Right hand: Cut the meat with the knife.", |
| "0006.jpg": "Left hand: Open the closet door. Right hand: None.", |
| "0007.jpg": "Left hand: None. Right hand: Cut the paper with the scissors.", |
| "0008.jpg": "Left hand: Wipe the countertop with the cloth. Right hand: None.", |
| "0009.jpg": "Left hand: None. Right hand: Open the cabinet door.", |
| "0010.png": "Left hand: None. Right hand: Turn on the faucet.", |
| "0011.jpg": "Left hand: Put the drink bottle into the trash can. Right hand: None.", |
| "0012.jpg": "Left hand: None. Right hand: Pick up the gray cup from the cabinet.", |
| "0013.jpg": "Left hand: None. Right hand: Take the milk bottle out of the fridge.", |
| "0014.jpg": "Left hand: None. Right hand: 拿起气球。", |
| "0015.jpg": "Left hand: None. Right hand: Pick up the picture with the smaller red heart.", |
| "0016.jpg": "Left hand: None. Right hand: Pick up the picture with \"Cat\".", |
| "0017.jpg": "Left hand: None. Right hand: Pick up the picture of the Statue of Liberty.", |
| "0018.jpg": "Left hand: None. Right hand: Pick up the picture of the two people.", |
| } |
| |
| examples_images = [] |
| instructions_map = {} |
| |
| if examples_dir.exists(): |
| |
| image_files = sorted([f for f in examples_dir.iterdir() |
| if f.suffix.lower() in ['.jpg', '.jpeg', '.png']]) |
| |
| for img_path in image_files: |
| img_path_str = str(img_path) |
| instruction = default_instructions.get( |
| img_path.name, |
| "Left hand: Perform the action. Right hand: None." |
| ) |
| |
| examples_images.append([img_path_str]) |
| |
| instructions_map[img_path_str] = instruction |
| |
| return examples_images, instructions_map |
|
|
|
|
| def get_instruction_for_image(image_path, instructions_map): |
| """Get the instruction for a given image path""" |
| if image_path is None: |
| return gr.update() |
| |
| |
| instruction = instructions_map.get(str(image_path), "") |
| return instruction |
|
|
|
|
|
|
| def create_gradio_interface(): |
| """Create Gradio interface""" |
| |
| with gr.Blocks(delete_cache=(600, 600), title="3D Hand Motion Prediction with VITRA") as demo: |
|
|
| |
| gr.HTML(""" |
| <style> |
| .disabled-checkbox { |
| opacity: 0.5 !important; |
| pointer-events: none !important; |
| } |
| .disabled-textbox textarea { |
| background-color: #f5f5f5 !important; |
| color: #9e9e9e !important; |
| cursor: not-allowed !important; |
| } |
| </style> |
| """) |
| |
| gr.HTML(""" |
| <div align="center"> |
| <h1> 🤖 Hand Action Prediction with <a href="https://microsoft.github.io/VITRA/" target="_blank" style="text-decoration: underline; font-weight: bold; color: #4A90E2;">VITRA</a> <a title="Github" href="https://github.com/microsoft/VITRA" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> <img src="https://img.shields.io/github/stars/microsoft/VITRA?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> </a> </h1> |
| </div> |
| |
| <div style="line-height: 1.8;"> |
| <br> |
| <p style="font-size: 16px;">Upload a <strong style="color: #7C4DFF;">landscape</strong>, <strong style="color: #7C4DFF;">egocentric (first-person)</strong> image containing hand(s) and provide instructions to predict future 3D hand trajectories.</p> |
| |
| <h3>🌟 Steps:</h3> |
| <ol> |
| <li>Upload an landscape view image containing hand(s).</li> |
| <li>Enter text instructions describing the desired task.</li> |
| <li>Configure advanced settings (Optional) and click "Generate 3D Hand Trajectory".</li> |
| </ol> |
| |
| <h3>💡 Tips:</h3> |
| <ul> |
| <li><strong>Use Left/Right Hand</strong>: Select which hand to predict based on what's detected and what you want to predict.</li> |
| <li><strong>Instruction</strong>: Provide clear and specific imperative instructions separately for the left and right hands, and enter them in the corresponding fields. If the results are unsatisfactory, <strong style="color: #7C4DFF;">try providing more detailed instructions</strong> (e.g., color, orientation, etc.).</li> |
| <li>For best inference quality, it is recommended to <strong style="color: #7C4DFF;">capture landscape view images from a camera height close to that of a human head</strong>. Highly unusual or distorted hand poses/positions may cause inference failures.</li> |
| <li>It is worth noting that each generation produces only a single action chunking starting from the current state, which <strong style="color: #7C4DFF;">does not necessarily complete the entire task</strong>. Executing an entire chunking in one step may lead to reduced precision.</li> |
| </ul> |
| |
| </div> |
| |
| <hr style='border: none; border-top: 1px solid #e0e0e0; margin: 20px 0;'> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.HTML(""" |
| <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
| <h3 style='color: white; margin: 0; text-align: center;'>📄 Input</h3> |
| </div> |
| """) |
| |
| |
| input_image = gr.Image( |
| label="🖼️ Upload Image with Hands", |
| type="pil", |
| height=300, |
| ) |
| |
| |
| recon_status = gr.Textbox( |
| label="🔍 Hand Reconstruction Status", |
| value="⏳ Waiting for image upload...", |
| interactive=False, |
| lines=2, |
| container=True |
| ) |
| |
| gr.Markdown("---") |
| gr.HTML(""" |
| <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
| <h3 style='color: white; margin: 0; text-align: center;'>⚙️ Prediction Settings</h3> |
| </div> |
| """) |
| gr.HTML(""" |
| <div style='padding: 8px; background-color: #e8eaf6; border-left: 4px solid #5c6bc0; border-radius: 4px; margin-bottom: 10px;'> |
| <strong style='color: #3949ab;'>👋 Select Hands:</strong> |
| </div> |
| """) |
| with gr.Row(): |
| use_left = gr.Checkbox(label="Use Left Hand", value=True) |
| use_right = gr.Checkbox(label="Use Right Hand", value=True) |
| |
| |
| gr.HTML(""" |
| <div style='padding: 8px; background-color: #e8eaf6; border-left: 4px solid #5c6bc0; border-radius: 4px; margin: 15px 0 10px 0;'> |
| <strong style='color: #3949ab;'>✍️ Instructions:</strong> |
| </div> |
| """) |
| with gr.Row(): |
| with gr.Column(): |
| with gr.Row(): |
| gr.HTML("<div style='display: flex; align-items: center; min-height: 40px; padding-right: 2px;'><span style='font-weight: 600; color: #5c6bc0; white-space: nowrap;'>Left hand:</span></div>") |
| left_instruction = gr.Textbox( |
| label="", |
| value="Put the trash into the garbage.", |
| lines=1, |
| max_lines=5, |
| placeholder="Describe left hand action...", |
| show_label=False, |
| interactive=True, |
| scale=3 |
| ) |
| with gr.Column(): |
| with gr.Row(): |
| gr.HTML("<div style='display: flex; align-items: center; min-height: 40px; padding-right: 2px;'><span style='font-weight: 600; color: #5c6bc0; white-space: nowrap;'>Right hand:</span></div>") |
| right_instruction = gr.Textbox( |
| label="", |
| value="None.", |
| lines=1, |
| max_lines=5, |
| placeholder="Describe right hand action...", |
| show_label=False, |
| interactive=True, |
| scale=3 |
| ) |
|
|
| |
| final_instruction = gr.HTML( |
| value="""<div style='padding: 12px; background-color: #f0f7ff; border-left: 4px solid #4A90E2; border-radius: 4px; margin-top: 10px;'> |
| <strong style='color: #2c5282;'>📝 Final Instruction:</strong><br> |
| <span style='color: #1a365d; font-size: 14px;'>Left hand: Put the trash into the garbage. Right hand: None.</span> |
| </div>""", |
| show_label=False |
| ) |
| final_instruction_text = gr.State(value="Left hand: Put the trash into the garbage. Right hand: None.") |
| |
| |
| with gr.Accordion("🔧 Advanced Settings", open=False): |
| sample_times = gr.Slider( |
| minimum=1, |
| maximum=9, |
| value=4, |
| step=1, |
| label="Number of Samples", |
| info="Multiple samples show different possible trajectories." |
| ) |
| num_ddim_steps = gr.Slider( |
| minimum=1, |
| maximum=50, |
| value=10, |
| step=5, |
| label="DDIM Steps", |
| info="DDIM steps of the diffusion model. 10 is usually sufficient." |
| ) |
| cfg_scale = gr.Slider( |
| minimum=1.0, |
| maximum=15.0, |
| value=5.0, |
| step=0.5, |
| label="CFG Scale", |
| info="Classifier-free guidance scale of the diffusion model." |
| ) |
| |
| |
| generate_btn = gr.Button("🎬 Generate 3D Hand Trajectory", variant="primary", size="lg") |
| |
| |
| hand_data = gr.State(value=None) |
| detected_left = gr.State(value=False) |
| detected_right = gr.State(value=False) |
| |
| |
| session_state = gr.State(value={}) |
| |
| |
| with gr.Column(scale=1): |
| gr.HTML(""" |
| <div style='background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
| <h3 style='color: white; margin: 0; text-align: center;'>🎬 Output</h3> |
| </div> |
| """) |
|
|
| |
| output_video = gr.Video( |
| label="🎬 Predicted Hand Motion", |
| height=500, |
| autoplay=True |
| ) |
| |
| |
| gen_status = gr.Textbox( |
| label="📊 Generation Status", |
| value="", |
| interactive=False, |
| lines=2 |
| ) |
| |
| |
| gr.Markdown("---") |
| gr.HTML(""" |
| <div style='background: linear-gradient(135deg, #89f7fe 0%, #66a6ff 100%); padding: 15px; border-radius: 8px; margin: 20px 0 10px 0;'> |
| <h3 style='color: white; margin: 0; text-align: center;'>📋 Examples</h3> |
| </div> |
| """) |
| gr.HTML(""" |
| <div style='padding: 10px; background-color: #e7f3ff; border-left: 4px solid #2196F3; border-radius: 4px; margin-bottom: 15px;'> |
| <span style='color: #1565c0;'>👆 Click any example below to load the image and instruction</span> |
| </div> |
| """) |
|
|
| examples_images, instructions_map = load_examples() |
| |
| |
| example_gallery = gr.Gallery( |
| value=[img[0] for img in examples_images], |
| label="", |
| columns=6, |
| height="450", |
| object_fit="contain", |
| show_label=False |
| ) |
| |
| |
| def load_example_from_gallery(evt: gr.SelectData): |
| selected_index = evt.index |
| if selected_index < len(examples_images): |
| img_path = examples_images[selected_index][0] |
| instruction_text = instructions_map.get(img_path, "") |
| |
| left_text, right_text = parse_instruction(instruction_text) |
| |
| return gr.update(value=img_path), gr.update(value=left_text), gr.update(value=right_text), gr.update(interactive=False) |
| return gr.update(), gr.update(), gr.update(), gr.update() |
|
|
| example_gallery.select( |
| fn=load_example_from_gallery, |
| inputs=[], |
| outputs=[input_image, left_instruction, right_instruction, generate_btn], |
| show_progress=False |
| ).then( |
| fn=update_final_instruction, |
| inputs=[left_instruction, right_instruction, use_left, use_right], |
| outputs=[final_instruction, final_instruction_text], |
| show_progress=False |
| ) |
|
|
| |
| |
| |
| input_image.change( |
| fn=validate_and_process_wrapper, |
| inputs=[input_image, session_state], |
| outputs=[recon_status, generate_btn, hand_data, detected_left, detected_right, session_state], |
| show_progress='full' |
| ).then( |
| fn=update_checkboxes, |
| inputs=[detected_left, detected_right], |
| outputs=[use_left, use_right, left_instruction, right_instruction], |
| show_progress=False |
| ) |
|
|
| |
| use_left.change( |
| fn=update_instruction_interactivity, |
| inputs=[use_left, use_right], |
| outputs=[left_instruction, right_instruction], |
| show_progress=False |
| ).then( |
| fn=update_final_instruction, |
| inputs=[left_instruction, right_instruction, use_left, use_right], |
| outputs=[final_instruction, final_instruction_text], |
| show_progress=False |
| ) |
|
|
| use_right.change( |
| fn=update_instruction_interactivity, |
| inputs=[use_left, use_right], |
| outputs=[left_instruction, right_instruction], |
| show_progress=False |
| ).then( |
| fn=update_final_instruction, |
| inputs=[left_instruction, right_instruction, use_left, use_right], |
| outputs=[final_instruction, final_instruction_text], |
| show_progress=False |
| ) |
|
|
| |
| left_instruction.change( |
| fn=update_final_instruction, |
| inputs=[left_instruction, right_instruction, use_left, use_right], |
| outputs=[final_instruction, final_instruction_text], |
| show_progress=False |
| ) |
| |
| right_instruction.change( |
| fn=update_final_instruction, |
| inputs=[left_instruction, right_instruction, use_left, use_right], |
| outputs=[final_instruction, final_instruction_text], |
| show_progress=False |
| ) |
|
|
|
|
| generate_btn.click( |
| fn=generate_prediction, |
| inputs=[final_instruction_text, use_left, use_right, sample_times, num_ddim_steps, cfg_scale, hand_data, input_image], |
| outputs=[output_video, gen_status], |
| show_progress='full' |
| ) |
| |
| return demo |
|
|
| if __name__ == "__main__": |
| """launch Gradio app""" |
| |
| print("Initializing services...") |
| init_msg = initialize_services() |
| print(init_msg) |
| |
| if "Failed" in init_msg: |
| print("⚠️ Services failed to initialize. Please check the configuration and try again.") |
| |
| |
| demo = create_gradio_interface() |
| |
| |
| demo.launch() |