| |
|
|
| |
| |
|
|
| """ |
| RobotHub Inference Server - Model inference engines for various policy types. |
| |
| This module provides unified inference engines for different policy architectures |
| including ACT, Pi0, SmolVLA, and Diffusion policies. |
| """ |
|
|
| import logging |
|
|
| from .act_inference import ACTInferenceEngine |
| from .base_inference import BaseInferenceEngine |
| from .diffusion_inference import DiffusionInferenceEngine |
| from .joint_config import JointConfig |
| from .pi0_inference import Pi0InferenceEngine |
| from .pi0fast_inference import Pi0FastInferenceEngine |
| from .smolvla_inference import SmolVLAInferenceEngine |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| __all__ = [ |
| "ACTInferenceEngine", |
| "BaseInferenceEngine", |
| "DiffusionInferenceEngine", |
| "JointConfig", |
| "Pi0FastInferenceEngine", |
| "Pi0InferenceEngine", |
| "SmolVLAInferenceEngine", |
| "get_inference_engine", |
| ] |
|
|
|
|
| POLICY_ENGINES = { |
| "act": ACTInferenceEngine, |
| "pi0": Pi0InferenceEngine, |
| "pi0fast": Pi0FastInferenceEngine, |
| "smolvla": SmolVLAInferenceEngine, |
| "diffusion": DiffusionInferenceEngine, |
| } |
|
|
|
|
| def get_inference_engine(policy_type: str, **kwargs) -> BaseInferenceEngine: |
| """ |
| Get an inference engine instance for the specified policy type. |
| |
| Args: |
| policy_type: Type of policy ('act', 'pi0', 'pi0fast', 'smolvla', 'diffusion') |
| **kwargs: Additional arguments passed to the engine constructor |
| |
| Returns: |
| BaseInferenceEngine: Configured inference engine instance |
| |
| Raises: |
| ValueError: If policy_type is not supported or not available |
| |
| """ |
| if policy_type not in POLICY_ENGINES: |
| available = list(POLICY_ENGINES.keys()) |
| if not available: |
| msg = "No policy engines are available. Check your LeRobot installation." |
| else: |
| msg = f"Unsupported policy type: {policy_type}. Available: {available}" |
| raise ValueError(msg) |
|
|
| engine_class = POLICY_ENGINES[policy_type] |
| return engine_class(**kwargs) |
|
|