File size: 5,952 Bytes
24d9eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Prefix Normalizer for vLLM prefix caching (enable_prefix_caching=True).

vLLM requires token-identical prefixes across requests to trigger KV cache hits.
A single extra space or different capitalization creates a completely different
token sequence and breaks cache sharing.

Key enforcement:
- FIXED order: [canonical_system_prompt][SEP][agent_role_prompt][SEP][user_prompt]
- SEPARATOR is exactly two newlines: "\n\n" (never one, never three)
- Each segment stripped of trailing whitespace before assembly
- SHA256 validation catches mismatched canonical prefixes

Usage:
    normalizer = PrefixNormalizer(
        canonical_system_prompt="You are a helpful AI assistant."
    )

    # All agents use the same normalizer
    prompt1 = normalizer.normalize("agent1", "What is AI?", "retriever role")
    prompt2 = normalizer.normalize("agent2", "What is AI?", "summarizer role")

    # prompt1 and prompt2 are byte-identical at the system prompt prefix
"""
import hashlib
import logging
from typing import Optional

logger = logging.getLogger(__name__)

# Fixed separator between prompt segments
SEPARATOR = "\n\n"


class PrefixNormalizer:
    """
    Enforces token-identical prefixes for vLLM prefix caching.

    All agents must use the same canonical_system_prompt. Any deviation
    is logged as a WARNING (not ERROR) because vLLM silently degrades
    to non-cached computation when prefixes don't match.

    Usage:
        normalizer = PrefixNormalizer(
            canonical_system_prompt="You are a helpful AI assistant."
        )
        final_prompt = normalizer.normalize(
            agent_id="agent1",
            user_prompt="What is machine learning?",
            agent_role_prompt="You are a retriever agent."
        )
    """

    def __init__(
        self,
        canonical_system_prompt: str,
        separator: str = SEPARATOR,
    ):
        """
        Initialize with the shared system prompt.

        Args:
            canonical_system_prompt: The shared base prompt (must be identical
                                     byte-for-byte across all agents)
            separator: Separator between segments (default: two newlines)
        """
        self._canonical_system_prompt = canonical_system_prompt.strip()
        self._separator = separator
        self._canonical_hash = self._compute_hash(self._canonical_system_prompt)
        self._registered_agents: set[str] = set()

        logger.info(
            f"PrefixNormalizer initialized with system prompt hash: "
            f"{self._canonical_hash[:16]}..."
        )

    @staticmethod
    def _compute_hash(text: str) -> str:
        """Compute SHA256 hex of text."""
        return hashlib.sha256(text.encode("utf-8")).hexdigest()

    def normalize(
        self,
        agent_id: str,
        user_prompt: str,
        agent_role_prompt: str,
    ) -> str:
        """
        Assemble final prompt in FIXED order with canonical system prompt.

        Order: [canonical_system_prompt][SEP][agent_role_prompt][SEP][user_prompt]

        Args:
            agent_id: Agent identifier (for logging only)
            user_prompt: User's query/input
            agent_role_prompt: Agent-specific role prompt

        Returns:
            Final assembled prompt with byte-identical system prefix
        """
        # Strip trailing whitespace from each segment
        system_part = self._canonical_system_prompt
        role_part = agent_role_prompt.strip()
        user_part = user_prompt.strip()

        # Assemble in fixed order
        segments = [system_part, role_part, user_part]
        assembled = self._separator.join(segments)

        # Validate system prompt hash (catch silent prefix mismatches)
        # We don't validate here because the system prompt is already stored
        # and should be identical. Validation happens at registration.

        if agent_id not in self._registered_agents:
            self._registered_agents.add(agent_id)

        return assembled

    def validate_system_prompt(self, system_prompt: str) -> bool:
        """
        Validate that a system prompt matches the canonical one.

        Args:
            system_prompt: System prompt to validate

        Returns:
            True if identical, False otherwise
        """
        hash_to_check = self._compute_hash(system_prompt.strip())
        matches = hash_to_check == self._canonical_hash

        if not matches:
            logger.warning(
                f"Agent system prompt hash MISMATCH. "
                f"Expected {self._canonical_hash[:16]}, "
                f"got {hash_to_check[:16]}. "
                f"vLLM prefix caching will NOT work for this agent."
            )

        return matches

    def get_canonical_hash(self) -> str:
        """Get SHA256 of the canonical system prompt."""
        return self._canonical_hash

    def get_canonical_prompt(self) -> str:
        """Get the canonical system prompt."""
        return self._canonical_system_prompt

    @property
    def separator(self) -> str:
        """Get the separator string."""
        return self._separator

    def compute_prompt_hash(self, prompt: str) -> str:
        """
        Compute hash of an assembled prompt (for debugging)."""
        return self._compute_hash(prompt)


def create_prefix_normalizer(
    canonical_system_prompt: Optional[str] = None,
) -> PrefixNormalizer:
    """
    Factory to create a PrefixNormalizer with default or custom system prompt.

    Args:
        canonical_system_prompt: Custom system prompt (optional)

    Returns:
        Configured PrefixNormalizer instance
    """
    default_prompt = (
        "You are a helpful AI assistant. "
        "Provide accurate, detailed, and thoughtful responses. "
        "Use chain-of-thought reasoning when appropriate."
    )

    return PrefixNormalizer(
        canonical_system_prompt=canonical_system_prompt or default_prompt,
        separator=SEPARATOR,
    )