File size: 5,694 Bytes
234574a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0a8ed
234574a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""Token counting via real Qwen3 tokenizer - fixes BUG-001.

Replaces heuristic len(text.split()) // 4 * 3 with accurate tokenization.
Uses transformers AutoTokenizer for Qwen3-35B-A3B (or fallback).
"""
import asyncio
import logging
from functools import lru_cache
from typing import Optional

logger = logging.getLogger(__name__)


class TokenCounter:
    """
    Accurate token counter using Qwen3 tokenizer.
    Singleton pattern for lazy initialization.
    
    Usage:
        counter = TokenCounter.get()
        token_count = counter.count("Hello world")
        token_ids = counter.encode("Hello world")
        kv_bytes = counter.compute_kv_vram_bytes(token_count)
    """
    
    _instance: Optional["TokenCounter"] = None
    
    def __init__(
        self,
        model_id: str = "Qwen/Qwen3-235B-A22B",
        use_fast: bool = True,
    ):
        self._model_id = model_id
        self._use_fast = use_fast
        self._tokenizer = None
        self._initialized = False
        self._use_fallback = False
    
    @classmethod
    def get(cls, model_id: str = "Qwen/Qwen3-235B-A22B") -> "TokenCounter":
        """Get or create singleton instance."""
        if cls._instance is None:
            cls._instance = cls(model_id)
        return cls._instance
    
    @classmethod
    def reset(cls) -> None:
        """Reset singleton (for testing)."""
        cls._instance = None
    
    def _ensure_initialized(self) -> None:
        """Lazy initialization of tokenizer."""
        if self._initialized:
            return
        
        try:
            from transformers import AutoTokenizer
            self._tokenizer = AutoTokenizer.from_pretrained(
                self._model_id,
                trust_remote_code=True,
                use_fast=self._use_fast,
            )
            self._initialized = True
            logger.info(f"TokenCounter initialized with {self._model_id}")
        except Exception as e:
            logger.warning(f"Failed to load {self._model_id}: {e}. Using fallback.")
            self._use_fallback = True
            self._initialized = True
    
    def count(self, text: str) -> int:
        """
        Count tokens in text (blocking - use count_async in hot path).
        
        Args:
            text: Input string
            
        Returns:
            Number of tokens
        """
        self._ensure_initialized()
        
        if self._use_fallback:
            # Rough fallback: ~0.75 tokens per word
            return max(1, int(len(text.split()) * 0.75))
        
        return len(self._tokenizer.encode(text, add_special_tokens=False))
    
    def encode(self, text: str) -> list[int]:
        """
        Encode text to token IDs (blocking).
        
        Args:
            text: Input string
            
        Returns:
            List of token IDs
        """
        self._ensure_initialized()
        
        if self._use_fallback:
            return [hash(w) % 50000 for w in text.split()]
        
        return self._tokenizer.encode(text, add_special_tokens=False)
    
    def decode(self, token_ids: list[int]) -> str:
        """Decode token IDs back to text."""
        self._ensure_initialized()
        
        if self._use_fallback:
            return " ".join(str(t) for t in token_ids)
        
        return self._tokenizer.decode(token_ids, skip_special_tokens=True)
    
    async def count_async(self, text: str) -> int:
        """
        Async token counting - non-blocking in hot path.
        
        Args:
            text: Input string
            
        Returns:
            Number of tokens
        """
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.count, text)
    
    async def encode_async(self, text: str) -> list[int]:
        """
        Async encoding - non-blocking in hot path.
        
        Args:
            text: Input string
            
        Returns:
            List of token IDs
        """
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.encode, text)
    
    def compute_kv_vram_bytes(
        self,
        token_count: int,
        n_layers: int = 64,
        n_kv_heads: int = 8,
        head_dim: int = 128,
        dtype_bytes: int = 2,  # fp16 = 2 bytes, bf16 = 2 bytes
    ) -> int:
        """
        Compute VRAM bytes for KV cache given token count.
        
        Formula: 2 (K+V) × layers × tokens × kv_heads × head_dim × dtype_bytes
        
        Args:
            token_count: Number of tokens in context
            n_layers: Number of transformer layers (Qwen3-35B has 64)
            n_kv_heads: Number of KV heads (Qwen3 uses GQA, typically 8)
            head_dim: Dimension per head (typically 128 for Qwen)
            dtype_bytes: Bytes per value (2 for fp16/bf16)
            
        Returns:
            VRAM bytes needed for KV cache
        """
        return 2 * n_layers * token_count * n_kv_heads * head_dim * dtype_bytes
    
    def compute_kv_vram_gb(
        self,
        token_count: int,
        **kwargs
    ) -> float:
        """Compute VRAM in gigabytes."""
        return self.compute_kv_vram_bytes(token_count, **kwargs) / (1024 ** 3)


# Convenience functions for use throughout codebase
def count_tokens(text: str) -> int:
    """Quick token count."""
    return TokenCounter.get().count(text)


def encode_tokens(text: str) -> list[int]:
    """Quick token encode."""
    return TokenCounter.get().encode(text)


def compute_kv_gb(token_count: int, **kwargs) -> float:
    """Quick KV VRAM compute in GB."""
    return TokenCounter.get().compute_kv_vram_gb(token_count, **kwargs)