YT-AI-Automation / backend /src /utils /cache_manager.py
github-actions
Sync Docker Space
5f3e9f5
"""Response caching for AI requests with TTL and size limits.
Thread-safety
~~~~~~~~~~~~~
A single ``threading.RLock`` guards both the in-memory ``self.cache`` dict
and the on-disk JSON. Without it, two concurrent SSE workers calling
``set()`` could lose entries (last writer wins on the JSON file), and
``get()`` -> ``last_accessed`` updates could race with other writers.
Hot-path I/O
~~~~~~~~~~~~
The previous implementation rewrote the whole JSON file on every cache
*hit* just to update ``last_accessed``. That's an unbounded write rate
under load. We now batch those updates and only flush to disk every
``LRU_FLUSH_INTERVAL`` seconds (or when ``set()`` mutates state). Worst
case after a crash: a hit's ``last_accessed`` is up to 30 s stale, which
is harmless for LRU eviction.
"""
import hashlib
import json
import os
import threading
import time
from pathlib import Path
# How long we may delay flushing access-time updates for cache hits.
LRU_FLUSH_INTERVAL = 30.0
class CacheManager:
"""Manage cached AI responses with TTL expiration and LRU eviction."""
def __init__(self, cache_dir="output/cache", max_entries=100, ttl_days=7):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_file = self.cache_dir / "ai_responses.json"
self.max_entries = max_entries
self.ttl_seconds = ttl_days * 86400
self._lock = threading.RLock()
self._dirty_since_flush = False
self._last_flush = 0.0
self.cache = self._load_cache()
def _load_cache(self):
"""Load cache from disk."""
if self.cache_file.exists():
try:
with open(self.cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Warning: Could not load cache: {e}")
return {}
return {}
def _save_cache_locked(self, force: bool = False):
"""Save cache to disk while holding the lock.
Uses an atomic ``write tmp + rename`` so a crash mid-write doesn't
leave a half-written ``ai_responses.json``. With ``force=False``
skips the write if no state has changed since the last flush.
"""
if not force and not self._dirty_since_flush:
return
try:
tmp_path = self.cache_file.with_suffix('.json.tmp')
with open(tmp_path, 'w', encoding='utf-8') as f:
json.dump(self.cache, f, ensure_ascii=False, indent=2)
os.replace(tmp_path, self.cache_file)
self._dirty_since_flush = False
self._last_flush = time.time()
except Exception as e:
print(f"Warning: Could not save cache: {e}")
def _generate_key(self, text):
"""Generate cache key from input text.
``text`` here is treated as the full composite key (model | prompt
| user text) β€” see ``ai_client.generate_response``.
"""
return hashlib.sha256(text.encode('utf-8')).hexdigest()
def _is_expired(self, entry):
"""Check if a cache entry has expired based on TTL."""
if 'created_at' not in entry:
return False # Legacy entries without timestamp are kept
age = time.time() - entry['created_at']
return age > self.ttl_seconds
def _evict_if_needed_locked(self):
"""Evict oldest entries if cache exceeds max_entries (LRU)."""
if len(self.cache) <= self.max_entries:
return
sorted_keys = sorted(
self.cache.keys(),
key=lambda k: self.cache[k].get(
'last_accessed', self.cache[k].get('created_at', 0)
),
)
entries_to_remove = len(self.cache) - self.max_entries
for key in sorted_keys[:entries_to_remove]:
del self.cache[key]
self._dirty_since_flush = True
print(f"βœ“ Cache evicted entry (LRU): {key[:8]}...")
def get(self, text):
"""Get cached response if available and not expired."""
key = self._generate_key(text)
with self._lock:
entry = self.cache.get(key)
if entry is None:
return None
if self._is_expired(entry):
del self.cache[key]
self._dirty_since_flush = True
self._save_cache_locked(force=True)
print(f"βœ— Cache entry expired: {key[:8]}...")
return None
entry['last_accessed'] = time.time()
self._dirty_since_flush = True
# Throttle flushes for hits β€” see module docstring.
if time.time() - self._last_flush > LRU_FLUSH_INTERVAL:
self._save_cache_locked()
print(f"βœ“ Cache hit! Using cached response from {entry['timestamp']}")
return entry['response']
def set(self, text, response):
"""Cache a response with timestamp tracking."""
key = self._generate_key(text)
now = time.time()
with self._lock:
self.cache[key] = {
'response': response,
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'created_at': now,
'last_accessed': now,
'text_length': len(text),
}
self._dirty_since_flush = True
self._evict_if_needed_locked()
self._save_cache_locked(force=True)
print(f"βœ“ Response cached (key: {key[:8]}...)")
def flush(self):
"""Force a flush of any pending access-time updates to disk."""
with self._lock:
self._save_cache_locked(force=True)
def clear(self):
"""Clear all cached responses."""
with self._lock:
self.cache = {}
self._dirty_since_flush = True
self._save_cache_locked(force=True)
print("βœ“ Cache cleared")
def get_stats(self):
"""Get cache statistics (snapshot β€” safe to call from any thread)."""
with self._lock:
expired_count = sum(
1 for entry in self.cache.values() if self._is_expired(entry)
)
total = len(self.cache)
cache_size = (
self.cache_file.stat().st_size if self.cache_file.exists() else 0
)
return {
'total_entries': total,
'expired_entries': expired_count,
'active_entries': total - expired_count,
'max_entries': self.max_entries,
'ttl_days': self.ttl_seconds / 86400,
'cache_file': str(self.cache_file),
'cache_size_kb': round(cache_size / 1024, 2),
}