Spaces:
Running
Running
File size: 6,729 Bytes
5f3e9f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """Response caching for AI requests with TTL and size limits.
Thread-safety
~~~~~~~~~~~~~
A single ``threading.RLock`` guards both the in-memory ``self.cache`` dict
and the on-disk JSON. Without it, two concurrent SSE workers calling
``set()`` could lose entries (last writer wins on the JSON file), and
``get()`` -> ``last_accessed`` updates could race with other writers.
Hot-path I/O
~~~~~~~~~~~~
The previous implementation rewrote the whole JSON file on every cache
*hit* just to update ``last_accessed``. That's an unbounded write rate
under load. We now batch those updates and only flush to disk every
``LRU_FLUSH_INTERVAL`` seconds (or when ``set()`` mutates state). Worst
case after a crash: a hit's ``last_accessed`` is up to 30 s stale, which
is harmless for LRU eviction.
"""
import hashlib
import json
import os
import threading
import time
from pathlib import Path
# How long we may delay flushing access-time updates for cache hits.
LRU_FLUSH_INTERVAL = 30.0
class CacheManager:
"""Manage cached AI responses with TTL expiration and LRU eviction."""
def __init__(self, cache_dir="output/cache", max_entries=100, ttl_days=7):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_file = self.cache_dir / "ai_responses.json"
self.max_entries = max_entries
self.ttl_seconds = ttl_days * 86400
self._lock = threading.RLock()
self._dirty_since_flush = False
self._last_flush = 0.0
self.cache = self._load_cache()
def _load_cache(self):
"""Load cache from disk."""
if self.cache_file.exists():
try:
with open(self.cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Warning: Could not load cache: {e}")
return {}
return {}
def _save_cache_locked(self, force: bool = False):
"""Save cache to disk while holding the lock.
Uses an atomic ``write tmp + rename`` so a crash mid-write doesn't
leave a half-written ``ai_responses.json``. With ``force=False``
skips the write if no state has changed since the last flush.
"""
if not force and not self._dirty_since_flush:
return
try:
tmp_path = self.cache_file.with_suffix('.json.tmp')
with open(tmp_path, 'w', encoding='utf-8') as f:
json.dump(self.cache, f, ensure_ascii=False, indent=2)
os.replace(tmp_path, self.cache_file)
self._dirty_since_flush = False
self._last_flush = time.time()
except Exception as e:
print(f"Warning: Could not save cache: {e}")
def _generate_key(self, text):
"""Generate cache key from input text.
``text`` here is treated as the full composite key (model | prompt
| user text) β see ``ai_client.generate_response``.
"""
return hashlib.sha256(text.encode('utf-8')).hexdigest()
def _is_expired(self, entry):
"""Check if a cache entry has expired based on TTL."""
if 'created_at' not in entry:
return False # Legacy entries without timestamp are kept
age = time.time() - entry['created_at']
return age > self.ttl_seconds
def _evict_if_needed_locked(self):
"""Evict oldest entries if cache exceeds max_entries (LRU)."""
if len(self.cache) <= self.max_entries:
return
sorted_keys = sorted(
self.cache.keys(),
key=lambda k: self.cache[k].get(
'last_accessed', self.cache[k].get('created_at', 0)
),
)
entries_to_remove = len(self.cache) - self.max_entries
for key in sorted_keys[:entries_to_remove]:
del self.cache[key]
self._dirty_since_flush = True
print(f"β Cache evicted entry (LRU): {key[:8]}...")
def get(self, text):
"""Get cached response if available and not expired."""
key = self._generate_key(text)
with self._lock:
entry = self.cache.get(key)
if entry is None:
return None
if self._is_expired(entry):
del self.cache[key]
self._dirty_since_flush = True
self._save_cache_locked(force=True)
print(f"β Cache entry expired: {key[:8]}...")
return None
entry['last_accessed'] = time.time()
self._dirty_since_flush = True
# Throttle flushes for hits β see module docstring.
if time.time() - self._last_flush > LRU_FLUSH_INTERVAL:
self._save_cache_locked()
print(f"β Cache hit! Using cached response from {entry['timestamp']}")
return entry['response']
def set(self, text, response):
"""Cache a response with timestamp tracking."""
key = self._generate_key(text)
now = time.time()
with self._lock:
self.cache[key] = {
'response': response,
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'created_at': now,
'last_accessed': now,
'text_length': len(text),
}
self._dirty_since_flush = True
self._evict_if_needed_locked()
self._save_cache_locked(force=True)
print(f"β Response cached (key: {key[:8]}...)")
def flush(self):
"""Force a flush of any pending access-time updates to disk."""
with self._lock:
self._save_cache_locked(force=True)
def clear(self):
"""Clear all cached responses."""
with self._lock:
self.cache = {}
self._dirty_since_flush = True
self._save_cache_locked(force=True)
print("β Cache cleared")
def get_stats(self):
"""Get cache statistics (snapshot β safe to call from any thread)."""
with self._lock:
expired_count = sum(
1 for entry in self.cache.values() if self._is_expired(entry)
)
total = len(self.cache)
cache_size = (
self.cache_file.stat().st_size if self.cache_file.exists() else 0
)
return {
'total_entries': total,
'expired_entries': expired_count,
'active_entries': total - expired_count,
'max_entries': self.max_entries,
'ttl_days': self.ttl_seconds / 86400,
'cache_file': str(self.cache_file),
'cache_size_kb': round(cache_size / 1024, 2),
}
|