Spaces:
Sleeping
Sleeping
File size: 5,952 Bytes
24d9eca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """Prefix Normalizer for vLLM prefix caching (enable_prefix_caching=True).
vLLM requires token-identical prefixes across requests to trigger KV cache hits.
A single extra space or different capitalization creates a completely different
token sequence and breaks cache sharing.
Key enforcement:
- FIXED order: [canonical_system_prompt][SEP][agent_role_prompt][SEP][user_prompt]
- SEPARATOR is exactly two newlines: "\n\n" (never one, never three)
- Each segment stripped of trailing whitespace before assembly
- SHA256 validation catches mismatched canonical prefixes
Usage:
normalizer = PrefixNormalizer(
canonical_system_prompt="You are a helpful AI assistant."
)
# All agents use the same normalizer
prompt1 = normalizer.normalize("agent1", "What is AI?", "retriever role")
prompt2 = normalizer.normalize("agent2", "What is AI?", "summarizer role")
# prompt1 and prompt2 are byte-identical at the system prompt prefix
"""
import hashlib
import logging
from typing import Optional
logger = logging.getLogger(__name__)
# Fixed separator between prompt segments
SEPARATOR = "\n\n"
class PrefixNormalizer:
"""
Enforces token-identical prefixes for vLLM prefix caching.
All agents must use the same canonical_system_prompt. Any deviation
is logged as a WARNING (not ERROR) because vLLM silently degrades
to non-cached computation when prefixes don't match.
Usage:
normalizer = PrefixNormalizer(
canonical_system_prompt="You are a helpful AI assistant."
)
final_prompt = normalizer.normalize(
agent_id="agent1",
user_prompt="What is machine learning?",
agent_role_prompt="You are a retriever agent."
)
"""
def __init__(
self,
canonical_system_prompt: str,
separator: str = SEPARATOR,
):
"""
Initialize with the shared system prompt.
Args:
canonical_system_prompt: The shared base prompt (must be identical
byte-for-byte across all agents)
separator: Separator between segments (default: two newlines)
"""
self._canonical_system_prompt = canonical_system_prompt.strip()
self._separator = separator
self._canonical_hash = self._compute_hash(self._canonical_system_prompt)
self._registered_agents: set[str] = set()
logger.info(
f"PrefixNormalizer initialized with system prompt hash: "
f"{self._canonical_hash[:16]}..."
)
@staticmethod
def _compute_hash(text: str) -> str:
"""Compute SHA256 hex of text."""
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def normalize(
self,
agent_id: str,
user_prompt: str,
agent_role_prompt: str,
) -> str:
"""
Assemble final prompt in FIXED order with canonical system prompt.
Order: [canonical_system_prompt][SEP][agent_role_prompt][SEP][user_prompt]
Args:
agent_id: Agent identifier (for logging only)
user_prompt: User's query/input
agent_role_prompt: Agent-specific role prompt
Returns:
Final assembled prompt with byte-identical system prefix
"""
# Strip trailing whitespace from each segment
system_part = self._canonical_system_prompt
role_part = agent_role_prompt.strip()
user_part = user_prompt.strip()
# Assemble in fixed order
segments = [system_part, role_part, user_part]
assembled = self._separator.join(segments)
# Validate system prompt hash (catch silent prefix mismatches)
# We don't validate here because the system prompt is already stored
# and should be identical. Validation happens at registration.
if agent_id not in self._registered_agents:
self._registered_agents.add(agent_id)
return assembled
def validate_system_prompt(self, system_prompt: str) -> bool:
"""
Validate that a system prompt matches the canonical one.
Args:
system_prompt: System prompt to validate
Returns:
True if identical, False otherwise
"""
hash_to_check = self._compute_hash(system_prompt.strip())
matches = hash_to_check == self._canonical_hash
if not matches:
logger.warning(
f"Agent system prompt hash MISMATCH. "
f"Expected {self._canonical_hash[:16]}, "
f"got {hash_to_check[:16]}. "
f"vLLM prefix caching will NOT work for this agent."
)
return matches
def get_canonical_hash(self) -> str:
"""Get SHA256 of the canonical system prompt."""
return self._canonical_hash
def get_canonical_prompt(self) -> str:
"""Get the canonical system prompt."""
return self._canonical_system_prompt
@property
def separator(self) -> str:
"""Get the separator string."""
return self._separator
def compute_prompt_hash(self, prompt: str) -> str:
"""
Compute hash of an assembled prompt (for debugging)."""
return self._compute_hash(prompt)
def create_prefix_normalizer(
canonical_system_prompt: Optional[str] = None,
) -> PrefixNormalizer:
"""
Factory to create a PrefixNormalizer with default or custom system prompt.
Args:
canonical_system_prompt: Custom system prompt (optional)
Returns:
Configured PrefixNormalizer instance
"""
default_prompt = (
"You are a helpful AI assistant. "
"Provide accurate, detailed, and thoughtful responses. "
"Use chain-of-thought reasoning when appropriate."
)
return PrefixNormalizer(
canonical_system_prompt=canonical_system_prompt or default_prompt,
separator=SEPARATOR,
) |