File size: 6,722 Bytes
4eefabb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
SQLite-backed grid cache with risk-adaptive TTL.

Design notes
------------
* WAL journal mode lets concurrent reads proceed during writes β€” critical
  for FastAPI's async I/O. Default rollback-journal mode would serialise
  every reader behind a writer.
* All blocking sqlite3 calls are wrapped in `asyncio.to_thread` so they
  never stall the event loop.
* Cache key quantises (lat, lon) to a fixed grid resolution (~1.1 km).
  Without quantisation, floating-point jitter destroys hit rate.
"""
from __future__ import annotations

import asyncio
import json
import sqlite3
import time
from pathlib import Path
from typing import Any

from . import config

_INIT_SQL = """
CREATE TABLE IF NOT EXISTS grid_cache (
    grid_key   TEXT PRIMARY KEY,
    payload    TEXT NOT NULL,
    expires_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_expires ON grid_cache(expires_at);

CREATE TABLE IF NOT EXISTS inference_log (
    id        INTEGER PRIMARY KEY AUTOINCREMENT,
    ts        INTEGER NOT NULL,
    lat       REAL NOT NULL,
    lon       REAL NOT NULL,
    risk      INTEGER NOT NULL,
    veto      INTEGER NOT NULL,
    summary   TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_log_ts ON inference_log(ts);
"""

# Inference-log retention β€” older rows are pruned on startup.
INFERENCE_LOG_RETENTION_DAYS = 7


def _grid_key(lat: float, lon: float, activity: str = "general") -> str:
    res = config.GRID_RESOLUTION_DEG
    return f"{round(lat / res)}:{round(lon / res)}:{activity}"


def _connect(db_path: Path) -> sqlite3.Connection:
    conn = sqlite3.connect(db_path, timeout=5.0, isolation_level=None)
    conn.execute("PRAGMA journal_mode=WAL;")
    conn.execute("PRAGMA synchronous=NORMAL;")
    conn.execute("PRAGMA busy_timeout=5000;")
    return conn


def _init_blocking(db_path: Path) -> None:
    conn = _connect(db_path)
    try:
        conn.executescript(_INIT_SQL)
    finally:
        conn.close()


async def init_db(db_path: Path = config.DB_PATH) -> None:
    """Create tables and switch to WAL. Idempotent."""
    await asyncio.to_thread(_init_blocking, db_path)


def _get_blocking(db_path: Path, key: str) -> tuple[dict[str, Any], int] | None:
    conn = _connect(db_path)
    try:
        row = conn.execute(
            "SELECT payload, expires_at FROM grid_cache WHERE grid_key=?",
            (key,),
        ).fetchone()
        if row is None:
            return None
        payload, expires_at = row
        if expires_at <= int(time.time()):
            return None
        ttl_remaining = expires_at - int(time.time())
        return json.loads(payload), ttl_remaining
    finally:
        conn.close()


async def get(lat: float, lon: float, *, activity: str = "general") -> tuple[dict[str, Any], int] | None:
    return await asyncio.to_thread(_get_blocking, config.DB_PATH, _grid_key(lat, lon, activity))


def _set_blocking(db_path: Path, key: str, payload: dict[str, Any], ttl_sec: int) -> None:
    conn = _connect(db_path)
    try:
        conn.execute(
            "INSERT OR REPLACE INTO grid_cache(grid_key, payload, expires_at) "
            "VALUES (?, ?, ?)",
            (key, json.dumps(payload), int(time.time()) + ttl_sec),
        )
    finally:
        conn.close()


async def set(lat: float, lon: float, payload: dict[str, Any], ttl_sec: int,
              *, activity: str = "general") -> None:
    await asyncio.to_thread(_set_blocking, config.DB_PATH, _grid_key(lat, lon, activity),
                            payload, ttl_sec)


def adaptive_ttl(risk_score: int, has_veto: bool) -> int:
    """Higher risk β†’ shorter TTL. We must not serve stale 'Safe' results
    while severe weather is developing."""
    if has_veto or risk_score >= 70:
        return config.TTL_HIGH_RISK_SEC
    if risk_score >= 40:
        return config.TTL_MID_RISK_SEC
    return config.TTL_LOW_RISK_SEC


def _log_blocking(db_path: Path, lat: float, lon: float, risk: int,
                  veto: bool, summary: str) -> None:
    conn = _connect(db_path)
    try:
        conn.execute(
            "INSERT INTO inference_log(ts, lat, lon, risk, veto, summary) "
            "VALUES (?, ?, ?, ?, ?, ?)",
            (int(time.time()), lat, lon, risk, int(veto), summary),
        )
    finally:
        conn.close()


async def log_inference(lat: float, lon: float, risk: int,
                        veto: bool, summary: str) -> None:
    await asyncio.to_thread(_log_blocking, config.DB_PATH, lat, lon,
                            risk, veto, summary)


# ──────────────────────────────────────────────────────────────────────────
# GC / introspection
# ──────────────────────────────────────────────────────────────────────────

def _prune_blocking(db_path: Path) -> int:
    """Delete expired cache rows + old inference_log rows. Returns total deleted."""
    now = int(time.time())
    log_cutoff = now - INFERENCE_LOG_RETENTION_DAYS * 86_400
    conn = _connect(db_path)
    try:
        c1 = conn.execute("DELETE FROM grid_cache WHERE expires_at <= ?", (now,)).rowcount
        c2 = conn.execute("DELETE FROM inference_log WHERE ts < ?",       (log_cutoff,)).rowcount
        return int(c1 or 0) + int(c2 or 0)
    finally:
        conn.close()


async def prune_expired(db_path: Path = config.DB_PATH) -> int:
    """Run cache GC. Returns number of rows removed across both tables."""
    return await asyncio.to_thread(_prune_blocking, db_path)


def _stats_blocking(db_path: Path) -> dict[str, Any]:
    now = int(time.time())
    conn = _connect(db_path)
    try:
        total  = conn.execute("SELECT COUNT(*) FROM grid_cache").fetchone()[0]
        live   = conn.execute(
            "SELECT COUNT(*) FROM grid_cache WHERE expires_at > ?",
            (now,),
        ).fetchone()[0]
        logged = conn.execute("SELECT COUNT(*) FROM inference_log").fetchone()[0]
        page_size = conn.execute("PRAGMA page_size").fetchone()[0]
        page_count = conn.execute("PRAGMA page_count").fetchone()[0]
        return {
            "rows_total":        int(total),
            "rows_live":         int(live),
            "rows_expired":      int(total) - int(live),
            "inference_log_rows": int(logged),
            "db_bytes":           int(page_size) * int(page_count),
        }
    finally:
        conn.close()


async def cache_stats(db_path: Path = config.DB_PATH) -> dict[str, Any]:
    return await asyncio.to_thread(_stats_blocking, db_path)