File size: 1,675 Bytes
47203d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from urllib.parse import urlparse, urlencode, parse_qs, urlunparse

from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver

_checkpointer: AsyncPostgresSaver | None = None
_pool: AsyncConnectionPool | None = None


def _psycopg_conn_string(url: str) -> str:
    """
    Strip parameters psycopg doesn't support (e.g. channel_binding)
    and ensure sslmode is set to require.
    """
    parsed = urlparse(url)
    params = parse_qs(parsed.query)
    # Remove unsupported params
    params.pop("channel_binding", None)
    # Ensure SSL
    if "sslmode" not in params:
        params["sslmode"] = ["require"]
    new_query = urlencode({k: v[0] for k, v in params.items()})
    clean = urlunparse(parsed._replace(query=new_query))
    return clean


async def init_checkpointer() -> AsyncPostgresSaver:
    """
    Initialize AsyncPostgresSaver backed by a psycopg connection pool.
    Call once at app startup. The pool stays open for the process lifetime.
    """
    global _checkpointer, _pool
    conn_string = _psycopg_conn_string(os.getenv("NEON_DB_URL", ""))

    _pool = AsyncConnectionPool(
        conn_string,
        max_size=5,
        kwargs={"autocommit": True, "prepare_threshold": 0, "row_factory": dict_row},
        open=False,
    )
    await _pool.open()

    _checkpointer = AsyncPostgresSaver(_pool)
    await _checkpointer.setup()
    return _checkpointer


def get_checkpointer() -> AsyncPostgresSaver:
    if _checkpointer is None:
        raise RuntimeError("Checkpointer not initialized — call init_checkpointer() first")
    return _checkpointer