| from contextlib import asynccontextmanager |
| from typing import AsyncGenerator |
| from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine |
|
|
| from Backend.core.config import settings |
|
|
| database_url = settings.database_url.replace("postgresql://", "postgresql+asyncpg://") |
|
|
| engine = create_async_engine( |
| database_url, |
| echo=False, |
| pool_size=settings.db_pool_size, |
| max_overflow=settings.db_max_overflow, |
| pool_timeout=settings.db_pool_timeout, |
| pool_recycle=settings.db_pool_recycle, |
| pool_pre_ping=True, |
| connect_args={ |
| "statement_cache_size": settings.db_statement_cache_size, |
| "prepared_statement_cache_size": settings.db_prepared_statement_cache_size, |
| }, |
| ) |
|
|
| async_session_factory = async_sessionmaker( |
| engine, |
| class_=AsyncSession, |
| expire_on_commit=False, |
| autocommit=False, |
| autoflush=False, |
| ) |
|
|
|
|
| async def get_db() -> AsyncGenerator[AsyncSession, None]: |
| async with async_session_factory() as session: |
| try: |
| yield session |
| await session.commit() |
| except Exception: |
| await session.rollback() |
| raise |
|
|
|
|
| @asynccontextmanager |
| async def get_db_context() -> AsyncGenerator[AsyncSession, None]: |
| async with async_session_factory() as session: |
| try: |
| yield session |
| await session.commit() |
| except Exception: |
| await session.rollback() |
| raise |
|
|
|
|
| async def init_db() -> None: |
| from Backend.database.models import Base |
| async with engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
|
|
|
|
| async def close_db() -> None: |
| await engine.dispose() |
|
|