| import asyncio |
| import contextlib |
| import logging |
| import time |
| from collections import deque |
|
|
| import numpy as np |
| from transport_server_client import RoboticsConsumer, RoboticsProducer |
| from transport_server_client.video import VideoConsumer, VideoProducer |
|
|
| from inference_server.models import get_inference_engine |
| from inference_server.models.joint_config import JointConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def busy_wait(seconds): |
| """ |
| Precise timing function for consistent control loops. |
| |
| On some systems, asyncio.sleep is not accurate enough for |
| control loops, so we use busy waiting for short delays. |
| """ |
| if seconds > 0: |
| end_time = asyncio.get_event_loop().time() + seconds |
| while asyncio.get_event_loop().time() < end_time: |
| pass |
|
|
|
|
| class InferenceSession: |
| """ |
| A single inference session managing one model and its Transport Server connections. |
| |
| Handles joint values in NORMALIZED VALUES throughout the pipeline. |
| Supports multiple camera streams with different camera names. |
| Supports multiple policy types: ACT, Pi0, SmolVLA, Diffusion Policy. |
| """ |
|
|
| def __init__( |
| self, |
| session_id: str, |
| policy_path: str, |
| camera_names: list[str], |
| transport_server_url: str, |
| workspace_id: str, |
| camera_room_ids: dict[str, str], |
| joint_input_room_id: str, |
| joint_output_room_id: str, |
| policy_type: str = "act", |
| language_instruction: str | None = None, |
| ): |
| self.session_id = session_id |
| self.policy_path = policy_path |
| self.policy_type = policy_type.lower() |
| self.camera_names = camera_names |
| self.transport_server_url = transport_server_url |
| self.language_instruction = language_instruction |
|
|
| |
| self.workspace_id = workspace_id |
| self.camera_room_ids = camera_room_ids |
| self.joint_input_room_id = joint_input_room_id |
| self.joint_output_room_id = joint_output_room_id |
|
|
| |
| self.camera_consumers: dict[str, VideoConsumer] = {} |
| self.joint_input_consumer: RoboticsConsumer | None = None |
| self.joint_output_producer: RoboticsProducer | None = None |
|
|
| |
| self.inference_engine = None |
|
|
| |
| self.status = "initializing" |
| self.error_message: str | None = None |
| self.inference_task: asyncio.Task | None = None |
|
|
| |
| self.latest_images: dict[str, np.ndarray] = {} |
| self.latest_joint_positions: np.ndarray | None = None |
| |
| self.complete_joint_state: np.ndarray = np.zeros(6, dtype=np.float32) |
| self.images_updated: dict[str, bool] = dict.fromkeys(camera_names, False) |
| self.joints_updated = False |
|
|
| |
| self.action_queue: deque = deque(maxlen=100) |
| self.n_action_steps = 10 |
|
|
| |
| self.last_queue_cleanup = time.time() |
| self.queue_cleanup_interval = 10.0 |
|
|
| |
| self.control_frequency_hz = 20 |
| self.inference_frequency_hz = 2 |
|
|
| |
| self.stats = { |
| "inference_count": 0, |
| "images_received": dict.fromkeys(camera_names, 0), |
| "joints_received": 0, |
| "commands_sent": 0, |
| "errors": 0, |
| "actions_in_queue": 0, |
| "policy_type": self.policy_type, |
| } |
|
|
| |
| self.last_command_values: np.ndarray | None = None |
| self.last_joint_check_time = time.time() |
|
|
| |
| self.last_activity_time = time.time() |
| self.timeout_seconds = 600 |
| self.timeout_check_task: asyncio.Task | None = None |
|
|
| async def initialize(self): |
| """Initialize the session by loading the model and setting up Transport Server connections.""" |
| logger.info( |
| f"Initializing session {self.session_id} with policy type: {self.policy_type}, " |
| f"cameras: {self.camera_names}" |
| ) |
|
|
| |
| engine_kwargs = { |
| "policy_path": self.policy_path, |
| "camera_names": self.camera_names, |
| } |
|
|
| |
| if ( |
| self.policy_type in {"pi0", "pi0fast", "smolvla"} |
| and self.language_instruction |
| ): |
| engine_kwargs["language_instruction"] = self.language_instruction |
|
|
| self.inference_engine = get_inference_engine(self.policy_type, **engine_kwargs) |
|
|
| |
| await self.inference_engine.load_policy() |
|
|
| |
| for camera_name in self.camera_names: |
| self.camera_consumers[camera_name] = VideoConsumer( |
| self.transport_server_url |
| ) |
|
|
| self.joint_input_consumer = RoboticsConsumer(self.transport_server_url) |
| self.joint_output_producer = RoboticsProducer(self.transport_server_url) |
|
|
| |
| self._setup_callbacks() |
|
|
| |
| await self._connect_to_rooms() |
|
|
| |
| for camera_name, consumer in self.camera_consumers.items(): |
| await consumer.start_receiving() |
| logger.info(f"Started receiving frames for camera: {camera_name}") |
|
|
| |
| self.timeout_check_task = asyncio.create_task(self._timeout_monitor()) |
|
|
| self.status = "ready" |
| logger.info( |
| f"✅ Session {self.session_id} initialized successfully with {self.policy_type} policy " |
| f"and {len(self.camera_names)} cameras" |
| ) |
|
|
| def _setup_callbacks(self): |
| """Set up callbacks for Transport clients.""" |
|
|
| def create_frame_callback(camera_name: str): |
| """Create a frame callback for a specific camera.""" |
|
|
| def on_frame_received(frame_data): |
| """Handle incoming camera frame from VideoConsumer.""" |
| metadata = frame_data.metadata |
| width = metadata.get("width", 0) |
| height = metadata.get("height", 0) |
| format_type = metadata.get("format", "rgb24") |
|
|
| if format_type == "rgb24" and width > 0 and height > 0: |
| |
| frame_bytes = frame_data.data |
|
|
| |
| expected_size = height * width * 3 |
| if len(frame_bytes) != expected_size: |
| logger.warning( |
| f"Frame size mismatch for camera {camera_name}: " |
| f"expected {expected_size}, got {len(frame_bytes)}. Skipping frame." |
| ) |
| self.stats["errors"] += 1 |
| return |
|
|
| img_rgb = np.frombuffer(frame_bytes, dtype=np.uint8).reshape(( |
| height, |
| width, |
| 3, |
| )) |
|
|
| |
| self.latest_images[camera_name] = img_rgb |
| self.images_updated[camera_name] = True |
| self.stats["images_received"][camera_name] += 1 |
| |
| self.last_activity_time = time.time() |
|
|
| return on_frame_received |
|
|
| |
| for camera_name in self.camera_names: |
| callback = create_frame_callback(camera_name) |
| self.camera_consumers[camera_name].on_frame_update(callback) |
|
|
| def on_joints_received(joints_data): |
| """Handle incoming joint data from RoboticsConsumer.""" |
| joint_values = self._parse_joint_data(joints_data) |
| if joint_values: |
| |
| for i, value in enumerate(joint_values[:6]): |
| self.complete_joint_state[i] = value |
|
|
| self.latest_joint_positions = self.complete_joint_state.copy() |
| self.joints_updated = True |
| self.stats["joints_received"] += 1 |
| |
| self.last_activity_time = time.time() |
|
|
| def on_error(error_msg): |
| """Handle Transport client errors.""" |
| logger.error( |
| f"Transport client error in session {self.session_id}: {error_msg}" |
| ) |
| self.error_message = str(error_msg) |
| self.stats["errors"] += 1 |
|
|
| |
| self.joint_input_consumer.on_joint_update(on_joints_received) |
| self.joint_input_consumer.on_state_sync(on_joints_received) |
| self.joint_input_consumer.on_error(on_error) |
|
|
| |
| for consumer in self.camera_consumers.values(): |
| consumer.on_error(on_error) |
|
|
| def _parse_joint_data(self, joints_data) -> list[float]: |
| """ |
| Parse joint data from Transport Server message. |
| |
| Args: |
| joints_data: Joint data from Transport Server message |
| |
| Returns: |
| List of 6 normalized joint values in standard order |
| |
| """ |
| return JointConfig.parse_joint_data(joints_data, self.policy_type) |
|
|
| async def _connect_to_rooms(self): |
| """Connect to all Transport Server rooms.""" |
| |
| for camera_name, consumer in self.camera_consumers.items(): |
| room_id = self.camera_room_ids[camera_name] |
| success = await consumer.connect( |
| self.workspace_id, room_id, f"{self.session_id}-{camera_name}-consumer" |
| ) |
| if not success: |
| msg = f"Failed to connect to camera room for {camera_name}" |
| logger.error(msg) |
| logger.info( |
| f"Connected to camera room for {camera_name}: {room_id} in workspace {self.workspace_id}" |
| ) |
|
|
| |
| success = await self.joint_input_consumer.connect( |
| self.workspace_id, |
| self.joint_input_room_id, |
| f"{self.session_id}-joint-input-consumer", |
| ) |
| if not success: |
| msg = "Failed to connect to joint input room" |
| logger.error(msg) |
|
|
| |
| success = await self.joint_output_producer.connect( |
| self.workspace_id, |
| self.joint_output_room_id, |
| f"{self.session_id}-joint-output-producer", |
| ) |
| if not success: |
| msg = "Failed to connect to joint output room" |
| logger.error(msg) |
|
|
| logger.info( |
| f"Connected to all rooms for session {self.session_id} in workspace {self.workspace_id}" |
| ) |
|
|
| async def start_inference(self): |
| """Start the inference loop.""" |
| if self.status != "ready": |
| msg = f"Session not ready. Current status: {self.status}" |
| logger.error(msg) |
|
|
| self.status = "running" |
| self.inference_task = asyncio.create_task(self._inference_loop()) |
| logger.info(f"Started inference for session {self.session_id}") |
|
|
| async def stop_inference(self): |
| """Stop the inference loop.""" |
| if self.inference_task: |
| self.inference_task.cancel() |
| with contextlib.suppress(asyncio.CancelledError): |
| await self.inference_task |
| self.inference_task = None |
|
|
| self.status = "stopped" |
| logger.info(f"Stopped inference for session {self.session_id}") |
|
|
| async def restart_inference(self): |
| """Restart the inference loop (stop if running, then start).""" |
| logger.info(f"Restarting inference for session {self.session_id}") |
|
|
| |
| await self.stop_inference() |
|
|
| |
| self._reset_session_state() |
|
|
| |
| await self.start_inference() |
|
|
| logger.info(f"Successfully restarted inference for session {self.session_id}") |
|
|
| def _reset_session_state(self): |
| """Reset session state for restart.""" |
| |
| self.action_queue.clear() |
|
|
| |
| self.complete_joint_state.fill(0.0) |
|
|
| |
| for camera_name in self.camera_names: |
| self.images_updated[camera_name] = False |
| self.joints_updated = False |
|
|
| |
| self.last_queue_cleanup = time.time() |
|
|
| |
| if self.inference_engine: |
| self.inference_engine.reset() |
|
|
| |
| self.stats["actions_in_queue"] = 0 |
|
|
| logger.info(f"Reset session state for {self.session_id}") |
|
|
| async def _timeout_monitor(self): |
| """Monitor session for inactivity timeout.""" |
| while True: |
| try: |
| await asyncio.sleep(60) |
|
|
| current_time = time.time() |
| inactive_time = current_time - self.last_activity_time |
|
|
| if inactive_time > self.timeout_seconds: |
| logger.warning( |
| f"Session {self.session_id} has been inactive for " |
| f"{inactive_time:.1f} seconds (timeout: {self.timeout_seconds}s). " |
| f"Marking for cleanup." |
| ) |
| self.status = "timeout" |
| |
| break |
| if inactive_time > self.timeout_seconds * 0.8: |
| logger.info( |
| f"Session {self.session_id} inactive for {inactive_time:.1f}s, " |
| f"will timeout in {self.timeout_seconds - inactive_time:.1f}s" |
| ) |
|
|
| except asyncio.CancelledError: |
| logger.info(f"Timeout monitor cancelled for session {self.session_id}") |
| break |
| except Exception: |
| logger.exception( |
| f"Error in timeout monitor for session {self.session_id}" |
| ) |
| break |
|
|
| def _all_cameras_have_data(self) -> bool: |
| """Check if we have received data from all cameras.""" |
| return all( |
| camera_name in self.latest_images for camera_name in self.camera_names |
| ) |
|
|
| async def _inference_loop(self): |
| """Main inference loop that processes incoming data and sends commands.""" |
| logger.info(f"Starting inference loop for session {self.session_id}") |
| logger.info( |
| f"Control frequency: {self.control_frequency_hz} Hz, Inference frequency: {self.inference_frequency_hz} Hz" |
| ) |
| logger.info( |
| f"Waiting for data from {len(self.camera_names)} cameras: {self.camera_names}" |
| ) |
|
|
| inference_counter = 0 |
| target_dt = 1.0 / self.control_frequency_hz |
| inference_interval = ( |
| self.control_frequency_hz // self.inference_frequency_hz |
| ) |
|
|
| while True: |
| loop_start_time = asyncio.get_event_loop().time() |
|
|
| |
| if ( |
| self._all_cameras_have_data() |
| and self.latest_joint_positions is not None |
| ): |
| |
| |
| should_run_inference = len(self.action_queue) == 0 or ( |
| inference_counter % inference_interval == 0 |
| and len(self.action_queue) < 3 |
| ) |
|
|
| if should_run_inference: |
| |
| if self.stats["inference_count"] % 10 == 0: |
| logger.info( |
| f"Running inference #{self.stats['inference_count']} for session {self.session_id} " |
| f"(queue length: {len(self.action_queue)})" |
| ) |
|
|
| |
| if self.latest_joint_positions.shape != (6,): |
| logger.error( |
| f"Invalid joint positions shape: {self.latest_joint_positions.shape}, " |
| f"expected (6,). Values: {self.latest_joint_positions}" |
| ) |
| |
| self.latest_joint_positions = self.complete_joint_state.copy() |
|
|
| |
| inference_kwargs = { |
| "images": self.latest_images, |
| "joint_positions": self.latest_joint_positions, |
| } |
|
|
| |
| if ( |
| self.policy_type in {"pi0", "pi0fast", "smolvla"} |
| and self.language_instruction |
| ): |
| inference_kwargs["task"] = self.language_instruction |
|
|
| |
| predicted_actions = await self.inference_engine.predict( |
| **inference_kwargs |
| ) |
|
|
| |
| if len(predicted_actions.shape) == 1: |
| |
| actions_to_queue = [predicted_actions] |
| else: |
| |
| actions_to_queue = predicted_actions[: self.n_action_steps] |
|
|
| |
| for action in actions_to_queue: |
| joint_commands = ( |
| self.inference_engine.get_joint_commands_with_names(action) |
| ) |
| self.action_queue.append(joint_commands) |
|
|
| self.stats["inference_count"] += 1 |
| |
| for camera_name in self.camera_names: |
| self.images_updated[camera_name] = False |
| self.joints_updated = False |
|
|
| |
| if len(self.action_queue) > 0: |
| joint_commands = self.action_queue.popleft() |
| |
| if self.stats["commands_sent"] % 100 == 0: |
| logger.info( |
| f"🤖 Sent {self.stats['commands_sent']} commands. Latest: {joint_commands[0]['name']}={joint_commands[0]['value']:.1f}" |
| ) |
|
|
| await self.joint_output_producer.send_joint_update(joint_commands) |
| self.stats["commands_sent"] += 1 |
| self.stats["actions_in_queue"] = len(self.action_queue) |
|
|
| |
| command_values = np.array( |
| [cmd["value"] for cmd in joint_commands], dtype=np.float32 |
| ) |
| self.last_command_values = command_values |
|
|
| |
| current_time = asyncio.get_event_loop().time() |
| if current_time - self.last_queue_cleanup > self.queue_cleanup_interval: |
| |
| if len(self.action_queue) > 80: |
| self.action_queue.clear() |
| self.last_queue_cleanup = current_time |
|
|
| inference_counter += 1 |
|
|
| |
| elapsed_time = asyncio.get_event_loop().time() - loop_start_time |
| sleep_time = target_dt - elapsed_time |
|
|
| if sleep_time > 0.001: |
| await asyncio.sleep(sleep_time) |
| elif sleep_time > 0: |
| busy_wait(sleep_time) |
| elif sleep_time < -0.01: |
| logger.warning( |
| f"Control loop running slow for session {self.session_id}: {elapsed_time * 1000:.1f}ms (target: {target_dt * 1000:.1f}ms)" |
| ) |
|
|
| async def cleanup(self): |
| """Clean up session resources.""" |
| logger.info(f"Cleaning up session {self.session_id}") |
|
|
| |
| if self.timeout_check_task: |
| self.timeout_check_task.cancel() |
| with contextlib.suppress(asyncio.CancelledError): |
| await self.timeout_check_task |
| self.timeout_check_task = None |
|
|
| |
| await self.stop_inference() |
|
|
| |
| for camera_name, consumer in self.camera_consumers.items(): |
| await consumer.stop_receiving() |
| await consumer.disconnect() |
| logger.info(f"Disconnected camera consumer for {camera_name}") |
|
|
| if self.joint_input_consumer: |
| await self.joint_input_consumer.disconnect() |
| if self.joint_output_producer: |
| await self.joint_output_producer.disconnect() |
|
|
| |
| if self.inference_engine: |
| del self.inference_engine |
| self.inference_engine = None |
|
|
| logger.info(f"Session {self.session_id} cleanup completed") |
|
|
| def get_status(self) -> dict: |
| """Get current session status.""" |
| status_dict = { |
| "session_id": self.session_id, |
| "status": self.status, |
| "policy_path": self.policy_path, |
| "policy_type": self.policy_type, |
| "camera_names": self.camera_names, |
| "workspace_id": self.workspace_id, |
| "rooms": { |
| "workspace_id": self.workspace_id, |
| "camera_room_ids": self.camera_room_ids, |
| "joint_input_room_id": self.joint_input_room_id, |
| "joint_output_room_id": self.joint_output_room_id, |
| }, |
| "stats": self.stats.copy(), |
| "error_message": self.error_message, |
| "joint_state": { |
| "complete_joint_state": self.complete_joint_state.tolist(), |
| "latest_joint_positions": ( |
| self.latest_joint_positions.tolist() |
| if self.latest_joint_positions is not None |
| else None |
| ), |
| "joint_state_shape": ( |
| self.latest_joint_positions.shape |
| if self.latest_joint_positions is not None |
| else None |
| ), |
| }, |
| } |
|
|
| return status_dict |
|
|
|
|
| class SessionManager: |
| """Manages multiple inference sessions and their lifecycle.""" |
|
|
| def __init__(self): |
| self.sessions: dict[str, InferenceSession] = {} |
| self.cleanup_task: asyncio.Task | None = None |
| self._start_cleanup_task() |
|
|
| def _start_cleanup_task(self): |
| """Start the automatic cleanup task for timed-out sessions.""" |
| try: |
| |
| loop = asyncio.get_running_loop() |
| self.cleanup_task = loop.create_task(self._periodic_cleanup()) |
| except RuntimeError: |
| |
| pass |
|
|
| async def _periodic_cleanup(self): |
| """Periodically check for and clean up timed-out sessions.""" |
| while True: |
| try: |
| await asyncio.sleep(300) |
|
|
| |
| timed_out_sessions = [] |
| for session_id, session in self.sessions.items(): |
| if session.status == "timeout": |
| timed_out_sessions.append(session_id) |
|
|
| |
| for session_id in timed_out_sessions: |
| logger.info(f"Auto-cleaning up timed-out session: {session_id}") |
| await self.delete_session(session_id) |
|
|
| except asyncio.CancelledError: |
| logger.info("Periodic cleanup task cancelled") |
| break |
|
|
| async def create_session( |
| self, |
| session_id: str, |
| policy_path: str, |
| transport_server_url: str, |
| camera_names: list[str] | None = None, |
| workspace_id: str | None = None, |
| policy_type: str = "act", |
| language_instruction: str | None = None, |
| ) -> dict[str, str]: |
| """Create a new inference session.""" |
| if camera_names is None: |
| camera_names = ["front"] |
| if session_id in self.sessions: |
| msg = f"Session {session_id} already exists" |
| raise ValueError(msg) |
|
|
| |
| video_temp_client = VideoProducer(transport_server_url) |
| camera_room_ids = {} |
|
|
| |
| if workspace_id: |
| target_workspace_id = workspace_id |
| logger.info( |
| f"Using provided workspace ID {target_workspace_id} for session {session_id}" |
| ) |
|
|
| |
| for camera_name in camera_names: |
| _, room_id = await video_temp_client.create_room( |
| workspace_id=target_workspace_id, |
| room_id=f"{session_id}-{camera_name}", |
| ) |
| camera_room_ids[camera_name] = room_id |
| else: |
| |
| first_camera = camera_names[0] |
| target_workspace_id, first_room_id = await video_temp_client.create_room( |
| room_id=f"{session_id}-{first_camera}" |
| ) |
| logger.info( |
| f"Generated new workspace ID {target_workspace_id} for session {session_id}" |
| ) |
|
|
| |
| camera_room_ids[first_camera] = first_room_id |
|
|
| |
| for camera_name in camera_names[1:]: |
| _, room_id = await video_temp_client.create_room( |
| workspace_id=target_workspace_id, |
| room_id=f"{session_id}-{camera_name}", |
| ) |
| camera_room_ids[camera_name] = room_id |
|
|
| |
| robotics_temp_client = RoboticsProducer(transport_server_url) |
| _, joint_input_room_id = await robotics_temp_client.create_room( |
| workspace_id=target_workspace_id, room_id=f"{session_id}-joint-input" |
| ) |
| _, joint_output_room_id = await robotics_temp_client.create_room( |
| workspace_id=target_workspace_id, room_id=f"{session_id}-joint-output" |
| ) |
|
|
| logger.info( |
| f"Created rooms for session {session_id} in workspace {target_workspace_id}:" |
| ) |
| for camera_name, room_id in camera_room_ids.items(): |
| logger.info(f" Camera room ({camera_name}): {room_id}") |
| logger.info(f" Joint input room: {joint_input_room_id}") |
| logger.info(f" Joint output room: {joint_output_room_id}") |
|
|
| |
| session = InferenceSession( |
| session_id=session_id, |
| policy_path=policy_path, |
| camera_names=camera_names, |
| transport_server_url=transport_server_url, |
| workspace_id=target_workspace_id, |
| camera_room_ids=camera_room_ids, |
| joint_input_room_id=joint_input_room_id, |
| joint_output_room_id=joint_output_room_id, |
| policy_type=policy_type, |
| language_instruction=language_instruction, |
| ) |
|
|
| |
| await session.initialize() |
|
|
| |
| self.sessions[session_id] = session |
|
|
| |
| if not self.cleanup_task or self.cleanup_task.done(): |
| self._start_cleanup_task() |
|
|
| return { |
| "workspace_id": target_workspace_id, |
| "camera_room_ids": camera_room_ids, |
| "joint_input_room_id": joint_input_room_id, |
| "joint_output_room_id": joint_output_room_id, |
| } |
|
|
| async def start_inference(self, session_id: str): |
| """Start inference for a specific session.""" |
| if session_id not in self.sessions: |
| msg = f"Session {session_id} not found" |
| raise KeyError(msg) |
| await self.sessions[session_id].start_inference() |
|
|
| async def stop_inference(self, session_id: str): |
| """Stop inference for a specific session.""" |
| if session_id not in self.sessions: |
| msg = f"Session {session_id} not found" |
| raise KeyError(msg) |
| await self.sessions[session_id].stop_inference() |
|
|
| async def restart_inference(self, session_id: str): |
| """Restart inference for a specific session.""" |
| if session_id not in self.sessions: |
| msg = f"Session {session_id} not found" |
| raise KeyError(msg) |
| await self.sessions[session_id].restart_inference() |
|
|
| async def delete_session(self, session_id: str): |
| """Delete a session and clean up all resources.""" |
| if session_id not in self.sessions: |
| msg = f"Session {session_id} not found" |
| raise KeyError(msg) |
|
|
| session = self.sessions[session_id] |
| await session.cleanup() |
| del self.sessions[session_id] |
| logger.info(f"Deleted session {session_id}") |
|
|
| async def list_sessions(self) -> list[dict]: |
| """List all sessions with their status.""" |
| return [session.get_status() for session in self.sessions.values()] |
|
|
| async def cleanup_all_sessions(self): |
| """Clean up all sessions.""" |
| logger.info("Cleaning up all sessions...") |
|
|
| |
| if self.cleanup_task: |
| self.cleanup_task.cancel() |
| with contextlib.suppress(asyncio.CancelledError): |
| await self.cleanup_task |
| self.cleanup_task = None |
|
|
| |
| for session_id in list(self.sessions.keys()): |
| await self.delete_session(session_id) |
| logger.info("All sessions cleaned up") |
|
|