File size: 2,154 Bytes
6d9c72b
 
 
 
 
cf0a8ed
6d9c72b
 
 
 
 
 
 
 
 
cf0a8ed
6d9c72b
 
 
 
 
 
 
 
cf0a8ed
6d9c72b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""LLMLingua-2 async wrapper - runs in ThreadPoolExecutor."""
import asyncio
import logging
from typing import Literal

from llmlingua import PromptCompressor

logger = logging.getLogger(__name__)


class ContextCompressor:
    """Async wrapper for LLMLingua-2 compression."""

    def __init__(self, model_name: str = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank"):
        self._model_name = model_name
        self._model: PromptCompressor | None = None
        self._lock = asyncio.Lock()

    async def load(self) -> None:
        """Lazy load the compressor model."""
        if self._model is None:
            async with self._lock:
                if self._model is None:
                    logger.info(f"Loading compressor: {self._model_name}")
                    self._model = PromptCompressor(self._model_name)

    async def compress(self, context: str, rate: float = 0.5) -> tuple[str, float]:
        """
        Compress context at given rate.
        Returns (compressed_text, actual_compression_ratio).
        """
        await self.load()
        loop = asyncio.get_event_loop()
        
        def sync_compress():
            assert self._model is not None
            result = self._model.compress_prompt(
                context,
                rate=rate,
                force_tokens=[".", "!", "?", ",", "\n"],
            )
            return result["compressed_prompt"]

        compressed = await loop.run_in_executor(None, sync_compress)
        original_tokens = len(context.split())
        compressed_tokens = len(compressed.split())
        actual_ratio = original_tokens / compressed_tokens if compressed_tokens > 0 else 1.0
        logger.debug(f"Compressed {original_tokens} -> {compressed_tokens} tokens (rate={rate})")
        return compressed, actual_ratio

    async def compress_batch(
        self, contexts: list[str], rate: float = 0.5
    ) -> list[tuple[str, float]]:
        """Compress multiple contexts."""
        results = []
        for ctx in contexts:
            compressed, ratio = await self.compress(ctx, rate)
            results.append((compressed, ratio))
        return results