Spaces:
Sleeping
Sleeping
| """AgentStepGraph — workflow dependency graph for KV cache eviction priority. | |
| Based on KVFlow (NeurIPS 2025, arXiv:2507.07400): | |
| - Workflow-aware eviction: evict caches of agents with high steps-to-execution | |
| (agents far from being invoked) before agents about to run. | |
| - Overlapped KV prefetching: proactively prefetch KV tensors for agents | |
| scheduled in the next N steps. | |
| Result from paper: 1.83x speedup over SGLang, 2.19x for concurrent workflows. | |
| V4.0 CHANGES: New module for workflow-aware eviction. | |
| """ | |
| import sys | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| class AgentStep: | |
| """A single step in a workflow graph.""" | |
| agent_id: str | |
| depends_on: list[str] = field(default_factory=list) | |
| step_index: int = 0 | |
| estimated_tokens: int = 0 | |
| is_optional: bool = False # True for dynamic conditional agents | |
| class AgentStepGraph: | |
| """ | |
| Workflow dependency graph for KV cache eviction priority. | |
| Usage: | |
| graph = AgentStepGraph() | |
| graph.add_step(AgentStep(agent_id="retriever", depends_on=[], step_index=0)) | |
| graph.add_step(AgentStep(agent_id="summarizer", depends_on=["retriever"], step_index=1)) | |
| order = graph.get_eviction_priority_order() # agents far from execution first | |
| """ | |
| def __init__(self): | |
| self._steps: dict[str, AgentStep] = {} | |
| self._step_list: list[AgentStep] = [] # topological order | |
| def add_step(self, step: AgentStep) -> "AgentStepGraph": | |
| """Add a step to the graph. Returns self for chaining.""" | |
| self._steps[step.agent_id] = step | |
| self._step_list.append(step) | |
| return self | |
| def compute_steps_to_execution(self, agent_id: str, current_step: int = 0) -> int: | |
| """ | |
| Returns how many steps must complete before agent_id is invoked. | |
| Returns: | |
| 0 if agent is the current step. | |
| sys.maxsize if agent_id not in graph. | |
| Raises ValueError if graph has cycles. | |
| """ | |
| self.validate_dag() # Will raise if cycles | |
| if agent_id not in self._steps: | |
| return sys.maxsize | |
| step = self._steps[agent_id] | |
| # Compute longest path from any root to this step | |
| if step.step_index <= current_step: | |
| return 0 | |
| # BFS/DFS to compute depth | |
| visited = set() | |
| def compute_depth(s: AgentStep, visited: set) -> int: | |
| if s.agent_id in visited: | |
| return 0 | |
| visited.add(s.agent_id) | |
| if not s.depends_on: | |
| return s.step_index | |
| max_parent_depth = 0 | |
| for dep_id in s.depends_on: | |
| if dep_id in self._steps: | |
| max_parent_depth = max(max_parent_depth, compute_depth(self._steps[dep_id], visited)) | |
| return max_parent_depth + 1 | |
| return compute_depth(step, set()) | |
| def get_prefetch_candidates( | |
| self, | |
| current_step: int, | |
| lookahead: int = 2, | |
| ) -> list[str]: | |
| """Return agent_ids to prefetch within `lookahead` steps.""" | |
| candidates = [] | |
| for step in self._step_list: | |
| if step.step_index <= current_step: | |
| continue | |
| if step.step_index <= current_step + lookahead: | |
| candidates.append(step.agent_id) | |
| return candidates | |
| def get_eviction_priority_order(self) -> list[str]: | |
| """ | |
| Return agent_ids ordered from lowest to highest eviction priority | |
| (first in list = evict first = highest steps_to_execution). | |
| """ | |
| # Sort by steps_to_execution descending (agents far from execution evict first) | |
| priorities = [] | |
| for step in self._step_list: | |
| steps = self.compute_steps_to_execution(step.agent_id, current_step=0) | |
| priorities.append((step.agent_id, steps)) | |
| # Sort descending by steps (highest first = evict first) | |
| priorities.sort(key=lambda x: x[1], reverse=True) | |
| return [agent_id for agent_id, _ in priorities] | |
| def validate_dag(self) -> None: | |
| """Raise ValueError if graph contains cycles.""" | |
| # DFS-based cycle detection | |
| WHITE, GRAY, BLACK = 0, 1, 2 | |
| color = {sid: WHITE for sid in self._steps} | |
| def dfs(node_id: str) -> None: | |
| color[node_id] = GRAY | |
| if node_id in self._steps: | |
| for dep in self._steps[node_id].depends_on: | |
| if dep not in color: | |
| color[dep] = WHITE | |
| if color.get(dep, WHITE) == GRAY: | |
| raise ValueError(f"Cycle detected involving agent '{node_id}'") | |
| if color.get(dep, WHITE) == WHITE: | |
| dfs(dep) | |
| color[node_id] = BLACK | |
| for sid in self._steps: | |
| if color[sid] == WHITE: | |
| dfs(sid) | |
| def size(self) -> int: | |
| """Number of steps in the graph.""" | |
| return len(self._steps) | |
| def get_step(self, agent_id: str) -> Optional[AgentStep]: | |
| """Get step by agent_id.""" | |
| return self._steps.get(agent_id) | |
| def get_all_agents(self) -> list[str]: | |
| """Get all agent IDs in the graph.""" | |
| return list(self._steps.keys()) |