Pablo
feat: APOHARA: Context Forge V5 — synthesis + rebrand complete
cf0a8ed
"""VRAM-pressure-aware eviction cache - IMPROVEMENT-002.
Replaces static TTL-based eviction with adaptive LRU/LFU hybrid that responds
to actual GPU memory pressure. Monitors MI300X VRAM via PyRSMI and adjusts
eviction policy dynamically.
Eviction modes:
- RELAXED (VRAM < 70%): No eviction, TTL = 10 minutes
- NORMAL (70-85%): LRU eviction of entries idle > 2 min
- PRESSURE (85-92%): LFU by token_count, evict heaviest first
- CRITICAL (92-96%): Offload inactive KV tensors to CPU RAM
- EMERGENCY (VRAM >= 96%): Hard evict all idle > 30s, block new registrations
"""
import asyncio
import heapq
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from apohara_context_forge.scheduling.step_graph import AgentStepGraph
from apohara_context_forge.metrics.vram_monitor import VRAMMonitor
class EvictionMode(Enum):
RELAXED = "relaxed"
NORMAL = "normal"
PRESSURE = "pressure"
CRITICAL = "critical"
EMERGENCY = "emergency"
WORKFLOW_AWARE = "workflow_aware"
@dataclass(order=True)
class CacheEntry:
# Priority for heap (lower = evict first): last_accessed - (access_count * 10)
# LFU/LRU hybrid: frequent+recent entries survive longer
priority: float = field(compare=True)
last_accessed: float = field(compare=False, default_factory=time.monotonic)
access_count: int = field(compare=False, default=0)
token_count: int = field(compare=False, default=0)
key: str = field(compare=False, default="")
value: Any = field(compare=False, default=None)
offloaded_to_cpu: bool = field(compare=False, default=False)
class VRAMAwareCache:
"""
LRU/LFU hybrid cache with VRAM pressure-responsive eviction.
Monitors AMD MI300X memory in real-time via PyRSMI.
Usage:
cache = VRAMAwareCache(max_token_budget=50_000_000) # 50M tokens = ~3GB
await cache.start()
await cache.set("agent1", context_entry, token_count=500)
entry = await cache.get("agent1")
await cache.stop()
"""
VRAM_CHECK_INTERVAL = 2.0 # seconds between VRAM pressure checks
def __init__(self, max_token_budget: int = 50_000_000, step_graph: Optional["AgentStepGraph"] = None):
"""
Args:
max_token_budget: Maximum tokens to hold in cache (~3GB for 64-layer model)
step_graph: Optional workflow dependency graph for WORKFLOW_AWARE eviction
"""
self._store: dict[str, CacheEntry] = {}
self._heap: list[CacheEntry] = []
self._total_tokens: int = 0
self._max_token_budget = max_token_budget
self._vram = VRAMMonitor()
self._mode = EvictionMode.RELAXED
self._lock = asyncio.Lock()
self._monitor_task: Optional[asyncio.Task] = None
self._blocked = False
self._step_graph = step_graph
async def start(self) -> None:
"""Start background VRAM monitor."""
if self._monitor_task is not None:
return
self._monitor_task = asyncio.create_task(self._vram_monitor_loop())
async def stop(self) -> None:
"""Stop background monitoring."""
if self._monitor_task:
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
self._monitor_task = None
async def _vram_monitor_loop(self) -> None:
"""Background loop: check VRAM pressure every interval."""
while True:
try:
pressure = self._vram.get_pressure()
new_mode = self._pressure_to_mode(pressure, self._step_graph)
if new_mode != self._mode:
self._mode = new_mode
if new_mode == EvictionMode.EMERGENCY:
self._blocked = True
elif self._mode == EvictionMode.EMERGENCY:
self._blocked = False
await self._apply_eviction_policy()
await asyncio.sleep(self.VRAM_CHECK_INTERVAL)
except asyncio.CancelledError:
break
except Exception as e:
await asyncio.sleep(1) # Brief backoff on error
@staticmethod
def _pressure_to_mode(pressure: float, step_graph=None) -> EvictionMode:
"""Convert VRAM pressure to eviction mode."""
if pressure < 0.70: return EvictionMode.RELAXED
if pressure < 0.85: return EvictionMode.NORMAL
if pressure < 0.92: return EvictionMode.PRESSURE
if pressure < 0.96: return EvictionMode.CRITICAL
return EvictionMode.EMERGENCY
async def set(self, key: str, value: Any, token_count: int) -> bool:
"""
Store value in cache.
Args:
key: Cache key (e.g., "context:agent1")
value: Value to store
token_count: Token count for VRAM tracking
Returns:
True if stored, False if blocked in EMERGENCY mode
"""
if self._blocked:
return False
entry = CacheEntry(
priority=time.monotonic(), # Will be updated by LRU/LFU formula
last_accessed=time.monotonic(),
access_count=1,
token_count=token_count,
key=key,
value=value,
)
async with self._lock:
# Evict old entry if key exists
if key in self._store:
old_entry = self._store[key]
self._total_tokens -= old_entry.token_count
self._store[key] = entry
heapq.heappush(self._heap, entry)
self._total_tokens += token_count
# Trigger eviction check if needed
if self._mode in (EvictionMode.PRESSURE, EvictionMode.CRITICAL, EvictionMode.EMERGENCY):
await self._apply_eviction_policy()
return True
async def get(self, key: str) -> Any | None:
"""Retrieve value, updating access metadata."""
async with self._lock:
entry = self._store.get(key)
if entry is None:
return None
# Update access metadata
entry.last_accessed = time.monotonic()
entry.access_count += 1
# Recalculate priority: lower = evict first
entry.priority = entry.last_accessed - (entry.access_count * 10)
return entry.value
async def delete(self, key: str) -> bool:
"""Delete entry from cache."""
async with self._lock:
entry = self._store.pop(key, None)
if entry:
self._total_tokens -= entry.token_count
return True
return False
async def _apply_eviction_policy(self, pressure: Optional[float] = None) -> int:
"""
Apply eviction policy based on current mode (or pressure override for testing).
Args:
pressure: Optional pressure value to use for mode determination (for testing).
If None, uses the actual VRAM pressure reading.
Returns:
Number of entries evicted
"""
evicted = 0
now = time.monotonic()
# Determine mode: use pressure override if provided (for testing),
# else respect pre-set _mode (tests set it directly),
# else read from VRAM monitor (production)
if pressure is not None:
mode = self._pressure_to_mode(pressure, self._step_graph)
elif hasattr(self, '_mode') and self._mode is not None:
mode = self._mode
else:
mode = self._pressure_to_mode(self._vram.get_pressure(), self._step_graph)
# Update internal mode if pressure override was provided (for testing)
if pressure is not None:
self._mode = mode
async with self._lock:
match mode:
case EvictionMode.EMERGENCY:
self._blocked = True
# Hard evict everything idle > 30s
to_evict = [
k for k, e in self._store.items()
if now - e.last_accessed > 30
]
for k in to_evict:
self._evict(k)
evicted += 1
case EvictionMode.CRITICAL:
self._blocked = False
# Mark inactive for CPU offload instead of destroying
for entry in self._store.values():
if now - entry.last_accessed > 30 and not entry.offloaded_to_cpu:
entry.offloaded_to_cpu = True
case EvictionMode.NORMAL:
self._blocked = False
# LRU: evict entries idle > 120s
to_evict = [
k for k, e in self._store.items()
if now - e.last_accessed > 120
]
for k in to_evict:
self._evict(k)
evicted += 1
case EvictionMode.PRESSURE:
self._blocked = False
# LFU by token_count: evict heaviest, least used first
candidates = sorted(
self._store.values(),
key=lambda e: e.token_count / max(e.access_count, 1),
reverse=True
)
# Evict top 25%
target = max(1, int(len(candidates) * 0.25))
for entry in candidates[:target]:
self._evict(entry.key)
evicted += 1
case EvictionMode.RELAXED:
self._blocked = False
# No eviction needed
case EvictionMode.WORKFLOW_AWARE:
self._blocked = False
if self._step_graph is not None:
priority_order = self._step_graph.get_eviction_priority_order()
# Evict in reverse priority order (lowest priority first)
for agent_id in reversed(priority_order):
key = f"context:{agent_id}"
if key in self._store:
self._evict(key)
evicted += 1
if evicted > 0:
await self._reheap()
return evicted
def _evict(self, key: str) -> None:
"""Remove entry. Must be called under lock."""
entry = self._store.pop(key, None)
if entry:
self._total_tokens -= entry.token_count
async def _reheap(self) -> None:
"""Rebuild heap after evictions."""
self._heap = list(self._store.values())
heapq.heapify(self._heap)
async def clear(self) -> None:
"""Clear all entries."""
async with self._lock:
self._store.clear()
self._heap.clear()
self._total_tokens = 0
@property
def size(self) -> int:
"""Number of entries."""
return len(self._store)
@property
def total_tokens(self) -> int:
"""Total token count in cache."""
return self._total_tokens
@property
def mode(self) -> EvictionMode:
"""Current eviction mode."""
return self._mode
@property
def is_blocked(self) -> bool:
"""True if new registrations are blocked (EMERGENCY mode)."""
return self._blocked
@property
def step_graph(self) -> Optional["AgentStepGraph"]:
"""The workflow dependency graph for WORKFLOW_AWARE eviction."""
return self._step_graph