Prince-1 commited on
Commit
bccd4cb
·
verified ·
1 Parent(s): 8514b28

Add files using upload-large-folder tool

Browse files
Files changed (5) hide show
  1. Convert.py +71 -0
  2. __init__.py +1 -0
  3. config.json +4 -0
  4. modeling_omnivoice.py +1598 -0
  5. tokenizer.model +3 -0
Convert.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "accelerate>=1.13.0",
5
+ # "flash-linear-attention>=0.4.2",
6
+ # "hf-xet>=1.4.3",
7
+ # "huggingface-hub>=1.8.0",
8
+ # "onnx>=1.21.0",
9
+ # "onnx-ir>=0.2.0",
10
+ # "onnxruntime>=1.24.4",
11
+ # "onnxruntime-genai>=0.13.1",
12
+ # "optimum>=2.1.0",
13
+ # "sentencepiece>=0.2.1",
14
+ # "tiktoken>=0.12.0",
15
+ # "torch>=2.11.0",
16
+ # "transformers==5.7.0",
17
+ # ]
18
+ # ///
19
+
20
+ import argparse
21
+ from pathlib import Path
22
+ from huggingface_hub import snapshot_download
23
+ #from onnxruntime_genai.python.models.builder import create_model
24
+ from onnxruntime_genai.models.builder import create_model
25
+ def main():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--name", required=False,default=None)
28
+ parser.add_argument("--token",required=False)
29
+ args = parser.parse_args()
30
+
31
+ token = args.token if args.token else None
32
+
33
+ pwd = Path.cwd()
34
+ model_dir = pwd / "model"
35
+ onnx_dir = pwd / "onnx"
36
+ cache_dir = pwd / "cache"
37
+ model_dir.mkdir(exist_ok=True)
38
+ onnx_dir.mkdir(exist_ok=True)
39
+ cache_dir.mkdir(exist_ok=True)
40
+
41
+ # ===== STEP 1: DOWNLOAD (HF HUB + XET backend automatically used) =====
42
+ print(">> Downloading model via huggingface_hub (Xet enabled if installed)...")
43
+
44
+ # local_path = snapshot_download(
45
+ # repo_id=args.name,
46
+ # local_dir=str(model_dir),
47
+ # token=token
48
+ # #local_dir_use_symlinks=False # important for ONNX tools
49
+ # )
50
+
51
+ #print(f"Model downloaded to: {local_path}")
52
+
53
+ # ===== STEP 2: CONVERT USING ONNX GENAI BUILDER =====
54
+ print(">> Converting to ONNX (GenAI format)...")
55
+
56
+ create_model(
57
+ model_name=args.name,
58
+ input_path=str(model_dir), # HF model directory
59
+ output_dir=str(onnx_dir), # ONNX output
60
+ precision="fp16", # fp32 | fp16 | int8 | int4 (if supported)
61
+ execution_provider="cpu", # cpu | cuda | dml
62
+ cache_dir=str(cache_dir), # optional cache
63
+ extra_options={}
64
+ )
65
+
66
+ print("\n✅ Done")
67
+ print(f"ONNX model at: {onnx_dir}")
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_omnivoice import OmniVoice, OmniVoiceConfig
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "OmniVoice"
4
  ],
 
 
 
 
5
  "audio_codebook_weights": [
6
  8,
7
  8,
 
2
  "architectures": [
3
  "OmniVoice"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_omnivoice.OmniVoiceConfig",
7
+ "AutoModel": "modeling_omnivoice.OmniVoice"
8
+ },
9
  "audio_codebook_weights": [
10
  8,
11
  8,
modeling_omnivoice.py ADDED
@@ -0,0 +1,1598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Core OmniVoice model implementation.
19
+
20
+ Defines the ``OmniVoice`` model class, generation config, and inference pipeline.
21
+ This is the main entry point for both inference and training:
22
+
23
+ - **Inference**: ``OmniVoice.from_pretrained()`` loads the model, then
24
+ ``model.generate()`` supports voice cloning, voice design, and auto voice.
25
+ - **Training**: ``model.forward()`` computes the training loss; the model is
26
+ built and used by ``omnivoice.training.builder`` and ``omnivoice.training.trainer``.
27
+
28
+ """
29
+
30
+ import difflib
31
+ import logging
32
+ import math
33
+ import os
34
+ import re
35
+ from dataclasses import dataclass, fields
36
+ from functools import partial
37
+ from typing import Any, List, Optional, Union
38
+
39
+ import numpy as np
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+ import torchaudio
44
+
45
+ try:
46
+ from torch.nn.attention.flex_attention import create_block_mask
47
+
48
+ _flex_attention_available = True
49
+ except ImportError:
50
+ _flex_attention_available = False
51
+ from transformers import (
52
+ AutoFeatureExtractor,
53
+ AutoModel,
54
+ AutoTokenizer,
55
+ HiggsAudioV2TokenizerModel,
56
+ PretrainedConfig,
57
+ PreTrainedModel,
58
+ )
59
+ from transformers.modeling_outputs import ModelOutput
60
+ from transformers.models.auto import CONFIG_MAPPING, AutoConfig
61
+
62
+ from omnivoice.utils.audio import (
63
+ cross_fade_chunks,
64
+ fade_and_pad_audio,
65
+ load_audio,
66
+ remove_silence,
67
+ trim_long_audio,
68
+ )
69
+ from omnivoice.utils.duration import RuleDurationEstimator
70
+ from omnivoice.utils.lang_map import LANG_IDS, LANG_NAMES
71
+ from omnivoice.utils.text import add_punctuation, chunk_text_punctuation
72
+ from omnivoice.utils.voice_design import (
73
+ _INSTRUCT_ALL_VALID,
74
+ _INSTRUCT_EN_TO_ZH,
75
+ _INSTRUCT_MUTUALLY_EXCLUSIVE,
76
+ _INSTRUCT_VALID_EN,
77
+ _INSTRUCT_VALID_ZH,
78
+ _INSTRUCT_ZH_TO_EN,
79
+ _ZH_RE,
80
+ )
81
+
82
+ logger = logging.getLogger(__name__)
83
+
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # Dataclasses
87
+ # ---------------------------------------------------------------------------
88
+
89
+
90
+ @dataclass
91
+ class VoiceClonePrompt:
92
+ ref_audio_tokens: torch.Tensor # (C, T)
93
+ ref_text: str
94
+ ref_rms: float
95
+
96
+
97
+ @dataclass
98
+ class OmniVoiceGenerationConfig:
99
+ num_step: int = 32
100
+ guidance_scale: float = 2.0
101
+ t_shift: float = 0.1
102
+ layer_penalty_factor: float = 5.0
103
+ position_temperature: float = 5.0
104
+ class_temperature: float = 0.0
105
+ denoise: bool = True
106
+ preprocess_prompt: bool = True
107
+ postprocess_output: bool = True
108
+ audio_chunk_duration: float = 15.0
109
+ audio_chunk_threshold: float = 30.0
110
+
111
+ @classmethod
112
+ def from_dict(cls, kwargs_dict):
113
+ valid_keys = {f.name for f in fields(cls)}
114
+ filtered = {k: v for k, v in kwargs_dict.items() if k in valid_keys}
115
+ return cls(**filtered)
116
+
117
+
118
+ @dataclass
119
+ class GenerationTask:
120
+ batch_size: int
121
+ texts: List[str]
122
+ target_lens: List[int]
123
+ langs: List[Optional[str]]
124
+ instructs: List[Optional[str]]
125
+ ref_texts: List[Optional[str]]
126
+ ref_audio_tokens: List[Optional[torch.Tensor]]
127
+ ref_rms: List[Optional[float]]
128
+ speed: Optional[List[float]] = None
129
+
130
+ def get_indices(self, config: OmniVoiceGenerationConfig, frame_rate: int):
131
+ threshold = int(config.audio_chunk_threshold * frame_rate)
132
+ short_idx = [i for i, l in enumerate(self.target_lens) if l <= threshold]
133
+ long_idx = [i for i, l in enumerate(self.target_lens) if l > threshold]
134
+ return short_idx, long_idx
135
+
136
+ def slice_task(self, indices: List[int]):
137
+ if not indices:
138
+ return None
139
+ return GenerationTask(
140
+ batch_size=len(indices),
141
+ texts=[self.texts[i] for i in indices],
142
+ target_lens=[self.target_lens[i] for i in indices],
143
+ langs=[self.langs[i] for i in indices],
144
+ instructs=[self.instructs[i] for i in indices],
145
+ ref_texts=[self.ref_texts[i] for i in indices],
146
+ ref_audio_tokens=[self.ref_audio_tokens[i] for i in indices],
147
+ ref_rms=[self.ref_rms[i] for i in indices],
148
+ speed=[self.speed[i] for i in indices] if self.speed else None,
149
+ )
150
+
151
+
152
+ @dataclass
153
+ class OmniVoiceModelOutput(ModelOutput):
154
+ loss: Optional[torch.Tensor] = None
155
+ logits: Optional[torch.Tensor] = None
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # Config & Model
160
+ # ---------------------------------------------------------------------------
161
+
162
+
163
+ class OmniVoiceConfig(PretrainedConfig):
164
+ model_type = "omnivoice"
165
+ sub_configs = {"llm_config": AutoConfig}
166
+
167
+ def __init__(
168
+ self,
169
+ audio_vocab_size: int = 1025,
170
+ audio_mask_id: int = 1024,
171
+ num_audio_codebook: int = 8,
172
+ audio_codebook_weights: Optional[list[float]] = None,
173
+ llm_config: Optional[Union[dict, PretrainedConfig]] = None,
174
+ **kwargs,
175
+ ):
176
+
177
+ if isinstance(llm_config, dict):
178
+ llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config)
179
+
180
+ self.llm_config = llm_config
181
+
182
+ super().__init__(**kwargs)
183
+ self.audio_vocab_size = audio_vocab_size
184
+ self.audio_mask_id = audio_mask_id
185
+ self.num_audio_codebook = num_audio_codebook
186
+ if audio_codebook_weights is None:
187
+ audio_codebook_weights = [8, 8, 6, 6, 4, 4, 2, 2]
188
+ self.audio_codebook_weights = audio_codebook_weights
189
+
190
+
191
+ def _resolve_model_path(name_or_path: str) -> str:
192
+ if os.path.isdir(name_or_path):
193
+ return name_or_path
194
+ from huggingface_hub import snapshot_download
195
+
196
+ return snapshot_download(name_or_path)
197
+
198
+
199
+ class OmniVoice(PreTrainedModel):
200
+ _supports_flex_attn = True
201
+ _supports_flash_attn_2 = True
202
+ _supports_sdpa = True
203
+ config_class = OmniVoiceConfig
204
+
205
+ def __init__(self, config: OmniVoiceConfig, llm: Optional[PreTrainedModel] = None):
206
+ super().__init__(config)
207
+
208
+ if llm is not None:
209
+ # If an LLM instance is provided, use it directly
210
+ # (skipping config-based init).
211
+ self.llm = llm
212
+ else:
213
+ # Otherwise, initialize the LLM from the config.
214
+ self.llm = AutoModel.from_config(self.config.llm_config)
215
+
216
+ self.audio_embeddings = nn.Embedding(
217
+ config.num_audio_codebook * config.audio_vocab_size,
218
+ self.config.llm_config.hidden_size,
219
+ )
220
+ self.register_buffer(
221
+ "codebook_layer_offsets",
222
+ torch.arange(config.num_audio_codebook) * config.audio_vocab_size,
223
+ )
224
+
225
+ self.audio_heads = nn.Linear(
226
+ self.config.llm_config.hidden_size,
227
+ config.num_audio_codebook * config.audio_vocab_size,
228
+ bias=False,
229
+ )
230
+
231
+ self.normalized_audio_codebook_weights = [
232
+ w / sum(config.audio_codebook_weights)
233
+ for w in config.audio_codebook_weights
234
+ ]
235
+
236
+ self.post_init()
237
+
238
+ # Inference-only attributes (set by from_pretrained when not in train mode)
239
+ self.text_tokenizer = None
240
+ self.audio_tokenizer = None
241
+ self.duration_estimator = None
242
+ self.sampling_rate = None
243
+ self._asr_pipe = None
244
+
245
+ @classmethod
246
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
247
+ train_mode = kwargs.pop("train", False)
248
+ load_asr = kwargs.pop("load_asr", False)
249
+ asr_model_name = kwargs.pop("asr_model_name", "openai/whisper-large-v3-turbo")
250
+
251
+ # Suppress noisy INFO logs from transformers/huggingface_hub during loading
252
+ _prev_disable = logging.root.manager.disable
253
+ logging.disable(logging.INFO)
254
+
255
+ try:
256
+ # Resolve to local path first; download only if not already cached
257
+ resolved_path = _resolve_model_path(pretrained_model_name_or_path)
258
+
259
+ model = super().from_pretrained(resolved_path, *args, **kwargs)
260
+
261
+ if not train_mode:
262
+ model.text_tokenizer = AutoTokenizer.from_pretrained(resolved_path)
263
+
264
+ audio_tokenizer_path = os.path.join(resolved_path, "audio_tokenizer")
265
+
266
+ if not os.path.isdir(audio_tokenizer_path):
267
+ audio_tokenizer_path = _resolve_model_path(
268
+ "eustlb/higgs-audio-v2-tokenizer"
269
+ )
270
+
271
+ # higgs-audio-v2-tokenizer does not support MPS
272
+ # (output channels > 65536)
273
+ tokenizer_device = (
274
+ "cpu" if str(model.device).startswith("mps") else model.device
275
+ )
276
+ model.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
277
+ audio_tokenizer_path, device_map=tokenizer_device
278
+ )
279
+ model.feature_extractor = AutoFeatureExtractor.from_pretrained(
280
+ audio_tokenizer_path
281
+ )
282
+
283
+ model.sampling_rate = model.feature_extractor.sampling_rate
284
+
285
+ model.duration_estimator = RuleDurationEstimator()
286
+
287
+ if load_asr:
288
+ model.load_asr_model(model_name=asr_model_name)
289
+ finally:
290
+ logging.disable(_prev_disable)
291
+
292
+ return model
293
+
294
+ # -------------------------------------------------------------------
295
+ # ASR support (optional, for auto-transcription)
296
+ # -------------------------------------------------------------------
297
+
298
+ def load_asr_model(self, model_name: str = "openai/whisper-large-v3-turbo"):
299
+ """Load a Whisper ASR model for reference audio transcription.
300
+
301
+ Args:
302
+ model_name: HuggingFace model name or local path for the Whisper model.
303
+ """
304
+ from transformers import pipeline as hf_pipeline
305
+
306
+ logger.info("Loading ASR model %s ...", model_name)
307
+ asr_dtype = (
308
+ torch.float16 if str(self.device).startswith("cuda") else torch.float32
309
+ )
310
+
311
+ model_name = _resolve_model_path(model_name)
312
+
313
+ self._asr_pipe = hf_pipeline(
314
+ "automatic-speech-recognition",
315
+ model=model_name,
316
+ dtype=asr_dtype,
317
+ device_map=self.device,
318
+ )
319
+ logger.info("ASR model loaded on %s.", self.device)
320
+
321
+ @torch.inference_mode()
322
+ def transcribe(
323
+ self,
324
+ audio: Union[str, tuple],
325
+ ) -> str:
326
+ """Transcribe audio using the loaded Whisper ASR model.
327
+
328
+ Args:
329
+ audio: File path or ``(waveform, sample_rate)`` tuple.
330
+ Waveform can be a numpy array or torch.Tensor of shape
331
+ ``(1, T)`` or ``(T,)``.
332
+
333
+ Returns:
334
+ Transcribed text.
335
+ """
336
+ if self._asr_pipe is None:
337
+ raise RuntimeError(
338
+ "ASR model is not loaded. Call model.load_asr_model() first."
339
+ )
340
+
341
+ if isinstance(audio, str):
342
+ return self._asr_pipe(audio)["text"].strip()
343
+ else:
344
+ waveform, sr = audio
345
+ if isinstance(waveform, torch.Tensor):
346
+ waveform = waveform.cpu().numpy()
347
+ waveform = np.squeeze(waveform) # (1, T) or (T,) → (T,)
348
+ audio_input = {
349
+ "array": waveform,
350
+ "sampling_rate": sr,
351
+ }
352
+ return self._asr_pipe(audio_input)["text"].strip()
353
+
354
+ def get_input_embeddings(self):
355
+ return self.llm.get_input_embeddings()
356
+
357
+ def set_input_embeddings(self, value):
358
+ self.llm.set_input_embeddings(value)
359
+
360
+ def _prepare_embed_inputs(
361
+ self, input_ids: torch.Tensor, audio_mask: torch.Tensor
362
+ ) -> torch.Tensor:
363
+ """
364
+ Prepares embeddings from input_ids of shape (batch_size, layers, seq_length).
365
+ Embedding shape is (batch_size, seq_length, hidden_size).
366
+ """
367
+ text_embeds = self.get_input_embeddings()(input_ids[:, 0, :])
368
+
369
+ # Apply shift to audio IDs based on codebook layer
370
+ # audio_ids: [Batch, 8, Seq]
371
+ # codebook_layer_offsets: [1, 8, 1]
372
+ # Result: Layer 0 ID Layer 1 ID + Layer 2 ID + 2050...
373
+ shifted_ids = (
374
+ input_ids * audio_mask.unsqueeze(1)
375
+ ) + self.codebook_layer_offsets.view(1, -1, 1)
376
+
377
+ # input: [Batch, 8, Seq] -> output: [Batch, Seq, Hidden]
378
+ audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1)
379
+
380
+ return torch.where(audio_mask.unsqueeze(-1), audio_embeds, text_embeds)
381
+
382
+ def forward(
383
+ self,
384
+ input_ids: torch.LongTensor,
385
+ audio_mask: torch.Tensor,
386
+ labels: Optional[torch.LongTensor] = None,
387
+ attention_mask: Optional[torch.Tensor] = None,
388
+ document_ids: Optional[torch.Tensor] = None,
389
+ position_ids: Optional[torch.LongTensor] = None,
390
+ ):
391
+
392
+ inputs_embeds = self._prepare_embed_inputs(input_ids, audio_mask)
393
+
394
+ if attention_mask is None and document_ids is not None:
395
+ if not _flex_attention_available:
396
+ raise RuntimeError(
397
+ "flex_attention is not available in the current environment. "
398
+ "If you do not need flex_attention, set "
399
+ '"attn_implementation": "sdpa" in your training config.'
400
+ )
401
+ attention_mask = create_block_mask(
402
+ _get_packed_mask(
403
+ document_ids[0].to(inputs_embeds.device),
404
+ ),
405
+ B=None,
406
+ H=None,
407
+ Q_LEN=input_ids.size(-1),
408
+ KV_LEN=input_ids.size(-1),
409
+ _compile=True,
410
+ device=inputs_embeds.device,
411
+ )
412
+
413
+ llm_outputs = self.llm(
414
+ inputs_embeds=inputs_embeds,
415
+ attention_mask=attention_mask,
416
+ return_dict=True,
417
+ position_ids=position_ids,
418
+ )
419
+ hidden_states = llm_outputs[0]
420
+
421
+ loss = None
422
+
423
+ # Shape: [B, S, C * Vocab]
424
+ batch_size, seq_len, _ = hidden_states.shape
425
+ logits_flat = self.audio_heads(hidden_states)
426
+ # Shape: [B, S, C, Vocab] -> [B, C, S, Vocab]
427
+ audio_logits = logits_flat.view(
428
+ batch_size,
429
+ seq_len,
430
+ self.config.num_audio_codebook,
431
+ self.config.audio_vocab_size,
432
+ ).permute(0, 2, 1, 3)
433
+
434
+ if labels is not None:
435
+
436
+ # audio_logits.permute(0, 3, 1, 2):
437
+ # [Batch, Layer, Seq, Vocab] -> [Batch, Vocab, Layer, Seq]
438
+ # per_token_loss shape: [Batch, Layer, Seq],ignore -100
439
+ per_token_loss = torch.nn.functional.cross_entropy(
440
+ audio_logits.permute(0, 3, 1, 2),
441
+ labels,
442
+ reduction="none",
443
+ ignore_index=-100,
444
+ )
445
+ # valid_mask shape: [Batch, Layer, Seq]
446
+ valid_mask = (labels != -100).float()
447
+
448
+ # layer_means shape: [num_layers]
449
+ layer_means = (per_token_loss * valid_mask).sum(
450
+ dim=(0, 2)
451
+ ) / valid_mask.sum(dim=(0, 2)).clamp(min=1.0)
452
+
453
+ weights = torch.tensor(
454
+ self.normalized_audio_codebook_weights, device=audio_logits.device
455
+ )
456
+ loss = (layer_means * weights).sum()
457
+
458
+ return OmniVoiceModelOutput(
459
+ loss=loss,
460
+ logits=audio_logits,
461
+ )
462
+
463
+ def supported_language_ids(self) -> set[str]:
464
+ """Return a list of supported language IDs."""
465
+ return LANG_IDS
466
+
467
+ def supported_language_names(self) -> set[str]:
468
+ """Return a list of supported language names."""
469
+ return LANG_NAMES
470
+
471
+ # -------------------------------------------------------------------
472
+ # Inference API
473
+ # -------------------------------------------------------------------
474
+
475
+ @torch.inference_mode()
476
+ def generate(
477
+ self,
478
+ text: Union[str, list[str]],
479
+ language: Union[str, list[str], None] = None,
480
+ ref_text: Union[str, list[str], None] = None,
481
+ ref_audio: Union[
482
+ str,
483
+ list[str],
484
+ tuple[torch.Tensor, int],
485
+ list[tuple[torch.Tensor, int]],
486
+ None,
487
+ ] = None,
488
+ voice_clone_prompt: Union[
489
+ VoiceClonePrompt, list[VoiceClonePrompt], None
490
+ ] = None,
491
+ instruct: Union[str, list[str], None] = None,
492
+ duration: Union[float, list[Optional[float]], None] = None,
493
+ speed: Union[float, list[Optional[float]], None] = None,
494
+ generation_config: Optional[OmniVoiceGenerationConfig] = None,
495
+ **kwargs,
496
+ ) -> list[np.ndarray]:
497
+ """Generate speech audio given text in various modes.
498
+
499
+ Supports three modes:
500
+
501
+ 1. **Voice clone** — clone the voice style from the reference audio.
502
+ Should provide ``voice_clone_prompt`` (from
503
+ :meth:`create_voice_clone_prompt`) or ``ref_text`` + ``ref_audio``.
504
+ 2. **Voice design** — provide ``instruct`` text describing
505
+ the desired voice style; no reference audio needed.
506
+ 3. **Auto** — provide neither; the model picks a voice itself.
507
+
508
+ Args:
509
+ text: Target text (single string or list for batch).
510
+ language: Language name (e.g. ``"English"``) or code
511
+ (e.g. ``"en"``). ``None`` for language-agnostic mode.
512
+ Performance is slightly better if you specify the language.
513
+ ref_text: Optional reference text for voice cloning mode.
514
+ ref_audio: Optional reference audio for voice cloning mode.
515
+ Can be a file path or a (waveform, sample_rate) tuple.
516
+ voice_clone_prompt: Reusable prompt from :meth:`create_voice_clone_prompt`.
517
+ If provided, it overrides ``ref_text`` and ``ref_audio``.
518
+ instruct: Style instruction for voice design mode.
519
+ duration: Fixed output duration in seconds. If a single float,
520
+ applies to all items; if a list, one value per item.
521
+ ``None`` (default) lets the model estimate duration from text.
522
+ Overrides ``speed`` when both are provided.
523
+ speed: Speaking speed factor. ``> 1.0`` for faster, ``< 1.0`` for
524
+ slower. If a list, one value per item. ``None`` (default) uses
525
+ the model's default estimation.
526
+ generation_config: Explicit config object. If provided, takes
527
+ precedence over ``**kwargs``.
528
+ **kwargs: Generation config or its fields:
529
+ denoise: Whether to prepend the ``<|denoise|>`` token.
530
+ num_step: Number of iterative decoding steps.
531
+ guidance_scale: Classifier-free guidance scale.
532
+ t_shift: Time-step shift (smaller → emphasise low-SNR).
533
+ postprocess_output: Post-process output (remove silence, fade-in/out, pad edges).
534
+ layer_penalty_factor: Penalty encouraging earlier codebook
535
+ layers to unmask first.
536
+ position_temperature: Temperature for position selection.
537
+ class_temperature: Temperature for token sampling (0 = greedy).
538
+ audio_chunk_duration: If > 0, split long text into chunks of
539
+ this duration (seconds) and generate chunk by chunk.
540
+ audio_chunk_threshold: Only apply chunking if estimated audio
541
+ duration exceeds this threshold (seconds).
542
+ Returns:
543
+ ``audios`` a list of 1-D ``np.ndarray`` with shape ``(T,)`` and
544
+ sampling rate consistent with the model's audio tokenizer
545
+ (usually 24 000 Hz). Can be saved directly with
546
+ ``soundfile.write("out.wav", audios[0], model.sampling_rate)``.
547
+ """
548
+
549
+ if self.audio_tokenizer is None or self.text_tokenizer is None:
550
+ raise RuntimeError(
551
+ "Model is not loaded with audio/text tokenizers. Make sure you "
552
+ "loaded the model with OmniVoice.from_pretrained()."
553
+ )
554
+ gen_config = (
555
+ generation_config
556
+ if generation_config is not None
557
+ else OmniVoiceGenerationConfig.from_dict(kwargs)
558
+ )
559
+
560
+ self.eval()
561
+
562
+ full_task = self._preprocess_all(
563
+ text=text,
564
+ language=language,
565
+ ref_text=ref_text,
566
+ ref_audio=ref_audio,
567
+ voice_clone_prompt=voice_clone_prompt,
568
+ instruct=instruct,
569
+ preprocess_prompt=gen_config.preprocess_prompt,
570
+ speed=speed,
571
+ duration=duration,
572
+ )
573
+
574
+ short_idx, long_idx = full_task.get_indices(
575
+ gen_config, self.audio_tokenizer.config.frame_rate
576
+ )
577
+
578
+ results = [None] * full_task.batch_size
579
+
580
+ if short_idx:
581
+ short_task = full_task.slice_task(short_idx)
582
+ short_results = self._generate_iterative(short_task, gen_config)
583
+ for idx, res in zip(short_idx, short_results):
584
+ results[idx] = res
585
+
586
+ if long_idx:
587
+ long_task = full_task.slice_task(long_idx)
588
+ long_results = self._generate_chunked(long_task, gen_config)
589
+ for idx, res in zip(long_idx, long_results):
590
+ results[idx] = res
591
+
592
+ generated_audios = []
593
+ for i in range(full_task.batch_size):
594
+ assert results[i] is not None, f"Result {i} was not generated"
595
+ generated_audios.append(
596
+ self._decode_and_post_process(
597
+ results[i], full_task.ref_rms[i], gen_config # type: ignore[arg-type]
598
+ )
599
+ )
600
+
601
+ return generated_audios
602
+
603
+ def create_voice_clone_prompt(
604
+ self,
605
+ ref_audio: Union[str, tuple[torch.Tensor, int]],
606
+ ref_text: Optional[str] = None,
607
+ preprocess_prompt: bool = True,
608
+ ) -> VoiceClonePrompt:
609
+ """Create a reusable voice clone prompt from reference audio.
610
+
611
+ Args:
612
+ ref_audio: File path (str) or ``(waveform, sample_rate)`` tuple.
613
+ waveform should be a 1-D or 2-D torch.Tensor (channels x samples).
614
+ ref_text: Transcript of the reference audio. If ``None``, the
615
+ ASR model will be used to auto-transcribe (must call
616
+ :meth:`load_asr_model` first).
617
+ preprocess_prompt: If ``True`` (default), apply silence removal and
618
+ trimming to the reference audio, add punctuation in the end
619
+ of reference text (if not already)
620
+
621
+ Returns:
622
+ A :class:`VoiceClonePrompt` that can be passed to :meth:`generate`.
623
+ """
624
+ if self.audio_tokenizer is None:
625
+ raise RuntimeError(
626
+ "Audio tokenizer is not loaded. Make sure you loaded the model "
627
+ "with OmniVoice.from_pretrained()."
628
+ )
629
+
630
+ if isinstance(ref_audio, str):
631
+ ref_wav = load_audio(ref_audio, self.sampling_rate)
632
+ else:
633
+ waveform, sr = ref_audio
634
+ if isinstance(waveform, torch.Tensor):
635
+ waveform = waveform.cpu().numpy()
636
+ if waveform.ndim == 1:
637
+ waveform = waveform[np.newaxis, :]
638
+ if waveform.shape[0] > 1:
639
+ waveform = np.mean(waveform, axis=0, keepdims=True)
640
+ if sr != self.sampling_rate:
641
+ waveform = torchaudio.functional.resample(
642
+ torch.from_numpy(waveform),
643
+ orig_freq=sr,
644
+ new_freq=self.sampling_rate,
645
+ ).numpy()
646
+ ref_wav = waveform
647
+
648
+ ref_rms = float(np.sqrt(np.mean(ref_wav**2)))
649
+ if 0 < ref_rms < 0.1:
650
+ ref_wav = ref_wav * 0.1 / ref_rms
651
+
652
+ if preprocess_prompt:
653
+ # Trim long reference audio (>20s) by splitting at the largest silence gap.
654
+ # Skip trimming when ref_text is user-provided, otherwise the
655
+ # trimmed audio will no longer match the full transcript.
656
+ if ref_text is None:
657
+ ref_wav = trim_long_audio(
658
+ ref_wav, self.sampling_rate, trim_threshold=20.0
659
+ )
660
+ ref_wav = remove_silence(
661
+ ref_wav,
662
+ self.sampling_rate,
663
+ mid_sil=200,
664
+ lead_sil=100,
665
+ trail_sil=200,
666
+ )
667
+ if ref_wav.shape[-1] == 0:
668
+ raise ValueError(
669
+ "Reference audio is empty after silence removal. "
670
+ "Try setting preprocess_prompt=False."
671
+ )
672
+
673
+ ref_duration = ref_wav.shape[-1] / self.sampling_rate
674
+ if ref_duration > 20.0:
675
+ logger.warning(
676
+ "Reference audio is %.1fs long (>20s). This may cause slower "
677
+ "generation, higher memory usage, and degraded voice cloning "
678
+ "quality. We recommend trimming it to 3-10s.",
679
+ ref_duration,
680
+ )
681
+
682
+ # Auto-transcribe if ref_text not provided
683
+ if ref_text is None:
684
+ if self._asr_pipe is None:
685
+ logger.info("ASR model not loaded yet, loading on-the-fly ...")
686
+ self.load_asr_model()
687
+ ref_text = self.transcribe((ref_wav, self.sampling_rate))
688
+ logger.debug("Auto-transcribed ref_text: %s", ref_text)
689
+
690
+ chunk_size = self.audio_tokenizer.config.hop_length
691
+ clip_size = int(ref_wav.shape[-1] % chunk_size)
692
+ ref_wav = ref_wav[:, :-clip_size] if clip_size > 0 else ref_wav
693
+ # numpy → torch at tokenizer boundary
694
+ ref_wav_tensor = torch.from_numpy(ref_wav).to(self.audio_tokenizer.device)
695
+ ref_audio_tokens = self.audio_tokenizer.encode(
696
+ ref_wav_tensor.unsqueeze(0),
697
+ ).audio_codes.squeeze(
698
+ 0
699
+ ) # (C, T)
700
+
701
+ if preprocess_prompt:
702
+ ref_text = add_punctuation(ref_text)
703
+
704
+ return VoiceClonePrompt(
705
+ ref_audio_tokens=ref_audio_tokens,
706
+ ref_text=ref_text,
707
+ ref_rms=ref_rms,
708
+ )
709
+
710
+ def _decode_and_post_process(
711
+ self,
712
+ tokens: Union[torch.Tensor, List[torch.Tensor]],
713
+ rms: Union[float, None],
714
+ gen_config: OmniVoiceGenerationConfig,
715
+ ) -> np.ndarray:
716
+ """
717
+ Args:
718
+ tokens: Audio tokens — either a single tensor of shape
719
+ (num_codebooks, seq_len) or a list of chunk tensors.
720
+ rms: RMS of the reference audio for volume adjustment.
721
+ gen_config: Generation config for post-processing options.
722
+ Returns:
723
+ Decoded and post-processed audio array of shape (T,).
724
+ """
725
+ tokenizer_device = self.audio_tokenizer.device
726
+ if isinstance(tokens, list):
727
+ chunk_audios = [
728
+ self.audio_tokenizer.decode(t.to(tokenizer_device).unsqueeze(0))
729
+ .audio_values[0]
730
+ .cpu()
731
+ .numpy()
732
+ for t in tokens
733
+ ]
734
+ audio_waveform = cross_fade_chunks(chunk_audios, self.sampling_rate)
735
+ else:
736
+ audio_waveform = (
737
+ self.audio_tokenizer.decode(tokens.to(tokenizer_device).unsqueeze(0))
738
+ .audio_values[0]
739
+ .cpu()
740
+ .numpy()
741
+ )
742
+
743
+ audio_waveform = self._post_process_audio(
744
+ audio_waveform,
745
+ postprocess_output=gen_config.postprocess_output,
746
+ ref_rms=rms,
747
+ )
748
+ return audio_waveform.squeeze(0)
749
+
750
+ def _post_process_audio(
751
+ self,
752
+ generated_audio: np.ndarray,
753
+ postprocess_output: bool,
754
+ ref_rms: Union[float, None],
755
+ ) -> np.ndarray:
756
+ """Optionally remove long silences, adjust volume, and add edge padding.
757
+
758
+ Args:
759
+ generated_audio: Numpy array of shape (1, T).
760
+ postprocess_output: If True, remove long silences and apply fade/pad.
761
+ ref_rms: RMS of the reference audio for volume normalisation.
762
+ Returns:
763
+ Processed numpy array of shape (1, T).
764
+ """
765
+ if postprocess_output:
766
+ generated_audio = remove_silence(
767
+ generated_audio,
768
+ self.sampling_rate,
769
+ mid_sil=500,
770
+ lead_sil=100,
771
+ trail_sil=100,
772
+ )
773
+
774
+ if ref_rms is not None and ref_rms < 0.1:
775
+ generated_audio = generated_audio * ref_rms / 0.1
776
+ elif ref_rms is None:
777
+ peak = np.abs(generated_audio).max()
778
+ if peak > 1e-6:
779
+ generated_audio = generated_audio / peak * 0.5
780
+
781
+ generated_audio = fade_and_pad_audio(
782
+ generated_audio,
783
+ sample_rate=self.sampling_rate,
784
+ )
785
+ return generated_audio
786
+
787
+ def _generate_chunked(
788
+ self, task: GenerationTask, gen_config: OmniVoiceGenerationConfig
789
+ ) -> List[List[torch.Tensor]]:
790
+ """Generate long audio by splitting text into chunks and batching.
791
+
792
+ Each item in the returned list corresponds to one input and contains
793
+ a list of audio token tensors — one per text chunk.
794
+
795
+ Args:
796
+ task: A :class:`GenerationTask` with one or more items whose
797
+ estimated audio exceeds ``audio_chunk_threshold``.
798
+ gen_config: Generation config (``audio_chunk_duration`` controls
799
+ chunk size).
800
+ Returns:
801
+ Per-item list of chunk token-tensor lists.
802
+ """
803
+ # Chunk each item's text
804
+ all_chunks = []
805
+ for i in range(task.batch_size):
806
+ avg_tokens_per_char = task.target_lens[i] / len(task.texts[i])
807
+ text_chunk_len = int(
808
+ gen_config.audio_chunk_duration
809
+ * self.audio_tokenizer.config.frame_rate
810
+ / avg_tokens_per_char
811
+ )
812
+ chunks = chunk_text_punctuation(
813
+ text=task.texts[i],
814
+ chunk_len=text_chunk_len,
815
+ min_chunk_len=3,
816
+ )
817
+ logger.debug(f"Item {i} chunked into {len(chunks)} pieces: {chunks}")
818
+ all_chunks.append(chunks)
819
+
820
+ has_ref = [t is not None for t in task.ref_audio_tokens]
821
+ assert all(has_ref) or not any(has_ref), (
822
+ "Chunked inference requires all items to either have or not have "
823
+ "ref_audio. Mixed ref/non-ref is not supported."
824
+ )
825
+
826
+ max_num_chunks = max(len(c) for c in all_chunks)
827
+
828
+ # chunk_results[item_idx] = list of generated token tensors per chunk
829
+ chunk_results = [[] for _ in range(task.batch_size)]
830
+
831
+ def _run_batch(indices, texts, ref_audios, ref_texts):
832
+ speed_list = task.speed
833
+ target_lens = [
834
+ self._estimate_target_tokens(
835
+ texts[j],
836
+ ref_texts[j],
837
+ ref_audios[j].size(-1) if ref_audios[j] is not None else None,
838
+ speed=speed_list[i] if speed_list else 1.0,
839
+ )
840
+ for j, i in enumerate(indices)
841
+ ]
842
+ sub_task = GenerationTask(
843
+ batch_size=len(indices),
844
+ texts=texts,
845
+ target_lens=target_lens,
846
+ langs=[task.langs[i] for i in indices],
847
+ instructs=[task.instructs[i] for i in indices],
848
+ ref_texts=ref_texts,
849
+ ref_audio_tokens=ref_audios,
850
+ ref_rms=[task.ref_rms[i] for i in indices],
851
+ speed=[task.speed[i] for i in indices] if task.speed else None,
852
+ )
853
+ gen_tokens = self._generate_iterative(sub_task, gen_config)
854
+ for j, idx in enumerate(indices):
855
+ chunk_results[idx].append(gen_tokens[j])
856
+
857
+ if all(has_ref):
858
+ # All items have reference audio.
859
+ # We still sequentially generate chunks within each item, but we
860
+ # batch across items for the same chunk index. This allows to keep
861
+ # the VRAM usage manageable while still benefiting from batching.
862
+ for ci in range(max_num_chunks):
863
+ indices = [i for i in range(task.batch_size) if ci < len(all_chunks[i])]
864
+ if not indices:
865
+ continue
866
+ _run_batch(
867
+ indices,
868
+ texts=[all_chunks[i][ci] for i in indices],
869
+ ref_audios=[task.ref_audio_tokens[i] for i in indices],
870
+ ref_texts=[task.ref_texts[i] for i in indices],
871
+ )
872
+ else:
873
+ # No reference audio — generate chunk 0 for all items first,
874
+ # then use chunk 0 output as reference for all subsequent chunks.
875
+ indices_0 = [i for i in range(task.batch_size) if len(all_chunks[i]) > 0]
876
+ _run_batch(
877
+ indices_0,
878
+ texts=[all_chunks[i][0] for i in indices_0],
879
+ ref_audios=[None] * len(indices_0),
880
+ ref_texts=[None] * len(indices_0),
881
+ )
882
+ first_chunk_map = {idx: chunk_results[idx][0] for idx in indices_0}
883
+
884
+ # Batch all remaining chunks, using chunk 0 as fixed reference
885
+ for ci in range(1, max_num_chunks):
886
+ indices = [i for i in range(task.batch_size) if ci < len(all_chunks[i])]
887
+ if not indices:
888
+ continue
889
+ _run_batch(
890
+ indices,
891
+ texts=[all_chunks[i][ci] for i in indices],
892
+ ref_audios=[first_chunk_map[i] for i in indices],
893
+ ref_texts=[all_chunks[i][0] for i in indices],
894
+ )
895
+
896
+ return chunk_results
897
+
898
+ def _preprocess_all(
899
+ self,
900
+ text: Union[str, list[str]],
901
+ language: Union[str, list[str], None] = None,
902
+ ref_text: Union[str, list[str], None] = None,
903
+ ref_audio: Union[
904
+ str,
905
+ list[str],
906
+ tuple[torch.Tensor, int],
907
+ list[tuple[torch.Tensor, int]],
908
+ None,
909
+ ] = None,
910
+ voice_clone_prompt: Union[
911
+ VoiceClonePrompt, list[VoiceClonePrompt], None
912
+ ] = None,
913
+ instruct: Union[str, list[str], None] = None,
914
+ preprocess_prompt: bool = True,
915
+ speed: Union[float, list[Optional[float]], None] = None,
916
+ duration: Union[float, list[Optional[float]], None] = None,
917
+ ) -> GenerationTask:
918
+
919
+ if isinstance(text, str):
920
+ text_list = [text]
921
+ else:
922
+ assert isinstance(
923
+ text, list
924
+ ), "text should be a string or a list of strings"
925
+ text_list = text
926
+ batch_size = len(text_list)
927
+
928
+ language_list = self._ensure_list(language, batch_size)
929
+ language_list = [_resolve_language(lang) for lang in language_list]
930
+ instruct_list = self._ensure_list(instruct, batch_size)
931
+ for i, s in enumerate(instruct_list):
932
+ if s is None:
933
+ continue
934
+ use_zh = bool(text_list[i] and _ZH_RE.search(text_list[i]))
935
+ instruct_list[i] = _resolve_instruct(s, use_zh=use_zh)
936
+
937
+ if voice_clone_prompt is not None and (
938
+ ref_text is not None or ref_audio is not None
939
+ ):
940
+ logger.warning(
941
+ "Both voice_clone_prompt and ref_text/ref_audio are provided. "
942
+ "ref_text/ref_audio will be ignored."
943
+ )
944
+ if voice_clone_prompt is None and ref_audio is not None:
945
+ # If voice_clone_prompt is not provided, create it from
946
+ # ref_audio (ref_text will be auto-transcribed if not given).
947
+ ref_text_list = self._ensure_list(ref_text, batch_size, auto_repeat=False)
948
+ ref_audio_list = self._ensure_list(ref_audio, batch_size, auto_repeat=False)
949
+
950
+ voice_clone_prompt = []
951
+ for i in range(len(ref_text_list)):
952
+ voice_clone_prompt.append(
953
+ self.create_voice_clone_prompt(
954
+ ref_audio=ref_audio_list[i],
955
+ ref_text=ref_text_list[i],
956
+ preprocess_prompt=preprocess_prompt,
957
+ )
958
+ )
959
+
960
+ voice_clone_prompt_list = self._ensure_list(voice_clone_prompt, batch_size)
961
+ if voice_clone_prompt_list[0] is not None:
962
+ ref_text_list = [vc.ref_text for vc in voice_clone_prompt_list]
963
+ ref_audio_tokens_list = [
964
+ vc.ref_audio_tokens for vc in voice_clone_prompt_list
965
+ ]
966
+ ref_rms_list = [vc.ref_rms for vc in voice_clone_prompt_list]
967
+ else:
968
+ ref_text_list = [None] * batch_size
969
+ ref_audio_tokens_list = [None] * batch_size
970
+ ref_rms_list = [None] * batch_size
971
+
972
+ # Normalize speed/duration to per-item lists (may contain None).
973
+ if speed is not None:
974
+ if isinstance(speed, (int, float)):
975
+ user_speed = [float(speed)] * batch_size
976
+ else:
977
+ user_speed = list(speed)
978
+ else:
979
+ user_speed = None
980
+
981
+ if duration is not None:
982
+ if isinstance(duration, (int, float)):
983
+ durations = [float(duration)] * batch_size
984
+ else:
985
+ durations = list(duration)
986
+ else:
987
+ durations = None
988
+
989
+ num_target_tokens_list = []
990
+ for i in range(batch_size):
991
+ # duration[i] overrides speed for estimation: use speed=1.0
992
+ # to get the raw estimate, then override target_lens below.
993
+ has_dur = durations is not None and durations[i] is not None
994
+ item_speed = 1.0 if has_dur else (user_speed[i] if user_speed else 1.0)
995
+ est = self._estimate_target_tokens(
996
+ text_list[i],
997
+ ref_text_list[i],
998
+ ref_audio_tokens_list[i].size(-1)
999
+ if ref_audio_tokens_list[i] is not None
1000
+ else None,
1001
+ speed=item_speed,
1002
+ )
1003
+ num_target_tokens_list.append(est)
1004
+
1005
+ # Per-item duration overrides: set target_lens to exact frame count
1006
+ # and compute speed ratio so chunked generation scales proportionally.
1007
+ speed_list: Optional[List[float]] = None
1008
+ if durations is not None:
1009
+ frame_rate = self.audio_tokenizer.config.frame_rate
1010
+ speed_list = []
1011
+ for i in range(batch_size):
1012
+ if durations[i] is not None:
1013
+ target_tokens = max(1, int(durations[i] * frame_rate))
1014
+ est = num_target_tokens_list[i]
1015
+ speed_list.append(est / target_tokens if target_tokens > 0 else 1.0)
1016
+ num_target_tokens_list[i] = target_tokens
1017
+ else:
1018
+ s = user_speed[i] if user_speed else None
1019
+ speed_list.append(s if s is not None else 1.0)
1020
+ elif user_speed is not None:
1021
+ speed_list = [s if s is not None else 1.0 for s in user_speed]
1022
+
1023
+ return GenerationTask(
1024
+ batch_size=batch_size,
1025
+ texts=text_list,
1026
+ target_lens=num_target_tokens_list,
1027
+ langs=language_list,
1028
+ instructs=instruct_list,
1029
+ ref_texts=ref_text_list,
1030
+ ref_audio_tokens=ref_audio_tokens_list,
1031
+ ref_rms=ref_rms_list,
1032
+ speed=speed_list,
1033
+ )
1034
+
1035
+ def _estimate_target_tokens(self, text, ref_text, num_ref_audio_tokens, speed=1.0):
1036
+ """Estimate number of target audio tokens."""
1037
+ if num_ref_audio_tokens is None or ref_text is None or len(ref_text) == 0:
1038
+ # Fall back to a simple heuristic
1039
+ ref_text = "Nice to meet you."
1040
+ num_ref_audio_tokens = 25
1041
+
1042
+ est = self.duration_estimator.estimate_duration(
1043
+ text, ref_text, num_ref_audio_tokens
1044
+ )
1045
+ if speed > 0 and speed != 1.0:
1046
+ est = est / speed
1047
+ return max(1, int(est))
1048
+
1049
+ def _ensure_list(
1050
+ self, x: Union[Any, List[Any]], batch_size: int, auto_repeat: bool = True
1051
+ ) -> List[Any]:
1052
+ x_list = x if isinstance(x, list) else [x]
1053
+ if len(x_list) not in (
1054
+ 1,
1055
+ batch_size,
1056
+ ):
1057
+ raise ValueError(
1058
+ f"should be either the number of the text or 1, but got {len(x_list)}"
1059
+ )
1060
+ if auto_repeat and len(x_list) == 1 and batch_size is not None:
1061
+ x_list = x_list * batch_size
1062
+ return x_list
1063
+
1064
+ def _prepare_inference_inputs(
1065
+ self,
1066
+ text: str,
1067
+ num_target_tokens: int,
1068
+ ref_text: Optional[str] = None,
1069
+ ref_audio_tokens: Optional[torch.Tensor] = None,
1070
+ lang: Optional[str] = None,
1071
+ instruct: Optional[str] = None,
1072
+ denoise: bool = True,
1073
+ ):
1074
+ """Prepare input_ids and audio masks for inference.
1075
+ Args:
1076
+ text: Target text to generate.
1077
+ num_target_tokens: Number of audio tokens to generate.
1078
+ ref_text: Optional reference text for voice cloning.
1079
+ ref_audio_tokens: Optional reference audio tokens for voice cloning.
1080
+ with shape (C, T).
1081
+ lang: Optional language ID.
1082
+ instruct: Optional style instruction for voice design.
1083
+ denoise: Whether to include the <|denoise|> token.
1084
+ """
1085
+
1086
+ # Build style tokens: <|denoise|> + <|lang_start|>...<|lang_end|>
1087
+ # + <|instruct_start|>...<|instruct_end|>
1088
+ style_text = ""
1089
+ if denoise and ref_audio_tokens is not None:
1090
+ style_text += "<|denoise|>"
1091
+ lang_str = lang if lang else "None"
1092
+ instruct_str = instruct if instruct else "None"
1093
+ style_text += f"<|lang_start|>{lang_str}<|lang_end|>"
1094
+ style_text += f"<|instruct_start|>{instruct_str}<|instruct_end|>"
1095
+
1096
+ style_tokens = (
1097
+ self.text_tokenizer(style_text, return_tensors="pt")
1098
+ .input_ids.repeat(self.config.num_audio_codebook, 1)
1099
+ .unsqueeze(0)
1100
+ ).to(
1101
+ self.device
1102
+ ) # [1, C, N1]
1103
+
1104
+ # Build text tokens
1105
+ full_text = _combine_text(ref_text=ref_text, text=text)
1106
+ wrapped_text = f"<|text_start|>{full_text}<|text_end|>"
1107
+ text_tokens = (
1108
+ _tokenize_with_nonverbal_tags(wrapped_text, self.text_tokenizer)
1109
+ .repeat(self.config.num_audio_codebook, 1)
1110
+ .unsqueeze(0)
1111
+ ).to(
1112
+ self.device
1113
+ ) # [1, C, N2]
1114
+
1115
+ # Target: all MASK
1116
+ target_audio_tokens = torch.full(
1117
+ (1, self.config.num_audio_codebook, num_target_tokens),
1118
+ self.config.audio_mask_id,
1119
+ dtype=torch.long,
1120
+ device=self.device,
1121
+ )
1122
+
1123
+ # Conditional input
1124
+ parts = [style_tokens, text_tokens]
1125
+ if ref_audio_tokens is not None:
1126
+ parts.append(ref_audio_tokens.unsqueeze(0).to(self.device))
1127
+ parts.append(target_audio_tokens)
1128
+ cond_input_ids = torch.cat(parts, dim=2)
1129
+
1130
+ cond_total_length = cond_input_ids.shape[2]
1131
+ cond_audio_start_idx = cond_total_length - num_target_tokens
1132
+ if ref_audio_tokens is not None:
1133
+ cond_audio_start_idx -= ref_audio_tokens.size(-1)
1134
+
1135
+ cond_audio_mask = torch.zeros(
1136
+ 1, cond_total_length, dtype=torch.bool, device=self.device
1137
+ )
1138
+ cond_audio_mask[0, cond_audio_start_idx:] = True
1139
+
1140
+ return {
1141
+ "input_ids": cond_input_ids,
1142
+ "audio_mask": cond_audio_mask,
1143
+ }
1144
+
1145
+ def _generate_iterative(
1146
+ self, task: GenerationTask, gen_config: OmniVoiceGenerationConfig
1147
+ ) -> List[torch.Tensor]:
1148
+ """N-step iterative unmasked decoding.
1149
+
1150
+ Args:
1151
+ task: A :class:`GenerationTask` containing batch texts, target
1152
+ lengths, languages, instructions, and optional reference data.
1153
+ gen_config: A :class:`OmniVoiceGenerationConfig` controlling
1154
+ decoding steps, guidance, temperatures, etc.
1155
+ Returns:
1156
+ List of generated audio token tensors of shape (C, T) (one per
1157
+ input text).
1158
+ """
1159
+
1160
+ B = task.batch_size
1161
+
1162
+ for i in range(B):
1163
+ logger.debug(
1164
+ "Item %d — text: %s | ref_text: %s | instruct: %s | lang: %s | target_tokens: %d",
1165
+ i,
1166
+ task.texts[i],
1167
+ task.ref_texts[i],
1168
+ task.instructs[i],
1169
+ task.langs[i],
1170
+ task.target_lens[i],
1171
+ )
1172
+
1173
+ inputs_list = [
1174
+ self._prepare_inference_inputs(
1175
+ task.texts[i],
1176
+ task.target_lens[i],
1177
+ task.ref_texts[i],
1178
+ task.ref_audio_tokens[i],
1179
+ task.langs[i],
1180
+ task.instructs[i],
1181
+ gen_config.denoise,
1182
+ )
1183
+ for i in range(B)
1184
+ ]
1185
+
1186
+ c_lens = [inp["input_ids"].size(2) for inp in inputs_list]
1187
+ max_c_len = max(c_lens)
1188
+ pad_id = self.config.audio_mask_id # Or any other tokens
1189
+
1190
+ batch_input_ids = torch.full(
1191
+ (2 * B, self.config.num_audio_codebook, max_c_len),
1192
+ pad_id,
1193
+ dtype=torch.long,
1194
+ device=self.device,
1195
+ )
1196
+ batch_audio_mask = torch.zeros(
1197
+ (2 * B, max_c_len), dtype=torch.bool, device=self.device
1198
+ )
1199
+ batch_attention_mask = torch.zeros(
1200
+ (2 * B, 1, max_c_len, max_c_len), dtype=torch.bool, device=self.device
1201
+ )
1202
+
1203
+ for i, inp in enumerate(inputs_list):
1204
+ c_len, u_len = c_lens[i], task.target_lens[i]
1205
+
1206
+ # Cond (0 ~ B-1)
1207
+ batch_input_ids[i, :, :c_len] = inp["input_ids"]
1208
+ batch_audio_mask[i, :c_len] = inp["audio_mask"]
1209
+ batch_attention_mask[i, :, :c_len, :c_len] = True
1210
+
1211
+ # Uncond (B ~ 2B-1)
1212
+ batch_input_ids[B + i, :, :u_len] = inp["input_ids"][..., -u_len:]
1213
+ batch_audio_mask[B + i, :u_len] = inp["audio_mask"][..., -u_len:]
1214
+ batch_attention_mask[B + i, :, :u_len, :u_len] = True
1215
+ if max_c_len > u_len:
1216
+ pad_diag = torch.arange(u_len, max_c_len, device=self.device)
1217
+ batch_attention_mask[B + i, :, pad_diag, pad_diag] = True
1218
+
1219
+ tokens = torch.full(
1220
+ (B, self.config.num_audio_codebook, max(task.target_lens)),
1221
+ self.config.audio_mask_id,
1222
+ dtype=torch.long,
1223
+ device=self.device,
1224
+ )
1225
+
1226
+ timesteps = _get_time_steps(
1227
+ t_start=0.0,
1228
+ t_end=1.0,
1229
+ num_step=gen_config.num_step,
1230
+ t_shift=gen_config.t_shift,
1231
+ ).tolist()
1232
+ schedules = []
1233
+ for t_len in task.target_lens:
1234
+ total_mask = t_len * self.config.num_audio_codebook
1235
+ rem = total_mask
1236
+ sched = []
1237
+ for step in range(gen_config.num_step):
1238
+ num = (
1239
+ rem
1240
+ if step == gen_config.num_step - 1
1241
+ else min(
1242
+ math.ceil(total_mask * (timesteps[step + 1] - timesteps[step])),
1243
+ rem,
1244
+ )
1245
+ )
1246
+ sched.append(int(num))
1247
+ rem -= int(num)
1248
+ schedules.append(sched)
1249
+
1250
+ layer_ids = torch.arange(
1251
+ self.config.num_audio_codebook, device=self.device
1252
+ ).view(1, -1, 1)
1253
+
1254
+ for step in range(gen_config.num_step):
1255
+ batch_logits = self(
1256
+ input_ids=batch_input_ids,
1257
+ audio_mask=batch_audio_mask,
1258
+ attention_mask=batch_attention_mask,
1259
+ ).logits.to(torch.float32)
1260
+
1261
+ for i in range(B):
1262
+ k = schedules[i][step]
1263
+ if k <= 0:
1264
+ continue
1265
+
1266
+ c_len, t_len = c_lens[i], task.target_lens[i]
1267
+
1268
+ # Extract real target Logits
1269
+ # [1, C, T, V]
1270
+ c_logits = batch_logits[i : i + 1, :, c_len - t_len : c_len, :]
1271
+ u_logits = batch_logits[B + i : B + i + 1, :, :t_len, :]
1272
+
1273
+ pred_tokens, scores = self._predict_tokens_with_scoring(
1274
+ c_logits, u_logits, gen_config
1275
+ )
1276
+
1277
+ scores = scores - (layer_ids * gen_config.layer_penalty_factor)
1278
+
1279
+ if gen_config.position_temperature > 0.0:
1280
+ scores = _gumbel_sample(scores, gen_config.position_temperature)
1281
+
1282
+ sample_tokens = tokens[i : i + 1, :, :t_len]
1283
+ scores.masked_fill_(
1284
+ sample_tokens != self.config.audio_mask_id, -float("inf")
1285
+ )
1286
+
1287
+ _, topk_idx = torch.topk(scores.flatten(), k)
1288
+ flat_tokens = sample_tokens.flatten()
1289
+ flat_tokens[topk_idx] = pred_tokens.flatten()[topk_idx]
1290
+ sample_tokens.copy_(flat_tokens.view_as(sample_tokens))
1291
+
1292
+ # Update individual slices into batched structure
1293
+ tokens[i : i + 1, :, :t_len] = sample_tokens
1294
+ batch_input_ids[i : i + 1, :, c_len - t_len : c_len] = sample_tokens
1295
+ batch_input_ids[B + i : B + i + 1, :, :t_len] = sample_tokens
1296
+
1297
+ return [tokens[i, :, : task.target_lens[i]] for i in range(B)]
1298
+
1299
+ def _predict_tokens_with_scoring(self, c_logits, u_logits, gen_config):
1300
+ if gen_config.guidance_scale != 0:
1301
+ c_log_probs = F.log_softmax(c_logits, dim=-1)
1302
+ u_log_probs = F.log_softmax(u_logits, dim=-1)
1303
+ log_probs = torch.log_softmax(
1304
+ c_log_probs + gen_config.guidance_scale * (c_log_probs - u_log_probs),
1305
+ dim=-1,
1306
+ )
1307
+ else:
1308
+ log_probs = F.log_softmax(c_logits, dim=-1)
1309
+
1310
+ log_probs[..., self.config.audio_mask_id] = -float("inf")
1311
+
1312
+ if gen_config.class_temperature > 0.0:
1313
+ filtered_probs = _filter_top_k(log_probs, ratio=0.1)
1314
+ pred_tokens = _gumbel_sample(
1315
+ filtered_probs, gen_config.class_temperature
1316
+ ).argmax(dim=-1)
1317
+ else:
1318
+ pred_tokens = log_probs.argmax(dim=-1)
1319
+
1320
+ confidence_scores = log_probs.max(dim=-1)[0]
1321
+
1322
+ return pred_tokens, confidence_scores
1323
+
1324
+
1325
+ # ---------------------------------------------------------------------------
1326
+ # Standalone helpers
1327
+ # ---------------------------------------------------------------------------
1328
+
1329
+
1330
+ def _get_packed_mask(document_ids):
1331
+ return partial(_mask_mod_packed, document_ids)
1332
+
1333
+
1334
+ def _mask_mod_packed(document_ids, b, h, q_idx, kv_idx):
1335
+ # 1. Sequence Packing Logic: Tokens must belong to the same document.
1336
+ # Note: The doc_id for padding tokens is -1, which will automatically not match
1337
+ # (if handled correctly) or be ignored.
1338
+ same_doc = document_ids[q_idx] == document_ids[kv_idx]
1339
+ return same_doc
1340
+
1341
+
1342
+ def _resolve_language(language: Optional[str]) -> Union[str, None]:
1343
+ from omnivoice.utils.lang_map import LANG_IDS, LANG_NAME_TO_ID
1344
+
1345
+ if language is None or language.lower() == "none":
1346
+ return None
1347
+ if language in LANG_IDS:
1348
+ return language
1349
+ key = language.lower()
1350
+ if key in LANG_NAME_TO_ID:
1351
+ return LANG_NAME_TO_ID[key]
1352
+ logger.warning(
1353
+ f"Language '{language}' is not recognized. "
1354
+ f"Please use a valid language ID (e.g., 'en', 'zh', 'ja', 'de') "
1355
+ f"or a full language name (e.g., 'English', 'Chinese', 'Japanese'). "
1356
+ f"See supported_language_ids() or supported_language_names() for details. "
1357
+ f"Falling back to None (language-agnostic mode)."
1358
+ )
1359
+ return None
1360
+
1361
+
1362
+ def _resolve_instruct(
1363
+ instruct: Optional[str], use_zh: bool = False
1364
+ ) -> Union[str, None]:
1365
+ """Validate and normalise a voice-design instruct string.
1366
+
1367
+ Supported instruct items (case-insensitive for English):
1368
+
1369
+ English (comma + space separated):
1370
+ gender: male, female
1371
+ age: child, teenager, young adult, middle-aged, elderly
1372
+ pitch: very low pitch, low pitch, moderate pitch,
1373
+ high pitch, very high pitch
1374
+ style: whisper
1375
+ accent: american accent, british accent, australian accent, ...
1376
+
1377
+ Chinese (full-width comma separated):
1378
+ gender: 男, 女
1379
+ age: 儿童, 少年, 青年, 中年, 老年
1380
+ pitch: 极低音调, 低音调, 中音调, 高音调, 极高音调
1381
+ style: 耳语
1382
+ dialect: 河南话, 陕西话, 四川话, 贵州话, 云南话,
1383
+ 桂林话, 济南话, 石家庄话, 甘肃话, 宁夏话,
1384
+ 青岛话, 东北话
1385
+
1386
+ Minor issues (auto-fixed):
1387
+ - Wrong separator (half-width comma in Chinese instruct or
1388
+ full-width comma in English instruct)
1389
+ - Leading / trailing commas
1390
+
1391
+ Major issues (raise ``ValueError``):
1392
+ - Unsupported or misspelled instruct items
1393
+ - Suggestions are offered for close matches
1394
+
1395
+ Args:
1396
+ instruct: Raw instruct string, or ``None``.
1397
+ use_zh: If True, normalise all items to Chinese (used when the
1398
+ synthesis text contains Chinese and no accent is specified).
1399
+
1400
+ Returns:
1401
+ Normalised instruct string, or ``None``.
1402
+
1403
+ Raises:
1404
+ ValueError: if any instruct item is unsupported or misspelled.
1405
+ """
1406
+ if instruct is None:
1407
+ return None
1408
+
1409
+ instruct_str = instruct.strip()
1410
+ if not instruct_str:
1411
+ return None
1412
+
1413
+ # Split on both half-width and full-width commas
1414
+ raw_items = re.split(r"\s*[,,]\s*", instruct_str)
1415
+ raw_items = [x for x in raw_items if x]
1416
+
1417
+ # Validate each item
1418
+ unknown = []
1419
+ normalised = []
1420
+ for raw in raw_items:
1421
+ n = raw.strip().lower()
1422
+ if n in _INSTRUCT_ALL_VALID:
1423
+ normalised.append(n)
1424
+ else:
1425
+ sug = difflib.get_close_matches(n, _INSTRUCT_ALL_VALID, n=1, cutoff=0.6)
1426
+ unknown.append((raw, n, sug[0] if sug else None))
1427
+
1428
+ if unknown:
1429
+ lines = []
1430
+ for raw, n, sug in unknown:
1431
+ if sug:
1432
+ lines.append(f" '{raw}' -> '{n}' (unsupported; did you mean '{sug}'?)")
1433
+ else:
1434
+ lines.append(f" '{raw}' -> '{n}' (unsupported)")
1435
+ err = (
1436
+ f"Unsupported instruct items found in {instruct_str}:\n"
1437
+ + "\n".join(lines)
1438
+ + "\n\nValid English items: "
1439
+ + ", ".join(sorted(_INSTRUCT_VALID_EN))
1440
+ + "\nValid Chinese items: "
1441
+ + ",".join(sorted(_INSTRUCT_VALID_ZH))
1442
+ + "\n\nTip: Use only English or only Chinese instructs. "
1443
+ "English instructs should use comma + space (e.g. "
1444
+ "'male, indian accent'),\nChinese instructs should use full-width "
1445
+ "comma (e.g. '男,河南话')."
1446
+ )
1447
+ raise ValueError(err)
1448
+
1449
+ # --- Language consistency: dialect forces Chinese, accent forces English ---
1450
+ has_dialect = any(n.endswith("话") for n in normalised)
1451
+ has_accent = any(" accent" in n for n in normalised)
1452
+
1453
+ if has_dialect and has_accent:
1454
+ raise ValueError(
1455
+ "Cannot mix Chinese dialect and English accent in a single instruct. "
1456
+ "Dialects are for Chinese speech, accents for English speech."
1457
+ )
1458
+
1459
+ if has_dialect:
1460
+ use_zh = True
1461
+ elif has_accent:
1462
+ use_zh = False
1463
+
1464
+ # --- Unify to single language ---
1465
+ if use_zh:
1466
+ normalised = [_INSTRUCT_EN_TO_ZH.get(n, n) for n in normalised]
1467
+ else:
1468
+ normalised = [_INSTRUCT_ZH_TO_EN.get(n, n) for n in normalised]
1469
+
1470
+ # --- Category conflict check ---
1471
+ conflicts = []
1472
+ for cat in _INSTRUCT_MUTUALLY_EXCLUSIVE:
1473
+ hits = [n for n in normalised if n in cat]
1474
+ if len(hits) > 1:
1475
+ conflicts.append(hits)
1476
+ if conflicts:
1477
+ parts = []
1478
+ for group in conflicts:
1479
+ parts.append(" vs ".join(f"'{x}'" for x in group))
1480
+ raise ValueError(
1481
+ "Conflicting instruct items within the same category: "
1482
+ + "; ".join(parts)
1483
+ + ". Each category (gender, age, pitch, style, accent, dialect) "
1484
+ "allows at most one item."
1485
+ )
1486
+
1487
+ # Determine separator based on language
1488
+ has_zh = any(any("\u4e00" <= c <= "\u9fff" for c in n) for n in normalised)
1489
+ separator = "," if has_zh else ", "
1490
+
1491
+ return separator.join(normalised)
1492
+
1493
+
1494
+ def _filter_top_k(logits: torch.Tensor, ratio: float = 0.1) -> torch.Tensor:
1495
+ k = math.ceil(ratio * logits.shape[-1])
1496
+ val, ind = logits.topk(k, dim=-1)
1497
+ probs = torch.full_like(logits, float("-inf"))
1498
+ probs.scatter_(-1, ind, val)
1499
+ return probs
1500
+
1501
+
1502
+ def _gumbel_sample(logits: torch.Tensor, temperature: float) -> torch.Tensor:
1503
+ scaled_logits = logits / temperature
1504
+ u = torch.rand_like(scaled_logits)
1505
+ gumbel_noise = -torch.log(-torch.log(u + 1e-10) + 1e-10)
1506
+ return scaled_logits + gumbel_noise
1507
+
1508
+
1509
+ def _get_time_steps(
1510
+ t_start: float = 0.0,
1511
+ t_end: float = 1.0,
1512
+ num_step: int = 10,
1513
+ t_shift: float = 1.0,
1514
+ device: torch.device = torch.device("cpu"),
1515
+ ) -> torch.Tensor:
1516
+ timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
1517
+ timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
1518
+ return timesteps
1519
+
1520
+
1521
+ _NONVERBAL_PATTERN = re.compile(
1522
+ r"\[(laughter|sigh|confirmation-en|question-en|question-ah|question-oh|"
1523
+ r"question-ei|question-yi|surprise-ah|surprise-oh|surprise-wa|"
1524
+ r"surprise-yo|dissatisfaction-hnn)\]"
1525
+ )
1526
+
1527
+
1528
+ def _tokenize_with_nonverbal_tags(text: str, tokenizer) -> torch.Tensor:
1529
+ """Tokenize text containing non-verbal tags, handling each tag independently.
1530
+
1531
+ Non-verbal tags are tokenized standalone to guarantee consistent token
1532
+ IDs regardless of surrounding language context (Chinese, English, etc.).
1533
+
1534
+ Args:
1535
+ text: Full text string potentially containing non-verbal tags.
1536
+ tokenizer: HuggingFace text tokenizer instance.
1537
+ Returns:
1538
+ Token IDs tensor of shape (1, seq_len).
1539
+ """
1540
+ parts = []
1541
+ last_end = 0
1542
+ for m in _NONVERBAL_PATTERN.finditer(text):
1543
+ if m.start() > last_end:
1544
+ segment = text[last_end : m.start()]
1545
+ ids = tokenizer(segment, add_special_tokens=False).input_ids
1546
+ if ids:
1547
+ parts.append(ids)
1548
+ tag_ids = tokenizer(m.group(), add_special_tokens=False).input_ids
1549
+ if tag_ids:
1550
+ parts.append(tag_ids)
1551
+ last_end = m.end()
1552
+ if last_end < len(text):
1553
+ segment = text[last_end:]
1554
+ ids = tokenizer(segment, add_special_tokens=False).input_ids
1555
+ if ids:
1556
+ parts.append(ids)
1557
+
1558
+ if not parts:
1559
+ result = tokenizer(text, return_tensors="pt").input_ids
1560
+ else:
1561
+ combined = []
1562
+ for p in parts:
1563
+ combined.extend(p)
1564
+ result = torch.tensor([combined], dtype=torch.long)
1565
+ return result
1566
+
1567
+
1568
+ def _combine_text(text, ref_text: Optional[str] = None) -> str:
1569
+
1570
+ # combine with reference text if not None
1571
+ if ref_text:
1572
+ full_text = ref_text.strip() + " " + text.strip()
1573
+ else:
1574
+ full_text = text.strip()
1575
+
1576
+ # filter out newline / carriage-return characters
1577
+ full_text = re.sub(r"[\r\n]+", "", full_text)
1578
+
1579
+ # replace Chinese parentheses with English ones
1580
+ full_text = full_text.replace("\uff08", "(").replace("\uff09", ")")
1581
+
1582
+ # collapse consecutive spaces / tabs into a single space
1583
+ full_text = re.sub(r"[ \t]+", " ", full_text)
1584
+
1585
+ # remove spaces around chinese characters
1586
+ chinese_range = r"[\u4e00-\u9fff]"
1587
+ pattern = rf"(?<={chinese_range})\s+|\s+(?={chinese_range})"
1588
+ full_text = re.sub(pattern, "", full_text)
1589
+
1590
+ return full_text
1591
+
1592
+
1593
+ # ---------------------------------------------------------------------------
1594
+ # Register with HuggingFace Auto classes
1595
+ # ---------------------------------------------------------------------------
1596
+
1597
+ AutoConfig.register("omnivoice", OmniVoiceConfig)
1598
+ AutoModel.register(OmniVoiceConfig, OmniVoice)
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c353ee1479b536bf414c1b247f5542b6607fb8ae91320e5af1781fee200fddff
3
+ size 470897