| import base64 |
| import json |
| import os |
| from io import BytesIO |
| from typing import Any, Dict, List |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| from openpi.policies import policy_config |
| from openpi.training import config as train_config |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| """ |
| Initialize the handler for pi0 model inference using openpi infrastructure. |
| |
| Args: |
| path: Path to the model weights directory |
| """ |
| |
| model_path = os.environ.get("MODEL_PATH", path) |
| if not model_path: |
| model_path = "weights/pi0" |
|
|
| |
| config_path = os.path.join(model_path, "config.json") |
| with open(config_path, "r") as f: |
| model_config = json.load(f) |
|
|
| model_type = model_config.get("type", "pi0") |
|
|
| |
| |
| if model_type == "pi0": |
| self.train_config = train_config.get_config("pi0") |
| else: |
| |
| self.train_config = train_config.get_config("pi0") |
|
|
| |
| |
| self.policy = policy_config.create_trained_policy( |
| self.train_config, |
| model_path, |
| pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" |
| ) |
|
|
| |
| self.default_num_steps = 50 |
|
|
| def _decode_base64_image(self, base64_str: str) -> np.ndarray: |
| """ |
| Decode base64 image string to numpy array. |
| |
| Args: |
| base64_str: Base64 encoded image string |
| |
| Returns: |
| numpy array of shape (H, W, 3) with values in [0, 255] |
| """ |
| |
| if base64_str.startswith("data:image"): |
| base64_str = base64_str.split(",", 1)[1] |
|
|
| |
| image_bytes = base64.b64decode(base64_str) |
|
|
| |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| image_array = np.array(image) |
|
|
| return image_array |
|
|
| def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]: |
| """ |
| Prepare observation dictionary in the format expected by openpi. |
| |
| Args: |
| images: Dictionary mapping camera names to base64 encoded images |
| state: List of robot state values |
| prompt: Optional text prompt |
| |
| Returns: |
| Observation dictionary in openpi format |
| """ |
| |
| processed_images = {} |
|
|
| |
| |
| camera_mapping = { |
| "camera0": "cam_high", |
| "camera1": "cam_left_wrist", |
| "camera2": "cam_right_wrist", |
| |
| "base_camera": "cam_high", |
| "left_wrist": "cam_left_wrist", |
| "right_wrist": "cam_right_wrist", |
| |
| "cam_high": "cam_high", |
| "cam_left_wrist": "cam_left_wrist", |
| "cam_right_wrist": "cam_right_wrist" |
| } |
|
|
| for input_name, image_b64 in images.items(): |
| |
| openpi_name = camera_mapping.get(input_name, input_name) |
|
|
| |
| image_array = self._decode_base64_image(image_b64) |
|
|
| |
| if image_array.shape[:2] != (224, 224): |
| image_pil = Image.fromarray(image_array) |
| image_resized = image_pil.resize((224, 224)) |
| image_array = np.array(image_resized) |
|
|
| |
| processed_images[openpi_name] = image_array.astype(np.uint8) |
|
|
| |
| required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"] |
| for cam_name in required_cameras: |
| if cam_name not in processed_images: |
| |
| processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8) |
|
|
| |
| state_array = np.array(state, dtype=np.float32) |
|
|
| |
| observation = { |
| "state": state_array, |
| "images": processed_images, |
| } |
|
|
| |
| if prompt: |
| observation["prompt"] = prompt |
|
|
| return observation |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Main inference function called by HuggingFace endpoint. |
| |
| Args: |
| data: Input data dictionary containing: |
| - inputs: Dictionary with: |
| - images: Dict mapping camera names to base64 encoded images |
| - state: List of robot state values |
| - prompt: Optional text prompt |
| - num_actions: Optional, number of actions to predict (default: 50) |
| - noise: Optional, noise array for sampling |
| |
| Returns: |
| List containing prediction results |
| """ |
| try: |
| inputs = data.get("inputs", {}) |
|
|
| |
| images = inputs.get("images", {}) |
| state = inputs.get("state", []) |
| prompt = inputs.get("prompt", "") |
| num_actions = inputs.get("num_actions", self.default_num_steps) |
| noise_input = inputs.get("noise", None) |
|
|
| |
| if not images: |
| raise ValueError("No images provided") |
| if not state: |
| raise ValueError("No state provided") |
|
|
| |
| observation = self._prepare_observation(images, state, prompt) |
|
|
| |
| noise = None |
| if noise_input is not None: |
| noise = np.array(noise_input, dtype=np.float32) |
|
|
| |
| |
| result = self.policy.infer(observation, noise=noise) |
|
|
| |
| actions = result["actions"] |
|
|
| |
| if isinstance(actions, np.ndarray): |
| actions_list = actions.tolist() |
| else: |
| actions_list = actions |
|
|
| |
| return [{ |
| "actions": actions_list, |
| "num_actions": len(actions_list), |
| "action_horizon": len(actions_list), |
| "action_dim": len(actions_list[0]) if actions_list else 0, |
| "success": True, |
| "metadata": { |
| "model_type": self.train_config.model.model_type.value, |
| "policy_metadata": getattr(self.policy, '_metadata', {}) |
| } |
| }] |
|
|
| except Exception as e: |
| return [{ |
| "error": str(e), |
| "success": False |
| }] |