| import asyncio |
| import hashlib |
| import threading |
| import typing as T |
| import uuid |
| from collections.abc import Awaitable, Callable |
| from datetime import datetime, timedelta, timezone |
|
|
| from sqlalchemy import CursorResult, Row |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from sqlmodel import col, delete, desc, func, or_, select, text, update |
|
|
| from astrbot.core.db import BaseDatabase |
| from astrbot.core.db.po import ( |
| ApiKey, |
| Attachment, |
| ChatUIProject, |
| CommandConfig, |
| CommandConflict, |
| ConversationV2, |
| CronJob, |
| Persona, |
| PersonaFolder, |
| PlatformMessageHistory, |
| PlatformSession, |
| PlatformStat, |
| Preference, |
| SessionProjectRelation, |
| SQLModel, |
| ) |
| from astrbot.core.db.po import ( |
| Platform as DeprecatedPlatformStat, |
| ) |
| from astrbot.core.db.po import ( |
| Stats as DeprecatedStats, |
| ) |
| from astrbot.core.sentinels import NOT_GIVEN |
|
|
| TxResult = T.TypeVar("TxResult") |
| CRON_FIELD_NOT_SET = object() |
|
|
|
|
| class SQLiteDatabase(BaseDatabase): |
| def __init__(self, db_path: str) -> None: |
| self.db_path = db_path |
| self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" |
| self.inited = False |
| super().__init__() |
|
|
| async def initialize(self) -> None: |
| """Initialize the database by creating tables if they do not exist.""" |
| async with self.engine.begin() as conn: |
| await conn.run_sync(SQLModel.metadata.create_all) |
| await conn.execute(text("PRAGMA journal_mode=WAL")) |
| await conn.execute(text("PRAGMA synchronous=NORMAL")) |
| await conn.execute(text("PRAGMA cache_size=20000")) |
| await conn.execute(text("PRAGMA temp_store=MEMORY")) |
| await conn.execute(text("PRAGMA mmap_size=134217728")) |
| await conn.execute(text("PRAGMA optimize")) |
| await self._ensure_persona_folder_columns(conn) |
| await self._ensure_persona_skills_column(conn) |
| await self._ensure_persona_custom_error_message_column(conn) |
| await conn.commit() |
|
|
| await self._create_default_api_key() |
|
|
| async def _create_default_api_key(self) -> None: |
| """Create a default developer API key if none exists.""" |
| async with self.engine.begin() as conn: |
| result = await conn.execute(text("SELECT COUNT(*) FROM api_keys")) |
| count = result.scalar() |
| if count > 0: |
| return |
|
|
| raw_key = "abk_astrbot" |
| key_hash = hashlib.pbkdf2_hmac( |
| "sha256", |
| raw_key.encode("utf-8"), |
| b"astrbot_api_key", |
| 100_000, |
| ).hex() |
| key_prefix = raw_key[:12] |
| key_id = str(uuid.uuid4()) |
| now = datetime.now(timezone.utc).isoformat() |
|
|
| async with self.engine.begin() as conn: |
| await conn.execute( |
| text(""" |
| INSERT INTO api_keys (key_id, name, key_hash, key_prefix, scopes, created_by, created_at, updated_at) |
| VALUES (:key_id, :name, :key_hash, :key_prefix, :scopes, :created_by, :created_at, :updated_at) |
| """), |
| { |
| "key_id": key_id, |
| "name": "Default Developer Key", |
| "key_hash": key_hash, |
| "key_prefix": key_prefix, |
| "scopes": '["chat","config","file","im"]', |
| "created_by": "system", |
| "created_at": now, |
| "updated_at": now, |
| }, |
| ) |
| await conn.commit() |
|
|
| async def _ensure_persona_folder_columns(self, conn) -> None: |
| """确保 personas 表有 folder_id 和 sort_order 列。 |
| |
| 这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel |
| 的 metadata.create_all 自动创建这些列。 |
| """ |
| result = await conn.execute(text("PRAGMA table_info(personas)")) |
| columns = {row[1] for row in result.fetchall()} |
|
|
| if "folder_id" not in columns: |
| await conn.execute( |
| text( |
| "ALTER TABLE personas ADD COLUMN folder_id VARCHAR(36) DEFAULT NULL" |
| ) |
| ) |
| if "sort_order" not in columns: |
| await conn.execute( |
| text("ALTER TABLE personas ADD COLUMN sort_order INTEGER DEFAULT 0") |
| ) |
|
|
| async def _ensure_persona_skills_column(self, conn) -> None: |
| """确保 personas 表有 skills 列。 |
| |
| 这是为了支持旧版数据库的平滑升级。新版数据库通过 SQLModel |
| 的 metadata.create_all 自动创建这些列。 |
| """ |
| result = await conn.execute(text("PRAGMA table_info(personas)")) |
| columns = {row[1] for row in result.fetchall()} |
|
|
| if "skills" not in columns: |
| await conn.execute(text("ALTER TABLE personas ADD COLUMN skills JSON")) |
|
|
| async def _ensure_persona_custom_error_message_column(self, conn) -> None: |
| """确保 personas 表有 custom_error_message 列。""" |
| result = await conn.execute(text("PRAGMA table_info(personas)")) |
| columns = {row[1] for row in result.fetchall()} |
|
|
| if "custom_error_message" not in columns: |
| await conn.execute( |
| text("ALTER TABLE personas ADD COLUMN custom_error_message TEXT") |
| ) |
|
|
| |
| |
| |
|
|
| async def insert_platform_stats( |
| self, |
| platform_id, |
| platform_type, |
| count=1, |
| timestamp=None, |
| ) -> None: |
| """Insert a new platform statistic record.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| if timestamp is None: |
| timestamp = datetime.now().replace( |
| minute=0, |
| second=0, |
| microsecond=0, |
| ) |
| current_hour = timestamp |
| await session.execute( |
| text(""" |
| INSERT INTO platform_stats (timestamp, platform_id, platform_type, count) |
| VALUES (:timestamp, :platform_id, :platform_type, :count) |
| ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET |
| count = platform_stats.count + EXCLUDED.count |
| """), |
| { |
| "timestamp": current_hour, |
| "platform_id": platform_id, |
| "platform_type": platform_type, |
| "count": count, |
| }, |
| ) |
|
|
| async def count_platform_stats(self) -> int: |
| """Count the number of platform statistics records.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| result = await session.execute( |
| select(func.count(col(PlatformStat.platform_id))).select_from( |
| PlatformStat, |
| ), |
| ) |
| count = result.scalar_one_or_none() |
| return count if count is not None else 0 |
|
|
| async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: |
| """Get platform statistics within the specified offset in seconds and group by platform_id.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| now = datetime.now() |
| start_time = now - timedelta(seconds=offset_sec) |
| result = await session.execute( |
| text(""" |
| SELECT * FROM platform_stats |
| WHERE timestamp >= :start_time |
| GROUP BY platform_id |
| ORDER BY timestamp DESC |
| """), |
| {"start_time": start_time}, |
| ) |
| return list(result.scalars().all()) |
|
|
| |
| |
| |
|
|
| async def get_conversations(self, user_id=None, platform_id=None): |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(ConversationV2) |
|
|
| if user_id: |
| query = query.where(ConversationV2.user_id == user_id) |
| if platform_id: |
| query = query.where(ConversationV2.platform_id == platform_id) |
| |
| query = query.order_by(desc(ConversationV2.created_at)) |
| result = await session.execute(query) |
|
|
| return result.scalars().all() |
|
|
| async def get_conversation_by_id(self, cid): |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(ConversationV2).where(ConversationV2.conversation_id == cid) |
| result = await session.execute(query) |
| return result.scalar_one_or_none() |
|
|
| async def get_all_conversations(self, page=1, page_size=20): |
| async with self.get_db() as session: |
| session: AsyncSession |
| offset = (page - 1) * page_size |
| result = await session.execute( |
| select(ConversationV2) |
| .order_by(desc(ConversationV2.created_at)) |
| .offset(offset) |
| .limit(page_size), |
| ) |
| return result.scalars().all() |
|
|
| async def get_filtered_conversations( |
| self, |
| page=1, |
| page_size=20, |
| platform_ids=None, |
| search_query="", |
| **kwargs, |
| ): |
| async with self.get_db() as session: |
| session: AsyncSession |
| |
| base_query = select(ConversationV2) |
|
|
| if platform_ids: |
| base_query = base_query.where( |
| col(ConversationV2.platform_id).in_(platform_ids), |
| ) |
| if search_query: |
| search_query = search_query.encode("unicode_escape").decode("utf-8") |
| base_query = base_query.where( |
| or_( |
| col(ConversationV2.title).ilike(f"%{search_query}%"), |
| col(ConversationV2.content).ilike(f"%{search_query}%"), |
| col(ConversationV2.user_id).ilike(f"%{search_query}%"), |
| col(ConversationV2.conversation_id).ilike(f"%{search_query}%"), |
| ), |
| ) |
| if "message_types" in kwargs and len(kwargs["message_types"]) > 0: |
| for msg_type in kwargs["message_types"]: |
| base_query = base_query.where( |
| col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"), |
| ) |
| if "platforms" in kwargs and len(kwargs["platforms"]) > 0: |
| base_query = base_query.where( |
| col(ConversationV2.platform_id).in_(kwargs["platforms"]), |
| ) |
|
|
| |
| count_query = select(func.count()).select_from(base_query.subquery()) |
| total_count = await session.execute(count_query) |
| total = total_count.scalar_one() |
|
|
| |
| offset = (page - 1) * page_size |
| result_query = ( |
| base_query.order_by(desc(ConversationV2.created_at)) |
| .offset(offset) |
| .limit(page_size) |
| ) |
| result = await session.execute(result_query) |
| conversations = result.scalars().all() |
|
|
| return conversations, total |
|
|
| async def create_conversation( |
| self, |
| user_id, |
| platform_id, |
| content=None, |
| title=None, |
| persona_id=None, |
| cid=None, |
| created_at=None, |
| updated_at=None, |
| ): |
| kwargs = {} |
| if cid: |
| kwargs["conversation_id"] = cid |
| if created_at: |
| kwargs["created_at"] = created_at |
| if updated_at: |
| kwargs["updated_at"] = updated_at |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| new_conversation = ConversationV2( |
| user_id=user_id, |
| content=content or [], |
| platform_id=platform_id, |
| title=title, |
| persona_id=persona_id, |
| **kwargs, |
| ) |
| session.add(new_conversation) |
| return new_conversation |
|
|
| async def update_conversation( |
| self, cid, title=None, persona_id=None, content=None, token_usage=None |
| ): |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| query = update(ConversationV2).where( |
| col(ConversationV2.conversation_id) == cid, |
| ) |
| values = {} |
| if title is not None: |
| values["title"] = title |
| if persona_id is not None: |
| values["persona_id"] = persona_id |
| if content is not None: |
| values["content"] = content |
| if token_usage is not None: |
| values["token_usage"] = token_usage |
| if not values: |
| return None |
| query = query.values(**values) |
| await session.execute(query) |
| return await self.get_conversation_by_id(cid) |
|
|
| async def delete_conversation(self, cid) -> None: |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| delete(ConversationV2).where( |
| col(ConversationV2.conversation_id) == cid, |
| ), |
| ) |
|
|
| async def delete_conversations_by_user_id(self, user_id: str) -> None: |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| delete(ConversationV2).where( |
| col(ConversationV2.user_id) == user_id |
| ), |
| ) |
|
|
| async def get_session_conversations( |
| self, |
| page=1, |
| page_size=20, |
| search_query=None, |
| platform=None, |
| ) -> tuple[list[dict], int]: |
| """Get paginated session conversations with joined conversation and persona details.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| offset = (page - 1) * page_size |
|
|
| base_query = ( |
| select( |
| col(Preference.scope_id).label("session_id"), |
| func.json_extract(Preference.value, "$.val").label( |
| "conversation_id", |
| ), |
| col(ConversationV2.persona_id).label("persona_id"), |
| col(ConversationV2.title).label("title"), |
| col(Persona.persona_id).label("persona_name"), |
| ) |
| .select_from(Preference) |
| .outerjoin( |
| ConversationV2, |
| func.json_extract(Preference.value, "$.val") |
| == ConversationV2.conversation_id, |
| ) |
| .outerjoin( |
| Persona, |
| col(ConversationV2.persona_id) == Persona.persona_id, |
| ) |
| .where(Preference.scope == "umo", Preference.key == "sel_conv_id") |
| ) |
|
|
| |
| if search_query: |
| search_pattern = f"%{search_query}%" |
| base_query = base_query.where( |
| or_( |
| col(Preference.scope_id).ilike(search_pattern), |
| col(ConversationV2.title).ilike(search_pattern), |
| col(Persona.persona_id).ilike(search_pattern), |
| ), |
| ) |
|
|
| |
| if platform: |
| platform_pattern = f"{platform}:%" |
| base_query = base_query.where( |
| col(Preference.scope_id).like(platform_pattern), |
| ) |
|
|
| |
| base_query = base_query.order_by(Preference.scope_id) |
|
|
| |
| result_query = base_query.offset(offset).limit(page_size) |
| result = await session.execute(result_query) |
| rows = result.fetchall() |
|
|
| |
| count_base_query = ( |
| select(func.count(col(Preference.scope_id))) |
| .select_from(Preference) |
| .outerjoin( |
| ConversationV2, |
| func.json_extract(Preference.value, "$.val") |
| == ConversationV2.conversation_id, |
| ) |
| .outerjoin( |
| Persona, |
| col(ConversationV2.persona_id) == Persona.persona_id, |
| ) |
| .where(Preference.scope == "umo", Preference.key == "sel_conv_id") |
| ) |
|
|
| |
| if search_query: |
| search_pattern = f"%{search_query}%" |
| count_base_query = count_base_query.where( |
| or_( |
| col(Preference.scope_id).ilike(search_pattern), |
| col(ConversationV2.title).ilike(search_pattern), |
| col(Persona.persona_id).ilike(search_pattern), |
| ), |
| ) |
|
|
| if platform: |
| platform_pattern = f"{platform}:%" |
| count_base_query = count_base_query.where( |
| col(Preference.scope_id).like(platform_pattern), |
| ) |
|
|
| total_result = await session.execute(count_base_query) |
| total = total_result.scalar() or 0 |
|
|
| sessions_data = [ |
| { |
| "session_id": row.session_id, |
| "conversation_id": row.conversation_id, |
| "persona_id": row.persona_id, |
| "title": row.title, |
| "persona_name": row.persona_name, |
| } |
| for row in rows |
| ] |
| return sessions_data, total |
|
|
| async def insert_platform_message_history( |
| self, |
| platform_id, |
| user_id, |
| content, |
| sender_id=None, |
| sender_name=None, |
| ): |
| """Insert a new platform message history record.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| new_history = PlatformMessageHistory( |
| platform_id=platform_id, |
| user_id=user_id, |
| content=content, |
| sender_id=sender_id, |
| sender_name=sender_name, |
| ) |
| session.add(new_history) |
| return new_history |
|
|
| async def delete_platform_message_offset( |
| self, |
| platform_id, |
| user_id, |
| offset_sec=86400, |
| ) -> None: |
| """Delete platform message history records newer than the specified offset.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| now = datetime.now() |
| cutoff_time = now - timedelta(seconds=offset_sec) |
| await session.execute( |
| delete(PlatformMessageHistory).where( |
| col(PlatformMessageHistory.platform_id) == platform_id, |
| col(PlatformMessageHistory.user_id) == user_id, |
| col(PlatformMessageHistory.created_at) >= cutoff_time, |
| ), |
| ) |
|
|
| async def get_platform_message_history( |
| self, |
| platform_id, |
| user_id, |
| page=1, |
| page_size=20, |
| ): |
| """Get platform message history records.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| offset = (page - 1) * page_size |
| query = ( |
| select(PlatformMessageHistory) |
| .where( |
| PlatformMessageHistory.platform_id == platform_id, |
| PlatformMessageHistory.user_id == user_id, |
| ) |
| .order_by(desc(PlatformMessageHistory.created_at)) |
| ) |
| result = await session.execute(query.offset(offset).limit(page_size)) |
| return result.scalars().all() |
|
|
| async def get_platform_message_history_by_id( |
| self, message_id: int |
| ) -> PlatformMessageHistory | None: |
| """Get a platform message history record by its ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(PlatformMessageHistory).where( |
| PlatformMessageHistory.id == message_id |
| ) |
| result = await session.execute(query) |
| return result.scalar_one_or_none() |
|
|
| async def insert_attachment(self, path, type, mime_type): |
| """Insert a new attachment record.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| new_attachment = Attachment( |
| path=path, |
| type=type, |
| mime_type=mime_type, |
| ) |
| session.add(new_attachment) |
| return new_attachment |
|
|
| async def get_attachment_by_id(self, attachment_id): |
| """Get an attachment by its ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(Attachment).where(Attachment.attachment_id == attachment_id) |
| result = await session.execute(query) |
| return result.scalar_one_or_none() |
|
|
| async def get_attachments(self, attachment_ids: list[str]) -> list: |
| """Get multiple attachments by their IDs.""" |
| if not attachment_ids: |
| return [] |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(Attachment).where( |
| col(Attachment.attachment_id).in_(attachment_ids) |
| ) |
| result = await session.execute(query) |
| return list(result.scalars().all()) |
|
|
| async def delete_attachment(self, attachment_id: str) -> bool: |
| """Delete an attachment by its ID. |
| |
| Returns True if the attachment was deleted, False if it was not found. |
| """ |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| query = delete(Attachment).where( |
| col(Attachment.attachment_id) == attachment_id |
| ) |
| result = T.cast(CursorResult, await session.execute(query)) |
| return result.rowcount > 0 |
|
|
| async def delete_attachments(self, attachment_ids: list[str]) -> int: |
| """Delete multiple attachments by their IDs. |
| |
| Returns the number of attachments deleted. |
| """ |
| if not attachment_ids: |
| return 0 |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| query = delete(Attachment).where( |
| col(Attachment.attachment_id).in_(attachment_ids) |
| ) |
| result = T.cast(CursorResult, await session.execute(query)) |
| return result.rowcount |
|
|
| async def create_api_key( |
| self, |
| name: str, |
| key_hash: str, |
| key_prefix: str, |
| scopes: list[str] | None, |
| created_by: str, |
| expires_at: datetime | None = None, |
| ) -> ApiKey: |
| """Create a new API key record.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| api_key = ApiKey( |
| name=name, |
| key_hash=key_hash, |
| key_prefix=key_prefix, |
| scopes=scopes, |
| created_by=created_by, |
| expires_at=expires_at, |
| ) |
| session.add(api_key) |
| await session.flush() |
| await session.refresh(api_key) |
| return api_key |
|
|
| async def list_api_keys(self) -> list[ApiKey]: |
| """List all API keys.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| result = await session.execute( |
| select(ApiKey).order_by(desc(ApiKey.created_at)) |
| ) |
| return list(result.scalars().all()) |
|
|
| async def get_api_key_by_id(self, key_id: str) -> ApiKey | None: |
| """Get an API key by key_id.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| result = await session.execute( |
| select(ApiKey).where(ApiKey.key_id == key_id) |
| ) |
| return result.scalar_one_or_none() |
|
|
| async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None: |
| """Get an active API key by hash (not revoked, not expired).""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| now = datetime.now(timezone.utc) |
| query = select(ApiKey).where( |
| ApiKey.key_hash == key_hash, |
| col(ApiKey.revoked_at).is_(None), |
| or_(col(ApiKey.expires_at).is_(None), col(ApiKey.expires_at) > now), |
| ) |
| result = await session.execute(query) |
| return result.scalar_one_or_none() |
|
|
| async def touch_api_key(self, key_id: str) -> None: |
| """Update last_used_at of an API key.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| update(ApiKey) |
| .where(col(ApiKey.key_id) == key_id) |
| .values(last_used_at=datetime.now(timezone.utc)), |
| ) |
|
|
| async def revoke_api_key(self, key_id: str) -> bool: |
| """Revoke an API key.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| query = ( |
| update(ApiKey) |
| .where(col(ApiKey.key_id) == key_id) |
| .values(revoked_at=datetime.now(timezone.utc)) |
| ) |
| result = T.cast(CursorResult, await session.execute(query)) |
| return result.rowcount > 0 |
|
|
| async def delete_api_key(self, key_id: str) -> bool: |
| """Delete an API key.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| result = T.cast( |
| CursorResult, |
| await session.execute( |
| delete(ApiKey).where(col(ApiKey.key_id) == key_id) |
| ), |
| ) |
| return result.rowcount > 0 |
|
|
| async def insert_persona( |
| self, |
| persona_id, |
| system_prompt, |
| begin_dialogs=None, |
| tools=None, |
| skills=None, |
| custom_error_message=None, |
| folder_id=None, |
| sort_order=0, |
| ): |
| """Insert a new persona record.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| new_persona = Persona( |
| persona_id=persona_id, |
| system_prompt=system_prompt, |
| begin_dialogs=begin_dialogs or [], |
| tools=tools, |
| skills=skills, |
| custom_error_message=custom_error_message, |
| folder_id=folder_id, |
| sort_order=sort_order, |
| ) |
| session.add(new_persona) |
| await session.flush() |
| await session.refresh(new_persona) |
| return new_persona |
|
|
| async def get_persona_by_id(self, persona_id): |
| """Get a persona by its ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(Persona).where(Persona.persona_id == persona_id) |
| result = await session.execute(query) |
| return result.scalar_one_or_none() |
|
|
| async def get_personas(self): |
| """Get all personas for a specific bot.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(Persona) |
| result = await session.execute(query) |
| return result.scalars().all() |
|
|
| async def update_persona( |
| self, |
| persona_id, |
| system_prompt=None, |
| begin_dialogs=None, |
| tools=NOT_GIVEN, |
| skills=NOT_GIVEN, |
| custom_error_message=NOT_GIVEN, |
| ): |
| """Update a persona's system prompt or begin dialogs.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| query = update(Persona).where(col(Persona.persona_id) == persona_id) |
| values = {} |
| if system_prompt is not None: |
| values["system_prompt"] = system_prompt |
| if begin_dialogs is not None: |
| values["begin_dialogs"] = begin_dialogs |
| if tools is not NOT_GIVEN: |
| values["tools"] = tools |
| if skills is not NOT_GIVEN: |
| values["skills"] = skills |
| if custom_error_message is not NOT_GIVEN: |
| values["custom_error_message"] = custom_error_message |
| if not values: |
| return None |
| query = query.values(**values) |
| await session.execute(query) |
| return await self.get_persona_by_id(persona_id) |
|
|
| async def delete_persona(self, persona_id) -> None: |
| """Delete a persona by its ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| delete(Persona).where(col(Persona.persona_id) == persona_id), |
| ) |
|
|
| |
| |
| |
|
|
| async def insert_persona_folder( |
| self, |
| name: str, |
| parent_id: str | None = None, |
| description: str | None = None, |
| sort_order: int = 0, |
| ) -> PersonaFolder: |
| """Insert a new persona folder.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| new_folder = PersonaFolder( |
| name=name, |
| parent_id=parent_id, |
| description=description, |
| sort_order=sort_order, |
| ) |
| session.add(new_folder) |
| await session.flush() |
| await session.refresh(new_folder) |
| return new_folder |
|
|
| async def get_persona_folder_by_id(self, folder_id: str) -> PersonaFolder | None: |
| """Get a persona folder by its folder_id.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(PersonaFolder).where(PersonaFolder.folder_id == folder_id) |
| result = await session.execute(query) |
| return result.scalar_one_or_none() |
|
|
| async def get_persona_folders( |
| self, parent_id: str | None = None |
| ) -> list[PersonaFolder]: |
| """Get all persona folders, optionally filtered by parent_id. |
| |
| Args: |
| parent_id: If None, returns root folders only. If specified, returns |
| children of that folder. |
| """ |
| async with self.get_db() as session: |
| session: AsyncSession |
| if parent_id is None: |
| |
| query = ( |
| select(PersonaFolder) |
| .where(col(PersonaFolder.parent_id).is_(None)) |
| .order_by(col(PersonaFolder.sort_order), col(PersonaFolder.name)) |
| ) |
| else: |
| query = ( |
| select(PersonaFolder) |
| .where(PersonaFolder.parent_id == parent_id) |
| .order_by(col(PersonaFolder.sort_order), col(PersonaFolder.name)) |
| ) |
| result = await session.execute(query) |
| return list(result.scalars().all()) |
|
|
| async def get_all_persona_folders(self) -> list[PersonaFolder]: |
| """Get all persona folders.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(PersonaFolder).order_by( |
| col(PersonaFolder.sort_order), col(PersonaFolder.name) |
| ) |
| result = await session.execute(query) |
| return list(result.scalars().all()) |
|
|
| async def update_persona_folder( |
| self, |
| folder_id: str, |
| name: str | None = None, |
| parent_id: T.Any = NOT_GIVEN, |
| description: T.Any = NOT_GIVEN, |
| sort_order: int | None = None, |
| ) -> PersonaFolder | None: |
| """Update a persona folder.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| query = update(PersonaFolder).where( |
| col(PersonaFolder.folder_id) == folder_id |
| ) |
| values: dict[str, T.Any] = {} |
| if name is not None: |
| values["name"] = name |
| if parent_id is not NOT_GIVEN: |
| values["parent_id"] = parent_id |
| if description is not NOT_GIVEN: |
| values["description"] = description |
| if sort_order is not None: |
| values["sort_order"] = sort_order |
| if not values: |
| return None |
| query = query.values(**values) |
| await session.execute(query) |
| return await self.get_persona_folder_by_id(folder_id) |
|
|
| async def delete_persona_folder(self, folder_id: str) -> None: |
| """Delete a persona folder by its folder_id. |
| |
| Note: This will also set folder_id to NULL for all personas in this folder, |
| moving them to the root directory. |
| """ |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| |
| await session.execute( |
| update(Persona) |
| .where(col(Persona.folder_id) == folder_id) |
| .values(folder_id=None) |
| ) |
| |
| await session.execute( |
| delete(PersonaFolder).where( |
| col(PersonaFolder.folder_id) == folder_id |
| ), |
| ) |
|
|
| async def move_persona_to_folder( |
| self, persona_id: str, folder_id: str | None |
| ) -> Persona | None: |
| """Move a persona to a folder (or root if folder_id is None).""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| update(Persona) |
| .where(col(Persona.persona_id) == persona_id) |
| .values(folder_id=folder_id) |
| ) |
| return await self.get_persona_by_id(persona_id) |
|
|
| async def get_personas_by_folder( |
| self, folder_id: str | None = None |
| ) -> list[Persona]: |
| """Get all personas in a specific folder. |
| |
| Args: |
| folder_id: If None, returns personas in root directory. |
| """ |
| async with self.get_db() as session: |
| session: AsyncSession |
| if folder_id is None: |
| query = ( |
| select(Persona) |
| .where(col(Persona.folder_id).is_(None)) |
| .order_by(col(Persona.sort_order), col(Persona.persona_id)) |
| ) |
| else: |
| query = ( |
| select(Persona) |
| .where(Persona.folder_id == folder_id) |
| .order_by(col(Persona.sort_order), col(Persona.persona_id)) |
| ) |
| result = await session.execute(query) |
| return list(result.scalars().all()) |
|
|
| async def batch_update_sort_order( |
| self, |
| items: list[dict], |
| ) -> None: |
| """Batch update sort_order for personas and/or folders. |
| |
| Args: |
| items: List of dicts with keys: |
| - id: The persona_id or folder_id |
| - type: Either "persona" or "folder" |
| - sort_order: The new sort_order value |
| """ |
| if not items: |
| return |
|
|
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| for item in items: |
| item_id = item.get("id") |
| item_type = item.get("type") |
| sort_order = item.get("sort_order") |
|
|
| if item_id is None or item_type is None or sort_order is None: |
| continue |
|
|
| if item_type == "persona": |
| await session.execute( |
| update(Persona) |
| .where(col(Persona.persona_id) == item_id) |
| .values(sort_order=sort_order) |
| ) |
| elif item_type == "folder": |
| await session.execute( |
| update(PersonaFolder) |
| .where(col(PersonaFolder.folder_id) == item_id) |
| .values(sort_order=sort_order) |
| ) |
|
|
| async def insert_preference_or_update(self, scope, scope_id, key, value): |
| """Insert a new preference record or update if it exists.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| query = select(Preference).where( |
| Preference.scope == scope, |
| Preference.scope_id == scope_id, |
| Preference.key == key, |
| ) |
| result = await session.execute(query) |
| existing_preference = result.scalar_one_or_none() |
| if existing_preference: |
| existing_preference.value = value |
| else: |
| new_preference = Preference( |
| scope=scope, |
| scope_id=scope_id, |
| key=key, |
| value=value, |
| ) |
| session.add(new_preference) |
| return existing_preference or new_preference |
|
|
| async def get_preference(self, scope, scope_id, key): |
| """Get a preference by key.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(Preference).where( |
| Preference.scope == scope, |
| Preference.scope_id == scope_id, |
| Preference.key == key, |
| ) |
| result = await session.execute(query) |
| return result.scalar_one_or_none() |
|
|
| async def get_preferences(self, scope, scope_id=None, key=None): |
| """Get all preferences for a specific scope ID or key.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(Preference).where(Preference.scope == scope) |
| if scope_id is not None: |
| query = query.where(Preference.scope_id == scope_id) |
| if key is not None: |
| query = query.where(Preference.key == key) |
| result = await session.execute(query) |
| return result.scalars().all() |
|
|
| async def remove_preference(self, scope, scope_id, key) -> None: |
| """Remove a preference by scope ID and key.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| delete(Preference).where( |
| col(Preference.scope) == scope, |
| col(Preference.scope_id) == scope_id, |
| col(Preference.key) == key, |
| ), |
| ) |
| await session.commit() |
|
|
| async def clear_preferences(self, scope, scope_id) -> None: |
| """Clear all preferences for a specific scope ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| delete(Preference).where( |
| col(Preference.scope) == scope, |
| col(Preference.scope_id) == scope_id, |
| ), |
| ) |
| await session.commit() |
|
|
| |
| |
| |
|
|
| async def _run_in_tx( |
| self, |
| fn: Callable[[AsyncSession], Awaitable[TxResult]], |
| ) -> TxResult: |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| return await fn(session) |
|
|
| @staticmethod |
| def _apply_updates(model, **updates) -> None: |
| for field, value in updates.items(): |
| if value is not None: |
| setattr(model, field, value) |
|
|
| @staticmethod |
| def _new_command_config( |
| handler_full_name: str, |
| plugin_name: str, |
| module_path: str, |
| original_command: str, |
| *, |
| resolved_command: str | None = None, |
| enabled: bool | None = None, |
| keep_original_alias: bool | None = None, |
| conflict_key: str | None = None, |
| resolution_strategy: str | None = None, |
| note: str | None = None, |
| extra_data: dict | None = None, |
| auto_managed: bool | None = None, |
| ) -> CommandConfig: |
| return CommandConfig( |
| handler_full_name=handler_full_name, |
| plugin_name=plugin_name, |
| module_path=module_path, |
| original_command=original_command, |
| resolved_command=resolved_command, |
| enabled=True if enabled is None else enabled, |
| keep_original_alias=False |
| if keep_original_alias is None |
| else keep_original_alias, |
| conflict_key=conflict_key or original_command, |
| resolution_strategy=resolution_strategy, |
| note=note, |
| extra_data=extra_data, |
| auto_managed=bool(auto_managed), |
| ) |
|
|
| @staticmethod |
| def _new_command_conflict( |
| conflict_key: str, |
| handler_full_name: str, |
| plugin_name: str, |
| *, |
| status: str | None = None, |
| resolution: str | None = None, |
| resolved_command: str | None = None, |
| note: str | None = None, |
| extra_data: dict | None = None, |
| auto_generated: bool | None = None, |
| ) -> CommandConflict: |
| return CommandConflict( |
| conflict_key=conflict_key, |
| handler_full_name=handler_full_name, |
| plugin_name=plugin_name, |
| status=status or "pending", |
| resolution=resolution, |
| resolved_command=resolved_command, |
| note=note, |
| extra_data=extra_data, |
| auto_generated=bool(auto_generated), |
| ) |
|
|
| async def get_command_configs(self) -> list[CommandConfig]: |
| async with self.get_db() as session: |
| session: AsyncSession |
| result = await session.execute(select(CommandConfig)) |
| return list(result.scalars().all()) |
|
|
| async def get_command_config( |
| self, |
| handler_full_name: str, |
| ) -> CommandConfig | None: |
| async with self.get_db() as session: |
| session: AsyncSession |
| return await session.get(CommandConfig, handler_full_name) |
|
|
| async def upsert_command_config( |
| self, |
| handler_full_name: str, |
| plugin_name: str, |
| module_path: str, |
| original_command: str, |
| *, |
| resolved_command: str | None = None, |
| enabled: bool | None = None, |
| keep_original_alias: bool | None = None, |
| conflict_key: str | None = None, |
| resolution_strategy: str | None = None, |
| note: str | None = None, |
| extra_data: dict | None = None, |
| auto_managed: bool | None = None, |
| ) -> CommandConfig: |
| async def _op(session: AsyncSession) -> CommandConfig: |
| config = await session.get(CommandConfig, handler_full_name) |
| if not config: |
| config = self._new_command_config( |
| handler_full_name, |
| plugin_name, |
| module_path, |
| original_command, |
| resolved_command=resolved_command, |
| enabled=enabled, |
| keep_original_alias=keep_original_alias, |
| conflict_key=conflict_key, |
| resolution_strategy=resolution_strategy, |
| note=note, |
| extra_data=extra_data, |
| auto_managed=auto_managed, |
| ) |
| session.add(config) |
| else: |
| self._apply_updates( |
| config, |
| plugin_name=plugin_name, |
| module_path=module_path, |
| original_command=original_command, |
| resolved_command=resolved_command, |
| enabled=enabled, |
| keep_original_alias=keep_original_alias, |
| conflict_key=conflict_key, |
| resolution_strategy=resolution_strategy, |
| note=note, |
| extra_data=extra_data, |
| auto_managed=auto_managed, |
| ) |
| await session.flush() |
| await session.refresh(config) |
| return config |
|
|
| return await self._run_in_tx(_op) |
|
|
| async def delete_command_config(self, handler_full_name: str) -> None: |
| await self.delete_command_configs([handler_full_name]) |
|
|
| async def delete_command_configs(self, handler_full_names: list[str]) -> None: |
| if not handler_full_names: |
| return |
|
|
| async def _op(session: AsyncSession) -> None: |
| await session.execute( |
| delete(CommandConfig).where( |
| col(CommandConfig.handler_full_name).in_(handler_full_names), |
| ), |
| ) |
|
|
| await self._run_in_tx(_op) |
|
|
| async def list_command_conflicts( |
| self, |
| status: str | None = None, |
| ) -> list[CommandConflict]: |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(CommandConflict) |
| if status: |
| query = query.where(CommandConflict.status == status) |
| result = await session.execute(query) |
| return list(result.scalars().all()) |
|
|
| async def upsert_command_conflict( |
| self, |
| conflict_key: str, |
| handler_full_name: str, |
| plugin_name: str, |
| *, |
| status: str | None = None, |
| resolution: str | None = None, |
| resolved_command: str | None = None, |
| note: str | None = None, |
| extra_data: dict | None = None, |
| auto_generated: bool | None = None, |
| ) -> CommandConflict: |
| async def _op(session: AsyncSession) -> CommandConflict: |
| result = await session.execute( |
| select(CommandConflict).where( |
| CommandConflict.conflict_key == conflict_key, |
| CommandConflict.handler_full_name == handler_full_name, |
| ), |
| ) |
| record = result.scalar_one_or_none() |
| if not record: |
| record = self._new_command_conflict( |
| conflict_key, |
| handler_full_name, |
| plugin_name, |
| status=status, |
| resolution=resolution, |
| resolved_command=resolved_command, |
| note=note, |
| extra_data=extra_data, |
| auto_generated=auto_generated, |
| ) |
| session.add(record) |
| else: |
| self._apply_updates( |
| record, |
| plugin_name=plugin_name, |
| status=status, |
| resolution=resolution, |
| resolved_command=resolved_command, |
| note=note, |
| extra_data=extra_data, |
| auto_generated=auto_generated, |
| ) |
| await session.flush() |
| await session.refresh(record) |
| return record |
|
|
| return await self._run_in_tx(_op) |
|
|
| async def delete_command_conflicts(self, ids: list[int]) -> None: |
| if not ids: |
| return |
|
|
| async def _op(session: AsyncSession) -> None: |
| await session.execute( |
| delete(CommandConflict).where(col(CommandConflict.id).in_(ids)), |
| ) |
|
|
| await self._run_in_tx(_op) |
|
|
| |
| |
| |
|
|
| def get_base_stats(self, offset_sec=86400): |
| """Get base statistics within the specified offset in seconds.""" |
|
|
| async def _inner(): |
| async with self.get_db() as session: |
| session: AsyncSession |
| now = datetime.now() |
| start_time = now - timedelta(seconds=offset_sec) |
| result = await session.execute( |
| select(PlatformStat).where(PlatformStat.timestamp >= start_time), |
| ) |
| all_datas = result.scalars().all() |
| deprecated_stats = DeprecatedStats() |
| for data in all_datas: |
| deprecated_stats.platform.append( |
| DeprecatedPlatformStat( |
| name=data.platform_id, |
| count=data.count, |
| timestamp=int(data.timestamp.timestamp()), |
| ), |
| ) |
| return deprecated_stats |
|
|
| result = None |
|
|
| def runner() -> None: |
| nonlocal result |
| result = asyncio.run(_inner()) |
|
|
| t = threading.Thread(target=runner) |
| t.start() |
| t.join() |
| return result |
|
|
| def get_total_message_count(self): |
| """Get the total message count from platform statistics.""" |
|
|
| async def _inner(): |
| async with self.get_db() as session: |
| session: AsyncSession |
| result = await session.execute( |
| select(func.sum(PlatformStat.count)).select_from(PlatformStat), |
| ) |
| total_count = result.scalar_one_or_none() |
| return total_count if total_count is not None else 0 |
|
|
| result = None |
|
|
| def runner() -> None: |
| nonlocal result |
| result = asyncio.run(_inner()) |
|
|
| t = threading.Thread(target=runner) |
| t.start() |
| t.join() |
| return result |
|
|
| def get_grouped_base_stats(self, offset_sec=86400): |
| |
| async def _inner(): |
| async with self.get_db() as session: |
| session: AsyncSession |
| now = datetime.now() |
| start_time = now - timedelta(seconds=offset_sec) |
| result = await session.execute( |
| select(PlatformStat.platform_id, func.sum(PlatformStat.count)) |
| .where(PlatformStat.timestamp >= start_time) |
| .group_by(PlatformStat.platform_id), |
| ) |
| grouped_stats = result.all() |
| deprecated_stats = DeprecatedStats() |
| for platform_id, count in grouped_stats: |
| deprecated_stats.platform.append( |
| DeprecatedPlatformStat( |
| name=platform_id, |
| count=count, |
| timestamp=int(start_time.timestamp()), |
| ), |
| ) |
| return deprecated_stats |
|
|
| result = None |
|
|
| def runner() -> None: |
| nonlocal result |
| result = asyncio.run(_inner()) |
|
|
| t = threading.Thread(target=runner) |
| t.start() |
| t.join() |
| return result |
|
|
| |
| |
| |
|
|
| async def create_platform_session( |
| self, |
| creator: str, |
| platform_id: str = "webchat", |
| session_id: str | None = None, |
| display_name: str | None = None, |
| is_group: int = 0, |
| ) -> PlatformSession: |
| """Create a new Platform session.""" |
| kwargs = {} |
| if session_id: |
| kwargs["session_id"] = session_id |
|
|
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| new_session = PlatformSession( |
| creator=creator, |
| platform_id=platform_id, |
| display_name=display_name, |
| is_group=is_group, |
| **kwargs, |
| ) |
| session.add(new_session) |
| await session.flush() |
| await session.refresh(new_session) |
| return new_session |
|
|
| async def get_platform_session_by_id( |
| self, session_id: str |
| ) -> PlatformSession | None: |
| """Get a Platform session by its ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(PlatformSession).where( |
| PlatformSession.session_id == session_id, |
| ) |
| result = await session.execute(query) |
| return result.scalar_one_or_none() |
|
|
| async def get_platform_sessions_by_creator( |
| self, |
| creator: str, |
| platform_id: str | None = None, |
| page: int = 1, |
| page_size: int = 20, |
| ) -> list[dict]: |
| """Get all Platform sessions for a specific creator (username) and optionally platform. |
| |
| Returns a list of dicts containing session info and project info (if session belongs to a project). |
| """ |
| ( |
| sessions_with_projects, |
| _, |
| ) = await self.get_platform_sessions_by_creator_paginated( |
| creator=creator, |
| platform_id=platform_id, |
| page=page, |
| page_size=page_size, |
| exclude_project_sessions=False, |
| ) |
| return sessions_with_projects |
|
|
| @staticmethod |
| def _build_platform_sessions_query( |
| creator: str, |
| platform_id: str | None = None, |
| exclude_project_sessions: bool = False, |
| ): |
| query = ( |
| select( |
| PlatformSession, |
| col(ChatUIProject.project_id), |
| col(ChatUIProject.title).label("project_title"), |
| col(ChatUIProject.emoji).label("project_emoji"), |
| ) |
| .outerjoin( |
| SessionProjectRelation, |
| col(PlatformSession.session_id) |
| == col(SessionProjectRelation.session_id), |
| ) |
| .outerjoin( |
| ChatUIProject, |
| col(SessionProjectRelation.project_id) == col(ChatUIProject.project_id), |
| ) |
| .where(col(PlatformSession.creator) == creator) |
| ) |
|
|
| if platform_id: |
| query = query.where(PlatformSession.platform_id == platform_id) |
| if exclude_project_sessions: |
| query = query.where(col(ChatUIProject.project_id).is_(None)) |
|
|
| return query |
|
|
| @staticmethod |
| def _rows_to_session_dicts(rows: T.Sequence[Row[tuple]]) -> list[dict]: |
| sessions_with_projects = [] |
| for row in rows: |
| platform_session = row[0] |
| project_id = row[1] |
| project_title = row[2] |
| project_emoji = row[3] |
|
|
| session_dict = { |
| "session": platform_session, |
| "project_id": project_id, |
| "project_title": project_title, |
| "project_emoji": project_emoji, |
| } |
| sessions_with_projects.append(session_dict) |
|
|
| return sessions_with_projects |
|
|
| async def get_platform_sessions_by_creator_paginated( |
| self, |
| creator: str, |
| platform_id: str | None = None, |
| page: int = 1, |
| page_size: int = 20, |
| exclude_project_sessions: bool = False, |
| ) -> tuple[list[dict], int]: |
| """Get paginated Platform sessions for a creator with total count.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| offset = (page - 1) * page_size |
|
|
| base_query = self._build_platform_sessions_query( |
| creator=creator, |
| platform_id=platform_id, |
| exclude_project_sessions=exclude_project_sessions, |
| ) |
|
|
| total_result = await session.execute( |
| select(func.count()).select_from(base_query.subquery()) |
| ) |
| total = int(total_result.scalar_one() or 0) |
|
|
| result_query = ( |
| base_query.order_by(desc(PlatformSession.updated_at)) |
| .offset(offset) |
| .limit(page_size) |
| ) |
| result = await session.execute(result_query) |
|
|
| sessions_with_projects = self._rows_to_session_dicts(result.all()) |
| return sessions_with_projects, total |
|
|
| async def update_platform_session( |
| self, |
| session_id: str, |
| display_name: str | None = None, |
| ) -> None: |
| """Update a Platform session's updated_at timestamp and optionally display_name.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} |
| if display_name is not None: |
| values["display_name"] = display_name |
|
|
| await session.execute( |
| update(PlatformSession) |
| .where(col(PlatformSession.session_id) == session_id) |
| .values(**values), |
| ) |
|
|
| async def delete_platform_session(self, session_id: str) -> None: |
| """Delete a Platform session by its ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| delete(PlatformSession).where( |
| col(PlatformSession.session_id) == session_id, |
| ), |
| ) |
|
|
| |
| |
| |
|
|
| async def create_chatui_project( |
| self, |
| creator: str, |
| title: str, |
| emoji: str | None = "📁", |
| description: str | None = None, |
| ) -> ChatUIProject: |
| """Create a new ChatUI project.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| project = ChatUIProject( |
| creator=creator, |
| title=title, |
| emoji=emoji, |
| description=description, |
| ) |
| session.add(project) |
| await session.flush() |
| await session.refresh(project) |
| return project |
|
|
| async def get_chatui_project_by_id(self, project_id: str) -> ChatUIProject | None: |
| """Get a ChatUI project by its ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| result = await session.execute( |
| select(ChatUIProject).where( |
| col(ChatUIProject.project_id) == project_id, |
| ), |
| ) |
| return result.scalar_one_or_none() |
|
|
| async def get_chatui_projects_by_creator( |
| self, |
| creator: str, |
| page: int = 1, |
| page_size: int = 100, |
| ) -> list[ChatUIProject]: |
| """Get all ChatUI projects for a specific creator.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| offset = (page - 1) * page_size |
| result = await session.execute( |
| select(ChatUIProject) |
| .where(col(ChatUIProject.creator) == creator) |
| .order_by(desc(ChatUIProject.updated_at)) |
| .limit(page_size) |
| .offset(offset), |
| ) |
| return list(result.scalars().all()) |
|
|
| async def update_chatui_project( |
| self, |
| project_id: str, |
| title: str | None = None, |
| emoji: str | None = None, |
| description: str | None = None, |
| ) -> None: |
| """Update a ChatUI project.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} |
| if title is not None: |
| values["title"] = title |
| if emoji is not None: |
| values["emoji"] = emoji |
| if description is not None: |
| values["description"] = description |
|
|
| await session.execute( |
| update(ChatUIProject) |
| .where(col(ChatUIProject.project_id) == project_id) |
| .values(**values), |
| ) |
|
|
| async def delete_chatui_project(self, project_id: str) -> None: |
| """Delete a ChatUI project by its ID.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| |
| await session.execute( |
| delete(SessionProjectRelation).where( |
| col(SessionProjectRelation.project_id) == project_id, |
| ), |
| ) |
| |
| await session.execute( |
| delete(ChatUIProject).where( |
| col(ChatUIProject.project_id) == project_id, |
| ), |
| ) |
|
|
| async def add_session_to_project( |
| self, |
| session_id: str, |
| project_id: str, |
| ) -> SessionProjectRelation: |
| """Add a session to a project.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| |
| await session.execute( |
| delete(SessionProjectRelation).where( |
| col(SessionProjectRelation.session_id) == session_id, |
| ), |
| ) |
| |
| relation = SessionProjectRelation( |
| session_id=session_id, |
| project_id=project_id, |
| ) |
| session.add(relation) |
| await session.flush() |
| await session.refresh(relation) |
| return relation |
|
|
| async def remove_session_from_project(self, session_id: str) -> None: |
| """Remove a session from its project.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| delete(SessionProjectRelation).where( |
| col(SessionProjectRelation.session_id) == session_id, |
| ), |
| ) |
|
|
| async def get_project_sessions( |
| self, |
| project_id: str, |
| page: int = 1, |
| page_size: int = 100, |
| ) -> list[PlatformSession]: |
| """Get all sessions in a project.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| offset = (page - 1) * page_size |
| result = await session.execute( |
| select(PlatformSession) |
| .join( |
| SessionProjectRelation, |
| col(PlatformSession.session_id) |
| == col(SessionProjectRelation.session_id), |
| ) |
| .where(col(SessionProjectRelation.project_id) == project_id) |
| .order_by(desc(PlatformSession.updated_at)) |
| .limit(page_size) |
| .offset(offset), |
| ) |
| return list(result.scalars().all()) |
|
|
| async def get_project_by_session( |
| self, session_id: str, creator: str |
| ) -> ChatUIProject | None: |
| """Get the project that a session belongs to.""" |
| async with self.get_db() as session: |
| session: AsyncSession |
| result = await session.execute( |
| select(ChatUIProject) |
| .join( |
| SessionProjectRelation, |
| col(ChatUIProject.project_id) |
| == col(SessionProjectRelation.project_id), |
| ) |
| .where( |
| col(SessionProjectRelation.session_id) == session_id, |
| col(ChatUIProject.creator) == creator, |
| ), |
| ) |
| return result.scalar_one_or_none() |
|
|
| |
| |
| |
|
|
| async def create_cron_job( |
| self, |
| name: str, |
| job_type: str, |
| cron_expression: str | None, |
| *, |
| timezone: str | None = None, |
| payload: dict | None = None, |
| description: str | None = None, |
| enabled: bool = True, |
| persistent: bool = True, |
| run_once: bool = False, |
| status: str | None = None, |
| job_id: str | None = None, |
| ) -> CronJob: |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| job = CronJob( |
| name=name, |
| job_type=job_type, |
| cron_expression=cron_expression, |
| timezone=timezone, |
| payload=payload or {}, |
| description=description, |
| enabled=enabled, |
| persistent=persistent, |
| run_once=run_once, |
| status=status or "scheduled", |
| ) |
| if job_id: |
| job.job_id = job_id |
| session.add(job) |
| await session.flush() |
| await session.refresh(job) |
| return job |
|
|
| async def update_cron_job( |
| self, |
| job_id: str, |
| *, |
| name: str | None | object = CRON_FIELD_NOT_SET, |
| cron_expression: str | None | object = CRON_FIELD_NOT_SET, |
| timezone: str | None | object = CRON_FIELD_NOT_SET, |
| payload: dict | None | object = CRON_FIELD_NOT_SET, |
| description: str | None | object = CRON_FIELD_NOT_SET, |
| enabled: bool | None | object = CRON_FIELD_NOT_SET, |
| persistent: bool | None | object = CRON_FIELD_NOT_SET, |
| run_once: bool | None | object = CRON_FIELD_NOT_SET, |
| status: str | None | object = CRON_FIELD_NOT_SET, |
| next_run_time: datetime | None | object = CRON_FIELD_NOT_SET, |
| last_run_at: datetime | None | object = CRON_FIELD_NOT_SET, |
| last_error: str | None | object = CRON_FIELD_NOT_SET, |
| ) -> CronJob | None: |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| updates: dict = {} |
| for key, val in { |
| "name": name, |
| "cron_expression": cron_expression, |
| "timezone": timezone, |
| "payload": payload, |
| "description": description, |
| "enabled": enabled, |
| "persistent": persistent, |
| "run_once": run_once, |
| "status": status, |
| "next_run_time": next_run_time, |
| "last_run_at": last_run_at, |
| "last_error": last_error, |
| }.items(): |
| if val is CRON_FIELD_NOT_SET: |
| continue |
| updates[key] = val |
|
|
| stmt = ( |
| update(CronJob) |
| .where(col(CronJob.job_id) == job_id) |
| .values(**updates) |
| .execution_options(synchronize_session="fetch") |
| ) |
| await session.execute(stmt) |
| result = await session.execute( |
| select(CronJob).where(col(CronJob.job_id) == job_id) |
| ) |
| return result.scalar_one_or_none() |
|
|
| async def delete_cron_job(self, job_id: str) -> None: |
| async with self.get_db() as session: |
| session: AsyncSession |
| async with session.begin(): |
| await session.execute( |
| delete(CronJob).where(col(CronJob.job_id) == job_id) |
| ) |
|
|
| async def get_cron_job(self, job_id: str) -> CronJob | None: |
| async with self.get_db() as session: |
| session: AsyncSession |
| result = await session.execute( |
| select(CronJob).where(col(CronJob.job_id) == job_id) |
| ) |
| return result.scalar_one_or_none() |
|
|
| async def list_cron_jobs(self, job_type: str | None = None) -> list[CronJob]: |
| async with self.get_db() as session: |
| session: AsyncSession |
| query = select(CronJob) |
| if job_type: |
| query = query.where(col(CronJob.job_type) == job_type) |
| query = query.order_by(desc(CronJob.created_at)) |
| result = await session.execute(query) |
| return list(result.scalars().all()) |
|
|