MCplayer commited on
Commit
4b0005e
·
1 Parent(s): 76689e2

XY_Tokenizer AutoModel version support

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+
6
+ ```python
7
+ import torchaudio
8
+ from transformers import AutoFeatureExtractor, AutoModel
9
+
10
+ wav_form, sampling_rate = torchaudio.load("examples/zh_spk1_moon.wav")
11
+ feature_extractor = AutoFeatureExtractor.from_pretrained("MCplayer/XY_Tokenizer", trust_remote_code=True)
12
+ codec = AutoModel.from_pretrained("MCplayer/XY_Tokenizer", trust_remote_code=True, device_map="auto").eval()
13
+
14
+ if sampling_rate != 16000:
15
+ resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
16
+ wav_form = resampler(wav_form)
17
+
18
+ input_spectrum = feature_extractor(wav_form, sampling_rate=16000, return_attention_mask=True, return_tensors="pt")
19
+ code = codec.encode(input_spectrum)
20
+
21
+ output_wav = codec.decode(code["audio_codes"], overlap_seconds=10)
22
+ for i, audio in enumerate(output_wav["audio_values"]):
23
+ torchaudio.save(f"outputs/audio{i}.wav", audio.cpu(), 24000)
24
+
25
+
26
+ ```
config.json ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "xy_tokenizer",
3
+ "auto_map": {
4
+ "AutoFeatureExtractor": "feature_extraction_xy_tokenizer.XYTokenizerFeatureExtractor",
5
+ "AutoConfig": "configuration_xy_tokenizer.XYTokenizerConfig",
6
+ "AutoModel": "modeling_xy_tokenizer.XYTokenizerModel"
7
+ },
8
+ "input_sample_rate": 16000,
9
+ "output_sample_rate": 24000,
10
+ "encoder_downsample_rate": 1280,
11
+ "decoder_upsample_rate": 1920,
12
+ "code_dim": 3072,
13
+ "params": {
14
+ "feature_extractor_kwargs": {
15
+ "chunk_length": 30,
16
+ "feature_size": 80,
17
+ "hop_length": 160,
18
+ "n_fft": 400,
19
+ "n_samples": 480000,
20
+ "nb_max_frames": 3000,
21
+ "padding_side": "right",
22
+ "padding_value": 0.0,
23
+ "sampling_rate": 16000,
24
+ "encoder_downsample_rate": 1280,
25
+ "return_attention_mask": true,
26
+ "return_tensors": "pt"
27
+ },
28
+ "semantic_encoder_kwargs": {
29
+ "num_mel_bins": 80,
30
+ "sampling_rate": 16000,
31
+ "hop_length": 160,
32
+ "stride_size": 2,
33
+ "kernel_size": 3,
34
+ "d_model": 768,
35
+ "scale_embedding": false,
36
+ "max_audio_seconds": 30,
37
+ "encoder_layers": 12,
38
+ "encoder_attention_heads": 12,
39
+ "encoder_ffn_dim": 3072,
40
+ "activation_function": "gelu"
41
+ },
42
+ "semantic_encoder_adapter_kwargs": {
43
+ "input_dim": 768,
44
+ "output_dim": 768,
45
+ "d_model": 768,
46
+ "max_source_positions": 1500,
47
+ "encoder_layers": 4,
48
+ "encoder_attention_heads": 12,
49
+ "encoder_ffn_dim": 3072
50
+ },
51
+ "acoustic_encoder_kwargs": {
52
+ "num_mel_bins": 80,
53
+ "sampling_rate": 16000,
54
+ "hop_length": 160,
55
+ "stride_size": 2,
56
+ "kernel_size": 3,
57
+ "d_model": 768,
58
+ "scale_embedding": false,
59
+ "max_audio_seconds": 30,
60
+ "encoder_layers": 12,
61
+ "encoder_attention_heads": 12,
62
+ "encoder_ffn_dim": 3072,
63
+ "activation_function": "gelu"
64
+ },
65
+ "pre_rvq_adapter_kwargs": {
66
+ "input_dim": 1536,
67
+ "output_dim": 768,
68
+ "d_model": 768,
69
+ "max_source_positions": 1500,
70
+ "encoder_layers": 4,
71
+ "encoder_attention_heads": 12,
72
+ "encoder_ffn_dim": 3072
73
+ },
74
+ "downsample_kwargs": {
75
+ "d_model": 768,
76
+ "avg_pooler": 4
77
+ },
78
+ "quantizer_kwargs": {
79
+ "input_dim": 3072,
80
+ "rvq_dim": 512,
81
+ "output_dim": 3072,
82
+ "num_quantizers": 8,
83
+ "codebook_size": 1024,
84
+ "codebook_dim": 512,
85
+ "quantizer_dropout": 0.0
86
+ },
87
+ "post_rvq_adapter_kwargs": {
88
+ "input_dim": 3072,
89
+ "output_dim": 3072,
90
+ "d_model": 768,
91
+ "max_source_positions": 375,
92
+ "encoder_layers": 4,
93
+ "encoder_attention_heads": 12,
94
+ "encoder_ffn_dim": 3072
95
+ },
96
+ "upsample_kwargs": {
97
+ "d_model": 768,
98
+ "stride": 4
99
+ },
100
+ "acoustic_decoder_kwargs": {
101
+ "num_mel_bins": 80,
102
+ "sampling_rate": 16000,
103
+ "hop_length": 160,
104
+ "stride_size": 2,
105
+ "kernel_size": 3,
106
+ "d_model": 768,
107
+ "scale_embedding": false,
108
+ "max_audio_seconds": 30,
109
+ "decoder_layers": 12,
110
+ "decoder_attention_heads": 12,
111
+ "decoder_ffn_dim": 3072,
112
+ "activation_function": "gelu"
113
+ },
114
+ "vocos_kwargs": {
115
+ "input_channels": 80,
116
+ "dim": 512,
117
+ "intermediate_dim": 4096,
118
+ "num_layers": 30,
119
+ "n_fft": 960,
120
+ "hop_size": 240,
121
+ "padding": "same"
122
+ }
123
+ }
124
+ }
configuration_xy_tokenizer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """XYTokenizer model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class XYTokenizerConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`XYTokenizerModel`]. It is used to instantiate a
26
+ XY Tokenizer model according to the specified arguments, defining the model architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ input_sample_rate (`int`, *optional*, defaults to 16000):
33
+ The sampling rate of the input audio.
34
+ output_sample_rate (`int`, *optional*, defaults to 16000):
35
+ The sampling rate of the output audio.
36
+ encoder_downsample_rate (`int`, *optional*, defaults to 1280):
37
+ The total downsampling factor of the encoder part.
38
+ decoder_upsample_rate (`int`, *optional*, defaults to 1920):
39
+ The total upsampling factor of the decoder part.
40
+ code_dim (`int`, *optional*, defaults to 1280):
41
+ The dimension of the code embeddings.
42
+
43
+ // ... (All other parameters from the original YAML/dict config would be listed here) ...
44
+ // For brevity, we will define them with default values based on the provided code.
45
+
46
+ Example:
47
+ semantic_encoder_d_model (`int`, *optional*, defaults to 1280):
48
+ Hidden dimension for the semantic encoder.
49
+ num_quantizers (`int`, *optional*, defaults to 32):
50
+ Number of residual quantizers.
51
+ ...
52
+ """
53
+ model_type = "xy_tokenizer"
54
+
55
+ # A comprehensive config would flatten all nested kwargs from the original `generator_params`.
56
+ # For this example, we will create a simplified version. A real implementation would
57
+ # have all parameters explicitly defined here.
58
+ def __init__(
59
+ self,
60
+ input_sample_rate=16000,
61
+ output_sample_rate=16000,
62
+ encoder_downsample_rate=1280,
63
+ decoder_upsample_rate=1920,
64
+ code_dim=1280,
65
+ # A real config would have dozens of parameters here.
66
+ # We will dynamically accept them via **kwargs.
67
+ **kwargs,
68
+ ):
69
+ self.input_sample_rate = input_sample_rate
70
+ self.output_sample_rate = output_sample_rate
71
+ self.encoder_downsample_rate = encoder_downsample_rate
72
+ self.decoder_upsample_rate = decoder_upsample_rate
73
+ self.code_dim = code_dim
74
+
75
+ # Store all other parameters dynamically. This is a shortcut.
76
+ # A production-ready config should list all parameters explicitly.
77
+ self.params = kwargs
78
+
79
+ super().__init__(**kwargs)
80
+
81
+
82
+ __all__ = ["XYTokenizerConfig"]
examples/m1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93148fbccb0a2e589bcb1690f16fa1bb54bea633925d8c42436969d4cf2d0cc4
3
+ size 64844
examples/m2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ae2ec7cee605794830d0569014b52aeb5c6f96aafee93b2f6698661944b1166
3
+ size 567044
examples/pod_f_enhanced.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8209f371d91b68e45ec28f4772fa6152f277caa3c4308de5f3a7730a2de7066f
3
+ size 523244
examples/pod_m_enhanced.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29f50bbfb8497dc3727e1412b7c299f7d3c894bb8436388fa11ea86301dbccca
3
+ size 546284
examples/zh_spk1_moon.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd95ddcb03537bc04d19a6e5ab56b2694c9d65324c231c6676569877dc005ace
3
+ size 457956
examples/zh_spk2_moon.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ad41d9c8967148fd0875967c1498c7aaab7dfdd919d7e7e063ae91e7a4a1d19
3
+ size 414720
feature_extraction_xy_tokenizer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for Whisper
17
+ """
18
+ from functools import partial
19
+ from typing import List, Optional, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import WhisperFeatureExtractor
24
+ from transformers.audio_utils import mel_filter_bank
25
+ from transformers.configuration_utils import PretrainedConfig
26
+ from transformers.feature_extraction_utils import BatchFeature
27
+ from transformers.utils import TensorType, logging
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class ExtractorIterator:
33
+ def __init__(
34
+ self,
35
+ data,
36
+ batch_size=1,
37
+ chunk_length=30,
38
+ overlap_seconds=10,
39
+ sampling_rate=16000,
40
+ encoder_downsample_rate=1280,
41
+ encode_func = None,
42
+ ) -> None:
43
+ self.data = data
44
+ self.batch_size = batch_size
45
+ self.chunk_length = chunk_length
46
+ self.overlap_seconds = overlap_seconds
47
+ self.sampling_rate = sampling_rate
48
+ self.encoder_downsample_rate = encoder_downsample_rate
49
+
50
+ # duration_size 是每次处理的有效音频长度
51
+ self.duration_seconds = self.chunk_length - self.overlap_seconds
52
+ self.duration_size = int(self.duration_seconds * self.sampling_rate)
53
+ self.code_duration_length = self.duration_size // self.encoder_downsample_rate
54
+ # 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
55
+ # 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
56
+
57
+ assert callable(encode_func)
58
+ self.encode_func = encode_func
59
+
60
+ def __iter__(self):
61
+ """
62
+ 返回一个生成器,该生成器负责处理所有批处理逻辑。
63
+ 这是最 Pythonic 的实现方式。
64
+ """
65
+ # 批处理相关的变量现在是 __iter__ 的局部变量,非常清晰
66
+ batch_num = 0
67
+
68
+ # 注意:chunk_and_pad_view 输出的块大小是 duration_size
69
+ wav_tensor = torch.zeros(self.batch_size, 1, self.duration_size)
70
+ input_lengths = torch.zeros(self.batch_size, dtype=torch.long)
71
+ input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
72
+
73
+ def chunk_and_pad_view(tensor, chunk_size, seq_no):
74
+ x = tensor[0:1, :].unsqueeze(0)
75
+ B, C, L = x.shape
76
+ num_chunks = (L + chunk_size - 1) // chunk_size
77
+ target_len = num_chunks * chunk_size
78
+ pad_len = target_len - L
79
+ padded_x = F.pad(x, (0, pad_len))
80
+ output_tensor = padded_x.view(B, num_chunks, chunk_size).transpose(0, 1)
81
+ output_lengths = torch.full((num_chunks,), chunk_size, dtype=torch.long)
82
+ if pad_len > 0:
83
+ output_lengths[-1] = chunk_size - pad_len
84
+ output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
85
+ return output_tensor, output_lengths, output_seq_no
86
+
87
+ for i, sample in enumerate(self.data):
88
+ sample_chunks, sample_lengths, sample_seq_no = chunk_and_pad_view(sample, self.duration_size, i)
89
+
90
+ processed_in_sample = 0
91
+ while processed_in_sample < len(sample_chunks):
92
+ space_in_batch = self.batch_size - batch_num
93
+ chunks_to_add = min(space_in_batch, len(sample_chunks) - processed_in_sample)
94
+
95
+ # 定义切片范围
96
+ start_idx_sample = processed_in_sample
97
+ end_idx_sample = processed_in_sample + chunks_to_add
98
+ start_idx_batch = batch_num
99
+ end_idx_batch = batch_num + chunks_to_add
100
+
101
+ # 填充数据
102
+ wav_tensor[start_idx_batch:end_idx_batch] = sample_chunks[start_idx_sample:end_idx_sample]
103
+ input_lengths[start_idx_batch:end_idx_batch] = sample_lengths[start_idx_sample:end_idx_sample]
104
+ input_seq_no[start_idx_batch:end_idx_batch] = sample_seq_no[start_idx_sample:end_idx_sample]
105
+
106
+ # 更新计数器
107
+ batch_num += chunks_to_add
108
+ processed_in_sample += chunks_to_add
109
+
110
+ # 如果批次满了,yield 一个副本并重置
111
+ if batch_num == self.batch_size:
112
+ list_x = [
113
+ wav_tensor[xi, :, :x_len].reshape(-1).cpu().numpy()
114
+ for xi, x_len in enumerate(input_lengths.tolist())
115
+ ]
116
+ yield BatchFeature({
117
+ **self.encode_func(list_x),
118
+ "input_lengths": input_lengths.clone(),
119
+ "chunk_seq_no": input_seq_no.clone(),
120
+ })
121
+
122
+ # 重置批次计数器和Tensor内容
123
+ batch_num = 0
124
+ wav_tensor.zero_()
125
+ input_lengths.zero_()
126
+ input_seq_no.zero_()
127
+
128
+ # 循环结束后,处理最后一个未满的批次
129
+ if batch_num > 0:
130
+ list_x = [
131
+ wav_tensor[xi, :, :x_len].reshape(-1).cpu().numpy()
132
+ for xi, x_len in enumerate(input_lengths.tolist())
133
+ ]
134
+ yield BatchFeature({
135
+ **self.encode_func(list_x),
136
+ "input_lengths": input_lengths.clone(),
137
+ "chunk_seq_no": input_seq_no[:batch_num].clone(),
138
+ })
139
+
140
+
141
+ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
142
+ def __init__(
143
+ self,
144
+ feature_size=80,
145
+ sampling_rate=16000,
146
+ encoder_downsample_rate=1280,
147
+ hop_length=160,
148
+ chunk_length=30,
149
+ n_fft=400,
150
+ padding_value=0.0,
151
+ dither=0.0,
152
+ return_attention_mask=False,
153
+ max_frequency=None,
154
+ batch_size=None,
155
+ **kwargs,
156
+ ):
157
+ super().__init__(
158
+ feature_size=feature_size,
159
+ sampling_rate=sampling_rate,
160
+ hop_length=hop_length,
161
+ chunk_length=chunk_length,
162
+ n_fft=n_fft,
163
+ padding_value=padding_value,
164
+ dither=dither,
165
+ return_attention_mask=return_attention_mask,
166
+ **kwargs,
167
+ )
168
+ self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
169
+ self.encoder_downsample_rate = encoder_downsample_rate
170
+ self.batch_size = batch_size
171
+ self.mel_filters = mel_filter_bank(
172
+ num_frequency_bins=1 + n_fft // 2,
173
+ num_mel_filters=feature_size,
174
+ min_frequency=0.0,
175
+ max_frequency=self.max_frequency,
176
+ sampling_rate=sampling_rate,
177
+ norm="slaney",
178
+ mel_scale="slaney",
179
+ )
180
+
181
+ def __call__(
182
+ self,
183
+ raw_speech: Union[torch.Tensor, List[torch.Tensor]],
184
+ truncation: bool = True,
185
+ pad_to_multiple_of: Optional[int] = None,
186
+ return_tensors: Optional[Union[str, TensorType]] = None,
187
+ return_attention_mask: Optional[bool] = None,
188
+ padding: Optional[str] = "max_length",
189
+ max_length: Optional[int] = None,
190
+ sampling_rate: Optional[int] = None,
191
+ do_normalize: Optional[bool] = None,
192
+ device: Optional[str] = "cpu",
193
+ return_token_timestamps: Optional[bool] = None,
194
+ overlap_seconds: int = 10,
195
+ **kwargs,
196
+ ) -> ExtractorIterator:
197
+
198
+ if not isinstance(raw_speech, list):
199
+ raw_speech = [raw_speech]
200
+
201
+ return ExtractorIterator(
202
+ raw_speech,
203
+ batch_size=len(raw_speech) if self.batch_size is None else self.batch_size,
204
+ chunk_length=self.chunk_length,
205
+ overlap_seconds=overlap_seconds,
206
+ sampling_rate=self.sampling_rate,
207
+ encoder_downsample_rate=self.encoder_downsample_rate,
208
+ encode_func=partial(
209
+ super().__call__,
210
+ truncation=truncation,
211
+ pad_to_multiple_of=pad_to_multiple_of,
212
+ return_tensors=return_tensors,
213
+ return_attention_mask=return_attention_mask,
214
+ padding=padding,
215
+ max_length=max_length,
216
+ sampling_rate=sampling_rate,
217
+ do_normalize=do_normalize,
218
+ device=device,
219
+ return_token_timestamps=return_token_timestamps,
220
+ **kwargs,
221
+ )
222
+ )
modeling_xy_tokenizer.py ADDED
@@ -0,0 +1,1227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Transformers XYTokenizer model."""
16
+
17
+ import math
18
+ from collections import defaultdict
19
+ from dataclasses import asdict, dataclass
20
+ from typing import Optional, Tuple, Union, List
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.distributed as dist
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from einops import rearrange
28
+ from torch.nn.utils.parametrizations import weight_norm
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import ModelOutput, logging
32
+ from transformers.feature_extraction_utils import BatchFeature
33
+
34
+ from .configuration_xy_tokenizer import XYTokenizerConfig
35
+ from .feature_extraction_xy_tokenizer import ExtractorIterator
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ # ----------------------------------------------- #
41
+ # Model Output Dataclasses #
42
+ # ----------------------------------------------- #
43
+ @dataclass
44
+ class XYTokenizerEncodeOutput(ModelOutput):
45
+ """
46
+ Output type of [`XYTokenizerModel.encode`].
47
+
48
+ Args:
49
+ quantized_representation (`torch.FloatTensor` of shape `(batch_size, hidden_dim, sequence_length)`):
50
+ The quantized continuous representation of the input audio. This is the output of the quantizer.
51
+ audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`):
52
+ The discrete codes from the quantizer for each codebook.
53
+ codes_lengths (`torch.LongTensor` of shape `(batch_size,)`):
54
+ The valid length of each sequence in `audio_codes`.
55
+ commit_loss (`torch.FloatTensor`, *optional*):
56
+ The commitment loss from the vector quantizer.
57
+ overlap_seconds (`int`, *optional*):
58
+ The duration of the overlap in seconds between adjacent audio chunks.
59
+ """
60
+ quantized_representation: torch.FloatTensor = None
61
+ audio_codes: torch.LongTensor = None
62
+ codes_lengths: torch.LongTensor = None
63
+ commit_loss: Optional[torch.FloatTensor] = None
64
+ overlap_seconds: Optional[int] = None
65
+
66
+
67
+ @dataclass
68
+ class XYTokenizerDecodeOutput(ModelOutput):
69
+ """
70
+ Output type of [`XYTokenizerModel.decode`].
71
+
72
+ Args:
73
+ audio_values (`torch.FloatTensor` of shape `(batch_size, 1, sequence_length)`):
74
+ The reconstructed audio waveform.
75
+ output_length (`torch.LongTensor` of shape `(batch_size,)`):
76
+ The valid length of each sequence in `audio_values`.
77
+ """
78
+ audio_values: torch.FloatTensor = None
79
+ output_length: Optional[torch.LongTensor] = None
80
+
81
+
82
+ @dataclass
83
+ class XYTokenizerModelOutput(ModelOutput):
84
+ """
85
+ Output type of [`XYTokenizerModel`]'s forward pass.
86
+
87
+ Args:
88
+ audio_values (`torch.FloatTensor` of shape `(batch_size, 1, sequence_length)`):
89
+ The reconstructed audio waveform.
90
+ output_length (`torch.LongTensor` of shape `(batch_size,)`):
91
+ The valid length of each sequence in `audio_values`.
92
+ quantized_representation (`torch.FloatTensor` of shape `(batch_size, hidden_dim, sequence_length)`):
93
+ The quantized continuous representation of the input audio. This is the output of the quantizer.
94
+ audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`):
95
+ The discrete codes from the quantizer for each codebook.
96
+ codes_lengths (`torch.LongTensor` of shape `(batch_size,)`):
97
+ The valid length of each sequence in `audio_codes`.
98
+ commit_loss (`torch.FloatTensor`, *optional*):
99
+ The commitment loss from the vector quantizer.
100
+ """
101
+ audio_values: torch.FloatTensor = None
102
+ output_length: torch.LongTensor = None
103
+ quantized_representation: torch.FloatTensor = None
104
+ audio_codes: torch.LongTensor = None
105
+ codes_lengths: torch.LongTensor = None
106
+ commit_loss: Optional[torch.FloatTensor] = None
107
+
108
+
109
+ @dataclass
110
+ class VectorQuantizerConfig:
111
+ """Configuration for the VectorQuantize module."""
112
+ commitment: float = 1.0
113
+ decay: float = 0.99
114
+ epsilon: float = 1e-5
115
+ threshold_ema_dead: int = 2
116
+ kmeans_init: bool = True
117
+ kmeans_iters: int = 10
118
+
119
+
120
+ # ----------------------------------------------- #
121
+ # All Helper Modules (Copied from source) #
122
+ # ----------------------------------------------- #
123
+ def sinusoids(length, channels, max_timescale=10000):
124
+ assert channels % 2 == 0
125
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
126
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
127
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
128
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
129
+
130
+
131
+ def get_sequence_mask(inputs, inputs_length):
132
+ if inputs.dim() == 3:
133
+ bsz, tgt_len, _ = inputs.size()
134
+ else:
135
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
136
+ sequence_mask = torch.arange(0, tgt_len, device=inputs.device)
137
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
138
+ return sequence_mask
139
+
140
+
141
+ class RMSNorm(nn.Module):
142
+ def __init__(self, hidden_size, eps=1e-6):
143
+ super().__init__()
144
+ self.weight = nn.Parameter(torch.ones(hidden_size))
145
+ self.variance_epsilon = eps
146
+
147
+ def forward(self, hidden_states):
148
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
149
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
150
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
151
+ hidden_states = hidden_states.to(self.weight.dtype)
152
+ return self.weight * hidden_states
153
+
154
+
155
+ class VarLenAttention(nn.Module):
156
+ def __init__(self, embed_dim, num_heads, causal=False, dropout=0.0):
157
+ super().__init__()
158
+ self.embed_dim = embed_dim
159
+ self.num_heads = num_heads
160
+ self.head_dim = embed_dim // num_heads
161
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
162
+ self.causal = causal
163
+ self.dropout = nn.Dropout(dropout)
164
+ self.scaling = self.head_dim ** -0.5
165
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
166
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
167
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
168
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
169
+
170
+ def _create_attention_mask(self, seq_len, max_len, device, dtype):
171
+ bsz = seq_len.size(0)
172
+ mask = torch.ones(bsz, 1, max_len, max_len, device=device, dtype=dtype)
173
+ seq_indices = torch.arange(max_len, device=device).unsqueeze(0)
174
+ seq_len_expanded = seq_len.unsqueeze(1)
175
+ valid_mask = seq_indices < seq_len_expanded.unsqueeze(-1)
176
+ mask = mask * (valid_mask.unsqueeze(2) & valid_mask.unsqueeze(3)).to(dtype)
177
+ if self.causal:
178
+ causal_mask = torch.triu(torch.ones(max_len, max_len, device=device, dtype=torch.bool), diagonal=1)
179
+ mask = mask * (~causal_mask.unsqueeze(0).unsqueeze(1)).to(dtype)
180
+ mask = mask + (1.0 - mask) * torch.finfo(dtype).min
181
+ return mask
182
+
183
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
184
+ bsz, max_len, _ = hidden_states.size()
185
+ query = self.q_proj(hidden_states) * self.scaling
186
+ key = self.k_proj(hidden_states)
187
+ value = self.v_proj(hidden_states)
188
+ query = query.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2)
189
+ key = key.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2)
190
+ value = value.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2)
191
+ attn_scores = torch.matmul(query, key.transpose(-1, -2))
192
+ attn_mask = self._create_attention_mask(seq_len, max_len, hidden_states.device, attn_scores.dtype)
193
+ attn_scores = attn_scores + attn_mask
194
+ attn_weights = F.softmax(attn_scores, dim=-1)
195
+ attn_weights = self.dropout(attn_weights)
196
+ attn_output = torch.matmul(attn_weights, value)
197
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, max_len, self.embed_dim)
198
+ attn_output = self.out_proj(attn_output)
199
+ return attn_output
200
+
201
+
202
+ class OmniWhisperTransformerLayer(nn.Module):
203
+ def __init__(self, activation_function="gelu", d_model=1280, attention_heads=20, ffn_dim=5120, causal=False, ln_type="LayerNorm", attn_type="varlen"):
204
+ super().__init__()
205
+ self.embed_dim = d_model
206
+ if attn_type != "varlen":
207
+ raise ValueError(f"Unknown attn_type: {attn_type}. Only 'varlen' is supported.")
208
+ self.self_attn = VarLenAttention(self.embed_dim, attention_heads, causal)
209
+ if ln_type == "LayerNorm":
210
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
211
+ elif ln_type == "RMSNorm":
212
+ self.self_attn_layer_norm = RMSNorm(self.embed_dim)
213
+ else:
214
+ raise ValueError(f"Unknown ln_type: {ln_type}")
215
+ self.activation_fn = ACT2FN[activation_function]
216
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
217
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
218
+ if ln_type == "LayerNorm":
219
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
220
+ elif ln_type == "RMSNorm":
221
+ self.final_layer_norm = RMSNorm(self.embed_dim)
222
+ else:
223
+ raise ValueError(f"Unknown ln_type: {ln_type}")
224
+
225
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
226
+ residual = hidden_states
227
+ hidden_states = self.self_attn_layer_norm(hidden_states)
228
+ hidden_states = self.self_attn(hidden_states, seq_len)
229
+ hidden_states = residual + hidden_states
230
+ residual = hidden_states
231
+ hidden_states = self.final_layer_norm(hidden_states)
232
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
233
+ hidden_states = self.fc2(hidden_states)
234
+ hidden_states = residual + hidden_states
235
+ if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and \
236
+ (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
237
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
238
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
239
+ return hidden_states
240
+
241
+
242
+ class OmniAudioEncoder(nn.Module):
243
+ def __init__(
244
+ self, num_mel_bins=128, sampling_rate=16000, hop_length=160, stride_size=2, kernel_size=3,
245
+ d_model=1280, scale_embedding=True, max_audio_seconds=30, encoder_layers=32,
246
+ encoder_attention_heads=20, encoder_ffn_dim=5120, activation_function="gelu", attn_type="varlen"
247
+ ):
248
+ super().__init__()
249
+ self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size
250
+ self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0
251
+ self.num_mel_bins, self.d_model, self.stride_size = num_mel_bins, d_model, stride_size
252
+ self.conv1 = nn.Conv1d(num_mel_bins, d_model, kernel_size=kernel_size, padding=1)
253
+ self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=kernel_size, stride=stride_size, padding=1)
254
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
255
+ self.layers = nn.ModuleList([
256
+ OmniWhisperTransformerLayer(activation_function, d_model, encoder_attention_heads, encoder_ffn_dim, False, attn_type=attn_type)
257
+ for _ in range(encoder_layers)
258
+ ])
259
+ self.layer_norm = nn.LayerNorm(d_model)
260
+
261
+ def forward(self, input_features, input_length, output_hidden_states=False):
262
+ input_features = input_features.to(self.conv1.weight.dtype)
263
+ inputs_embeds = F.gelu(self.conv1(input_features))
264
+ inputs_embeds = F.gelu(self.conv2(inputs_embeds))
265
+ output_length = (input_length // self.stride_size).long()
266
+ hidden_states = inputs_embeds.permute(0, 2, 1)
267
+ bsz, tgt_len, _ = hidden_states.size()
268
+ pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding
269
+ hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype)
270
+ attention_mask = get_sequence_mask(hidden_states, output_length)
271
+ all_hidden = () if output_hidden_states else None
272
+ for layer in self.layers:
273
+ if output_hidden_states:
274
+ all_hidden += (hidden_states,)
275
+ hidden_states = layer(hidden_states, output_length)
276
+ hidden_states = self.layer_norm(hidden_states)
277
+ if output_hidden_states:
278
+ all_hidden += (hidden_states,)
279
+ hidden_states = torch.where(attention_mask, hidden_states, 0).transpose(1, 2)
280
+ if not output_hidden_states:
281
+ return hidden_states, output_length
282
+ return hidden_states, output_length, all_hidden
283
+
284
+
285
+ class OmniAudioDecoder(nn.Module):
286
+ def __init__(
287
+ self, num_mel_bins=128, sampling_rate=16000, hop_length=160, stride_size=2, kernel_size=3,
288
+ d_model=1280, scale_embedding=True, max_audio_seconds=30, decoder_layers=32,
289
+ decoder_attention_heads=20, decoder_ffn_dim=5120, activation_function="gelu", attn_type="varlen"
290
+ ):
291
+ super().__init__()
292
+ self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size
293
+ self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0
294
+ self.num_mel_bins, self.d_model, self.stride_size = num_mel_bins, d_model, stride_size
295
+ self.deconv1 = nn.ConvTranspose1d(d_model, d_model, kernel_size, stride_size, padding=0, output_padding=0)
296
+ self.deconv2 = nn.ConvTranspose1d(d_model, num_mel_bins, kernel_size, stride=1, padding=0)
297
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
298
+ self.layers = nn.ModuleList([
299
+ OmniWhisperTransformerLayer(activation_function, d_model, decoder_attention_heads, decoder_ffn_dim, False, attn_type=attn_type)
300
+ for _ in range(decoder_layers)
301
+ ])
302
+ self.layer_norm = nn.LayerNorm(d_model)
303
+
304
+ def forward(self, hidden_states, input_length):
305
+ hidden_states = hidden_states.transpose(1, 2)
306
+ bsz, tgt_len, _ = hidden_states.size()
307
+ pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding
308
+ hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype)
309
+ attention_mask = get_sequence_mask(hidden_states, input_length)
310
+ for layer in self.layers:
311
+ hidden_states = layer(hidden_states, input_length)
312
+ hidden_states = self.layer_norm(hidden_states)
313
+ hidden_states = torch.where(attention_mask, hidden_states, 0).permute(0, 2, 1)
314
+ output_features = F.gelu(self.deconv1(hidden_states))
315
+ output_features = F.gelu(self.deconv2(output_features))
316
+ expected_length = tgt_len * self.stride_size
317
+ if output_features.size(2) > expected_length:
318
+ output_features = output_features[:, :, :expected_length]
319
+ output_length = input_length * self.stride_size
320
+ return output_features, output_length
321
+
322
+
323
+ class ResidualDownConv(nn.Module):
324
+ def __init__(self, d_model=1280, avg_pooler=4):
325
+ super().__init__()
326
+ self.d_model, self.avg_pooler = d_model, avg_pooler
327
+ self.intermediate_dim = d_model * avg_pooler
328
+ self.gate_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False)
329
+ self.up_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False)
330
+ self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
331
+ self.act_fn = ACT2FN['silu']
332
+ self.layer_norm = nn.LayerNorm(self.intermediate_dim)
333
+
334
+ def forward(self, x, input_length):
335
+ output_length = input_length // self.avg_pooler
336
+ x = x.transpose(1, 2)
337
+ batch_size, seq_len, _ = x.shape
338
+ if seq_len % self.avg_pooler != 0:
339
+ pad_size = self.avg_pooler - seq_len % self.avg_pooler
340
+ x = F.pad(x, (0, 0, 0, pad_size), "constant", 0) # Pad sequence dim
341
+ xt = x.permute(0, 2, 1)
342
+ g, u = self.gate_proj(xt).permute(0, 2, 1), self.up_proj(xt).permute(0, 2, 1)
343
+ x = x.reshape(batch_size, -1, self.intermediate_dim)
344
+ c = self.down_proj(self.act_fn(g) * u)
345
+ res = self.layer_norm(c + x).transpose(1, 2)
346
+ return res, output_length
347
+
348
+
349
+ class UpConv(nn.Module):
350
+ def __init__(self, d_model=1280, stride=4):
351
+ super().__init__()
352
+ self.d_model, self.stride = d_model, stride
353
+ self.up_conv = nn.ConvTranspose1d(self.stride * d_model, d_model, stride, stride, bias=False)
354
+
355
+ def forward(self, x, input_length):
356
+ res = self.up_conv(x)
357
+ output_length = input_length * self.stride
358
+ return res, output_length
359
+
360
+
361
+ class Transformer(nn.Module):
362
+ def __init__(
363
+ self, input_dim=1280, d_model=1280, output_dim=1280, max_source_positions=1500,
364
+ encoder_layers=32, encoder_attention_heads=20, encoder_ffn_dim=5120,
365
+ activation_function="gelu", attn_type="varlen"
366
+ ):
367
+ super().__init__()
368
+ self.input_dim, self.d_model, self.output_dim, self.max_source_positions = input_dim, d_model, output_dim, max_source_positions
369
+ self.proj = nn.Linear(input_dim, d_model, bias=True) if input_dim != d_model else None
370
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
371
+ self.layers = nn.ModuleList([
372
+ OmniWhisperTransformerLayer(activation_function, d_model, encoder_attention_heads, encoder_ffn_dim, False, attn_type=attn_type)
373
+ for _ in range(encoder_layers)
374
+ ])
375
+ self.layer_norm = nn.LayerNorm(d_model)
376
+ self.out_proj = nn.Linear(d_model, output_dim, bias=True) if output_dim != d_model else None
377
+
378
+ def forward(self, input_features, input_length, output_hidden_states=False):
379
+ output_length = input_length.long()
380
+ hidden_states = self.proj(input_features.permute(0, 2, 1)).permute(0, 2, 1) if self.proj else input_features
381
+ hidden_states = hidden_states.permute(0, 2, 1)
382
+ bsz, tgt_len, _ = hidden_states.size()
383
+ pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding
384
+ hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype)
385
+ attention_mask = get_sequence_mask(hidden_states, output_length)
386
+ all_hidden = () if output_hidden_states else None
387
+ for layer in self.layers:
388
+ if output_hidden_states:
389
+ all_hidden += (hidden_states,)
390
+ hidden_states = layer(hidden_states, output_length)
391
+ hidden_states = self.layer_norm(hidden_states)
392
+ if output_hidden_states:
393
+ all_hidden += (hidden_states,)
394
+ hidden_states = torch.where(attention_mask, hidden_states, 0).transpose(1, 2)
395
+ if self.out_proj:
396
+ hidden_states = self.out_proj(hidden_states.permute(0, 2, 1)).permute(0, 2, 1)
397
+ if not output_hidden_states:
398
+ return hidden_states, output_length
399
+ return hidden_states, output_length, all_hidden
400
+
401
+
402
+ # Note: The other helper classes like STFT, ISTFT, Vocos, VectorQuantize, etc.,
403
+ # would be placed here. For brevity, they are omitted but are required dependencies.
404
+ # Assuming they are defined in the same way as the user provided code.
405
+ # The code below will assume these classes are defined in the current scope.
406
+ # ... [Paste all other helper classes here] ...
407
+ class ISTFT(nn.Module):
408
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
409
+ super().__init__()
410
+ if padding not in ["center", "same"]:
411
+ raise ValueError("Padding must be 'center' or 'same'.")
412
+ self.padding, self.n_fft, self.hop_length, self.win_length = padding, n_fft, hop_length, win_length
413
+ self.register_buffer("window", torch.hann_window(win_length))
414
+
415
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
416
+ if self.padding == "center":
417
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
418
+ elif self.padding == "same":
419
+ pad = (self.win_length - self.hop_length) // 2
420
+ else:
421
+ raise ValueError("Padding must be 'center' or 'same'.")
422
+ B, N, T = spec.shape
423
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") * self.window[None, :, None]
424
+ output_size = (T - 1) * self.hop_length + self.win_length
425
+
426
+ y = F.fold(ifft, (1, output_size), (1, self.win_length), stride=(1, self.hop_length))[:, 0, 0, pad:-pad]
427
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
428
+ window_envelope = torch.nn.functional.fold(
429
+ window_sq,
430
+ output_size=(1, output_size),
431
+ kernel_size=(1, self.win_length),
432
+ stride=(1, self.hop_length),
433
+ ).squeeze()[pad:-pad]
434
+ assert (window_envelope > 1e-11).all()
435
+ return y / window_envelope
436
+
437
+
438
+ class FourierHead(nn.Module):
439
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
440
+ raise NotImplementedError("Subclasses must implement the forward method.")
441
+
442
+
443
+ class ISTFTHead(FourierHead):
444
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
445
+ super().__init__()
446
+ self.out = nn.Linear(dim, n_fft + 2)
447
+ self.istft = ISTFT(n_fft, hop_length, n_fft, padding)
448
+
449
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
450
+ x = self.out(x).transpose(1, 2)
451
+ mag, p = x.chunk(2, dim=1)
452
+ mag = torch.exp(mag).clip(max=1e2)
453
+ s = mag.float() * (torch.cos(p).float() + 1j * torch.sin(p).float())
454
+ return self.istft(s).to(x.dtype)
455
+
456
+
457
+ class AdaLayerNorm(nn.Module):
458
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
459
+ super().__init__()
460
+ self.eps, self.dim = eps, embedding_dim
461
+ self.scale = nn.Embedding(num_embeddings, embedding_dim)
462
+ self.shift = nn.Embedding(num_embeddings, embedding_dim)
463
+ torch.nn.init.ones_(self.scale.weight)
464
+ torch.nn.init.zeros_(self.shift.weight)
465
+
466
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
467
+ scale, shift = self.scale(cond_embedding_id), self.shift(cond_embedding_id)
468
+ x = F.layer_norm(x, (self.dim,), eps=self.eps)
469
+ return x * scale + shift
470
+
471
+
472
+ class ConvNeXtBlock(nn.Module):
473
+ def __init__(self, dim, intermediate_dim, layer_scale_init_value, adanorm_num_embeddings=None):
474
+ super().__init__()
475
+ self.dwconv = nn.Conv1d(dim, dim, 7, 1, 3, groups=dim)
476
+ self.adanorm = adanorm_num_embeddings is not None
477
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim) if self.adanorm else nn.LayerNorm(dim, eps=1e-6)
478
+ self.pwconv1 = nn.Linear(dim, intermediate_dim)
479
+ self.act = nn.GELU()
480
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
481
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None
482
+
483
+ def forward(self, x, cond_embedding_id=None):
484
+ res = x
485
+ x = self.dwconv(x).transpose(1, 2)
486
+ x = self.norm(x, cond_embedding_id) if self.adanorm else self.norm(x)
487
+ x = self.pwconv2(self.act(self.pwconv1(x)))
488
+ if self.gamma is not None:
489
+ x = self.gamma * x
490
+ x = res + x.transpose(1, 2)
491
+ return x
492
+
493
+
494
+ class Backbone(nn.Module):
495
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
496
+ raise NotImplementedError("Subclasses must implement the forward method.")
497
+
498
+
499
+ class VocosBackbone(Backbone):
500
+ def __init__(self, input_channels, dim, intermediate_dim, num_layers, layer_scale_init_value=None, adanorm_num_embeddings=None):
501
+ super().__init__()
502
+ self.input_channels, self.embed = input_channels, nn.Conv1d(input_channels, dim, 7, 1, 3)
503
+ self.adanorm = adanorm_num_embeddings is not None
504
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim) if self.adanorm else nn.LayerNorm(dim, eps=1e-6)
505
+ self.convnext = nn.ModuleList([ConvNeXtBlock(dim, intermediate_dim, layer_scale_init_value or 1/num_layers, adanorm_num_embeddings) for _ in range(num_layers)])
506
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
507
+ self.apply(self._init_weights)
508
+
509
+ def _init_weights(self, m):
510
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
511
+ nn.init.trunc_normal_(m.weight, std=0.02)
512
+ if m.bias is not None:
513
+ nn.init.constant_(m.bias, 0)
514
+
515
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
516
+ x = self.embed(x).transpose(1, 2)
517
+ x = self.norm(x, kwargs.get("bandwidth_id")) if self.adanorm else self.norm(x)
518
+ x = x.transpose(1, 2)
519
+ for block in self.convnext:
520
+ x = block(x, kwargs.get("bandwidth_id"))
521
+ return self.final_layer_norm(x.transpose(1, 2))
522
+
523
+
524
+ class Vocos(nn.Module):
525
+ def __init__(self, input_channels=128, dim=512, intermediate_dim=4096, num_layers=30, n_fft=640, hop_size=160, padding="same", adanorm_num_embeddings=None):
526
+ super().__init__()
527
+ self.backbone = VocosBackbone(input_channels, dim, intermediate_dim, num_layers, adanorm_num_embeddings=adanorm_num_embeddings)
528
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
529
+ self.hop_size = hop_size
530
+
531
+ def forward(self, x, input_length):
532
+ x = self.backbone(x)
533
+ x = self.head(x)
534
+ return x[:, None, :], input_length * self.hop_size
535
+
536
+
537
+ def WNConv1d(*args, **kwargs):
538
+ return weight_norm(nn.Conv1d(*args, **kwargs))
539
+
540
+
541
+ def ema_inplace(moving_avg, new, decay):
542
+ moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay))
543
+
544
+
545
+ def sample_vectors(samples, num):
546
+ num_samples, device = samples.shape[0], samples.device
547
+ indices = torch.randperm(num_samples, device=device)[:num] if num_samples >= num else torch.randint(0, num_samples, (num,), device=device)
548
+ return samples[indices].float()
549
+
550
+
551
+ def kmeans(samples, num_clusters, num_iters=10):
552
+ dim, means = samples.shape[-1], sample_vectors(samples, num_clusters).float()
553
+ for _ in range(num_iters):
554
+ dists = -(samples.float().pow(2).sum(1, keepdim=True) - 2 * samples.float() @ means.t() + means.t().float().pow(2).sum(0, keepdim=True))
555
+ buckets = dists.max(dim=-1).indices
556
+ bins = torch.bincount(buckets, minlength=num_clusters)
557
+ zero_mask = bins == 0
558
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
559
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32).scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) / bins_min_clamped[..., None]
560
+ means = torch.where(zero_mask[..., None], means, new_means)
561
+ dists = -(samples.float().pow(2).sum(1, keepdim=True) - 2 * samples.float() @ means.t() + means.t().float().pow(2).sum(0, keepdim=True))
562
+ return means, torch.bincount(dists.max(dim=-1).indices, minlength=num_clusters).float()
563
+
564
+
565
+ class VectorQuantize(nn.Module):
566
+ def __init__(self, input_dim, codebook_size, codebook_dim, commitment=1.0, decay=0.99, epsilon=1e-5, threshold_ema_dead=2, kmeans_init=True, kmeans_iters=10):
567
+ super().__init__()
568
+ self.input_dim, self.codebook_size, self.codebook_dim = input_dim, codebook_size, codebook_dim
569
+ self.commitment, self.decay, self.epsilon, self.threshold_ema_dead = commitment, decay, epsilon, threshold_ema_dead
570
+ self.kmeans_init, self.kmeans_iters = kmeans_init, kmeans_iters
571
+ self.in_project = WNConv1d(input_dim, codebook_dim, 1) if input_dim != codebook_dim else nn.Identity()
572
+ self.out_project = WNConv1d(codebook_dim, input_dim, 1) if codebook_dim != input_dim else nn.Identity()
573
+ self.register_buffer("codebook", torch.zeros(codebook_size, codebook_dim) if kmeans_init else torch.randn(codebook_size, codebook_dim))
574
+ self.register_buffer("inited", torch.tensor(not kmeans_init, dtype=torch.bool))
575
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
576
+ self.register_buffer("embed_avg", self.codebook.clone())
577
+
578
+ def ema_update(self, encodings, embed_onehot):
579
+ encodings, embed_onehot = encodings.float(), embed_onehot.float()
580
+ cluster_size_new, embed_sum = embed_onehot.sum(0), encodings.t() @ embed_onehot
581
+ if dist.is_initialized():
582
+ dist.all_reduce(cluster_size_new)
583
+ dist.all_reduce(embed_sum)
584
+ ema_inplace(self.cluster_size, cluster_size_new, self.decay)
585
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
586
+ cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) * self.cluster_size.sum()
587
+ self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1))
588
+
589
+ def replace_dead_codes(self, encodings):
590
+ if self.threshold_ema_dead == 0: return
591
+ dead_mask = self.cluster_size < self.threshold_ema_dead
592
+ if dead_mask.any():
593
+ samples = sample_vectors(encodings.float(), self.codebook_size) if not dist.is_initialized() or dist.get_rank() == 0 else torch.zeros_like(self.codebook)
594
+ if dist.is_initialized(): dist.broadcast(samples, src=0)
595
+ self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype)
596
+
597
+ def init_codebook(self, encodings):
598
+ if self.inited.item(): return
599
+ if not dist.is_initialized() or dist.get_rank() == 0:
600
+ embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters)
601
+ else:
602
+ embed, cluster_sizes = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device), torch.zeros(self.codebook_size, device=encodings.device)
603
+ if dist.is_initialized():
604
+ dist.broadcast(embed, src=0)
605
+ dist.broadcast(cluster_sizes, src=0)
606
+ self.codebook.copy_(embed)
607
+ self.embed_avg.copy_(embed.clone())
608
+ self.cluster_size.copy_(cluster_sizes)
609
+ self.inited.fill_(True)
610
+
611
+ def forward(self, z):
612
+ z_e = self.in_project(z.float())
613
+ encodings = rearrange(z_e, "b d t -> (b t) d")
614
+ if self.kmeans_init and not self.inited.item(): self.init_codebook(encodings)
615
+ dist = encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ self.codebook.float().t() + self.codebook.float().pow(2).sum(1, keepdim=True).t()
616
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=z.size(0))
617
+ z_q = self.decode_code(indices)
618
+ commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment
619
+ if self.training and torch.is_grad_enabled():
620
+ self.ema_update(encodings, F.one_hot(indices.view(-1), self.codebook_size))
621
+ self.replace_dead_codes(encodings)
622
+ z_q = self.out_project(z_e + (z_q - z_e).detach())
623
+ return z_q, commit_loss, torch.tensor(0.0, device=z.device), indices, z_e
624
+
625
+ def decode_code(self, embed_id):
626
+ return F.embedding(embed_id, self.codebook.float()).transpose(1, 2)
627
+
628
+
629
+ class ResidualVQ(nn.Module):
630
+ def __init__(
631
+ self,
632
+ input_dim: int = 1280,
633
+ rvq_dim: int = None,
634
+ output_dim: int = None,
635
+ num_quantizers: int = 32,
636
+ codebook_size: int = 1024,
637
+ codebook_dim: int = 8,
638
+ quantizer_dropout: float = 0.5,
639
+ skip_rvq_ratio: float = 0.0,
640
+ vq_config: VectorQuantizerConfig = None,
641
+ **kwargs
642
+ ):
643
+ super().__init__()
644
+ self.input_dim, self.rvq_dim, self.output_dim = input_dim, rvq_dim, output_dim or input_dim
645
+ self.num_quantizers, self.codebook_size, self.codebook_dim = num_quantizers, codebook_size, codebook_dim
646
+ self.quantizer_dropout, self.skip_rvq_ratio = quantizer_dropout, skip_rvq_ratio
647
+ self.input_proj = WNConv1d(input_dim, rvq_dim, 1) if input_dim != rvq_dim else nn.Identity()
648
+ self.output_proj = WNConv1d(rvq_dim, self.output_dim, 1) if rvq_dim != self.output_dim else nn.Identity()
649
+ if vq_config is None:
650
+ vq_config = VectorQuantizerConfig()
651
+ quantizer_kwargs = asdict(vq_config)
652
+ self.quantizers = nn.ModuleList([VectorQuantize(rvq_dim, codebook_size, codebook_dim, **quantizer_kwargs, **kwargs) for _ in range(num_quantizers)])
653
+
654
+
655
+ def forward(self, z, input_length, n_quantizers: int = None):
656
+ z = self.input_proj(z)
657
+
658
+ with torch.autocast('cuda', enabled=False):
659
+ batch_size, _, max_time = z.shape
660
+ device = z.device
661
+ mask = torch.arange(max_time, device=device).expand(batch_size, max_time) < input_length.unsqueeze(1)
662
+
663
+ quantized_out = torch.zeros_like(z)
664
+ residual = z.clone().float()
665
+
666
+ all_commit_losses = []
667
+ all_indices = []
668
+ all_quantized = []
669
+
670
+ # --- Complexity Reduction Start ---
671
+ # 1. Extracted logic for determining quantizer numbers and skip mask
672
+ n_q_tensor = self._get_n_quantizers_tensor(batch_size, device, n_quantizers)
673
+ skip_mask = self._get_skip_mask(batch_size, device)
674
+ # --- Complexity Reduction End ---
675
+
676
+ max_q_to_run = self.num_quantizers if self.training else (n_quantizers or self.num_quantizers)
677
+
678
+ for i, quantizer in enumerate(self.quantizers[:max_q_to_run]):
679
+ # Create a mask for which batch items are active in this iteration
680
+ active_in_iteration_mask = (i < n_q_tensor)
681
+
682
+ # Skip quantization for items that are not active
683
+ if not active_in_iteration_mask.any():
684
+ # If no items are active, we can add placeholders and continue
685
+ # This branch is less common but handles the case where all items have dropped out
686
+ all_commit_losses.append(torch.tensor(0.0, device=device))
687
+ all_indices.append(torch.zeros(batch_size, max_time, dtype=torch.long, device=device))
688
+ all_quantized.append(torch.zeros_like(z))
689
+ continue
690
+
691
+ masked_residual = residual * mask.unsqueeze(1)
692
+
693
+ # --- Complexity Reduction Start ---
694
+ # 2. Extracted quantization step logic
695
+ z_q_i, commit_loss_i, indices_i = self._quantize_step(quantizer, masked_residual, skip_mask)
696
+ # --- Complexity Reduction End ---
697
+
698
+ # Create a mask for updating tensors (batch items active in this iteration AND within valid length)
699
+ update_mask = (active_in_iteration_mask.view(-1, 1, 1) & mask.unsqueeze(1))
700
+
701
+ quantized_out += z_q_i * update_mask
702
+ residual -= z_q_i * update_mask
703
+
704
+ # Calculate average commitment loss only for active items
705
+ commit_loss_i = commit_loss_i[active_in_iteration_mask].mean() if active_in_iteration_mask.any() else torch.tensor(0.0, device=device)
706
+
707
+ all_commit_losses.append(commit_loss_i)
708
+ all_indices.append(indices_i)
709
+ all_quantized.append(z_q_i)
710
+
711
+ # Pad the outputs if the loop was exited early (e.g., in eval mode with n_quantizers)
712
+ num_loops_done = len(all_commit_losses)
713
+ if num_loops_done < self.num_quantizers:
714
+ remaining = self.num_quantizers - num_loops_done
715
+ all_commit_losses.extend([torch.tensor(0.0, device=device)] * remaining)
716
+ all_indices.extend([torch.zeros(batch_size, max_time, dtype=torch.long, device=device)] * remaining)
717
+ all_quantized.extend([torch.zeros_like(z)] * remaining)
718
+
719
+
720
+ quantized_out = self.output_proj(quantized_out)
721
+ all_indices_tensor = torch.stack(all_indices)
722
+ all_commit_losses_tensor = torch.stack(all_commit_losses)
723
+ all_quantized_tensor = torch.stack(all_quantized)
724
+
725
+ return (
726
+ quantized_out,
727
+ all_indices_tensor,
728
+ all_commit_losses_tensor,
729
+ all_quantized_tensor,
730
+ input_length,
731
+ )
732
+
733
+ def decode_codes(self, codes):
734
+ nq, B, T = codes.shape
735
+ emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32)
736
+ for i, quantizer in enumerate(self.quantizers[:nq]):
737
+ emb += quantizer.decode_code(codes[i])
738
+ return self.output_proj(emb)
739
+
740
+ def _get_n_quantizers_tensor(self, batch_size: int, device: torch.device, n_quantizers_override: Optional[int] = None) -> torch.Tensor:
741
+ """
742
+ Determines the number of quantizers to use for each item in the batch,
743
+ applying dropout during training.
744
+ """
745
+ # If not training or dropout is disabled, use the override or default number of quantizers
746
+ is_training = self.training and torch.is_grad_enabled()
747
+ if not is_training or self.quantizer_dropout == 0:
748
+ num_q = n_quantizers_override or self.num_quantizers
749
+ return torch.full((batch_size,), num_q, dtype=torch.long, device=device)
750
+
751
+ # During training, apply quantizer dropout
752
+ n_q_tensor = torch.full((batch_size,), self.num_quantizers, device=device)
753
+ n_dropout = int(batch_size * self.quantizer_dropout)
754
+ if n_dropout > 0:
755
+ dropout_indices = torch.randperm(batch_size, device=device)[:n_dropout]
756
+ dropout_values = torch.randint(1, self.num_quantizers + 1, (n_dropout,), device=device)
757
+ n_q_tensor[dropout_indices] = dropout_values
758
+
759
+ return n_q_tensor
760
+
761
+ def _get_skip_mask(self, batch_size: int, device: torch.device) -> Optional[torch.Tensor]:
762
+ """Generates a mask for skipping RVQ during training if skip_rvq_ratio > 0."""
763
+ is_training = self.training and torch.is_grad_enabled()
764
+ if not is_training or self.skip_rvq_ratio <= 0:
765
+ return None
766
+
767
+ skip_mask = torch.rand(batch_size, device=device) < self.skip_rvq_ratio
768
+ # Ensure at least one sample is not skipped to avoid errors in modules like DDP
769
+ if skip_mask.all():
770
+ skip_mask[0] = False
771
+ return skip_mask
772
+
773
+ def _quantize_step(self, quantizer, residual, skip_mask):
774
+ """Helper to perform one step of quantization, handling the skip logic."""
775
+ # The main logic is for non-skipped samples
776
+ z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(residual.float())
777
+
778
+ # If skipping is active, overwrite the results for the masked samples
779
+ if skip_mask is not None:
780
+ # For skipped samples, the "quantized" output is the residual itself
781
+ # and the loss is zero.
782
+ skip_mask_expanded = skip_mask.view(-1, 1, 1)
783
+ z_q_i = torch.where(skip_mask_expanded, residual, z_q_i)
784
+ commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i)
785
+
786
+ return z_q_i, commit_loss_i, indices_i
787
+
788
+
789
+
790
+ # ----------------------------------------------- #
791
+ # PreTrainedModel Base Class #
792
+ # ----------------------------------------------- #
793
+ class XYTokenizerPreTrainedModel(PreTrainedModel):
794
+ """
795
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
796
+ models.
797
+ """
798
+ config_class = XYTokenizerConfig
799
+ base_model_prefix = "xy_tokenizer"
800
+ main_input_name = "input_values"
801
+ _supports_grad_checkpointing = True
802
+
803
+ def _init_weights(self, module):
804
+ """Initialize the weights."""
805
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)):
806
+ module.weight.data.normal_(mean=0.0, std=0.02)
807
+ if module.bias is not None:
808
+ module.bias.data.zero_()
809
+ elif isinstance(module, nn.Embedding):
810
+ module.weight.data.normal_(mean=0.0, std=0.02)
811
+ if module.padding_idx is not None:
812
+ module.weight.data[module.padding_idx].zero_()
813
+
814
+ def _set_gradient_checkpointing(self, module, value=False):
815
+ if isinstance(module, (OmniAudioEncoder, OmniAudioDecoder, Transformer)):
816
+ module.gradient_checkpointing = value
817
+
818
+
819
+ # ----------------------------------------------- #
820
+ # Main Model Class #
821
+ # ----------------------------------------------- #
822
+ class XYTokenizerModel(XYTokenizerPreTrainedModel):
823
+ def __init__(self, config: XYTokenizerConfig):
824
+ super().__init__(config)
825
+ # Reconstruct the nested parameter dictionaries from the flat config
826
+ # This is a bit of a boilerplate but necessary to reuse the original module code.
827
+ # A more integrated approach would refactor the sub-modules to accept the flat config directly.
828
+ self.config = config
829
+
830
+ params = config.params
831
+ self.semantic_encoder = OmniAudioEncoder(**params['semantic_encoder_kwargs'])
832
+ self.semantic_encoder_adapter = Transformer(**params['semantic_encoder_adapter_kwargs'])
833
+ self.acoustic_encoder = OmniAudioEncoder(**params['acoustic_encoder_kwargs'])
834
+ self.pre_rvq_adapter = Transformer(**params['pre_rvq_adapter_kwargs'])
835
+ self.downsample = ResidualDownConv(**params['downsample_kwargs'])
836
+ self.quantizer = ResidualVQ(**params['quantizer_kwargs'])
837
+ self.post_rvq_adapter = Transformer(**params['post_rvq_adapter_kwargs'])
838
+ self.upsample = UpConv(**params['upsample_kwargs'])
839
+ self.acoustic_decoder = OmniAudioDecoder(**params['acoustic_decoder_kwargs'])
840
+ self.enhanced_vocos = Vocos(**params['vocos_kwargs'])
841
+ self.feature_extractor = params['feature_extractor_kwargs']
842
+ # Store some config values for easier access
843
+ self.nq = params['quantizer_kwargs']['num_quantizers']
844
+
845
+ # Initialize weights and apply final processing
846
+ self.post_init()
847
+
848
+ def _get_feat_extract_output_lengths(self, input_lengths: Optional[torch.Tensor]):
849
+ """
850
+ Computes the output lengths of the feature extractor.
851
+ """
852
+ def _get_out_len(in_len):
853
+ return (in_len - self.feature_extractor["n_fft"]) // self.feature_extractor["hop_length"] + 1
854
+
855
+ if input_lengths is None:
856
+ return None
857
+
858
+ return torch.tensor([_get_out_len(l) for l in input_lengths], device=self.device)
859
+
860
+ @torch.inference_mode
861
+ def encode(
862
+ self,
863
+ features: Union[BatchFeature, ExtractorIterator],
864
+ n_quantizers: Optional[int] = None,
865
+ return_dict: Optional[bool] = True,
866
+ ) -> Union[XYTokenizerEncodeOutput, Tuple]:
867
+ r"""
868
+ Encodes the input audio waveform into discrete codes.
869
+
870
+ Args:
871
+ features (`BatchFeature` or `ExtractorIterator`):
872
+ A single batch of features or an iterator that yields batches of chunks for long audio files.
873
+ The iterator is expected to yield `BatchFeature` dicts which must contain a `sequence_ids`
874
+ tensor of shape `(batch_size,)` mapping each item in the chunk to its original sequence.
875
+ n_quantizers (`int`, *optional*):
876
+ The number of quantizers to use. If not specified, all quantizers are used.
877
+ return_dict (`bool`, *optional*):
878
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
879
+ Returns:
880
+ [`XYTokenizerEncodeOutput`] or `tuple(torch.FloatTensor)`
881
+ """
882
+ assert isinstance(features, (BatchFeature, ExtractorIterator))
883
+ # Handle single batch case
884
+ if isinstance(features, BatchFeature):
885
+ return self._encode(features, n_quantizers, return_dict)
886
+
887
+ # Handle streaming/chunked case
888
+ else:
889
+ # Use a dictionary to group chunks by their original sequence ID
890
+ encodings = defaultdict(lambda: {"zq": [], "codes": [], "length": 0})
891
+ commit_losses = []
892
+ total_frames = 0
893
+
894
+ # 1. Iterate through chunks and store intermediate results
895
+ for chunk_features in features:
896
+ code_duration_length = features.code_duration_length
897
+ # Always use return_dict=True for easier access to named outputs
898
+ chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
899
+ valid_code_lengths = torch.clamp(chunk_output.codes_lengths, 0, code_duration_length)
900
+
901
+ # Accumulate weighted commit loss
902
+ chunk_length = chunk_output.codes_lengths.sum().item()
903
+ valid_chunk_length = valid_code_lengths.sum().item()
904
+ if chunk_output.commit_loss is not None and valid_chunk_length > 0:
905
+ commit_loss = chunk_output.commit_loss / chunk_length * valid_chunk_length
906
+ commit_losses.append((commit_loss.cpu(), valid_chunk_length))
907
+ total_frames += valid_chunk_length
908
+
909
+ # Group results by original sequence ID
910
+ for i, seq_id in enumerate(chunk_features["chunk_seq_no"].tolist()):
911
+ valid_code_length = valid_code_lengths[i]
912
+ if valid_code_length > 0:
913
+ encodings[seq_id]["zq"].append(chunk_output.quantized_representation[i:i+1, :, :valid_code_length])
914
+ encodings[seq_id]["codes"].append(chunk_output.audio_codes[:, i:i+1, :valid_code_length])
915
+ # Add the valid length of this chunk to the total for this sequence
916
+ encodings[seq_id]["length"] += valid_code_lengths[i].item()
917
+
918
+ final_outputs = []
919
+ for seq_id, seq_data in encodings.items():
920
+ final_outputs.append({
921
+ "zq": torch.cat(seq_data["zq"], dim=2),
922
+ "codes": torch.cat(seq_data["codes"], dim=2),
923
+ "length": seq_data["length"]
924
+ })
925
+
926
+ # 3. Pad all sequences to the same length and stack into a batch
927
+ max_len = max(seq["zq"].shape[2] for seq in final_outputs)
928
+
929
+ batch_zq = []
930
+ batch_codes = []
931
+ batch_lengths = []
932
+
933
+ for seq in final_outputs:
934
+ pad_amount = max_len - seq["zq"].shape[2]
935
+ # Pad on the right side of the last dimension (time)
936
+ padded_zq = F.pad(seq["zq"], (0, pad_amount))
937
+ padded_codes = F.pad(seq["codes"], (0, pad_amount))
938
+
939
+ batch_zq.append(padded_zq)
940
+ batch_codes.append(padded_codes)
941
+ batch_lengths.append(seq["length"])
942
+
943
+ # Stack the list of tensors into a single batch tensor
944
+ quantized_representation = torch.cat(batch_zq, dim=0)
945
+ audio_codes = torch.cat(batch_codes, dim=0)
946
+ codes_lengths = torch.tensor(batch_lengths, dtype=torch.long, device=self.device)
947
+
948
+ # 4. Calculate final commit loss
949
+ if total_frames > 0:
950
+ # Weighted average of commit losses
951
+ commit_loss = sum(loss * length for loss, length in commit_losses) / total_frames
952
+ commit_loss = commit_loss.to(self.device)
953
+ else:
954
+ commit_loss = torch.tensor(0.0, device=self.device)
955
+
956
+ if not return_dict:
957
+ return (quantized_representation, audio_codes, codes_lengths, commit_loss)
958
+
959
+ return XYTokenizerEncodeOutput(
960
+ quantized_representation=quantized_representation,
961
+ audio_codes=audio_codes,
962
+ codes_lengths=codes_lengths,
963
+ commit_loss=commit_loss,
964
+ overlap_seconds=features.overlap_seconds,
965
+ )
966
+
967
+ def _encode(
968
+ self,
969
+ features: BatchFeature,
970
+ n_quantizers: Optional[int] = None,
971
+ return_dict: Optional[bool] = True,
972
+ ) -> Union[XYTokenizerEncodeOutput, Tuple]:
973
+ input_mel = features['input_features'].to(self.device, dtype=self.dtype)
974
+ mel_attention_mask = features['attention_mask'].to(self.device)
975
+ input_lengths = features['input_lengths'].to(self.device).unsqueeze(1)
976
+ mel_output_length = mel_attention_mask.sum(dim=-1).long().unsqueeze(1)
977
+ mel_output_length = torch.cat((mel_output_length, input_lengths), dim=1).min(dim=1).values
978
+
979
+ # --- Encoder Path ---
980
+ semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length)
981
+ semantic_adapter_output, _ = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length)
982
+ acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length)
983
+
984
+ concated_channel = torch.cat([semantic_adapter_output, acoustic_encoder_output], dim=1)
985
+
986
+ pre_rvq_adapter_output, _ = self.pre_rvq_adapter(concated_channel, acoustic_encoder_output_length)
987
+ downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, acoustic_encoder_output_length)
988
+
989
+ n_quantizers = n_quantizers or self.quantizer.num_quantizers
990
+ zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length, n_quantizers=n_quantizers)
991
+
992
+ if not return_dict:
993
+ return (zq, codes, quantizer_output_length, vq_loss)
994
+
995
+ return XYTokenizerEncodeOutput(
996
+ quantized_representation=zq,
997
+ audio_codes=codes,
998
+ codes_lengths=quantizer_output_length,
999
+ commit_loss=vq_loss.mean()
1000
+ )
1001
+
1002
+ @torch.inference_mode
1003
+ def decode(
1004
+ self,
1005
+ audio_codes: Union[torch.Tensor, XYTokenizerEncodeOutput],
1006
+ overlap_seconds: int = 10,
1007
+ return_dict: Optional[bool] = True,
1008
+ ) -> Union[XYTokenizerDecodeOutput, Tuple]:
1009
+ r"""
1010
+ Decodes discrete codes back into an audio waveform.
1011
+
1012
+ Args:
1013
+ audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`):
1014
+ The discrete codes from the quantizer for each codebook.
1015
+ codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1016
+ The valid length of each sequence in `audio_codes`. If not provided, it's assumed to be the full length.
1017
+ return_dict (`bool`, *optional*):
1018
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1019
+ Returns:
1020
+ [`XYTokenizerDecodeOutput`] or `tuple(torch.FloatTensor)`
1021
+ """
1022
+ assert not isinstance(audio_codes, tuple), "try to set param `return_dict=True` for `codec.encode()` function"
1023
+ assert isinstance(audio_codes, (torch.Tensor, XYTokenizerEncodeOutput)), \
1024
+ "only accept `torch.Tensor` or `XYTokenizerEncodeOutput` for `codec.decode()` function"
1025
+ if isinstance(audio_codes, XYTokenizerEncodeOutput):
1026
+ audio_codes = audio_codes.audio_codes
1027
+ if hasattr(audio_codes, "overlap_seconds"):
1028
+ overlap_seconds = audio_codes.overlap_seconds
1029
+ if overlap_seconds is None:
1030
+ overlap_seconds = 0
1031
+ chunk_length = self.feature_extractor["chunk_length"]
1032
+ duration_seconds = chunk_length - overlap_seconds
1033
+ chunk_code_length = int(chunk_length * self.feature_extractor["sampling_rate"] // self.config.encoder_downsample_rate) # Maximum code length per chunk
1034
+ duration_code_length = int(duration_seconds * self.feature_extractor["sampling_rate"] // self.config.encoder_downsample_rate) # Valid code length per chunk
1035
+ duration_wav_length = duration_code_length * self.config.decoder_upsample_rate # Valid waveform length per chunk
1036
+
1037
+ # Get maximum code length
1038
+ batch_size = audio_codes.shape[1]
1039
+ codes_list = [audio_codes[:, i, :] for i in range(batch_size)]
1040
+ max_code_length = max(codes.shape[-1] for codes in codes_list)
1041
+ batch_size = len(codes_list)
1042
+ codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=self.device, dtype=torch.long)
1043
+ code_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.device)
1044
+ for i, codes in enumerate(codes_list):
1045
+ codes_tensor[:, i, :codes.shape[-1]] = codes.to(self.device)
1046
+ code_lengths[i] = codes.shape[-1] # (B,)
1047
+
1048
+ # Calculate number of chunks needed
1049
+ max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length
1050
+ wav_list = []
1051
+
1052
+ # Process the entire batch in chunks
1053
+ for chunk_idx in range(max_chunks):
1054
+ start = chunk_idx * duration_code_length
1055
+ end = min(start + chunk_code_length, max_code_length)
1056
+ chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
1057
+ chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,)
1058
+
1059
+ # Skip empty chunks
1060
+ if chunk_code_lengths.max() == 0:
1061
+ continue
1062
+
1063
+ # Decode
1064
+ result = self._decode(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)}
1065
+ chunk_wav = result["audio_values"] # (B, 1, T')
1066
+ chunk_wav_lengths = result["output_length"] # (B,)
1067
+
1068
+ # Extract valid portion
1069
+ valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,)
1070
+ valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=self.device)
1071
+ for b in range(batch_size):
1072
+ if valid_wav_lengths[b] > 0:
1073
+ valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length)
1074
+
1075
+ wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
1076
+
1077
+ # Concatenate all chunks
1078
+ if wav_list:
1079
+ wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
1080
+ syn_wav_list = [wav_tensor[i, :, :code_lengths[i] * self.config.decoder_upsample_rate] for i in range(batch_size)] # B * (1, T,)
1081
+ else:
1082
+ syn_wav_list = [torch.zeros(1, 0, device=self.device) for _ in range(batch_size)] # B * (1, 0,)
1083
+
1084
+ if not return_dict:
1085
+ return (syn_wav_list,)
1086
+
1087
+ return XYTokenizerDecodeOutput(
1088
+ audio_values=syn_wav_list
1089
+ )
1090
+
1091
+ def _decode(
1092
+ self,
1093
+ audio_codes: torch.Tensor,
1094
+ codes_lengths: Optional[torch.Tensor] = None,
1095
+ return_dict: Optional[bool] = True,
1096
+ ) -> Union[XYTokenizerDecodeOutput, Tuple]:
1097
+ r"""
1098
+ Decodes discrete codes back into an audio waveform.
1099
+
1100
+ Args:
1101
+ audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`):
1102
+ The discrete codes from the quantizer for each codebook.
1103
+ codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1104
+ The valid length of each sequence in `audio_codes`. If not provided, it's assumed to be the full length.
1105
+ return_dict (`bool`, *optional*):
1106
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1107
+ Returns:
1108
+ [`XYTokenizerDecodeOutput`] or `tuple(torch.FloatTensor)`
1109
+ """
1110
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1111
+
1112
+ if codes_lengths is None:
1113
+ codes_lengths = torch.full((audio_codes.shape[1],), audio_codes.shape[2], device=self.device)
1114
+
1115
+ # --- Decoder Path ---
1116
+ zq = self.quantizer.decode_codes(audio_codes)
1117
+
1118
+ post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths)
1119
+ upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length)
1120
+ acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length)
1121
+ y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length)
1122
+
1123
+ if not return_dict:
1124
+ return (y, vocos_output_length)
1125
+
1126
+ return XYTokenizerDecodeOutput(
1127
+ audio_values=y,
1128
+ output_length=vocos_output_length
1129
+ )
1130
+
1131
+ def forward(
1132
+ self,
1133
+ input_values: torch.Tensor,
1134
+ attention_mask: Optional[torch.Tensor] = None,
1135
+ n_quantizers: Optional[int] = None,
1136
+ return_dict: Optional[bool] = True,
1137
+ ) -> Union[XYTokenizerModelOutput, Tuple]:
1138
+ r"""
1139
+ The forward method that handles the full encoding and decoding process.
1140
+
1141
+ Args:
1142
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1143
+ Float values of the input audio waveform.
1144
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1145
+ Mask to avoid performing attention on padding token indices.
1146
+ n_quantizers (`int`, *optional*):
1147
+ The number of quantizers to use for encoding. If not specified, all quantizers are used.
1148
+ return_dict (`bool`, *optional*):
1149
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1150
+
1151
+ Examples:
1152
+
1153
+ ```python
1154
+ >>> from transformers import AutoModel, AutoFeatureExtractor
1155
+ >>> from datasets import load_dataset, Audio
1156
+ >>> import torch
1157
+
1158
+ >>> # This is a placeholder model name, replace with the actual one on the Hub
1159
+ >>> model_id = "your-namespace/xy-tokenizer-model"
1160
+ >>> model = AutoModel.from_pretrained(model_id)
1161
+ >>> # The feature extractor config is part of the model config, so it can be loaded this way
1162
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
1163
+
1164
+ >>> # Load a dummy audio dataset
1165
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1166
+ >>> audio_sample = ds[0]["audio"]["array"]
1167
+ >>> sampling_rate = ds[0]["audio"]["sampling_rate"]
1168
+
1169
+ >>> # Process audio
1170
+ >>> inputs = feature_extractor(audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
1171
+
1172
+ >>> # Encode to get codes
1173
+ >>> with torch.no_grad():
1174
+ ... encoder_output = model.encode(inputs["input_values"], attention_mask=inputs["attention_mask"])
1175
+ ... audio_codes = encoder_output.audio_codes
1176
+
1177
+ >>> # Decode from codes
1178
+ >>> with torch.no_grad():
1179
+ ... decoder_output = model.decode(audio_codes)
1180
+ ... reconstructed_audio = decoder_output.audio_values
1181
+
1182
+ >>> # Full forward pass
1183
+ >>> with torch.no_grad():
1184
+ ... model_output = model(**inputs)
1185
+ ... reconstructed_audio_fwd = model_output.audio_values
1186
+
1187
+ >>> print(reconstructed_audio.shape)
1188
+ torch.Size([1, 1, 147200])
1189
+ >>> print(torch.allclose(reconstructed_audio, reconstructed_audio_fwd))
1190
+ True
1191
+ ```
1192
+
1193
+ Returns:
1194
+ [`XYTokenizerModelOutput`] or `tuple(torch.FloatTensor)`
1195
+ """
1196
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1197
+
1198
+ encoder_outputs = self.encode(
1199
+ input_values=input_values,
1200
+ attention_mask=attention_mask,
1201
+ n_quantizers=n_quantizers,
1202
+ return_dict=True
1203
+ )
1204
+
1205
+ decoder_outputs = self.decode(
1206
+ audio_codes=encoder_outputs,
1207
+ return_dict=True
1208
+ )
1209
+
1210
+ if not return_dict:
1211
+ return (
1212
+ decoder_outputs.audio_values,
1213
+ decoder_outputs.output_length,
1214
+ encoder_outputs.quantized_representation,
1215
+ encoder_outputs.audio_codes,
1216
+ encoder_outputs.codes_lengths,
1217
+ encoder_outputs.commit_loss
1218
+ )
1219
+
1220
+ return XYTokenizerModelOutput(
1221
+ audio_values=decoder_outputs.audio_values,
1222
+ output_length=decoder_outputs.output_length,
1223
+ quantized_representation=encoder_outputs.quantized_representation,
1224
+ audio_codes=encoder_outputs.audio_codes,
1225
+ codes_lengths=encoder_outputs.codes_lengths,
1226
+ commit_loss=encoder_outputs.commit_loss
1227
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_size": 80,
4
+ "hop_length": 160,
5
+ "n_fft": 400,
6
+ "n_samples": 480000,
7
+ "nb_max_frames": 3000,
8
+ "padding_side": "right",
9
+ "padding_value": 0.0,
10
+ "sampling_rate": 16000,
11
+ "encoder_downsample_rate": 1280,
12
+ "return_attention_mask": true,
13
+ "return_tensors": "pt"
14
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f1074eb82317fc9d767e23175ed485b84478841461d5d95451af5d8ec89aaf6
3
+ size 2137329653