pi05-so100-diverse / infer_so101.py
bot
Restore all project files from original repo
9ad6280
#!/usr/bin/env python3
"""
Run Pi0.5 inference on SO-101.
Uses LeRobot's FeetechMotorsBus with calibration for correct normalization,
but bypasses lerobot_record's problematic control loop.
Usage:
python infer_so101.py --task "pick up the blue football"
"""
import argparse
import json
import logging
import sys
import time
from pathlib import Path
import cv2
import numpy as np
import scservo_sdk as scs
import torch
sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path.home() / "lerobot" / "src"))
logging.basicConfig(level=logging.WARNING, format='%(asctime)s %(message)s', datefmt='%H:%M:%S')
log = logging.getLogger()
MOTOR_NAMES = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
MOTOR_IDS = [1, 2, 3, 4, 5, 6]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, required=True)
parser.add_argument("--checkpoint", type=str,
default="/mnt/hdd/pi05-training/full_run/checkpoints/015000/pretrained_model")
parser.add_argument("--port", type=str, default="/dev/ttyACM0")
parser.add_argument("--cam-front", type=int, default=2)
parser.add_argument("--cam-wrist", type=int, default=0)
parser.add_argument("--max-steps", type=int, default=0, help="0 = run until Ctrl+C")
args = parser.parse_args()
# --- Connect motors using LeRobot's bus (for calibration/normalization) ---
from lerobot.motors.feetech.feetech import FeetechMotorsBus
from lerobot.motors import Motor, MotorNormMode, MotorCalibration
bus = FeetechMotorsBus(
port=args.port,
motors={
'shoulder_pan': Motor(1, 'sts3215', MotorNormMode.RANGE_M100_100),
'shoulder_lift': Motor(2, 'sts3215', MotorNormMode.RANGE_M100_100),
'elbow_flex': Motor(3, 'sts3215', MotorNormMode.RANGE_M100_100),
'wrist_flex': Motor(4, 'sts3215', MotorNormMode.RANGE_M100_100),
'wrist_roll': Motor(5, 'sts3215', MotorNormMode.RANGE_M100_100),
'gripper': Motor(6, 'sts3215', MotorNormMode.RANGE_0_100),
},
)
bus.connect()
# Load calibration
cal_path = Path.home() / ".cache/huggingface/lerobot/calibration/robots/so_follower/my_so101.json"
cal = json.load(open(cal_path))
cal_dict = {name: MotorCalibration(**vals) for name, vals in cal.items()}
bus.write_calibration(cal_dict)
log.warning("Bus connected with calibration")
# Configure motors the same way LeRobot does in so_follower.configure()
# This uses torque_disabled() context which disables torque, configures, re-enables
with bus.torque_disabled():
bus.configure_motors()
for motor in bus.motors:
bus.write("Operating_Mode", motor, 0) # Position mode
bus.write("P_Coefficient", motor, 16)
bus.write("I_Coefficient", motor, 0)
bus.write("D_Coefficient", motor, 32)
bus.write("Goal_Velocity", motor, 600) # Slow velocity limit
bus.write("Acceleration", motor, 50) # Gentle acceleration
if motor == "gripper":
bus.write("Max_Torque_Limit", motor, 500)
bus.write("Protection_Current", motor, 250)
bus.write("Overload_Torque", motor, 25)
# torque_disabled() re-enables torque on exit
# Velocity and acceleration limits prevent snapping
log.warning("Motors configured and torque enabled (velocity/accel limited)")
# --- Open cameras ---
cap_front = cv2.VideoCapture(args.cam_front)
cap_wrist = cv2.VideoCapture(args.cam_wrist)
for cap in [cap_front, cap_wrist]:
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
log.warning("Cameras open")
# --- Load policy + preprocessor + postprocessor ---
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import prepare_observation_for_inference, make_robot_action
from lerobot.configs.policies import PreTrainedConfig
from lerobot.processor.rename_processor import rename_stats
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
log.warning("Loading Pi0.5...")
policy_cfg = PreTrainedConfig.from_pretrained(args.checkpoint)
policy_cfg.pretrained_path = Path(args.checkpoint)
policy = PI05Policy.from_pretrained(args.checkpoint)
policy = policy.to("cuda")
policy.eval()
policy.reset()
# Build stats from checkpoint's saved preprocessor
rename_map = {
"observation.images.front": "observation.images.base_0_rgb",
"observation.images.wrist": "observation.images.left_wrist_0_rgb",
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy_cfg,
pretrained_path=policy_cfg.pretrained_path,
preprocessor_overrides={
"device_processor": {"device": "cuda"},
"rename_observations_processor": {"rename_map": rename_map},
},
)
action_names = [f"{name}.pos" for name in MOTOR_NAMES]
ds_features = {"action": {"names": action_names}}
# --- Set up live camera display ---
try:
import rerun as rr
rr.init("so101_inference", spawn=True)
use_rerun = True
log.warning("Rerun viewer launched — live camera feed")
except ImportError:
use_rerun = False
log.warning("Rerun not available, no live view")
log.warning(f"Running: '{args.task}' — Ctrl+C to stop")
step = 0
try:
while args.max_steps == 0 or step < args.max_steps:
t0 = time.perf_counter()
# 1. Read motor positions (calibrated/normalized by bus)
try:
pos_dict = bus.sync_read("Present_Position", num_retry=5)
except ConnectionError:
bus.port_handler.is_using = False
bus.port_handler.ser.reset_input_buffer()
continue
# Build observation dict
state_array = np.array([pos_dict[name] for name in MOTOR_NAMES], dtype=np.float32)
# 2. Capture camera images
ret_f, frame_front = cap_front.read()
ret_w, frame_wrist = cap_wrist.read()
if not ret_f or not ret_w:
continue
# Live display
if use_rerun:
rr.set_time_sequence("step", step)
rr.log("camera/front", rr.Image(frame_front))
rr.log("camera/wrist", rr.Image(frame_wrist))
rr.log("state", rr.BarChart([pos_dict[n] for n in MOTOR_NAMES]))
observation = {
"observation.images.front": frame_front,
"observation.images.wrist": frame_wrist,
"observation.state": state_array,
}
# 3. Inference
with torch.inference_mode():
obs = prepare_observation_for_inference(
observation, torch.device("cuda"), args.task, "so101_follower"
)
obs = preprocessor(obs)
action = policy.select_action(obs)
action = postprocessor(action)
# 4. Convert to motor commands
robot_action = make_robot_action(action, ds_features)
# 5. Send to motors (calibrated/normalized by bus)
goal_pos = {name: robot_action[f"{name}.pos"] for name in MOTOR_NAMES}
try:
bus.sync_write("Goal_Position", goal_pos)
except ConnectionError:
bus.port_handler.is_using = False
bus.port_handler.ser.reset_input_buffer()
dt = time.perf_counter() - t0
step += 1
if step % 10 == 0:
pos_str = " ".join(f"{pos_dict[n]:>7.1f}" for n in MOTOR_NAMES)
act_str = " ".join(f"{robot_action[f'{n}.pos']:>7.1f}" for n in MOTOR_NAMES)
log.warning(f"step {step:>4} | state=[{pos_str}] | action=[{act_str}] | {dt*1000:.0f}ms")
except KeyboardInterrupt:
log.warning("Stopped by user")
finally:
log.warning("Disabling torque...")
try:
bus.disable_torque()
except Exception:
for mid in MOTOR_IDS:
try:
bus.packet_handler.write1ByteTxRx(bus.port_handler, mid, 40, 0)
except Exception:
pass
bus.disconnect()
cap_front.release()
cap_wrist.release()
log.warning("Done")
if __name__ == "__main__":
main()