from collections.abc import Generator from pathlib import Path from alembic import command from alembic.config import Config from sqlalchemy import create_engine, event, inspect from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker DATABASE_URL = "sqlite:////data/app.db" LOCAL_DATABASE_URL = "sqlite:///./app.db" class Base(DeclarativeBase): pass def _database_url() -> str: return DATABASE_URL if Path("/data").exists() else LOCAL_DATABASE_URL engine = create_engine( _database_url(), connect_args={"check_same_thread": False}, ) @event.listens_for(engine, "connect") def set_sqlite_pragma(dbapi_conn, _connection_record) -> None: cursor = dbapi_conn.cursor() cursor.execute("PRAGMA journal_mode=WAL") cursor.execute("PRAGMA foreign_keys=ON") cursor.close() SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False) def get_db() -> Generator[Session, None, None]: db = SessionLocal() try: yield db finally: db.close() def init_db() -> None: import models # noqa: F401 Base.metadata.create_all(bind=engine) def run_migrations() -> None: config = Config(str(Path(__file__).parent / "alembic.ini")) config.set_main_option("sqlalchemy.url", _database_url()) inspector = inspect(engine) table_names = set(inspector.get_table_names()) if "users" in table_names and "alembic_version" not in table_names: command.stamp(config, "head") return command.upgrade(config, "head")