#!/usr/bin/env python3 """ Run Pi0.5 inference on SO-101. Uses LeRobot's FeetechMotorsBus with calibration for correct normalization, but bypasses lerobot_record's problematic control loop. Usage: python infer_so101.py --task "pick up the blue football" """ import argparse import json import logging import sys import time from pathlib import Path import cv2 import numpy as np import scservo_sdk as scs import torch sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path.home() / "lerobot" / "src")) logging.basicConfig(level=logging.WARNING, format='%(asctime)s %(message)s', datefmt='%H:%M:%S') log = logging.getLogger() MOTOR_NAMES = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"] MOTOR_IDS = [1, 2, 3, 4, 5, 6] def main(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, required=True) parser.add_argument("--checkpoint", type=str, default="/mnt/hdd/pi05-training/full_run/checkpoints/015000/pretrained_model") parser.add_argument("--port", type=str, default="/dev/ttyACM0") parser.add_argument("--cam-front", type=int, default=2) parser.add_argument("--cam-wrist", type=int, default=0) parser.add_argument("--max-steps", type=int, default=0, help="0 = run until Ctrl+C") args = parser.parse_args() # --- Connect motors using LeRobot's bus (for calibration/normalization) --- from lerobot.motors.feetech.feetech import FeetechMotorsBus from lerobot.motors import Motor, MotorNormMode, MotorCalibration bus = FeetechMotorsBus( port=args.port, motors={ 'shoulder_pan': Motor(1, 'sts3215', MotorNormMode.RANGE_M100_100), 'shoulder_lift': Motor(2, 'sts3215', MotorNormMode.RANGE_M100_100), 'elbow_flex': Motor(3, 'sts3215', MotorNormMode.RANGE_M100_100), 'wrist_flex': Motor(4, 'sts3215', MotorNormMode.RANGE_M100_100), 'wrist_roll': Motor(5, 'sts3215', MotorNormMode.RANGE_M100_100), 'gripper': Motor(6, 'sts3215', MotorNormMode.RANGE_0_100), }, ) bus.connect() # Load calibration cal_path = Path.home() / ".cache/huggingface/lerobot/calibration/robots/so_follower/my_so101.json" cal = json.load(open(cal_path)) cal_dict = {name: MotorCalibration(**vals) for name, vals in cal.items()} bus.write_calibration(cal_dict) log.warning("Bus connected with calibration") # Configure motors the same way LeRobot does in so_follower.configure() # This uses torque_disabled() context which disables torque, configures, re-enables with bus.torque_disabled(): bus.configure_motors() for motor in bus.motors: bus.write("Operating_Mode", motor, 0) # Position mode bus.write("P_Coefficient", motor, 16) bus.write("I_Coefficient", motor, 0) bus.write("D_Coefficient", motor, 32) bus.write("Goal_Velocity", motor, 600) # Slow velocity limit bus.write("Acceleration", motor, 50) # Gentle acceleration if motor == "gripper": bus.write("Max_Torque_Limit", motor, 500) bus.write("Protection_Current", motor, 250) bus.write("Overload_Torque", motor, 25) # torque_disabled() re-enables torque on exit # Velocity and acceleration limits prevent snapping log.warning("Motors configured and torque enabled (velocity/accel limited)") # --- Open cameras --- cap_front = cv2.VideoCapture(args.cam_front) cap_wrist = cv2.VideoCapture(args.cam_wrist) for cap in [cap_front, cap_wrist]: cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) log.warning("Cameras open") # --- Load policy + preprocessor + postprocessor --- from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.utils import prepare_observation_for_inference, make_robot_action from lerobot.configs.policies import PreTrainedConfig from lerobot.processor.rename_processor import rename_stats from lerobot.policies.pi05.modeling_pi05 import PI05Policy log.warning("Loading Pi0.5...") policy_cfg = PreTrainedConfig.from_pretrained(args.checkpoint) policy_cfg.pretrained_path = Path(args.checkpoint) policy = PI05Policy.from_pretrained(args.checkpoint) policy = policy.to("cuda") policy.eval() policy.reset() # Build stats from checkpoint's saved preprocessor rename_map = { "observation.images.front": "observation.images.base_0_rgb", "observation.images.wrist": "observation.images.left_wrist_0_rgb", } preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy_cfg, pretrained_path=policy_cfg.pretrained_path, preprocessor_overrides={ "device_processor": {"device": "cuda"}, "rename_observations_processor": {"rename_map": rename_map}, }, ) action_names = [f"{name}.pos" for name in MOTOR_NAMES] ds_features = {"action": {"names": action_names}} # --- Set up live camera display --- try: import rerun as rr rr.init("so101_inference", spawn=True) use_rerun = True log.warning("Rerun viewer launched — live camera feed") except ImportError: use_rerun = False log.warning("Rerun not available, no live view") log.warning(f"Running: '{args.task}' — Ctrl+C to stop") step = 0 try: while args.max_steps == 0 or step < args.max_steps: t0 = time.perf_counter() # 1. Read motor positions (calibrated/normalized by bus) try: pos_dict = bus.sync_read("Present_Position", num_retry=5) except ConnectionError: bus.port_handler.is_using = False bus.port_handler.ser.reset_input_buffer() continue # Build observation dict state_array = np.array([pos_dict[name] for name in MOTOR_NAMES], dtype=np.float32) # 2. Capture camera images ret_f, frame_front = cap_front.read() ret_w, frame_wrist = cap_wrist.read() if not ret_f or not ret_w: continue # Live display if use_rerun: rr.set_time_sequence("step", step) rr.log("camera/front", rr.Image(frame_front)) rr.log("camera/wrist", rr.Image(frame_wrist)) rr.log("state", rr.BarChart([pos_dict[n] for n in MOTOR_NAMES])) observation = { "observation.images.front": frame_front, "observation.images.wrist": frame_wrist, "observation.state": state_array, } # 3. Inference with torch.inference_mode(): obs = prepare_observation_for_inference( observation, torch.device("cuda"), args.task, "so101_follower" ) obs = preprocessor(obs) action = policy.select_action(obs) action = postprocessor(action) # 4. Convert to motor commands robot_action = make_robot_action(action, ds_features) # 5. Send to motors (calibrated/normalized by bus) goal_pos = {name: robot_action[f"{name}.pos"] for name in MOTOR_NAMES} try: bus.sync_write("Goal_Position", goal_pos) except ConnectionError: bus.port_handler.is_using = False bus.port_handler.ser.reset_input_buffer() dt = time.perf_counter() - t0 step += 1 if step % 10 == 0: pos_str = " ".join(f"{pos_dict[n]:>7.1f}" for n in MOTOR_NAMES) act_str = " ".join(f"{robot_action[f'{n}.pos']:>7.1f}" for n in MOTOR_NAMES) log.warning(f"step {step:>4} | state=[{pos_str}] | action=[{act_str}] | {dt*1000:.0f}ms") except KeyboardInterrupt: log.warning("Stopped by user") finally: log.warning("Disabling torque...") try: bus.disable_torque() except Exception: for mid in MOTOR_IDS: try: bus.packet_handler.write1ByteTxRx(bus.port_handler, mid, 40, 0) except Exception: pass bus.disconnect() cap_front.release() cap_wrist.release() log.warning("Done") if __name__ == "__main__": main()