Wuji Hand Gesture VAM (TI2V-5B + 30L ActionDiT)

World Action Model checkpoints for hand gesture manipulation, trained with a TI2V-5B video backbone and a 30-layer ActionDiT action decoder.

Model Architecture

  • Video Backbone: Wan2.2-TI2V-5B (5B parameters, video_dim=3072)
  • Action Decoder: ActionDiT with 30 transformer layers
    • dim=768, ffn_dim=3072, num_heads=12
    • Bridge type: cross_attn_detach (REPA-style, stop gradient at bridge)
    • Bridge layers: all 30 layers (0,1,2,...,29)
    • bridge_exclude_full_ref: enabled
  • Action Dimension: 20
  • Context Frames: 17
  • Video Resolution: 480 x 832
  • Denoising: Flow matching, 20 steps at inference

Checkpoints

Checkpoint val/loss_action val/loss_video val/loss (total) Notes
step-3000.safetensors 0.167 0.040 0.207 Best overall validation loss
step-5000.safetensors 0.192 0.065 0.257 Mid-late training
step-7000.safetensors 0.216 0.057 0.273 Latest checkpoint

Recommendation: Use step-3000 for best action prediction quality.

Training Details

  • Dataset: Wuji hand gesture cropped dataset
  • Training: 32 GPUs (4 nodes x 8 GPUs), multi-node distributed
  • Learning rate: 5e-5
  • Wandb project: wuji_hand_gesture

Usage

# Server side
python -m deploy.scripts.serve_policy \
    --checkpoint step-3000.safetensors \
    --task_name "hand gesture" \
    --action_dim 20 \
    --bridge_type cross_attn_detach \
    --action_dit_num_layers 30 \
    --action_dit_bridge_layers "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29" \
    --video_dim 3072 \
    --num_frames 17 \
    --port 8000

# Client side
from deploy.serving import WebsocketClientPolicy, ActionChunkBroker

client = WebsocketClientPolicy(host="gpu-server", port=8000)
broker = ActionChunkBroker(client, action_horizon=17, replan_steps=5)

result = broker.infer({"head_camera": observation_image})
action = result["actions"]  # (20,) single-step action
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading