File size: 6,722 Bytes
a8358d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | """
SQLite-backed grid cache with risk-adaptive TTL.
Design notes
------------
* WAL journal mode lets concurrent reads proceed during writes β critical
for FastAPI's async I/O. Default rollback-journal mode would serialise
every reader behind a writer.
* All blocking sqlite3 calls are wrapped in `asyncio.to_thread` so they
never stall the event loop.
* Cache key quantises (lat, lon) to a fixed grid resolution (~1.1 km).
Without quantisation, floating-point jitter destroys hit rate.
"""
from __future__ import annotations
import asyncio
import json
import sqlite3
import time
from pathlib import Path
from typing import Any
from . import config
_INIT_SQL = """
CREATE TABLE IF NOT EXISTS grid_cache (
grid_key TEXT PRIMARY KEY,
payload TEXT NOT NULL,
expires_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_expires ON grid_cache(expires_at);
CREATE TABLE IF NOT EXISTS inference_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
lat REAL NOT NULL,
lon REAL NOT NULL,
risk INTEGER NOT NULL,
veto INTEGER NOT NULL,
summary TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_log_ts ON inference_log(ts);
"""
# Inference-log retention β older rows are pruned on startup.
INFERENCE_LOG_RETENTION_DAYS = 7
def _grid_key(lat: float, lon: float, activity: str = "general") -> str:
res = config.GRID_RESOLUTION_DEG
return f"{round(lat / res)}:{round(lon / res)}:{activity}"
def _connect(db_path: Path) -> sqlite3.Connection:
conn = sqlite3.connect(db_path, timeout=5.0, isolation_level=None)
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute("PRAGMA synchronous=NORMAL;")
conn.execute("PRAGMA busy_timeout=5000;")
return conn
def _init_blocking(db_path: Path) -> None:
conn = _connect(db_path)
try:
conn.executescript(_INIT_SQL)
finally:
conn.close()
async def init_db(db_path: Path = config.DB_PATH) -> None:
"""Create tables and switch to WAL. Idempotent."""
await asyncio.to_thread(_init_blocking, db_path)
def _get_blocking(db_path: Path, key: str) -> tuple[dict[str, Any], int] | None:
conn = _connect(db_path)
try:
row = conn.execute(
"SELECT payload, expires_at FROM grid_cache WHERE grid_key=?",
(key,),
).fetchone()
if row is None:
return None
payload, expires_at = row
if expires_at <= int(time.time()):
return None
ttl_remaining = expires_at - int(time.time())
return json.loads(payload), ttl_remaining
finally:
conn.close()
async def get(lat: float, lon: float, *, activity: str = "general") -> tuple[dict[str, Any], int] | None:
return await asyncio.to_thread(_get_blocking, config.DB_PATH, _grid_key(lat, lon, activity))
def _set_blocking(db_path: Path, key: str, payload: dict[str, Any], ttl_sec: int) -> None:
conn = _connect(db_path)
try:
conn.execute(
"INSERT OR REPLACE INTO grid_cache(grid_key, payload, expires_at) "
"VALUES (?, ?, ?)",
(key, json.dumps(payload), int(time.time()) + ttl_sec),
)
finally:
conn.close()
async def set(lat: float, lon: float, payload: dict[str, Any], ttl_sec: int,
*, activity: str = "general") -> None:
await asyncio.to_thread(_set_blocking, config.DB_PATH, _grid_key(lat, lon, activity),
payload, ttl_sec)
def adaptive_ttl(risk_score: int, has_veto: bool) -> int:
"""Higher risk β shorter TTL. We must not serve stale 'Safe' results
while severe weather is developing."""
if has_veto or risk_score >= 70:
return config.TTL_HIGH_RISK_SEC
if risk_score >= 40:
return config.TTL_MID_RISK_SEC
return config.TTL_LOW_RISK_SEC
def _log_blocking(db_path: Path, lat: float, lon: float, risk: int,
veto: bool, summary: str) -> None:
conn = _connect(db_path)
try:
conn.execute(
"INSERT INTO inference_log(ts, lat, lon, risk, veto, summary) "
"VALUES (?, ?, ?, ?, ?, ?)",
(int(time.time()), lat, lon, risk, int(veto), summary),
)
finally:
conn.close()
async def log_inference(lat: float, lon: float, risk: int,
veto: bool, summary: str) -> None:
await asyncio.to_thread(_log_blocking, config.DB_PATH, lat, lon,
risk, veto, summary)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# GC / introspection
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _prune_blocking(db_path: Path) -> int:
"""Delete expired cache rows + old inference_log rows. Returns total deleted."""
now = int(time.time())
log_cutoff = now - INFERENCE_LOG_RETENTION_DAYS * 86_400
conn = _connect(db_path)
try:
c1 = conn.execute("DELETE FROM grid_cache WHERE expires_at <= ?", (now,)).rowcount
c2 = conn.execute("DELETE FROM inference_log WHERE ts < ?", (log_cutoff,)).rowcount
return int(c1 or 0) + int(c2 or 0)
finally:
conn.close()
async def prune_expired(db_path: Path = config.DB_PATH) -> int:
"""Run cache GC. Returns number of rows removed across both tables."""
return await asyncio.to_thread(_prune_blocking, db_path)
def _stats_blocking(db_path: Path) -> dict[str, Any]:
now = int(time.time())
conn = _connect(db_path)
try:
total = conn.execute("SELECT COUNT(*) FROM grid_cache").fetchone()[0]
live = conn.execute(
"SELECT COUNT(*) FROM grid_cache WHERE expires_at > ?",
(now,),
).fetchone()[0]
logged = conn.execute("SELECT COUNT(*) FROM inference_log").fetchone()[0]
page_size = conn.execute("PRAGMA page_size").fetchone()[0]
page_count = conn.execute("PRAGMA page_count").fetchone()[0]
return {
"rows_total": int(total),
"rows_live": int(live),
"rows_expired": int(total) - int(live),
"inference_log_rows": int(logged),
"db_bytes": int(page_size) * int(page_count),
}
finally:
conn.close()
async def cache_stats(db_path: Path = config.DB_PATH) -> dict[str, Any]:
return await asyncio.to_thread(_stats_blocking, db_path)
|