Spaces:
Sleeping
Sleeping
| """Token counting via real Qwen3 tokenizer - fixes BUG-001. | |
| Replaces heuristic len(text.split()) // 4 * 3 with accurate tokenization. | |
| Uses transformers AutoTokenizer for Qwen3-35B-A3B (or fallback). | |
| """ | |
| import asyncio | |
| import logging | |
| from functools import lru_cache | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| class TokenCounter: | |
| """ | |
| Accurate token counter using Qwen3 tokenizer. | |
| Singleton pattern for lazy initialization. | |
| Usage: | |
| counter = TokenCounter.get() | |
| token_count = counter.count("Hello world") | |
| token_ids = counter.encode("Hello world") | |
| kv_bytes = counter.compute_kv_vram_bytes(token_count) | |
| """ | |
| _instance: Optional["TokenCounter"] = None | |
| def __init__( | |
| self, | |
| model_id: str = "Qwen/Qwen3-235B-A22B", | |
| use_fast: bool = True, | |
| ): | |
| self._model_id = model_id | |
| self._use_fast = use_fast | |
| self._tokenizer = None | |
| self._initialized = False | |
| self._use_fallback = False | |
| def get(cls, model_id: str = "Qwen/Qwen3-235B-A22B") -> "TokenCounter": | |
| """Get or create singleton instance.""" | |
| if cls._instance is None: | |
| cls._instance = cls(model_id) | |
| return cls._instance | |
| def reset(cls) -> None: | |
| """Reset singleton (for testing).""" | |
| cls._instance = None | |
| def _ensure_initialized(self) -> None: | |
| """Lazy initialization of tokenizer.""" | |
| if self._initialized: | |
| return | |
| try: | |
| from transformers import AutoTokenizer | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self._model_id, | |
| trust_remote_code=True, | |
| use_fast=self._use_fast, | |
| ) | |
| self._initialized = True | |
| logger.info(f"TokenCounter initialized with {self._model_id}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load {self._model_id}: {e}. Using fallback.") | |
| self._use_fallback = True | |
| self._initialized = True | |
| def count(self, text: str) -> int: | |
| """ | |
| Count tokens in text (blocking - use count_async in hot path). | |
| Args: | |
| text: Input string | |
| Returns: | |
| Number of tokens | |
| """ | |
| self._ensure_initialized() | |
| if self._use_fallback: | |
| # Rough fallback: ~0.75 tokens per word | |
| return max(1, int(len(text.split()) * 0.75)) | |
| return len(self._tokenizer.encode(text, add_special_tokens=False)) | |
| def encode(self, text: str) -> list[int]: | |
| """ | |
| Encode text to token IDs (blocking). | |
| Args: | |
| text: Input string | |
| Returns: | |
| List of token IDs | |
| """ | |
| self._ensure_initialized() | |
| if self._use_fallback: | |
| return [hash(w) % 50000 for w in text.split()] | |
| return self._tokenizer.encode(text, add_special_tokens=False) | |
| def decode(self, token_ids: list[int]) -> str: | |
| """Decode token IDs back to text.""" | |
| self._ensure_initialized() | |
| if self._use_fallback: | |
| return " ".join(str(t) for t in token_ids) | |
| return self._tokenizer.decode(token_ids, skip_special_tokens=True) | |
| async def count_async(self, text: str) -> int: | |
| """ | |
| Async token counting - non-blocking in hot path. | |
| Args: | |
| text: Input string | |
| Returns: | |
| Number of tokens | |
| """ | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, self.count, text) | |
| async def encode_async(self, text: str) -> list[int]: | |
| """ | |
| Async encoding - non-blocking in hot path. | |
| Args: | |
| text: Input string | |
| Returns: | |
| List of token IDs | |
| """ | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, self.encode, text) | |
| def compute_kv_vram_bytes( | |
| self, | |
| token_count: int, | |
| n_layers: int = 64, | |
| n_kv_heads: int = 8, | |
| head_dim: int = 128, | |
| dtype_bytes: int = 2, # fp16 = 2 bytes, bf16 = 2 bytes | |
| ) -> int: | |
| """ | |
| Compute VRAM bytes for KV cache given token count. | |
| Formula: 2 (K+V) × layers × tokens × kv_heads × head_dim × dtype_bytes | |
| Args: | |
| token_count: Number of tokens in context | |
| n_layers: Number of transformer layers (Qwen3-35B has 64) | |
| n_kv_heads: Number of KV heads (Qwen3 uses GQA, typically 8) | |
| head_dim: Dimension per head (typically 128 for Qwen) | |
| dtype_bytes: Bytes per value (2 for fp16/bf16) | |
| Returns: | |
| VRAM bytes needed for KV cache | |
| """ | |
| return 2 * n_layers * token_count * n_kv_heads * head_dim * dtype_bytes | |
| def compute_kv_vram_gb( | |
| self, | |
| token_count: int, | |
| **kwargs | |
| ) -> float: | |
| """Compute VRAM in gigabytes.""" | |
| return self.compute_kv_vram_bytes(token_count, **kwargs) / (1024 ** 3) | |
| # Convenience functions for use throughout codebase | |
| def count_tokens(text: str) -> int: | |
| """Quick token count.""" | |
| return TokenCounter.get().count(text) | |
| def encode_tokens(text: str) -> list[int]: | |
| """Quick token encode.""" | |
| return TokenCounter.get().encode(text) | |
| def compute_kv_gb(token_count: int, **kwargs) -> float: | |
| """Quick KV VRAM compute in GB.""" | |
| return TokenCounter.get().compute_kv_vram_gb(token_count, **kwargs) | |