""" 数据库会话管理 """ from contextlib import contextmanager from typing import Generator from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.exc import SQLAlchemyError import os import logging from .models import Base logger = logging.getLogger(__name__) def _build_sqlalchemy_url(database_url: str) -> str: if database_url.startswith("postgresql://"): return "postgresql+psycopg://" + database_url[len("postgresql://"):] if database_url.startswith("postgres://"): return "postgresql+psycopg://" + database_url[len("postgres://"):] return database_url class DatabaseSessionManager: """数据库会话管理器""" def __init__(self, database_url: str = None): if database_url is None: env_url = os.environ.get("APP_DATABASE_URL") or os.environ.get("DATABASE_URL") if env_url: database_url = env_url else: # 优先使用 APP_DATA_DIR 环境变量(PyInstaller 打包后由 webui.py 设置) data_dir = os.environ.get('APP_DATA_DIR') or os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'data' ) db_path = os.path.join(data_dir, 'database.db') # 确保目录存在 os.makedirs(data_dir, exist_ok=True) database_url = f"sqlite:///{db_path}" self.database_url = _build_sqlalchemy_url(database_url) self.engine = create_engine( self.database_url, connect_args={"check_same_thread": False} if self.database_url.startswith("sqlite") else {}, echo=False, # 设置为 True 可以查看所有 SQL 语句 pool_pre_ping=True # 连接池预检查 ) self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) def get_db(self) -> Generator[Session, None, None]: """ 获取数据库会话的上下文管理器 使用示例: with get_db() as db: # 使用 db 进行数据库操作 pass """ db = self.SessionLocal() try: yield db finally: db.close() @contextmanager def session_scope(self) -> Generator[Session, None, None]: """ 事务作用域上下文管理器 使用示例: with session_scope() as session: # 数据库操作 pass """ session = self.SessionLocal() try: yield session session.commit() except Exception as e: session.rollback() raise e finally: session.close() def create_tables(self): """创建所有表""" Base.metadata.create_all(bind=self.engine) def drop_tables(self): """删除所有表(谨慎使用)""" Base.metadata.drop_all(bind=self.engine) def migrate_tables(self): """ 数据库迁移 - 添加缺失的列 用于在不删除数据的情况下更新表结构 """ if not self.database_url.startswith("sqlite"): logger.info("非 SQLite 数据库,跳过自动迁移") return # 需要检查和添加的新列 migrations = [ # (表名, 列名, 列类型) ("accounts", "cpa_uploaded", "BOOLEAN DEFAULT 0"), ("accounts", "cpa_uploaded_at", "DATETIME"), ("accounts", "source", "VARCHAR(20) DEFAULT 'register'"), ("accounts", "subscription_type", "VARCHAR(20)"), ("accounts", "subscription_at", "DATETIME"), ("accounts", "cookies", "TEXT"), ("proxies", "is_default", "BOOLEAN DEFAULT 0"), ] # 确保新表存在(create_tables 已处理,此处兜底) Base.metadata.create_all(bind=self.engine) with self.engine.connect() as conn: # 数据迁移:将旧的 custom_domain 记录统一为 moe_mail try: conn.execute(text("UPDATE email_services SET service_type='moe_mail' WHERE service_type='custom_domain'")) conn.execute(text("UPDATE accounts SET email_service='moe_mail' WHERE email_service='custom_domain'")) conn.commit() except Exception as e: logger.warning(f"迁移 custom_domain -> moe_mail 时出错: {e}") for table_name, column_name, column_type in migrations: try: # 检查列是否存在 result = conn.execute(text( f"SELECT * FROM pragma_table_info('{table_name}') WHERE name='{column_name}'" )) if result.fetchone() is None: # 列不存在,添加它 logger.info(f"添加列 {table_name}.{column_name}") conn.execute(text( f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" )) conn.commit() logger.info(f"成功添加列 {table_name}.{column_name}") except Exception as e: logger.warning(f"迁移列 {table_name}.{column_name} 时出错: {e}") # 全局数据库会话管理器实例 _db_manager: DatabaseSessionManager = None def init_database(database_url: str = None) -> DatabaseSessionManager: """ 初始化数据库会话管理器 """ global _db_manager if _db_manager is None: _db_manager = DatabaseSessionManager(database_url) _db_manager.create_tables() # 执行数据库迁移 _db_manager.migrate_tables() return _db_manager def get_session_manager() -> DatabaseSessionManager: """ 获取数据库会话管理器 """ if _db_manager is None: raise RuntimeError("数据库未初始化,请先调用 init_database()") return _db_manager @contextmanager def get_db() -> Generator[Session, None, None]: """ 获取数据库会话的快捷函数 """ manager = get_session_manager() db = manager.SessionLocal() try: yield db finally: db.close()