microclimate-x / backend /cache.py
W1nd5pac's picture
Deploy 2026-05-20T06:52:08Z β€” 11e81c5
4eefabb verified
"""
SQLite-backed grid cache with risk-adaptive TTL.
Design notes
------------
* WAL journal mode lets concurrent reads proceed during writes β€” critical
for FastAPI's async I/O. Default rollback-journal mode would serialise
every reader behind a writer.
* All blocking sqlite3 calls are wrapped in `asyncio.to_thread` so they
never stall the event loop.
* Cache key quantises (lat, lon) to a fixed grid resolution (~1.1 km).
Without quantisation, floating-point jitter destroys hit rate.
"""
from __future__ import annotations
import asyncio
import json
import sqlite3
import time
from pathlib import Path
from typing import Any
from . import config
_INIT_SQL = """
CREATE TABLE IF NOT EXISTS grid_cache (
grid_key TEXT PRIMARY KEY,
payload TEXT NOT NULL,
expires_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_expires ON grid_cache(expires_at);
CREATE TABLE IF NOT EXISTS inference_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
lat REAL NOT NULL,
lon REAL NOT NULL,
risk INTEGER NOT NULL,
veto INTEGER NOT NULL,
summary TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_log_ts ON inference_log(ts);
"""
# Inference-log retention β€” older rows are pruned on startup.
INFERENCE_LOG_RETENTION_DAYS = 7
def _grid_key(lat: float, lon: float, activity: str = "general") -> str:
res = config.GRID_RESOLUTION_DEG
return f"{round(lat / res)}:{round(lon / res)}:{activity}"
def _connect(db_path: Path) -> sqlite3.Connection:
conn = sqlite3.connect(db_path, timeout=5.0, isolation_level=None)
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute("PRAGMA synchronous=NORMAL;")
conn.execute("PRAGMA busy_timeout=5000;")
return conn
def _init_blocking(db_path: Path) -> None:
conn = _connect(db_path)
try:
conn.executescript(_INIT_SQL)
finally:
conn.close()
async def init_db(db_path: Path = config.DB_PATH) -> None:
"""Create tables and switch to WAL. Idempotent."""
await asyncio.to_thread(_init_blocking, db_path)
def _get_blocking(db_path: Path, key: str) -> tuple[dict[str, Any], int] | None:
conn = _connect(db_path)
try:
row = conn.execute(
"SELECT payload, expires_at FROM grid_cache WHERE grid_key=?",
(key,),
).fetchone()
if row is None:
return None
payload, expires_at = row
if expires_at <= int(time.time()):
return None
ttl_remaining = expires_at - int(time.time())
return json.loads(payload), ttl_remaining
finally:
conn.close()
async def get(lat: float, lon: float, *, activity: str = "general") -> tuple[dict[str, Any], int] | None:
return await asyncio.to_thread(_get_blocking, config.DB_PATH, _grid_key(lat, lon, activity))
def _set_blocking(db_path: Path, key: str, payload: dict[str, Any], ttl_sec: int) -> None:
conn = _connect(db_path)
try:
conn.execute(
"INSERT OR REPLACE INTO grid_cache(grid_key, payload, expires_at) "
"VALUES (?, ?, ?)",
(key, json.dumps(payload), int(time.time()) + ttl_sec),
)
finally:
conn.close()
async def set(lat: float, lon: float, payload: dict[str, Any], ttl_sec: int,
*, activity: str = "general") -> None:
await asyncio.to_thread(_set_blocking, config.DB_PATH, _grid_key(lat, lon, activity),
payload, ttl_sec)
def adaptive_ttl(risk_score: int, has_veto: bool) -> int:
"""Higher risk β†’ shorter TTL. We must not serve stale 'Safe' results
while severe weather is developing."""
if has_veto or risk_score >= 70:
return config.TTL_HIGH_RISK_SEC
if risk_score >= 40:
return config.TTL_MID_RISK_SEC
return config.TTL_LOW_RISK_SEC
def _log_blocking(db_path: Path, lat: float, lon: float, risk: int,
veto: bool, summary: str) -> None:
conn = _connect(db_path)
try:
conn.execute(
"INSERT INTO inference_log(ts, lat, lon, risk, veto, summary) "
"VALUES (?, ?, ?, ?, ?, ?)",
(int(time.time()), lat, lon, risk, int(veto), summary),
)
finally:
conn.close()
async def log_inference(lat: float, lon: float, risk: int,
veto: bool, summary: str) -> None:
await asyncio.to_thread(_log_blocking, config.DB_PATH, lat, lon,
risk, veto, summary)
# ──────────────────────────────────────────────────────────────────────────
# GC / introspection
# ──────────────────────────────────────────────────────────────────────────
def _prune_blocking(db_path: Path) -> int:
"""Delete expired cache rows + old inference_log rows. Returns total deleted."""
now = int(time.time())
log_cutoff = now - INFERENCE_LOG_RETENTION_DAYS * 86_400
conn = _connect(db_path)
try:
c1 = conn.execute("DELETE FROM grid_cache WHERE expires_at <= ?", (now,)).rowcount
c2 = conn.execute("DELETE FROM inference_log WHERE ts < ?", (log_cutoff,)).rowcount
return int(c1 or 0) + int(c2 or 0)
finally:
conn.close()
async def prune_expired(db_path: Path = config.DB_PATH) -> int:
"""Run cache GC. Returns number of rows removed across both tables."""
return await asyncio.to_thread(_prune_blocking, db_path)
def _stats_blocking(db_path: Path) -> dict[str, Any]:
now = int(time.time())
conn = _connect(db_path)
try:
total = conn.execute("SELECT COUNT(*) FROM grid_cache").fetchone()[0]
live = conn.execute(
"SELECT COUNT(*) FROM grid_cache WHERE expires_at > ?",
(now,),
).fetchone()[0]
logged = conn.execute("SELECT COUNT(*) FROM inference_log").fetchone()[0]
page_size = conn.execute("PRAGMA page_size").fetchone()[0]
page_count = conn.execute("PRAGMA page_count").fetchone()[0]
return {
"rows_total": int(total),
"rows_live": int(live),
"rows_expired": int(total) - int(live),
"inference_log_rows": int(logged),
"db_bytes": int(page_size) * int(page_count),
}
finally:
conn.close()
async def cache_stats(db_path: Path = config.DB_PATH) -> dict[str, Any]:
return await asyncio.to_thread(_stats_blocking, db_path)