| import logging |
|
|
| import numpy as np |
| import torch |
| from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy |
| from lerobot.common.utils.utils import init_logging |
| from torchvision import transforms |
|
|
| from .base_inference import BaseInferenceEngine |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class Pi0FastInferenceEngine(BaseInferenceEngine): |
| """ |
| Pi0Fast (Physical Intelligence Fast) inference engine. |
| |
| Handles image preprocessing, joint normalization, and fast action prediction |
| for Pi0Fast models with language instruction support. |
| """ |
|
|
| def __init__( |
| self, |
| policy_path: str, |
| camera_names: list[str], |
| device: str | None = None, |
| language_instruction: str | None = None, |
| ): |
| super().__init__(policy_path, camera_names, device) |
|
|
| |
| self.language_instruction = language_instruction |
| self.supports_language = True |
|
|
| async def load_policy(self): |
| """Load the Pi0Fast policy from the specified path.""" |
| logger.info(f"Loading Pi0Fast policy from: {self.policy_path}") |
|
|
| |
| init_logging() |
|
|
| |
| self.policy = PI0FASTPolicy.from_pretrained(self.policy_path) |
| self.policy.to(self.device) |
| self.policy.eval() |
|
|
| |
| if hasattr(self.policy, "config"): |
| self._setup_image_transforms() |
|
|
| self.is_loaded = True |
| logger.info(f"Pi0Fast policy loaded successfully on {self.device}") |
|
|
| def _setup_image_transforms(self): |
| """Set up image transforms based on the policy configuration.""" |
| |
| config = self.policy.config |
| image_size = getattr(config, "image_size", 224) |
|
|
| |
| for camera_name in self.camera_names: |
| |
| if hasattr(self.policy, "image_processor"): |
| self.image_transforms[camera_name] = self.policy.image_processor |
| else: |
| |
| self.image_transforms[camera_name] = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| ), |
| ]) |
|
|
| async def predict( |
| self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs |
| ) -> np.ndarray: |
| """ |
| Run Pi0Fast inference to predict actions. |
| |
| Args: |
| images: Dictionary of {camera_name: rgb_image_array} |
| joint_positions: Current joint positions in LeRobot standard order |
| task: Optional language instruction (overrides instance language_instruction) |
| |
| Returns: |
| Array of predicted actions |
| |
| """ |
| if not self.is_loaded: |
| msg = "Policy not loaded. Call load_policy() first." |
| raise RuntimeError(msg) |
|
|
| |
| processed_images = self.preprocess_images(images) |
| processed_joints = self.preprocess_joint_positions(joint_positions) |
|
|
| |
| task = kwargs.get("task", self.language_instruction) |
|
|
| |
| batch = self._prepare_batch(processed_images, processed_joints, task) |
|
|
| |
| with torch.no_grad(): |
| action = self.policy.predict(batch) |
|
|
| |
| if isinstance(action, torch.Tensor): |
| action = action.cpu().numpy() |
|
|
| return action |
|
|
| def _prepare_batch( |
| self, |
| images: dict[str, torch.Tensor], |
| joints: torch.Tensor, |
| task: str | None = None, |
| ) -> dict: |
| """ |
| Prepare batch inputs for Pi0Fast model. |
| |
| Args: |
| images: Preprocessed images |
| joints: Preprocessed joint positions |
| task: Language instruction |
| |
| Returns: |
| Batch dictionary for Pi0Fast model |
| |
| """ |
| batch = {} |
|
|
| |
| for camera_name, image_tensor in images.items(): |
| |
| if len(image_tensor.shape) == 3: |
| image_tensor = image_tensor.unsqueeze(0) |
| batch[f"observation.images.{camera_name}"] = image_tensor |
|
|
| |
| if len(joints.shape) == 1: |
| joints = joints.unsqueeze(0) |
| batch["observation.state"] = joints |
|
|
| |
| if task: |
| batch["task"] = task |
|
|
| return batch |
|
|