Spaces:
Sleeping
Sleeping
File size: 5,409 Bytes
bfb7184 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """AgentStepGraph — workflow dependency graph for KV cache eviction priority.
Based on KVFlow (NeurIPS 2025, arXiv:2507.07400):
- Workflow-aware eviction: evict caches of agents with high steps-to-execution
(agents far from being invoked) before agents about to run.
- Overlapped KV prefetching: proactively prefetch KV tensors for agents
scheduled in the next N steps.
Result from paper: 1.83x speedup over SGLang, 2.19x for concurrent workflows.
V4.0 CHANGES: New module for workflow-aware eviction.
"""
import sys
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class AgentStep:
"""A single step in a workflow graph."""
agent_id: str
depends_on: list[str] = field(default_factory=list)
step_index: int = 0
estimated_tokens: int = 0
is_optional: bool = False # True for dynamic conditional agents
class AgentStepGraph:
"""
Workflow dependency graph for KV cache eviction priority.
Usage:
graph = AgentStepGraph()
graph.add_step(AgentStep(agent_id="retriever", depends_on=[], step_index=0))
graph.add_step(AgentStep(agent_id="summarizer", depends_on=["retriever"], step_index=1))
order = graph.get_eviction_priority_order() # agents far from execution first
"""
def __init__(self):
self._steps: dict[str, AgentStep] = {}
self._step_list: list[AgentStep] = [] # topological order
def add_step(self, step: AgentStep) -> "AgentStepGraph":
"""Add a step to the graph. Returns self for chaining."""
self._steps[step.agent_id] = step
self._step_list.append(step)
return self
def compute_steps_to_execution(self, agent_id: str, current_step: int = 0) -> int:
"""
Returns how many steps must complete before agent_id is invoked.
Returns:
0 if agent is the current step.
sys.maxsize if agent_id not in graph.
Raises ValueError if graph has cycles.
"""
self.validate_dag() # Will raise if cycles
if agent_id not in self._steps:
return sys.maxsize
step = self._steps[agent_id]
# Compute longest path from any root to this step
if step.step_index <= current_step:
return 0
# BFS/DFS to compute depth
visited = set()
def compute_depth(s: AgentStep, visited: set) -> int:
if s.agent_id in visited:
return 0
visited.add(s.agent_id)
if not s.depends_on:
return s.step_index
max_parent_depth = 0
for dep_id in s.depends_on:
if dep_id in self._steps:
max_parent_depth = max(max_parent_depth, compute_depth(self._steps[dep_id], visited))
return max_parent_depth + 1
return compute_depth(step, set())
def get_prefetch_candidates(
self,
current_step: int,
lookahead: int = 2,
) -> list[str]:
"""Return agent_ids to prefetch within `lookahead` steps."""
candidates = []
for step in self._step_list:
if step.step_index <= current_step:
continue
if step.step_index <= current_step + lookahead:
candidates.append(step.agent_id)
return candidates
def get_eviction_priority_order(self) -> list[str]:
"""
Return agent_ids ordered from lowest to highest eviction priority
(first in list = evict first = highest steps_to_execution).
"""
# Sort by steps_to_execution descending (agents far from execution evict first)
priorities = []
for step in self._step_list:
steps = self.compute_steps_to_execution(step.agent_id, current_step=0)
priorities.append((step.agent_id, steps))
# Sort descending by steps (highest first = evict first)
priorities.sort(key=lambda x: x[1], reverse=True)
return [agent_id for agent_id, _ in priorities]
def validate_dag(self) -> None:
"""Raise ValueError if graph contains cycles."""
# DFS-based cycle detection
WHITE, GRAY, BLACK = 0, 1, 2
color = {sid: WHITE for sid in self._steps}
def dfs(node_id: str) -> None:
color[node_id] = GRAY
if node_id in self._steps:
for dep in self._steps[node_id].depends_on:
if dep not in color:
color[dep] = WHITE
if color.get(dep, WHITE) == GRAY:
raise ValueError(f"Cycle detected involving agent '{node_id}'")
if color.get(dep, WHITE) == WHITE:
dfs(dep)
color[node_id] = BLACK
for sid in self._steps:
if color[sid] == WHITE:
dfs(sid)
@property
def size(self) -> int:
"""Number of steps in the graph."""
return len(self._steps)
def get_step(self, agent_id: str) -> Optional[AgentStep]:
"""Get step by agent_id."""
return self._steps.get(agent_id)
def get_all_agents(self) -> list[str]:
"""Get all agent IDs in the graph."""
return list(self._steps.keys()) |