| from contextlib import asynccontextmanager |
| from typing import AsyncGenerator |
| from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine |
| from sqlalchemy.pool import NullPool |
|
|
| from Backend.core.config import settings |
|
|
| database_url = settings.database_url.replace("postgresql://", "postgresql+asyncpg://") |
|
|
| engine = create_async_engine( |
| database_url, |
| poolclass=NullPool, |
| echo=False, |
| connect_args={ |
| "statement_cache_size": 0, |
| "prepared_statement_cache_size": 0, |
| }, |
| ) |
|
|
| async_session_factory = async_sessionmaker( |
| engine, |
| class_=AsyncSession, |
| expire_on_commit=False, |
| autocommit=False, |
| autoflush=False, |
| ) |
|
|
|
|
| async def get_db() -> AsyncGenerator[AsyncSession, None]: |
| async with async_session_factory() as session: |
| try: |
| yield session |
| await session.commit() |
| except Exception: |
| await session.rollback() |
| raise |
|
|
|
|
| @asynccontextmanager |
| async def get_db_context() -> AsyncGenerator[AsyncSession, None]: |
| async with async_session_factory() as session: |
| try: |
| yield session |
| await session.commit() |
| except Exception: |
| await session.rollback() |
| raise |
|
|
|
|
| async def init_db() -> None: |
| from Backend.database.models import Base |
| async with engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
|
|
|
|
| async def close_db() -> None: |
| await engine.dispose() |
|
|