| |
| """ |
| Run Pi0.5 inference on SO-101. |
| |
| Uses LeRobot's FeetechMotorsBus with calibration for correct normalization, |
| but bypasses lerobot_record's problematic control loop. |
| |
| Usage: |
| python infer_so101.py --task "pick up the blue football" |
| """ |
| import argparse |
| import json |
| import logging |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import cv2 |
| import numpy as np |
| import scservo_sdk as scs |
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
| sys.path.insert(0, str(Path.home() / "lerobot" / "src")) |
|
|
| logging.basicConfig(level=logging.WARNING, format='%(asctime)s %(message)s', datefmt='%H:%M:%S') |
| log = logging.getLogger() |
|
|
| MOTOR_NAMES = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"] |
| MOTOR_IDS = [1, 2, 3, 4, 5, 6] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--task", type=str, required=True) |
| parser.add_argument("--checkpoint", type=str, |
| default="/mnt/hdd/pi05-training/full_run/checkpoints/015000/pretrained_model") |
| parser.add_argument("--port", type=str, default="/dev/ttyACM0") |
| parser.add_argument("--cam-front", type=int, default=2) |
| parser.add_argument("--cam-wrist", type=int, default=0) |
| parser.add_argument("--max-steps", type=int, default=0, help="0 = run until Ctrl+C") |
| args = parser.parse_args() |
|
|
| |
| from lerobot.motors.feetech.feetech import FeetechMotorsBus |
| from lerobot.motors import Motor, MotorNormMode, MotorCalibration |
|
|
| bus = FeetechMotorsBus( |
| port=args.port, |
| motors={ |
| 'shoulder_pan': Motor(1, 'sts3215', MotorNormMode.RANGE_M100_100), |
| 'shoulder_lift': Motor(2, 'sts3215', MotorNormMode.RANGE_M100_100), |
| 'elbow_flex': Motor(3, 'sts3215', MotorNormMode.RANGE_M100_100), |
| 'wrist_flex': Motor(4, 'sts3215', MotorNormMode.RANGE_M100_100), |
| 'wrist_roll': Motor(5, 'sts3215', MotorNormMode.RANGE_M100_100), |
| 'gripper': Motor(6, 'sts3215', MotorNormMode.RANGE_0_100), |
| }, |
| ) |
| bus.connect() |
|
|
| |
| cal_path = Path.home() / ".cache/huggingface/lerobot/calibration/robots/so_follower/my_so101.json" |
| cal = json.load(open(cal_path)) |
| cal_dict = {name: MotorCalibration(**vals) for name, vals in cal.items()} |
| bus.write_calibration(cal_dict) |
| log.warning("Bus connected with calibration") |
|
|
| |
| |
| with bus.torque_disabled(): |
| bus.configure_motors() |
| for motor in bus.motors: |
| bus.write("Operating_Mode", motor, 0) |
| bus.write("P_Coefficient", motor, 16) |
| bus.write("I_Coefficient", motor, 0) |
| bus.write("D_Coefficient", motor, 32) |
| bus.write("Goal_Velocity", motor, 600) |
| bus.write("Acceleration", motor, 50) |
| if motor == "gripper": |
| bus.write("Max_Torque_Limit", motor, 500) |
| bus.write("Protection_Current", motor, 250) |
| bus.write("Overload_Torque", motor, 25) |
| |
| |
| log.warning("Motors configured and torque enabled (velocity/accel limited)") |
|
|
| |
| cap_front = cv2.VideoCapture(args.cam_front) |
| cap_wrist = cv2.VideoCapture(args.cam_wrist) |
| for cap in [cap_front, cap_wrist]: |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) |
| log.warning("Cameras open") |
|
|
| |
| from lerobot.policies.factory import make_pre_post_processors |
| from lerobot.policies.utils import prepare_observation_for_inference, make_robot_action |
| from lerobot.configs.policies import PreTrainedConfig |
| from lerobot.processor.rename_processor import rename_stats |
| from lerobot.policies.pi05.modeling_pi05 import PI05Policy |
|
|
| log.warning("Loading Pi0.5...") |
| policy_cfg = PreTrainedConfig.from_pretrained(args.checkpoint) |
| policy_cfg.pretrained_path = Path(args.checkpoint) |
|
|
| policy = PI05Policy.from_pretrained(args.checkpoint) |
| policy = policy.to("cuda") |
| policy.eval() |
| policy.reset() |
|
|
| |
| rename_map = { |
| "observation.images.front": "observation.images.base_0_rgb", |
| "observation.images.wrist": "observation.images.left_wrist_0_rgb", |
| } |
|
|
| preprocessor, postprocessor = make_pre_post_processors( |
| policy_cfg=policy_cfg, |
| pretrained_path=policy_cfg.pretrained_path, |
| preprocessor_overrides={ |
| "device_processor": {"device": "cuda"}, |
| "rename_observations_processor": {"rename_map": rename_map}, |
| }, |
| ) |
|
|
| action_names = [f"{name}.pos" for name in MOTOR_NAMES] |
| ds_features = {"action": {"names": action_names}} |
|
|
| |
| try: |
| import rerun as rr |
| rr.init("so101_inference", spawn=True) |
| use_rerun = True |
| log.warning("Rerun viewer launched — live camera feed") |
| except ImportError: |
| use_rerun = False |
| log.warning("Rerun not available, no live view") |
|
|
| log.warning(f"Running: '{args.task}' — Ctrl+C to stop") |
|
|
| step = 0 |
| try: |
| while args.max_steps == 0 or step < args.max_steps: |
| t0 = time.perf_counter() |
|
|
| |
| try: |
| pos_dict = bus.sync_read("Present_Position", num_retry=5) |
| except ConnectionError: |
| bus.port_handler.is_using = False |
| bus.port_handler.ser.reset_input_buffer() |
| continue |
|
|
| |
| state_array = np.array([pos_dict[name] for name in MOTOR_NAMES], dtype=np.float32) |
|
|
| |
| ret_f, frame_front = cap_front.read() |
| ret_w, frame_wrist = cap_wrist.read() |
| if not ret_f or not ret_w: |
| continue |
|
|
| |
| if use_rerun: |
| rr.set_time_sequence("step", step) |
| rr.log("camera/front", rr.Image(frame_front)) |
| rr.log("camera/wrist", rr.Image(frame_wrist)) |
| rr.log("state", rr.BarChart([pos_dict[n] for n in MOTOR_NAMES])) |
|
|
| observation = { |
| "observation.images.front": frame_front, |
| "observation.images.wrist": frame_wrist, |
| "observation.state": state_array, |
| } |
|
|
| |
| with torch.inference_mode(): |
| obs = prepare_observation_for_inference( |
| observation, torch.device("cuda"), args.task, "so101_follower" |
| ) |
| obs = preprocessor(obs) |
| action = policy.select_action(obs) |
| action = postprocessor(action) |
|
|
| |
| robot_action = make_robot_action(action, ds_features) |
|
|
| |
| goal_pos = {name: robot_action[f"{name}.pos"] for name in MOTOR_NAMES} |
| try: |
| bus.sync_write("Goal_Position", goal_pos) |
| except ConnectionError: |
| bus.port_handler.is_using = False |
| bus.port_handler.ser.reset_input_buffer() |
|
|
| dt = time.perf_counter() - t0 |
| step += 1 |
|
|
| if step % 10 == 0: |
| pos_str = " ".join(f"{pos_dict[n]:>7.1f}" for n in MOTOR_NAMES) |
| act_str = " ".join(f"{robot_action[f'{n}.pos']:>7.1f}" for n in MOTOR_NAMES) |
| log.warning(f"step {step:>4} | state=[{pos_str}] | action=[{act_str}] | {dt*1000:.0f}ms") |
|
|
| except KeyboardInterrupt: |
| log.warning("Stopped by user") |
| finally: |
| log.warning("Disabling torque...") |
| try: |
| bus.disable_torque() |
| except Exception: |
| for mid in MOTOR_IDS: |
| try: |
| bus.packet_handler.write1ByteTxRx(bus.port_handler, mid, 40, 0) |
| except Exception: |
| pass |
| bus.disconnect() |
| cap_front.release() |
| cap_wrist.release() |
| log.warning("Done") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|