contextforge-demo / contextforge /normalization /prefix_normalizer.py
Pablo
ContextForge v3.0: production-grade shared context compiler
24d9eca
raw
history blame
5.95 kB
"""Prefix Normalizer for vLLM prefix caching (enable_prefix_caching=True).
vLLM requires token-identical prefixes across requests to trigger KV cache hits.
A single extra space or different capitalization creates a completely different
token sequence and breaks cache sharing.
Key enforcement:
- FIXED order: [canonical_system_prompt][SEP][agent_role_prompt][SEP][user_prompt]
- SEPARATOR is exactly two newlines: "\n\n" (never one, never three)
- Each segment stripped of trailing whitespace before assembly
- SHA256 validation catches mismatched canonical prefixes
Usage:
normalizer = PrefixNormalizer(
canonical_system_prompt="You are a helpful AI assistant."
)
# All agents use the same normalizer
prompt1 = normalizer.normalize("agent1", "What is AI?", "retriever role")
prompt2 = normalizer.normalize("agent2", "What is AI?", "summarizer role")
# prompt1 and prompt2 are byte-identical at the system prompt prefix
"""
import hashlib
import logging
from typing import Optional
logger = logging.getLogger(__name__)
# Fixed separator between prompt segments
SEPARATOR = "\n\n"
class PrefixNormalizer:
"""
Enforces token-identical prefixes for vLLM prefix caching.
All agents must use the same canonical_system_prompt. Any deviation
is logged as a WARNING (not ERROR) because vLLM silently degrades
to non-cached computation when prefixes don't match.
Usage:
normalizer = PrefixNormalizer(
canonical_system_prompt="You are a helpful AI assistant."
)
final_prompt = normalizer.normalize(
agent_id="agent1",
user_prompt="What is machine learning?",
agent_role_prompt="You are a retriever agent."
)
"""
def __init__(
self,
canonical_system_prompt: str,
separator: str = SEPARATOR,
):
"""
Initialize with the shared system prompt.
Args:
canonical_system_prompt: The shared base prompt (must be identical
byte-for-byte across all agents)
separator: Separator between segments (default: two newlines)
"""
self._canonical_system_prompt = canonical_system_prompt.strip()
self._separator = separator
self._canonical_hash = self._compute_hash(self._canonical_system_prompt)
self._registered_agents: set[str] = set()
logger.info(
f"PrefixNormalizer initialized with system prompt hash: "
f"{self._canonical_hash[:16]}..."
)
@staticmethod
def _compute_hash(text: str) -> str:
"""Compute SHA256 hex of text."""
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def normalize(
self,
agent_id: str,
user_prompt: str,
agent_role_prompt: str,
) -> str:
"""
Assemble final prompt in FIXED order with canonical system prompt.
Order: [canonical_system_prompt][SEP][agent_role_prompt][SEP][user_prompt]
Args:
agent_id: Agent identifier (for logging only)
user_prompt: User's query/input
agent_role_prompt: Agent-specific role prompt
Returns:
Final assembled prompt with byte-identical system prefix
"""
# Strip trailing whitespace from each segment
system_part = self._canonical_system_prompt
role_part = agent_role_prompt.strip()
user_part = user_prompt.strip()
# Assemble in fixed order
segments = [system_part, role_part, user_part]
assembled = self._separator.join(segments)
# Validate system prompt hash (catch silent prefix mismatches)
# We don't validate here because the system prompt is already stored
# and should be identical. Validation happens at registration.
if agent_id not in self._registered_agents:
self._registered_agents.add(agent_id)
return assembled
def validate_system_prompt(self, system_prompt: str) -> bool:
"""
Validate that a system prompt matches the canonical one.
Args:
system_prompt: System prompt to validate
Returns:
True if identical, False otherwise
"""
hash_to_check = self._compute_hash(system_prompt.strip())
matches = hash_to_check == self._canonical_hash
if not matches:
logger.warning(
f"Agent system prompt hash MISMATCH. "
f"Expected {self._canonical_hash[:16]}, "
f"got {hash_to_check[:16]}. "
f"vLLM prefix caching will NOT work for this agent."
)
return matches
def get_canonical_hash(self) -> str:
"""Get SHA256 of the canonical system prompt."""
return self._canonical_hash
def get_canonical_prompt(self) -> str:
"""Get the canonical system prompt."""
return self._canonical_system_prompt
@property
def separator(self) -> str:
"""Get the separator string."""
return self._separator
def compute_prompt_hash(self, prompt: str) -> str:
"""
Compute hash of an assembled prompt (for debugging)."""
return self._compute_hash(prompt)
def create_prefix_normalizer(
canonical_system_prompt: Optional[str] = None,
) -> PrefixNormalizer:
"""
Factory to create a PrefixNormalizer with default or custom system prompt.
Args:
canonical_system_prompt: Custom system prompt (optional)
Returns:
Configured PrefixNormalizer instance
"""
default_prompt = (
"You are a helpful AI assistant. "
"Provide accurate, detailed, and thoughtful responses. "
"Use chain-of-thought reasoning when appropriate."
)
return PrefixNormalizer(
canonical_system_prompt=canonical_system_prompt or default_prompt,
separator=SEPARATOR,
)