| import asyncio |
| import logging |
| from abc import ABC, abstractmethod |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from .joint_config import JointConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BaseInferenceEngine(ABC): |
| """ |
| Base class for all inference engines. |
| |
| This class provides common functionality for: |
| - Image preprocessing and normalization |
| - Joint data handling and validation |
| - Model loading and management |
| - Action prediction interface |
| """ |
|
|
| def __init__( |
| self, |
| policy_path: str, |
| camera_names: list[str], |
| device: str | None = None, |
| ): |
| self.policy_path = policy_path |
| self.camera_names = camera_names |
|
|
| |
| if device is None: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| self.device = torch.device(device) |
|
|
| logger.info(f"Using device: {self.device}") |
|
|
| |
| self.policy = None |
| self.image_transforms = {} |
| self.stats = None |
|
|
| |
| self.is_loaded = False |
| self.last_images = {} |
| self.last_joint_positions = None |
|
|
| @abstractmethod |
| async def load_policy(self): |
| """Load the policy model. Must be implemented by subclasses.""" |
|
|
| @abstractmethod |
| async def predict( |
| self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs |
| ) -> np.ndarray: |
| """Run inference. Must be implemented by subclasses.""" |
|
|
| def preprocess_images( |
| self, images: dict[str, np.ndarray] |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Preprocess images for inference. |
| |
| Args: |
| images: Dictionary of {camera_name: rgb_image_array} |
| |
| Returns: |
| Dictionary of {camera_name: preprocessed_tensor} |
| |
| """ |
| processed_images = {} |
|
|
| for camera_name, image in images.items(): |
| if camera_name not in self.camera_names: |
| logger.warning(f"Unexpected camera: {camera_name}") |
| continue |
|
|
| |
| if isinstance(image, np.ndarray): |
| if image.dtype != np.uint8: |
| image = (image * 255).astype(np.uint8) |
| pil_image = Image.fromarray(image) |
| else: |
| pil_image = image |
|
|
| |
| if camera_name in self.image_transforms: |
| tensor = self.image_transforms[camera_name](pil_image) |
| else: |
| |
| tensor = self._default_image_transform(pil_image) |
|
|
| processed_images[camera_name] = tensor.to(self.device) |
|
|
| return processed_images |
|
|
| def _default_image_transform(self, image: Image.Image) -> torch.Tensor: |
| """Default image preprocessing.""" |
| |
| image = image.resize((224, 224), Image.Resampling.LANCZOS) |
|
|
| |
| tensor = torch.from_numpy(np.array(image)).float() / 255.0 |
|
|
| |
| if len(tensor.shape) == 3: |
| tensor = tensor.permute(2, 0, 1) |
|
|
| return tensor |
|
|
| def preprocess_joint_positions(self, joint_positions: np.ndarray) -> torch.Tensor: |
| """ |
| Preprocess joint positions for inference. |
| |
| Args: |
| joint_positions: Array of joint positions in standard order |
| |
| Returns: |
| Preprocessed joint tensor |
| |
| """ |
| |
| joint_positions = JointConfig.validate_joint_values(joint_positions) |
|
|
| |
| joint_tensor = torch.from_numpy(joint_positions).float().to(self.device) |
|
|
| |
| if self.stats and hasattr(self.stats, "joint_stats"): |
| joint_tensor = self._normalize_joints(joint_tensor) |
|
|
| return joint_tensor |
|
|
| def _normalize_joints(self, joint_tensor: torch.Tensor) -> torch.Tensor: |
| """Normalize joint values using dataset statistics.""" |
| |
| |
| return joint_tensor |
|
|
| def get_joint_commands_with_names(self, action: np.ndarray) -> list[dict]: |
| """ |
| Convert action array to joint commands with names. |
| |
| Args: |
| action: Array of joint actions in standard order |
| |
| Returns: |
| List of joint command dictionaries |
| |
| """ |
| |
| action = JointConfig.validate_joint_values(action) |
|
|
| |
| return JointConfig.create_joint_commands(action) |
|
|
| def reset(self): |
| """Reset the inference engine state.""" |
| self.last_images = {} |
| self.last_joint_positions = None |
|
|
| |
| if hasattr(self.policy, "reset"): |
| self.policy.reset() |
|
|
| async def cleanup(self): |
| """Clean up resources.""" |
| if self.policy: |
| del self.policy |
| self.policy = None |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| self.is_loaded = False |
| logger.info(f"Cleaned up inference engine for {self.policy_path}") |
|
|
| def __del__(self): |
| """Destructor to ensure cleanup.""" |
| if hasattr(self, "policy") and self.policy: |
| asyncio.create_task(self.cleanup()) |
|
|