Pablo
feat: APOHARA: Context Forge V5 — synthesis + rebrand complete
cf0a8ed
"""PBKVPredictor — prediction-based KV cache eviction priority.
Based on PBKV (arXiv:2605.06472, May 2026):
Prediction-based KV cache management for dynamic agent workflows.
Key result: 1.26x speedup over KVFlow (NeurIPS 2025).
Implementation: 2nd-order Markov chain over agent_id sequences.
State: (agent_id_t-2, agent_id_t-1)
Transition: predict agent_id_t with highest probability
Training: MLE on JSONL logs from PBKVPredictor stub output
Why Markov over neural:
- Zero VRAM overhead
- <1μs prediction latency
- Sufficient for agentic workflow patterns (low entropy, high repetition)
- PBKV paper uses similar lightweight approach for dynamic scenarios
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from apohara_context_forge.scheduling.step_graph import AgentStepGraph
logger = logging.getLogger(__name__)
@dataclass
class WorkflowStepRecord:
"""Single step in a workflow sequence."""
step_idx: int
agent_id: str
anchor_hash: str
token_length: int
cla_group: Optional[int] = None
@dataclass
class PredictionResult:
"""Prediction for next KV cache access."""
predicted_agents: list[str] # ranked by probability
predicted_anchor_hashes: list[str]
confidence: float
prefetch_block_ids: list[str] = field(default_factory=list)
class PBKVPredictor:
"""Predictor-based KV cache prefetching using 2nd-order Markov chain.
Design:
1. Log each workflow step to local JSONL file
2. Train Markov transition table from logged steps
3. Predict next agents using transition probabilities
4. Blend with AgentStepGraph for eviction/prefetch decisions
Markov Chain:
- 2nd-order: state = (prev_agent, curr_agent) → next_agent
- 1st-order fallback: state = curr_agent → next_agent
- Laplace smoothing (alpha=1) for unseen transitions
"""
def __init__(
self,
log_dir: Optional[str] = None,
max_history_steps: int = 1000,
blend_alpha: float = 0.6,
):
self._log_dir = Path(log_dir) if log_dir else Path(".") / ".pbkv_logs"
self._max_history_steps = max_history_steps
self._blend_alpha = blend_alpha
self._history: list[WorkflowStepRecord] = []
self._transition_table: dict[tuple[str, str], dict[str, int]] = {}
self._first_order_table: dict[str, dict[str, int]] = {}
self._all_agents: set[str] = set()
self._lock = asyncio.Lock()
self._log_file = self._log_dir / "workflow_steps.jsonl"
self._log_dir.mkdir(parents=True, exist_ok=True)
self._trained = False
async def log_workflow_step(
self,
step_idx: int,
agent_id: str,
anchor_hash: str,
token_length: int,
cla_group: Optional[int] = None,
) -> None:
"""Log a workflow step for future prediction training."""
record = WorkflowStepRecord(
step_idx=step_idx,
agent_id=agent_id,
anchor_hash=anchor_hash,
token_length=token_length,
cla_group=cla_group,
)
async with self._lock:
self._history.append(record)
if len(self._history) > self._max_history_steps:
self._history.pop(0)
# Append to JSONL log
try:
with open(self._log_file, "a") as f:
f.write(json.dumps(record.__dict__) + "\n")
except Exception as e:
logger.warning(f"Failed to write PBKV log: {e}")
def train_from_jsonl(self, path: str) -> None:
"""Load JSONL and build Markov transition table.
Reads workflow_steps.jsonl files from the log directory.
Builds: {(prev_agent, curr_agent): {next_agent: count}}
Also builds 1st-order fallback: {curr_agent: {next_agent: count}}
Uses Laplace smoothing (alpha=1) for unseen transitions.
"""
log_path = Path(path)
if log_path.is_dir():
log_path = log_path / "workflow_steps.jsonl"
if not log_path.exists():
logger.warning(f"JSONL file not found: {log_path}")
return
sequences: list[list[str]] = []
current_seq: list[str] = []
with open(log_path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
current_seq.append(record["agent_id"])
except (json.JSONDecodeError, KeyError):
# End of sequence marker (empty line or invalid)
if current_seq:
sequences.append(current_seq)
current_seq = []
if current_seq:
sequences.append(current_seq)
# Build transition tables
self._transition_table.clear()
self._first_order_table.clear()
self._all_agents.clear()
for seq in sequences:
for i, agent_id in enumerate(seq):
self._all_agents.add(agent_id)
if i >= 1:
prev_agent = seq[i - 1]
# 2nd-order: (prev, curr) → next
key = (prev_agent, agent_id)
if key not in self._transition_table:
self._transition_table[key] = {}
self._transition_table[key][agent_id] = \
self._transition_table[key].get(agent_id, 0) + 1
if i >= 2:
# 1st-order: curr → next
curr_agent = seq[i - 1]
next_agent = seq[i]
if curr_agent not in self._first_order_table:
self._first_order_table[curr_agent] = {}
self._first_order_table[curr_agent][next_agent] = \
self._first_order_table[curr_agent].get(next_agent, 0) + 1
self._trained = True
logger.info(
f"Trained Markov model: {len(self._transition_table)} 2nd-order states, "
f"{len(self._first_order_table)} 1st-order states, "
f"{len(self._all_agents)} unique agents"
)
def _get_transition_probs(
self,
prev_agent: Optional[str],
curr_agent: str,
) -> dict[str, float]:
"""Get transition probabilities for given state.
Uses 2nd-order if prev_agent available, else 1st-order.
Applies Laplace smoothing (alpha=1).
"""
alpha = 1.0
num_states = len(self._all_agents) if self._all_agents else 1
if prev_agent is not None:
key = (prev_agent, curr_agent)
if key in self._transition_table:
total = sum(self._transition_table[key].values())
probs = {}
for agent in self._all_agents:
count = self._transition_table[key].get(agent, 0)
probs[agent] = (count + alpha) / (total + alpha * num_states)
return probs
# Fallback to 1st-order
if curr_agent in self._first_order_table:
total = sum(self._first_order_table[curr_agent].values())
probs = {}
for agent in self._all_agents:
count = self._first_order_table[curr_agent].get(agent, 0)
probs[agent] = (count + alpha) / (total + alpha * num_states)
return probs
# Uniform fallback
return {agent: 1.0 / num_states for agent in self._all_agents}
def predict_next_agents(
self,
current_agent_id: str,
top_k: int = 3,
) -> list[str]:
"""Predict top-k most likely next agents (synchronous).
Uses only the last observed agent as prev_state for 1st-order
approximation if history is empty, but tries (prev, curr) → next
if available.
"""
if not self._trained and not self._history:
return [current_agent_id]
prev_agent: Optional[str] = None
curr_agent = current_agent_id
# Build sequences from history if not trained from JSONL
if not self._trained:
seq: list[str] = [s.agent_id for s in self._history]
for i, agent_id in enumerate(seq):
if agent_id == current_agent_id and i > 0:
prev_agent = seq[i - 1]
break
if prev_agent is None and len(seq) >= 2:
prev_agent = seq[-2]
curr_agent = seq[-1]
probs = self._get_transition_probs(prev_agent, curr_agent)
sorted_agents = sorted(probs.items(), key=lambda x: -x[1])
return [agent for agent, _ in sorted_agents[:top_k]]
async def _predict_next_agents_async(
self,
current_agent_id: str,
current_step: int = 0,
num_predictions: int = 3,
) -> PredictionResult:
"""Async wrapper for backward compatibility with PredictionResult.
Internal use only. Use predict_next_agents() for the public API.
"""
async with self._lock:
history_copy = list(self._history)
if not history_copy:
return PredictionResult(
predicted_agents=[current_agent_id],
predicted_anchor_hashes=[],
confidence=0.0,
)
# Determine prev_agent from history
prev_agent: Optional[str] = None
curr_agent = current_agent_id
# Find current agent in history to get preceding agent
for i, step in enumerate(history_copy):
if step.agent_id == current_agent_id and i > 0:
prev_agent = history_copy[i - 1].agent_id
curr_agent = current_agent_id
break
# Get transition probabilities
probs = self._get_transition_probs(prev_agent, curr_agent)
# Sort by probability descending
sorted_agents = sorted(probs.items(), key=lambda x: -x[1])
top_agents = [agent for agent, _ in sorted_agents[:num_predictions]]
confidence = sorted_agents[0][1] if sorted_agents else 0.0
# Get anchor hashes from recent history for predicted agents
anchor_hashes = []
agent_set = set(top_agents)
for step in reversed(history_copy):
if step.agent_id in agent_set and step.anchor_hash not in anchor_hashes:
anchor_hashes.append(step.anchor_hash)
if len(anchor_hashes) >= num_predictions:
break
return PredictionResult(
predicted_agents=top_agents,
predicted_anchor_hashes=anchor_hashes,
confidence=confidence,
)
async def get_eviction_priority(
self,
agent_ids: list[str],
step_graph: Optional["AgentStepGraph"] = None,
) -> list[str]:
"""Order agents by inverse predicted probability for eviction.
Evicts agents least likely to be needed next (low priority).
Blends with AgentStepGraph if available using blend_alpha:
- blend_alpha=0.6: step_graph weight
- (1-blend_alpha)=0.4: pbkv weight
"""
if not agent_ids:
return []
# Get PBKV priorities (lower prob = higher eviction priority)
pbkv_scores: dict[str, float] = {}
if self._trained or self._history:
for agent_id in agent_ids:
top_k = self.predict_next_agents(agent_id, top_k=len(agent_ids))
# Score = position in ranked list (lower position = higher prob)
if agent_id in top_k:
pbkv_scores[agent_id] = 1.0 / (top_k.index(agent_id) + 1)
else:
pbkv_scores[agent_id] = 0.0
else:
# Uniform if no training data
for agent_id in agent_ids:
pbkv_scores[agent_id] = 1.0 / len(agent_ids)
# Get AgentStepGraph priorities if available
if step_graph is not None:
try:
graph_priorities = step_graph.get_eviction_priority_order()
graph_scores: dict[str, float] = {}
for rank, agent_id in enumerate(graph_priorities):
if agent_id in agent_ids:
graph_scores[agent_id] = 1.0 / (rank + 1)
# Blend scores
blended_scores: dict[str, float] = {}
for agent_id in agent_ids:
pbkv = pbkv_scores.get(agent_id, 0.0)
graph = graph_scores.get(agent_id, 0.0)
blended_scores[agent_id] = (
self._blend_alpha * graph + (1 - self._blend_alpha) * pbkv
)
# Sort ascending (low score = evict first = low priority)
sorted_agents = sorted(
agent_ids, key=lambda x: blended_scores.get(x, 0.0)
)
except Exception as e:
logger.warning(f"AgentStepGraph blend failed: {e}")
sorted_agents = sorted(
agent_ids, key=lambda x: pbkv_scores.get(x, 0.0)
)
else:
# PBKV only: sort ascending (low prob = evict first)
sorted_agents = sorted(
agent_ids, key=lambda x: pbkv_scores.get(x, 0.0)
)
return sorted_agents
async def get_prefetch_candidates(
self,
current_agent_id: str,
step: int = 0,
lookahead: int = 2,
) -> list[str]:
"""Get list of agent IDs to prefetch within lookahead steps.
Uses Markov prediction to find agents within 2 steps.
"""
prediction = await self._predict_next_agents_async(
current_agent_id, current_step=step, num_predictions=lookahead
)
candidates = prediction.predicted_agents
logger.debug(
f"PBKV prefetch candidates for agent={current_agent_id} step={step}: "
f"{len(candidates)} candidates"
)
return candidates
def get_stats(self) -> dict:
"""Return PBKV predictor statistics."""
return {
"history_size": len(self._history),
"log_file": str(self._log_file),
"max_history_steps": self._max_history_steps,
"blend_alpha": self._blend_alpha,
"trained": self._trained,
"transition_table_size": len(self._transition_table),
"first_order_table_size": len(self._first_order_table),
"unique_agents": len(self._all_agents),
}