Leon299 commited on
Commit
8337fa0
·
verified ·
1 Parent(s): 834120b

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. MuCodec/.gitattributes +2 -0
  2. MuCodec/.gitignore +3 -0
  3. MuCodec/LICENSE +21 -0
  4. MuCodec/LICENSE_weights +399 -0
  5. MuCodec/__pycache__/generate.cpython-310.pyc +0 -0
  6. MuCodec/__pycache__/generate.cpython-312.pyc +0 -0
  7. MuCodec/__pycache__/model.cpython-310.pyc +0 -0
  8. MuCodec/__pycache__/model.cpython-312.pyc +0 -0
  9. MuCodec/configs/models/transformer2D.json +25 -0
  10. MuCodec/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json +14 -0
  11. MuCodec/generate.py +247 -0
  12. MuCodec/libs/rvq/__pycache__/descript_quantize3.cpython-310.pyc +0 -0
  13. MuCodec/libs/rvq/descript_quantize3.py +298 -0
  14. MuCodec/model.py +367 -0
  15. MuCodec/models/attention.py +682 -0
  16. MuCodec/models/transformer_2d_flow.py +545 -0
  17. MuCodec/mp3_to_code.py +187 -0
  18. MuCodec/muq_dev/test.py +22 -0
  19. MuCodec/readme.md +67 -0
  20. MuCodec/requirements.txt +335 -0
  21. MuCodec/tools/get_melvaehifigan48k.py +1551 -0
  22. MuCodec/tools/torch_tools.py +100 -0
  23. __pycache__/audio_tokens.cpython-310.pyc +0 -0
  24. __pycache__/audio_tokens.cpython-312.pyc +0 -0
  25. __pycache__/condition_encoders.cpython-310.pyc +0 -0
  26. __pycache__/condition_encoders.cpython-312.pyc +0 -0
  27. __pycache__/dataset.cpython-310.pyc +0 -0
  28. __pycache__/dataset.cpython-312.pyc +0 -0
  29. __pycache__/decoders.cpython-310.pyc +0 -0
  30. __pycache__/decoders.cpython-312.pyc +0 -0
  31. __pycache__/inference_full.cpython-310.pyc +0 -0
  32. __pycache__/inference_full.cpython-312.pyc +0 -0
  33. __pycache__/modelling_qwen3.cpython-310.pyc +0 -0
  34. __pycache__/modelling_qwen3.cpython-312.pyc +0 -0
  35. __pycache__/runtime_utils.cpython-310.pyc +0 -0
  36. __pycache__/runtime_utils.cpython-312.pyc +0 -0
  37. audio_tokens.py +21 -0
  38. batch_infer_checkpoints.py +402 -0
  39. condition_encoders.py +149 -0
  40. dataset.py +513 -0
  41. decoders.py +158 -0
  42. inference_full.py +1084 -0
  43. modelling_qwen3.py +237 -0
  44. muse_mucodec_chord.ds/dataset_dict.json +1 -0
  45. runtime_utils.py +111 -0
  46. train.py +259 -0
  47. vocab/__init__.py +51 -0
  48. vocab/chord.py +144 -0
  49. vocab/sections.py +105 -0
  50. wandb/debug-cli.root.log +0 -0
MuCodec/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
MuCodec/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ *.pt
3
+ *.pth
MuCodec/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MuCodec/LICENSE_weights ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
MuCodec/__pycache__/generate.cpython-310.pyc ADDED
Binary file (8.18 kB). View file
 
MuCodec/__pycache__/generate.cpython-312.pyc ADDED
Binary file (18 kB). View file
 
MuCodec/__pycache__/model.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
MuCodec/__pycache__/model.cpython-312.pyc ADDED
Binary file (21.9 kB). View file
 
MuCodec/configs/models/transformer2D.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Transformer2DModel",
3
+ "activation_fn": "gelu-approximate",
4
+ "attention_bias": true,
5
+ "attention_head_dim": 72,
6
+ "attention_type": "default",
7
+ "cross_attention_dim": null,
8
+ "double_self_attention": false,
9
+ "dropout": 0.0,
10
+ "in_channels": 96,
11
+ "norm_elementwise_affine": false,
12
+ "norm_eps": 1e-06,
13
+ "norm_num_groups": 32,
14
+ "norm_type": "ada_norm_single",
15
+ "num_attention_heads": 22,
16
+ "num_embeds_ada_norm": 1000,
17
+ "num_layers": 24,
18
+ "num_vector_embeds": null,
19
+ "only_cross_attention": false,
20
+ "out_channels": 32,
21
+ "patch_size": 2,
22
+ "sample_size": 384,
23
+ "upcast_attention": false,
24
+ "use_linear_projection": false
25
+ }
MuCodec/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.8.0",
4
+ "beta_end": 0.02,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.0015,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "sample",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "trained_betas": null
14
+ }
MuCodec/generate.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import sys
5
+ from model import PromptCondAudioDiffusion
6
+ from diffusers import DDIMScheduler, DDPMScheduler
7
+ import torchaudio
8
+ import librosa
9
+ import os
10
+ import math
11
+ import numpy as np
12
+ from tools.get_melvaehifigan48k import build_pretrained_models
13
+ import tools.torch_tools as torch_tools
14
+ from safetensors.torch import load_file
15
+
16
+ class MuCodec:
17
+ def __init__(self, \
18
+ model_path, \
19
+ layer_num, \
20
+ load_main_model=True, \
21
+ device="cuda:0"):
22
+
23
+ self.layer_num = layer_num - 1
24
+ self.sample_rate = 48000
25
+ self.device = device
26
+
27
+ self.MAX_DURATION = 360
28
+ if load_main_model:
29
+ audio_ldm_path = os.path.dirname(os.path.abspath(__file__)) + "/tools/audioldm_48k.pth"
30
+ self.vae, self.stft = build_pretrained_models(audio_ldm_path)
31
+ self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
32
+ main_config = {
33
+ "num_channels":32,
34
+ "unet_model_name":None,
35
+ "unet_model_config_path":os.path.dirname(os.path.abspath(__file__)) + "/configs/models/transformer2D.json",
36
+ "snr_gamma":None,
37
+ }
38
+ self.model = PromptCondAudioDiffusion(**main_config)
39
+ if model_path.endswith('.safetensors'):
40
+ main_weights = load_file(model_path)
41
+ else:
42
+ main_weights = torch.load(model_path, map_location='cpu')
43
+ self.model.load_state_dict(main_weights, strict=False)
44
+ self.model = self.model.to(device)
45
+ print ("Successfully loaded checkpoint from:", model_path)
46
+ else:
47
+ main_config = {
48
+ "num_channels":32,
49
+ "unet_model_name":None,
50
+ "unet_model_config_path":None,
51
+ "snr_gamma":None,
52
+ }
53
+ self.model = PromptCondAudioDiffusion(**main_config).to(device)
54
+ main_weights = torch.load(model_path, map_location='cpu')
55
+ self.model.load_state_dict(main_weights, strict=False)
56
+ self.model = self.model.to(device)
57
+ print ("Successfully loaded checkpoint from:", model_path)
58
+
59
+ self.model.eval()
60
+ self.model.init_device_dtype(torch.device(device), torch.float32)
61
+ print("scaling factor: ", self.model.normfeat.std)
62
+
63
+ def file2code(self, fname):
64
+ orig_samples, fs = torchaudio.load(fname)
65
+ if(fs!=self.sample_rate):
66
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
67
+ fs = self.sample_rate
68
+ if orig_samples.shape[0] == 1:
69
+ orig_samples = torch.cat([orig_samples, orig_samples], 0)
70
+ return self.sound2code(orig_samples)
71
+
72
+ @torch.no_grad()
73
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
74
+ def sound2code(self, orig_samples, batch_size=3):
75
+ if(orig_samples.ndim == 2):
76
+ audios = orig_samples.unsqueeze(0).to(self.device)
77
+ elif(orig_samples.ndim == 3):
78
+ audios = orig_samples.to(self.device)
79
+ else:
80
+ assert orig_samples.ndim in (2,3), orig_samples.shape
81
+ audios = self.preprocess_audio(audios)
82
+ audios = audios.squeeze(0)
83
+ orig_length = audios.shape[-1]
84
+ min_samples = int(40.96 * self.sample_rate)
85
+ output_len = int(orig_length / float(self.sample_rate) * 25) + 1
86
+ print("output_len: ", output_len)
87
+
88
+ while(audios.shape[-1] < min_samples + 480):
89
+ audios = torch.cat([audios, audios], -1)
90
+ int_max_len=audios.shape[-1]//min_samples+1
91
+ # print("int_max_len: ", int_max_len)
92
+ audios = torch.cat([audios, audios], -1)
93
+ # print("audios:",audios.shape)
94
+ audios=audios[:,:int(int_max_len*(min_samples+480))]
95
+ codes_list=[]
96
+
97
+ audio_input = audios.reshape(2, -1, min_samples+480).permute(1, 0, 2).reshape(-1, 2, min_samples+480)
98
+
99
+ for audio_inx in range(0, audio_input.shape[0], batch_size):
100
+ # import pdb; pdb.set_trace()
101
+ codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
102
+ codes_list.append(torch.cat(codes, 1))
103
+ # print("codes_list",codes_list[0].shape)
104
+
105
+ codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
106
+ codes=codes[:,:,:output_len]
107
+
108
+ return codes
109
+
110
+ @torch.no_grad()
111
+ def code2sound(self, codes, prompt=None, duration=40.96, guidance_scale=1.5, num_steps=20, disable_progress=False):
112
+ codes = codes.to(self.device)
113
+ first_latent = torch.randn(codes.shape[0], 32, 512, 32).to(self.device)
114
+ first_latent_length = 0
115
+ first_latent_codes_length = 0
116
+ if(isinstance(prompt, torch.Tensor)):
117
+ prompt = prompt.to(self.device)
118
+ if(prompt.ndim == 3):
119
+ assert prompt.shape[0] == 1, prompt.shape
120
+ prompt = prompt[0]
121
+ elif(prompt.ndim == 1):
122
+ prompt = prompt.unsqueeze(0).repeat(2,1)
123
+ elif(prompt.ndim == 2):
124
+ if(prompt.shape[0] == 1):
125
+ prompt = prompt.repeat(2,1)
126
+
127
+ if(prompt.shape[-1] < int(30.76 * self.sample_rate)):
128
+ prompt = prompt[:,:int(10.24*self.sample_rate)] # limit max length to 10.24
129
+ else:
130
+ prompt = prompt[:,int(20.48*self.sample_rate):int(30.72*self.sample_rate)] # limit max length to 10.24
131
+
132
+ true_mel , _, _ = torch_tools.wav_to_fbank2(prompt, -1, fn_STFT=self.stft) # maximum 10.24s
133
+ true_mel = true_mel.unsqueeze(1).to(self.device)
134
+ true_latent = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(true_mel[[m]])) for m in range(true_mel.shape[0])],0)
135
+ true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach()
136
+
137
+ first_latent[:,:,0:true_latent.shape[2],:] = true_latent
138
+ first_latent_length = true_latent.shape[2]
139
+ first_latent_codes = self.sound2code(prompt)[:,:,0:first_latent_length*2] # B 4 T
140
+ first_latent_codes_length = first_latent_codes.shape[-1]
141
+ codes = torch.cat([first_latent_codes, codes], -1)
142
+
143
+ min_samples = 1024
144
+ hop_samples = min_samples // 4 * 3
145
+ ovlp_samples = min_samples - hop_samples
146
+ hop_frames = hop_samples // 2
147
+ ovlp_frames = ovlp_samples // 2
148
+
149
+ codes_len= codes.shape[-1]
150
+ target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
151
+
152
+ if(codes_len < min_samples):
153
+ while(codes.shape[-1] < min_samples):
154
+ codes = torch.cat([codes, codes], -1)
155
+ codes = codes[:,:,0:min_samples]
156
+ codes_len = codes.shape[-1]
157
+ if((codes_len - ovlp_frames) % hop_samples > 0):
158
+ len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
159
+ while(codes.shape[-1] < len_codes):
160
+ codes = torch.cat([codes, codes], -1)
161
+ codes = codes[:,:,0:len_codes]
162
+ latent_length = 512
163
+ latent_list = []
164
+ spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
165
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
166
+ for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
167
+ codes_input=[]
168
+ codes_input.append(codes[:,:,sinx:sinx+min_samples])
169
+ if(sinx == 0):
170
+ incontext_length = first_latent_length
171
+ latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
172
+ latent_list.append(latents)
173
+ else:
174
+ true_latent = latent_list[-1][:,:,-ovlp_frames:,:]
175
+ len_add_to_512 = 512 - true_latent.shape[-2]
176
+ incontext_length = true_latent.shape[-2]
177
+ true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], true_latent.shape[1], len_add_to_512, true_latent.shape[-1]).to(self.device)], -2)
178
+ latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
179
+ latent_list.append(latents)
180
+
181
+ latent_list = [l.float() for l in latent_list]
182
+ latent_list[0] = latent_list[0][:,:,first_latent_length:,:]
183
+ min_samples = int(duration * self.sample_rate)
184
+ hop_samples = min_samples // 4 * 3
185
+ ovlp_samples = min_samples - hop_samples
186
+ with torch.no_grad():
187
+ output = None
188
+ for i in range(len(latent_list)):
189
+ latent = latent_list[i]
190
+ bsz , ch, t, f = latent.shape
191
+ latent = latent.reshape(bsz*2, ch//2, t, f)
192
+ mel = self.vae.decode_first_stage(latent)
193
+ cur_output = self.vae.decode_to_waveform(mel)
194
+ cur_output = torch.from_numpy(cur_output)[:, 0:min_samples]
195
+
196
+ if output is None:
197
+ output = cur_output
198
+ else:
199
+ ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
200
+ ov_win = torch.cat([ov_win, 1 - ov_win], -1)
201
+ output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
202
+ output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
203
+ output = output[:, 0:target_len]
204
+ return output
205
+
206
+ @torch.no_grad()
207
+ def preprocess_audio(self, input_audios, threshold=0.8):
208
+ assert len(input_audios.shape) == 3, input_audios.shape
209
+ nchan = input_audios.shape[1]
210
+ input_audios = input_audios.reshape(input_audios.shape[0], -1)
211
+ norm_value = torch.ones_like(input_audios[:,0])
212
+ max_volume = input_audios.abs().max(dim=-1)[0]
213
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
214
+ return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
215
+
216
+ @torch.no_grad()
217
+ def sound2sound(self, sound, prompt=None, min_duration=40.96, steps=50, disable_progress=False):
218
+ codes = self.sound2code(sound)
219
+ wave = self.code2sound(codes, prompt, duration=min_duration, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
220
+ return wave
221
+
222
+ if __name__=="__main__":
223
+ ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/mucodec.pt")
224
+ mucodec = MuCodec(model_path=ckpt_path,layer_num=7,load_main_model=True)
225
+
226
+ filelist = []
227
+
228
+ root_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_wav")
229
+ for f in [os.path.join(root_dir, f) for f in os.listdir(root_dir) if '.flac' in f or '.wav' in f or '.mp3' in f]:
230
+ a, fs = torchaudio.load(f)
231
+ if(fs!=48000):
232
+ a = torchaudio.functional.resample(a, fs, 48000)
233
+ if(a.shape[0]==1):
234
+ a = torch.cat([a,a],0)
235
+ ori_len = a.shape[-1]
236
+ filelist.append([a, '', [0, a.shape[-1]/48000.], f,ori_len])
237
+
238
+ reconstructed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "reconstructed")
239
+
240
+ os.makedirs(reconstructed_dir, exist_ok=True)
241
+
242
+ for sample_idx, (orig_samples, lyric, st_et, fname,ori_len) in enumerate(filelist):
243
+ print(fname, lyric)
244
+ wave = mucodec.sound2sound(orig_samples,None)
245
+ wave = wave[:,0:ori_len]
246
+ torchaudio.save(os.path.join(reconstructed_dir, os.path.basename(fname)),wave.detach().cpu(), 48000)
247
+
MuCodec/libs/rvq/__pycache__/descript_quantize3.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
MuCodec/libs/rvq/descript_quantize3.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ def WNConv1d(*args, **kwargs):
11
+ return weight_norm(nn.Conv1d(*args, **kwargs))
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
34
+ self.stale_tolerance = stale_tolerance
35
+
36
+ def forward(self, z):
37
+ """Quantized the input tensor using a fixed codebook and returns
38
+ the corresponding codebook vectors
39
+
40
+ Parameters
41
+ ----------
42
+ z : Tensor[B x D x T]
43
+
44
+ Returns
45
+ -------
46
+ Tensor[B x D x T]
47
+ Quantized continuous representation of input
48
+ Tensor[1]
49
+ Commitment loss to train encoder to predict vectors closer to codebook
50
+ entries
51
+ Tensor[1]
52
+ Codebook loss to update the codebook
53
+ Tensor[B x T]
54
+ Codebook indices (quantized discrete representation of input)
55
+ Tensor[B x D x T]
56
+ Projected latents (continuous representation of input before quantization)
57
+ """
58
+
59
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
60
+ z_e = self.in_proj(z) # z_e : (B x D x T)
61
+ z_q, indices = self.decode_latents(z_e)
62
+
63
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
64
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
65
+
66
+ z_q = (
67
+ z_e + (z_q - z_e).detach()
68
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
69
+
70
+ z_q = self.out_proj(z_q)
71
+
72
+ return z_q, commitment_loss, codebook_loss, indices, z_e
73
+
74
+ def embed_code(self, embed_id):
75
+ return F.embedding(embed_id, self.codebook.weight)
76
+
77
+ def decode_code(self, embed_id):
78
+ return self.embed_code(embed_id).transpose(1, 2)
79
+
80
+ def decode_latents(self, latents):
81
+ encodings = rearrange(latents, "b d t -> (b t) d")
82
+ codebook = self.codebook.weight # codebook: (N x D)
83
+
84
+ # L2 normalize encodings and codebook (ViT-VQGAN)
85
+ encodings = F.normalize(encodings)
86
+ codebook = F.normalize(codebook)
87
+
88
+ # Compute euclidean distance with codebook
89
+ dist = (
90
+ encodings.pow(2).sum(1, keepdim=True)
91
+ - 2 * encodings @ codebook.t()
92
+ + codebook.pow(2).sum(1, keepdim=True).t()
93
+ )
94
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
95
+ z_q = self.decode_code(indices)
96
+
97
+ if(self.training):
98
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
99
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
100
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
101
+
102
+ # random replace codes that haven't been used for a while
103
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
104
+ if replace_code.sum(-1) > 0:
105
+ print("Replace {} codes".format(replace_code.sum(-1)))
106
+ random_input_idx = torch.randperm(encodings.shape[0])
107
+ random_input = encodings[random_input_idx].view(encodings.shape)
108
+ if random_input.shape[0] < self.codebook_size:
109
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
110
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
111
+
112
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
113
+ self.stale_counter = self.stale_counter * (1 - replace_code)
114
+
115
+ return z_q, indices
116
+
117
+
118
+ class ResidualVectorQuantize(nn.Module):
119
+ """
120
+ Introduced in SoundStream: An end2end neural audio codec
121
+ https://arxiv.org/abs/2107.03312
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ input_dim: int = 512,
127
+ n_codebooks: int = 9,
128
+ codebook_size: int = 1024,
129
+ codebook_dim: Union[int, list] = 8,
130
+ quantizer_dropout: float = 0.0,
131
+ stale_tolerance: int = 100,
132
+ ):
133
+ super().__init__()
134
+ if isinstance(codebook_dim, int):
135
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
136
+
137
+ self.n_codebooks = n_codebooks
138
+ self.codebook_dim = codebook_dim
139
+ self.codebook_size = codebook_size
140
+
141
+ self.quantizers = nn.ModuleList(
142
+ [
143
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
144
+ for i in range(n_codebooks)
145
+ ]
146
+ )
147
+ self.quantizer_dropout = quantizer_dropout
148
+
149
+ def forward(self, z, n_quantizers: int = None):
150
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
151
+ the corresponding codebook vectors
152
+ Parameters
153
+ ----------
154
+ z : Tensor[B x D x T]
155
+ n_quantizers : int, optional
156
+ No. of quantizers to use
157
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
158
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
159
+ when in training mode, and a random number of quantizers is used.
160
+ Returns
161
+ -------
162
+ dict
163
+ A dictionary with the following keys:
164
+
165
+ "z" : Tensor[B x D x T]
166
+ Quantized continuous representation of input
167
+ "codes" : Tensor[B x N x T]
168
+ Codebook indices for each codebook
169
+ (quantized discrete representation of input)
170
+ "latents" : Tensor[B x N*D x T]
171
+ Projected latents (continuous representation of input before quantization)
172
+ "vq/commitment_loss" : Tensor[1]
173
+ Commitment loss to train encoder to predict vectors closer to codebook
174
+ entries
175
+ "vq/codebook_loss" : Tensor[1]
176
+ Codebook loss to update the codebook
177
+ """
178
+ z_q = 0
179
+ residual = z
180
+ commitment_loss = 0
181
+ codebook_loss = 0
182
+
183
+ codebook_indices = []
184
+ latents = []
185
+
186
+ if n_quantizers is None:
187
+ n_quantizers = self.n_codebooks
188
+ if self.training:
189
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
190
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
191
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
192
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
193
+ n_quantizers = n_quantizers.to(z.device)
194
+ else:
195
+ n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
196
+ n_quantizers = n_quantizers.to(z.device)
197
+
198
+ for i, quantizer in enumerate(self.quantizers):
199
+ # if self.training is False and i >= n_quantizers:
200
+ # break
201
+
202
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
203
+ residual
204
+ )
205
+
206
+ # Create mask to apply quantizer dropout
207
+ mask = (
208
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
209
+ )
210
+ z_q = z_q + z_q_i * mask[:, None, None]
211
+ residual = residual - z_q_i
212
+
213
+ # Sum losses
214
+ commitment_loss += (commitment_loss_i * mask).mean()
215
+ codebook_loss += (codebook_loss_i * mask).mean()
216
+
217
+ codebook_indices.append(indices_i)
218
+ latents.append(z_e_i)
219
+
220
+ codes = torch.stack(codebook_indices, dim=1)
221
+ latents = torch.cat(latents, dim=1)
222
+
223
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
224
+ for n in range(encodings.shape[1]):
225
+ print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
226
+ (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
227
+ ))
228
+
229
+ return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
230
+
231
+ def from_codes(self, codes: torch.Tensor):
232
+ """Given the quantized codes, reconstruct the continuous representation
233
+ Parameters
234
+ ----------
235
+ codes : Tensor[B x N x T]
236
+ Quantized discrete representation of input
237
+ Returns
238
+ -------
239
+ Tensor[B x D x T]
240
+ Quantized continuous representation of input
241
+ """
242
+ z_q = 0.0
243
+ z_p = []
244
+ n_codebooks = codes.shape[1]
245
+ for i in range(n_codebooks):
246
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
247
+ z_p.append(z_p_i)
248
+
249
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
250
+ z_q = z_q + z_q_i
251
+ return z_q, torch.cat(z_p, dim=1), codes
252
+
253
+ def from_latents(self, latents: torch.Tensor):
254
+ """Given the unquantized latents, reconstruct the
255
+ continuous representation after quantization.
256
+
257
+ Parameters
258
+ ----------
259
+ latents : Tensor[B x N x T]
260
+ Continuous representation of input after projection
261
+
262
+ Returns
263
+ -------
264
+ Tensor[B x D x T]
265
+ Quantized representation of full-projected space
266
+ Tensor[B x D x T]
267
+ Quantized representation of latent space
268
+ """
269
+ z_q = 0
270
+ z_p = []
271
+ codes = []
272
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
273
+
274
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
275
+ 0
276
+ ]
277
+ for i in range(n_codebooks):
278
+ j, k = dims[i], dims[i + 1]
279
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
280
+ z_p.append(z_p_i)
281
+ codes.append(codes_i)
282
+
283
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
284
+ z_q = z_q + z_q_i
285
+
286
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
291
+ x = torch.randn(16, 1024, 80)
292
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
293
+ print(quantized_prompt_embeds.shape)
294
+ print(codes.shape)
295
+ # w/o reconstruction
296
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0
297
+ # w/ reconstruction
298
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
MuCodec/model.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import random
3
+ import inspect
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import typing as tp
7
+ from abc import ABC
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchaudio
13
+
14
+ from einops import repeat
15
+ from tools.torch_tools import wav_to_fbank
16
+ import os
17
+ import diffusers
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from diffusers import DDPMScheduler
20
+ from models.transformer_2d_flow import Transformer2DModel
21
+ from libs.rvq.descript_quantize3 import ResidualVectorQuantize
22
+ from torch.cuda.amp import autocast
23
+ from muq_dev.test import load_model
24
+
25
+
26
+
27
+
28
+ class SampleProcessor(torch.nn.Module):
29
+ def project_sample(self, x: torch.Tensor):
30
+ """Project the original sample to the 'space' where the diffusion will happen."""
31
+ return x
32
+
33
+ def return_sample(self, z: torch.Tensor):
34
+ """Project back from diffusion space to the actual sample space."""
35
+ return z
36
+
37
+ class Feature2DProcessor(SampleProcessor):
38
+ def __init__(self, dim: int = 8, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1., \
39
+ num_samples: int = 100_000):
40
+ super().__init__()
41
+ self.num_samples = num_samples
42
+ self.dim = dim
43
+ self.power_std = power_std
44
+ self.register_buffer('counts', torch.zeros(1))
45
+ self.register_buffer('sum_x', torch.zeros(dim, 32))
46
+ self.register_buffer('sum_x2', torch.zeros(dim, 32))
47
+ self.register_buffer('sum_target_x2', torch.zeros(dim, 32))
48
+ self.counts: torch.Tensor
49
+ self.sum_x: torch.Tensor
50
+ self.sum_x2: torch.Tensor
51
+
52
+ @property
53
+ def mean(self):
54
+ mean = self.sum_x / self.counts
55
+ return mean
56
+
57
+ @property
58
+ def std(self):
59
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
60
+ return std
61
+
62
+ @property
63
+ def target_std(self):
64
+ return 1
65
+
66
+ def project_sample(self, x: torch.Tensor):
67
+ assert x.dim() == 4
68
+ if self.counts.item() < self.num_samples:
69
+ self.counts += len(x)
70
+ self.sum_x += x.mean(dim=(2,)).sum(dim=0)
71
+ self.sum_x2 += x.pow(2).mean(dim=(2,)).sum(dim=0)
72
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
73
+ x = (x - self.mean.view(1, -1, 1, 32).contiguous()) * rescale.view(1, -1, 1, 32).contiguous()
74
+ return x
75
+
76
+ def return_sample(self, x: torch.Tensor):
77
+ assert x.dim() == 4
78
+ rescale = (self.std / self.target_std) ** self.power_std
79
+ x = x * rescale.view(1, -1, 1, 32).contiguous() + self.mean.view(1, -1, 1, 32).contiguous()
80
+ return x
81
+
82
+
83
+ class BASECFM(torch.nn.Module, ABC):
84
+ def __init__(
85
+ self,
86
+ estimator,
87
+ ):
88
+ super().__init__()
89
+ self.sigma_min = 1e-4
90
+
91
+ self.estimator = estimator
92
+
93
+ @torch.inference_mode()
94
+ def forward(self, mu, n_timesteps, temperature=1.0):
95
+ """Forward diffusion
96
+
97
+ Args:
98
+ mu (torch.Tensor): output of encoder
99
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
100
+ n_timesteps (int): number of diffusion steps
101
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
102
+
103
+ Returns:
104
+ sample: generated mel-spectrogram
105
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
106
+ """
107
+ z = torch.randn_like(mu) * temperature
108
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
109
+ return self.solve_euler(z, t_span=t_span)
110
+
111
+ def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, added_cond_kwargs, guidance_scale):
112
+ """
113
+ Fixed euler solver for ODEs.
114
+ Args:
115
+ x (torch.Tensor): random noise
116
+ t_span (torch.Tensor): n_timesteps interpolated
117
+ shape: (n_timesteps + 1,)
118
+ mu (torch.Tensor): output of encoder
119
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
120
+ """
121
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
122
+ noise = x.clone()
123
+
124
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
125
+ # Or in future might add like a return_all_steps flag
126
+ sol = []
127
+
128
+ for step in tqdm(range(1, len(t_span))):
129
+ x[:,:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,:,0:incontext_length,:] + t * incontext_x[:,:,0:incontext_length,:]
130
+ if(guidance_scale > 1.0):
131
+ dphi_dt = self.estimator( \
132
+ torch.cat([ \
133
+ torch.cat([x, x], 0), \
134
+ torch.cat([incontext_x, incontext_x], 0), \
135
+ torch.cat([torch.zeros_like(mu), mu], 0), \
136
+ ], 1), \
137
+ timestep = t.unsqueeze(-1).repeat(2), \
138
+ added_cond_kwargs={k:torch.cat([v,v],0) for k,v in added_cond_kwargs.items()}).sample
139
+ dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
140
+ dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
141
+ else:
142
+ dphi_dt = self.estimator(torch.cat([x, incontext_x, mu], 1), \
143
+ timestep = t.unsqueeze(-1),
144
+ added_cond_kwargs=added_cond_kwargs).sample
145
+
146
+ x = x + dt * dphi_dt
147
+ t = t + dt
148
+ sol.append(x)
149
+ if step < len(t_span) - 1:
150
+ dt = t_span[step + 1] - t
151
+
152
+ return sol[-1]
153
+
154
+
155
+ class PromptCondAudioDiffusion(nn.Module):
156
+ def __init__(
157
+ self,
158
+ num_channels,
159
+ unet_model_name=None,
160
+ unet_model_config_path=None,
161
+ snr_gamma=None,
162
+ uncondition=True,
163
+ out_paint=False,
164
+ ):
165
+ super().__init__()
166
+
167
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
168
+
169
+ self.unet_model_name = unet_model_name
170
+ self.unet_model_config_path = unet_model_config_path
171
+ self.snr_gamma = snr_gamma
172
+ self.uncondition = uncondition
173
+ self.num_channels = num_channels
174
+
175
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
176
+ self.normfeat = Feature2DProcessor(dim=num_channels)
177
+
178
+ self.sample_rate = 48000
179
+ self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
180
+ self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
181
+ muencoder_dir = "muq_dev/muq_fairseq"
182
+ muencoder_ckpt = "muq_dev/muq.pt"
183
+
184
+ self.muencoder = load_model(
185
+ model_dir=os.path.abspath(muencoder_dir),
186
+ checkpoint_dir=os.path.abspath(muencoder_ckpt),
187
+ )
188
+ self.rsq48tomuencoder = torchaudio.transforms.Resample(48000, 24000)
189
+ for v in self.muencoder.parameters():v.requires_grad = False
190
+ self.rvq_muencoder_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
191
+ self.cond_muencoder_emb = nn.Linear(1024, 16*32)
192
+ self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
193
+
194
+ unet = Transformer2DModel.from_config(
195
+ unet_model_config_path,
196
+ )
197
+ self.set_from = "random"
198
+ self.cfm_wrapper = BASECFM(unet)
199
+ print("Transformer initialized from pretrain.")
200
+
201
+
202
+ def compute_snr(self, timesteps):
203
+ """
204
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
205
+ """
206
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
207
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
208
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
209
+
210
+ # Expand the tensors.
211
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
212
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
213
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
214
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
215
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
216
+
217
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
218
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
219
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
220
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
221
+
222
+ # Compute SNR.
223
+ snr = (alpha / sigma) ** 2
224
+ return snr
225
+
226
+ def preprocess_audio(self, input_audios, threshold=0.9):
227
+ assert len(input_audios.shape) == 2, input_audios.shape
228
+ norm_value = torch.ones_like(input_audios[:,0])
229
+ max_volume = input_audios.abs().max(dim=-1)[0]
230
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
231
+ return input_audios/norm_value.unsqueeze(-1)
232
+
233
+
234
+
235
+
236
+ def extract_muencoder_embeds(self, input_audio_0,input_audio_1,layer):
237
+ input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
238
+ input_wav_mean = self.muencoder(self.rsq48tomuencoder(input_wav_mean), features_only = True)
239
+ layer_results = input_wav_mean['layer_results']
240
+ muencoder_emb = layer_results[layer]
241
+ muencoder_emb = muencoder_emb.permute(0,2,1).contiguous()
242
+ return muencoder_emb
243
+
244
+
245
+
246
+
247
+ def init_device_dtype(self, device, dtype):
248
+ self.device = device
249
+ self.dtype = dtype
250
+
251
+ @torch.no_grad()
252
+ def fetch_codes(self, input_audios, additional_feats,layer):
253
+ input_audio_0 = input_audios[[0],:]
254
+ input_audio_1 = input_audios[[1],:]
255
+ input_audio_0 = self.preprocess_audio(input_audio_0)
256
+ input_audio_1 = self.preprocess_audio(input_audio_1)
257
+
258
+ self.muencoder.eval()
259
+
260
+
261
+ muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
262
+ muencoder_emb = muencoder_emb.detach()
263
+
264
+ self.rvq_muencoder_emb.eval()
265
+ quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb)
266
+
267
+
268
+ spk_embeds = None
269
+
270
+
271
+ return [codes_muencoder_emb], [muencoder_emb], spk_embeds
272
+ @torch.no_grad()
273
+ def fetch_codes_batch(self, input_audios, additional_feats,layer):
274
+ input_audio_0 = input_audios[:,0,:]
275
+ input_audio_1 = input_audios[:,1,:]
276
+ input_audio_0 = self.preprocess_audio(input_audio_0)
277
+ input_audio_1 = self.preprocess_audio(input_audio_1)
278
+
279
+ self.muencoder.eval()
280
+
281
+
282
+ muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
283
+ muencoder_emb = muencoder_emb.detach()
284
+
285
+ self.rvq_muencoder_emb.eval()
286
+ quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb) # b,d,t
287
+
288
+ spk_embeds = None
289
+
290
+ return [codes_muencoder_emb], [muencoder_emb], spk_embeds
291
+ @torch.no_grad()
292
+ def inference_codes(self, codes, spk_embeds, true_latents, latent_length,incontext_length, additional_feats,
293
+ guidance_scale=2, num_steps=20,
294
+ disable_progress=True, scenario='start_seg'):
295
+ classifier_free_guidance = guidance_scale > 1.0
296
+ device = self.device
297
+ dtype = self.dtype
298
+ codes_muencoder_emb = codes[0]
299
+
300
+
301
+ batch_size = codes_muencoder_emb.shape[0]
302
+
303
+
304
+ quantized_muencoder_emb,_,_=self.rvq_muencoder_emb.from_codes(codes_muencoder_emb)
305
+
306
+ quantized_muencoder_emb = self.cond_muencoder_emb(quantized_muencoder_emb.permute(0,2,1)) # b t 16*32
307
+ quantized_muencoder_emb = quantized_muencoder_emb.reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2, 16, 32).reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2*16, 32).permute(0,2,1,3).contiguous() # b 32 t f
308
+
309
+
310
+ num_frames = quantized_muencoder_emb.shape[-2]
311
+
312
+ num_channels_latents = self.num_channels
313
+ latents = self.prepare_latents(batch_size, num_frames, num_channels_latents, dtype, device)
314
+
315
+ bsz, _, height, width = latents.shape
316
+ resolution = torch.tensor([height, width]).repeat(bsz, 1)
317
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(bsz, 1)
318
+ resolution = resolution.to(dtype=quantized_muencoder_emb.dtype, device=device)
319
+ aspect_ratio = aspect_ratio.to(dtype=quantized_muencoder_emb.dtype, device=device)
320
+ if classifier_free_guidance:
321
+ resolution = torch.cat([resolution, resolution], 0)
322
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], 0)
323
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
324
+
325
+ latent_masks = torch.zeros(latents.shape[0], latents.shape[2], dtype=torch.int64, device=latents.device)
326
+ latent_masks[:,0:latent_length] = 2
327
+ if(scenario=='other_seg'):
328
+ latent_masks[:,0:incontext_length] = 1
329
+
330
+
331
+
332
+ quantized_muencoder_emb = (latent_masks > 0.5).unsqueeze(1).unsqueeze(-1) * quantized_muencoder_emb \
333
+ + (latent_masks < 0.5).unsqueeze(1).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,32,1,32)
334
+ true_latents = self.normfeat.project_sample(true_latents)
335
+ incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(1).unsqueeze(-1).float()
336
+ incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
337
+
338
+ additional_model_input = torch.cat([quantized_muencoder_emb],1)
339
+
340
+ temperature = 1.0
341
+ t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_muencoder_emb.device)
342
+ latents = self.cfm_wrapper.solve_euler(latents * temperature, incontext_latents, incontext_length, t_span, additional_model_input, added_cond_kwargs, guidance_scale)
343
+
344
+ latents[:,:,0:incontext_length,:] = incontext_latents[:,:,0:incontext_length,:]
345
+ latents = self.normfeat.return_sample(latents)
346
+ return latents
347
+
348
+ @torch.no_grad()
349
+ def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
350
+ disable_progress=True,layer=5,scenario='start_seg'):
351
+ codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
352
+
353
+ latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
354
+ guidance_scale=guidance_scale, num_steps=num_steps, \
355
+ disable_progress=disable_progress,scenario=scenario)
356
+ return latents
357
+
358
+ def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
359
+ divisor = 4
360
+ shape = (batch_size, num_channels_latents, num_frames, 32)
361
+ if(num_frames%divisor>0):
362
+ num_frames = round(num_frames/float(divisor))*divisor
363
+ shape = (batch_size, num_channels_latents, num_frames, 32)
364
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
365
+ return latents
366
+
367
+
MuCodec/models/attention.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import USE_PEFT_BACKEND
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
23
+ from diffusers.models.attention_processor import Attention
24
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.lora import LoRACompatibleLinear
26
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
27
+
28
+
29
+ def _chunked_feed_forward(
30
+ ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
31
+ ):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ if lora_scale is None:
40
+ ff_output = torch.cat(
41
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
42
+ dim=chunk_dim,
43
+ )
44
+ else:
45
+ # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
46
+ ff_output = torch.cat(
47
+ [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
48
+ dim=chunk_dim,
49
+ )
50
+
51
+ return ff_output
52
+
53
+
54
+ @maybe_allow_in_graph
55
+ class GatedSelfAttentionDense(nn.Module):
56
+ r"""
57
+ A gated self-attention dense layer that combines visual features and object features.
58
+
59
+ Parameters:
60
+ query_dim (`int`): The number of channels in the query.
61
+ context_dim (`int`): The number of channels in the context.
62
+ n_heads (`int`): The number of heads to use for attention.
63
+ d_head (`int`): The number of channels in each head.
64
+ """
65
+
66
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
67
+ super().__init__()
68
+
69
+ # we need a linear projection since we need cat visual feature and obj feature
70
+ self.linear = nn.Linear(context_dim, query_dim)
71
+
72
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
73
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
74
+
75
+ self.norm1 = nn.LayerNorm(query_dim)
76
+ self.norm2 = nn.LayerNorm(query_dim)
77
+
78
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
79
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
80
+
81
+ self.enabled = True
82
+
83
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
84
+ if not self.enabled:
85
+ return x
86
+
87
+ n_visual = x.shape[1]
88
+ objs = self.linear(objs)
89
+
90
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
91
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
92
+
93
+ return x
94
+
95
+
96
+ @maybe_allow_in_graph
97
+ class BasicTransformerBlock(nn.Module):
98
+ r"""
99
+ A basic Transformer block.
100
+
101
+ Parameters:
102
+ dim (`int`): The number of channels in the input and output.
103
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
104
+ attention_head_dim (`int`): The number of channels in each head.
105
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
106
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
107
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
108
+ num_embeds_ada_norm (:
109
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
110
+ attention_bias (:
111
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
112
+ only_cross_attention (`bool`, *optional*):
113
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
114
+ double_self_attention (`bool`, *optional*):
115
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
116
+ upcast_attention (`bool`, *optional*):
117
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
118
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
119
+ Whether to use learnable elementwise affine parameters for normalization.
120
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
121
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
122
+ final_dropout (`bool` *optional*, defaults to False):
123
+ Whether to apply a final dropout after the last feed-forward layer.
124
+ attention_type (`str`, *optional*, defaults to `"default"`):
125
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
126
+ positional_embeddings (`str`, *optional*, defaults to `None`):
127
+ The type of positional embeddings to apply to.
128
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
129
+ The maximum number of positional embeddings to apply.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ num_attention_heads: int,
136
+ attention_head_dim: int,
137
+ dropout=0.0,
138
+ cross_attention_dim: Optional[int] = None,
139
+ activation_fn: str = "geglu",
140
+ num_embeds_ada_norm: Optional[int] = None,
141
+ attention_bias: bool = False,
142
+ only_cross_attention: bool = False,
143
+ double_self_attention: bool = False,
144
+ upcast_attention: bool = False,
145
+ norm_elementwise_affine: bool = True,
146
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
147
+ norm_eps: float = 1e-5,
148
+ final_dropout: bool = False,
149
+ attention_type: str = "default",
150
+ positional_embeddings: Optional[str] = None,
151
+ num_positional_embeddings: Optional[int] = None,
152
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
153
+ ada_norm_bias: Optional[int] = None,
154
+ ff_inner_dim: Optional[int] = None,
155
+ ff_bias: bool = True,
156
+ attention_out_bias: bool = True,
157
+ ):
158
+ super().__init__()
159
+ self.only_cross_attention = only_cross_attention
160
+
161
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
162
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
163
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
164
+ self.use_layer_norm = norm_type == "layer_norm"
165
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
166
+
167
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
168
+ raise ValueError(
169
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
170
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
171
+ )
172
+
173
+ if positional_embeddings and (num_positional_embeddings is None):
174
+ raise ValueError(
175
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
176
+ )
177
+
178
+ if positional_embeddings == "sinusoidal":
179
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
180
+ else:
181
+ self.pos_embed = None
182
+
183
+ # Define 3 blocks. Each block has its own normalization layer.
184
+ # 1. Self-Attn
185
+ if self.use_ada_layer_norm:
186
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
187
+ elif self.use_ada_layer_norm_zero:
188
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
189
+ elif self.use_ada_layer_norm_continuous:
190
+ self.norm1 = AdaLayerNormContinuous(
191
+ dim,
192
+ ada_norm_continous_conditioning_embedding_dim,
193
+ norm_elementwise_affine,
194
+ norm_eps,
195
+ ada_norm_bias,
196
+ "rms_norm",
197
+ )
198
+ else:
199
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
200
+
201
+ self.attn1 = Attention(
202
+ query_dim=dim,
203
+ heads=num_attention_heads,
204
+ dim_head=attention_head_dim,
205
+ dropout=dropout,
206
+ bias=attention_bias,
207
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
208
+ upcast_attention=upcast_attention,
209
+ out_bias=attention_out_bias,
210
+ )
211
+
212
+ # 2. Cross-Attn
213
+ if cross_attention_dim is not None or double_self_attention:
214
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
215
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
216
+ # the second cross attention block.
217
+ if self.use_ada_layer_norm:
218
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
219
+ elif self.use_ada_layer_norm_continuous:
220
+ self.norm2 = AdaLayerNormContinuous(
221
+ dim,
222
+ ada_norm_continous_conditioning_embedding_dim,
223
+ norm_elementwise_affine,
224
+ norm_eps,
225
+ ada_norm_bias,
226
+ "rms_norm",
227
+ )
228
+ else:
229
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
230
+
231
+ self.attn2 = Attention(
232
+ query_dim=dim,
233
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
234
+ heads=num_attention_heads,
235
+ dim_head=attention_head_dim,
236
+ dropout=dropout,
237
+ bias=attention_bias,
238
+ upcast_attention=upcast_attention,
239
+ out_bias=attention_out_bias,
240
+ ) # is self-attn if encoder_hidden_states is none
241
+ else:
242
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
243
+ self.attn2 = None
244
+
245
+ # 3. Feed-forward
246
+ if self.use_ada_layer_norm_continuous:
247
+ self.norm3 = AdaLayerNormContinuous(
248
+ dim,
249
+ ada_norm_continous_conditioning_embedding_dim,
250
+ norm_elementwise_affine,
251
+ norm_eps,
252
+ ada_norm_bias,
253
+ "layer_norm",
254
+ )
255
+ elif not self.use_ada_layer_norm_single:
256
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
257
+
258
+ self.ff = FeedForward(
259
+ dim,
260
+ dropout=dropout,
261
+ activation_fn=activation_fn,
262
+ final_dropout=final_dropout,
263
+ inner_dim=ff_inner_dim,
264
+ bias=ff_bias,
265
+ )
266
+
267
+ # 4. Fuser
268
+ if attention_type == "gated" or attention_type == "gated-text-image":
269
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
270
+
271
+ # 5. Scale-shift for PixArt-Alpha.
272
+ if self.use_ada_layer_norm_single:
273
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
274
+
275
+ # let chunk size default to None
276
+ self._chunk_size = None
277
+ self._chunk_dim = 0
278
+
279
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
280
+ # Sets chunk feed-forward
281
+ self._chunk_size = chunk_size
282
+ self._chunk_dim = dim
283
+
284
+ def forward(
285
+ self,
286
+ hidden_states: torch.FloatTensor,
287
+ attention_mask: Optional[torch.FloatTensor] = None,
288
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
289
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
290
+ timestep: Optional[torch.LongTensor] = None,
291
+ cross_attention_kwargs: Dict[str, Any] = None,
292
+ class_labels: Optional[torch.LongTensor] = None,
293
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
294
+ ) -> torch.FloatTensor:
295
+ # Notice that normalization is always applied before the real computation in the following blocks.
296
+ # 0. Self-Attention
297
+ batch_size = hidden_states.shape[0]
298
+
299
+ if self.use_ada_layer_norm:
300
+ norm_hidden_states = self.norm1(hidden_states, timestep)
301
+ elif self.use_ada_layer_norm_zero:
302
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
303
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
304
+ )
305
+ elif self.use_layer_norm:
306
+ norm_hidden_states = self.norm1(hidden_states)
307
+ elif self.use_ada_layer_norm_continuous:
308
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
309
+ elif self.use_ada_layer_norm_single:
310
+ # print("Using PixArt-Alpha norm")
311
+ # print("time step: ", timestep.shape)
312
+ # print("self.scale_shift_table: ", self.scale_shift_table.shape)
313
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
314
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
315
+ ).chunk(6, dim=1)
316
+ norm_hidden_states = self.norm1(hidden_states)
317
+ # print("scale_msa: ", scale_msa.shape)
318
+ # print("shift_msa: ", shift_msa.shape)
319
+ #scale_msa: torch.Size([5, 1, 1152])
320
+ #shift_msa: torch.Size([5, 1, 1152])
321
+ # exit()
322
+ # print("before: ", norm_hidden_states.shape)
323
+ #before: torch.Size([5, 3584, 1152])
324
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
325
+ # print("after: ", norm_hidden_states.shape)
326
+ #before: torch.Size([5, 3584, 1152])
327
+ # exit()
328
+ norm_hidden_states = norm_hidden_states.squeeze(1)
329
+ else:
330
+ raise ValueError("Incorrect norm used")
331
+
332
+ if self.pos_embed is not None:
333
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
334
+
335
+
336
+ # 1. Retrieve lora scale.
337
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
338
+
339
+ # 2. Prepare GLIGEN inputs
340
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
341
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
342
+
343
+ attn_output = self.attn1(
344
+ norm_hidden_states,
345
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
346
+ attention_mask=attention_mask,
347
+ **cross_attention_kwargs,
348
+ )
349
+ if self.use_ada_layer_norm_zero:
350
+ attn_output = gate_msa.unsqueeze(1) * attn_output
351
+ elif self.use_ada_layer_norm_single:
352
+ attn_output = gate_msa * attn_output
353
+
354
+ hidden_states = attn_output + hidden_states
355
+ if hidden_states.ndim == 4:
356
+ hidden_states = hidden_states.squeeze(1)
357
+
358
+ # 2.5 GLIGEN Control
359
+ if gligen_kwargs is not None:
360
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
361
+
362
+ # 3. Cross-Attention
363
+ if self.attn2 is not None:
364
+ if self.use_ada_layer_norm:
365
+ norm_hidden_states = self.norm2(hidden_states, timestep)
366
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
367
+ norm_hidden_states = self.norm2(hidden_states)
368
+ elif self.use_ada_layer_norm_single:
369
+ # For PixArt norm2 isn't applied here:
370
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
371
+ norm_hidden_states = hidden_states
372
+ elif self.use_ada_layer_norm_continuous:
373
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
374
+ else:
375
+ raise ValueError("Incorrect norm")
376
+
377
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
378
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
379
+
380
+ attn_output = self.attn2(
381
+ norm_hidden_states,
382
+ encoder_hidden_states=encoder_hidden_states,
383
+ attention_mask=encoder_attention_mask,
384
+ **cross_attention_kwargs,
385
+ )
386
+ hidden_states = attn_output + hidden_states
387
+
388
+ # 4. Feed-forward
389
+ if self.use_ada_layer_norm_continuous:
390
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
391
+ elif not self.use_ada_layer_norm_single:
392
+ norm_hidden_states = self.norm3(hidden_states)
393
+
394
+ if self.use_ada_layer_norm_zero:
395
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
396
+
397
+ if self.use_ada_layer_norm_single:
398
+ norm_hidden_states = self.norm2(hidden_states)
399
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
400
+
401
+ if self._chunk_size is not None:
402
+ # "feed_forward_chunk_size" can be used to save memory
403
+ ff_output = _chunked_feed_forward(
404
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
405
+ )
406
+ else:
407
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
408
+
409
+ if self.use_ada_layer_norm_zero:
410
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
411
+ elif self.use_ada_layer_norm_single:
412
+ ff_output = gate_mlp * ff_output
413
+
414
+ hidden_states = ff_output + hidden_states
415
+ if hidden_states.ndim == 4:
416
+ hidden_states = hidden_states.squeeze(1)
417
+
418
+ return hidden_states
419
+
420
+
421
+ @maybe_allow_in_graph
422
+ class TemporalBasicTransformerBlock(nn.Module):
423
+ r"""
424
+ A basic Transformer block for video like data.
425
+
426
+ Parameters:
427
+ dim (`int`): The number of channels in the input and output.
428
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
429
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
430
+ attention_head_dim (`int`): The number of channels in each head.
431
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
432
+ """
433
+
434
+ def __init__(
435
+ self,
436
+ dim: int,
437
+ time_mix_inner_dim: int,
438
+ num_attention_heads: int,
439
+ attention_head_dim: int,
440
+ cross_attention_dim: Optional[int] = None,
441
+ ):
442
+ super().__init__()
443
+ self.is_res = dim == time_mix_inner_dim
444
+
445
+ self.norm_in = nn.LayerNorm(dim)
446
+
447
+ # Define 3 blocks. Each block has its own normalization layer.
448
+ # 1. Self-Attn
449
+ self.norm_in = nn.LayerNorm(dim)
450
+ self.ff_in = FeedForward(
451
+ dim,
452
+ dim_out=time_mix_inner_dim,
453
+ activation_fn="geglu",
454
+ )
455
+
456
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
457
+ self.attn1 = Attention(
458
+ query_dim=time_mix_inner_dim,
459
+ heads=num_attention_heads,
460
+ dim_head=attention_head_dim,
461
+ cross_attention_dim=None,
462
+ )
463
+
464
+ # 2. Cross-Attn
465
+ if cross_attention_dim is not None:
466
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
467
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
468
+ # the second cross attention block.
469
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
470
+ self.attn2 = Attention(
471
+ query_dim=time_mix_inner_dim,
472
+ cross_attention_dim=cross_attention_dim,
473
+ heads=num_attention_heads,
474
+ dim_head=attention_head_dim,
475
+ ) # is self-attn if encoder_hidden_states is none
476
+ else:
477
+ self.norm2 = None
478
+ self.attn2 = None
479
+
480
+ # 3. Feed-forward
481
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
482
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
483
+
484
+ # let chunk size default to None
485
+ self._chunk_size = None
486
+ self._chunk_dim = None
487
+
488
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
489
+ # Sets chunk feed-forward
490
+ self._chunk_size = chunk_size
491
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
492
+ self._chunk_dim = 1
493
+
494
+ def forward(
495
+ self,
496
+ hidden_states: torch.FloatTensor,
497
+ num_frames: int,
498
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
499
+ ) -> torch.FloatTensor:
500
+ # Notice that normalization is always applied before the real computation in the following blocks.
501
+ # 0. Self-Attention
502
+ batch_size = hidden_states.shape[0]
503
+
504
+ batch_frames, seq_length, channels = hidden_states.shape
505
+ batch_size = batch_frames // num_frames
506
+
507
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
508
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
509
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
510
+
511
+ residual = hidden_states
512
+ hidden_states = self.norm_in(hidden_states)
513
+
514
+ if self._chunk_size is not None:
515
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
516
+ else:
517
+ hidden_states = self.ff_in(hidden_states)
518
+
519
+ if self.is_res:
520
+ hidden_states = hidden_states + residual
521
+
522
+ norm_hidden_states = self.norm1(hidden_states)
523
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
524
+ hidden_states = attn_output + hidden_states
525
+
526
+ # 3. Cross-Attention
527
+ if self.attn2 is not None:
528
+ norm_hidden_states = self.norm2(hidden_states)
529
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
530
+ hidden_states = attn_output + hidden_states
531
+
532
+ # 4. Feed-forward
533
+ norm_hidden_states = self.norm3(hidden_states)
534
+
535
+ if self._chunk_size is not None:
536
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
537
+ else:
538
+ ff_output = self.ff(norm_hidden_states)
539
+
540
+ if self.is_res:
541
+ hidden_states = ff_output + hidden_states
542
+ else:
543
+ hidden_states = ff_output
544
+
545
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
546
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
547
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
548
+
549
+ return hidden_states
550
+
551
+
552
+ class SkipFFTransformerBlock(nn.Module):
553
+ def __init__(
554
+ self,
555
+ dim: int,
556
+ num_attention_heads: int,
557
+ attention_head_dim: int,
558
+ kv_input_dim: int,
559
+ kv_input_dim_proj_use_bias: bool,
560
+ dropout=0.0,
561
+ cross_attention_dim: Optional[int] = None,
562
+ attention_bias: bool = False,
563
+ attention_out_bias: bool = True,
564
+ ):
565
+ super().__init__()
566
+ if kv_input_dim != dim:
567
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
568
+ else:
569
+ self.kv_mapper = None
570
+
571
+ self.norm1 = RMSNorm(dim, 1e-06)
572
+
573
+ self.attn1 = Attention(
574
+ query_dim=dim,
575
+ heads=num_attention_heads,
576
+ dim_head=attention_head_dim,
577
+ dropout=dropout,
578
+ bias=attention_bias,
579
+ cross_attention_dim=cross_attention_dim,
580
+ out_bias=attention_out_bias,
581
+ )
582
+
583
+ self.norm2 = RMSNorm(dim, 1e-06)
584
+
585
+ self.attn2 = Attention(
586
+ query_dim=dim,
587
+ cross_attention_dim=cross_attention_dim,
588
+ heads=num_attention_heads,
589
+ dim_head=attention_head_dim,
590
+ dropout=dropout,
591
+ bias=attention_bias,
592
+ out_bias=attention_out_bias,
593
+ )
594
+
595
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
596
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
597
+
598
+ if self.kv_mapper is not None:
599
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
600
+
601
+ norm_hidden_states = self.norm1(hidden_states)
602
+
603
+ attn_output = self.attn1(
604
+ norm_hidden_states,
605
+ encoder_hidden_states=encoder_hidden_states,
606
+ **cross_attention_kwargs,
607
+ )
608
+
609
+ hidden_states = attn_output + hidden_states
610
+
611
+ norm_hidden_states = self.norm2(hidden_states)
612
+
613
+ attn_output = self.attn2(
614
+ norm_hidden_states,
615
+ encoder_hidden_states=encoder_hidden_states,
616
+ **cross_attention_kwargs,
617
+ )
618
+
619
+ hidden_states = attn_output + hidden_states
620
+
621
+ return hidden_states
622
+
623
+
624
+ class FeedForward(nn.Module):
625
+ r"""
626
+ A feed-forward layer.
627
+
628
+ Parameters:
629
+ dim (`int`): The number of channels in the input.
630
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
631
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
632
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
633
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
634
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
635
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
636
+ """
637
+
638
+ def __init__(
639
+ self,
640
+ dim: int,
641
+ dim_out: Optional[int] = None,
642
+ mult: int = 4,
643
+ dropout: float = 0.0,
644
+ activation_fn: str = "geglu",
645
+ final_dropout: bool = False,
646
+ inner_dim=None,
647
+ bias: bool = True,
648
+ ):
649
+ super().__init__()
650
+ if inner_dim is None:
651
+ inner_dim = int(dim * mult)
652
+ dim_out = dim_out if dim_out is not None else dim
653
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
654
+
655
+ if activation_fn == "gelu":
656
+ act_fn = GELU(dim, inner_dim, bias=bias)
657
+ if activation_fn == "gelu-approximate":
658
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
659
+ elif activation_fn == "geglu":
660
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
661
+ elif activation_fn == "geglu-approximate":
662
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
663
+
664
+ self.net = nn.ModuleList([])
665
+ # project in
666
+ self.net.append(act_fn)
667
+ # project dropout
668
+ self.net.append(nn.Dropout(dropout))
669
+ # project out
670
+ self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
671
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
672
+ if final_dropout:
673
+ self.net.append(nn.Dropout(dropout))
674
+
675
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
676
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
677
+ for module in self.net:
678
+ if isinstance(module, compatible_cls):
679
+ hidden_states = module(hidden_states, scale)
680
+ else:
681
+ hidden_states = module(hidden_states)
682
+ return hidden_states
MuCodec/models/transformer_2d_flow.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ import math
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
24
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
25
+ from models.attention import BasicTransformerBlock
26
+ from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.models.embeddings import TimestepEmbedding
30
+
31
+ class PixArtAlphaCombinedFlowEmbeddings(nn.Module):
32
+ """
33
+ For PixArt-Alpha.
34
+
35
+ Reference:
36
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
37
+ """
38
+
39
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
40
+ super().__init__()
41
+
42
+ self.flow_t_size = 512
43
+ self.outdim = size_emb_dim
44
+ self.timestep_embedder = TimestepEmbedding(in_channels=self.flow_t_size, time_embed_dim=embedding_dim)
45
+
46
+ self.use_additional_conditions = use_additional_conditions
47
+ if use_additional_conditions:
48
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
49
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
50
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
51
+
52
+ # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87
53
+ def timestep_embedding(self, timesteps, max_period=10000, scale=1000):
54
+ """Create sinusoidal timestep embeddings.
55
+
56
+ :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
57
+ :param dim: the dimension of the output.
58
+ :param max_period: controls the minimum frequency of the embeddings.
59
+ :return: an [N x dim] Tensor of positional embeddings.
60
+ """
61
+ half = self.flow_t_size // 2
62
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type())
63
+ args = timesteps[:, None] * freqs[None] * scale
64
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
65
+ if self.flow_t_size % 2:
66
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
67
+ return embedding
68
+
69
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
70
+ timesteps_proj = self.timestep_embedding(timestep)
71
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
72
+
73
+ if self.use_additional_conditions:
74
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
75
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
76
+ aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
77
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
78
+ conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
79
+ else:
80
+ conditioning = timesteps_emb
81
+
82
+ return conditioning
83
+
84
+ class AdaLayerNormSingleFlow(nn.Module):
85
+ r"""
86
+ Norm layer adaptive layer norm single (adaLN-single).
87
+
88
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
89
+
90
+ Parameters:
91
+ embedding_dim (`int`): The size of each embedding vector.
92
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
93
+ """
94
+
95
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
96
+ super().__init__()
97
+
98
+ self.emb = PixArtAlphaCombinedFlowEmbeddings(
99
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
100
+ )
101
+
102
+ self.silu = nn.SiLU()
103
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
104
+
105
+ def forward(
106
+ self,
107
+ timestep: torch.Tensor,
108
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
109
+ batch_size: Optional[int] = None,
110
+ hidden_dtype: Optional[torch.dtype] = None,
111
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
112
+ # No modulation happening here.
113
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
114
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
115
+
116
+
117
+ @dataclass
118
+ class Transformer2DModelOutput(BaseOutput):
119
+ """
120
+ The output of [`Transformer2DModel`].
121
+
122
+ Args:
123
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
124
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
125
+ distributions for the unnoised latent pixels.
126
+ """
127
+
128
+ sample: torch.FloatTensor
129
+
130
+
131
+ class Transformer2DModel(ModelMixin, ConfigMixin):
132
+ """
133
+ A 2D Transformer model for image-like data.
134
+
135
+ Parameters:
136
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
137
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
138
+ in_channels (`int`, *optional*):
139
+ The number of channels in the input and output (specify if the input is **continuous**).
140
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
141
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
142
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
143
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
144
+ This is fixed during training since it is used to learn a number of position embeddings.
145
+ num_vector_embeds (`int`, *optional*):
146
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
147
+ Includes the class for the masked latent pixel.
148
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
149
+ num_embeds_ada_norm ( `int`, *optional*):
150
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
151
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
152
+ added to the hidden states.
153
+
154
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
155
+ attention_bias (`bool`, *optional*):
156
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
157
+ """
158
+
159
+ _supports_gradient_checkpointing = True
160
+
161
+ @register_to_config
162
+ def __init__(
163
+ self,
164
+ num_attention_heads: int = 16,
165
+ attention_head_dim: int = 88,
166
+ in_channels: Optional[int] = None,
167
+ out_channels: Optional[int] = None,
168
+ num_layers: int = 1,
169
+ dropout: float = 0.0,
170
+ norm_num_groups: int = 32,
171
+ cross_attention_dim: Optional[int] = None,
172
+ attention_bias: bool = False,
173
+ sample_size: Optional[int] = None,
174
+ num_vector_embeds: Optional[int] = None,
175
+ patch_size: Optional[int] = None,
176
+ activation_fn: str = "geglu",
177
+ num_embeds_ada_norm: Optional[int] = None,
178
+ use_linear_projection: bool = False,
179
+ only_cross_attention: bool = False,
180
+ double_self_attention: bool = False,
181
+ upcast_attention: bool = False,
182
+ norm_type: str = "layer_norm",
183
+ norm_elementwise_affine: bool = True,
184
+ norm_eps: float = 1e-5,
185
+ attention_type: str = "default",
186
+ caption_channels: int = None,
187
+ ):
188
+ super().__init__()
189
+ self.use_linear_projection = use_linear_projection
190
+ self.num_attention_heads = num_attention_heads
191
+ self.attention_head_dim = attention_head_dim
192
+ inner_dim = num_attention_heads * attention_head_dim
193
+
194
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
195
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
196
+
197
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
198
+ # Define whether input is continuous or discrete depending on configuration
199
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
200
+ self.is_input_vectorized = num_vector_embeds is not None
201
+ self.is_input_patches = in_channels is not None and patch_size is not None
202
+
203
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
204
+ deprecation_message = (
205
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
206
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
207
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
208
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
209
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
210
+ )
211
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
212
+ norm_type = "ada_norm"
213
+
214
+ if self.is_input_continuous and self.is_input_vectorized:
215
+ raise ValueError(
216
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
217
+ " sure that either `in_channels` or `num_vector_embeds` is None."
218
+ )
219
+ elif self.is_input_vectorized and self.is_input_patches:
220
+ raise ValueError(
221
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
222
+ " sure that either `num_vector_embeds` or `num_patches` is None."
223
+ )
224
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
225
+ raise ValueError(
226
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
227
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
228
+ )
229
+
230
+ # 2. Define input layers
231
+ if self.is_input_continuous:
232
+ self.in_channels = in_channels
233
+
234
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
235
+ if use_linear_projection:
236
+ self.proj_in = linear_cls(in_channels, inner_dim)
237
+ else:
238
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
239
+ elif self.is_input_vectorized:
240
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
241
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
242
+
243
+ self.height = sample_size
244
+ self.width = sample_size
245
+ self.num_vector_embeds = num_vector_embeds
246
+ self.num_latent_pixels = self.height * self.width
247
+
248
+ self.latent_image_embedding = ImagePositionalEmbeddings(
249
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
250
+ )
251
+ elif self.is_input_patches:
252
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
253
+
254
+ self.height = sample_size
255
+ self.width = sample_size
256
+
257
+ self.patch_size = patch_size
258
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
259
+ interpolation_scale = max(interpolation_scale, 1)
260
+ self.pos_embed = PatchEmbed(
261
+ height=sample_size,
262
+ width=sample_size,
263
+ patch_size=patch_size,
264
+ in_channels=in_channels,
265
+ embed_dim=inner_dim,
266
+ interpolation_scale=interpolation_scale,
267
+ )
268
+
269
+ # 3. Define transformers blocks
270
+ self.transformer_blocks = nn.ModuleList(
271
+ [
272
+ BasicTransformerBlock(
273
+ inner_dim,
274
+ num_attention_heads,
275
+ attention_head_dim,
276
+ dropout=dropout,
277
+ cross_attention_dim=cross_attention_dim,
278
+ activation_fn=activation_fn,
279
+ num_embeds_ada_norm=num_embeds_ada_norm,
280
+ attention_bias=attention_bias,
281
+ only_cross_attention=only_cross_attention,
282
+ double_self_attention=double_self_attention,
283
+ upcast_attention=upcast_attention,
284
+ norm_type=norm_type,
285
+ norm_elementwise_affine=norm_elementwise_affine,
286
+ norm_eps=norm_eps,
287
+ attention_type=attention_type,
288
+ )
289
+ for d in range(num_layers)
290
+ ]
291
+ )
292
+
293
+ # 4. Define output layers
294
+ self.out_channels = in_channels if out_channels is None else out_channels
295
+ if self.is_input_continuous:
296
+ # TODO: should use out_channels for continuous projections
297
+ if use_linear_projection:
298
+ self.proj_out = linear_cls(inner_dim, in_channels)
299
+ else:
300
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
301
+ elif self.is_input_vectorized:
302
+ self.norm_out = nn.LayerNorm(inner_dim)
303
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
304
+ elif self.is_input_patches and norm_type != "ada_norm_single":
305
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
306
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
307
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
308
+ elif self.is_input_patches and norm_type == "ada_norm_single":
309
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
310
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
311
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
312
+
313
+ # 5. PixArt-Alpha blocks.
314
+ self.adaln_single = None
315
+ self.use_additional_conditions = False
316
+ if norm_type == "ada_norm_single":
317
+ self.use_additional_conditions = self.config.sample_size == 128
318
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
319
+ # additional conditions until we find better name
320
+ self.adaln_single = AdaLayerNormSingleFlow(inner_dim, use_additional_conditions=self.use_additional_conditions)
321
+
322
+ self.caption_projection = None
323
+ if caption_channels is not None:
324
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
325
+
326
+ self.gradient_checkpointing = False
327
+
328
+ def _set_gradient_checkpointing(self, module, value=False):
329
+ if hasattr(module, "gradient_checkpointing"):
330
+ module.gradient_checkpointing = value
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: torch.Tensor,
335
+ encoder_hidden_states: Optional[torch.Tensor] = None,
336
+ timestep: Optional[torch.LongTensor] = None,
337
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
338
+ class_labels: Optional[torch.LongTensor] = None,
339
+ cross_attention_kwargs: Dict[str, Any] = None,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ encoder_attention_mask: Optional[torch.Tensor] = None,
342
+ return_dict: bool = True,
343
+ ):
344
+ """
345
+ The [`Transformer2DModel`] forward method.
346
+
347
+ Args:
348
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
349
+ Input `hidden_states`.
350
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
351
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
352
+ self-attention.
353
+ timestep ( `torch.LongTensor`, *optional*):
354
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
355
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
356
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
357
+ `AdaLayerZeroNorm`.
358
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
359
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
360
+ `self.processor` in
361
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
362
+ attention_mask ( `torch.Tensor`, *optional*):
363
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
364
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
365
+ negative values to the attention scores corresponding to "discard" tokens.
366
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
367
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
368
+
369
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
370
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
371
+
372
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
373
+ above. This bias will be added to the cross-attention scores.
374
+ return_dict (`bool`, *optional*, defaults to `True`):
375
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
376
+ tuple.
377
+
378
+ Returns:
379
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
380
+ `tuple` where the first element is the sample tensor.
381
+ """
382
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
383
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
384
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
385
+ # expects mask of shape:
386
+ # [batch, key_tokens]
387
+ # adds singleton query_tokens dimension:
388
+ # [batch, 1, key_tokens]
389
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
390
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
391
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
392
+ if attention_mask is not None and attention_mask.ndim == 2:
393
+ # assume that mask is expressed as:
394
+ # (1 = keep, 0 = discard)
395
+ # convert mask into a bias that can be added to attention scores:
396
+ # (keep = +0, discard = -10000.0)
397
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
398
+ attention_mask = attention_mask.unsqueeze(1)
399
+
400
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
401
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
402
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
403
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
404
+
405
+ # Retrieve lora scale.
406
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
407
+
408
+ # 1. Input
409
+ if self.is_input_continuous:
410
+ batch, _, height, width = hidden_states.shape
411
+ residual = hidden_states
412
+
413
+ hidden_states = self.norm(hidden_states)
414
+ if not self.use_linear_projection:
415
+ hidden_states = (
416
+ self.proj_in(hidden_states, scale=lora_scale)
417
+ if not USE_PEFT_BACKEND
418
+ else self.proj_in(hidden_states)
419
+ )
420
+ inner_dim = hidden_states.shape[1]
421
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
422
+ else:
423
+ inner_dim = hidden_states.shape[1]
424
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
425
+ hidden_states = (
426
+ self.proj_in(hidden_states, scale=lora_scale)
427
+ if not USE_PEFT_BACKEND
428
+ else self.proj_in(hidden_states)
429
+ )
430
+
431
+ elif self.is_input_vectorized:
432
+ hidden_states = self.latent_image_embedding(hidden_states)
433
+ elif self.is_input_patches:
434
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
435
+ hidden_states = self.pos_embed(hidden_states)
436
+
437
+ if self.adaln_single is not None:
438
+ if self.use_additional_conditions and added_cond_kwargs is None:
439
+ raise ValueError(
440
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
441
+ )
442
+ batch_size = hidden_states.shape[0]
443
+ timestep, embedded_timestep = self.adaln_single(
444
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
445
+ )
446
+
447
+ # 2. Blocks
448
+ if self.caption_projection is not None:
449
+ batch_size = hidden_states.shape[0]
450
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
451
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
452
+
453
+ for block in self.transformer_blocks:
454
+ if self.training and self.gradient_checkpointing:
455
+
456
+ def create_custom_forward(module, return_dict=None):
457
+ def custom_forward(*inputs):
458
+ if return_dict is not None:
459
+ return module(*inputs, return_dict=return_dict)
460
+ else:
461
+ return module(*inputs)
462
+
463
+ return custom_forward
464
+
465
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
466
+ hidden_states = torch.utils.checkpoint.checkpoint(
467
+ create_custom_forward(block),
468
+ hidden_states,
469
+ attention_mask,
470
+ encoder_hidden_states,
471
+ encoder_attention_mask,
472
+ timestep,
473
+ cross_attention_kwargs,
474
+ class_labels,
475
+ **ckpt_kwargs,
476
+ )
477
+ else:
478
+ hidden_states = block(
479
+ hidden_states,
480
+ attention_mask=attention_mask,
481
+ encoder_hidden_states=encoder_hidden_states,
482
+ encoder_attention_mask=encoder_attention_mask,
483
+ timestep=timestep,
484
+ cross_attention_kwargs=cross_attention_kwargs,
485
+ class_labels=class_labels,
486
+ )
487
+
488
+ # 3. Output
489
+ if self.is_input_continuous:
490
+ if not self.use_linear_projection:
491
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
492
+ hidden_states = (
493
+ self.proj_out(hidden_states, scale=lora_scale)
494
+ if not USE_PEFT_BACKEND
495
+ else self.proj_out(hidden_states)
496
+ )
497
+ else:
498
+ hidden_states = (
499
+ self.proj_out(hidden_states, scale=lora_scale)
500
+ if not USE_PEFT_BACKEND
501
+ else self.proj_out(hidden_states)
502
+ )
503
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
504
+
505
+ output = hidden_states + residual
506
+ elif self.is_input_vectorized:
507
+ hidden_states = self.norm_out(hidden_states)
508
+ logits = self.out(hidden_states)
509
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
510
+ logits = logits.permute(0, 2, 1)
511
+
512
+ # log(p(x_0))
513
+ output = F.log_softmax(logits.double(), dim=1).float()
514
+
515
+ if self.is_input_patches:
516
+ if self.config.norm_type != "ada_norm_single":
517
+ conditioning = self.transformer_blocks[0].norm1.emb(
518
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
519
+ )
520
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
521
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
522
+ hidden_states = self.proj_out_2(hidden_states)
523
+ elif self.config.norm_type == "ada_norm_single":
524
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
525
+ hidden_states = self.norm_out(hidden_states)
526
+ # Modulation
527
+ hidden_states = hidden_states * (1 + scale) + shift
528
+ hidden_states = self.proj_out(hidden_states)
529
+ hidden_states = hidden_states.squeeze(1)
530
+
531
+ # unpatchify
532
+ if self.adaln_single is None:
533
+ height = width = int(hidden_states.shape[1] ** 0.5)
534
+ hidden_states = hidden_states.reshape(
535
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
536
+ )
537
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
538
+ output = hidden_states.reshape(
539
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
540
+ )
541
+
542
+ if not return_dict:
543
+ return (output,)
544
+
545
+ return Transformer2DModelOutput(sample=output)
MuCodec/mp3_to_code.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+ import tempfile
5
+ import traceback
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(
14
+ description="Batch encode MP3 files to MuCodec codes (recursive)."
15
+ )
16
+ parser.add_argument("input_dir", type=Path, help="Input folder (recursive scan)")
17
+ parser.add_argument("output_dir", type=Path, help="Output folder for saved codes")
18
+ parser.add_argument(
19
+ "--ckpt",
20
+ type=Path,
21
+ default=Path(__file__).resolve().parent / "ckpt" / "mucodec.pt",
22
+ help="Path to MuCodec checkpoint",
23
+ )
24
+ parser.add_argument(
25
+ "--layer-num",
26
+ type=int,
27
+ default=7,
28
+ help="MuCodec layer num (default follows generate.py)",
29
+ )
30
+ parser.add_argument(
31
+ "--device",
32
+ default="cuda:0",
33
+ help="Torch device, e.g. cuda:0",
34
+ )
35
+ parser.add_argument(
36
+ "--ext",
37
+ nargs="+",
38
+ default=[".mp3"],
39
+ help="Audio extensions to include, e.g. .mp3 .wav .flac",
40
+ )
41
+ parser.add_argument(
42
+ "--format",
43
+ choices=["npz", "pt", "npy", "both", "all"],
44
+ default="npz",
45
+ help="Output format for code files",
46
+ )
47
+ parser.add_argument(
48
+ "--overwrite",
49
+ action="store_true",
50
+ help="Recompute files even if output already exists (disable resume)",
51
+ )
52
+ parser.add_argument(
53
+ "--strict",
54
+ action="store_true",
55
+ help="Stop immediately on first failed file",
56
+ )
57
+ return parser.parse_args()
58
+
59
+
60
+ def list_audio_files(root: Path, exts):
61
+ ext_set = {e.lower() if e.startswith(".") else f".{e.lower()}" for e in exts}
62
+ files = [
63
+ p
64
+ for p in root.rglob("*")
65
+ if p.is_file() and p.suffix.lower() in ext_set
66
+ ]
67
+ files.sort()
68
+ return files
69
+
70
+
71
+ def expected_output_paths(output_stem: Path, fmt: str):
72
+ if fmt == "npz":
73
+ return [output_stem.with_suffix(".npz")]
74
+ if fmt == "pt":
75
+ return [output_stem.with_suffix(".pt")]
76
+ if fmt == "npy":
77
+ return [output_stem.with_suffix(".npy")]
78
+ if fmt == "both":
79
+ return [output_stem.with_suffix(".pt"), output_stem.with_suffix(".npy")]
80
+ if fmt == "all":
81
+ return [
82
+ output_stem.with_suffix(".npz"),
83
+ output_stem.with_suffix(".pt"),
84
+ output_stem.with_suffix(".npy"),
85
+ ]
86
+ raise ValueError(f"Unsupported format: {fmt}")
87
+
88
+
89
+ def save_npz_atomic(codes_np: np.ndarray, output_path: Path):
90
+ output_path.parent.mkdir(parents=True, exist_ok=True)
91
+ tmp_path = None
92
+ try:
93
+ with tempfile.NamedTemporaryFile(
94
+ mode="wb",
95
+ suffix=".npz",
96
+ dir=output_path.parent,
97
+ delete=False,
98
+ ) as tmp_file:
99
+ tmp_path = Path(tmp_file.name)
100
+ np.savez_compressed(tmp_file, codes=codes_np)
101
+ os.replace(tmp_path, output_path)
102
+ except Exception:
103
+ if tmp_path is not None and tmp_path.exists():
104
+ tmp_path.unlink()
105
+ raise
106
+
107
+
108
+ def save_codes(codes: torch.Tensor, output_stem: Path, fmt: str):
109
+ codes_cpu = codes.detach().cpu()
110
+ codes_np = codes_cpu.numpy()
111
+ if fmt in ("npz", "all"):
112
+ save_npz_atomic(codes_np, output_stem.with_suffix(".npz"))
113
+ if fmt in ("pt", "both", "all"):
114
+ torch.save(codes_cpu, output_stem.with_suffix(".pt"))
115
+ if fmt in ("npy", "both", "all"):
116
+ np.save(output_stem.with_suffix(".npy"), codes_np)
117
+
118
+
119
+ def main():
120
+ args = parse_args()
121
+
122
+ from generate import MuCodec
123
+
124
+ if not args.input_dir.exists() or not args.input_dir.is_dir():
125
+ raise ValueError(f"input_dir does not exist or is not a directory: {args.input_dir}")
126
+
127
+ if not args.ckpt.exists():
128
+ raise FileNotFoundError(f"Checkpoint not found: {args.ckpt}")
129
+
130
+ if args.device.startswith("cuda") and not torch.cuda.is_available():
131
+ raise RuntimeError("CUDA device requested but torch.cuda.is_available() is False")
132
+
133
+ audio_files = list_audio_files(args.input_dir, args.ext)
134
+ if not audio_files:
135
+ print("No audio files found.")
136
+ return
137
+
138
+ args.output_dir.mkdir(parents=True, exist_ok=True)
139
+
140
+ mucodec = MuCodec(
141
+ model_path=str(args.ckpt),
142
+ layer_num=args.layer_num,
143
+ load_main_model=True,
144
+ device=args.device,
145
+ )
146
+
147
+ resume_enabled = not args.overwrite
148
+ ok = 0
149
+ skipped = 0
150
+ failed = []
151
+
152
+ for src in tqdm(audio_files, desc="Encoding", unit="file"):
153
+ rel = src.relative_to(args.input_dir)
154
+ output_stem = (args.output_dir / rel).with_suffix("")
155
+ output_paths = expected_output_paths(output_stem, args.format)
156
+
157
+ if resume_enabled and all(p.exists() for p in output_paths):
158
+ skipped += 1
159
+ continue
160
+
161
+ output_stem.parent.mkdir(parents=True, exist_ok=True)
162
+
163
+ try:
164
+ codes = mucodec.file2code(str(src))
165
+ save_codes(codes, output_stem, args.format)
166
+ ok += 1
167
+ except Exception as e:
168
+ failed.append((src, str(e)))
169
+ print(f"[FAILED] {src}: {e}")
170
+ if args.strict:
171
+ print("--strict enabled, stopping on first failure.")
172
+ traceback.print_exc()
173
+ break
174
+
175
+ print(
176
+ "Done. "
177
+ f"success={ok}, skipped={skipped}, failed={len(failed)}, total={len(audio_files)}"
178
+ )
179
+ if failed:
180
+ print("Failed files:")
181
+ for path, err in failed:
182
+ print(f"- {path}: {err}")
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
187
+
MuCodec/muq_dev/test.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+ import fairseq
4
+ import os.path as op
5
+
6
+ root = op.dirname(op.abspath(__file__))
7
+
8
+
9
+ @dataclass
10
+ class UserDirModule:
11
+ user_dir: str
12
+
13
+ def load_model(model_dir, checkpoint_dir):
14
+ '''Load Fairseq SSL model'''
15
+
16
+ model_path = UserDirModule(model_dir)
17
+ fairseq.utils.import_user_module(model_path)
18
+
19
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir], strict=False)
20
+ model = model[0]
21
+
22
+ return model
MuCodec/readme.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MuCodec: Ultra Low-Bitrate Music Codec
2
+
3
+ This repository is the official code repository for MuCodec: Ultra Low-Bitrate Music Codec. You can find our paper on [arXiv] (https://arxiv.org/pdf/2409.13216). The demo page is available [here](https://xuyaoxun.github.io/MuCodec_demo/).
4
+
5
+ In this repository, we provide the Mucodec model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset. Specifically, we have released the model and inference code corresponding to the lowest bitrate of 0.35 kbps as mentioned in the paper, to demonstrate the effectiveness of our work.
6
+
7
+
8
+ MuCodec supports 48kHz, dual-channel (stereo) audio reconstruction. If the original audio is in a different format, it will first be converted to 48kHz, dual-channel audio.
9
+
10
+ ## Installation
11
+
12
+ You can install the necessary dependencies using the `requirements.txt` file with Python 3.8.12:
13
+
14
+ ```bash
15
+ pip install -r requirements.txt
16
+ ```
17
+
18
+ Due to storage limitations, we have saved the model checkpoints on Hugging Face at https://huggingface.co/yaoxunxu/mucodec. You can easily download the models from Hugging Face and save them in the following directories:
19
+
20
+ - Save `audioldm_48k.pth` in the `tools` folder.
21
+ - Save `muq.pt` in the `muq_dev` folder.
22
+ - Save `mucodec.pt` in the `ckpt` folder.
23
+
24
+ Please note that all three checkpoints must be downloaded completely for the model to load correctly. The final file paths should be:
25
+
26
+ ```
27
+ tools/audioldm_48k.pth
28
+ muq_dev/muq.pt
29
+ ckpt/mucodec.pt
30
+ ```
31
+
32
+ The file `audioldm_48k.pth` is sourced from https://huggingface.co/haoheliu/audioldm_48k/blob/main/audioldm_48k.pth.
33
+
34
+ ## Inference
35
+
36
+ To run inference, use the following command:
37
+
38
+ ```bash
39
+ python3 generate.py
40
+ ```
41
+
42
+ We have provided a sample song `test.wav`, randomly sampled from the Million Song Dataset, in the `test_wav` folder. The default input path is `test_wav/test.wav`, and the output path for the reconstructed audio is `reconstruct/test.wav`.
43
+
44
+ In the `generate.py` file, we have implemented several functions to facilitate the music compression and reconstruction process. You can easily obtain compressed tokens from audio using the `sound2code` function, and reconstruct the audio from tokens using the `code2sound` function.
45
+
46
+ ## Note
47
+
48
+ Please note that the open-sourced model was trained solely on the Million Song Dataset. Considering the quality issues of this dataset, the open-sourced model may not achieve the same performance as demonstrated in the demo. Unfortunately, due to copyright restrictions, we are unable to release the checkpoints trained on additional datasets. However, you can use your own dataset to further train the model and achieve better results.
49
+
50
+ ## License
51
+
52
+ The code in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
53
+
54
+ The model weights (muq.pt, mucodec.pt) in this repository are released under the CC-BY-NC 4.0 license, as detailed in the [LICENSE_weights](LICENSE_weights) file.
55
+
56
+ ## Citation
57
+
58
+ If you find our work useful, please cite our paper:
59
+
60
+ ```bibtex
61
+ @article{xu2024mucodec,
62
+ title={MuCodec: Ultra Low-Bitrate Music Codec},
63
+ author={Xu, Yaoxun and Chen, Hangting and Yu, Jianwei and Tan, Wei and Gu, Rongzhi and Lei, Shun and Lin, Zhiwei and Wu, Zhiyong},
64
+ journal={arXiv preprint arXiv:2409.13216},
65
+ year={2024}
66
+ }
67
+ ```
MuCodec/requirements.txt ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ accelerate==0.30.1
3
+ aeiou==0.0.20
4
+ aiobotocore==2.13.1
5
+ aiofiles==23.2.1
6
+ aiohttp==3.9.3
7
+ aioitertools==0.11.0
8
+ aiosignal==1.3.1
9
+ alias-free-torch==0.0.6
10
+ altair==5.3.0
11
+ annotated-types==0.6.0
12
+ antlr4-python3-runtime==4.8
13
+ anyio==4.3.0
14
+ appdirs==1.4.4
15
+ argbind==0.3.9
16
+ asttokens==2.4.1
17
+ astunparse==1.6.3
18
+ async-timeout==4.0.3
19
+ attrs==23.1.0
20
+ audioread==3.0.1
21
+ auraloss==0.4.0
22
+ av==11.0.0
23
+ backcall==0.2.0
24
+ beartype==0.18.5
25
+ bitarray==2.9.2
26
+ bleach==6.1.0
27
+ blis==0.7.11
28
+ bokeh==3.1.1
29
+ botocore==1.34.131
30
+ braceexpand==0.1.7
31
+ cachetools==5.3.2
32
+ catalogue==2.0.10
33
+ certifi==2023.11.17
34
+ cffi==1.16.0
35
+ charset-normalizer==3.3.2
36
+ clean-fid==0.1.35
37
+ click==8.1.7
38
+ clip-anytorch==2.6.0
39
+ cloudpathlib==0.16.0
40
+ cloudpickle==3.0.0
41
+ cn2an==0.5.22
42
+ colorama==0.4.6
43
+ colorcet==3.1.0
44
+ colorlog==6.8.2
45
+ confection==0.1.4
46
+ configparser==7.0.0
47
+ contourpy==1.1.1
48
+ cycler==0.12.1
49
+ cymem==2.0.8
50
+ Cython==3.0.10
51
+ dataclasses==0.6
52
+ datasets
53
+ dctorch==0.1.2
54
+ decorator==5.1.1
55
+ decord==0.6.0
56
+ deepspeed==0.14.0
57
+ demucs==4.0.1
58
+ descript-audio-codec==1.0.0
59
+ descript-audiotools==0.7.2
60
+ diffusers==0.27.2
61
+ dill==0.3.8
62
+ Distance==0.1.3
63
+ docker-pycreds==0.4.0
64
+ docopt==0.6.2
65
+ docstring_parser==0.16
66
+ dora_search==0.1.12
67
+ einops==0.7.0
68
+ einops-exts==0.0.4
69
+ einx==0.3.0
70
+ ema-pytorch==0.2.3
71
+ encodec==0.1.1
72
+ exceptiongroup==1.2.0
73
+ executing==2.0.1
74
+ expecttest==0.1.6
75
+ fairseq==0.12.2
76
+ fastapi==0.110.3
77
+ fastcore==1.6.3
78
+ ffmpy==0.3.2
79
+ filelock==3.13.1
80
+ fire==0.6.0
81
+ flashy==0.0.2
82
+ flatten-dict==0.4.2
83
+ fonttools==4.49.0
84
+ frozendict==2.4.4
85
+ frozenlist==1.4.1
86
+ fsspec==2024.6.1
87
+ ftfy==6.1.3
88
+ future==1.0.0
89
+ g2p-en==2.1.0
90
+ gin-config==0.5.0
91
+ gitdb==4.0.11
92
+ GitPython==3.1.43
93
+ google-auth==2.23.4
94
+ google-auth-oauthlib==1.0.0
95
+ gradio==4.26.0
96
+ gradio_client==0.15.1
97
+ grpcio==1.59.3
98
+ h11==0.14.0
99
+ h5py==3.11.0
100
+ hjson==3.1.0
101
+ holoviews==1.17.1
102
+ httpcore==1.0.5
103
+ httpx==0.27.0
104
+ huggingface-hub==0.23.5
105
+ hydra-colorlog==1.2.0
106
+ hydra-core==1.0.7
107
+ hypothesis==6.90.0
108
+ idna==3.4
109
+ imageio==2.34.2
110
+ importlib-metadata==6.8.0
111
+ importlib-resources==5.12.0
112
+ inflect==7.0.0
113
+ ipython==8.12.3
114
+ jedi==0.19.1
115
+ jieba-fast==0.53
116
+ Jinja2==3.1.2
117
+ jmespath==1.0.1
118
+ joblib==1.3.2
119
+ json5==0.9.25
120
+ jsonlines==4.0.0
121
+ jsonmerge==1.9.2
122
+ jsonschema==4.22.0
123
+ jsonschema-specifications==2023.12.1
124
+ julius==0.2.7
125
+ k-diffusion==0.1.1
126
+ kaldiio==2.18.0
127
+ kiwisolver==1.4.5
128
+ kornia==0.7.3
129
+ kornia_rs==0.1.5
130
+ laion-clap==1.1.4
131
+ lameenc==1.7.0
132
+ langcodes==3.4.0
133
+ language_data==1.2.0
134
+ lazy_loader==0.3
135
+ librosa==0.9.2
136
+ lightning==2.2.1
137
+ lightning-utilities==0.10.1
138
+ linkify-it-py==2.0.3
139
+ lion-pytorch==0.2.2
140
+ llvmlite==0.41.1
141
+ local-attention==1.8.6
142
+ loguru==0.7.2
143
+ lxml==5.2.2
144
+ marisa-trie==1.1.1
145
+ Markdown==3.5.1
146
+ markdown-it-py==3.0.0
147
+ markdown2==2.5.0
148
+ MarkupSafe==2.1.3
149
+ matplotlib==3.7.5
150
+ matplotlib-inline==0.1.7
151
+ mdit-py-plugins==0.4.1
152
+ mdurl==0.1.2
153
+ mpmath==1.3.0
154
+ msgpack==1.0.8
155
+ multidict==6.0.5
156
+ multiprocess==0.70.16
157
+ murmurhash==1.0.10
158
+ mypy-extensions==1.0.0
159
+ networkx==3.1
160
+ ninja==1.11.1.1
161
+ nltk==3.8.1
162
+ nnAudio==0.3.3
163
+ num2words==0.5.13
164
+ numba==0.58.1
165
+ numpy==1.23.5
166
+ nvidia-cublas-cu11==11.11.3.6
167
+ nvidia-cuda-cupti-cu11==11.8.87
168
+ nvidia-cuda-nvrtc-cu11==11.8.89
169
+ nvidia-cuda-runtime-cu11==11.8.89
170
+ nvidia-cudnn-cu11==8.7.0.84
171
+ nvidia-cufft-cu11==10.9.0.58
172
+ nvidia-curand-cu11==10.3.0.86
173
+ nvidia-cusolver-cu11==11.4.1.48
174
+ nvidia-cusparse-cu11==11.7.5.86
175
+ nvidia-nccl-cu11==2.19.3
176
+ nvidia-nvtx-cu11==11.8.86
177
+ oauthlib==3.2.2
178
+ omegaconf
179
+ opencv-contrib-python==4.8.1.78
180
+ opencv-python==4.8.1.78
181
+ openunmix==1.2.1
182
+ orjson==3.10.3
183
+ packaging==23.2
184
+ pandas==2.0.2
185
+ panel==1.2.3
186
+ param==2.1.1
187
+ parso==0.8.4
188
+ pathtools==0.1.2
189
+ pedalboard==0.7.4
190
+ peft==0.10.0
191
+ pexpect==4.9.0
192
+ pickleshare==0.7.5
193
+ Pillow==10.1.0
194
+ pkgutil_resolve_name==1.3.10
195
+ platformdirs==4.2.0
196
+ plotly==5.23.0
197
+ pooch==1.8.1
198
+ portalocker==2.10.1
199
+ prefigure==0.0.9
200
+ preshed==3.0.9
201
+ proces==0.1.7
202
+ prodict==0.8.18
203
+ progressbar==2.5
204
+ prompt_toolkit==3.0.47
205
+ protobuf==3.19.6
206
+ psutil==5.9.6
207
+ ptyprocess==0.7.0
208
+ pure_eval==0.2.3
209
+ py-cpuinfo==9.0.0
210
+ pyarrow==17.0.0
211
+ pyarrow-hotfix==0.6
212
+ pyasn1==0.5.1
213
+ pyasn1-modules==0.3.0
214
+ pybind11==2.11.1
215
+ pycparser==2.21
216
+ pydantic==2.6.3
217
+ pydantic_core==2.16.3
218
+ pydub==0.25.1
219
+ Pygments==2.18.0
220
+ pyloudnorm==0.1.1
221
+ pynndescent==0.5.13
222
+ pynvml==11.5.0
223
+ pyparsing==3.1.2
224
+ pypinyin==0.51.0
225
+ pyre-extensions==0.0.29
226
+ pyreaper==0.0.10
227
+ pystoi==0.4.1
228
+ python-dateutil==2.8.2
229
+ python-multipart==0.0.9
230
+ pytorch-lightning==2.1.0
231
+ pytz==2023.3.post1
232
+ pyviz_comms==3.0.3
233
+ PyWavelets==1.4.1
234
+ PyYAML==6.0.1
235
+ randomname==0.2.1
236
+ referencing==0.35.1
237
+ regex==2023.10.3
238
+ requests==2.32.3
239
+ requests-oauthlib==1.3.1
240
+ resampy==0.4.3
241
+ retrying==1.3.4
242
+ rich==13.7.1
243
+ rpds-py==0.18.1
244
+ rsa==4.9
245
+ ruamel.yaml==0.18.5
246
+ ruamel.yaml.clib==0.2.8
247
+ ruff==0.4.4
248
+ s3fs==2024.6.1
249
+ s3transfer==0.7.0
250
+ sacrebleu==2.4.2
251
+ safetensors==0.4.3
252
+ scikit-image==0.21.0
253
+ scikit-learn==1.3.2
254
+ scipy==1.10.1
255
+ semantic-version==2.10.0
256
+ sentencepiece==0.1.99
257
+ sentry-sdk==2.10.0
258
+ setproctitle==1.3.3
259
+ shellingham==1.5.4
260
+ six==1.16.0
261
+ smart-open==6.4.0
262
+ smmap==5.0.1
263
+ sniffio==1.3.1
264
+ sortedcontainers==2.4.0
265
+ SoundFile==0.10.2
266
+ sox==1.4.1
267
+ soxr==0.3.7
268
+ spacy==3.7.4
269
+ spacy-legacy==3.0.12
270
+ spacy-loggers==1.0.5
271
+ srsly==2.4.8
272
+ stack-data==0.6.3
273
+ starlette==0.37.2
274
+ submitit==1.5.1
275
+ sympy==1.12
276
+ tabulate==0.9.0
277
+ tenacity==9.0.0
278
+ tensorboard==2.14.0
279
+ tensorboard-data-server==0.7.2
280
+ termcolor==2.3.0
281
+ thinc==8.2.3
282
+ threadpoolctl==3.3.0
283
+ tifffile==2023.7.10
284
+ timm==0.9.11
285
+ tokenizers==0.19.1
286
+ tomlkit==0.12.0
287
+ toolz==0.12.1
288
+ torch==2.2.0+cu118
289
+ torch-stoi==0.2.1
290
+ torchaudio==2.2.0+cu118
291
+ torchdata==0.7.1
292
+ torchdiffeq==0.2.4
293
+ torchlibrosa==0.1.0
294
+ torchmetrics==0.11.4
295
+ torchsde==0.2.6
296
+ torchtext==0.17.0
297
+ torchvision==0.17.0+cu118
298
+ tornado==6.4.1
299
+ tqdm==4.66.4
300
+ traitlets==5.14.3
301
+ trampoline==0.1.2
302
+ transformers==4.42.4
303
+ treetable==0.2.5
304
+ triton==2.2.0
305
+ typeguard==2.13.0
306
+ typer==0.9.4
307
+ types-dataclasses==0.6.6
308
+ typing-inspect==0.9.0
309
+ typing_extensions==4.8.0
310
+ tzdata==2023.3
311
+ uc-micro-py==1.0.3
312
+ umap-learn==0.5.6
313
+ Unidecode==1.3.8
314
+ urllib3==1.26.18
315
+ uvicorn==0.29.0
316
+ v-diffusion-pytorch==0.0.2
317
+ vector-quantize-pytorch==1.9.14
318
+ wandb==0.15.4
319
+ wasabi==1.1.2
320
+ wcwidth==0.2.12
321
+ weasel==0.3.4
322
+ webdataset==0.2.48
323
+ webencodings==0.5.1
324
+ websockets==11.0.3
325
+ Werkzeug==3.0.1
326
+ wget==3.2
327
+ wordsegment==1.3.1
328
+ wrapt==1.16.0
329
+ x-clip==0.14.4
330
+ x-transformers==1.26.6
331
+ xformers==0.0.24+cu118
332
+ xxhash==3.4.1
333
+ xyzservices==2024.6.0
334
+ yarl==1.9.4
335
+ zipp==3.17.0
MuCodec/tools/get_melvaehifigan48k.py ADDED
@@ -0,0 +1,1551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import soundfile as sf
3
+ import os
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ import sys
6
+ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
7
+ import tools.torch_tools as torch_tools
8
+ import torch.nn as nn
9
+ import torch
10
+ import numpy as np
11
+ from einops import rearrange
12
+ from scipy.signal import get_window
13
+ from librosa.util import pad_center, tiny
14
+ import librosa.util as librosa_util
15
+
16
+ class AttrDict(dict):
17
+ def __init__(self, *args, **kwargs):
18
+ super(AttrDict, self).__init__(*args, **kwargs)
19
+ self.__dict__ = self
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def get_padding(kernel_size, dilation=1):
28
+ return int((kernel_size * dilation - dilation) / 2)
29
+
30
+ LRELU_SLOPE = 0.1
31
+
32
+ class ResBlock(torch.nn.Module):
33
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
34
+ super(ResBlock, self).__init__()
35
+ self.h = h
36
+ self.convs1 = nn.ModuleList(
37
+ [
38
+ torch.nn.utils.weight_norm(
39
+ nn.Conv1d(
40
+ channels,
41
+ channels,
42
+ kernel_size,
43
+ 1,
44
+ dilation=dilation[0],
45
+ padding=get_padding(kernel_size, dilation[0]),
46
+ )
47
+ ),
48
+ torch.nn.utils.weight_norm(
49
+ nn.Conv1d(
50
+ channels,
51
+ channels,
52
+ kernel_size,
53
+ 1,
54
+ dilation=dilation[1],
55
+ padding=get_padding(kernel_size, dilation[1]),
56
+ )
57
+ ),
58
+ torch.nn.utils.weight_norm(
59
+ nn.Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ 1,
64
+ dilation=dilation[2],
65
+ padding=get_padding(kernel_size, dilation[2]),
66
+ )
67
+ ),
68
+ ]
69
+ )
70
+ self.convs1.apply(init_weights)
71
+
72
+ self.convs2 = nn.ModuleList(
73
+ [
74
+ torch.nn.utils.weight_norm(
75
+ nn.Conv1d(
76
+ channels,
77
+ channels,
78
+ kernel_size,
79
+ 1,
80
+ dilation=1,
81
+ padding=get_padding(kernel_size, 1),
82
+ )
83
+ ),
84
+ torch.nn.utils.weight_norm(
85
+ nn.Conv1d(
86
+ channels,
87
+ channels,
88
+ kernel_size,
89
+ 1,
90
+ dilation=1,
91
+ padding=get_padding(kernel_size, 1),
92
+ )
93
+ ),
94
+ torch.nn.utils.weight_norm(
95
+ nn.Conv1d(
96
+ channels,
97
+ channels,
98
+ kernel_size,
99
+ 1,
100
+ dilation=1,
101
+ padding=get_padding(kernel_size, 1),
102
+ )
103
+ ),
104
+ ]
105
+ )
106
+ self.convs2.apply(init_weights)
107
+
108
+ def forward(self, x):
109
+ for c1, c2 in zip(self.convs1, self.convs2):
110
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
111
+ xt = c1(xt)
112
+ xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
113
+ xt = c2(xt)
114
+ x = xt + x
115
+ return x
116
+
117
+ def remove_weight_norm(self):
118
+ for l in self.convs1:
119
+ torch.nn.utils.remove_weight_norm(l)
120
+ for l in self.convs2:
121
+ torch.nn.utils.remove_weight_norm(l)
122
+
123
+
124
+ class Generator_old(torch.nn.Module):
125
+ def __init__(self, h):
126
+ super(Generator_old, self).__init__()
127
+ self.h = h
128
+ self.num_kernels = len(h.resblock_kernel_sizes)
129
+ self.num_upsamples = len(h.upsample_rates)
130
+ self.conv_pre = torch.nn.utils.weight_norm(
131
+ nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
132
+ )
133
+ resblock = ResBlock
134
+
135
+ self.ups = nn.ModuleList()
136
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
137
+ self.ups.append(
138
+ torch.nn.utils.weight_norm(
139
+ nn.ConvTranspose1d(
140
+ h.upsample_initial_channel // (2**i),
141
+ h.upsample_initial_channel // (2 ** (i + 1)),
142
+ k,
143
+ u,
144
+ padding=(k - u) // 2,
145
+ )
146
+ )
147
+ )
148
+
149
+ self.resblocks = nn.ModuleList()
150
+ for i in range(len(self.ups)):
151
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
152
+ for j, (k, d) in enumerate(
153
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
154
+ ):
155
+ self.resblocks.append(resblock(h, ch, k, d))
156
+
157
+ self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
158
+ self.ups.apply(init_weights)
159
+ self.conv_post.apply(init_weights)
160
+
161
+ def forward(self, x):
162
+ x = self.conv_pre(x)
163
+ for i in range(self.num_upsamples):
164
+ x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
165
+ x = self.ups[i](x)
166
+ xs = None
167
+ for j in range(self.num_kernels):
168
+ if xs is None:
169
+ xs = self.resblocks[i * self.num_kernels + j](x)
170
+ else:
171
+ xs += self.resblocks[i * self.num_kernels + j](x)
172
+ x = xs / self.num_kernels
173
+ x = torch.nn.functional.leaky_relu(x)
174
+ x = self.conv_post(x)
175
+ x = torch.tanh(x)
176
+
177
+ return x
178
+
179
+ def remove_weight_norm(self):
180
+ # print("Removing weight norm...")
181
+ for l in self.ups:
182
+ torch.nn.utils.remove_weight_norm(l)
183
+ for l in self.resblocks:
184
+ l.remove_weight_norm()
185
+ torch.nn.utils.remove_weight_norm(self.conv_pre)
186
+ torch.nn.utils.remove_weight_norm(self.conv_post)
187
+
188
+
189
+
190
+ def nonlinearity(x):
191
+ # swish
192
+ return x * torch.sigmoid(x)
193
+
194
+
195
+ def Normalize(in_channels, num_groups=32):
196
+ return torch.nn.GroupNorm(
197
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
198
+ )
199
+
200
+ class Downsample(nn.Module):
201
+ def __init__(self, in_channels, with_conv):
202
+ super().__init__()
203
+ self.with_conv = with_conv
204
+ if self.with_conv:
205
+ # Do time downsampling here
206
+ # no asymmetric padding in torch conv, must do it ourselves
207
+ self.conv = torch.nn.Conv2d(
208
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
209
+ )
210
+
211
+ def forward(self, x):
212
+ if self.with_conv:
213
+ pad = (0, 1, 0, 1)
214
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
215
+ x = self.conv(x)
216
+ else:
217
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
218
+ return x
219
+
220
+
221
+ class DownsampleTimeStride4(nn.Module):
222
+ def __init__(self, in_channels, with_conv):
223
+ super().__init__()
224
+ self.with_conv = with_conv
225
+ if self.with_conv:
226
+ # Do time downsampling here
227
+ # no asymmetric padding in torch conv, must do it ourselves
228
+ self.conv = torch.nn.Conv2d(
229
+ in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
230
+ )
231
+
232
+ def forward(self, x):
233
+ if self.with_conv:
234
+ pad = (0, 1, 0, 1)
235
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
236
+ x = self.conv(x)
237
+ else:
238
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
239
+ return x
240
+
241
+ class Upsample(nn.Module):
242
+ def __init__(self, in_channels, with_conv):
243
+ super().__init__()
244
+ self.with_conv = with_conv
245
+ if self.with_conv:
246
+ self.conv = torch.nn.Conv2d(
247
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
248
+ )
249
+
250
+ def forward(self, x):
251
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
252
+ if self.with_conv:
253
+ x = self.conv(x)
254
+ return x
255
+
256
+
257
+ class UpsampleTimeStride4(nn.Module):
258
+ def __init__(self, in_channels, with_conv):
259
+ super().__init__()
260
+ self.with_conv = with_conv
261
+ if self.with_conv:
262
+ self.conv = torch.nn.Conv2d(
263
+ in_channels, in_channels, kernel_size=5, stride=1, padding=2
264
+ )
265
+
266
+ def forward(self, x):
267
+ x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
268
+ if self.with_conv:
269
+ x = self.conv(x)
270
+ return x
271
+
272
+ class AttnBlock(nn.Module):
273
+ def __init__(self, in_channels):
274
+ super().__init__()
275
+ self.in_channels = in_channels
276
+
277
+ self.norm = Normalize(in_channels)
278
+ self.q = torch.nn.Conv2d(
279
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
280
+ )
281
+ self.k = torch.nn.Conv2d(
282
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
283
+ )
284
+ self.v = torch.nn.Conv2d(
285
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
286
+ )
287
+ self.proj_out = torch.nn.Conv2d(
288
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
289
+ )
290
+
291
+ def forward(self, x):
292
+ h_ = x
293
+ h_ = self.norm(h_)
294
+ q = self.q(h_)
295
+ k = self.k(h_)
296
+ v = self.v(h_)
297
+
298
+ # compute attention
299
+ b, c, h, w = q.shape
300
+ q = q.reshape(b, c, h * w).contiguous()
301
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
302
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
303
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
304
+ w_ = w_ * (int(c) ** (-0.5))
305
+ w_ = torch.nn.functional.softmax(w_, dim=2)
306
+
307
+ # attend to values
308
+ v = v.reshape(b, c, h * w).contiguous()
309
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
310
+ h_ = torch.bmm(
311
+ v, w_
312
+ ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
313
+ h_ = h_.reshape(b, c, h, w).contiguous()
314
+
315
+ h_ = self.proj_out(h_)
316
+
317
+ return x + h_
318
+
319
+
320
+ def make_attn(in_channels, attn_type="vanilla"):
321
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
322
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
323
+ if attn_type == "vanilla":
324
+ return AttnBlock(in_channels)
325
+ elif attn_type == "none":
326
+ return nn.Identity(in_channels)
327
+ else:
328
+ raise ValueError(attn_type)
329
+
330
+
331
+ class ResnetBlock(nn.Module):
332
+ def __init__(
333
+ self,
334
+ *,
335
+ in_channels,
336
+ out_channels=None,
337
+ conv_shortcut=False,
338
+ dropout,
339
+ temb_channels=512,
340
+ ):
341
+ super().__init__()
342
+ self.in_channels = in_channels
343
+ out_channels = in_channels if out_channels is None else out_channels
344
+ self.out_channels = out_channels
345
+ self.use_conv_shortcut = conv_shortcut
346
+
347
+ self.norm1 = Normalize(in_channels)
348
+ self.conv1 = torch.nn.Conv2d(
349
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
350
+ )
351
+ if temb_channels > 0:
352
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
353
+ self.norm2 = Normalize(out_channels)
354
+ self.dropout = torch.nn.Dropout(dropout)
355
+ self.conv2 = torch.nn.Conv2d(
356
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
357
+ )
358
+ if self.in_channels != self.out_channels:
359
+ if self.use_conv_shortcut:
360
+ self.conv_shortcut = torch.nn.Conv2d(
361
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
362
+ )
363
+ else:
364
+ self.nin_shortcut = torch.nn.Conv2d(
365
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
366
+ )
367
+
368
+ def forward(self, x, temb):
369
+ h = x
370
+ h = self.norm1(h)
371
+ h = nonlinearity(h)
372
+ h = self.conv1(h)
373
+
374
+ if temb is not None:
375
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
376
+
377
+ h = self.norm2(h)
378
+ h = nonlinearity(h)
379
+ h = self.dropout(h)
380
+ h = self.conv2(h)
381
+
382
+ if self.in_channels != self.out_channels:
383
+ if self.use_conv_shortcut:
384
+ x = self.conv_shortcut(x)
385
+ else:
386
+ x = self.nin_shortcut(x)
387
+
388
+ return x + h
389
+
390
+
391
+ class Encoder(nn.Module):
392
+ def __init__(
393
+ self,
394
+ *,
395
+ ch,
396
+ out_ch,
397
+ ch_mult=(1, 2, 4, 8),
398
+ num_res_blocks,
399
+ attn_resolutions,
400
+ dropout=0.0,
401
+ resamp_with_conv=True,
402
+ in_channels,
403
+ resolution,
404
+ z_channels,
405
+ double_z=True,
406
+ use_linear_attn=False,
407
+ attn_type="vanilla",
408
+ downsample_time_stride4_levels=[],
409
+ **ignore_kwargs,
410
+ ):
411
+ super().__init__()
412
+ if use_linear_attn:
413
+ attn_type = "linear"
414
+ self.ch = ch
415
+ self.temb_ch = 0
416
+ self.num_resolutions = len(ch_mult)
417
+ self.num_res_blocks = num_res_blocks
418
+ self.resolution = resolution
419
+ self.in_channels = in_channels
420
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
421
+
422
+ if len(self.downsample_time_stride4_levels) > 0:
423
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
424
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
425
+ % str(self.num_resolutions)
426
+ )
427
+
428
+ # downsampling
429
+ self.conv_in = torch.nn.Conv2d(
430
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
431
+ )
432
+
433
+ curr_res = resolution
434
+ in_ch_mult = (1,) + tuple(ch_mult)
435
+ self.in_ch_mult = in_ch_mult
436
+ self.down = nn.ModuleList()
437
+ for i_level in range(self.num_resolutions):
438
+ block = nn.ModuleList()
439
+ attn = nn.ModuleList()
440
+ block_in = ch * in_ch_mult[i_level]
441
+ block_out = ch * ch_mult[i_level]
442
+ for i_block in range(self.num_res_blocks):
443
+ block.append(
444
+ ResnetBlock(
445
+ in_channels=block_in,
446
+ out_channels=block_out,
447
+ temb_channels=self.temb_ch,
448
+ dropout=dropout,
449
+ )
450
+ )
451
+ block_in = block_out
452
+ if curr_res in attn_resolutions:
453
+ attn.append(make_attn(block_in, attn_type=attn_type))
454
+ down = nn.Module()
455
+ down.block = block
456
+ down.attn = attn
457
+ if i_level != self.num_resolutions - 1:
458
+ if i_level in self.downsample_time_stride4_levels:
459
+ down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
460
+ else:
461
+ down.downsample = Downsample(block_in, resamp_with_conv)
462
+ curr_res = curr_res // 2
463
+ self.down.append(down)
464
+
465
+ # middle
466
+ self.mid = nn.Module()
467
+ self.mid.block_1 = ResnetBlock(
468
+ in_channels=block_in,
469
+ out_channels=block_in,
470
+ temb_channels=self.temb_ch,
471
+ dropout=dropout,
472
+ )
473
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
474
+ self.mid.block_2 = ResnetBlock(
475
+ in_channels=block_in,
476
+ out_channels=block_in,
477
+ temb_channels=self.temb_ch,
478
+ dropout=dropout,
479
+ )
480
+
481
+ # end
482
+ self.norm_out = Normalize(block_in)
483
+ self.conv_out = torch.nn.Conv2d(
484
+ block_in,
485
+ 2 * z_channels if double_z else z_channels,
486
+ kernel_size=3,
487
+ stride=1,
488
+ padding=1,
489
+ )
490
+
491
+ def forward(self, x):
492
+ # timestep embedding
493
+ temb = None
494
+ # downsampling
495
+ hs = [self.conv_in(x)]
496
+ for i_level in range(self.num_resolutions):
497
+ for i_block in range(self.num_res_blocks):
498
+ h = self.down[i_level].block[i_block](hs[-1], temb)
499
+ if len(self.down[i_level].attn) > 0:
500
+ h = self.down[i_level].attn[i_block](h)
501
+ hs.append(h)
502
+ if i_level != self.num_resolutions - 1:
503
+ hs.append(self.down[i_level].downsample(hs[-1]))
504
+
505
+ # middle
506
+ h = hs[-1]
507
+ h = self.mid.block_1(h, temb)
508
+ h = self.mid.attn_1(h)
509
+ h = self.mid.block_2(h, temb)
510
+
511
+ # end
512
+ h = self.norm_out(h)
513
+ h = nonlinearity(h)
514
+ h = self.conv_out(h)
515
+ return h
516
+
517
+
518
+ class Decoder(nn.Module):
519
+ def __init__(
520
+ self,
521
+ *,
522
+ ch,
523
+ out_ch,
524
+ ch_mult=(1, 2, 4, 8),
525
+ num_res_blocks,
526
+ attn_resolutions,
527
+ dropout=0.0,
528
+ resamp_with_conv=True,
529
+ in_channels,
530
+ resolution,
531
+ z_channels,
532
+ give_pre_end=False,
533
+ tanh_out=False,
534
+ use_linear_attn=False,
535
+ downsample_time_stride4_levels=[],
536
+ attn_type="vanilla",
537
+ **ignorekwargs,
538
+ ):
539
+ super().__init__()
540
+ if use_linear_attn:
541
+ attn_type = "linear"
542
+ self.ch = ch
543
+ self.temb_ch = 0
544
+ self.num_resolutions = len(ch_mult)
545
+ self.num_res_blocks = num_res_blocks
546
+ self.resolution = resolution
547
+ self.in_channels = in_channels
548
+ self.give_pre_end = give_pre_end
549
+ self.tanh_out = tanh_out
550
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
551
+
552
+ if len(self.downsample_time_stride4_levels) > 0:
553
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
554
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
555
+ % str(self.num_resolutions)
556
+ )
557
+
558
+ # compute in_ch_mult, block_in and curr_res at lowest res
559
+ (1,) + tuple(ch_mult)
560
+ block_in = ch * ch_mult[self.num_resolutions - 1]
561
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
562
+ self.z_shape = (1, z_channels, curr_res, curr_res)
563
+ # print(
564
+ # "Working with z of shape {} = {} dimensions.".format(
565
+ # self.z_shape, np.prod(self.z_shape)
566
+ # )
567
+ # )
568
+
569
+ # z to block_in
570
+ self.conv_in = torch.nn.Conv2d(
571
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
572
+ )
573
+
574
+ # middle
575
+ self.mid = nn.Module()
576
+ self.mid.block_1 = ResnetBlock(
577
+ in_channels=block_in,
578
+ out_channels=block_in,
579
+ temb_channels=self.temb_ch,
580
+ dropout=dropout,
581
+ )
582
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
583
+ self.mid.block_2 = ResnetBlock(
584
+ in_channels=block_in,
585
+ out_channels=block_in,
586
+ temb_channels=self.temb_ch,
587
+ dropout=dropout,
588
+ )
589
+
590
+ # upsampling
591
+ self.up = nn.ModuleList()
592
+ for i_level in reversed(range(self.num_resolutions)):
593
+ block = nn.ModuleList()
594
+ attn = nn.ModuleList()
595
+ block_out = ch * ch_mult[i_level]
596
+ for i_block in range(self.num_res_blocks + 1):
597
+ block.append(
598
+ ResnetBlock(
599
+ in_channels=block_in,
600
+ out_channels=block_out,
601
+ temb_channels=self.temb_ch,
602
+ dropout=dropout,
603
+ )
604
+ )
605
+ block_in = block_out
606
+ if curr_res in attn_resolutions:
607
+ attn.append(make_attn(block_in, attn_type=attn_type))
608
+ up = nn.Module()
609
+ up.block = block
610
+ up.attn = attn
611
+ if i_level != 0:
612
+ if i_level - 1 in self.downsample_time_stride4_levels:
613
+ up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
614
+ else:
615
+ up.upsample = Upsample(block_in, resamp_with_conv)
616
+ curr_res = curr_res * 2
617
+ self.up.insert(0, up) # prepend to get consistent order
618
+
619
+ # end
620
+ self.norm_out = Normalize(block_in)
621
+ self.conv_out = torch.nn.Conv2d(
622
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
623
+ )
624
+
625
+ def forward(self, z):
626
+ # assert z.shape[1:] == self.z_shape[1:]
627
+ self.last_z_shape = z.shape
628
+
629
+ # timestep embedding
630
+ temb = None
631
+
632
+ # z to block_in
633
+ h = self.conv_in(z)
634
+
635
+ # middle
636
+ h = self.mid.block_1(h, temb)
637
+ h = self.mid.attn_1(h)
638
+ h = self.mid.block_2(h, temb)
639
+
640
+ # upsampling
641
+ for i_level in reversed(range(self.num_resolutions)):
642
+ for i_block in range(self.num_res_blocks + 1):
643
+ h = self.up[i_level].block[i_block](h, temb)
644
+ if len(self.up[i_level].attn) > 0:
645
+ h = self.up[i_level].attn[i_block](h)
646
+ if i_level != 0:
647
+ h = self.up[i_level].upsample(h)
648
+
649
+ # end
650
+ if self.give_pre_end:
651
+ return h
652
+
653
+ h = self.norm_out(h)
654
+ h = nonlinearity(h)
655
+ h = self.conv_out(h)
656
+ if self.tanh_out:
657
+ h = torch.tanh(h)
658
+ return h
659
+
660
+
661
+ class DiagonalGaussianDistribution(object):
662
+ def __init__(self, parameters, deterministic=False):
663
+ self.parameters = parameters
664
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
665
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
666
+ self.deterministic = deterministic
667
+ self.std = torch.exp(0.5 * self.logvar)
668
+ self.var = torch.exp(self.logvar)
669
+ if self.deterministic:
670
+ self.var = self.std = torch.zeros_like(self.mean).to(
671
+ device=self.parameters.device
672
+ )
673
+
674
+ def sample(self):
675
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
676
+ device=self.parameters.device
677
+ )
678
+ return x
679
+
680
+ def kl(self, other=None):
681
+ if self.deterministic:
682
+ return torch.Tensor([0.0])
683
+ else:
684
+ if other is None:
685
+ return 0.5 * torch.mean(
686
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
687
+ dim=[1, 2, 3],
688
+ )
689
+ else:
690
+ return 0.5 * torch.mean(
691
+ torch.pow(self.mean - other.mean, 2) / other.var
692
+ + self.var / other.var
693
+ - 1.0
694
+ - self.logvar
695
+ + other.logvar,
696
+ dim=[1, 2, 3],
697
+ )
698
+
699
+ def nll(self, sample, dims=[1, 2, 3]):
700
+ if self.deterministic:
701
+ return torch.Tensor([0.0])
702
+ logtwopi = np.log(2.0 * np.pi)
703
+ return 0.5 * torch.sum(
704
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
705
+ dim=dims,
706
+ )
707
+
708
+ def mode(self):
709
+ return self.mean
710
+
711
+ def get_vocoder_config_48k():
712
+ return {
713
+ "resblock": "1",
714
+ "num_gpus": 8,
715
+ "batch_size": 128,
716
+ "learning_rate": 0.0001,
717
+ "adam_b1": 0.8,
718
+ "adam_b2": 0.99,
719
+ "lr_decay": 0.999,
720
+ "seed": 1234,
721
+
722
+ "upsample_rates": [6,5,4,2,2],
723
+ "upsample_kernel_sizes": [12,10,8,4,4],
724
+ "upsample_initial_channel": 1536,
725
+ "resblock_kernel_sizes": [3,7,11,15],
726
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]],
727
+
728
+ "segment_size": 15360,
729
+ "num_mels": 256,
730
+ "n_fft": 2048,
731
+ "hop_size": 480,
732
+ "win_size": 2048,
733
+
734
+ "sampling_rate": 48000,
735
+
736
+ "fmin": 20,
737
+ "fmax": 24000,
738
+ "fmax_for_loss": None,
739
+
740
+ "num_workers": 8,
741
+
742
+ "dist_config": {
743
+ "dist_backend": "nccl",
744
+ "dist_url": "tcp://localhost:18273",
745
+ "world_size": 1
746
+ }
747
+ }
748
+
749
+ def get_vocoder(config, device, mel_bins):
750
+ name = "HiFi-GAN"
751
+ speaker = ""
752
+ if name == "MelGAN":
753
+ if speaker == "LJSpeech":
754
+ vocoder = torch.hub.load(
755
+ "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
756
+ )
757
+ elif speaker == "universal":
758
+ vocoder = torch.hub.load(
759
+ "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
760
+ )
761
+ vocoder.mel2wav.eval()
762
+ vocoder.mel2wav.to(device)
763
+ elif name == "HiFi-GAN":
764
+ if(mel_bins == 256):
765
+ config = get_vocoder_config_48k()
766
+ config = AttrDict(config)
767
+ vocoder = Generator_old(config)
768
+ # print("Load hifigan/g_01080000")
769
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
770
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
771
+ # ckpt = torch_version_orig_mod_remove(ckpt)
772
+ # vocoder.load_state_dict(ckpt["generator"])
773
+ vocoder.eval()
774
+ vocoder.remove_weight_norm()
775
+ vocoder.to(device)
776
+ else:
777
+ raise ValueError(mel_bins)
778
+ return vocoder
779
+
780
+ def vocoder_infer(mels, vocoder, lengths=None):
781
+ with torch.no_grad():
782
+ wavs = vocoder(mels).squeeze(1)
783
+
784
+ #wavs = (wavs.cpu().numpy() * 32768).astype("int16")
785
+ wavs = (wavs.cpu().numpy())
786
+
787
+ if lengths is not None:
788
+ wavs = wavs[:, :lengths]
789
+
790
+ # wavs = [wav for wav in wavs]
791
+
792
+ # for i in range(len(mels)):
793
+ # if lengths is not None:
794
+ # wavs[i] = wavs[i][: lengths[i]]
795
+
796
+ return wavs
797
+
798
+ @torch.no_grad()
799
+ def vocoder_chunk_infer(mels, vocoder, lengths=None):
800
+ chunk_size = 256*4
801
+ shift_size = 256*1
802
+ ov_size = chunk_size-shift_size
803
+ # import pdb;pdb.set_trace()
804
+
805
+ for cinx in range(0, mels.shape[2], shift_size):
806
+ if(cinx==0):
807
+ wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()
808
+ num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size
809
+ wavs = wavs[:,0:num_samples]
810
+ ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size)
811
+ ov_win = torch.from_numpy(np.linspace(0,1,ov_sample)[None,:])
812
+ ov_win = torch.cat([ov_win,1-ov_win],-1)
813
+ if(cinx+chunk_size>=mels.shape[2]):
814
+ break
815
+ else:
816
+ cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()[:,0:num_samples]
817
+ wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample]
818
+ # wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0
819
+ wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1)
820
+ if(cinx+chunk_size>=mels.shape[2]):
821
+ break
822
+ # print(wavs.shape)
823
+
824
+ wavs = (wavs.cpu().numpy())
825
+
826
+ if lengths is not None:
827
+ wavs = wavs[:, :lengths]
828
+ # print(wavs.shape)
829
+ return wavs
830
+
831
+ def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
832
+ if vocoder is not None:
833
+
834
+ wav_reconstruction = vocoder_infer(
835
+ mel_input.permute(0, 2, 1),
836
+ vocoder,
837
+ )
838
+ wav_prediction = vocoder_infer(
839
+ mel_prediction.permute(0, 2, 1),
840
+ vocoder,
841
+ )
842
+ else:
843
+ wav_reconstruction = wav_prediction = None
844
+
845
+ return wav_reconstruction, wav_prediction
846
+
847
+
848
+ class AutoencoderKL(nn.Module):
849
+ def __init__(
850
+ self,
851
+ ddconfig=None,
852
+ lossconfig=None,
853
+ batchsize=None,
854
+ embed_dim=None,
855
+ time_shuffle=1,
856
+ subband=1,
857
+ sampling_rate=16000,
858
+ ckpt_path=None,
859
+ reload_from_ckpt=None,
860
+ ignore_keys=[],
861
+ image_key="fbank",
862
+ colorize_nlabels=None,
863
+ monitor=None,
864
+ base_learning_rate=1e-5,
865
+ scale_factor=1
866
+ ):
867
+ super().__init__()
868
+ self.automatic_optimization = False
869
+ assert (
870
+ "mel_bins" in ddconfig.keys()
871
+ ), "mel_bins is not specified in the Autoencoder config"
872
+ num_mel = ddconfig["mel_bins"]
873
+ self.image_key = image_key
874
+ self.sampling_rate = sampling_rate
875
+ self.encoder = Encoder(**ddconfig)
876
+ self.decoder = Decoder(**ddconfig)
877
+
878
+ self.loss = None
879
+ self.subband = int(subband)
880
+
881
+ if self.subband > 1:
882
+ print("Use subband decomposition %s" % self.subband)
883
+
884
+ assert ddconfig["double_z"]
885
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
886
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
887
+
888
+ if self.image_key == "fbank":
889
+ self.vocoder = get_vocoder(None, "cpu", num_mel)
890
+ self.embed_dim = embed_dim
891
+ if colorize_nlabels is not None:
892
+ assert type(colorize_nlabels) == int
893
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
894
+ if monitor is not None:
895
+ self.monitor = monitor
896
+ if ckpt_path is not None:
897
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
898
+ self.learning_rate = float(base_learning_rate)
899
+ # print("Initial learning rate %s" % self.learning_rate)
900
+
901
+ self.time_shuffle = time_shuffle
902
+ self.reload_from_ckpt = reload_from_ckpt
903
+ self.reloaded = False
904
+ self.mean, self.std = None, None
905
+
906
+ self.feature_cache = None
907
+ self.flag_first_run = True
908
+ self.train_step = 0
909
+
910
+ self.logger_save_dir = None
911
+ self.logger_exp_name = None
912
+ self.scale_factor = scale_factor
913
+
914
+ print("Num parameters:")
915
+ print("Encoder : ", sum(p.numel() for p in self.encoder.parameters()))
916
+ print("Decoder : ", sum(p.numel() for p in self.decoder.parameters()))
917
+ print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters()))
918
+
919
+ def get_log_dir(self):
920
+ if self.logger_save_dir is None and self.logger_exp_name is None:
921
+ return os.path.join(self.logger.save_dir, self.logger._project)
922
+ else:
923
+ return os.path.join(self.logger_save_dir, self.logger_exp_name)
924
+
925
+ def set_log_dir(self, save_dir, exp_name):
926
+ self.logger_save_dir = save_dir
927
+ self.logger_exp_name = exp_name
928
+
929
+ def init_from_ckpt(self, path, ignore_keys=list()):
930
+ sd = torch.load(path, map_location="cpu")["state_dict"]
931
+ keys = list(sd.keys())
932
+ for k in keys:
933
+ for ik in ignore_keys:
934
+ if k.startswith(ik):
935
+ print("Deleting key {} from state_dict.".format(k))
936
+ del sd[k]
937
+ self.load_state_dict(sd, strict=False)
938
+ print(f"Restored from {path}")
939
+
940
+ def encode(self, x):
941
+ # x = self.time_shuffle_operation(x)
942
+ # x = self.freq_split_subband(x)
943
+ h = self.encoder(x)
944
+ moments = self.quant_conv(h)
945
+ posterior = DiagonalGaussianDistribution(moments)
946
+ return posterior
947
+
948
+ def decode(self, z):
949
+ z = self.post_quant_conv(z)
950
+ dec = self.decoder(z)
951
+ # bs, ch, shuffled_timesteps, fbins = dec.size()
952
+ # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
953
+ # dec = self.freq_merge_subband(dec)
954
+ return dec
955
+
956
+ def decode_to_waveform(self, dec):
957
+
958
+ if self.image_key == "fbank":
959
+ dec = dec.squeeze(1).permute(0, 2, 1)
960
+ wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder)
961
+ elif self.image_key == "stft":
962
+ dec = dec.squeeze(1).permute(0, 2, 1)
963
+ wav_reconstruction = self.wave_decoder(dec)
964
+ return wav_reconstruction
965
+
966
+ def mel_spectrogram_to_waveform(
967
+ self, mel, savepath=".", bs=None, name="outwav", save=True
968
+ ):
969
+ # Mel: [bs, 1, t-steps, fbins]
970
+ if len(mel.size()) == 4:
971
+ mel = mel.squeeze(1)
972
+ mel = mel.permute(0, 2, 1)
973
+ waveform = self.vocoder(mel)
974
+ waveform = waveform.cpu().detach().numpy()
975
+ #if save:
976
+ # self.save_waveform(waveform, savepath, name)
977
+ return waveform
978
+
979
+ @torch.no_grad()
980
+ def encode_first_stage(self, x):
981
+ return self.encode(x)
982
+
983
+ @torch.no_grad()
984
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
985
+ if predict_cids:
986
+ if z.dim() == 4:
987
+ z = torch.argmax(z.exp(), dim=1).long()
988
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
989
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
990
+
991
+ z = 1.0 / self.scale_factor * z
992
+ return self.decode(z)
993
+
994
+ def decode_first_stage_withgrad(self, z):
995
+ z = 1.0 / self.scale_factor * z
996
+ return self.decode(z)
997
+
998
+ def get_first_stage_encoding(self, encoder_posterior, use_mode=False):
999
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode:
1000
+ z = encoder_posterior.sample()
1001
+ elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode:
1002
+ z = encoder_posterior.mode()
1003
+ elif isinstance(encoder_posterior, torch.Tensor):
1004
+ z = encoder_posterior
1005
+ else:
1006
+ raise NotImplementedError(
1007
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
1008
+ )
1009
+ return self.scale_factor * z
1010
+
1011
+ def visualize_latent(self, input):
1012
+ import matplotlib.pyplot as plt
1013
+
1014
+ # for i in range(10):
1015
+ # zero_input = torch.zeros_like(input) - 11.59
1016
+ # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59
1017
+
1018
+ # posterior = self.encode(zero_input)
1019
+ # latent = posterior.sample()
1020
+ # avg_latent = torch.mean(latent, dim=1)[0]
1021
+ # plt.imshow(avg_latent.cpu().detach().numpy().T)
1022
+ # plt.savefig("%s.png" % i)
1023
+ # plt.close()
1024
+
1025
+ np.save("input.npy", input.cpu().detach().numpy())
1026
+ # zero_input = torch.zeros_like(input) - 11.59
1027
+ time_input = input.clone()
1028
+ time_input[:, :, :, :32] *= 0
1029
+ time_input[:, :, :, :32] -= 11.59
1030
+
1031
+ np.save("time_input.npy", time_input.cpu().detach().numpy())
1032
+
1033
+ posterior = self.encode(time_input)
1034
+ latent = posterior.sample()
1035
+ np.save("time_latent.npy", latent.cpu().detach().numpy())
1036
+ avg_latent = torch.mean(latent, dim=1)
1037
+ for i in range(avg_latent.size(0)):
1038
+ plt.imshow(avg_latent[i].cpu().detach().numpy().T)
1039
+ plt.savefig("freq_%s.png" % i)
1040
+ plt.close()
1041
+
1042
+ freq_input = input.clone()
1043
+ freq_input[:, :, :512, :] *= 0
1044
+ freq_input[:, :, :512, :] -= 11.59
1045
+
1046
+ np.save("freq_input.npy", freq_input.cpu().detach().numpy())
1047
+
1048
+ posterior = self.encode(freq_input)
1049
+ latent = posterior.sample()
1050
+ np.save("freq_latent.npy", latent.cpu().detach().numpy())
1051
+ avg_latent = torch.mean(latent, dim=1)
1052
+ for i in range(avg_latent.size(0)):
1053
+ plt.imshow(avg_latent[i].cpu().detach().numpy().T)
1054
+ plt.savefig("time_%s.png" % i)
1055
+ plt.close()
1056
+
1057
+ def get_input(self, batch):
1058
+ fname, text, label_indices, waveform, stft, fbank = (
1059
+ batch["fname"],
1060
+ batch["text"],
1061
+ batch["label_vector"],
1062
+ batch["waveform"],
1063
+ batch["stft"],
1064
+ batch["log_mel_spec"],
1065
+ )
1066
+ # if(self.time_shuffle != 1):
1067
+ # if(fbank.size(1) % self.time_shuffle != 0):
1068
+ # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
1069
+ # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))
1070
+
1071
+ ret = {}
1072
+
1073
+ ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
1074
+ fbank.unsqueeze(1),
1075
+ stft.unsqueeze(1),
1076
+ fname,
1077
+ waveform.unsqueeze(1),
1078
+ )
1079
+
1080
+ return ret
1081
+
1082
+ def save_wave(self, batch_wav, fname, save_dir):
1083
+ os.makedirs(save_dir, exist_ok=True)
1084
+
1085
+ for wav, name in zip(batch_wav, fname):
1086
+ name = os.path.basename(name)
1087
+
1088
+ sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
1089
+
1090
+ def get_last_layer(self):
1091
+ return self.decoder.conv_out.weight
1092
+
1093
+ @torch.no_grad()
1094
+ def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
1095
+ log = dict()
1096
+ x = batch.to(self.device)
1097
+ if not only_inputs:
1098
+ xrec, posterior = self(x)
1099
+ log["samples"] = self.decode(posterior.sample())
1100
+ log["reconstructions"] = xrec
1101
+
1102
+ log["inputs"] = x
1103
+ wavs = self._log_img(log, train=train, index=0, waveform=waveform)
1104
+ return wavs
1105
+
1106
+ def _log_img(self, log, train=True, index=0, waveform=None):
1107
+ images_input = self.tensor2numpy(log["inputs"][index, 0]).T
1108
+ images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
1109
+ images_samples = self.tensor2numpy(log["samples"][index, 0]).T
1110
+
1111
+ if train:
1112
+ name = "train"
1113
+ else:
1114
+ name = "val"
1115
+
1116
+ if self.logger is not None:
1117
+ self.logger.log_image(
1118
+ "img_%s" % name,
1119
+ [images_input, images_reconstruct, images_samples],
1120
+ caption=["input", "reconstruct", "samples"],
1121
+ )
1122
+
1123
+ inputs, reconstructions, samples = (
1124
+ log["inputs"],
1125
+ log["reconstructions"],
1126
+ log["samples"],
1127
+ )
1128
+
1129
+ if self.image_key == "fbank":
1130
+ wav_original, wav_prediction = synth_one_sample(
1131
+ inputs[index],
1132
+ reconstructions[index],
1133
+ labels="validation",
1134
+ vocoder=self.vocoder,
1135
+ )
1136
+ wav_original, wav_samples = synth_one_sample(
1137
+ inputs[index], samples[index], labels="validation", vocoder=self.vocoder
1138
+ )
1139
+ wav_original, wav_samples, wav_prediction = (
1140
+ wav_original[0],
1141
+ wav_samples[0],
1142
+ wav_prediction[0],
1143
+ )
1144
+ elif self.image_key == "stft":
1145
+ wav_prediction = (
1146
+ self.decode_to_waveform(reconstructions)[index, 0]
1147
+ .cpu()
1148
+ .detach()
1149
+ .numpy()
1150
+ )
1151
+ wav_samples = (
1152
+ self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
1153
+ )
1154
+ wav_original = waveform[index, 0].cpu().detach().numpy()
1155
+
1156
+ if self.logger is not None:
1157
+ self.logger.experiment.log(
1158
+ {
1159
+ "original_%s"
1160
+ % name: wandb.Audio(
1161
+ wav_original, caption="original", sample_rate=self.sampling_rate
1162
+ ),
1163
+ "reconstruct_%s"
1164
+ % name: wandb.Audio(
1165
+ wav_prediction,
1166
+ caption="reconstruct",
1167
+ sample_rate=self.sampling_rate,
1168
+ ),
1169
+ "samples_%s"
1170
+ % name: wandb.Audio(
1171
+ wav_samples, caption="samples", sample_rate=self.sampling_rate
1172
+ ),
1173
+ }
1174
+ )
1175
+
1176
+ return wav_original, wav_prediction, wav_samples
1177
+
1178
+ def tensor2numpy(self, tensor):
1179
+ return tensor.cpu().detach().numpy()
1180
+
1181
+ def to_rgb(self, x):
1182
+ assert self.image_key == "segmentation"
1183
+ if not hasattr(self, "colorize"):
1184
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
1185
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
1186
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
1187
+ return x
1188
+
1189
+
1190
+ class IdentityFirstStage(torch.nn.Module):
1191
+ def __init__(self, *args, vq_interface=False, **kwargs):
1192
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
1193
+ super().__init__()
1194
+
1195
+ def encode(self, x, *args, **kwargs):
1196
+ return x
1197
+
1198
+ def decode(self, x, *args, **kwargs):
1199
+ return x
1200
+
1201
+ def quantize(self, x, *args, **kwargs):
1202
+ if self.vq_interface:
1203
+ return x, None, [None, None, None]
1204
+ return x
1205
+
1206
+ def forward(self, x, *args, **kwargs):
1207
+ return x
1208
+
1209
+
1210
+ def window_sumsquare(
1211
+ window,
1212
+ n_frames,
1213
+ hop_length,
1214
+ win_length,
1215
+ n_fft,
1216
+ dtype=np.float32,
1217
+ norm=None,
1218
+ ):
1219
+ """
1220
+ # from librosa 0.6
1221
+ Compute the sum-square envelope of a window function at a given hop length.
1222
+
1223
+ This is used to estimate modulation effects induced by windowing
1224
+ observations in short-time fourier transforms.
1225
+
1226
+ Parameters
1227
+ ----------
1228
+ window : string, tuple, number, callable, or list-like
1229
+ Window specification, as in `get_window`
1230
+
1231
+ n_frames : int > 0
1232
+ The number of analysis frames
1233
+
1234
+ hop_length : int > 0
1235
+ The number of samples to advance between frames
1236
+
1237
+ win_length : [optional]
1238
+ The length of the window function. By default, this matches `n_fft`.
1239
+
1240
+ n_fft : int > 0
1241
+ The length of each analysis frame.
1242
+
1243
+ dtype : np.dtype
1244
+ The data type of the output
1245
+
1246
+ Returns
1247
+ -------
1248
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
1249
+ The sum-squared envelope of the window function
1250
+ """
1251
+ if win_length is None:
1252
+ win_length = n_fft
1253
+
1254
+ n = n_fft + hop_length * (n_frames - 1)
1255
+ x = np.zeros(n, dtype=dtype)
1256
+
1257
+ # Compute the squared window at the desired length
1258
+ win_sq = get_window(window, win_length, fftbins=True)
1259
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
1260
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
1261
+
1262
+ # Fill the envelope
1263
+ for i in range(n_frames):
1264
+ sample = i * hop_length
1265
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
1266
+ return x
1267
+
1268
+ def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
1269
+ """
1270
+ PARAMS
1271
+ ------
1272
+ C: compression factor
1273
+ """
1274
+ return normalize_fun(torch.clamp(x, min=clip_val) * C)
1275
+
1276
+
1277
+ def dynamic_range_decompression(x, C=1):
1278
+ """
1279
+ PARAMS
1280
+ ------
1281
+ C: compression factor used to compress
1282
+ """
1283
+ return torch.exp(x) / C
1284
+
1285
+
1286
+ class STFT(torch.nn.Module):
1287
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
1288
+
1289
+ def __init__(self, filter_length, hop_length, win_length, window="hann"):
1290
+ super(STFT, self).__init__()
1291
+ self.filter_length = filter_length
1292
+ self.hop_length = hop_length
1293
+ self.win_length = win_length
1294
+ self.window = window
1295
+ self.forward_transform = None
1296
+ scale = self.filter_length / self.hop_length
1297
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
1298
+
1299
+ cutoff = int((self.filter_length / 2 + 1))
1300
+ fourier_basis = np.vstack(
1301
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
1302
+ )
1303
+
1304
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
1305
+ inverse_basis = torch.FloatTensor(
1306
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
1307
+ )
1308
+
1309
+ if window is not None:
1310
+ assert filter_length >= win_length
1311
+ # get window and zero center pad it to filter_length
1312
+ fft_window = get_window(window, win_length, fftbins=True)
1313
+ fft_window = pad_center(fft_window, size=filter_length)
1314
+ fft_window = torch.from_numpy(fft_window).float()
1315
+
1316
+ # window the bases
1317
+ forward_basis *= fft_window
1318
+ inverse_basis *= fft_window
1319
+
1320
+ self.register_buffer("forward_basis", forward_basis.float())
1321
+ self.register_buffer("inverse_basis", inverse_basis.float())
1322
+
1323
+ def transform(self, input_data):
1324
+
1325
+ device = self.forward_basis.device
1326
+ input_data = input_data.to(device)
1327
+
1328
+ num_batches = input_data.size(0)
1329
+ num_samples = input_data.size(1)
1330
+
1331
+ self.num_samples = num_samples
1332
+
1333
+ # similar to librosa, reflect-pad the input
1334
+ input_data = input_data.view(num_batches, 1, num_samples)
1335
+ input_data = torch.nn.functional.pad(
1336
+ input_data.unsqueeze(1),
1337
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
1338
+ mode="reflect",
1339
+ )
1340
+ input_data = input_data.squeeze(1)
1341
+
1342
+ forward_transform = torch.nn.functional.conv1d(
1343
+ input_data,
1344
+ torch.autograd.Variable(self.forward_basis, requires_grad=False),
1345
+ stride=self.hop_length,
1346
+ padding=0,
1347
+ )#.cpu()
1348
+
1349
+ cutoff = int((self.filter_length / 2) + 1)
1350
+ real_part = forward_transform[:, :cutoff, :]
1351
+ imag_part = forward_transform[:, cutoff:, :]
1352
+
1353
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
1354
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
1355
+
1356
+ return magnitude, phase
1357
+
1358
+ def inverse(self, magnitude, phase):
1359
+
1360
+ device = self.forward_basis.device
1361
+ magnitude, phase = magnitude.to(device), phase.to(device)
1362
+
1363
+ recombine_magnitude_phase = torch.cat(
1364
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
1365
+ )
1366
+
1367
+ inverse_transform = torch.nn.functional.conv_transpose1d(
1368
+ recombine_magnitude_phase,
1369
+ torch.autograd.Variable(self.inverse_basis, requires_grad=False),
1370
+ stride=self.hop_length,
1371
+ padding=0,
1372
+ )
1373
+
1374
+ if self.window is not None:
1375
+ window_sum = window_sumsquare(
1376
+ self.window,
1377
+ magnitude.size(-1),
1378
+ hop_length=self.hop_length,
1379
+ win_length=self.win_length,
1380
+ n_fft=self.filter_length,
1381
+ dtype=np.float32,
1382
+ )
1383
+ # remove modulation effects
1384
+ approx_nonzero_indices = torch.from_numpy(
1385
+ np.where(window_sum > tiny(window_sum))[0]
1386
+ )
1387
+ window_sum = torch.autograd.Variable(
1388
+ torch.from_numpy(window_sum), requires_grad=False
1389
+ )
1390
+ window_sum = window_sum
1391
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
1392
+ approx_nonzero_indices
1393
+ ]
1394
+
1395
+ # scale by hop ratio
1396
+ inverse_transform *= float(self.filter_length) / self.hop_length
1397
+
1398
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
1399
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
1400
+
1401
+ return inverse_transform
1402
+
1403
+ def forward(self, input_data):
1404
+ self.magnitude, self.phase = self.transform(input_data)
1405
+ reconstruction = self.inverse(self.magnitude, self.phase)
1406
+ return reconstruction
1407
+
1408
+
1409
+ class TacotronSTFT(torch.nn.Module):
1410
+ def __init__(
1411
+ self,
1412
+ filter_length,
1413
+ hop_length,
1414
+ win_length,
1415
+ n_mel_channels,
1416
+ sampling_rate,
1417
+ mel_fmin,
1418
+ mel_fmax,
1419
+ ):
1420
+ super(TacotronSTFT, self).__init__()
1421
+ self.n_mel_channels = n_mel_channels
1422
+ self.sampling_rate = sampling_rate
1423
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
1424
+ mel_basis = librosa_mel_fn(
1425
+ sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax
1426
+ )
1427
+ mel_basis = torch.from_numpy(mel_basis).float()
1428
+ self.register_buffer("mel_basis", mel_basis)
1429
+
1430
+ def spectral_normalize(self, magnitudes, normalize_fun):
1431
+ output = dynamic_range_compression(magnitudes, normalize_fun)
1432
+ return output
1433
+
1434
+ def spectral_de_normalize(self, magnitudes):
1435
+ output = dynamic_range_decompression(magnitudes)
1436
+ return output
1437
+
1438
+ def mel_spectrogram(self, y, normalize_fun=torch.log):
1439
+ """Computes mel-spectrograms from a batch of waves
1440
+ PARAMS
1441
+ ------
1442
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
1443
+
1444
+ RETURNS
1445
+ -------
1446
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
1447
+ """
1448
+ assert torch.min(y.data) >= -1, torch.min(y.data)
1449
+ assert torch.max(y.data) <= 1, torch.max(y.data)
1450
+
1451
+ magnitudes, phases = self.stft_fn.transform(y)
1452
+ magnitudes = magnitudes.data
1453
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
1454
+ mel_output = self.spectral_normalize(mel_output, normalize_fun)
1455
+ energy = torch.norm(magnitudes, dim=1)
1456
+
1457
+ log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
1458
+
1459
+ return mel_output, log_magnitudes, energy
1460
+
1461
+
1462
+ def build_pretrained_models(ckpt):
1463
+ checkpoint = torch.load(ckpt, map_location="cpu")
1464
+ scale_factor = checkpoint["state_dict"]["scale_factor"].item()
1465
+ print("scale_factor: ", scale_factor)
1466
+
1467
+ vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
1468
+
1469
+ config = {
1470
+ "preprocessing": {
1471
+ "audio": {
1472
+ "sampling_rate": 48000,
1473
+ "max_wav_value": 32768,
1474
+ "duration": 10.24
1475
+ },
1476
+ "stft": {
1477
+ "filter_length": 2048,
1478
+ "hop_length": 480,
1479
+ "win_length": 2048
1480
+ },
1481
+ "mel": {
1482
+ "n_mel_channels": 256,
1483
+ "mel_fmin": 20,
1484
+ "mel_fmax": 24000
1485
+ }
1486
+ },
1487
+ "model": {
1488
+ "params": {
1489
+ "first_stage_config": {
1490
+ "params": {
1491
+ "sampling_rate": 48000,
1492
+ "batchsize": 4,
1493
+ "monitor": "val/rec_loss",
1494
+ "image_key": "fbank",
1495
+ "subband": 1,
1496
+ "embed_dim": 16,
1497
+ "time_shuffle": 1,
1498
+ "lossconfig": {
1499
+ "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator",
1500
+ "params": {
1501
+ "disc_start": 50001,
1502
+ "kl_weight": 1000,
1503
+ "disc_weight": 0.5,
1504
+ "disc_in_channels": 1
1505
+ }
1506
+ },
1507
+ "ddconfig": {
1508
+ "double_z": True,
1509
+ "mel_bins": 256,
1510
+ "z_channels": 16,
1511
+ "resolution": 256,
1512
+ "downsample_time": False,
1513
+ "in_channels": 1,
1514
+ "out_ch": 1,
1515
+ "ch": 128,
1516
+ "ch_mult": [
1517
+ 1,
1518
+ 2,
1519
+ 4,
1520
+ 8
1521
+ ],
1522
+ "num_res_blocks": 2,
1523
+ "attn_resolutions": [],
1524
+ "dropout": 0
1525
+ }
1526
+ }
1527
+ },
1528
+ }
1529
+ }
1530
+ }
1531
+ vae_config = config["model"]["params"]["first_stage_config"]["params"]
1532
+ vae_config["scale_factor"] = scale_factor
1533
+
1534
+ vae = AutoencoderKL(**vae_config)
1535
+ vae.load_state_dict(vae_state_dict)
1536
+
1537
+ fn_STFT = TacotronSTFT(
1538
+ config["preprocessing"]["stft"]["filter_length"],
1539
+ config["preprocessing"]["stft"]["hop_length"],
1540
+ config["preprocessing"]["stft"]["win_length"],
1541
+ config["preprocessing"]["mel"]["n_mel_channels"],
1542
+ config["preprocessing"]["audio"]["sampling_rate"],
1543
+ config["preprocessing"]["mel"]["mel_fmin"],
1544
+ config["preprocessing"]["mel"]["mel_fmax"],
1545
+ )
1546
+
1547
+ vae.eval()
1548
+ fn_STFT.eval()
1549
+ return vae, fn_STFT
1550
+
1551
+
MuCodec/tools/torch_tools.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import random
4
+ import itertools
5
+ import numpy as np
6
+
7
+
8
+
9
+ def normalize_wav(waveform):
10
+ waveform = waveform - torch.mean(waveform)
11
+ waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
12
+ return waveform * 0.5
13
+
14
+
15
+ def pad_wav(waveform, segment_length):
16
+ waveform_length = len(waveform)
17
+
18
+ if segment_length is None or waveform_length == segment_length:
19
+ return waveform
20
+ elif waveform_length > segment_length:
21
+ return waveform[:segment_length]
22
+ else:
23
+ pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
24
+ waveform = torch.cat([waveform, pad_wav])
25
+ return waveform
26
+
27
+
28
+ def _pad_spec(fbank, target_length=1024):
29
+ batch, n_frames, channels = fbank.shape
30
+ p = target_length - n_frames
31
+ if p > 0:
32
+ pad = torch.zeros(batch, p, channels).to(fbank.device)
33
+ fbank = torch.cat([fbank, pad], 1)
34
+ elif p < 0:
35
+ fbank = fbank[:, :target_length, :]
36
+
37
+ if channels % 2 != 0:
38
+ fbank = fbank[:, :, :-1]
39
+
40
+ return fbank
41
+
42
+
43
+ def read_wav_file(filename, segment_length):
44
+ waveform, sr = torchaudio.load(filename) # Faster!!!
45
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0]
46
+ try:
47
+ waveform = normalize_wav(waveform)
48
+ except:
49
+ print ("Exception normalizing:", filename)
50
+ waveform = torch.ones(160000)
51
+ waveform = pad_wav(waveform, segment_length).unsqueeze(0)
52
+ waveform = waveform / torch.max(torch.abs(waveform))
53
+ waveform = 0.5 * waveform
54
+ return waveform
55
+
56
+
57
+ def get_mel_from_wav(audio, _stft):
58
+ audio = torch.nan_to_num(torch.clip(audio, -1, 1))
59
+ audio = torch.autograd.Variable(audio, requires_grad=False)
60
+ melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
61
+ return melspec, log_magnitudes_stft, energy
62
+
63
+
64
+ def wav_to_fbank(paths, target_length=1024, fn_STFT=None):
65
+ assert fn_STFT is not None
66
+
67
+ waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160
68
+
69
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
70
+ fbank = fbank.transpose(1, 2)
71
+ log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
72
+
73
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
74
+ log_magnitudes_stft, target_length
75
+ )
76
+
77
+ return fbank, log_magnitudes_stft, waveform
78
+
79
+ def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None):
80
+ assert fn_STFT is not None
81
+
82
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
83
+ fbank = fbank.transpose(1, 2)
84
+ log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
85
+ # print(fbank.shape, log_magnitudes_stft.shape)
86
+
87
+ if(target_length>0):
88
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
89
+ log_magnitudes_stft, target_length
90
+ )
91
+
92
+ return fbank, log_magnitudes_stft, waveform
93
+
94
+
95
+ def uncapitalize(s):
96
+ if s:
97
+ return s[:1].lower() + s[1:]
98
+ else:
99
+ return ""
100
+
__pycache__/audio_tokens.cpython-310.pyc ADDED
Binary file (826 Bytes). View file
 
__pycache__/audio_tokens.cpython-312.pyc ADDED
Binary file (951 Bytes). View file
 
__pycache__/condition_encoders.cpython-310.pyc ADDED
Binary file (4.24 kB). View file
 
__pycache__/condition_encoders.cpython-312.pyc ADDED
Binary file (6.85 kB). View file
 
__pycache__/dataset.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
__pycache__/dataset.cpython-312.pyc ADDED
Binary file (21.5 kB). View file
 
__pycache__/decoders.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
__pycache__/decoders.cpython-312.pyc ADDED
Binary file (5.77 kB). View file
 
__pycache__/inference_full.cpython-310.pyc ADDED
Binary file (24.5 kB). View file
 
__pycache__/inference_full.cpython-312.pyc ADDED
Binary file (35 kB). View file
 
__pycache__/modelling_qwen3.cpython-310.pyc ADDED
Binary file (5.33 kB). View file
 
__pycache__/modelling_qwen3.cpython-312.pyc ADDED
Binary file (9.53 kB). View file
 
__pycache__/runtime_utils.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
__pycache__/runtime_utils.cpython-312.pyc ADDED
Binary file (3.65 kB). View file
 
audio_tokens.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ SOA_TOKEN = "[SOA]"
5
+ EOA_TOKEN = "[EOA]"
6
+ MASK_AUDIO_TOKEN = "[MASK_AUDIO]"
7
+
8
+
9
+ def audio_id_to_token(audio_id: int) -> str:
10
+ return f"<AUDIO_{int(audio_id)}>"
11
+
12
+ def add_audio_special_tokens(tokenizer, num_audio_token: int) -> int:
13
+ special_tokens = [audio_id_to_token(i) for i in range(num_audio_token)] + [
14
+ MASK_AUDIO_TOKEN,
15
+ SOA_TOKEN,
16
+ EOA_TOKEN,
17
+ ]
18
+ return tokenizer.add_tokens(
19
+ special_tokens,
20
+ special_tokens=True,
21
+ )
batch_infer_checkpoints.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import json
6
+ import traceback
7
+ from pathlib import Path
8
+
9
+ import datasets
10
+ import torch
11
+
12
+ from inference_full import (
13
+ TokenLayout,
14
+ batch_generate_segmentwise,
15
+ build_mucodec_decoder,
16
+ generate_segmentwise,
17
+ load_hf_template_sample_from_music_dataset,
18
+ save_outputs,
19
+ )
20
+ from runtime_utils import (
21
+ load_magel_checkpoint,
22
+ load_music_dataset,
23
+ maybe_compile_model,
24
+ resolve_device,
25
+ seed_everything,
26
+ )
27
+
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser(
31
+ description="Run audio inference on validation samples for multiple checkpoints."
32
+ )
33
+ parser.add_argument(
34
+ "--checkpoint_list",
35
+ type=str,
36
+ default=None,
37
+ help="Text file with one checkpoint path per line.",
38
+ )
39
+ parser.add_argument(
40
+ "--checkpoint_dir",
41
+ type=str,
42
+ default=None,
43
+ help="Directory to scan for checkpoint-* subdirectories and optional final.",
44
+ )
45
+ parser.add_argument(
46
+ "--dataset_path",
47
+ type=str,
48
+ default="muse_mucodec_chord.ds",
49
+ )
50
+ parser.add_argument(
51
+ "--split",
52
+ type=str,
53
+ default="validation",
54
+ )
55
+ parser.add_argument(
56
+ "--tokenizer_path",
57
+ type=str,
58
+ default="checkpoints/Qwen3-0.6B",
59
+ )
60
+ parser.add_argument(
61
+ "--sample_indices",
62
+ type=int,
63
+ nargs="*",
64
+ default=None,
65
+ help="Specific sample indices to infer. Leave unset to run the full split.",
66
+ )
67
+ parser.add_argument(
68
+ "--max_samples",
69
+ type=int,
70
+ default=0,
71
+ help="Run only the first N samples from the split. Ignored if --sample_indices is set.",
72
+ )
73
+ parser.add_argument(
74
+ "--infer_batch_size",
75
+ type=int,
76
+ default=1,
77
+ help="Number of samples to decode together per step for the same checkpoint.",
78
+ )
79
+ parser.add_argument("--temperature", type=float, default=1.0)
80
+ parser.add_argument("--top_k", type=int, default=50)
81
+ parser.add_argument("--top_p", type=float, default=0.90)
82
+ parser.add_argument("--greedy", action="store_true", default=False)
83
+ parser.add_argument("--max_audio_tokens", type=int, default=0)
84
+ parser.add_argument("--fps", type=int, default=25)
85
+ parser.add_argument("--seed", type=int, default=1234)
86
+ parser.add_argument("--device", type=str, default="auto")
87
+ parser.add_argument(
88
+ "--dtype",
89
+ type=str,
90
+ default="bfloat16",
91
+ choices=["float32", "float16", "bfloat16"],
92
+ )
93
+ parser.add_argument(
94
+ "--attn_implementation",
95
+ type=str,
96
+ default="sdpa",
97
+ choices=["eager", "sdpa", "flash_attention_2"],
98
+ )
99
+ parser.add_argument("--use_cache", action="store_true", default=True)
100
+ parser.add_argument("--no_cache", action="store_true", default=False)
101
+ parser.add_argument("--compile", action="store_true", default=False)
102
+ parser.add_argument(
103
+ "--compile_mode",
104
+ type=str,
105
+ default="reduce-overhead",
106
+ choices=["default", "reduce-overhead", "max-autotune"],
107
+ )
108
+ parser.add_argument("--mucodec_device", type=str, default="auto")
109
+ parser.add_argument("--mucodec_layer_num", type=int, default=7)
110
+ parser.add_argument("--mucodec_duration", type=float, default=40.96)
111
+ parser.add_argument("--mucodec_guidance_scale", type=float, default=1.5)
112
+ parser.add_argument("--mucodec_num_steps", type=int, default=20)
113
+ parser.add_argument("--mucodec_sample_rate", type=int, default=48000)
114
+ parser.add_argument(
115
+ "--output_dir",
116
+ type=str,
117
+ default="/root/new_batch_predictions",
118
+ help="Root output dir. Each checkpoint gets its own subdirectory.",
119
+ )
120
+ parser.add_argument(
121
+ "--summary_json",
122
+ type=str,
123
+ default="/root/new_batch_predictions/summary.json",
124
+ )
125
+ args = parser.parse_args()
126
+ if not args.checkpoint_list and not args.checkpoint_dir:
127
+ parser.error("one of --checkpoint_list or --checkpoint_dir is required")
128
+ return args
129
+
130
+
131
+ def parse_checkpoint_list(path: str) -> list[str]:
132
+ checkpoints: list[str] = []
133
+ with open(path, "r", encoding="utf-8") as f:
134
+ for raw_line in f:
135
+ line = raw_line.strip()
136
+ if not line or line.startswith("#"):
137
+ continue
138
+ checkpoints.append(line)
139
+ if not checkpoints:
140
+ raise ValueError(f"No checkpoints found in list: {path}")
141
+ return checkpoints
142
+
143
+
144
+ def scan_checkpoint_dir(path: str) -> list[str]:
145
+ root = Path(path)
146
+ if not root.is_dir():
147
+ raise NotADirectoryError(f"Checkpoint directory not found: {path}")
148
+
149
+ checkpoint_dirs = [
150
+ item
151
+ for item in root.iterdir()
152
+ if item.is_dir() and item.name.startswith("checkpoint-")
153
+ ]
154
+ checkpoint_dirs = sorted(
155
+ checkpoint_dirs,
156
+ key=lambda p: int(p.name.split("-", 1)[1])
157
+ if p.name.split("-", 1)[1].isdigit()
158
+ else p.name,
159
+ )
160
+
161
+ final_dir = root / "final"
162
+ if final_dir.is_dir():
163
+ checkpoint_dirs.append(final_dir)
164
+
165
+ checkpoints = [str(path_obj) for path_obj in checkpoint_dirs]
166
+ if not checkpoints:
167
+ raise ValueError(f"No checkpoint-* directories found under: {path}")
168
+ return checkpoints
169
+
170
+
171
+ def get_dtype(name: str) -> torch.dtype:
172
+ return {
173
+ "float32": torch.float32,
174
+ "float16": torch.float16,
175
+ "bfloat16": torch.bfloat16,
176
+ }[name]
177
+
178
+
179
+ def get_split_size(dataset_path: str, split: str) -> int:
180
+ dataset_obj = datasets.load_from_disk(dataset_path)
181
+ if isinstance(dataset_obj, datasets.DatasetDict):
182
+ if split not in dataset_obj:
183
+ raise KeyError(f"Split not found: {split}")
184
+ return len(dataset_obj[split])
185
+ return len(dataset_obj)
186
+
187
+
188
+ def resolve_sample_indices(
189
+ dataset_path: str,
190
+ split: str,
191
+ sample_indices: list[int] | None,
192
+ max_samples: int,
193
+ ) -> list[int]:
194
+ if sample_indices:
195
+ return list(sample_indices)
196
+ split_size = get_split_size(dataset_path, split)
197
+ if max_samples and max_samples > 0:
198
+ split_size = min(split_size, max_samples)
199
+ return list(range(split_size))
200
+
201
+
202
+ def sanitize_checkpoint_name(checkpoint_path: str) -> str:
203
+ path = Path(checkpoint_path.rstrip("/"))
204
+ if path.parent.name:
205
+ return f"{path.parent.name}__{path.name}"
206
+ return path.name
207
+
208
+
209
+ def chunk_list(items: list[int], chunk_size: int) -> list[list[int]]:
210
+ return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)]
211
+
212
+
213
+ def main() -> None:
214
+ args = parse_args()
215
+ seed_everything(args.seed)
216
+
217
+ if args.checkpoint_list:
218
+ checkpoints = parse_checkpoint_list(args.checkpoint_list)
219
+ else:
220
+ checkpoints = scan_checkpoint_dir(args.checkpoint_dir)
221
+ sample_indices = resolve_sample_indices(
222
+ dataset_path=args.dataset_path,
223
+ split=args.split,
224
+ sample_indices=args.sample_indices,
225
+ max_samples=args.max_samples,
226
+ )
227
+
228
+ use_cache = args.use_cache and not args.no_cache
229
+ device = resolve_device(args.device)
230
+ dtype = get_dtype(args.dtype)
231
+ if device.type == "cpu" and dtype != torch.float32:
232
+ print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.")
233
+ dtype = torch.float32
234
+
235
+ output_root = Path(args.output_dir)
236
+ output_root.mkdir(parents=True, exist_ok=True)
237
+
238
+ print(f"[INFO] checkpoints={len(checkpoints)}")
239
+ print(f"[INFO] samples_per_checkpoint={len(sample_indices)}")
240
+ print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}")
241
+
242
+ mucodec_decoder = build_mucodec_decoder(args)
243
+ summary: list[dict] = []
244
+
245
+ for checkpoint_path in checkpoints:
246
+ ckpt_name = sanitize_checkpoint_name(checkpoint_path)
247
+ ckpt_output_dir = output_root / ckpt_name
248
+ json_dir = ckpt_output_dir / "json"
249
+ wav_dir = ckpt_output_dir / "wav"
250
+
251
+ print(f"\n[INFO] loading model from {checkpoint_path}")
252
+ model = load_magel_checkpoint(
253
+ checkpoint_path=checkpoint_path,
254
+ device=device,
255
+ dtype=dtype,
256
+ attn_implementation=args.attn_implementation,
257
+ )
258
+ model = maybe_compile_model(
259
+ model,
260
+ enabled=bool(args.compile),
261
+ mode=str(args.compile_mode),
262
+ )
263
+ num_audio_codebook = int(getattr(model.config, "magel_num_audio_token", 16384))
264
+ music_ds = load_music_dataset(
265
+ dataset_path=args.dataset_path,
266
+ split=args.split,
267
+ tokenizer_path=args.tokenizer_path,
268
+ num_audio_token=num_audio_codebook,
269
+ use_fast=True,
270
+ )
271
+
272
+ checkpoint_record = {
273
+ "checkpoint_path": checkpoint_path,
274
+ "checkpoint_name": ckpt_name,
275
+ "status": "ok",
276
+ "num_samples_requested": len(sample_indices),
277
+ "results": [],
278
+ }
279
+
280
+ try:
281
+ for batch_indices in chunk_list(sample_indices, max(1, int(args.infer_batch_size))):
282
+ samples = []
283
+ for sample_idx in batch_indices:
284
+ print(
285
+ f"[INFO] checkpoint={ckpt_name} sample_idx={sample_idx} split={args.split}"
286
+ )
287
+ samples.append(
288
+ load_hf_template_sample_from_music_dataset(
289
+ music_ds=music_ds,
290
+ sample_idx=sample_idx,
291
+ num_audio_codebook=num_audio_codebook,
292
+ )
293
+ )
294
+
295
+ layout = TokenLayout(
296
+ num_text_token=samples[0].num_text_token,
297
+ num_audio_codebook=num_audio_codebook,
298
+ )
299
+
300
+ if len(samples) == 1:
301
+ batch_outputs = [
302
+ generate_segmentwise(
303
+ model=model,
304
+ sample=samples[0],
305
+ layout=layout,
306
+ device=device,
307
+ use_cache=use_cache,
308
+ temperature=float(args.temperature),
309
+ top_k=int(args.top_k),
310
+ top_p=float(args.top_p),
311
+ greedy=bool(args.greedy),
312
+ max_audio_tokens=max(0, int(args.max_audio_tokens)),
313
+ )
314
+ ]
315
+ else:
316
+ try:
317
+ batch_outputs = batch_generate_segmentwise(
318
+ model=model,
319
+ samples=samples,
320
+ layout=layout,
321
+ device=device,
322
+ use_cache=use_cache,
323
+ temperature=float(args.temperature),
324
+ top_k=int(args.top_k),
325
+ top_p=float(args.top_p),
326
+ greedy=bool(args.greedy),
327
+ max_audio_tokens=max(0, int(args.max_audio_tokens)),
328
+ )
329
+ except Exception as exc:
330
+ print(
331
+ "[WARN] batch_generate_segmentwise failed; "
332
+ f"falling back to single-sample decode. error={exc!r}"
333
+ )
334
+ traceback.print_exc()
335
+ batch_outputs = [
336
+ generate_segmentwise(
337
+ model=model,
338
+ sample=sample,
339
+ layout=layout,
340
+ device=device,
341
+ use_cache=use_cache,
342
+ temperature=float(args.temperature),
343
+ top_k=int(args.top_k),
344
+ top_p=float(args.top_p),
345
+ greedy=bool(args.greedy),
346
+ max_audio_tokens=max(0, int(args.max_audio_tokens)),
347
+ )
348
+ for sample in samples
349
+ ]
350
+
351
+ for sample_idx, sample, batch_output in zip(batch_indices, samples, batch_outputs):
352
+ generated_ids, sampled_count, sampled_chord_ids, sampled_segment_ids = batch_output
353
+ prefix = f"{sample_idx:05d}_{sample.song_id}"
354
+
355
+ # save_outputs expects these attributes on args.
356
+ args.sample_idx = sample_idx
357
+ args.json_output_dir = str(json_dir)
358
+ args.wav_output_dir = str(wav_dir)
359
+
360
+ save_outputs(
361
+ output_dir=str(ckpt_output_dir),
362
+ output_prefix=prefix,
363
+ sample=sample,
364
+ layout=layout,
365
+ generated_ids=generated_ids,
366
+ sampled_chord_ids=sampled_chord_ids,
367
+ sampled_segment_ids=sampled_segment_ids,
368
+ args=args,
369
+ mucodec_decoder=mucodec_decoder,
370
+ )
371
+
372
+ checkpoint_record["results"].append(
373
+ {
374
+ "sample_idx": sample_idx,
375
+ "song_id": sample.song_id,
376
+ "generated_audio_tokens": sampled_count,
377
+ "wav_path": str(wav_dir / f"{prefix}.wav"),
378
+ "json_path": str(json_dir / f"{prefix}.chord_segment.json"),
379
+ }
380
+ )
381
+ except Exception as exc:
382
+ checkpoint_record["status"] = "error"
383
+ checkpoint_record["error"] = str(exc)
384
+ print(f"[ERROR] checkpoint {checkpoint_path}: {exc!r}")
385
+ traceback.print_exc()
386
+
387
+ summary.append(checkpoint_record)
388
+
389
+ del model
390
+ if device.type == "cuda":
391
+ torch.cuda.empty_cache()
392
+
393
+ summary_path = Path(args.summary_json)
394
+ summary_path.parent.mkdir(parents=True, exist_ok=True)
395
+ with open(summary_path, "w", encoding="utf-8") as f:
396
+ json.dump(summary, f, ensure_ascii=False, indent=2)
397
+
398
+ print(f"\nSaved summary to: {summary_path}")
399
+
400
+
401
+ if __name__ == "__main__":
402
+ main()
condition_encoders.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional
5
+
6
+ from vocab import NUM_CHORD_CLASSES, NUM_STRUCTURE_CLASSES
7
+
8
+
9
+ class EncoderBlock(nn.Module):
10
+ def __init__(
11
+ self, d_model: int, n_layers: int = 2, n_heads: int = 8, dropout: float = 0.0
12
+ ):
13
+ super().__init__()
14
+ layer = nn.TransformerEncoderLayer(
15
+ d_model=d_model,
16
+ nhead=n_heads,
17
+ dim_feedforward=4 * d_model,
18
+ dropout=dropout,
19
+ activation="gelu",
20
+ batch_first=True,
21
+ norm_first=True,
22
+ )
23
+ self.enc = nn.TransformerEncoder(layer, num_layers=n_layers)
24
+
25
+ def forward(
26
+ self, x: torch.Tensor, pad_mask: Optional[torch.BoolTensor] = None
27
+ ) -> torch.Tensor:
28
+ # pad_mask: [B,T], True means PAD (masked)
29
+ return self.enc(x, src_key_padding_mask=pad_mask)
30
+
31
+
32
+ class ConditionEncoder(nn.Module):
33
+ """
34
+ Condition encoder for AdaLN-Zero:
35
+ Inputs: two aligned sequences
36
+ x_chord: [B,T,D_in]
37
+ x_seg: [B,T,D_in]
38
+
39
+ Output:
40
+ cond_expanded: [B,T,H] (feed into your AdaLN layers as cond_expanded)
41
+
42
+ What it encodes per token:
43
+ - token-level: chord/segment content at time t
44
+ - position: global position (always) + optional segment-relative position
45
+ - segment context: via x_seg + bidirectional transformer mixing
46
+
47
+ Notes:
48
+ - Non-causal (sees future): good for "guidance" conditions.
49
+ - Compute once per sample at generation start; slice per step.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ hidden_size: int,
55
+ chord_embed_dim: int = 512,
56
+ structure_embed_dim: int = 512,
57
+ n_layers: int = 2,
58
+ n_heads: int = 8,
59
+ dropout: float = 0.0,
60
+ ):
61
+ super().__init__()
62
+
63
+ self.hidden_size = hidden_size
64
+ self.chord_embedding = nn.Embedding(
65
+ NUM_CHORD_CLASSES, chord_embed_dim, padding_idx=0
66
+ )
67
+ self.structure_embedding = nn.Embedding(
68
+ NUM_STRUCTURE_CLASSES, structure_embed_dim, padding_idx=0
69
+ )
70
+
71
+ self.cond_dim = chord_embed_dim + structure_embed_dim
72
+ self.cond_proj = nn.Linear(self.cond_dim, hidden_size)
73
+
74
+ # Small bidirectional transformer
75
+ self.encoder = EncoderBlock(
76
+ d_model=hidden_size, n_layers=n_layers, n_heads=n_heads, dropout=dropout
77
+ )
78
+
79
+ self.proj_out = nn.Linear(hidden_size, hidden_size)
80
+
81
+ @staticmethod
82
+ def _sincos_pos(
83
+ positions: torch.Tensor, dim: int, dtype: torch.dtype
84
+ ) -> torch.Tensor:
85
+ """
86
+ positions: [B, T], absolute positions (0..T-1)
87
+ returns: [B, T, dim] sinusoidal positional encoding
88
+ """
89
+ if dim <= 0:
90
+ raise ValueError("dim must be > 0 for positional encoding.")
91
+
92
+ half = dim // 2
93
+ if half == 0:
94
+ return torch.zeros(
95
+ positions.size(0),
96
+ positions.size(1),
97
+ dim,
98
+ device=positions.device,
99
+ dtype=dtype,
100
+ )
101
+
102
+ pos = positions.to(dtype=torch.float32)
103
+ freqs = torch.exp(
104
+ -math.log(10000.0)
105
+ * torch.arange(half, device=positions.device, dtype=torch.float32)
106
+ / half
107
+ )
108
+ angles = pos.unsqueeze(-1) * freqs # [B, T, half]
109
+
110
+ enc = torch.zeros(
111
+ positions.size(0),
112
+ positions.size(1),
113
+ dim,
114
+ device=positions.device,
115
+ dtype=torch.float32,
116
+ )
117
+ enc[..., 0 : 2 * half : 2] = torch.sin(angles)
118
+ enc[..., 1 : 2 * half : 2] = torch.cos(angles)
119
+
120
+ return enc.to(dtype=dtype)
121
+
122
+ def forward(
123
+ self,
124
+ chord_ids: torch.Tensor, # [B, T]
125
+ structure_ids: torch.Tensor, # [B, T]
126
+ ) -> torch.Tensor:
127
+
128
+ chord_emb = self.chord_embedding(chord_ids) # [B, T, chord_dim]
129
+ structure_emb = self.structure_embedding(structure_ids) # [B, T, struct_dim]
130
+
131
+ cond = torch.cat([chord_emb, structure_emb], dim=-1)
132
+ cond = self.cond_proj(cond)
133
+
134
+ # Encoder attention mask is computed separately from condition content.
135
+ # True means this token can be attended by the condition encoder.
136
+ valid_tokens = chord_ids.ne(0) | structure_ids.ne(0)
137
+ pad_mask = ~valid_tokens
138
+
139
+ # Position ids are contiguous only on valid condition-id tokens.
140
+ pos = valid_tokens.to(torch.long).cumsum(dim=1) - 1
141
+ pos = torch.where(valid_tokens, pos, torch.zeros_like(pos))
142
+
143
+ pos_enc = self._sincos_pos(pos, self.hidden_size, cond.dtype)
144
+ valid_mask = valid_tokens.unsqueeze(-1)
145
+ cond = cond + pos_enc * valid_mask.to(dtype=cond.dtype)
146
+ encoded = self.encoder(cond, pad_mask=pad_mask)
147
+
148
+ # [B, T, hidden_size]
149
+ return self.proj_out(encoded)
dataset.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ """Dataset/collate implementation for music training data."""
4
+
5
+ import math
6
+ import re
7
+
8
+ import torch
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from transformers import AutoTokenizer
11
+
12
+ from audio_tokens import (
13
+ EOA_TOKEN,
14
+ MASK_AUDIO_TOKEN,
15
+ SOA_TOKEN,
16
+ add_audio_special_tokens,
17
+ audio_id_to_token,
18
+ )
19
+ from vocab import (
20
+ CHORD_BOS_ID,
21
+ CHORD_EOS_ID,
22
+ STRUCTURE_BOS_ID,
23
+ STRUCTURE_EOS_ID,
24
+ build_frame_chord_ids,
25
+ build_frame_structure_ids,
26
+ normalize_structure_label,
27
+ )
28
+
29
+
30
+ CN_LANGUAGE_LABELS = {"cn", "zh", "zh-cn", "chinese"}
31
+ SECTION_NAME_MAP = {
32
+ "intro": "Intro",
33
+ "verse": "Verse",
34
+ "chorus": "Chorus",
35
+ "prechorus": "Pre-Chorus",
36
+ "bridge": "Bridge",
37
+ "outro": "Outro",
38
+ "pad": "Pad",
39
+ }
40
+ SINGLETON_SECTION_NAMES = {"intro", "outro", "pad"}
41
+ ENDING_PUNCTUATION = {".", ";", "!", "?", "。", "?", "!", ";"}
42
+
43
+
44
+ def _pad_batch_field(batch, key: str, padding_value):
45
+ return pad_sequence(
46
+ [row[key] for row in batch],
47
+ batch_first=True,
48
+ padding_value=padding_value,
49
+ )
50
+
51
+
52
+ def detect_language(text: str, language: str | None = None) -> str:
53
+ return (
54
+ text.replace(" ", ";")
55
+ if str(language).strip().lower() in CN_LANGUAGE_LABELS
56
+ else text
57
+ )
58
+
59
+
60
+ def normalize_section_text(
61
+ text: str, structure: str, language: str | None = None
62
+ ) -> str:
63
+ text = str(text or "")
64
+ text = (
65
+ text.replace(f"[{structure.upper()}]", "")
66
+ .replace(f"[{structure.lower()}]", "")
67
+ .replace(",", ";")
68
+ .replace(".", ";")
69
+ .replace(",", ";")
70
+ .replace("。", ";")
71
+ )
72
+ text = detect_language(text, language=language)
73
+ text = re.sub(r";(?=[A-Za-z])", "; ", text)
74
+ if text and text[-1] not in ENDING_PUNCTUATION:
75
+ text += ";"
76
+ return text
77
+
78
+
79
+ class DataCollate:
80
+ def __call__(self, batch):
81
+ input_ids = _pad_batch_field(batch, "token_ids", 0)
82
+ labels = input_ids
83
+ mask_padded = _pad_batch_field(batch, "mask", 0)
84
+ attention_mask_padded = _pad_batch_field(batch, "attention_mask", 0)
85
+ chord_ids_padded = _pad_batch_field(batch, "chord_ids", 0)
86
+ structure_ids_padded = _pad_batch_field(batch, "structure_ids", 0)
87
+ condition_mask_padded = _pad_batch_field(batch, "condition_mask", False)
88
+
89
+ return {
90
+ "input_ids": input_ids,
91
+ "labels": labels,
92
+ "masks": mask_padded,
93
+ "attention_mask": attention_mask_padded,
94
+ "chord_ids": chord_ids_padded,
95
+ "structure_ids": structure_ids_padded,
96
+ "condition_mask": condition_mask_padded,
97
+ }
98
+
99
+
100
+ class MusicDataset(torch.utils.data.Dataset):
101
+ """Fly dataset with music-code tokens and section-conditioned text."""
102
+
103
+ def __init__(
104
+ self,
105
+ datasets,
106
+ split: str,
107
+ tokenizer_path: str,
108
+ num_audio_token=16384,
109
+ fps=25,
110
+ use_fast=True,
111
+ ):
112
+ self._data = datasets[split]
113
+ self.tokenizer_path = tokenizer_path
114
+ self.use_fast = use_fast
115
+ self.num_audio_token = num_audio_token
116
+ self.fps = fps
117
+
118
+ self.tokenizer = AutoTokenizer.from_pretrained(
119
+ self.tokenizer_path,
120
+ local_files_only=True,
121
+ use_fast=self.use_fast,
122
+ )
123
+ add_audio_special_tokens(self.tokenizer, self.num_audio_token)
124
+ self.tokenizer_vocab_size = len(self.tokenizer)
125
+
126
+ self.audio_prefix_length = int(
127
+ self.tokenizer.convert_tokens_to_ids(audio_id_to_token(0))
128
+ )
129
+ self.num_text_token = self.audio_prefix_length
130
+ self.MASK_AUDIO = int(self.tokenizer.convert_tokens_to_ids(MASK_AUDIO_TOKEN))
131
+ self.BOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(SOA_TOKEN))
132
+ self.EOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(EOA_TOKEN))
133
+ self._assistant_audio_placeholder = f"{SOA_TOKEN}{EOA_TOKEN}"
134
+ self._chat_template_kwargs = {"enable_thinking": False}
135
+
136
+ def __len__(self):
137
+ return len(self._data)
138
+
139
+ @staticmethod
140
+ def _positions(token_ids: torch.Tensor, target_id: int) -> torch.Tensor:
141
+ return torch.nonzero(token_ids == target_id, as_tuple=False).squeeze(-1)
142
+
143
+ @staticmethod
144
+ def _sorted_sections(sample: dict) -> list[dict]:
145
+ return sorted(
146
+ (
147
+ {
148
+ "raw_index": raw_index,
149
+ "text": str(seg["text"]),
150
+ "desc": str(seg["desc"]).strip(),
151
+ "start": float(seg["start"]),
152
+ "end": float(seg["end"]),
153
+ "structure": normalize_structure_label(seg["section"]),
154
+ }
155
+ for raw_index, seg in enumerate(sample.get("sections", []))
156
+ ),
157
+ key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]),
158
+ )
159
+
160
+ @staticmethod
161
+ def _sorted_chords(sample: dict) -> list[dict]:
162
+ return sorted(
163
+ (
164
+ {
165
+ "raw_index": raw_index,
166
+ "type": str(seg.get("type")),
167
+ "start": float(seg.get("start", 0.0)),
168
+ "end": float(seg.get("end", 0.0)),
169
+ }
170
+ for raw_index, seg in enumerate(sample.get("chords", []))
171
+ ),
172
+ key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]),
173
+ )
174
+
175
+ def __getitem__(self, idx):
176
+ sample = self._data[idx]
177
+ sections = self._prepare_sections(sample)
178
+ chords = self._prepare_chords(sample)
179
+ token_ids, attention_mask, frame_idx_map = self._tokenize_messages(
180
+ self._build_messages(sample, sections),
181
+ sample["mucodec_codes"],
182
+ sections,
183
+ )
184
+
185
+ total_frames = len(sample["mucodec_codes"])
186
+ structure_ids = build_frame_structure_ids(sections, total_frames, fps=self.fps)
187
+ chord_ids = build_frame_chord_ids(chords, total_frames, fps=self.fps)
188
+
189
+ structure_ids = torch.from_numpy(structure_ids)
190
+ chord_ids = torch.from_numpy(chord_ids)
191
+
192
+ (
193
+ audio_codebook_mask,
194
+ bos_audio_mask,
195
+ eos_mask,
196
+ label_mask,
197
+ condition_mask,
198
+ ) = self._build_token_masks(token_ids)
199
+
200
+ chord_ids_aligned, structure_ids_aligned = self._align_condition_ids(
201
+ token_ids=token_ids,
202
+ frame_idx_map=frame_idx_map,
203
+ total_frames=total_frames,
204
+ chord_ids=chord_ids,
205
+ structure_ids=structure_ids,
206
+ audio_codebook_mask=audio_codebook_mask,
207
+ bos_audio_mask=bos_audio_mask,
208
+ eos_mask=eos_mask,
209
+ )
210
+
211
+ return {
212
+ "token_ids": token_ids,
213
+ "mask": label_mask,
214
+ "attention_mask": attention_mask,
215
+ "chord_ids": chord_ids_aligned,
216
+ "structure_ids": structure_ids_aligned,
217
+ "condition_mask": condition_mask,
218
+ }
219
+
220
+ def _tokenize_messages(
221
+ self,
222
+ messages: list[dict[str, str]],
223
+ full_audio_codes,
224
+ sections: list[dict],
225
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
226
+
227
+ chat_inputs = self.tokenizer.apply_chat_template(
228
+ messages,
229
+ tokenize=True,
230
+ add_generation_prompt=False,
231
+ return_tensors="pt",
232
+ return_dict=True,
233
+ **self._chat_template_kwargs,
234
+ )
235
+
236
+ token_ids = chat_inputs["input_ids"]
237
+ attention_mask = chat_inputs["attention_mask"]
238
+
239
+ token_ids = token_ids.squeeze(0)
240
+ attention_mask = attention_mask.squeeze(0)
241
+
242
+ token_ids = token_ids.to(torch.long)
243
+ attention_mask = attention_mask.to(torch.long)
244
+
245
+ return self._expand_audio_tokens(
246
+ token_ids=token_ids,
247
+ attention_mask=attention_mask,
248
+ full_audio_codes=full_audio_codes,
249
+ sections=sections,
250
+ )
251
+
252
+ def _frame_bounds(
253
+ self,
254
+ start: float,
255
+ end: float,
256
+ total_frames: int,
257
+ prev_end_idx: int = 0,
258
+ ) -> tuple[int, int]:
259
+ start_idx = int(start * self.fps)
260
+ end_idx = int(math.ceil(end * self.fps))
261
+ start_idx = max(prev_end_idx, min(total_frames, start_idx))
262
+ end_idx = max(start_idx, min(total_frames, end_idx))
263
+
264
+ return start_idx, end_idx
265
+
266
+ def _prepare_sections(self, sample: dict) -> list[dict]:
267
+ sections = []
268
+ section_counts: dict[str, int] = {}
269
+ sample_language = sample.get("language")
270
+ total_frames = len(sample["mucodec_codes"])
271
+ prev_end_idx = 0
272
+
273
+ for seg in self._sorted_sections(sample):
274
+ structure = seg["structure"]
275
+ section_counts[structure] = section_counts.get(structure, 0) + 1
276
+ raw_start_idx = max(0, min(total_frames, int(seg["start"] * self.fps)))
277
+ raw_end_idx = max(
278
+ raw_start_idx,
279
+ min(total_frames, int(math.ceil(seg["end"] * self.fps))),
280
+ )
281
+ start_idx = prev_end_idx
282
+ end_idx = max(start_idx, raw_end_idx)
283
+
284
+ sections.append(
285
+ {
286
+ "text": normalize_section_text(
287
+ seg["text"], structure, language=sample_language
288
+ ),
289
+ "desc": seg["desc"],
290
+ "start": start_idx / float(self.fps),
291
+ "end": end_idx / float(self.fps),
292
+ "start_frame": start_idx,
293
+ "end_frame": end_idx,
294
+ "structure": structure,
295
+ "tag": f"{structure}{section_counts[structure]}",
296
+ "index": section_counts[structure],
297
+ }
298
+ )
299
+ prev_end_idx = end_idx
300
+
301
+ if sections:
302
+ sections[-1]["end_frame"] = total_frames
303
+ sections[-1]["end"] = total_frames / float(self.fps)
304
+
305
+ return sections
306
+
307
+ def _prepare_chords(self, sample: dict) -> list[dict]:
308
+ chords = []
309
+ total_frames = len(sample["mucodec_codes"])
310
+ prev_end_idx = 0
311
+
312
+ for seg in self._sorted_chords(sample):
313
+ start_idx, end_idx = self._frame_bounds(
314
+ seg["start"],
315
+ seg["end"],
316
+ total_frames,
317
+ prev_end_idx=prev_end_idx,
318
+ )
319
+
320
+ chords.append(
321
+ {
322
+ "type": seg["type"],
323
+ "start": start_idx / float(self.fps),
324
+ "end": end_idx / float(self.fps),
325
+ "start_frame": start_idx,
326
+ "end_frame": end_idx,
327
+ }
328
+ )
329
+ prev_end_idx = end_idx
330
+
331
+ return chords
332
+
333
+ def _format_section_label(self, section: dict) -> str:
334
+ structure = section["structure"]
335
+ index = section["index"]
336
+ label = SECTION_NAME_MAP[structure]
337
+ if structure in SINGLETON_SECTION_NAMES and index == 1:
338
+ return label
339
+ return f"{label} {index}"
340
+
341
+ def _build_section_user_content(
342
+ self, sample: dict, section: dict, is_first_turn: bool
343
+ ) -> str:
344
+ parts = []
345
+ if is_first_turn:
346
+ style = sample["style"].strip()
347
+ if style:
348
+ parts.append(
349
+ f"Please generate a song in the following style:{style}\n"
350
+ "Next, I will tell you the requirements and lyrics for the song "
351
+ "fragment to be generated, section by section."
352
+ )
353
+ else:
354
+ parts.append(
355
+ "Please generate the song section by section. "
356
+ "Next, I will tell you the requirements and lyrics for each fragment."
357
+ )
358
+
359
+ section_parts = [f"[{self._format_section_label(section)}]"]
360
+ desc = section["desc"]
361
+ if desc:
362
+ section_parts.append(f"[desc:{desc}]")
363
+
364
+ lyrics = section["text"]
365
+ if lyrics:
366
+ section_parts.append(f"[lyrics:{lyrics}]")
367
+
368
+ parts.append("".join(section_parts))
369
+
370
+ return "\n".join(part for part in parts if part)
371
+
372
+ def _build_messages(
373
+ self,
374
+ sample: dict,
375
+ sections: list[dict],
376
+ ) -> list[dict[str, str]]:
377
+ messages: list[dict[str, str]] = [None] * (2 * len(sections))
378
+
379
+ for i, section in enumerate(sections):
380
+ msg_idx = 2 * i
381
+ messages[msg_idx] = {
382
+ "role": "user",
383
+ "content": self._build_section_user_content(
384
+ sample, section, is_first_turn=(i == 0)
385
+ ),
386
+ }
387
+ messages[msg_idx + 1] = {
388
+ "role": "assistant",
389
+ "content": self._assistant_audio_placeholder,
390
+ }
391
+
392
+ return messages
393
+
394
+ def _expand_audio_tokens(
395
+ self,
396
+ token_ids: torch.Tensor,
397
+ attention_mask: torch.Tensor,
398
+ full_audio_codes,
399
+ sections: list[dict],
400
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
401
+
402
+ if not sections:
403
+ return (
404
+ token_ids,
405
+ attention_mask,
406
+ torch.full(token_ids.shape, -1, dtype=torch.long),
407
+ )
408
+
409
+ bos_positions = self._positions(token_ids, self.BOS_AUDIO)
410
+ eos_positions = self._positions(token_ids, self.EOS_AUDIO)
411
+
412
+ audio_code_tensor = torch.as_tensor(full_audio_codes, dtype=torch.long)
413
+ extra_audio_tokens = sum(
414
+ int(section["end_frame"]) - int(section["start_frame"])
415
+ for section in sections
416
+ )
417
+ final_len = token_ids.numel() + extra_audio_tokens
418
+
419
+ expanded_token_ids = torch.empty(final_len, dtype=torch.long)
420
+ expanded_attention_mask = torch.empty(final_len, dtype=torch.long)
421
+ frame_idx_map = torch.full((final_len,), -1, dtype=torch.long)
422
+
423
+ read_pos = 0
424
+ write_pos = 0
425
+
426
+ for bos_pos, eos_pos, section in zip(
427
+ bos_positions.tolist(), eos_positions.tolist(), sections
428
+ ):
429
+ start_idx = int(section["start_frame"])
430
+ end_idx = int(section["end_frame"])
431
+ audio_len = end_idx - start_idx
432
+
433
+ prefix_len = bos_pos + 1 - read_pos
434
+ next_write = write_pos + prefix_len
435
+ expanded_token_ids[write_pos:next_write] = token_ids[read_pos : bos_pos + 1]
436
+ expanded_attention_mask[write_pos:next_write] = attention_mask[
437
+ read_pos : bos_pos + 1
438
+ ]
439
+ frame_idx_map[next_write - 1] = start_idx if audio_len > 0 else -1
440
+ write_pos = next_write
441
+
442
+ if audio_len > 0:
443
+ next_write = write_pos + audio_len
444
+ expanded_token_ids[write_pos:next_write] = audio_code_tensor[
445
+ start_idx:end_idx
446
+ ]
447
+ expanded_token_ids[write_pos:next_write].add_(self.audio_prefix_length)
448
+ expanded_attention_mask[write_pos:next_write] = 1
449
+ frame_idx_map[write_pos:next_write] = torch.arange(
450
+ start_idx, end_idx, dtype=torch.long
451
+ )
452
+ write_pos = next_write
453
+
454
+ expanded_token_ids[write_pos] = token_ids[eos_pos]
455
+ expanded_attention_mask[write_pos] = attention_mask[eos_pos]
456
+ frame_idx_map[write_pos] = end_idx - 1 if audio_len > 0 else -1
457
+ write_pos += 1
458
+ read_pos = eos_pos + 1
459
+
460
+ tail_len = token_ids.numel() - read_pos
461
+ if tail_len > 0:
462
+ expanded_token_ids[write_pos : write_pos + tail_len] = token_ids[read_pos:]
463
+ expanded_attention_mask[write_pos : write_pos + tail_len] = attention_mask[
464
+ read_pos:
465
+ ]
466
+
467
+ return expanded_token_ids, expanded_attention_mask, frame_idx_map
468
+
469
+ def _build_token_masks(
470
+ self, token_ids: torch.Tensor
471
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
472
+
473
+ audio_codebook_mask = (token_ids >= self.audio_prefix_length) & (
474
+ token_ids < self.MASK_AUDIO
475
+ )
476
+ bos_audio_mask = token_ids == self.BOS_AUDIO
477
+ eos_mask = token_ids == self.EOS_AUDIO
478
+ label_mask = (audio_codebook_mask | eos_mask).long()
479
+ condition_mask = audio_codebook_mask | bos_audio_mask | eos_mask
480
+
481
+ return audio_codebook_mask, bos_audio_mask, eos_mask, label_mask, condition_mask
482
+
483
+ def _align_condition_ids(
484
+ self,
485
+ token_ids: torch.Tensor,
486
+ frame_idx_map: torch.Tensor,
487
+ total_frames: int,
488
+ chord_ids: torch.Tensor,
489
+ structure_ids: torch.Tensor,
490
+ audio_codebook_mask: torch.Tensor,
491
+ bos_audio_mask: torch.Tensor,
492
+ eos_mask: torch.Tensor,
493
+ ) -> tuple[torch.Tensor, torch.Tensor]:
494
+
495
+ seq_len = token_ids.numel()
496
+ chord_ids_aligned = torch.zeros(seq_len, dtype=torch.long)
497
+ structure_ids_aligned = torch.zeros(seq_len, dtype=torch.long)
498
+
499
+ bos_positions = torch.nonzero(bos_audio_mask, as_tuple=False).squeeze(-1)
500
+ chord_ids_aligned[bos_positions] = CHORD_BOS_ID
501
+ structure_ids_aligned[bos_positions] = STRUCTURE_BOS_ID
502
+
503
+ audio_positions = torch.nonzero(audio_codebook_mask, as_tuple=False).squeeze(-1)
504
+ cur_frame_idx = frame_idx_map[audio_positions]
505
+ cur_frame_idx = cur_frame_idx.clamp(0, max(total_frames - 1, 0))
506
+ chord_ids_aligned[audio_positions] = chord_ids[cur_frame_idx]
507
+ structure_ids_aligned[audio_positions] = structure_ids[cur_frame_idx]
508
+
509
+ eos_positions = torch.nonzero(eos_mask, as_tuple=False).squeeze(-1)
510
+ chord_ids_aligned[eos_positions] = CHORD_EOS_ID
511
+ structure_ids_aligned[eos_positions] = STRUCTURE_EOS_ID
512
+
513
+ return chord_ids_aligned, structure_ids_aligned
decoders.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple
3
+ from torch import nn as nn
4
+
5
+ from transformers.cache_utils import Cache
6
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
7
+
8
+
9
+ class AdaLN(nn.Module):
10
+ """
11
+ DiT-style AdaLN:
12
+ cond_token -> (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp)
13
+
14
+ If zero_init=True, then at step0:
15
+ shift/scale/gate are all exactly 0 -> base behavior preserved (mathematically).
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ hidden_size: int,
21
+ cond_dim: int,
22
+ zero_init: bool = True,
23
+ ):
24
+ super().__init__()
25
+ self.hidden_size = hidden_size
26
+
27
+ self.act = nn.SiLU()
28
+ self.linear = nn.Linear(cond_dim, 6 * hidden_size, bias=True)
29
+
30
+ if zero_init:
31
+ nn.init.zeros_(self.linear.weight)
32
+ nn.init.zeros_(self.linear.bias)
33
+
34
+ def forward(self, cond_token: torch.Tensor) -> Tuple[torch.Tensor, ...]:
35
+ """
36
+ cond_token: [B, T, cond_dim]
37
+ returns 6 tensors, each [B, T, H]
38
+ """
39
+ params = self.linear(self.act(cond_token)) # [B, T, 6H]
40
+ (
41
+ shift_msa,
42
+ scale_msa,
43
+ gate_msa,
44
+ shift_mlp,
45
+ scale_mlp,
46
+ gate_mlp,
47
+ ) = params.chunk(6, dim=-1)
48
+
49
+ return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
50
+
51
+
52
+ def apply_adaln(
53
+ x_norm: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor
54
+ ) -> torch.Tensor:
55
+ # x_norm * (1 + scale) + shift
56
+ return x_norm * (1.0 + scale) + shift
57
+
58
+
59
+ class Qwen3DecoderLayerAdaLN(Qwen3DecoderLayer):
60
+ """
61
+ Qwen3 decoder layer with AdaLN injection:
62
+ - Modulate normalized input with (shift, scale) on masked positions.
63
+ - IMPORTANT: gate must preserve base behavior at gate=0:
64
+ out = out_base * (1 + gate) (on masked positions)
65
+ so that when gate==0, out==out_base.
66
+
67
+ Only applied on audio positions (condition_mask==True).
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ config,
73
+ layer_idx: int,
74
+ cond_dim: int,
75
+ zero_init: bool = True,
76
+ ):
77
+ super().__init__(config, layer_idx)
78
+
79
+ self.dit_adaln = AdaLN(
80
+ hidden_size=config.hidden_size,
81
+ cond_dim=cond_dim,
82
+ zero_init=zero_init,
83
+ )
84
+
85
+ def forward(
86
+ self,
87
+ hidden_states: torch.Tensor,
88
+ attention_mask: Optional[torch.Tensor] = None,
89
+ position_ids: Optional[torch.LongTensor] = None,
90
+ past_key_values: Optional[Cache] = None,
91
+ use_cache: bool = False,
92
+ cache_position: Optional[torch.LongTensor] = None,
93
+ cond_expanded: Optional[torch.Tensor] = None, # [B, T, cond_dim]
94
+ condition_mask: Optional[torch.BoolTensor] = None, # [B, T]
95
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
96
+ **kwargs,
97
+ ):
98
+ # Keep the condition path fully tensor-based; avoid .item() checks that
99
+ # can force GPU-CPU synchronization in autoregressive decoding.
100
+ do_cond = (cond_expanded is not None) and (condition_mask is not None)
101
+
102
+ if do_cond:
103
+ (
104
+ shift_msa,
105
+ scale_msa,
106
+ gate_msa,
107
+ shift_mlp,
108
+ scale_mlp,
109
+ gate_mlp,
110
+ ) = self.dit_adaln(cond_expanded)
111
+ mask_expanded = condition_mask.unsqueeze(-1) # [B, T, 1]
112
+
113
+ # ---- Self-Attention branch ----
114
+ residual = hidden_states
115
+ x_norm = self.input_layernorm(hidden_states) # RMSNorm in Qwen3
116
+
117
+ if do_cond:
118
+ x_mod = apply_adaln(x_norm, shift_msa, scale_msa)
119
+ x_in = torch.where(mask_expanded, x_mod, x_norm)
120
+ else:
121
+ x_in = x_norm
122
+
123
+ attn_out, _ = self.self_attn(
124
+ hidden_states=x_in,
125
+ attention_mask=attention_mask,
126
+ position_ids=position_ids,
127
+ past_key_values=past_key_values,
128
+ use_cache=use_cache,
129
+ cache_position=cache_position,
130
+ position_embeddings=position_embeddings,
131
+ **kwargs,
132
+ )
133
+
134
+ if do_cond:
135
+ # Preserve base when gate==0: attn_out_audio = (1 + gate) * attn_out_base
136
+ attn_out = torch.where(mask_expanded, (1.0 + gate_msa) * attn_out, attn_out)
137
+
138
+ hidden_states = residual + attn_out
139
+
140
+ # ---- MLP branch ----
141
+ residual = hidden_states
142
+ x_norm = self.post_attention_layernorm(hidden_states)
143
+
144
+ if do_cond:
145
+ x_mod = apply_adaln(x_norm, shift_mlp, scale_mlp)
146
+ x_in = torch.where(mask_expanded, x_mod, x_norm)
147
+ else:
148
+ x_in = x_norm
149
+
150
+ mlp_out = self.mlp(x_in)
151
+
152
+ if do_cond:
153
+ # Preserve base when gate==0
154
+ mlp_out = torch.where(mask_expanded, (1.0 + gate_mlp) * mlp_out, mlp_out)
155
+
156
+ hidden_states = residual + mlp_out
157
+
158
+ return hidden_states
inference_full.py ADDED
@@ -0,0 +1,1084 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ HF-driven inference for MAGEL with segment-level autoregressive generation.
5
+
6
+ Uses from HF sample:
7
+ - text instruction/template tokens (token_ids scaffold)
8
+ - control tokens: chord_ids/structure_ids
9
+
10
+ Does NOT use:
11
+ - ground-truth audio token values as input (audio codebook positions are masked)
12
+ """
13
+
14
+ import argparse
15
+ import contextlib
16
+ import importlib
17
+ import json
18
+ import os
19
+ import sys
20
+ from dataclasses import dataclass
21
+ from datetime import datetime
22
+ from pathlib import Path
23
+ from typing import Any, Optional
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+ from runtime_utils import (
29
+ load_magel_checkpoint,
30
+ load_music_dataset,
31
+ maybe_compile_model,
32
+ maybe_mark_compile_step_begin,
33
+ resolve_device,
34
+ seed_everything,
35
+ )
36
+ from vocab import (
37
+ CHORD_BOS_ID,
38
+ CHORD_EOS_ID,
39
+ STRUCTURE_EOS_ID,
40
+ chord_id_to_label,
41
+ structure_id_to_label,
42
+ )
43
+ from modelling_qwen3 import MAGEL
44
+
45
+ REPO_ROOT = Path(__file__).resolve().parent
46
+ MUCODEC_ROOT = REPO_ROOT / "MuCodec"
47
+
48
+
49
+ @dataclass
50
+ class TokenLayout:
51
+ num_text_token: int
52
+ num_audio_codebook: int = 16384
53
+
54
+ @property
55
+ def audio_start(self) -> int:
56
+ return self.num_text_token
57
+
58
+ @property
59
+ def audio_end(self) -> int:
60
+ return self.num_text_token + self.num_audio_codebook
61
+
62
+ @property
63
+ def mask_audio(self) -> int:
64
+ return self.audio_end
65
+
66
+ @property
67
+ def bos_audio(self) -> int:
68
+ return self.audio_end + 1
69
+
70
+ @property
71
+ def eos_audio(self) -> int:
72
+ return self.audio_end + 2
73
+
74
+
75
+ @dataclass
76
+ class SegmentSpan:
77
+ seg_idx: int
78
+ bos_pos: int
79
+ eos_pos: int
80
+ audio_positions: list[int]
81
+
82
+
83
+ @dataclass
84
+ class HFTemplateSample:
85
+ song_id: str
86
+ num_text_token: int
87
+ template_ids: torch.Tensor # [T], original token_ids
88
+ input_ids: torch.Tensor # [T], audio codebook replaced with MASK_AUDIO
89
+ chord_ids: torch.Tensor # [T]
90
+ structure_ids: torch.Tensor # [T]
91
+ condition_mask: torch.Tensor # [T]
92
+ is_audio_codebook: torch.Tensor # [T]
93
+ is_eos: torch.Tensor # [T]
94
+ segments: list[SegmentSpan]
95
+ raw_item: dict[str, Any]
96
+
97
+ @property
98
+ def seq_len(self) -> int:
99
+ return int(self.input_ids.numel())
100
+
101
+
102
+ def parse_args() -> argparse.Namespace:
103
+ parser = argparse.ArgumentParser(
104
+ description="Segment-wise AR generation from HF controls/scaffold."
105
+ )
106
+ parser.add_argument(
107
+ "--model_path",
108
+ type=str,
109
+ default="./output_qwen3_0p6b_train/final",
110
+ )
111
+ parser.add_argument(
112
+ "--dataset_path",
113
+ type=str,
114
+ default="muse_mucodec_chord.ds",
115
+ )
116
+ parser.add_argument("--split", type=str, default="validation")
117
+ parser.add_argument("--sample_idx", type=int, default=0)
118
+ parser.add_argument(
119
+ "--tokenizer_path", type=str, default="checkpoints/Qwen3-0.6B"
120
+ )
121
+ parser.add_argument(
122
+ "--num_audio_codebook",
123
+ type=int,
124
+ default=None,
125
+ help="Audio codebook size. Defaults to checkpoint metadata when available.",
126
+ )
127
+
128
+ parser.add_argument("--temperature", type=float, default=1.0)
129
+ parser.add_argument("--top_k", type=int, default=50)
130
+ parser.add_argument("--top_p", type=float, default=0.90)
131
+ parser.add_argument("--greedy", action="store_true", default=False)
132
+ parser.add_argument("--max_audio_tokens", type=int, default=0)
133
+ parser.add_argument("--fps", type=int, default=25)
134
+
135
+ parser.add_argument("--seed", type=int, default=1234)
136
+ parser.add_argument("--device", type=str, default="auto")
137
+ parser.add_argument(
138
+ "--dtype",
139
+ type=str,
140
+ default="bfloat16",
141
+ choices=["float32", "float16", "bfloat16"],
142
+ )
143
+ parser.add_argument("--use_cache", action="store_true", default=True)
144
+ parser.add_argument("--no_cache", action="store_true", default=False)
145
+ parser.add_argument("--compile", action="store_true", default=False)
146
+ parser.add_argument(
147
+ "--compile_mode",
148
+ type=str,
149
+ default="reduce-overhead",
150
+ choices=["default", "reduce-overhead", "max-autotune"],
151
+ )
152
+ parser.add_argument(
153
+ "--attn_implementation",
154
+ type=str,
155
+ default="sdpa",
156
+ choices=["eager", "sdpa", "flash_attention_2"],
157
+ )
158
+ parser.add_argument("--output_dir", type=str, default="predictions")
159
+ parser.add_argument("--output_prefix", type=str, default="")
160
+ parser.add_argument(
161
+ "--json_output_dir",
162
+ type=str,
163
+ default="predictions/json",
164
+ help="Directory for chord/segment json. Default: <output_dir>/json",
165
+ )
166
+ parser.add_argument(
167
+ "--mucodec_device",
168
+ type=str,
169
+ default="auto",
170
+ help="Device string for MuCodec, for example cuda:0.",
171
+ )
172
+ parser.add_argument(
173
+ "--mucodec_layer_num",
174
+ type=int,
175
+ default=7,
176
+ help="MuCodec layer_num passed to the official decoder.",
177
+ )
178
+ parser.add_argument(
179
+ "--mucodec_duration",
180
+ type=float,
181
+ default=40.96,
182
+ help="Chunk duration argument passed to MuCodec code2sound.",
183
+ )
184
+ parser.add_argument(
185
+ "--mucodec_guidance_scale",
186
+ type=float,
187
+ default=1.5,
188
+ help="Guidance scale argument passed to MuCodec code2sound.",
189
+ )
190
+ parser.add_argument(
191
+ "--mucodec_num_steps",
192
+ type=int,
193
+ default=20,
194
+ help="Sampling steps argument passed to MuCodec code2sound.",
195
+ )
196
+ parser.add_argument(
197
+ "--mucodec_sample_rate",
198
+ type=int,
199
+ default=48000,
200
+ help="Sample rate used when saving decoded wav.",
201
+ )
202
+ parser.add_argument(
203
+ "--wav_output_dir",
204
+ type=str,
205
+ default="predictions/wav",
206
+ help="Directory for decoded wav. Default: <output_dir>/wav",
207
+ )
208
+ return parser.parse_args()
209
+
210
+
211
+ def resolve_runtime_device_str(device_arg: str) -> str:
212
+ if device_arg != "auto":
213
+ return device_arg
214
+ if torch.cuda.is_available():
215
+ return "cuda:0"
216
+ if torch.backends.mps.is_available():
217
+ return "mps"
218
+ return "cpu"
219
+
220
+
221
+ @contextlib.contextmanager
222
+ def pushd(path: str):
223
+ prev = os.getcwd()
224
+ os.chdir(path)
225
+ try:
226
+ yield
227
+ finally:
228
+ os.chdir(prev)
229
+
230
+
231
+ def ensure_sys_path(path: str) -> None:
232
+ if path and path not in sys.path:
233
+ sys.path.insert(0, path)
234
+
235
+
236
+ def get_mucodec_root() -> str:
237
+ if not MUCODEC_ROOT.is_dir():
238
+ raise FileNotFoundError(f"MuCodec directory not found: {MUCODEC_ROOT}")
239
+ if not (MUCODEC_ROOT / "generate.py").is_file():
240
+ raise FileNotFoundError(
241
+ f"MuCodec entrypoint not found: {MUCODEC_ROOT / 'generate.py'}"
242
+ )
243
+ return str(MUCODEC_ROOT)
244
+
245
+
246
+ def import_mucodec_class():
247
+ repo_path = get_mucodec_root()
248
+ ensure_sys_path(repo_path)
249
+ try:
250
+ module = importlib.import_module("generate")
251
+ return getattr(module, "MuCodec"), repo_path
252
+ except Exception as exc: # pragma: no cover - env dependent
253
+ raise ImportError(f"Could not import MuCodec from {repo_path}/generate.py: {exc}")
254
+
255
+
256
+ def build_mucodec_decoder(args: argparse.Namespace) -> Any:
257
+ MuCodec, resolved_repo = import_mucodec_class()
258
+
259
+ ckpt_path = os.path.join(resolved_repo, "ckpt", "mucodec.pt")
260
+ if not os.path.exists(ckpt_path):
261
+ raise FileNotFoundError(f"MuCodec checkpoint not found: {ckpt_path}")
262
+
263
+ required_local_files = [
264
+ os.path.join(resolved_repo, "tools", "audioldm_48k.pth"),
265
+ os.path.join(resolved_repo, "muq_dev", "muq.pt"),
266
+ ]
267
+ for path in required_local_files:
268
+ if not os.path.exists(path):
269
+ raise FileNotFoundError(
270
+ f"Required MuCodec dependency not found for current folder structure: {path}"
271
+ )
272
+
273
+ mucodec_device = resolve_runtime_device_str(args.mucodec_device)
274
+ if resolved_repo:
275
+ print(f"[INFO] resolved MuCodec repo: {resolved_repo}")
276
+ print(f"[INFO] loading MuCodec from {ckpt_path} on {mucodec_device}")
277
+ with pushd(resolved_repo):
278
+ decoder = MuCodec(
279
+ model_path=ckpt_path,
280
+ layer_num=int(args.mucodec_layer_num),
281
+ load_main_model=True,
282
+ device=mucodec_device,
283
+ )
284
+ setattr(decoder, "_magel_mucodec_repo", resolved_repo)
285
+ return decoder
286
+
287
+
288
+ def decode_mucodec_codes(
289
+ mucodec_decoder: Any,
290
+ shifted_codes: np.ndarray,
291
+ args: argparse.Namespace,
292
+ ) -> torch.Tensor:
293
+ if shifted_codes.ndim != 1:
294
+ raise ValueError(
295
+ f"Expected 1D MuCodec token stream, got shape {shifted_codes.shape}"
296
+ )
297
+
298
+ codes = torch.from_numpy(shifted_codes.astype(np.int64, copy=False))
299
+ codes = codes.unsqueeze(0).unsqueeze(0)
300
+ repo_path = getattr(mucodec_decoder, "_magel_mucodec_repo", "")
301
+ decode_ctx = pushd(repo_path) if repo_path else contextlib.nullcontext()
302
+ with decode_ctx:
303
+ wave = mucodec_decoder.code2sound(
304
+ codes,
305
+ prompt=None,
306
+ duration=float(args.mucodec_duration),
307
+ guidance_scale=float(args.mucodec_guidance_scale),
308
+ num_steps=int(args.mucodec_num_steps),
309
+ disable_progress=True,
310
+ )
311
+ if not torch.is_tensor(wave):
312
+ wave = torch.as_tensor(wave)
313
+ if wave.ndim == 1:
314
+ wave = wave.unsqueeze(0)
315
+ return wave.detach().cpu().to(torch.float32)
316
+
317
+
318
+ def build_segment_spans(
319
+ template_ids: torch.Tensor,
320
+ is_audio_codebook: torch.Tensor,
321
+ layout: TokenLayout,
322
+ ) -> list[SegmentSpan]:
323
+ bos_positions = torch.where(template_ids.eq(layout.bos_audio))[0].tolist()
324
+ eos_positions = torch.where(template_ids.eq(layout.eos_audio))[0].tolist()
325
+ if not bos_positions or not eos_positions:
326
+ return []
327
+
328
+ spans: list[SegmentSpan] = []
329
+ eos_ptr = 0
330
+ for b in bos_positions:
331
+ while eos_ptr < len(eos_positions) and eos_positions[eos_ptr] <= b:
332
+ eos_ptr += 1
333
+ if eos_ptr >= len(eos_positions):
334
+ break
335
+ e = eos_positions[eos_ptr]
336
+ eos_ptr += 1
337
+ idx = torch.arange(template_ids.numel(), device=template_ids.device)
338
+ mask = is_audio_codebook & (idx > b) & (idx < e)
339
+ audio_positions = torch.where(mask)[0].tolist()
340
+ spans.append(
341
+ SegmentSpan(
342
+ seg_idx=len(spans),
343
+ bos_pos=int(b),
344
+ eos_pos=int(e),
345
+ audio_positions=[int(p) for p in audio_positions],
346
+ )
347
+ )
348
+ return spans
349
+
350
+
351
+ def load_hf_template_sample(
352
+ dataset_path: str,
353
+ split: str,
354
+ tokenizer_path: str,
355
+ sample_idx: int,
356
+ num_audio_codebook: int,
357
+ ) -> HFTemplateSample:
358
+ music_ds = load_music_dataset(
359
+ dataset_path=dataset_path,
360
+ split=split,
361
+ tokenizer_path=tokenizer_path,
362
+ num_audio_token=num_audio_codebook,
363
+ use_fast=True,
364
+ )
365
+ return load_hf_template_sample_from_music_dataset(
366
+ music_ds=music_ds,
367
+ sample_idx=sample_idx,
368
+ num_audio_codebook=num_audio_codebook,
369
+ )
370
+
371
+
372
+ def load_hf_template_sample_from_music_dataset(
373
+ music_ds,
374
+ sample_idx: int,
375
+ num_audio_codebook: int,
376
+ ) -> HFTemplateSample:
377
+ layout = TokenLayout(
378
+ num_text_token=music_ds.num_text_token,
379
+ num_audio_codebook=num_audio_codebook,
380
+ )
381
+
382
+ raw_item = music_ds._data[sample_idx]
383
+ row = music_ds[sample_idx]
384
+
385
+ template_ids = row["token_ids"].to(torch.long)
386
+ chord_ids = row["chord_ids"].to(torch.long)
387
+ structure_ids = row["structure_ids"].to(torch.long)
388
+ condition_mask = row["condition_mask"].to(torch.bool)
389
+
390
+ seq_len = int(template_ids.numel())
391
+ for name, t in [
392
+ ("chord_ids", chord_ids),
393
+ ("structure_ids", structure_ids),
394
+ ("condition_mask", condition_mask),
395
+ ]:
396
+ if int(t.numel()) != seq_len:
397
+ raise ValueError(f"{name} length mismatch: {int(t.numel())} != {seq_len}")
398
+
399
+ is_audio_codebook = (template_ids >= layout.audio_start) & (
400
+ template_ids < layout.audio_end
401
+ )
402
+ is_eos = template_ids.eq(layout.eos_audio)
403
+
404
+ # Remove GT audio token values from input scaffold.
405
+ input_ids = template_ids.clone()
406
+ input_ids[is_audio_codebook] = layout.mask_audio
407
+
408
+ spans = build_segment_spans(template_ids, is_audio_codebook, layout)
409
+
410
+ return HFTemplateSample(
411
+ song_id=str(raw_item.get("song_id", f"sample_{sample_idx}")),
412
+ num_text_token=music_ds.num_text_token,
413
+ template_ids=template_ids,
414
+ input_ids=input_ids,
415
+ chord_ids=chord_ids,
416
+ structure_ids=structure_ids,
417
+ condition_mask=condition_mask,
418
+ is_audio_codebook=is_audio_codebook,
419
+ is_eos=is_eos,
420
+ segments=spans,
421
+ raw_item=raw_item,
422
+ )
423
+
424
+
425
+ def apply_top_k_top_p(logits: torch.Tensor, top_k: int, top_p: float) -> torch.Tensor:
426
+ if top_k is not None and top_k > 0:
427
+ k = min(top_k, logits.shape[-1])
428
+ values, _ = torch.topk(logits, k, dim=-1)
429
+ kth = values[:, -1].unsqueeze(-1)
430
+ logits = logits.masked_fill(logits < kth, float("-inf"))
431
+
432
+ if top_p is not None and 0.0 < top_p < 1.0:
433
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
434
+ sorted_probs = torch.softmax(sorted_logits, dim=-1)
435
+ cum_probs = torch.cumsum(sorted_probs, dim=-1)
436
+ remove_mask = cum_probs > top_p
437
+ remove_mask[:, 0] = False
438
+ sorted_logits = sorted_logits.masked_fill(remove_mask, float("-inf"))
439
+ filtered = torch.full_like(logits, float("-inf"))
440
+ filtered.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)
441
+ logits = filtered
442
+ return logits
443
+
444
+
445
+ def sample_from_logits(
446
+ logits: torch.Tensor,
447
+ temperature: float,
448
+ top_k: int,
449
+ top_p: float,
450
+ greedy: bool,
451
+ ) -> int:
452
+ if greedy or temperature <= 0:
453
+ return int(torch.argmax(logits, dim=-1).item())
454
+ logits = logits / max(temperature, 1e-6)
455
+ logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p)
456
+ if not torch.isfinite(logits).any():
457
+ raise RuntimeError("All logits are -inf after filtering.")
458
+ probs = torch.softmax(logits, dim=-1)
459
+ return int(torch.multinomial(probs, num_samples=1).item())
460
+
461
+
462
+ def sample_audio_token_from_logits(
463
+ logits: torch.Tensor,
464
+ layout: TokenLayout,
465
+ temperature: float,
466
+ top_k: int,
467
+ top_p: float,
468
+ greedy: bool,
469
+ ) -> int:
470
+ audio_logits = logits[:, layout.audio_start : layout.audio_end]
471
+ sampled_audio_idx = sample_from_logits(
472
+ audio_logits,
473
+ temperature=temperature,
474
+ top_k=top_k,
475
+ top_p=top_p,
476
+ greedy=greedy,
477
+ )
478
+ return int(layout.audio_start + sampled_audio_idx)
479
+
480
+
481
+ def chord_id_to_type(chord_id: int) -> str:
482
+ decoded = chord_id_to_label(chord_id)
483
+ return decoded if decoded != "N" or chord_id in {1, CHORD_BOS_ID, CHORD_EOS_ID} else f"unknown_{chord_id}"
484
+
485
+
486
+ def segment_id_to_type(segment_id: int) -> str:
487
+ decoded = structure_id_to_label(segment_id)
488
+ return decoded if 0 <= segment_id <= STRUCTURE_EOS_ID else f"unknown_{segment_id}"
489
+
490
+
491
+ def to_intervals(type_ids: list[int], fps: int, mapper) -> list[dict[str, Any]]:
492
+ if not type_ids:
493
+ return []
494
+ out: list[dict[str, Any]] = []
495
+ start = 0
496
+ cur = type_ids[0]
497
+ for i in range(1, len(type_ids) + 1):
498
+ if i == len(type_ids) or type_ids[i] != cur:
499
+ out.append(
500
+ {
501
+ "start": round(start / float(fps), 6),
502
+ "end": round(i / float(fps), 6),
503
+ "type": mapper(int(cur)),
504
+ }
505
+ )
506
+ if i < len(type_ids):
507
+ start = i
508
+ cur = type_ids[i]
509
+ return out
510
+
511
+
512
+ def merge_same_type_with_small_gap(
513
+ intervals: list[dict[str, Any]], fps: int, max_gap_frames: int = 1
514
+ ) -> list[dict[str, Any]]:
515
+ if not intervals:
516
+ return []
517
+ max_gap_s = float(max_gap_frames) / float(fps)
518
+ merged = [dict(intervals[0])]
519
+ for cur in intervals[1:]:
520
+ prev = merged[-1]
521
+ gap_s = float(cur["start"]) - float(prev["end"])
522
+ if prev.get("type") == cur.get("type") and gap_s <= (max_gap_s + 1e-9):
523
+ prev["end"] = cur["end"]
524
+ else:
525
+ merged.append(dict(cur))
526
+ return merged
527
+
528
+
529
+ @torch.inference_mode()
530
+ def generate_segmentwise(
531
+ model: MAGEL,
532
+ sample: HFTemplateSample,
533
+ layout: TokenLayout,
534
+ device: torch.device,
535
+ use_cache: bool,
536
+ temperature: float,
537
+ top_k: int,
538
+ top_p: float,
539
+ greedy: bool,
540
+ max_audio_tokens: int,
541
+ ) -> tuple[torch.Tensor, int, list[int], list[int]]:
542
+ import time
543
+
544
+ seq_template = sample.input_ids.to(device)
545
+ chord_template = sample.chord_ids.to(device)
546
+ structure_template = sample.structure_ids.to(device)
547
+ condition_mask_template = sample.condition_mask.to(device)
548
+ is_audio_code = sample.is_audio_codebook.to(device)
549
+ is_eos = sample.is_eos.to(device)
550
+
551
+ slot_positions = torch.where(is_audio_code | is_eos)[0]
552
+ if slot_positions.numel() == 0:
553
+ # No generation slot: return scaffold as-is.
554
+ return seq_template.detach().cpu(), 0, [], []
555
+
556
+ start_pos = int(slot_positions[0].item())
557
+ if sample.segments:
558
+ end_pos = int(sample.segments[-1].eos_pos)
559
+ else:
560
+ end_pos = int(slot_positions[-1].item())
561
+
562
+ sampled_chord_ids: list[int] = []
563
+ sampled_segment_ids: list[int] = []
564
+
565
+ generated_ids = seq_template.clone()
566
+ sampled_count = 0
567
+ past_key_values: Optional[tuple] = None
568
+
569
+ # Precompute full-sequence condition once so cached decoding keeps
570
+ # the same global condition-encoder context as training.
571
+ cond_template: torch.Tensor = model.condition_encoder(
572
+ chord_template.unsqueeze(0),
573
+ structure_template.unsqueeze(0),
574
+ )
575
+
576
+ # Prefill with fixed prefix.
577
+ full_attention_mask = torch.ones(
578
+ (1, sample.seq_len), dtype=torch.long, device=device
579
+ )
580
+ prefix_ids = generated_ids[:start_pos].unsqueeze(0)
581
+ prefix_attn = full_attention_mask[:, :start_pos]
582
+ model_kwargs = dict(
583
+ input_ids=prefix_ids,
584
+ attention_mask=prefix_attn,
585
+ condition_mask=condition_mask_template[:start_pos].unsqueeze(0),
586
+ cond_precomputed=cond_template[:, :start_pos, :],
587
+ use_cache=use_cache,
588
+ )
589
+ maybe_mark_compile_step_begin(model)
590
+ prefill_t0 = time.perf_counter()
591
+ out = model(**model_kwargs)
592
+ prefill_time_s = time.perf_counter() - prefill_t0
593
+ logits_next = out.logits[:, -1, :]
594
+ if use_cache:
595
+ past_key_values = out.past_key_values
596
+ step_ids = torch.empty((1, 1), dtype=torch.long, device=device)
597
+
598
+ decode_time_s = 0.0
599
+ for i in range(start_pos, end_pos + 1):
600
+ if bool(is_audio_code[i].item()):
601
+ if max_audio_tokens > 0 and sampled_count >= max_audio_tokens:
602
+ break
603
+ next_id = sample_audio_token_from_logits(
604
+ logits_next,
605
+ layout=layout,
606
+ temperature=temperature,
607
+ top_k=top_k,
608
+ top_p=top_p,
609
+ greedy=greedy,
610
+ )
611
+ sampled_count += 1
612
+ # Controls are input-aligned to the token sequence.
613
+ cond_pos = i
614
+ sampled_chord_ids.append(int(chord_template[cond_pos].item()))
615
+ sampled_segment_ids.append(int(structure_template[cond_pos].item()))
616
+ elif bool(is_eos[i].item()):
617
+ next_id = layout.eos_audio
618
+ else:
619
+ next_id = int(seq_template[i].item())
620
+
621
+ generated_ids[i] = int(next_id)
622
+
623
+ if i >= end_pos:
624
+ break
625
+
626
+ if use_cache:
627
+ step_ids[0, 0] = int(next_id)
628
+ step_attn = full_attention_mask[:, : i + 2]
629
+ model_kwargs = dict(
630
+ input_ids=step_ids,
631
+ attention_mask=step_attn,
632
+ condition_mask=condition_mask_template[i : i + 1].unsqueeze(0),
633
+ cond_precomputed=cond_template[:, i : i + 1, :],
634
+ past_key_values=past_key_values,
635
+ use_cache=True,
636
+ )
637
+ maybe_mark_compile_step_begin(model)
638
+ step_t0 = time.perf_counter()
639
+ out = model(**model_kwargs)
640
+ decode_time_s += time.perf_counter() - step_t0
641
+ logits_next = out.logits[:, -1, :]
642
+ past_key_values = out.past_key_values
643
+ else:
644
+ cur_len = i + 1
645
+ model_kwargs = dict(
646
+ input_ids=generated_ids[:cur_len].unsqueeze(0),
647
+ attention_mask=full_attention_mask[:, :cur_len],
648
+ condition_mask=condition_mask_template[:cur_len].unsqueeze(0),
649
+ cond_precomputed=cond_template[:, :cur_len, :],
650
+ use_cache=False,
651
+ )
652
+ maybe_mark_compile_step_begin(model)
653
+ step_t0 = time.perf_counter()
654
+ out = model(**model_kwargs)
655
+ decode_time_s += time.perf_counter() - step_t0
656
+ logits_next = out.logits[:, -1, :]
657
+
658
+ total_gen_time_s = prefill_time_s + decode_time_s
659
+ tokens_per_second = (
660
+ float(sampled_count) / decode_time_s if decode_time_s > 0 and sampled_count > 0 else 0.0
661
+ )
662
+ print(
663
+ "[PROFILE] generation "
664
+ f"prefill_s={prefill_time_s:.3f} "
665
+ f"decode_s={decode_time_s:.3f} "
666
+ f"total_s={total_gen_time_s:.3f} "
667
+ f"sampled_audio_tokens={sampled_count} "
668
+ f"decode_tok_per_s={tokens_per_second:.3f}"
669
+ )
670
+
671
+ return (
672
+ generated_ids.detach().cpu(),
673
+ sampled_count,
674
+ sampled_chord_ids,
675
+ sampled_segment_ids,
676
+ )
677
+
678
+
679
+ @torch.inference_mode()
680
+ def batch_generate_segmentwise(
681
+ model: MAGEL,
682
+ samples: list[HFTemplateSample],
683
+ layout: TokenLayout,
684
+ device: torch.device,
685
+ use_cache: bool,
686
+ temperature: float,
687
+ top_k: int,
688
+ top_p: float,
689
+ greedy: bool,
690
+ max_audio_tokens: int,
691
+ ) -> list[tuple[torch.Tensor, int, list[int], list[int]]]:
692
+ import time
693
+
694
+ if not samples:
695
+ return []
696
+ if not use_cache:
697
+ return [
698
+ generate_segmentwise(
699
+ model=model,
700
+ sample=sample,
701
+ layout=layout,
702
+ device=device,
703
+ use_cache=use_cache,
704
+ temperature=temperature,
705
+ top_k=top_k,
706
+ top_p=top_p,
707
+ greedy=greedy,
708
+ max_audio_tokens=max_audio_tokens,
709
+ )
710
+ for sample in samples
711
+ ]
712
+
713
+ batch_size = len(samples)
714
+ seq_lens = [sample.seq_len for sample in samples]
715
+ max_seq_len = max(seq_lens)
716
+
717
+ seq_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
718
+ generated_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
719
+ chord_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
720
+ structure_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
721
+ condition_mask_templates = torch.zeros(
722
+ (batch_size, max_seq_len), dtype=torch.bool, device=device
723
+ )
724
+ is_audio_code_templates = torch.zeros(
725
+ (batch_size, max_seq_len), dtype=torch.bool, device=device
726
+ )
727
+ is_eos_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.bool, device=device)
728
+
729
+ start_positions: list[int] = []
730
+ end_positions: list[int] = []
731
+ sampled_counts = [0 for _ in samples]
732
+ sampled_chord_ids: list[list[int]] = [[] for _ in samples]
733
+ sampled_segment_ids: list[list[int]] = [[] for _ in samples]
734
+ valid_sample_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
735
+
736
+ for row_idx, sample in enumerate(samples):
737
+ seq_templates[row_idx, : sample.seq_len] = sample.input_ids.to(device)
738
+ generated_ids[row_idx, : sample.seq_len] = sample.input_ids.to(device)
739
+ chord_templates[row_idx, : sample.seq_len] = sample.chord_ids.to(device)
740
+ structure_templates[row_idx, : sample.seq_len] = sample.structure_ids.to(device)
741
+ condition_mask_templates[row_idx, : sample.seq_len] = sample.condition_mask.to(device)
742
+ is_audio_code_templates[row_idx, : sample.seq_len] = sample.is_audio_codebook.to(device)
743
+ is_eos_templates[row_idx, : sample.seq_len] = sample.is_eos.to(device)
744
+
745
+ slot_positions = torch.where(
746
+ is_audio_code_templates[row_idx, : sample.seq_len]
747
+ | is_eos_templates[row_idx, : sample.seq_len]
748
+ )[0]
749
+ if slot_positions.numel() == 0:
750
+ valid_sample_mask[row_idx] = False
751
+ start_positions.append(sample.seq_len)
752
+ end_positions.append(sample.seq_len - 1)
753
+ continue
754
+ start_pos = int(slot_positions[0].item())
755
+ if sample.segments:
756
+ end_pos = int(sample.segments[-1].eos_pos)
757
+ else:
758
+ end_pos = int(slot_positions[-1].item())
759
+ start_positions.append(start_pos)
760
+ end_positions.append(end_pos)
761
+
762
+ if not bool(valid_sample_mask.any().item()):
763
+ return [
764
+ (sample.input_ids.detach().cpu(), 0, [], [])
765
+ for sample in samples
766
+ ]
767
+
768
+ start_positions_t = torch.tensor(start_positions, dtype=torch.long, device=device)
769
+ end_positions_t = torch.tensor(end_positions, dtype=torch.long, device=device)
770
+ prefix_lens = start_positions_t.clone()
771
+ max_prefix_len = int(prefix_lens.max().item())
772
+ max_decode_steps = int((end_positions_t - start_positions_t + 1).clamp_min(0).max().item())
773
+
774
+ cond_template = model.condition_encoder(chord_templates, structure_templates)
775
+
776
+ prefix_attention_mask = (
777
+ torch.arange(max_prefix_len, device=device).unsqueeze(0) < prefix_lens.unsqueeze(1)
778
+ ).to(torch.long)
779
+ prefill_t0 = time.perf_counter()
780
+ maybe_mark_compile_step_begin(model)
781
+ out = model(
782
+ input_ids=generated_ids[:, :max_prefix_len],
783
+ attention_mask=prefix_attention_mask,
784
+ condition_mask=condition_mask_templates[:, :max_prefix_len],
785
+ cond_precomputed=cond_template[:, :max_prefix_len, :],
786
+ use_cache=True,
787
+ )
788
+ prefill_time_s = time.perf_counter() - prefill_t0
789
+
790
+ gather_idx = (prefix_lens - 1).clamp_min(0)
791
+ batch_indices = torch.arange(batch_size, device=device)
792
+ logits_next = out.logits[batch_indices, gather_idx, :]
793
+ past_key_values = out.past_key_values
794
+
795
+ step_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
796
+ decode_valid_mask = torch.zeros(
797
+ (batch_size, max_decode_steps), dtype=torch.bool, device=device
798
+ )
799
+ decode_time_s = 0.0
800
+
801
+ for step_idx in range(max_decode_steps):
802
+ cur_positions = start_positions_t + step_idx
803
+ active_mask = valid_sample_mask & cur_positions.le(end_positions_t)
804
+ if not bool(active_mask.any().item()):
805
+ break
806
+
807
+ next_ids = torch.zeros(batch_size, dtype=torch.long, device=device)
808
+ for row_idx in range(batch_size):
809
+ if not bool(active_mask[row_idx].item()):
810
+ continue
811
+ cur_pos = int(cur_positions[row_idx].item())
812
+ if bool(is_audio_code_templates[row_idx, cur_pos].item()):
813
+ if max_audio_tokens > 0 and sampled_counts[row_idx] >= max_audio_tokens:
814
+ valid_sample_mask[row_idx] = False
815
+ continue
816
+ next_id = sample_audio_token_from_logits(
817
+ logits_next[row_idx : row_idx + 1],
818
+ layout=layout,
819
+ temperature=temperature,
820
+ top_k=top_k,
821
+ top_p=top_p,
822
+ greedy=greedy,
823
+ )
824
+ sampled_counts[row_idx] += 1
825
+ sampled_chord_ids[row_idx].append(
826
+ int(chord_templates[row_idx, cur_pos].item())
827
+ )
828
+ sampled_segment_ids[row_idx].append(
829
+ int(structure_templates[row_idx, cur_pos].item())
830
+ )
831
+ elif bool(is_eos_templates[row_idx, cur_pos].item()):
832
+ next_id = layout.eos_audio
833
+ else:
834
+ next_id = int(seq_templates[row_idx, cur_pos].item())
835
+
836
+ generated_ids[row_idx, cur_pos] = int(next_id)
837
+ next_ids[row_idx] = int(next_id)
838
+ decode_valid_mask[row_idx, step_idx] = True
839
+
840
+ if step_idx >= max_decode_steps - 1:
841
+ break
842
+
843
+ step_ids[:, 0] = next_ids
844
+ step_attention_mask = torch.cat(
845
+ [
846
+ prefix_attention_mask,
847
+ decode_valid_mask[:, : step_idx + 1].to(torch.long),
848
+ ],
849
+ dim=1,
850
+ )
851
+ step_condition_mask = torch.zeros((batch_size, 1), dtype=torch.bool, device=device)
852
+ step_cond = torch.zeros(
853
+ (batch_size, 1, cond_template.shape[-1]),
854
+ dtype=cond_template.dtype,
855
+ device=device,
856
+ )
857
+ for row_idx in range(batch_size):
858
+ if not bool(decode_valid_mask[row_idx, step_idx].item()):
859
+ continue
860
+ cur_pos = int(cur_positions[row_idx].item())
861
+ step_condition_mask[row_idx, 0] = condition_mask_templates[row_idx, cur_pos]
862
+ step_cond[row_idx, 0, :] = cond_template[row_idx, cur_pos, :]
863
+
864
+ step_t0 = time.perf_counter()
865
+ maybe_mark_compile_step_begin(model)
866
+ out = model(
867
+ input_ids=step_ids,
868
+ attention_mask=step_attention_mask,
869
+ condition_mask=step_condition_mask,
870
+ cond_precomputed=step_cond,
871
+ past_key_values=past_key_values,
872
+ use_cache=True,
873
+ )
874
+ decode_time_s += time.perf_counter() - step_t0
875
+ logits_next = out.logits[:, -1, :]
876
+ past_key_values = out.past_key_values
877
+
878
+ total_sampled_tokens = sum(sampled_counts)
879
+ total_gen_time_s = prefill_time_s + decode_time_s
880
+ tokens_per_second = (
881
+ float(total_sampled_tokens) / decode_time_s
882
+ if decode_time_s > 0 and total_sampled_tokens > 0
883
+ else 0.0
884
+ )
885
+ print(
886
+ "[PROFILE] batch_generation "
887
+ f"batch_size={batch_size} "
888
+ f"prefill_s={prefill_time_s:.3f} "
889
+ f"decode_s={decode_time_s:.3f} "
890
+ f"total_s={total_gen_time_s:.3f} "
891
+ f"sampled_audio_tokens={total_sampled_tokens} "
892
+ f"decode_tok_per_s={tokens_per_second:.3f}"
893
+ )
894
+
895
+ outputs: list[tuple[torch.Tensor, int, list[int], list[int]]] = []
896
+ for row_idx, sample in enumerate(samples):
897
+ if not bool((torch.where(sample.is_audio_codebook | sample.is_eos)[0]).numel()):
898
+ outputs.append((sample.input_ids.detach().cpu(), 0, [], []))
899
+ continue
900
+ outputs.append(
901
+ (
902
+ generated_ids[row_idx, : sample.seq_len].detach().cpu(),
903
+ sampled_counts[row_idx],
904
+ sampled_chord_ids[row_idx],
905
+ sampled_segment_ids[row_idx],
906
+ )
907
+ )
908
+ return outputs
909
+
910
+
911
+ def save_outputs(
912
+ output_dir: str,
913
+ output_prefix: str,
914
+ sample: HFTemplateSample,
915
+ layout: TokenLayout,
916
+ generated_ids: torch.Tensor,
917
+ sampled_chord_ids: list[int],
918
+ sampled_segment_ids: list[int],
919
+ args: argparse.Namespace,
920
+ mucodec_decoder: Any = None,
921
+ ) -> None:
922
+ import time
923
+
924
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
925
+ stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
926
+ prefix = output_prefix or f"{sample.song_id}_{args.sample_idx}_{stamp}"
927
+
928
+ json_dir = args.json_output_dir or os.path.join(output_dir, "json")
929
+ wav_dir = args.wav_output_dir or os.path.join(output_dir, "wav")
930
+ Path(json_dir).mkdir(parents=True, exist_ok=True)
931
+ Path(wav_dir).mkdir(parents=True, exist_ok=True)
932
+
933
+ json_path = os.path.join(json_dir, f"{prefix}.chord_segment.json")
934
+ wav_path = os.path.join(wav_dir, f"{prefix}.wav")
935
+
936
+ gen_full = generated_ids.cpu().numpy().astype(np.int64)
937
+
938
+ gen_audio_raw = gen_full[
939
+ (gen_full >= layout.audio_start) & (gen_full < layout.audio_end)
940
+ ]
941
+ gen_audio_shift = gen_audio_raw - layout.audio_start
942
+
943
+ save_t0 = time.perf_counter()
944
+ if gen_audio_shift.size == 0:
945
+ print("[WARN] No generated MuCodec tokens; skipping wav decode.")
946
+ else:
947
+ import torchaudio
948
+
949
+ wave = decode_mucodec_codes(mucodec_decoder, gen_audio_shift, args)
950
+ torchaudio.save(wav_path, wave, int(args.mucodec_sample_rate))
951
+ print(f"[OK] {wav_path}")
952
+
953
+ chord_intervals = to_intervals(
954
+ sampled_chord_ids, fps=int(args.fps), mapper=chord_id_to_type
955
+ )
956
+ segment_intervals = to_intervals(
957
+ sampled_segment_ids, fps=int(args.fps), mapper=segment_id_to_type
958
+ )
959
+
960
+ # PAD is used for EOS-related conditioning; drop it in exported json.
961
+ chord_intervals = [x for x in chord_intervals if x.get("type") != "pad"]
962
+ segment_intervals = [x for x in segment_intervals if x.get("type") != "pad"]
963
+ chord_intervals = merge_same_type_with_small_gap(
964
+ chord_intervals, fps=int(args.fps), max_gap_frames=1
965
+ )
966
+ segment_intervals = merge_same_type_with_small_gap(
967
+ segment_intervals, fps=int(args.fps), max_gap_frames=1
968
+ )
969
+
970
+ chord_segment = {
971
+ "song_id": sample.song_id,
972
+ "sample_idx": int(args.sample_idx),
973
+ "fps": int(args.fps),
974
+ "generated_audio_count": int(gen_audio_raw.shape[0]),
975
+ "chord": chord_intervals,
976
+ "segment": segment_intervals,
977
+ }
978
+ with open(json_path, "w", encoding="utf-8") as f:
979
+ json.dump(chord_segment, f, ensure_ascii=False, indent=2)
980
+
981
+ print(f"[OK] {json_path}")
982
+ save_time_s = time.perf_counter() - save_t0
983
+ print(
984
+ "[PROFILE] save "
985
+ f"save_s={save_time_s:.3f} "
986
+ f"generated_audio_count={int(gen_audio_raw.shape[0])}"
987
+ )
988
+
989
+
990
+ def main() -> None:
991
+ import time
992
+
993
+ args = parse_args()
994
+ seed_everything(args.seed)
995
+
996
+ use_cache = args.use_cache and not args.no_cache
997
+
998
+ device = resolve_device(args.device)
999
+ dtype = {
1000
+ "float32": torch.float32,
1001
+ "float16": torch.float16,
1002
+ "bfloat16": torch.bfloat16,
1003
+ }[args.dtype]
1004
+ if device.type == "cpu" and dtype != torch.float32:
1005
+ print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.")
1006
+ dtype = torch.float32
1007
+
1008
+ print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}")
1009
+ print(f"[INFO] loading model from {args.model_path}")
1010
+ model = load_magel_checkpoint(
1011
+ checkpoint_path=args.model_path,
1012
+ device=device,
1013
+ dtype=dtype,
1014
+ attn_implementation=args.attn_implementation,
1015
+ )
1016
+ model = maybe_compile_model(
1017
+ model,
1018
+ enabled=bool(args.compile),
1019
+ mode=str(args.compile_mode),
1020
+ )
1021
+ num_audio_codebook = (
1022
+ int(args.num_audio_codebook)
1023
+ if args.num_audio_codebook is not None
1024
+ else int(getattr(model.config, "magel_num_audio_token", 16384))
1025
+ )
1026
+ print(f"[INFO] num_audio_codebook={num_audio_codebook}")
1027
+
1028
+ print(f"[INFO] loading HF sample idx={args.sample_idx} from {args.dataset_path}")
1029
+ sample = load_hf_template_sample(
1030
+ dataset_path=args.dataset_path,
1031
+ split=args.split,
1032
+ tokenizer_path=args.tokenizer_path,
1033
+ sample_idx=args.sample_idx,
1034
+ num_audio_codebook=num_audio_codebook,
1035
+ )
1036
+ layout = TokenLayout(
1037
+ num_text_token=sample.num_text_token,
1038
+ num_audio_codebook=num_audio_codebook,
1039
+ )
1040
+ print(
1041
+ f"[INFO] song_id={sample.song_id}, seq_len={sample.seq_len}, segments={len(sample.segments)}"
1042
+ )
1043
+ mucodec_decoder = build_mucodec_decoder(args)
1044
+ print("[INFO] running segment-level autoregressive generation...")
1045
+ t1 = time.time()
1046
+ (
1047
+ generated_ids,
1048
+ sampled_count,
1049
+ sampled_chord_ids,
1050
+ sampled_segment_ids,
1051
+ ) = generate_segmentwise(
1052
+ model=model,
1053
+ sample=sample,
1054
+ layout=layout,
1055
+ device=device,
1056
+ use_cache=use_cache,
1057
+ temperature=float(args.temperature),
1058
+ top_k=int(args.top_k),
1059
+ top_p=float(args.top_p),
1060
+ greedy=bool(args.greedy),
1061
+ max_audio_tokens=max(0, int(args.max_audio_tokens)),
1062
+ )
1063
+
1064
+ print(f"[INFO] sampled audio tokens: {sampled_count}")
1065
+ print(f"[INFO] output sequence length: {generated_ids.numel()}")
1066
+ t2 = time.time()
1067
+
1068
+ print("total time:", t2 - t1)
1069
+
1070
+ save_outputs(
1071
+ output_dir=args.output_dir,
1072
+ output_prefix=args.output_prefix,
1073
+ sample=sample,
1074
+ layout=layout,
1075
+ generated_ids=generated_ids,
1076
+ sampled_chord_ids=sampled_chord_ids,
1077
+ sampled_segment_ids=sampled_segment_ids,
1078
+ args=args,
1079
+ mucodec_decoder=mucodec_decoder,
1080
+ )
1081
+
1082
+
1083
+ if __name__ == "__main__":
1084
+ main()
modelling_qwen3.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Any, Optional
4
+
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Config, Qwen3ForCausalLM
7
+ from transformers.cache_utils import Cache
8
+
9
+ from decoders import Qwen3DecoderLayerAdaLN
10
+ from condition_encoders import ConditionEncoder
11
+ from vocab import (
12
+ CHORD_BOS_ID,
13
+ CHORD_EOS_ID,
14
+ CHORD_N_ID,
15
+ SEGMENT_FALLBACK_ID,
16
+ STRUCTURE_BOS_ID,
17
+ STRUCTURE_EOS_ID,
18
+ )
19
+
20
+
21
+ class MAGEL(Qwen3ForCausalLM):
22
+ """
23
+ - masks-based CE loss
24
+ - decoder layers replaced with Qwen3DecoderLayerAdaLN
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ config: Qwen3Config,
30
+ **kwargs: Any,
31
+ ):
32
+ super().__init__(config)
33
+
34
+ adaln_dim = int(config.hidden_size)
35
+ chord_dropout_trigger_prob = float(config.magel_chord_dropout_trigger_prob)
36
+ structure_dropout_trigger_prob = float(config.magel_structure_dropout_trigger_prob)
37
+
38
+ self.vocab_size = config.vocab_size
39
+ self.adaln_dim = adaln_dim
40
+
41
+ self.condition_encoder = ConditionEncoder(hidden_size=adaln_dim)
42
+ self.chord_dropout_trigger_prob = chord_dropout_trigger_prob
43
+ self.structure_dropout_trigger_prob = structure_dropout_trigger_prob
44
+
45
+ for layer_idx in range(len(self.model.layers)):
46
+ self.model.layers[layer_idx] = Qwen3DecoderLayerAdaLN(
47
+ config,
48
+ layer_idx=layer_idx,
49
+ cond_dim=adaln_dim,
50
+ )
51
+
52
+ # Persist MAGEL-specific ctor args so checkpoints can be reloaded without
53
+ # out-of-band flags.
54
+ self.config.magel_chord_dropout_trigger_prob = chord_dropout_trigger_prob
55
+ self.config.magel_structure_dropout_trigger_prob = structure_dropout_trigger_prob
56
+
57
+ self.post_init()
58
+
59
+ @staticmethod
60
+ def _drop_audio_condition_spans(
61
+ ids: torch.LongTensor,
62
+ condition_mask: torch.BoolTensor,
63
+ trigger_prob: float,
64
+ replacement_id: int,
65
+ bos_id: int,
66
+ eos_id: int,
67
+ ) -> torch.LongTensor:
68
+ if trigger_prob <= 0.0:
69
+ return ids
70
+
71
+ # Only drop aligned audio-condition positions; keep BOS/EOS untouched.
72
+ eligible_mask = condition_mask & (ids != bos_id) & (ids != eos_id)
73
+
74
+ if not eligible_mask.any():
75
+ return ids
76
+
77
+ dropped = ids.clone()
78
+ trigger_mask = torch.rand(ids.size(0), device=ids.device) < trigger_prob
79
+ span_len = 25
80
+
81
+ for batch_idx in torch.nonzero(trigger_mask, as_tuple=False).flatten():
82
+ candidate_positions = torch.nonzero(
83
+ eligible_mask[batch_idx], as_tuple=False
84
+ ).flatten()
85
+ num_candidates = int(candidate_positions.numel())
86
+ if num_candidates == 0:
87
+ continue
88
+ drop_ratio = torch.rand((), device=ids.device).item()
89
+ num_to_drop = int(round(drop_ratio * num_candidates))
90
+ if num_to_drop <= 0:
91
+ continue
92
+
93
+ remaining = num_to_drop
94
+ available_positions = candidate_positions.clone()
95
+ while remaining > 0:
96
+ num_available = int(available_positions.numel())
97
+ if num_available == 0:
98
+ break
99
+
100
+ cur_span_len = min(span_len, remaining)
101
+ if num_available <= cur_span_len:
102
+ start_idx = 0
103
+ selected_positions = available_positions[:cur_span_len]
104
+ else:
105
+ max_start = num_available - cur_span_len + 1
106
+ start_idx = int(
107
+ torch.randint(0, max_start, (1,), device=ids.device).item()
108
+ )
109
+ selected_positions = available_positions[
110
+ start_idx : start_idx + cur_span_len
111
+ ]
112
+ dropped[batch_idx, selected_positions] = replacement_id
113
+
114
+ keep_mask = torch.ones(
115
+ num_available,
116
+ dtype=torch.bool,
117
+ device=ids.device,
118
+ )
119
+ keep_mask[start_idx : start_idx + int(selected_positions.numel())] = False
120
+ available_positions = available_positions[keep_mask]
121
+ remaining -= int(selected_positions.numel())
122
+
123
+ return dropped
124
+
125
+ def _build_condition(
126
+ self,
127
+ chord_ids: Optional[torch.LongTensor],
128
+ structure_ids: Optional[torch.LongTensor],
129
+ condition_mask: Optional[torch.BoolTensor],
130
+ cond_precomputed: Optional[torch.FloatTensor],
131
+ ) -> Optional[torch.FloatTensor]:
132
+ if cond_precomputed is not None:
133
+ return cond_precomputed
134
+ if chord_ids is None or structure_ids is None:
135
+ return None
136
+ if self.training:
137
+ if condition_mask is None:
138
+ raise ValueError("condition_mask is required during training.")
139
+ chord_ids = self._drop_audio_condition_spans(
140
+ ids=chord_ids,
141
+ condition_mask=condition_mask,
142
+ trigger_prob=self.chord_dropout_trigger_prob,
143
+ replacement_id=CHORD_N_ID,
144
+ bos_id=CHORD_BOS_ID,
145
+ eos_id=CHORD_EOS_ID,
146
+ )
147
+ structure_ids = self._drop_audio_condition_spans(
148
+ ids=structure_ids,
149
+ condition_mask=condition_mask,
150
+ trigger_prob=self.structure_dropout_trigger_prob,
151
+ replacement_id=SEGMENT_FALLBACK_ID,
152
+ bos_id=STRUCTURE_BOS_ID,
153
+ eos_id=STRUCTURE_EOS_ID,
154
+ )
155
+ return self.condition_encoder(chord_ids, structure_ids)
156
+
157
+ def ce_loss(
158
+ self,
159
+ logits: torch.FloatTensor,
160
+ labels: Optional[torch.LongTensor],
161
+ masks: Optional[torch.LongTensor],
162
+ ) -> Optional[torch.Tensor]:
163
+ if labels is None or masks is None:
164
+ return None
165
+
166
+ shift_logits = logits[:, :-1, :].contiguous()
167
+ shift_labels = labels[:, 1:].clone()
168
+ valid_token_mask = masks[:, 1:].bool().contiguous()
169
+
170
+ if not valid_token_mask.any():
171
+ return shift_logits.new_zeros(())
172
+
173
+ shift_labels.masked_fill_(~valid_token_mask, -100)
174
+ loss_sum = F.cross_entropy(
175
+ shift_logits.view(-1, self.config.vocab_size),
176
+ shift_labels.view(-1).to(shift_logits.device),
177
+ ignore_index=-100,
178
+ reduction="sum",
179
+ )
180
+ valid_count = valid_token_mask.sum().to(
181
+ device=loss_sum.device,
182
+ dtype=loss_sum.dtype,
183
+ )
184
+ return loss_sum / valid_count.clamp_min(1)
185
+
186
+ def forward(
187
+ self,
188
+ input_ids: torch.LongTensor,
189
+ masks: Optional[torch.LongTensor] = None,
190
+ attention_mask: Optional[torch.Tensor] = None,
191
+ position_ids: Optional[torch.LongTensor] = None,
192
+ past_key_values: Optional[Cache] = None,
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ labels: Optional[torch.LongTensor] = None,
195
+ use_cache: Optional[bool] = None,
196
+ cache_position: Optional[torch.LongTensor] = None,
197
+ chord_ids: Optional[torch.LongTensor] = None,
198
+ structure_ids: Optional[torch.LongTensor] = None,
199
+ condition_mask: Optional[torch.BoolTensor] = None,
200
+ cond_precomputed: Optional[torch.FloatTensor] = None,
201
+ ) -> CausalLMOutputWithPast:
202
+
203
+ if use_cache is None:
204
+ use_cache = self.config.use_cache
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.model.embed_tokens(input_ids)
208
+
209
+ cond = self._build_condition(
210
+ chord_ids=chord_ids,
211
+ structure_ids=structure_ids,
212
+ condition_mask=condition_mask,
213
+ cond_precomputed=cond_precomputed,
214
+ )
215
+
216
+ base_out = self.model(
217
+ inputs_embeds=inputs_embeds,
218
+ attention_mask=attention_mask,
219
+ position_ids=position_ids,
220
+ past_key_values=past_key_values,
221
+ use_cache=use_cache,
222
+ cond_expanded=cond,
223
+ condition_mask=condition_mask,
224
+ cache_position=cache_position,
225
+ )
226
+
227
+ hidden_states = base_out.last_hidden_state
228
+ logits = self.lm_head(hidden_states)
229
+ loss = self.ce_loss(logits=logits, labels=labels, masks=masks)
230
+
231
+ return CausalLMOutputWithPast(
232
+ loss=loss,
233
+ logits=logits,
234
+ past_key_values=base_out.past_key_values,
235
+ hidden_states=base_out.hidden_states,
236
+ attentions=base_out.attentions,
237
+ )
muse_mucodec_chord.ds/dataset_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"splits": ["train", "validation"]}
runtime_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import datasets
4
+ import numpy as np
5
+ import torch
6
+ from datasets import DatasetDict
7
+ from transformers import AutoConfig
8
+
9
+ from dataset import MusicDataset
10
+ from modelling_qwen3 import MAGEL
11
+
12
+
13
+ def seed_everything(seed: int) -> None:
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ if torch.cuda.is_available():
18
+ torch.cuda.manual_seed_all(seed)
19
+
20
+
21
+ def resolve_device(device_arg: str) -> torch.device:
22
+ if device_arg != "auto":
23
+ return torch.device(device_arg)
24
+ if torch.cuda.is_available():
25
+ return torch.device("cuda")
26
+ if torch.backends.mps.is_available():
27
+ return torch.device("mps")
28
+ return torch.device("cpu")
29
+
30
+
31
+ def move_batch_to_device(
32
+ batch: dict[str, torch.Tensor], device: torch.device
33
+ ) -> dict[str, torch.Tensor]:
34
+ return {
35
+ key: value.to(device) if torch.is_tensor(value) else value
36
+ for key, value in batch.items()
37
+ }
38
+
39
+ def load_music_dataset(
40
+ dataset_path: str,
41
+ split: str,
42
+ tokenizer_path: str,
43
+ num_audio_token: int = 16384,
44
+ fps: int = 25,
45
+ use_fast: bool = True,
46
+ ) -> MusicDataset:
47
+ hf = datasets.load_from_disk(dataset_path)
48
+ if isinstance(hf, DatasetDict):
49
+ if split not in hf:
50
+ raise KeyError(f"Split not found: {split}")
51
+ container = hf
52
+ else:
53
+ container = {split: hf}
54
+ return MusicDataset(
55
+ datasets=container,
56
+ split=split,
57
+ tokenizer_path=tokenizer_path,
58
+ num_audio_token=num_audio_token,
59
+ fps=fps,
60
+ use_fast=use_fast,
61
+ )
62
+
63
+
64
+ def load_magel_checkpoint(
65
+ checkpoint_path: str,
66
+ device: torch.device,
67
+ dtype: torch.dtype = torch.float32,
68
+ attn_implementation: str = "sdpa",
69
+ ) -> MAGEL:
70
+ config = AutoConfig.from_pretrained(
71
+ checkpoint_path,
72
+ local_files_only=True,
73
+ )
74
+
75
+ model = MAGEL.from_pretrained(
76
+ checkpoint_path,
77
+ config=config,
78
+ torch_dtype=dtype,
79
+ attn_implementation=attn_implementation,
80
+ local_files_only=True,
81
+ )
82
+ model.to(device=device)
83
+ model.eval()
84
+ return model
85
+
86
+
87
+ def maybe_compile_model(
88
+ model,
89
+ enabled: bool = False,
90
+ mode: str = "reduce-overhead",
91
+ ):
92
+ if not enabled:
93
+ setattr(model, "_magel_is_compiled", False)
94
+ return model
95
+ if not hasattr(torch, "compile"):
96
+ raise RuntimeError("torch.compile is not available in this PyTorch build.")
97
+ compiled_model = torch.compile(model, mode=mode)
98
+ setattr(compiled_model, "_magel_is_compiled", True)
99
+ return compiled_model
100
+
101
+
102
+ def maybe_mark_compile_step_begin(model) -> None:
103
+ if not getattr(model, "_magel_is_compiled", False):
104
+ return
105
+ compiler_ns = getattr(torch, "compiler", None)
106
+ if compiler_ns is None:
107
+ return
108
+ mark_step_begin = getattr(compiler_ns, "cudagraph_mark_step_begin", None)
109
+ if mark_step_begin is None:
110
+ return
111
+ mark_step_begin()
train.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Train MAGEL directly from a vanilla Qwen3 checkpoint.
5
+
6
+ Compared with train.py/train_newparaonly.py, this script:
7
+ 1) Loads an original Qwen3 base checkpoint.
8
+ 2) Resolves MAGEL hparams explicitly at construction time.
9
+ 3) Initializes MAGEL extra modules from scratch and trains end-to-end.
10
+ """
11
+
12
+ import argparse
13
+ import os
14
+
15
+ import torch
16
+ from transformers import (
17
+ AutoConfig,
18
+ Trainer,
19
+ TrainingArguments,
20
+ )
21
+ import datasets
22
+ from dataset import DataCollate, MusicDataset
23
+ from modelling_qwen3 import MAGEL
24
+
25
+
26
+ def resolve_model_source(model_path: str, resume_from_checkpoint: str | None) -> str:
27
+ if not resume_from_checkpoint:
28
+ return model_path
29
+
30
+ if os.path.abspath(model_path) != os.path.abspath(resume_from_checkpoint):
31
+ print(
32
+ "Ignoring --model_path during resume and loading config/model from: "
33
+ f"{resume_from_checkpoint}"
34
+ )
35
+ return resume_from_checkpoint
36
+
37
+
38
+ def create_model(
39
+ model_path: str,
40
+ model_dtype: torch.dtype,
41
+ target_vocab_size: int,
42
+ attn_implementation: str,
43
+ ) -> MAGEL:
44
+ print(f"Loading Qwen3 model from: {model_path}")
45
+
46
+ config = AutoConfig.from_pretrained(
47
+ model_path,
48
+ local_files_only=True,
49
+ )
50
+
51
+ model = MAGEL.from_pretrained(
52
+ model_path,
53
+ torch_dtype=model_dtype,
54
+ config=config,
55
+ attn_implementation=attn_implementation,
56
+ ignore_mismatched_sizes=True,
57
+ local_files_only=True,
58
+ )
59
+ model.resize_token_embeddings(target_vocab_size)
60
+
61
+ total_params = sum(p.numel() for p in model.parameters())
62
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
63
+ magel_extra_params = sum(
64
+ p.numel()
65
+ for name, p in model.named_parameters()
66
+ if ("condition_encoder" in name or "dit_adaln" in name)
67
+ )
68
+
69
+ print(f"Total parameters: {total_params:,}")
70
+ print(f"Trainable parameters: {trainable_params:,}")
71
+ print(f"MAGEL extra parameters: {magel_extra_params:,}")
72
+ print(
73
+ "MAGEL config: "
74
+ f"adaln_dim={model.adaln_dim}, "
75
+ f"chord_dropout_trigger_prob={model.chord_dropout_trigger_prob}, "
76
+ f"structure_dropout_trigger_prob={model.structure_dropout_trigger_prob}"
77
+ )
78
+
79
+ return model
80
+
81
+
82
+ def create_dataset(
83
+ dataset_path: str,
84
+ tokenizer_path: str,
85
+ num_audio_token: int = 16384,
86
+ ) -> MusicDataset:
87
+ print(f"Loading dataset from: {dataset_path}")
88
+ print(f"Loading tokenizer from: {tokenizer_path}")
89
+
90
+ hf_ds = datasets.load_from_disk(dataset_path)
91
+
92
+ train_dataset = MusicDataset(
93
+ hf_ds,
94
+ split="train",
95
+ tokenizer_path=tokenizer_path,
96
+ num_audio_token=num_audio_token,
97
+ use_fast=True,
98
+ )
99
+
100
+ print(f"Dataset size: {len(train_dataset)}")
101
+
102
+ return train_dataset
103
+
104
+
105
+ def main():
106
+ parser = argparse.ArgumentParser(
107
+ description="Train MAGEL directly from a vanilla Qwen3 base checkpoint."
108
+ )
109
+
110
+ parser.add_argument(
111
+ "--dataset_path",
112
+ type=str,
113
+ default="muse_mucodec_chord.ds",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--model_path",
118
+ type=str,
119
+ default="checkpoints/Qwen3-0.6B",
120
+ help="Local Qwen3 base checkpoint path.",
121
+ )
122
+ parser.add_argument(
123
+ "--tokenizer_path",
124
+ type=str,
125
+ default="checkpoints/Qwen3-0.6B",
126
+ help="Local tokenizer checkpoint path.",
127
+ )
128
+ parser.add_argument(
129
+ "--model_dtype",
130
+ type=str,
131
+ default="bfloat16",
132
+ choices=["float32", "float16", "bfloat16"],
133
+ )
134
+ parser.add_argument(
135
+ "--attn_implementation",
136
+ type=str,
137
+ default="sdpa",
138
+ choices=["eager", "sdpa", "flash_attention_2"],
139
+ )
140
+
141
+ parser.add_argument("--output_dir", type=str, default="./output_qwen3_0p6b_train")
142
+ parser.add_argument("--per_device_train_batch_size", type=int, default=1)
143
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
144
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
145
+ parser.add_argument("--weight_decay", type=float, default=0.01)
146
+ parser.add_argument("--num_train_epochs", type=float, default=20)
147
+ parser.add_argument("--warmup_steps", type=int, default=1000)
148
+ parser.add_argument("--max_grad_norm", type=float, default=5.0)
149
+ parser.add_argument("--logging_steps", type=int, default=10)
150
+ parser.add_argument(
151
+ "--resume_from_checkpoint",
152
+ type=str,
153
+ default=None,
154
+ help="Resume training from a Trainer checkpoint directory such as output_dir/checkpoint-500.",
155
+ )
156
+
157
+ parser.add_argument("--dataloader_num_workers", type=int, default=12)
158
+ parser.add_argument(
159
+ "--gradient_checkpointing",
160
+ dest="gradient_checkpointing",
161
+ action="store_true",
162
+ )
163
+
164
+ parser.add_argument(
165
+ "--deepspeed",
166
+ type=str,
167
+ default=None,
168
+ help="Path to DeepSpeed config. Leave unset to disable DeepSpeed.",
169
+ )
170
+
171
+ parser.add_argument("--report_to", type=str, default="wandb")
172
+ parser.add_argument("--wandb_project", type=str, default="vaultum-qwen3-0p6b")
173
+ parser.add_argument("--wandb_run_name", type=str, default=None)
174
+
175
+ args = parser.parse_args()
176
+
177
+ model_dtype = {
178
+ "float32": torch.float32,
179
+ "float16": torch.float16,
180
+ "bfloat16": torch.bfloat16,
181
+ }[args.model_dtype]
182
+
183
+ model_source = resolve_model_source(
184
+ model_path=args.model_path,
185
+ resume_from_checkpoint=args.resume_from_checkpoint,
186
+ )
187
+
188
+ base_config = AutoConfig.from_pretrained(
189
+ model_source,
190
+ local_files_only=True,
191
+ )
192
+
193
+ num_audio_token = int(base_config.magel_num_audio_token)
194
+ print(f"Using num_audio_token={num_audio_token}")
195
+
196
+ train_dataset = create_dataset(
197
+ dataset_path=args.dataset_path,
198
+ tokenizer_path=args.tokenizer_path,
199
+ num_audio_token=num_audio_token,
200
+ )
201
+
202
+ target_vocab_size = train_dataset.tokenizer_vocab_size
203
+
204
+ model = create_model(
205
+ model_path=model_source,
206
+ model_dtype=model_dtype,
207
+ attn_implementation=args.attn_implementation,
208
+ target_vocab_size=target_vocab_size,
209
+ )
210
+
211
+ training_args = TrainingArguments(
212
+ output_dir=args.output_dir,
213
+ per_device_train_batch_size=args.per_device_train_batch_size,
214
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
215
+ learning_rate=args.learning_rate,
216
+ weight_decay=args.weight_decay,
217
+ num_train_epochs=args.num_train_epochs,
218
+ warmup_steps=args.warmup_steps,
219
+ max_grad_norm=args.max_grad_norm,
220
+ logging_steps=args.logging_steps,
221
+ save_strategy="epoch",
222
+ dataloader_num_workers=args.dataloader_num_workers,
223
+ bf16=(args.model_dtype == "bfloat16"),
224
+ fp16=(args.model_dtype == "float16"),
225
+ gradient_checkpointing=args.gradient_checkpointing,
226
+ gradient_checkpointing_kwargs={"use_reentrant": False},
227
+ deepspeed=args.deepspeed,
228
+ remove_unused_columns=False,
229
+ dataloader_drop_last=True,
230
+ report_to=args.report_to,
231
+ logging_dir=None,
232
+ run_name=args.wandb_run_name,
233
+ )
234
+
235
+ if args.wandb_project and "wandb" in args.report_to:
236
+ os.environ["WANDB_PROJECT"] = args.wandb_project
237
+
238
+ trainer = Trainer(
239
+ model=model,
240
+ args=training_args,
241
+ train_dataset=train_dataset,
242
+ data_collator=DataCollate(),
243
+ )
244
+
245
+ if args.resume_from_checkpoint:
246
+ print(f"Resuming training from checkpoint: {args.resume_from_checkpoint}")
247
+ else:
248
+ print("Starting training from current model initialization.")
249
+
250
+ trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
251
+ final_dir = os.path.join(args.output_dir, "final")
252
+ trainer.save_model(final_dir)
253
+ train_dataset.tokenizer.save_pretrained(final_dir)
254
+
255
+ print(f"Training complete. Final model saved to: {final_dir}")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ main()
vocab/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ """Condition vocab package."""
4
+
5
+ from .chord import (
6
+ CHORD_BOS_ID,
7
+ CHORD_EOS_ID,
8
+ CHORD_LABELS,
9
+ CHORD_LABEL_TO_ID,
10
+ CHORD_N_ID,
11
+ NUM_CHORD_CLASSES,
12
+ build_frame_chord_ids,
13
+ chord_id_to_label,
14
+ chord_to_id,
15
+ normalize_chord_text,
16
+ )
17
+ from .sections import (
18
+ SEGMENT_FALLBACK_ID,
19
+ SEGMENT_LABELS,
20
+ SEGMENT_LABEL_TO_ID,
21
+ NUM_STRUCTURE_CLASSES,
22
+ STRUCTURE_BOS_ID,
23
+ STRUCTURE_EOS_ID,
24
+ build_frame_structure_ids,
25
+ normalize_structure_label,
26
+ structure_id_to_label,
27
+ structure_to_id,
28
+ )
29
+
30
+ __all__ = [
31
+ "CHORD_BOS_ID",
32
+ "CHORD_EOS_ID",
33
+ "CHORD_LABELS",
34
+ "CHORD_LABEL_TO_ID",
35
+ "CHORD_N_ID",
36
+ "NUM_CHORD_CLASSES",
37
+ "normalize_chord_text",
38
+ "chord_to_id",
39
+ "chord_id_to_label",
40
+ "build_frame_chord_ids",
41
+ "SEGMENT_LABELS",
42
+ "SEGMENT_LABEL_TO_ID",
43
+ "SEGMENT_FALLBACK_ID",
44
+ "STRUCTURE_BOS_ID",
45
+ "STRUCTURE_EOS_ID",
46
+ "NUM_STRUCTURE_CLASSES",
47
+ "normalize_structure_label",
48
+ "structure_to_id",
49
+ "structure_id_to_label",
50
+ "build_frame_structure_ids",
51
+ ]
vocab/chord.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ """Chord vocab helpers.
4
+
5
+ Dataset inspection summary for `muse_mucodec_chord.ds`:
6
+ - `chords[].type` uses only 25 labels
7
+ - labels are either `N` or `Root:maj|min`
8
+ - roots are represented with sharps rather than flats
9
+ """
10
+
11
+ import math
12
+ import re
13
+
14
+ import numpy as np
15
+
16
+ CHORD_BOS_ID = 26
17
+ CHORD_EOS_ID = 27
18
+ CHORD_N_ID = 1
19
+ NUM_CHORD_CLASSES = 28
20
+
21
+ _PITCH_TO_PC = {
22
+ "C": 0,
23
+ "B#": 0,
24
+ "C#": 1,
25
+ "Db": 1,
26
+ "D": 2,
27
+ "D#": 3,
28
+ "Eb": 3,
29
+ "E": 4,
30
+ "Fb": 4,
31
+ "F": 5,
32
+ "E#": 5,
33
+ "F#": 6,
34
+ "Gb": 6,
35
+ "G": 7,
36
+ "G#": 8,
37
+ "Ab": 8,
38
+ "A": 9,
39
+ "A#": 10,
40
+ "Bb": 10,
41
+ "B": 11,
42
+ "Cb": 11,
43
+ }
44
+
45
+ _PC_TO_ROOT = [
46
+ "C",
47
+ "C#",
48
+ "D",
49
+ "D#",
50
+ "E",
51
+ "F",
52
+ "F#",
53
+ "G",
54
+ "G#",
55
+ "A",
56
+ "A#",
57
+ "B",
58
+ ]
59
+
60
+ CHORD_ROOTS = tuple(_PC_TO_ROOT)
61
+ _QUALITIES = [
62
+ "maj",
63
+ "min",
64
+ ]
65
+ CHORD_LABELS = (
66
+ ("pad", "N")
67
+ + tuple(f"{root}:{quality}" for root in CHORD_ROOTS for quality in _QUALITIES)
68
+ + ("bos", "eos")
69
+ )
70
+
71
+ CHORD_LABEL_TO_ID = {label: index for index, label in enumerate(CHORD_LABELS)}
72
+
73
+
74
+ def normalize_chord_text(label: str) -> str:
75
+ if label is None or not isinstance(label, str):
76
+ return "N"
77
+
78
+ label = label.strip()
79
+ if not label or label.lower() in {"n", "none"}:
80
+ return "N"
81
+
82
+ base = label.replace("♯", "#").replace("♭", "b").strip()
83
+ match = re.fullmatch(r"([A-Ga-g])([#b]?):(maj|min)", base, flags=re.IGNORECASE)
84
+ if not match:
85
+ return "N"
86
+
87
+ root = (match.group(1).upper() + match.group(2)).strip()
88
+ pc = _PITCH_TO_PC.get(root)
89
+ if pc is None:
90
+ return "N"
91
+
92
+ quality = match.group(3).lower()
93
+ return f"{CHORD_ROOTS[pc]}:{quality}"
94
+
95
+
96
+ def chord_to_id(label: str) -> int:
97
+ return CHORD_LABEL_TO_ID.get(normalize_chord_text(label), CHORD_N_ID)
98
+
99
+
100
+ def chord_id_to_label(chord_id: int) -> str:
101
+ if 0 <= chord_id < len(CHORD_LABELS):
102
+ return CHORD_LABELS[chord_id]
103
+ return "N"
104
+
105
+
106
+ def build_frame_chord_ids(
107
+ chord_segments: list[dict], total_frames: int, fps: int = 25
108
+ ) -> np.ndarray:
109
+ chord_arr = (
110
+ np.full((total_frames,), CHORD_N_ID, dtype=np.int64)
111
+ if total_frames > 0
112
+ else np.zeros((0,), dtype=np.int64)
113
+ )
114
+
115
+ if not isinstance(chord_segments, (list, tuple)):
116
+ return chord_arr
117
+
118
+ for seg in chord_segments:
119
+ if not isinstance(seg, dict):
120
+ continue
121
+ if "start_frame" in seg and "end_frame" in seg:
122
+ start = int(seg.get("start_frame", 0))
123
+ end = int(seg.get("end_frame", 0))
124
+ else:
125
+ start = int(float(seg.get("start", 0.0)) * fps)
126
+ end = int(math.ceil(float(seg.get("end", 0.0)) * fps))
127
+ start = max(0, min(total_frames, start))
128
+ end = max(start, min(total_frames, end))
129
+ if end <= start:
130
+ continue
131
+ label = seg.get("type") or seg.get("chord") or seg.get("label") or "N"
132
+ chord_arr[start:end] = chord_to_id(str(label))
133
+
134
+ return chord_arr
135
+
136
+
137
+ if __name__ == "__main__":
138
+ import datasets
139
+
140
+ ds = datasets.load_from_disk("muse_mucodec_chord.ds")
141
+ sample = ds["train"][0]
142
+ total_frames = 1500 # e.g., for a 6-second clip at 25 fps
143
+ chord_arr = build_frame_chord_ids(sample["chords"], total_frames)
144
+ print(chord_arr)
vocab/sections.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ """Section vocab helpers.
4
+
5
+ Dataset inspection summary for `muse_mucodec_chord.ds`:
6
+ - `sections[]` has fields `section/text/start/end/desc`
7
+ - `sections[].section` uses only 6 labels:
8
+ `Intro`, `Verse`, `Prechorus`, `Chorus`, `Bridge`, `Outro`
9
+ """
10
+
11
+ import math
12
+ import re
13
+
14
+ import numpy as np
15
+
16
+ SEGMENT_LABELS = (
17
+ "pad",
18
+ "intro",
19
+ "verse",
20
+ "chorus",
21
+ "prechorus",
22
+ "bridge",
23
+ "outro",
24
+ )
25
+ SEGMENT_LABEL_TO_ID = {label: index for index, label in enumerate(SEGMENT_LABELS)}
26
+ SEGMENT_FALLBACK_ID = SEGMENT_LABEL_TO_ID["pad"]
27
+
28
+ STRUCTURE_BOS_ID = len(SEGMENT_LABELS)
29
+ STRUCTURE_EOS_ID = STRUCTURE_BOS_ID + 1
30
+ NUM_STRUCTURE_CLASSES = STRUCTURE_EOS_ID + 1
31
+
32
+
33
+ def normalize_structure_label(label: str) -> str:
34
+ if label is None or not isinstance(label, str):
35
+ return "pad"
36
+
37
+ normalized = re.sub(r"[\s_-]+", "", label.strip().lower())
38
+ normalized = re.sub(r"\d+", "", normalized)
39
+
40
+ if not normalized:
41
+ return "pad"
42
+
43
+ return (
44
+ normalized
45
+ if normalized != "pad" and normalized in SEGMENT_LABEL_TO_ID
46
+ else "pad"
47
+ )
48
+
49
+
50
+ def structure_to_id(structure: str) -> int:
51
+ return SEGMENT_LABEL_TO_ID.get(
52
+ normalize_structure_label(structure), SEGMENT_FALLBACK_ID
53
+ )
54
+
55
+
56
+ def structure_id_to_label(segment_id: int) -> str:
57
+ if segment_id == STRUCTURE_BOS_ID:
58
+ return "bos"
59
+ if segment_id == STRUCTURE_EOS_ID:
60
+ return "eos"
61
+ if 0 <= segment_id < len(SEGMENT_LABELS):
62
+ return SEGMENT_LABELS[segment_id]
63
+ return "pad"
64
+
65
+
66
+ def build_frame_structure_ids(
67
+ sections: list[dict], total_frames: int, fps: int = 25
68
+ ) -> np.ndarray:
69
+
70
+ labels = np.full((total_frames,), SEGMENT_FALLBACK_ID, dtype=np.int64)
71
+ if not isinstance(sections, (list, tuple)):
72
+ return labels
73
+
74
+ for seg in sections:
75
+ if not isinstance(seg, dict):
76
+ continue
77
+
78
+ if "start_frame" in seg and "end_frame" in seg:
79
+ start = int(seg.get("start_frame", 0))
80
+ end = int(seg.get("end_frame", 0))
81
+ else:
82
+ start = int(float(seg.get("start", 0.0)) * fps)
83
+ end = int(math.ceil(float(seg.get("end", 0.0)) * fps))
84
+
85
+ start = max(0, min(total_frames, start))
86
+ end = max(start, min(total_frames, end))
87
+ if end <= start:
88
+ continue
89
+
90
+ label = seg.get("section", seg.get("structure", "")) or ""
91
+ labels[start:end] = structure_to_id(str(label))
92
+
93
+ return labels
94
+
95
+
96
+ if __name__ == "__main__":
97
+ # Example usage
98
+ segments = [
99
+ {"section": "intro", "start": 0.0, "end": 10.0},
100
+ {"section": "Verse", "start": 10.0, "end": 30.0},
101
+ {"section": "Chorus", "start": 30.0, "end": 50.0},
102
+ ]
103
+ total_frames = 50 # e.g., for a 50-second audio at 25 fps
104
+ labels = build_frame_structure_ids(segments, total_frames)
105
+ print(labels)
wandb/debug-cli.root.log ADDED
File without changes