Spaces:
Sleeping
Sleeping
| """Adaptive Compression Budget Manager v3.0 - Dynamic per-segment rates. | |
| Replaces static COMPRESSION_BUDGET table with dynamic rates that: | |
| 1. Vary by segment_type (validated against LLMLingua-2 research, ACL 2024 Findings) | |
| 2. Respond to VRAM pressure (emergency compression when GPU memory is tight) | |
| 3. Use sample-wise probability threshold θ (dynamic per-segment, not fixed ratio) | |
| Key rates (from LLMLingua-2 §L): | |
| - system_prompt: 0.9 (near-lossless - role-critical information must be preserved) | |
| - shared_context: 0.5 (high compression - shared docs have high redundancy) | |
| - agent_output: 0.7 (moderate - reasoning chains have task-critical steps) | |
| - tool_result: 0.6 (moderate-high - tool outputs often contain padded JSON/XML) | |
| - user_query: 1.0 (NEVER compress - user intent must be preserved exactly) | |
| Under VRAM pressure > 0.85: multiply all non-user_query rates by 0.8 (emergency). | |
| Usage: | |
| manager = CompressionBudgetManager() | |
| rate = manager.get_rate_for_segment("shared_context", token_count=1000, vram_pressure=0.5) | |
| # rate = 0.5 (normal) | |
| rate_emergency = manager.get_rate_for_segment("shared_context", token_count=1000, vram_pressure=0.9) | |
| # rate = 0.4 (0.5 * 0.8 emergency multiplier) | |
| """ | |
| import asyncio | |
| import logging | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| # Minimum tokens before compression overhead is worthwhile | |
| COMPRESSION_MIN_TOKENS = 512 | |
| # VRAM pressure threshold for emergency compression | |
| VRAM_EMERGENCY_THRESHOLD = 0.85 | |
| # Emergency multiplier when VRAM pressure > threshold | |
| VRAM_EMERGENCY_MULTIPLIER = 0.8 | |
| class SegmentType(Enum): | |
| """Type of content segment for compression budget determination.""" | |
| SYSTEM_PROMPT = "system_prompt" | |
| SHARED_CONTEXT = "shared_context" | |
| AGENT_OUTPUT = "agent_output" | |
| TOOL_RESULT = "tool_result" | |
| USER_QUERY = "user_query" | |
| RETRIEVED_DOCS = "retrieved_docs" | |
| CONV_HISTORY = "conv_history" | |
| RECENT_TURNS = "recent_turns" | |
| COT_REASONING = "cot_reasoning" | |
| RAG_CHUNK = "rag_chunk" | |
| UNKNOWN = "unknown" | |
| # Dynamic compression rate table (higher = more aggressive = lower output) | |
| # Source: LLMLingua-2 research (ACL 2024 Findings) - dynamic per-sample approach | |
| DYNAMIC_RATE_TABLE: dict[SegmentType, float] = { | |
| # Near-lossless: system prompts are dense with role-critical information | |
| SegmentType.SYSTEM_PROMPT: 0.9, | |
| # High compression: shared retrieved docs have high redundancy | |
| SegmentType.SHARED_CONTEXT: 0.5, | |
| SegmentType.RETRIEVED_DOCS: 0.5, | |
| # Moderate: agent reasoning chains contain task-critical steps | |
| SegmentType.AGENT_OUTPUT: 0.7, | |
| SegmentType.COT_REASONING: 0.7, | |
| # Moderate-high: tool outputs often contain padded JSON/XML | |
| SegmentType.TOOL_RESULT: 0.6, | |
| # High compression: resolved context is safe to compress | |
| SegmentType.CONV_HISTORY: 0.4, | |
| SegmentType.RAG_CHUNK: 0.4, | |
| # NO compression: recent relevance and user intent must be exact | |
| SegmentType.RECENT_TURNS: 0.0, | |
| SegmentType.USER_QUERY: 1.0, # 1.0 = no compression | |
| # Safe default | |
| SegmentType.UNKNOWN: 0.5, | |
| } | |
| class CompressionPlan: | |
| """Compression plan for a single segment.""" | |
| segment: str | |
| segment_type: SegmentType | |
| original_tokens: int | |
| target_rate: float # 0.0 = no compression, 1.0 = most aggressive | |
| should_compress: bool | |
| reason: str | |
| emergency: bool = False # True if VRAM emergency multiplier applied | |
| class CompressionBudgetManager: | |
| """ | |
| Dynamic compression budget manager with VRAM-pressure-responsive rates. | |
| Key design decision: uses dynamic per-sample probability threshold θ | |
| rather than fixed ratio enforcement. This allows natural variation | |
| in compression ratio per segment based on content characteristics. | |
| Usage: | |
| manager = CompressionBudgetManager() | |
| plan = manager.plan(segment_text, SegmentType.SHARED_CONTEXT) | |
| # Or get rate directly for custom compression | |
| rate = manager.get_rate_for_segment("agent_output", token_count=1000, vram_pressure=0.5) | |
| """ | |
| def __init__(self): | |
| self._lock = asyncio.Lock() | |
| def get_rate_for_segment( | |
| self, | |
| segment_type: str, | |
| token_count: int, | |
| vram_pressure: float = 0.0, | |
| ) -> float: | |
| """ | |
| Get compression rate for a segment type with VRAM pressure adjustment. | |
| Args: | |
| segment_type: String name of segment type (e.g., "shared_context") | |
| token_count: Number of tokens in segment | |
| vram_pressure: Current VRAM utilization (0.0-1.0) | |
| Returns: | |
| Compression rate (0.0-1.0), or 1.0 if no compression needed | |
| """ | |
| # Parse segment type | |
| try: | |
| st = SegmentType(segment_type) | |
| except ValueError: | |
| st = SegmentType.UNKNOWN | |
| # Never compress user queries | |
| if st == SegmentType.USER_QUERY: | |
| return 1.0 | |
| # Get base rate | |
| rate = DYNAMIC_RATE_TABLE.get(st, DYNAMIC_RATE_TABLE[SegmentType.UNKNOWN]) | |
| # Never compress system prompts (prefix cache critical) | |
| if st == SegmentType.SYSTEM_PROMPT: | |
| return 0.9 # Near-lossless, not zero (LLMLingua-2 default) | |
| # Apply VRAM emergency multiplier | |
| emergency = False | |
| if vram_pressure > VRAM_EMERGENCY_THRESHOLD: | |
| rate = rate * VRAM_EMERGENCY_MULTIPLIER | |
| emergency = True | |
| return rate | |
| def plan( | |
| self, | |
| segment: str, | |
| segment_type: SegmentType, | |
| token_count: Optional[int] = None, | |
| vram_pressure: float = 0.0, | |
| ) -> CompressionPlan: | |
| """ | |
| Create a compression plan for a segment. | |
| Args: | |
| segment: Text content to potentially compress | |
| segment_type: Type of content (determines budget) | |
| token_count: Optional pre-computed token count (faster) | |
| vram_pressure: Current VRAM utilization for emergency detection | |
| Returns: | |
| CompressionPlan with decision and parameters | |
| """ | |
| from apohara_context_forge.token_counter import TokenCounter | |
| if token_count is None: | |
| token_count = TokenCounter.get().count(segment) | |
| rate = self.get_rate_for_segment(segment_type.value, token_count, vram_pressure) | |
| # Hard rule: never compress user queries | |
| if segment_type == SegmentType.USER_QUERY: | |
| return CompressionPlan( | |
| segment=segment, | |
| segment_type=segment_type, | |
| original_tokens=token_count, | |
| target_rate=1.0, | |
| should_compress=False, | |
| reason="user_query: never compress (intent must be preserved)", | |
| ) | |
| # Hard rule: never compress system prompts (prefix cache critical) | |
| if segment_type == SegmentType.SYSTEM_PROMPT: | |
| return CompressionPlan( | |
| segment=segment, | |
| segment_type=segment_type, | |
| original_tokens=token_count, | |
| target_rate=0.9, # Near-lossless | |
| should_compress=True, | |
| reason="system_prompt: near-lossless compression (prefix cache ok)", | |
| ) | |
| # Skip compression for too-short segments | |
| if token_count < COMPRESSION_MIN_TOKENS: | |
| return CompressionPlan( | |
| segment=segment, | |
| segment_type=segment_type, | |
| original_tokens=token_count, | |
| target_rate=0.0, | |
| should_compress=False, | |
| reason=f"too short ({token_count} tokens < {COMPRESSION_MIN_TOKENS} minimum)", | |
| ) | |
| # Check for emergency compression | |
| emergency = vram_pressure > VRAM_EMERGENCY_THRESHOLD | |
| return CompressionPlan( | |
| segment=segment, | |
| segment_type=segment_type, | |
| original_tokens=token_count, | |
| target_rate=rate, | |
| should_compress=True, | |
| reason=f"{segment_type.value}: rate={rate} (vram_pressure={vram_pressure:.2f})" | |
| + (" [EMERGENCY]" if emergency else ""), | |
| emergency=emergency, | |
| ) | |
| async def compress_with_plan(self, plan: CompressionPlan) -> tuple[str, float]: | |
| """ | |
| Execute compression according to plan. | |
| Args: | |
| plan: CompressionPlan from .plan() | |
| Returns: | |
| Tuple of (compressed_text, actual_compression_ratio) | |
| """ | |
| if not plan.should_compress: | |
| return plan.segment, 1.0 | |
| from apohara_context_forge.compression.compressor import ContextCompressor | |
| compressor = ContextCompressor() | |
| await compressor.load() | |
| return await compressor.compress( | |
| plan.segment, | |
| rate=plan.target_rate, | |
| ) | |
| def plan_and_compress( | |
| self, | |
| segment: str, | |
| segment_type: SegmentType, | |
| vram_pressure: float = 0.0, | |
| ) -> tuple[CompressionPlan, Optional[tuple[str, float]]]: | |
| """ | |
| Convenience: create plan and return (plan, None) or (plan, (compressed, ratio)). | |
| Synchronous version for non-async contexts. | |
| """ | |
| plan = self.plan(segment, segment_type, vram_pressure=vram_pressure) | |
| if plan.should_compress: | |
| # Note: caller should await compress_with_plan for actual compression | |
| return plan, None | |
| return plan, None | |
| def detect_segment_type(segment: str) -> SegmentType: | |
| """ | |
| Heuristic segment type detection based on content patterns. | |
| Override with explicit type when known. | |
| """ | |
| # Check for system prompt indicators | |
| system_indicators = ["system:", "instructions:", "# system", "you are a "] | |
| for indicator in system_indicators: | |
| if indicator.lower() in segment.lower()[:100]: | |
| return SegmentType.SYSTEM_PROMPT | |
| # Check for user query indicators (should be near start) | |
| user_indicators = ["query:", "question:", "what is", "how do", "tell me"] | |
| for indicator in user_indicators: | |
| if indicator.lower() in segment.lower()[:50]: | |
| return SegmentType.USER_QUERY | |
| # Check for tool output indicators | |
| tool_indicators = ["tool:", "function:", "execution result:", "output:", "tool result:"] | |
| for indicator in tool_indicators: | |
| if indicator.lower() in segment.lower()[:100]: | |
| return SegmentType.TOOL_RESULT | |
| # Check for CoT reasoning FIRST (before agent — "step" + "reasoning" without ":") | |
| if "step by step" in segment.lower() or ( | |
| "step" in segment.lower() and "reasoning" in segment.lower() | |
| ): | |
| return SegmentType.COT_REASONING | |
| # Check for agent output indicators (after CoT) | |
| agent_indicators = ["summarized", "analyzed", "reasoning:", "step"] | |
| if any(ind in segment.lower()[:150] for ind in agent_indicators): | |
| return SegmentType.AGENT_OUTPUT | |
| # Check for RAG/retrieved content | |
| rag_indicators = ["document", "retrieved", "context:", "reference:"] | |
| if any(ind in segment.lower()[:200] for ind in rag_indicators): | |
| return SegmentType.RETRIEVED_DOCS | |
| # Check for shared context (general knowledge) | |
| shared_indicators = ["knowledge", "context:", "background:"] | |
| if any(ind in segment.lower()[:200] for ind in shared_indicators): | |
| return SegmentType.SHARED_CONTEXT | |
| return SegmentType.UNKNOWN | |
| # Backwards compatibility alias | |
| COMPRESSION_BUDGET = DYNAMIC_RATE_TABLE |