| """Tool image cache module for storing and retrieving images returned by tools. |
| |
| This module allows LLM to review images before deciding whether to send them to users. |
| """ |
|
|
| import base64 |
| import os |
| import time |
| from dataclasses import dataclass, field |
| from typing import ClassVar |
|
|
| from astrbot import logger |
| from astrbot.core.utils.astrbot_path import get_astrbot_temp_path |
|
|
|
|
| @dataclass |
| class CachedImage: |
| """Represents a cached image from a tool call.""" |
|
|
| tool_call_id: str |
| """The tool call ID that produced this image.""" |
| tool_name: str |
| """The name of the tool that produced this image.""" |
| file_path: str |
| """The file path where the image is stored.""" |
| mime_type: str |
| """The MIME type of the image.""" |
| created_at: float = field(default_factory=time.time) |
| """Timestamp when the image was cached.""" |
|
|
|
|
| class ToolImageCache: |
| """Manages cached images from tool calls. |
| |
| Images are stored in data/temp/tool_images/ and can be retrieved by file path. |
| """ |
|
|
| _instance: ClassVar["ToolImageCache | None"] = None |
| CACHE_DIR_NAME: ClassVar[str] = "tool_images" |
| |
| CACHE_EXPIRY: ClassVar[int] = 3600 |
|
|
| def __new__(cls) -> "ToolImageCache": |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| cls._instance._initialized = False |
| return cls._instance |
|
|
| def __init__(self) -> None: |
| if self._initialized: |
| return |
| self._initialized = True |
| self._cache_dir = os.path.join(get_astrbot_temp_path(), self.CACHE_DIR_NAME) |
| os.makedirs(self._cache_dir, exist_ok=True) |
| logger.debug(f"ToolImageCache initialized, cache dir: {self._cache_dir}") |
|
|
| def _get_file_extension(self, mime_type: str) -> str: |
| """Get file extension from MIME type.""" |
| mime_to_ext = { |
| "image/png": ".png", |
| "image/jpeg": ".jpg", |
| "image/jpg": ".jpg", |
| "image/gif": ".gif", |
| "image/webp": ".webp", |
| "image/bmp": ".bmp", |
| "image/svg+xml": ".svg", |
| } |
| return mime_to_ext.get(mime_type.lower(), ".png") |
|
|
| def save_image( |
| self, |
| base64_data: str, |
| tool_call_id: str, |
| tool_name: str, |
| index: int = 0, |
| mime_type: str = "image/png", |
| ) -> CachedImage: |
| """Save an image to cache and return the cached image info. |
| |
| Args: |
| base64_data: Base64 encoded image data. |
| tool_call_id: The tool call ID that produced this image. |
| tool_name: The name of the tool that produced this image. |
| index: The index of the image (for multiple images from same tool call). |
| mime_type: The MIME type of the image. |
| |
| Returns: |
| CachedImage object with file path. |
| """ |
| ext = self._get_file_extension(mime_type) |
| file_name = f"{tool_call_id}_{index}{ext}" |
| file_path = os.path.join(self._cache_dir, file_name) |
|
|
| |
| try: |
| image_bytes = base64.b64decode(base64_data) |
| with open(file_path, "wb") as f: |
| f.write(image_bytes) |
| logger.debug(f"Saved tool image to: {file_path}") |
| except Exception as e: |
| logger.error(f"Failed to save tool image: {e}") |
| raise |
|
|
| return CachedImage( |
| tool_call_id=tool_call_id, |
| tool_name=tool_name, |
| file_path=file_path, |
| mime_type=mime_type, |
| ) |
|
|
| def get_image_base64_by_path( |
| self, file_path: str, mime_type: str = "image/png" |
| ) -> tuple[str, str] | None: |
| """Read an image file and return its base64 encoded data. |
| |
| Args: |
| file_path: The file path of the cached image. |
| mime_type: The MIME type of the image. |
| |
| Returns: |
| Tuple of (base64_data, mime_type) if found, None otherwise. |
| """ |
| if not os.path.exists(file_path): |
| return None |
|
|
| try: |
| with open(file_path, "rb") as f: |
| image_bytes = f.read() |
| base64_data = base64.b64encode(image_bytes).decode("utf-8") |
| return base64_data, mime_type |
| except Exception as e: |
| logger.error(f"Failed to read cached image {file_path}: {e}") |
| return None |
|
|
| def cleanup_expired(self) -> int: |
| """Clean up expired cached images. |
| |
| Returns: |
| Number of images cleaned up. |
| """ |
| now = time.time() |
| cleaned = 0 |
|
|
| try: |
| for file_name in os.listdir(self._cache_dir): |
| file_path = os.path.join(self._cache_dir, file_name) |
| if os.path.isfile(file_path): |
| file_age = now - os.path.getmtime(file_path) |
| if file_age > self.CACHE_EXPIRY: |
| os.remove(file_path) |
| cleaned += 1 |
| except Exception as e: |
| logger.warning(f"Error during cache cleanup: {e}") |
|
|
| if cleaned: |
| logger.info(f"Cleaned up {cleaned} expired cached images") |
|
|
| return cleaned |
|
|
|
|
| |
| tool_image_cache = ToolImageCache() |
|
|