Wuji Hand Gesture VAM (TI2V-5B + 30L ActionDiT)
World Action Model checkpoints for hand gesture manipulation, trained with a TI2V-5B video backbone and a 30-layer ActionDiT action decoder.
Model Architecture
- Video Backbone: Wan2.2-TI2V-5B (5B parameters,
video_dim=3072) - Action Decoder: ActionDiT with 30 transformer layers
dim=768,ffn_dim=3072,num_heads=12- Bridge type:
cross_attn_detach(REPA-style, stop gradient at bridge) - Bridge layers: all 30 layers (
0,1,2,...,29) bridge_exclude_full_ref: enabled
- Action Dimension: 20
- Context Frames: 17
- Video Resolution: 480 x 832
- Denoising: Flow matching, 20 steps at inference
Checkpoints
| Checkpoint | val/loss_action | val/loss_video | val/loss (total) | Notes |
|---|---|---|---|---|
step-3000.safetensors |
0.167 | 0.040 | 0.207 | Best overall validation loss |
step-5000.safetensors |
0.192 | 0.065 | 0.257 | Mid-late training |
step-7000.safetensors |
0.216 | 0.057 | 0.273 | Latest checkpoint |
Recommendation: Use step-3000 for best action prediction quality.
Training Details
- Dataset: Wuji hand gesture cropped dataset
- Training: 32 GPUs (4 nodes x 8 GPUs), multi-node distributed
- Learning rate: 5e-5
- Wandb project:
wuji_hand_gesture
Usage
# Server side
python -m deploy.scripts.serve_policy \
--checkpoint step-3000.safetensors \
--task_name "hand gesture" \
--action_dim 20 \
--bridge_type cross_attn_detach \
--action_dit_num_layers 30 \
--action_dit_bridge_layers "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29" \
--video_dim 3072 \
--num_frames 17 \
--port 8000
# Client side
from deploy.serving import WebsocketClientPolicy, ActionChunkBroker
client = WebsocketClientPolicy(host="gpu-server", port=8000)
broker = ActionChunkBroker(client, action_horizon=17, replan_steps=5)
result = broker.infer({"head_camera": observation_image})
action = result["actions"] # (20,) single-step action