| import abc |
| import datetime |
| import typing as T |
| from contextlib import asynccontextmanager |
| from dataclasses import dataclass |
|
|
| from deprecated import deprecated |
| from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine |
|
|
| from astrbot.core.db.po import ( |
| ApiKey, |
| Attachment, |
| ChatUIProject, |
| CommandConfig, |
| CommandConflict, |
| ConversationV2, |
| CronJob, |
| Persona, |
| PersonaFolder, |
| PlatformMessageHistory, |
| PlatformSession, |
| PlatformStat, |
| Preference, |
| SessionProjectRelation, |
| Stats, |
| ) |
|
|
|
|
| @dataclass |
| class BaseDatabase(abc.ABC): |
| """数据库基类""" |
|
|
| DATABASE_URL = "" |
|
|
| def __init__(self) -> None: |
| self.engine = create_async_engine( |
| self.DATABASE_URL, |
| echo=False, |
| future=True, |
| ) |
| self.AsyncSessionLocal = async_sessionmaker( |
| self.engine, |
| class_=AsyncSession, |
| expire_on_commit=False, |
| ) |
|
|
| async def initialize(self) -> None: |
| """初始化数据库连接""" |
|
|
| @asynccontextmanager |
| async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]: |
| """Get a database session.""" |
| if not self.inited: |
| await self.initialize() |
| self.inited = True |
| async with self.AsyncSessionLocal() as session: |
| yield session |
|
|
| @deprecated(version="4.0.0", reason="Use get_platform_stats instead") |
| @abc.abstractmethod |
| def get_base_stats(self, offset_sec: int = 86400) -> Stats: |
| """获取基础统计数据""" |
| raise NotImplementedError |
|
|
| @deprecated(version="4.0.0", reason="Use get_platform_stats instead") |
| @abc.abstractmethod |
| def get_total_message_count(self) -> int: |
| """获取总消息数""" |
| raise NotImplementedError |
|
|
| @deprecated(version="4.0.0", reason="Use get_platform_stats instead") |
| @abc.abstractmethod |
| def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: |
| """获取基础统计数据(合并)""" |
| raise NotImplementedError |
|
|
| |
|
|
| @abc.abstractmethod |
| async def insert_platform_stats( |
| self, |
| platform_id: str, |
| platform_type: str, |
| count: int = 1, |
| timestamp: datetime.datetime | None = None, |
| ) -> None: |
| """Insert a new platform statistic record.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def count_platform_stats(self) -> int: |
| """Count the number of platform statistics records.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: |
| """Get platform statistics within the specified offset in seconds and group by platform_id.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_conversations( |
| self, |
| user_id: str | None = None, |
| platform_id: str | None = None, |
| ) -> list[ConversationV2]: |
| """Get all conversations for a specific user and platform_id(optional). |
| |
| content is not included in the result. |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def get_conversation_by_id(self, cid: str) -> ConversationV2: |
| """Get a specific conversation by its ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_all_conversations( |
| self, |
| page: int = 1, |
| page_size: int = 20, |
| ) -> list[ConversationV2]: |
| """Get all conversations with pagination.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_filtered_conversations( |
| self, |
| page: int = 1, |
| page_size: int = 20, |
| platform_ids: list[str] | None = None, |
| search_query: str = "", |
| **kwargs, |
| ) -> tuple[list[ConversationV2], int]: |
| """Get conversations filtered by platform IDs and search query.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def create_conversation( |
| self, |
| user_id: str, |
| platform_id: str, |
| content: list[dict] | None = None, |
| title: str | None = None, |
| persona_id: str | None = None, |
| cid: str | None = None, |
| created_at: datetime.datetime | None = None, |
| updated_at: datetime.datetime | None = None, |
| ) -> ConversationV2: |
| """Create a new conversation.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def update_conversation( |
| self, |
| cid: str, |
| title: str | None = None, |
| persona_id: str | None = None, |
| content: list[dict] | None = None, |
| token_usage: int | None = None, |
| ) -> None: |
| """Update a conversation's history.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_conversation(self, cid: str) -> None: |
| """Delete a conversation by its ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_conversations_by_user_id(self, user_id: str) -> None: |
| """Delete all conversations for a specific user.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def insert_platform_message_history( |
| self, |
| platform_id: str, |
| user_id: str, |
| content: dict, |
| sender_id: str | None = None, |
| sender_name: str | None = None, |
| ) -> PlatformMessageHistory: |
| """Insert a new platform message history record.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_platform_message_offset( |
| self, |
| platform_id: str, |
| user_id: str, |
| offset_sec: int = 86400, |
| ) -> None: |
| """Delete platform message history records newer than the specified offset.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_platform_message_history( |
| self, |
| platform_id: str, |
| user_id: str, |
| page: int = 1, |
| page_size: int = 20, |
| ) -> list[PlatformMessageHistory]: |
| """Get platform message history for a specific user.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_platform_message_history_by_id( |
| self, |
| message_id: int, |
| ) -> PlatformMessageHistory | None: |
| """Get a platform message history record by its ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def insert_attachment( |
| self, |
| path: str, |
| type: str, |
| mime_type: str, |
| ): |
| """Insert a new attachment record.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_attachment_by_id(self, attachment_id: str) -> Attachment: |
| """Get an attachment by its ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]: |
| """Get multiple attachments by their IDs.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_attachment(self, attachment_id: str) -> bool: |
| """Delete an attachment by its ID. |
| |
| Returns True if the attachment was deleted, False if it was not found. |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_attachments(self, attachment_ids: list[str]) -> int: |
| """Delete multiple attachments by their IDs. |
| |
| Returns the number of attachments deleted. |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def create_api_key( |
| self, |
| name: str, |
| key_hash: str, |
| key_prefix: str, |
| scopes: list[str] | None, |
| created_by: str, |
| expires_at: datetime.datetime | None = None, |
| ) -> ApiKey: |
| """Create a new API key record.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def list_api_keys(self) -> list[ApiKey]: |
| """List all API keys.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_api_key_by_id(self, key_id: str) -> ApiKey | None: |
| """Get an API key by key_id.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None: |
| """Get an active API key by hash (not revoked, not expired).""" |
| ... |
|
|
| @abc.abstractmethod |
| async def touch_api_key(self, key_id: str) -> None: |
| """Update last_used_at of an API key.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def revoke_api_key(self, key_id: str) -> bool: |
| """Revoke an API key. |
| |
| Returns True when the key exists and is updated. |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_api_key(self, key_id: str) -> bool: |
| """Delete an API key. |
| |
| Returns True when the key exists and is deleted. |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def insert_persona( |
| self, |
| persona_id: str, |
| system_prompt: str, |
| begin_dialogs: list[str] | None = None, |
| tools: list[str] | None = None, |
| skills: list[str] | None = None, |
| custom_error_message: str | None = None, |
| folder_id: str | None = None, |
| sort_order: int = 0, |
| ) -> Persona: |
| """Insert a new persona record. |
| |
| Args: |
| persona_id: Unique identifier for the persona |
| system_prompt: System prompt for the persona |
| begin_dialogs: Optional list of initial dialog strings |
| tools: Optional list of tool names (None means all tools, [] means no tools) |
| skills: Optional list of skill names (None means all skills, [] means no skills) |
| custom_error_message: Optional persona-level fallback error message |
| folder_id: Optional folder ID to place the persona in (None means root) |
| sort_order: Sort order within the folder (default 0) |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def get_persona_by_id(self, persona_id: str) -> Persona: |
| """Get a persona by its ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_personas(self) -> list[Persona]: |
| """Get all personas for a specific bot.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def update_persona( |
| self, |
| persona_id: str, |
| system_prompt: str | None = None, |
| begin_dialogs: list[str] | None = None, |
| tools: list[str] | None = None, |
| skills: list[str] | None = None, |
| custom_error_message: str | None = None, |
| ) -> Persona | None: |
| """Update a persona's system prompt or begin dialogs.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_persona(self, persona_id: str) -> None: |
| """Delete a persona by its ID.""" |
| ... |
|
|
| |
| |
| |
|
|
| @abc.abstractmethod |
| async def insert_persona_folder( |
| self, |
| name: str, |
| parent_id: str | None = None, |
| description: str | None = None, |
| sort_order: int = 0, |
| ) -> PersonaFolder: |
| """Insert a new persona folder.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_persona_folder_by_id(self, folder_id: str) -> PersonaFolder | None: |
| """Get a persona folder by its folder_id.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_persona_folders( |
| self, parent_id: str | None = None |
| ) -> list[PersonaFolder]: |
| """Get all persona folders, optionally filtered by parent_id.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_all_persona_folders(self) -> list[PersonaFolder]: |
| """Get all persona folders.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def update_persona_folder( |
| self, |
| folder_id: str, |
| name: str | None = None, |
| parent_id: T.Any = None, |
| description: T.Any = None, |
| sort_order: int | None = None, |
| ) -> PersonaFolder | None: |
| """Update a persona folder.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_persona_folder(self, folder_id: str) -> None: |
| """Delete a persona folder by its folder_id.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def move_persona_to_folder( |
| self, persona_id: str, folder_id: str | None |
| ) -> Persona | None: |
| """Move a persona to a folder (or root if folder_id is None).""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_personas_by_folder( |
| self, folder_id: str | None = None |
| ) -> list[Persona]: |
| """Get all personas in a specific folder.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def batch_update_sort_order( |
| self, |
| items: list[dict], |
| ) -> None: |
| """Batch update sort_order for personas and/or folders. |
| |
| Args: |
| items: List of dicts with keys: |
| - id: The persona_id or folder_id |
| - type: Either "persona" or "folder" |
| - sort_order: The new sort_order value |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def insert_preference_or_update( |
| self, |
| scope: str, |
| scope_id: str, |
| key: str, |
| value: dict, |
| ) -> Preference: |
| """Insert a new preference record.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference: |
| """Get a preference by scope ID and key.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_preferences( |
| self, |
| scope: str, |
| scope_id: str | None = None, |
| key: str | None = None, |
| ) -> list[Preference]: |
| """Get all preferences for a specific scope ID or key.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def remove_preference(self, scope: str, scope_id: str, key: str) -> None: |
| """Remove a preference by scope ID and key.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def clear_preferences(self, scope: str, scope_id: str) -> None: |
| """Clear all preferences for a specific scope ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_command_configs(self) -> list[CommandConfig]: |
| """Get all stored command configurations.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_command_config(self, handler_full_name: str) -> CommandConfig | None: |
| """Fetch a single command configuration by handler.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def upsert_command_config( |
| self, |
| handler_full_name: str, |
| plugin_name: str, |
| module_path: str, |
| original_command: str, |
| *, |
| resolved_command: str | None = None, |
| enabled: bool | None = None, |
| keep_original_alias: bool | None = None, |
| conflict_key: str | None = None, |
| resolution_strategy: str | None = None, |
| note: str | None = None, |
| extra_data: dict | None = None, |
| auto_managed: bool | None = None, |
| ) -> CommandConfig: |
| """Create or update a command configuration.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_command_config(self, handler_full_name: str) -> None: |
| """Delete a single command configuration.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_command_configs(self, handler_full_names: list[str]) -> None: |
| """Bulk delete command configurations.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def list_command_conflicts( |
| self, |
| status: str | None = None, |
| ) -> list[CommandConflict]: |
| """List recorded command conflict entries.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def upsert_command_conflict( |
| self, |
| conflict_key: str, |
| handler_full_name: str, |
| plugin_name: str, |
| *, |
| status: str | None = None, |
| resolution: str | None = None, |
| resolved_command: str | None = None, |
| note: str | None = None, |
| extra_data: dict | None = None, |
| auto_generated: bool | None = None, |
| ) -> CommandConflict: |
| """Create or update a conflict record.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_command_conflicts(self, ids: list[int]) -> None: |
| """Delete conflict records.""" |
| ... |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| @abc.abstractmethod |
| async def get_session_conversations( |
| self, |
| page: int = 1, |
| page_size: int = 20, |
| search_query: str | None = None, |
| platform: str | None = None, |
| ) -> tuple[list[dict], int]: |
| """Get paginated session conversations with joined conversation and persona details, support search and platform filter.""" |
| ... |
|
|
| |
| |
| |
|
|
| @abc.abstractmethod |
| async def create_cron_job( |
| self, |
| name: str, |
| job_type: str, |
| cron_expression: str | None, |
| *, |
| timezone: str | None = None, |
| payload: dict | None = None, |
| description: str | None = None, |
| enabled: bool = True, |
| persistent: bool = True, |
| run_once: bool = False, |
| status: str | None = None, |
| job_id: str | None = None, |
| ) -> CronJob: |
| """Create and persist a cron job definition.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def update_cron_job( |
| self, |
| job_id: str, |
| *, |
| name: str | None = None, |
| cron_expression: str | None = None, |
| timezone: str | None = None, |
| payload: dict | None = None, |
| description: str | None = None, |
| enabled: bool | None = None, |
| persistent: bool | None = None, |
| run_once: bool | None = None, |
| status: str | None = None, |
| next_run_time: datetime.datetime | None = None, |
| last_run_at: datetime.datetime | None = None, |
| last_error: str | None = None, |
| ) -> CronJob | None: |
| """Update fields of a cron job by job_id.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_cron_job(self, job_id: str) -> None: |
| """Delete a cron job by its public job_id.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_cron_job(self, job_id: str) -> CronJob | None: |
| """Fetch a cron job by job_id.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def list_cron_jobs(self, job_type: str | None = None) -> list[CronJob]: |
| """List cron jobs, optionally filtered by job_type.""" |
| ... |
|
|
| |
| |
| |
|
|
| @abc.abstractmethod |
| async def create_platform_session( |
| self, |
| creator: str, |
| platform_id: str = "webchat", |
| session_id: str | None = None, |
| display_name: str | None = None, |
| is_group: int = 0, |
| ) -> PlatformSession: |
| """Create a new Platform session.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_platform_session_by_id( |
| self, session_id: str |
| ) -> PlatformSession | None: |
| """Get a Platform session by its ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_platform_sessions_by_creator( |
| self, |
| creator: str, |
| platform_id: str | None = None, |
| page: int = 1, |
| page_size: int = 20, |
| ) -> list[dict]: |
| """Get all Platform sessions for a specific creator (username) and optionally platform. |
| |
| Returns a list of dicts containing session info and project info (if session belongs to a project). |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def get_platform_sessions_by_creator_paginated( |
| self, |
| creator: str, |
| platform_id: str | None = None, |
| page: int = 1, |
| page_size: int = 20, |
| exclude_project_sessions: bool = False, |
| ) -> tuple[list[dict], int]: |
| """Get paginated platform sessions and total count for a creator. |
| |
| Returns: |
| tuple[list[dict], int]: (sessions_with_project_info, total_count) |
| """ |
| ... |
|
|
| @abc.abstractmethod |
| async def update_platform_session( |
| self, |
| session_id: str, |
| display_name: str | None = None, |
| ) -> None: |
| """Update a Platform session's updated_at timestamp and optionally display_name.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_platform_session(self, session_id: str) -> None: |
| """Delete a Platform session by its ID.""" |
| ... |
|
|
| |
| |
| |
|
|
| @abc.abstractmethod |
| async def create_chatui_project( |
| self, |
| creator: str, |
| title: str, |
| emoji: str | None = "📁", |
| description: str | None = None, |
| ) -> ChatUIProject: |
| """Create a new ChatUI project.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_chatui_project_by_id(self, project_id: str) -> ChatUIProject | None: |
| """Get a ChatUI project by its ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_chatui_projects_by_creator( |
| self, |
| creator: str, |
| page: int = 1, |
| page_size: int = 100, |
| ) -> list[ChatUIProject]: |
| """Get all ChatUI projects for a specific creator.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def update_chatui_project( |
| self, |
| project_id: str, |
| title: str | None = None, |
| emoji: str | None = None, |
| description: str | None = None, |
| ) -> None: |
| """Update a ChatUI project.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def delete_chatui_project(self, project_id: str) -> None: |
| """Delete a ChatUI project by its ID.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def add_session_to_project( |
| self, |
| session_id: str, |
| project_id: str, |
| ) -> SessionProjectRelation: |
| """Add a session to a project.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def remove_session_from_project(self, session_id: str) -> None: |
| """Remove a session from its project.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_project_sessions( |
| self, |
| project_id: str, |
| page: int = 1, |
| page_size: int = 100, |
| ) -> list[PlatformSession]: |
| """Get all sessions in a project.""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_project_by_session( |
| self, session_id: str, creator: str |
| ) -> ChatUIProject | None: |
| """Get the project that a session belongs to.""" |
| ... |
|
|