linyueqian commited on
Commit
3690391
·
verified ·
1 Parent(s): 2b5660c

feat: add custom tokenizer with multi-char Chinese token splitting

Browse files

VoxCPM2 was trained with `mask_multichar_chinese_tokens` which splits
multi-character Chinese tokens (e.g. "你好" → ["你", "好"]) into individual
character IDs. The current HF repo ships a plain LlamaTokenizerFast without
this splitting, causing downstream inference frameworks (vLLM-Omni, etc.) to
produce garbled Chinese audio.

This PR adds `VoxCPM2Tokenizer` (subclass of LlamaTokenizerFast) that applies
char-splitting inside `encode()` and `__call__()` transparently.

Note: `tokenizer_config.json` also needs `auto_map` and `tokenizer_class`
updates to point to this custom tokenizer. Will add in a follow-up commit
to this PR.

Files changed (1) hide show
  1. tokenization_voxcpm2.py +72 -0
tokenization_voxcpm2.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom tokenizer for VoxCPM2 that splits multi-character Chinese tokens.
2
+
3
+ VoxCPM2 was trained with ``mask_multichar_chinese_tokens`` which splits
4
+ multi-character Chinese tokens (e.g. "你好" -> ["你", "好"]) into individual
5
+ character IDs before embedding. The base LlamaTokenizerFast produces
6
+ multi-character Chinese tokens that the model has never seen during training,
7
+ yielding garbled Chinese audio output in downstream inference frameworks.
8
+
9
+ This module provides ``VoxCPM2Tokenizer`` which transparently applies the
10
+ character splitting inside ``encode()`` and ``__call__()``, so any downstream
11
+ consumer (vLLM, vLLM-Omni, Nano-vLLM, etc.) gets correct single-character
12
+ IDs without code changes.
13
+ """
14
+
15
+ from transformers import LlamaTokenizerFast
16
+
17
+
18
+ class VoxCPM2Tokenizer(LlamaTokenizerFast):
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self._split_map = self._build_split_map()
23
+
24
+ def _build_split_map(self) -> dict[int, list[int]]:
25
+ vocab = self.get_vocab()
26
+ split_map: dict[int, list[int]] = {}
27
+ for token, tid in vocab.items():
28
+ clean = token.replace("\u2581", "")
29
+ if len(clean) >= 2 and all(self._is_cjk(c) for c in clean):
30
+ char_ids = self.convert_tokens_to_ids(list(clean))
31
+ if all(c != self.unk_token_id for c in char_ids):
32
+ split_map[tid] = char_ids
33
+ return split_map
34
+
35
+ @staticmethod
36
+ def _is_cjk(c: str) -> bool:
37
+ return (
38
+ "\u4e00" <= c <= "\u9fff"
39
+ or "\u3400" <= c <= "\u4dbf"
40
+ or "\uf900" <= c <= "\ufaff"
41
+ or "\U00020000" <= c <= "\U0002a6df"
42
+ )
43
+
44
+ def _expand_ids(self, ids: list[int]) -> list[int]:
45
+ result: list[int] = []
46
+ for tid in ids:
47
+ expansion = self._split_map.get(tid)
48
+ if expansion is not None:
49
+ result.extend(expansion)
50
+ else:
51
+ result.append(tid)
52
+ return result
53
+
54
+ def encode(self, text, *args, **kwargs):
55
+ ids = super().encode(text, *args, **kwargs)
56
+ return self._expand_ids(ids)
57
+
58
+ def __call__(self, text, *args, **kwargs):
59
+ result = super().__call__(text, *args, **kwargs)
60
+ if isinstance(result, dict) and "input_ids" in result:
61
+ ids = result["input_ids"]
62
+ if isinstance(ids, list) and ids and isinstance(ids[0], list):
63
+ result["input_ids"] = [self._expand_ids(x) for x in ids]
64
+ if "attention_mask" in result:
65
+ result["attention_mask"] = [
66
+ [1] * len(x) for x in result["input_ids"]
67
+ ]
68
+ else:
69
+ result["input_ids"] = self._expand_ids(ids)
70
+ if "attention_mask" in result:
71
+ result["attention_mask"] = [1] * len(result["input_ids"])
72
+ return result