astrbbbb / astrbot /core /db /sqlite.py
qa1145's picture
Upload sqlite.py
1c8cdd0 verified
import asyncio
import hashlib
import threading
import typing as T
import uuid
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult, Row
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import (
ApiKey,
Attachment,
ChatUIProject,
CommandConfig,
CommandConflict,
ConversationV2,
CronJob,
Persona,
PersonaFolder,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
SessionProjectRelation,
SQLModel,
)
from astrbot.core.db.po import (
Platform as DeprecatedPlatformStat,
)
from astrbot.core.db.po import (
Stats as DeprecatedStats,
)
from astrbot.core.sentinels import NOT_GIVEN
TxResult = T.TypeVar("TxResult")
CRON_FIELD_NOT_SET = object()
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
super().__init__()
async def initialize(self) -> None:
"""Initialize the database by creating tables if they do not exist."""
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await self._ensure_persona_folder_columns(conn)
await self._ensure_persona_skills_column(conn)
await self._ensure_persona_custom_error_message_column(conn)
await conn.commit()
await self._create_default_api_key()
async def _create_default_api_key(self) -> None:
"""Create a default developer API key if none exists."""
async with self.engine.begin() as conn:
result = await conn.execute(text("SELECT COUNT(*) FROM api_keys"))
count = result.scalar()
if count > 0:
return
raw_key = "abk_astrbot"
key_hash = hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
key_prefix = raw_key[:12]
key_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
async with self.engine.begin() as conn:
await conn.execute(
text("""
INSERT INTO api_keys (key_id, name, key_hash, key_prefix, scopes, created_by, created_at, updated_at)
VALUES (:key_id, :name, :key_hash, :key_prefix, :scopes, :created_by, :created_at, :updated_at)
"""),
{
"key_id": key_id,
"name": "Default Developer Key",
"key_hash": key_hash,
"key_prefix": key_prefix,
"scopes": '["chat","config","file","im"]',
"created_by": "system",
"created_at": now,
"updated_at": now,
},
)
await conn.commit()
async def _ensure_persona_folder_columns(self, conn) -> None:
"""确保 personas 表有 folder_id 和 sort_order 列。
这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel
的 metadata.create_all 自动创建这些列。
"""
result = await conn.execute(text("PRAGMA table_info(personas)"))
columns = {row[1] for row in result.fetchall()}
if "folder_id" not in columns:
await conn.execute(
text(
"ALTER TABLE personas ADD COLUMN folder_id VARCHAR(36) DEFAULT NULL"
)
)
if "sort_order" not in columns:
await conn.execute(
text("ALTER TABLE personas ADD COLUMN sort_order INTEGER DEFAULT 0")
)
async def _ensure_persona_skills_column(self, conn) -> None:
"""确保 personas 表有 skills 列。
这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel
的 metadata.create_all 自动创建这些列。
"""
result = await conn.execute(text("PRAGMA table_info(personas)"))
columns = {row[1] for row in result.fetchall()}
if "skills" not in columns:
await conn.execute(text("ALTER TABLE personas ADD COLUMN skills JSON"))
async def _ensure_persona_custom_error_message_column(self, conn) -> None:
"""确保 personas 表有 custom_error_message 列。"""
result = await conn.execute(text("PRAGMA table_info(personas)"))
columns = {row[1] for row in result.fetchall()}
if "custom_error_message" not in columns:
await conn.execute(
text("ALTER TABLE personas ADD COLUMN custom_error_message TEXT")
)
# ====
# Platform Statistics
# ====
async def insert_platform_stats(
self,
platform_id,
platform_type,
count=1,
timestamp=None,
) -> None:
"""Insert a new platform statistic record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
if timestamp is None:
timestamp = datetime.now().replace(
minute=0,
second=0,
microsecond=0,
)
current_hour = timestamp
await session.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
count = platform_stats.count + EXCLUDED.count
"""),
{
"timestamp": current_hour,
"platform_id": platform_id,
"platform_type": platform_type,
"count": count,
},
)
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.count(col(PlatformStat.platform_id))).select_from(
PlatformStat,
),
)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
text("""
SELECT * FROM platform_stats
WHERE timestamp >= :start_time
GROUP BY platform_id
ORDER BY timestamp DESC
"""),
{"start_time": start_time},
)
return list(result.scalars().all())
# ====
# Conversation Management
# ====
async def get_conversations(self, user_id=None, platform_id=None):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2)
if user_id:
query = query.where(ConversationV2.user_id == user_id)
if platform_id:
query = query.where(ConversationV2.platform_id == platform_id)
# order by
query = query.order_by(desc(ConversationV2.created_at))
result = await session.execute(query)
return result.scalars().all()
async def get_conversation_by_id(self, cid):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_all_conversations(self, page=1, page_size=20):
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(ConversationV2)
.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size),
)
return result.scalars().all()
async def get_filtered_conversations(
self,
page=1,
page_size=20,
platform_ids=None,
search_query="",
**kwargs,
):
async with self.get_db() as session:
session: AsyncSession
# Build the base query with filters
base_query = select(ConversationV2)
if platform_ids:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(platform_ids),
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
or_(
col(ConversationV2.title).ilike(f"%{search_query}%"),
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
),
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
)
# Get total count matching the filters
count_query = select(func.count()).select_from(base_query.subquery())
total_count = await session.execute(count_query)
total = total_count.scalar_one()
# Get paginated results
offset = (page - 1) * page_size
result_query = (
base_query.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(result_query)
conversations = result.scalars().all()
return conversations, total
async def create_conversation(
self,
user_id,
platform_id,
content=None,
title=None,
persona_id=None,
cid=None,
created_at=None,
updated_at=None,
):
kwargs = {}
if cid:
kwargs["conversation_id"] = cid
if created_at:
kwargs["created_at"] = created_at
if updated_at:
kwargs["updated_at"] = updated_at
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_conversation = ConversationV2(
user_id=user_id,
content=content or [],
platform_id=platform_id,
title=title,
persona_id=persona_id,
**kwargs,
)
session.add(new_conversation)
return new_conversation
async def update_conversation(
self, cid, title=None, persona_id=None, content=None, token_usage=None
):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
)
values = {}
if title is not None:
values["title"] = title
if persona_id is not None:
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if token_usage is not None:
values["token_usage"] = token_usage
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.conversation_id) == cid,
),
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.user_id) == user_id
),
)
async def get_session_conversations(
self,
page=1,
page_size=20,
search_query=None,
platform=None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
base_query = (
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id",
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
col(Persona.persona_id).label("persona_name"),
)
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 搜索筛选
if search_query:
search_pattern = f"%{search_query}%"
base_query = base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
# 排序
base_query = base_query.order_by(Preference.scope_id)
# 分页结果
result_query = base_query.offset(offset).limit(page_size)
result = await session.execute(result_query)
rows = result.fetchall()
# 查询总数(应用相同的筛选条件)
count_base_query = (
select(func.count(col(Preference.scope_id)))
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 应用相同的搜索和平台筛选条件到计数查询
if search_query:
search_pattern = f"%{search_query}%"
count_base_query = count_base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
total_result = await session.execute(count_base_query)
total = total_result.scalar() or 0
sessions_data = [
{
"session_id": row.session_id,
"conversation_id": row.conversation_id,
"persona_id": row.persona_id,
"title": row.title,
"persona_name": row.persona_name,
}
for row in rows
]
return sessions_data, total
async def insert_platform_message_history(
self,
platform_id,
user_id,
content,
sender_id=None,
sender_name=None,
):
"""Insert a new platform message history record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=user_id,
content=content,
sender_id=sender_id,
sender_name=sender_name,
)
session.add(new_history)
return new_history
async def delete_platform_message_offset(
self,
platform_id,
user_id,
offset_sec=86400,
) -> None:
"""Delete platform message history records newer than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
now = datetime.now()
cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute(
delete(PlatformMessageHistory).where(
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) >= cutoff_time,
),
)
async def get_platform_message_history(
self,
platform_id,
user_id,
page=1,
page_size=20,
):
"""Get platform message history records."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = (
select(PlatformMessageHistory)
.where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
)
.order_by(desc(PlatformMessageHistory.created_at))
)
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
async def get_platform_message_history_by_id(
self, message_id: int
) -> PlatformMessageHistory | None:
"""Get a platform message history record by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformMessageHistory).where(
PlatformMessageHistory.id == message_id
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def insert_attachment(self, path, type, mime_type):
"""Insert a new attachment record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_attachment = Attachment(
path=path,
type=type,
mime_type=mime_type,
)
session.add(new_attachment)
return new_attachment
async def get_attachment_by_id(self, attachment_id):
"""Get an attachment by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_attachments(self, attachment_ids: list[str]) -> list:
"""Get multiple attachments by their IDs."""
if not attachment_ids:
return []
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
async def delete_attachment(self, attachment_id: str) -> bool:
"""Delete an attachment by its ID.
Returns True if the attachment was deleted, False if it was not found.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int:
"""Delete multiple attachments by their IDs.
Returns the number of attachments deleted.
"""
if not attachment_ids:
return 0
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount
async def create_api_key(
self,
name: str,
key_hash: str,
key_prefix: str,
scopes: list[str] | None,
created_by: str,
expires_at: datetime | None = None,
) -> ApiKey:
"""Create a new API key record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
api_key = ApiKey(
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
scopes=scopes,
created_by=created_by,
expires_at=expires_at,
)
session.add(api_key)
await session.flush()
await session.refresh(api_key)
return api_key
async def list_api_keys(self) -> list[ApiKey]:
"""List all API keys."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(ApiKey).order_by(desc(ApiKey.created_at))
)
return list(result.scalars().all())
async def get_api_key_by_id(self, key_id: str) -> ApiKey | None:
"""Get an API key by key_id."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(ApiKey).where(ApiKey.key_id == key_id)
)
return result.scalar_one_or_none()
async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None:
"""Get an active API key by hash (not revoked, not expired)."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now(timezone.utc)
query = select(ApiKey).where(
ApiKey.key_hash == key_hash,
col(ApiKey.revoked_at).is_(None),
or_(col(ApiKey.expires_at).is_(None), col(ApiKey.expires_at) > now),
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def touch_api_key(self, key_id: str) -> None:
"""Update last_used_at of an API key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
update(ApiKey)
.where(col(ApiKey.key_id) == key_id)
.values(last_used_at=datetime.now(timezone.utc)),
)
async def revoke_api_key(self, key_id: str) -> bool:
"""Revoke an API key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = (
update(ApiKey)
.where(col(ApiKey.key_id) == key_id)
.values(revoked_at=datetime.now(timezone.utc))
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0
async def delete_api_key(self, key_id: str) -> bool:
"""Delete an API key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
result = T.cast(
CursorResult,
await session.execute(
delete(ApiKey).where(col(ApiKey.key_id) == key_id)
),
)
return result.rowcount > 0
async def insert_persona(
self,
persona_id,
system_prompt,
begin_dialogs=None,
tools=None,
skills=None,
custom_error_message=None,
folder_id=None,
sort_order=0,
):
"""Insert a new persona record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_persona = Persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs or [],
tools=tools,
skills=skills,
custom_error_message=custom_error_message,
folder_id=folder_id,
sort_order=sort_order,
)
session.add(new_persona)
await session.flush()
await session.refresh(new_persona)
return new_persona
async def get_persona_by_id(self, persona_id):
"""Get a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona).where(Persona.persona_id == persona_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_personas(self):
"""Get all personas for a specific bot."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona)
result = await session.execute(query)
return result.scalars().all()
async def update_persona(
self,
persona_id,
system_prompt=None,
begin_dialogs=None,
tools=NOT_GIVEN,
skills=NOT_GIVEN,
custom_error_message=NOT_GIVEN,
):
"""Update a persona's system prompt or begin dialogs."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(col(Persona.persona_id) == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
if begin_dialogs is not None:
values["begin_dialogs"] = begin_dialogs
if tools is not NOT_GIVEN:
values["tools"] = tools
if skills is not NOT_GIVEN:
values["skills"] = skills
if custom_error_message is not NOT_GIVEN:
values["custom_error_message"] = custom_error_message
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_persona_by_id(persona_id)
async def delete_persona(self, persona_id) -> None:
"""Delete a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(col(Persona.persona_id) == persona_id),
)
# ====
# Persona Folder Management
# ====
async def insert_persona_folder(
self,
name: str,
parent_id: str | None = None,
description: str | None = None,
sort_order: int = 0,
) -> PersonaFolder:
"""Insert a new persona folder."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_folder = PersonaFolder(
name=name,
parent_id=parent_id,
description=description,
sort_order=sort_order,
)
session.add(new_folder)
await session.flush()
await session.refresh(new_folder)
return new_folder
async def get_persona_folder_by_id(self, folder_id: str) -> PersonaFolder | None:
"""Get a persona folder by its folder_id."""
async with self.get_db() as session:
session: AsyncSession
query = select(PersonaFolder).where(PersonaFolder.folder_id == folder_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_persona_folders(
self, parent_id: str | None = None
) -> list[PersonaFolder]:
"""Get all persona folders, optionally filtered by parent_id.
Args:
parent_id: If None, returns root folders only. If specified, returns
children of that folder.
"""
async with self.get_db() as session:
session: AsyncSession
if parent_id is None:
# Get root folders (parent_id is NULL)
query = (
select(PersonaFolder)
.where(col(PersonaFolder.parent_id).is_(None))
.order_by(col(PersonaFolder.sort_order), col(PersonaFolder.name))
)
else:
query = (
select(PersonaFolder)
.where(PersonaFolder.parent_id == parent_id)
.order_by(col(PersonaFolder.sort_order), col(PersonaFolder.name))
)
result = await session.execute(query)
return list(result.scalars().all())
async def get_all_persona_folders(self) -> list[PersonaFolder]:
"""Get all persona folders."""
async with self.get_db() as session:
session: AsyncSession
query = select(PersonaFolder).order_by(
col(PersonaFolder.sort_order), col(PersonaFolder.name)
)
result = await session.execute(query)
return list(result.scalars().all())
async def update_persona_folder(
self,
folder_id: str,
name: str | None = None,
parent_id: T.Any = NOT_GIVEN,
description: T.Any = NOT_GIVEN,
sort_order: int | None = None,
) -> PersonaFolder | None:
"""Update a persona folder."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(PersonaFolder).where(
col(PersonaFolder.folder_id) == folder_id
)
values: dict[str, T.Any] = {}
if name is not None:
values["name"] = name
if parent_id is not NOT_GIVEN:
values["parent_id"] = parent_id
if description is not NOT_GIVEN:
values["description"] = description
if sort_order is not None:
values["sort_order"] = sort_order
if not values:
return None
query = query.values(**values)
await session.execute(query)
return await self.get_persona_folder_by_id(folder_id)
async def delete_persona_folder(self, folder_id: str) -> None:
"""Delete a persona folder by its folder_id.
Note: This will also set folder_id to NULL for all personas in this folder,
moving them to the root directory.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# Move personas to root directory
await session.execute(
update(Persona)
.where(col(Persona.folder_id) == folder_id)
.values(folder_id=None)
)
# Delete the folder
await session.execute(
delete(PersonaFolder).where(
col(PersonaFolder.folder_id) == folder_id
),
)
async def move_persona_to_folder(
self, persona_id: str, folder_id: str | None
) -> Persona | None:
"""Move a persona to a folder (or root if folder_id is None)."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
update(Persona)
.where(col(Persona.persona_id) == persona_id)
.values(folder_id=folder_id)
)
return await self.get_persona_by_id(persona_id)
async def get_personas_by_folder(
self, folder_id: str | None = None
) -> list[Persona]:
"""Get all personas in a specific folder.
Args:
folder_id: If None, returns personas in root directory.
"""
async with self.get_db() as session:
session: AsyncSession
if folder_id is None:
query = (
select(Persona)
.where(col(Persona.folder_id).is_(None))
.order_by(col(Persona.sort_order), col(Persona.persona_id))
)
else:
query = (
select(Persona)
.where(Persona.folder_id == folder_id)
.order_by(col(Persona.sort_order), col(Persona.persona_id))
)
result = await session.execute(query)
return list(result.scalars().all())
async def batch_update_sort_order(
self,
items: list[dict],
) -> None:
"""Batch update sort_order for personas and/or folders.
Args:
items: List of dicts with keys:
- id: The persona_id or folder_id
- type: Either "persona" or "folder"
- sort_order: The new sort_order value
"""
if not items:
return
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
for item in items:
item_id = item.get("id")
item_type = item.get("type")
sort_order = item.get("sort_order")
if item_id is None or item_type is None or sort_order is None:
continue
if item_type == "persona":
await session.execute(
update(Persona)
.where(col(Persona.persona_id) == item_id)
.values(sort_order=sort_order)
)
elif item_type == "folder":
await session.execute(
update(PersonaFolder)
.where(col(PersonaFolder.folder_id) == item_id)
.values(sort_order=sort_order)
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
existing_preference = result.scalar_one_or_none()
if existing_preference:
existing_preference.value = value
else:
new_preference = Preference(
scope=scope,
scope_id=scope_id,
key=key,
value=value,
)
session.add(new_preference)
return existing_preference or new_preference
async def get_preference(self, scope, scope_id, key):
"""Get a preference by key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_preferences(self, scope, scope_id=None, key=None):
"""Get all preferences for a specific scope ID or key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(Preference.scope == scope)
if scope_id is not None:
query = query.where(Preference.scope_id == scope_id)
if key is not None:
query = query.where(Preference.key == key)
result = await session.execute(query)
return result.scalars().all()
async def remove_preference(self, scope, scope_id, key) -> None:
"""Remove a preference by scope ID and key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
col(Preference.key) == key,
),
)
await session.commit()
async def clear_preferences(self, scope, scope_id) -> None:
"""Clear all preferences for a specific scope ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
),
)
await session.commit()
# ====
# Command Configuration & Conflict Tracking
# ====
async def _run_in_tx(
self,
fn: Callable[[AsyncSession], Awaitable[TxResult]],
) -> TxResult:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
return await fn(session)
@staticmethod
def _apply_updates(model, **updates) -> None:
for field, value in updates.items():
if value is not None:
setattr(model, field, value)
@staticmethod
def _new_command_config(
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
return CommandConfig(
handler_full_name=handler_full_name,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=True if enabled is None else enabled,
keep_original_alias=False
if keep_original_alias is None
else keep_original_alias,
conflict_key=conflict_key or original_command,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=bool(auto_managed),
)
@staticmethod
def _new_command_conflict(
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
return CommandConflict(
conflict_key=conflict_key,
handler_full_name=handler_full_name,
plugin_name=plugin_name,
status=status or "pending",
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=bool(auto_generated),
)
async def get_command_configs(self) -> list[CommandConfig]:
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(select(CommandConfig))
return list(result.scalars().all())
async def get_command_config(
self,
handler_full_name: str,
) -> CommandConfig | None:
async with self.get_db() as session:
session: AsyncSession
return await session.get(CommandConfig, handler_full_name)
async def upsert_command_config(
self,
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
async def _op(session: AsyncSession) -> CommandConfig:
config = await session.get(CommandConfig, handler_full_name)
if not config:
config = self._new_command_config(
handler_full_name,
plugin_name,
module_path,
original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
session.add(config)
else:
self._apply_updates(
config,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
await session.flush()
await session.refresh(config)
return config
return await self._run_in_tx(_op)
async def delete_command_config(self, handler_full_name: str) -> None:
await self.delete_command_configs([handler_full_name])
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
if not handler_full_names:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConfig).where(
col(CommandConfig.handler_full_name).in_(handler_full_names),
),
)
await self._run_in_tx(_op)
async def list_command_conflicts(
self,
status: str | None = None,
) -> list[CommandConflict]:
async with self.get_db() as session:
session: AsyncSession
query = select(CommandConflict)
if status:
query = query.where(CommandConflict.status == status)
result = await session.execute(query)
return list(result.scalars().all())
async def upsert_command_conflict(
self,
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
async def _op(session: AsyncSession) -> CommandConflict:
result = await session.execute(
select(CommandConflict).where(
CommandConflict.conflict_key == conflict_key,
CommandConflict.handler_full_name == handler_full_name,
),
)
record = result.scalar_one_or_none()
if not record:
record = self._new_command_conflict(
conflict_key,
handler_full_name,
plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
session.add(record)
else:
self._apply_updates(
record,
plugin_name=plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
await session.flush()
await session.refresh(record)
return record
return await self._run_in_tx(_op)
async def delete_command_conflicts(self, ids: list[int]) -> None:
if not ids:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
)
await self._run_in_tx(_op)
# ====
# Deprecated Methods
# ====
def get_base_stats(self, offset_sec=86400):
"""Get base statistics within the specified offset in seconds."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
)
all_datas = result.scalars().all()
deprecated_stats = DeprecatedStats()
for data in all_datas:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=int(data.timestamp.timestamp()),
),
)
return deprecated_stats
result = None
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_total_message_count(self):
"""Get the total message count from platform statistics."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
)
total_count = result.scalar_one_or_none()
return total_count if total_count is not None else 0
result = None
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_grouped_base_stats(self, offset_sec=86400):
# group by platform_id
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
.where(PlatformStat.timestamp >= start_time)
.group_by(PlatformStat.platform_id),
)
grouped_stats = result.all()
deprecated_stats = DeprecatedStats()
for platform_id, count in grouped_stats:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=int(start_time.timestamp()),
),
)
return deprecated_stats
result = None
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
# ====
# Platform Session Management
# ====
async def create_platform_session(
self,
creator: str,
platform_id: str = "webchat",
session_id: str | None = None,
display_name: str | None = None,
is_group: int = 0,
) -> PlatformSession:
"""Create a new Platform session."""
kwargs = {}
if session_id:
kwargs["session_id"] = session_id
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_session = PlatformSession(
creator=creator,
platform_id=platform_id,
display_name=display_name,
is_group=is_group,
**kwargs,
)
session.add(new_session)
await session.flush()
await session.refresh(new_session)
return new_session
async def get_platform_session_by_id(
self, session_id: str
) -> PlatformSession | None:
"""Get a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformSession).where(
PlatformSession.session_id == session_id,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_platform_sessions_by_creator(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
) -> list[dict]:
"""Get all Platform sessions for a specific creator (username) and optionally platform.
Returns a list of dicts containing session info and project info (if session belongs to a project).
"""
(
sessions_with_projects,
_,
) = await self.get_platform_sessions_by_creator_paginated(
creator=creator,
platform_id=platform_id,
page=page,
page_size=page_size,
exclude_project_sessions=False,
)
return sessions_with_projects
@staticmethod
def _build_platform_sessions_query(
creator: str,
platform_id: str | None = None,
exclude_project_sessions: bool = False,
):
query = (
select(
PlatformSession,
col(ChatUIProject.project_id),
col(ChatUIProject.title).label("project_title"),
col(ChatUIProject.emoji).label("project_emoji"),
)
.outerjoin(
SessionProjectRelation,
col(PlatformSession.session_id)
== col(SessionProjectRelation.session_id),
)
.outerjoin(
ChatUIProject,
col(SessionProjectRelation.project_id) == col(ChatUIProject.project_id),
)
.where(col(PlatformSession.creator) == creator)
)
if platform_id:
query = query.where(PlatformSession.platform_id == platform_id)
if exclude_project_sessions:
query = query.where(col(ChatUIProject.project_id).is_(None))
return query
@staticmethod
def _rows_to_session_dicts(rows: T.Sequence[Row[tuple]]) -> list[dict]:
sessions_with_projects = []
for row in rows:
platform_session = row[0]
project_id = row[1]
project_title = row[2]
project_emoji = row[3]
session_dict = {
"session": platform_session,
"project_id": project_id,
"project_title": project_title,
"project_emoji": project_emoji,
}
sessions_with_projects.append(session_dict)
return sessions_with_projects
async def get_platform_sessions_by_creator_paginated(
self,
creator: str,
platform_id: str | None = None,
page: int = 1,
page_size: int = 20,
exclude_project_sessions: bool = False,
) -> tuple[list[dict], int]:
"""Get paginated Platform sessions for a creator with total count."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
base_query = self._build_platform_sessions_query(
creator=creator,
platform_id=platform_id,
exclude_project_sessions=exclude_project_sessions,
)
total_result = await session.execute(
select(func.count()).select_from(base_query.subquery())
)
total = int(total_result.scalar_one() or 0)
result_query = (
base_query.order_by(desc(PlatformSession.updated_at))
.offset(offset)
.limit(page_size)
)
result = await session.execute(result_query)
sessions_with_projects = self._rows_to_session_dicts(result.all())
return sessions_with_projects, total
async def update_platform_session(
self,
session_id: str,
display_name: str | None = None,
) -> None:
"""Update a Platform session's updated_at timestamp and optionally display_name."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if display_name is not None:
values["display_name"] = display_name
await session.execute(
update(PlatformSession)
.where(col(PlatformSession.session_id) == session_id)
.values(**values),
)
async def delete_platform_session(self, session_id: str) -> None:
"""Delete a Platform session by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(PlatformSession).where(
col(PlatformSession.session_id) == session_id,
),
)
# ====
# ChatUI Project Management
# ====
async def create_chatui_project(
self,
creator: str,
title: str,
emoji: str | None = "📁",
description: str | None = None,
) -> ChatUIProject:
"""Create a new ChatUI project."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
project = ChatUIProject(
creator=creator,
title=title,
emoji=emoji,
description=description,
)
session.add(project)
await session.flush()
await session.refresh(project)
return project
async def get_chatui_project_by_id(self, project_id: str) -> ChatUIProject | None:
"""Get a ChatUI project by its ID."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(ChatUIProject).where(
col(ChatUIProject.project_id) == project_id,
),
)
return result.scalar_one_or_none()
async def get_chatui_projects_by_creator(
self,
creator: str,
page: int = 1,
page_size: int = 100,
) -> list[ChatUIProject]:
"""Get all ChatUI projects for a specific creator."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(ChatUIProject)
.where(col(ChatUIProject.creator) == creator)
.order_by(desc(ChatUIProject.updated_at))
.limit(page_size)
.offset(offset),
)
return list(result.scalars().all())
async def update_chatui_project(
self,
project_id: str,
title: str | None = None,
emoji: str | None = None,
description: str | None = None,
) -> None:
"""Update a ChatUI project."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
if title is not None:
values["title"] = title
if emoji is not None:
values["emoji"] = emoji
if description is not None:
values["description"] = description
await session.execute(
update(ChatUIProject)
.where(col(ChatUIProject.project_id) == project_id)
.values(**values),
)
async def delete_chatui_project(self, project_id: str) -> None:
"""Delete a ChatUI project by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# First remove all session relations
await session.execute(
delete(SessionProjectRelation).where(
col(SessionProjectRelation.project_id) == project_id,
),
)
# Then delete the project
await session.execute(
delete(ChatUIProject).where(
col(ChatUIProject.project_id) == project_id,
),
)
async def add_session_to_project(
self,
session_id: str,
project_id: str,
) -> SessionProjectRelation:
"""Add a session to a project."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# First remove existing relation if any
await session.execute(
delete(SessionProjectRelation).where(
col(SessionProjectRelation.session_id) == session_id,
),
)
# Then create new relation
relation = SessionProjectRelation(
session_id=session_id,
project_id=project_id,
)
session.add(relation)
await session.flush()
await session.refresh(relation)
return relation
async def remove_session_from_project(self, session_id: str) -> None:
"""Remove a session from its project."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(SessionProjectRelation).where(
col(SessionProjectRelation.session_id) == session_id,
),
)
async def get_project_sessions(
self,
project_id: str,
page: int = 1,
page_size: int = 100,
) -> list[PlatformSession]:
"""Get all sessions in a project."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(PlatformSession)
.join(
SessionProjectRelation,
col(PlatformSession.session_id)
== col(SessionProjectRelation.session_id),
)
.where(col(SessionProjectRelation.project_id) == project_id)
.order_by(desc(PlatformSession.updated_at))
.limit(page_size)
.offset(offset),
)
return list(result.scalars().all())
async def get_project_by_session(
self, session_id: str, creator: str
) -> ChatUIProject | None:
"""Get the project that a session belongs to."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(ChatUIProject)
.join(
SessionProjectRelation,
col(ChatUIProject.project_id)
== col(SessionProjectRelation.project_id),
)
.where(
col(SessionProjectRelation.session_id) == session_id,
col(ChatUIProject.creator) == creator,
),
)
return result.scalar_one_or_none()
# ====
# Cron Job Management
# ====
async def create_cron_job(
self,
name: str,
job_type: str,
cron_expression: str | None,
*,
timezone: str | None = None,
payload: dict | None = None,
description: str | None = None,
enabled: bool = True,
persistent: bool = True,
run_once: bool = False,
status: str | None = None,
job_id: str | None = None,
) -> CronJob:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
job = CronJob(
name=name,
job_type=job_type,
cron_expression=cron_expression,
timezone=timezone,
payload=payload or {},
description=description,
enabled=enabled,
persistent=persistent,
run_once=run_once,
status=status or "scheduled",
)
if job_id:
job.job_id = job_id
session.add(job)
await session.flush()
await session.refresh(job)
return job
async def update_cron_job(
self,
job_id: str,
*,
name: str | None | object = CRON_FIELD_NOT_SET,
cron_expression: str | None | object = CRON_FIELD_NOT_SET,
timezone: str | None | object = CRON_FIELD_NOT_SET,
payload: dict | None | object = CRON_FIELD_NOT_SET,
description: str | None | object = CRON_FIELD_NOT_SET,
enabled: bool | None | object = CRON_FIELD_NOT_SET,
persistent: bool | None | object = CRON_FIELD_NOT_SET,
run_once: bool | None | object = CRON_FIELD_NOT_SET,
status: str | None | object = CRON_FIELD_NOT_SET,
next_run_time: datetime | None | object = CRON_FIELD_NOT_SET,
last_run_at: datetime | None | object = CRON_FIELD_NOT_SET,
last_error: str | None | object = CRON_FIELD_NOT_SET,
) -> CronJob | None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
updates: dict = {}
for key, val in {
"name": name,
"cron_expression": cron_expression,
"timezone": timezone,
"payload": payload,
"description": description,
"enabled": enabled,
"persistent": persistent,
"run_once": run_once,
"status": status,
"next_run_time": next_run_time,
"last_run_at": last_run_at,
"last_error": last_error,
}.items():
if val is CRON_FIELD_NOT_SET:
continue
updates[key] = val
stmt = (
update(CronJob)
.where(col(CronJob.job_id) == job_id)
.values(**updates)
.execution_options(synchronize_session="fetch")
)
await session.execute(stmt)
result = await session.execute(
select(CronJob).where(col(CronJob.job_id) == job_id)
)
return result.scalar_one_or_none()
async def delete_cron_job(self, job_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(CronJob).where(col(CronJob.job_id) == job_id)
)
async def get_cron_job(self, job_id: str) -> CronJob | None:
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(CronJob).where(col(CronJob.job_id) == job_id)
)
return result.scalar_one_or_none()
async def list_cron_jobs(self, job_type: str | None = None) -> list[CronJob]:
async with self.get_db() as session:
session: AsyncSession
query = select(CronJob)
if job_type:
query = query.where(col(CronJob.job_type) == job_type)
query = query.order_by(desc(CronJob.created_at))
result = await session.execute(query)
return list(result.scalars().all())