File size: 5,236 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Tool image cache module for storing and retrieving images returned by tools.

This module allows LLM to review images before deciding whether to send them to users.
"""

import base64
import os
import time
from dataclasses import dataclass, field
from typing import ClassVar

from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path


@dataclass
class CachedImage:
    """Represents a cached image from a tool call."""

    tool_call_id: str
    """The tool call ID that produced this image."""
    tool_name: str
    """The name of the tool that produced this image."""
    file_path: str
    """The file path where the image is stored."""
    mime_type: str
    """The MIME type of the image."""
    created_at: float = field(default_factory=time.time)
    """Timestamp when the image was cached."""


class ToolImageCache:
    """Manages cached images from tool calls.

    Images are stored in data/temp/tool_images/ and can be retrieved by file path.
    """

    _instance: ClassVar["ToolImageCache | None"] = None
    CACHE_DIR_NAME: ClassVar[str] = "tool_images"
    # Cache expiry time in seconds (1 hour)
    CACHE_EXPIRY: ClassVar[int] = 3600

    def __new__(cls) -> "ToolImageCache":
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialized = False
        return cls._instance

    def __init__(self) -> None:
        if self._initialized:
            return
        self._initialized = True
        self._cache_dir = os.path.join(get_astrbot_temp_path(), self.CACHE_DIR_NAME)
        os.makedirs(self._cache_dir, exist_ok=True)
        logger.debug(f"ToolImageCache initialized, cache dir: {self._cache_dir}")

    def _get_file_extension(self, mime_type: str) -> str:
        """Get file extension from MIME type."""
        mime_to_ext = {
            "image/png": ".png",
            "image/jpeg": ".jpg",
            "image/jpg": ".jpg",
            "image/gif": ".gif",
            "image/webp": ".webp",
            "image/bmp": ".bmp",
            "image/svg+xml": ".svg",
        }
        return mime_to_ext.get(mime_type.lower(), ".png")

    def save_image(
        self,
        base64_data: str,
        tool_call_id: str,
        tool_name: str,
        index: int = 0,
        mime_type: str = "image/png",
    ) -> CachedImage:
        """Save an image to cache and return the cached image info.

        Args:
            base64_data: Base64 encoded image data.
            tool_call_id: The tool call ID that produced this image.
            tool_name: The name of the tool that produced this image.
            index: The index of the image (for multiple images from same tool call).
            mime_type: The MIME type of the image.

        Returns:
            CachedImage object with file path.
        """
        ext = self._get_file_extension(mime_type)
        file_name = f"{tool_call_id}_{index}{ext}"
        file_path = os.path.join(self._cache_dir, file_name)

        # Decode and save the image
        try:
            image_bytes = base64.b64decode(base64_data)
            with open(file_path, "wb") as f:
                f.write(image_bytes)
            logger.debug(f"Saved tool image to: {file_path}")
        except Exception as e:
            logger.error(f"Failed to save tool image: {e}")
            raise

        return CachedImage(
            tool_call_id=tool_call_id,
            tool_name=tool_name,
            file_path=file_path,
            mime_type=mime_type,
        )

    def get_image_base64_by_path(
        self, file_path: str, mime_type: str = "image/png"
    ) -> tuple[str, str] | None:
        """Read an image file and return its base64 encoded data.

        Args:
            file_path: The file path of the cached image.
            mime_type: The MIME type of the image.

        Returns:
            Tuple of (base64_data, mime_type) if found, None otherwise.
        """
        if not os.path.exists(file_path):
            return None

        try:
            with open(file_path, "rb") as f:
                image_bytes = f.read()
            base64_data = base64.b64encode(image_bytes).decode("utf-8")
            return base64_data, mime_type
        except Exception as e:
            logger.error(f"Failed to read cached image {file_path}: {e}")
            return None

    def cleanup_expired(self) -> int:
        """Clean up expired cached images.

        Returns:
            Number of images cleaned up.
        """
        now = time.time()
        cleaned = 0

        try:
            for file_name in os.listdir(self._cache_dir):
                file_path = os.path.join(self._cache_dir, file_name)
                if os.path.isfile(file_path):
                    file_age = now - os.path.getmtime(file_path)
                    if file_age > self.CACHE_EXPIRY:
                        os.remove(file_path)
                        cleaned += 1
        except Exception as e:
            logger.warning(f"Error during cache cleanup: {e}")

        if cleaned:
            logger.info(f"Cleaned up {cleaned} expired cached images")

        return cleaned


# Global singleton instance
tool_image_cache = ToolImageCache()