File size: 5,409 Bytes
bfb7184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""AgentStepGraph — workflow dependency graph for KV cache eviction priority.

Based on KVFlow (NeurIPS 2025, arXiv:2507.07400):
- Workflow-aware eviction: evict caches of agents with high steps-to-execution
  (agents far from being invoked) before agents about to run.
- Overlapped KV prefetching: proactively prefetch KV tensors for agents
  scheduled in the next N steps.

Result from paper: 1.83x speedup over SGLang, 2.19x for concurrent workflows.

V4.0 CHANGES: New module for workflow-aware eviction.
"""
import sys
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class AgentStep:
    """A single step in a workflow graph."""
    agent_id: str
    depends_on: list[str] = field(default_factory=list)
    step_index: int = 0
    estimated_tokens: int = 0
    is_optional: bool = False  # True for dynamic conditional agents


class AgentStepGraph:
    """
    Workflow dependency graph for KV cache eviction priority.
    
    Usage:
        graph = AgentStepGraph()
        graph.add_step(AgentStep(agent_id="retriever", depends_on=[], step_index=0))
        graph.add_step(AgentStep(agent_id="summarizer", depends_on=["retriever"], step_index=1))
        order = graph.get_eviction_priority_order()  # agents far from execution first
    """
    
    def __init__(self):
        self._steps: dict[str, AgentStep] = {}
        self._step_list: list[AgentStep] = []  # topological order
    
    def add_step(self, step: AgentStep) -> "AgentStepGraph":
        """Add a step to the graph. Returns self for chaining."""
        self._steps[step.agent_id] = step
        self._step_list.append(step)
        return self
    
    def compute_steps_to_execution(self, agent_id: str, current_step: int = 0) -> int:
        """
        Returns how many steps must complete before agent_id is invoked.
        
        Returns:
            0 if agent is the current step.
            sys.maxsize if agent_id not in graph.
            Raises ValueError if graph has cycles.
        """
        self.validate_dag()  # Will raise if cycles
        
        if agent_id not in self._steps:
            return sys.maxsize
        
        step = self._steps[agent_id]
        
        # Compute longest path from any root to this step
        if step.step_index <= current_step:
            return 0
        
        # BFS/DFS to compute depth
        visited = set()
        
        def compute_depth(s: AgentStep, visited: set) -> int:
            if s.agent_id in visited:
                return 0
            visited.add(s.agent_id)
            
            if not s.depends_on:
                return s.step_index
            
            max_parent_depth = 0
            for dep_id in s.depends_on:
                if dep_id in self._steps:
                    max_parent_depth = max(max_parent_depth, compute_depth(self._steps[dep_id], visited))
            
            return max_parent_depth + 1
        
        return compute_depth(step, set())
    
    def get_prefetch_candidates(
        self,
        current_step: int,
        lookahead: int = 2,
    ) -> list[str]:
        """Return agent_ids to prefetch within `lookahead` steps."""
        candidates = []
        for step in self._step_list:
            if step.step_index <= current_step:
                continue
            if step.step_index <= current_step + lookahead:
                candidates.append(step.agent_id)
        return candidates
    
    def get_eviction_priority_order(self) -> list[str]:
        """
        Return agent_ids ordered from lowest to highest eviction priority
        (first in list = evict first = highest steps_to_execution).
        """
        # Sort by steps_to_execution descending (agents far from execution evict first)
        priorities = []
        for step in self._step_list:
            steps = self.compute_steps_to_execution(step.agent_id, current_step=0)
            priorities.append((step.agent_id, steps))
        
        # Sort descending by steps (highest first = evict first)
        priorities.sort(key=lambda x: x[1], reverse=True)
        return [agent_id for agent_id, _ in priorities]
    
    def validate_dag(self) -> None:
        """Raise ValueError if graph contains cycles."""
        # DFS-based cycle detection
        WHITE, GRAY, BLACK = 0, 1, 2
        color = {sid: WHITE for sid in self._steps}
        
        def dfs(node_id: str) -> None:
            color[node_id] = GRAY
            if node_id in self._steps:
                for dep in self._steps[node_id].depends_on:
                    if dep not in color:
                        color[dep] = WHITE
                    if color.get(dep, WHITE) == GRAY:
                        raise ValueError(f"Cycle detected involving agent '{node_id}'")
                    if color.get(dep, WHITE) == WHITE:
                        dfs(dep)
            color[node_id] = BLACK
        
        for sid in self._steps:
            if color[sid] == WHITE:
                dfs(sid)
    
    @property
    def size(self) -> int:
        """Number of steps in the graph."""
        return len(self._steps)
    
    def get_step(self, agent_id: str) -> Optional[AgentStep]:
        """Get step by agent_id."""
        return self._steps.get(agent_id)
    
    def get_all_agents(self) -> list[str]:
        """Get all agent IDs in the graph."""
        return list(self._steps.keys())