Add files using upload-large-folder tool
Browse files- MuCodec/.gitattributes +2 -0
- MuCodec/.gitignore +3 -0
- MuCodec/LICENSE +21 -0
- MuCodec/LICENSE_weights +399 -0
- MuCodec/__pycache__/generate.cpython-310.pyc +0 -0
- MuCodec/__pycache__/generate.cpython-312.pyc +0 -0
- MuCodec/__pycache__/model.cpython-310.pyc +0 -0
- MuCodec/__pycache__/model.cpython-312.pyc +0 -0
- MuCodec/configs/models/transformer2D.json +25 -0
- MuCodec/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json +14 -0
- MuCodec/generate.py +247 -0
- MuCodec/libs/rvq/__pycache__/descript_quantize3.cpython-310.pyc +0 -0
- MuCodec/libs/rvq/descript_quantize3.py +298 -0
- MuCodec/model.py +367 -0
- MuCodec/models/attention.py +682 -0
- MuCodec/models/transformer_2d_flow.py +545 -0
- MuCodec/mp3_to_code.py +187 -0
- MuCodec/muq_dev/test.py +22 -0
- MuCodec/readme.md +67 -0
- MuCodec/requirements.txt +335 -0
- MuCodec/tools/get_melvaehifigan48k.py +1551 -0
- MuCodec/tools/torch_tools.py +100 -0
- __pycache__/audio_tokens.cpython-310.pyc +0 -0
- __pycache__/audio_tokens.cpython-312.pyc +0 -0
- __pycache__/condition_encoders.cpython-310.pyc +0 -0
- __pycache__/condition_encoders.cpython-312.pyc +0 -0
- __pycache__/dataset.cpython-310.pyc +0 -0
- __pycache__/dataset.cpython-312.pyc +0 -0
- __pycache__/decoders.cpython-310.pyc +0 -0
- __pycache__/decoders.cpython-312.pyc +0 -0
- __pycache__/inference_full.cpython-310.pyc +0 -0
- __pycache__/inference_full.cpython-312.pyc +0 -0
- __pycache__/modelling_qwen3.cpython-310.pyc +0 -0
- __pycache__/modelling_qwen3.cpython-312.pyc +0 -0
- __pycache__/runtime_utils.cpython-310.pyc +0 -0
- __pycache__/runtime_utils.cpython-312.pyc +0 -0
- audio_tokens.py +21 -0
- batch_infer_checkpoints.py +402 -0
- condition_encoders.py +149 -0
- dataset.py +513 -0
- decoders.py +158 -0
- inference_full.py +1084 -0
- modelling_qwen3.py +237 -0
- muse_mucodec_chord.ds/dataset_dict.json +1 -0
- runtime_utils.py +111 -0
- train.py +259 -0
- vocab/__init__.py +51 -0
- vocab/chord.py +144 -0
- vocab/sections.py +105 -0
- wandb/debug-cli.root.log +0 -0
MuCodec/.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
MuCodec/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pt
|
| 3 |
+
*.pth
|
MuCodec/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
MuCodec/LICENSE_weights
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution-NonCommercial 4.0 International
|
| 2 |
+
|
| 3 |
+
=======================================================================
|
| 4 |
+
|
| 5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
+
does not provide legal services or legal advice. Distribution of
|
| 7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
+
other relationship. Creative Commons makes its licenses and related
|
| 9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
+
warranties regarding its licenses, any material licensed under their
|
| 11 |
+
terms and conditions, or any related information. Creative Commons
|
| 12 |
+
disclaims all liability for damages resulting from their use to the
|
| 13 |
+
fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
+
conditions that creators and other rights holders may use to share
|
| 19 |
+
original works of authorship and other material subject to copyright
|
| 20 |
+
and certain other rights specified in the public license below. The
|
| 21 |
+
following considerations are for informational purposes only, are not
|
| 22 |
+
exhaustive, and do not form part of our licenses.
|
| 23 |
+
|
| 24 |
+
Considerations for licensors: Our public licenses are
|
| 25 |
+
intended for use by those authorized to give the public
|
| 26 |
+
permission to use material in ways otherwise restricted by
|
| 27 |
+
copyright and certain other rights. Our licenses are
|
| 28 |
+
irrevocable. Licensors should read and understand the terms
|
| 29 |
+
and conditions of the license they choose before applying it.
|
| 30 |
+
Licensors should also secure all rights necessary before
|
| 31 |
+
applying our licenses so that the public can reuse the
|
| 32 |
+
material as expected. Licensors should clearly mark any
|
| 33 |
+
material not subject to the license. This includes other CC-
|
| 34 |
+
licensed material, or material used under an exception or
|
| 35 |
+
limitation to copyright. More considerations for licensors:
|
| 36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
+
|
| 38 |
+
Considerations for the public: By using one of our public
|
| 39 |
+
licenses, a licensor grants the public permission to use the
|
| 40 |
+
licensed material under specified terms and conditions. If
|
| 41 |
+
the licensor's permission is not necessary for any reason--for
|
| 42 |
+
example, because of any applicable exception or limitation to
|
| 43 |
+
copyright--then that use is not regulated by the license. Our
|
| 44 |
+
licenses grant only permissions under copyright and certain
|
| 45 |
+
other rights that a licensor has authority to grant. Use of
|
| 46 |
+
the licensed material may still be restricted for other
|
| 47 |
+
reasons, including because others have copyright or other
|
| 48 |
+
rights in the material. A licensor may make special requests,
|
| 49 |
+
such as asking that all changes be marked or described.
|
| 50 |
+
Although not required by our licenses, you are encouraged to
|
| 51 |
+
respect those requests where reasonable. More_considerations
|
| 52 |
+
for the public:
|
| 53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
+
|
| 55 |
+
=======================================================================
|
| 56 |
+
|
| 57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
| 58 |
+
License
|
| 59 |
+
|
| 60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 61 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
| 63 |
+
License"). To the extent this Public License may be interpreted as a
|
| 64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
| 65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
| 66 |
+
such rights in consideration of benefits the Licensor receives from
|
| 67 |
+
making the Licensed Material available under these terms and
|
| 68 |
+
conditions.
|
| 69 |
+
|
| 70 |
+
Section 1 -- Definitions.
|
| 71 |
+
|
| 72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 73 |
+
Rights that is derived from or based upon the Licensed Material
|
| 74 |
+
and in which the Licensed Material is translated, altered,
|
| 75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 76 |
+
permission under the Copyright and Similar Rights held by the
|
| 77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 78 |
+
Material is a musical work, performance, or sound recording,
|
| 79 |
+
Adapted Material is always produced where the Licensed Material is
|
| 80 |
+
synched in timed relation with a moving image.
|
| 81 |
+
|
| 82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 84 |
+
accordance with the terms and conditions of this Public License.
|
| 85 |
+
|
| 86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
| 87 |
+
closely related to copyright including, without limitation,
|
| 88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 89 |
+
Rights, without regard to how the rights are labeled or
|
| 90 |
+
categorized. For purposes of this Public License, the rights
|
| 91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 92 |
+
Rights.
|
| 93 |
+
d. Effective Technological Measures means those measures that, in the
|
| 94 |
+
absence of proper authority, may not be circumvented under laws
|
| 95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 97 |
+
agreements.
|
| 98 |
+
|
| 99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 100 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 101 |
+
that applies to Your use of the Licensed Material.
|
| 102 |
+
|
| 103 |
+
f. Licensed Material means the artistic or literary work, database,
|
| 104 |
+
or other material to which the Licensor applied this Public
|
| 105 |
+
License.
|
| 106 |
+
|
| 107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
| 108 |
+
terms and conditions of this Public License, which are limited to
|
| 109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 110 |
+
Licensed Material and that the Licensor has authority to license.
|
| 111 |
+
|
| 112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
| 113 |
+
under this Public License.
|
| 114 |
+
|
| 115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
| 116 |
+
commercial advantage or monetary compensation. For purposes of
|
| 117 |
+
this Public License, the exchange of the Licensed Material for
|
| 118 |
+
other material subject to Copyright and Similar Rights by digital
|
| 119 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 120 |
+
no payment of monetary compensation in connection with the
|
| 121 |
+
exchange.
|
| 122 |
+
|
| 123 |
+
j. Share means to provide material to the public by any means or
|
| 124 |
+
process that requires permission under the Licensed Rights, such
|
| 125 |
+
as reproduction, public display, public performance, distribution,
|
| 126 |
+
dissemination, communication, or importation, and to make material
|
| 127 |
+
available to the public including in ways that members of the
|
| 128 |
+
public may access the material from a place and at a time
|
| 129 |
+
individually chosen by them.
|
| 130 |
+
|
| 131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
| 132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 134 |
+
as amended and/or succeeded, as well as other essentially
|
| 135 |
+
equivalent rights anywhere in the world.
|
| 136 |
+
|
| 137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
| 138 |
+
under this Public License. Your has a corresponding meaning.
|
| 139 |
+
|
| 140 |
+
Section 2 -- Scope.
|
| 141 |
+
|
| 142 |
+
a. License grant.
|
| 143 |
+
|
| 144 |
+
1. Subject to the terms and conditions of this Public License,
|
| 145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 148 |
+
|
| 149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 150 |
+
in part, for NonCommercial purposes only; and
|
| 151 |
+
|
| 152 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 153 |
+
NonCommercial purposes only.
|
| 154 |
+
|
| 155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 156 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 157 |
+
License does not apply, and You do not need to comply with
|
| 158 |
+
its terms and conditions.
|
| 159 |
+
|
| 160 |
+
3. Term. The term of this Public License is specified in Section
|
| 161 |
+
6(a).
|
| 162 |
+
|
| 163 |
+
4. Media and formats; technical modifications allowed. The
|
| 164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 165 |
+
all media and formats whether now known or hereafter created,
|
| 166 |
+
and to make technical modifications necessary to do so. The
|
| 167 |
+
Licensor waives and/or agrees not to assert any right or
|
| 168 |
+
authority to forbid You from making technical modifications
|
| 169 |
+
necessary to exercise the Licensed Rights, including
|
| 170 |
+
technical modifications necessary to circumvent Effective
|
| 171 |
+
Technological Measures. For purposes of this Public License,
|
| 172 |
+
simply making modifications authorized by this Section 2(a)
|
| 173 |
+
(4) never produces Adapted Material.
|
| 174 |
+
|
| 175 |
+
5. Downstream recipients.
|
| 176 |
+
|
| 177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 178 |
+
recipient of the Licensed Material automatically
|
| 179 |
+
receives an offer from the Licensor to exercise the
|
| 180 |
+
Licensed Rights under the terms and conditions of this
|
| 181 |
+
Public License.
|
| 182 |
+
|
| 183 |
+
b. No downstream restrictions. You may not offer or impose
|
| 184 |
+
any additional or different terms or conditions on, or
|
| 185 |
+
apply any Effective Technological Measures to, the
|
| 186 |
+
Licensed Material if doing so restricts exercise of the
|
| 187 |
+
Licensed Rights by any recipient of the Licensed
|
| 188 |
+
Material.
|
| 189 |
+
|
| 190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 191 |
+
may be construed as permission to assert or imply that You
|
| 192 |
+
are, or that Your use of the Licensed Material is, connected
|
| 193 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 194 |
+
the Licensor or others designated to receive attribution as
|
| 195 |
+
provided in Section 3(a)(1)(A)(i).
|
| 196 |
+
|
| 197 |
+
b. Other rights.
|
| 198 |
+
|
| 199 |
+
1. Moral rights, such as the right of integrity, are not
|
| 200 |
+
licensed under this Public License, nor are publicity,
|
| 201 |
+
privacy, and/or other similar personality rights; however, to
|
| 202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 203 |
+
assert any such rights held by the Licensor to the limited
|
| 204 |
+
extent necessary to allow You to exercise the Licensed
|
| 205 |
+
Rights, but not otherwise.
|
| 206 |
+
|
| 207 |
+
2. Patent and trademark rights are not licensed under this
|
| 208 |
+
Public License.
|
| 209 |
+
|
| 210 |
+
3. To the extent possible, the Licensor waives any right to
|
| 211 |
+
collect royalties from You for the exercise of the Licensed
|
| 212 |
+
Rights, whether directly or through a collecting society
|
| 213 |
+
under any voluntary or waivable statutory or compulsory
|
| 214 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 215 |
+
reserves any right to collect such royalties, including when
|
| 216 |
+
the Licensed Material is used other than for NonCommercial
|
| 217 |
+
purposes.
|
| 218 |
+
|
| 219 |
+
Section 3 -- License Conditions.
|
| 220 |
+
|
| 221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 222 |
+
following conditions.
|
| 223 |
+
|
| 224 |
+
a. Attribution.
|
| 225 |
+
|
| 226 |
+
1. If You Share the Licensed Material (including in modified
|
| 227 |
+
form), You must:
|
| 228 |
+
|
| 229 |
+
a. retain the following if it is supplied by the Licensor
|
| 230 |
+
with the Licensed Material:
|
| 231 |
+
|
| 232 |
+
i. identification of the creator(s) of the Licensed
|
| 233 |
+
Material and any others designated to receive
|
| 234 |
+
attribution, in any reasonable manner requested by
|
| 235 |
+
the Licensor (including by pseudonym if
|
| 236 |
+
designated);
|
| 237 |
+
|
| 238 |
+
ii. a copyright notice;
|
| 239 |
+
|
| 240 |
+
iii. a notice that refers to this Public License;
|
| 241 |
+
|
| 242 |
+
iv. a notice that refers to the disclaimer of
|
| 243 |
+
warranties;
|
| 244 |
+
|
| 245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 246 |
+
extent reasonably practicable;
|
| 247 |
+
|
| 248 |
+
b. indicate if You modified the Licensed Material and
|
| 249 |
+
retain an indication of any previous modifications; and
|
| 250 |
+
|
| 251 |
+
c. indicate the Licensed Material is licensed under this
|
| 252 |
+
Public License, and include the text of, or the URI or
|
| 253 |
+
hyperlink to, this Public License.
|
| 254 |
+
|
| 255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 256 |
+
reasonable manner based on the medium, means, and context in
|
| 257 |
+
which You Share the Licensed Material. For example, it may be
|
| 258 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 259 |
+
hyperlink to a resource that includes the required
|
| 260 |
+
information.
|
| 261 |
+
|
| 262 |
+
3. If requested by the Licensor, You must remove any of the
|
| 263 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 264 |
+
reasonably practicable.
|
| 265 |
+
|
| 266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
| 267 |
+
License You apply must not prevent recipients of the Adapted
|
| 268 |
+
Material from complying with this Public License.
|
| 269 |
+
|
| 270 |
+
Section 4 -- Sui Generis Database Rights.
|
| 271 |
+
|
| 272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 273 |
+
apply to Your use of the Licensed Material:
|
| 274 |
+
|
| 275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 277 |
+
portion of the contents of the database for NonCommercial purposes
|
| 278 |
+
only;
|
| 279 |
+
|
| 280 |
+
b. if You include all or a substantial portion of the database
|
| 281 |
+
contents in a database in which You have Sui Generis Database
|
| 282 |
+
Rights, then the database in which You have Sui Generis Database
|
| 283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
| 284 |
+
|
| 285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 286 |
+
all or a substantial portion of the contents of the database.
|
| 287 |
+
|
| 288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 289 |
+
replace Your obligations under this Public License where the Licensed
|
| 290 |
+
Rights include other Copyright and Similar Rights.
|
| 291 |
+
|
| 292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 293 |
+
|
| 294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 304 |
+
|
| 305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 314 |
+
|
| 315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 316 |
+
above shall be interpreted in a manner that, to the extent
|
| 317 |
+
possible, most closely approximates an absolute disclaimer and
|
| 318 |
+
waiver of all liability.
|
| 319 |
+
|
| 320 |
+
Section 6 -- Term and Termination.
|
| 321 |
+
|
| 322 |
+
a. This Public License applies for the term of the Copyright and
|
| 323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 324 |
+
this Public License, then Your rights under this Public License
|
| 325 |
+
terminate automatically.
|
| 326 |
+
|
| 327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 328 |
+
Section 6(a), it reinstates:
|
| 329 |
+
|
| 330 |
+
1. automatically as of the date the violation is cured, provided
|
| 331 |
+
it is cured within 30 days of Your discovery of the
|
| 332 |
+
violation; or
|
| 333 |
+
|
| 334 |
+
2. upon express reinstatement by the Licensor.
|
| 335 |
+
|
| 336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 337 |
+
right the Licensor may have to seek remedies for Your violations
|
| 338 |
+
of this Public License.
|
| 339 |
+
|
| 340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 341 |
+
Licensed Material under separate terms or conditions or stop
|
| 342 |
+
distributing the Licensed Material at any time; however, doing so
|
| 343 |
+
will not terminate this Public License.
|
| 344 |
+
|
| 345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 346 |
+
License.
|
| 347 |
+
|
| 348 |
+
Section 7 -- Other Terms and Conditions.
|
| 349 |
+
|
| 350 |
+
a. The Licensor shall not be bound by any additional or different
|
| 351 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 352 |
+
|
| 353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 354 |
+
Licensed Material not stated herein are separate from and
|
| 355 |
+
independent of the terms and conditions of this Public License.
|
| 356 |
+
|
| 357 |
+
Section 8 -- Interpretation.
|
| 358 |
+
|
| 359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 361 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 362 |
+
be made without permission under this Public License.
|
| 363 |
+
|
| 364 |
+
b. To the extent possible, if any provision of this Public License is
|
| 365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 366 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 367 |
+
cannot be reformed, it shall be severed from this Public License
|
| 368 |
+
without affecting the enforceability of the remaining terms and
|
| 369 |
+
conditions.
|
| 370 |
+
|
| 371 |
+
c. No term or condition of this Public License will be waived and no
|
| 372 |
+
failure to comply consented to unless expressly agreed to by the
|
| 373 |
+
Licensor.
|
| 374 |
+
|
| 375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 377 |
+
that apply to the Licensor or You, including from the legal
|
| 378 |
+
processes of any jurisdiction or authority.
|
| 379 |
+
|
| 380 |
+
=======================================================================
|
| 381 |
+
|
| 382 |
+
Creative Commons is not a party to its public
|
| 383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 384 |
+
its public licenses to material it publishes and in those instances
|
| 385 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
| 386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 388 |
+
material is shared under a Creative Commons public license or as
|
| 389 |
+
otherwise permitted by the Creative Commons policies published at
|
| 390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 392 |
+
of Creative Commons without its prior written consent including,
|
| 393 |
+
without limitation, in connection with any unauthorized modifications
|
| 394 |
+
to any of its public licenses or any other arrangements,
|
| 395 |
+
understandings, or agreements concerning use of licensed material. For
|
| 396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 397 |
+
public licenses.
|
| 398 |
+
|
| 399 |
+
Creative Commons may be contacted at creativecommons.org.
|
MuCodec/__pycache__/generate.cpython-310.pyc
ADDED
|
Binary file (8.18 kB). View file
|
|
|
MuCodec/__pycache__/generate.cpython-312.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
MuCodec/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
MuCodec/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
MuCodec/configs/models/transformer2D.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "Transformer2DModel",
|
| 3 |
+
"activation_fn": "gelu-approximate",
|
| 4 |
+
"attention_bias": true,
|
| 5 |
+
"attention_head_dim": 72,
|
| 6 |
+
"attention_type": "default",
|
| 7 |
+
"cross_attention_dim": null,
|
| 8 |
+
"double_self_attention": false,
|
| 9 |
+
"dropout": 0.0,
|
| 10 |
+
"in_channels": 96,
|
| 11 |
+
"norm_elementwise_affine": false,
|
| 12 |
+
"norm_eps": 1e-06,
|
| 13 |
+
"norm_num_groups": 32,
|
| 14 |
+
"norm_type": "ada_norm_single",
|
| 15 |
+
"num_attention_heads": 22,
|
| 16 |
+
"num_embeds_ada_norm": 1000,
|
| 17 |
+
"num_layers": 24,
|
| 18 |
+
"num_vector_embeds": null,
|
| 19 |
+
"only_cross_attention": false,
|
| 20 |
+
"out_channels": 32,
|
| 21 |
+
"patch_size": 2,
|
| 22 |
+
"sample_size": 384,
|
| 23 |
+
"upcast_attention": false,
|
| 24 |
+
"use_linear_projection": false
|
| 25 |
+
}
|
MuCodec/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "DDIMScheduler",
|
| 3 |
+
"_diffusers_version": "0.8.0",
|
| 4 |
+
"beta_end": 0.02,
|
| 5 |
+
"beta_schedule": "scaled_linear",
|
| 6 |
+
"beta_start": 0.0015,
|
| 7 |
+
"clip_sample": false,
|
| 8 |
+
"num_train_timesteps": 1000,
|
| 9 |
+
"prediction_type": "sample",
|
| 10 |
+
"set_alpha_to_one": false,
|
| 11 |
+
"skip_prk_steps": true,
|
| 12 |
+
"steps_offset": 1,
|
| 13 |
+
"trained_betas": null
|
| 14 |
+
}
|
MuCodec/generate.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import sys
|
| 5 |
+
from model import PromptCondAudioDiffusion
|
| 6 |
+
from diffusers import DDIMScheduler, DDPMScheduler
|
| 7 |
+
import torchaudio
|
| 8 |
+
import librosa
|
| 9 |
+
import os
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tools.get_melvaehifigan48k import build_pretrained_models
|
| 13 |
+
import tools.torch_tools as torch_tools
|
| 14 |
+
from safetensors.torch import load_file
|
| 15 |
+
|
| 16 |
+
class MuCodec:
|
| 17 |
+
def __init__(self, \
|
| 18 |
+
model_path, \
|
| 19 |
+
layer_num, \
|
| 20 |
+
load_main_model=True, \
|
| 21 |
+
device="cuda:0"):
|
| 22 |
+
|
| 23 |
+
self.layer_num = layer_num - 1
|
| 24 |
+
self.sample_rate = 48000
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
self.MAX_DURATION = 360
|
| 28 |
+
if load_main_model:
|
| 29 |
+
audio_ldm_path = os.path.dirname(os.path.abspath(__file__)) + "/tools/audioldm_48k.pth"
|
| 30 |
+
self.vae, self.stft = build_pretrained_models(audio_ldm_path)
|
| 31 |
+
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
| 32 |
+
main_config = {
|
| 33 |
+
"num_channels":32,
|
| 34 |
+
"unet_model_name":None,
|
| 35 |
+
"unet_model_config_path":os.path.dirname(os.path.abspath(__file__)) + "/configs/models/transformer2D.json",
|
| 36 |
+
"snr_gamma":None,
|
| 37 |
+
}
|
| 38 |
+
self.model = PromptCondAudioDiffusion(**main_config)
|
| 39 |
+
if model_path.endswith('.safetensors'):
|
| 40 |
+
main_weights = load_file(model_path)
|
| 41 |
+
else:
|
| 42 |
+
main_weights = torch.load(model_path, map_location='cpu')
|
| 43 |
+
self.model.load_state_dict(main_weights, strict=False)
|
| 44 |
+
self.model = self.model.to(device)
|
| 45 |
+
print ("Successfully loaded checkpoint from:", model_path)
|
| 46 |
+
else:
|
| 47 |
+
main_config = {
|
| 48 |
+
"num_channels":32,
|
| 49 |
+
"unet_model_name":None,
|
| 50 |
+
"unet_model_config_path":None,
|
| 51 |
+
"snr_gamma":None,
|
| 52 |
+
}
|
| 53 |
+
self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
| 54 |
+
main_weights = torch.load(model_path, map_location='cpu')
|
| 55 |
+
self.model.load_state_dict(main_weights, strict=False)
|
| 56 |
+
self.model = self.model.to(device)
|
| 57 |
+
print ("Successfully loaded checkpoint from:", model_path)
|
| 58 |
+
|
| 59 |
+
self.model.eval()
|
| 60 |
+
self.model.init_device_dtype(torch.device(device), torch.float32)
|
| 61 |
+
print("scaling factor: ", self.model.normfeat.std)
|
| 62 |
+
|
| 63 |
+
def file2code(self, fname):
|
| 64 |
+
orig_samples, fs = torchaudio.load(fname)
|
| 65 |
+
if(fs!=self.sample_rate):
|
| 66 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
| 67 |
+
fs = self.sample_rate
|
| 68 |
+
if orig_samples.shape[0] == 1:
|
| 69 |
+
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| 70 |
+
return self.sound2code(orig_samples)
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
| 74 |
+
def sound2code(self, orig_samples, batch_size=3):
|
| 75 |
+
if(orig_samples.ndim == 2):
|
| 76 |
+
audios = orig_samples.unsqueeze(0).to(self.device)
|
| 77 |
+
elif(orig_samples.ndim == 3):
|
| 78 |
+
audios = orig_samples.to(self.device)
|
| 79 |
+
else:
|
| 80 |
+
assert orig_samples.ndim in (2,3), orig_samples.shape
|
| 81 |
+
audios = self.preprocess_audio(audios)
|
| 82 |
+
audios = audios.squeeze(0)
|
| 83 |
+
orig_length = audios.shape[-1]
|
| 84 |
+
min_samples = int(40.96 * self.sample_rate)
|
| 85 |
+
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
| 86 |
+
print("output_len: ", output_len)
|
| 87 |
+
|
| 88 |
+
while(audios.shape[-1] < min_samples + 480):
|
| 89 |
+
audios = torch.cat([audios, audios], -1)
|
| 90 |
+
int_max_len=audios.shape[-1]//min_samples+1
|
| 91 |
+
# print("int_max_len: ", int_max_len)
|
| 92 |
+
audios = torch.cat([audios, audios], -1)
|
| 93 |
+
# print("audios:",audios.shape)
|
| 94 |
+
audios=audios[:,:int(int_max_len*(min_samples+480))]
|
| 95 |
+
codes_list=[]
|
| 96 |
+
|
| 97 |
+
audio_input = audios.reshape(2, -1, min_samples+480).permute(1, 0, 2).reshape(-1, 2, min_samples+480)
|
| 98 |
+
|
| 99 |
+
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
| 100 |
+
# import pdb; pdb.set_trace()
|
| 101 |
+
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
|
| 102 |
+
codes_list.append(torch.cat(codes, 1))
|
| 103 |
+
# print("codes_list",codes_list[0].shape)
|
| 104 |
+
|
| 105 |
+
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
|
| 106 |
+
codes=codes[:,:,:output_len]
|
| 107 |
+
|
| 108 |
+
return codes
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def code2sound(self, codes, prompt=None, duration=40.96, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
| 112 |
+
codes = codes.to(self.device)
|
| 113 |
+
first_latent = torch.randn(codes.shape[0], 32, 512, 32).to(self.device)
|
| 114 |
+
first_latent_length = 0
|
| 115 |
+
first_latent_codes_length = 0
|
| 116 |
+
if(isinstance(prompt, torch.Tensor)):
|
| 117 |
+
prompt = prompt.to(self.device)
|
| 118 |
+
if(prompt.ndim == 3):
|
| 119 |
+
assert prompt.shape[0] == 1, prompt.shape
|
| 120 |
+
prompt = prompt[0]
|
| 121 |
+
elif(prompt.ndim == 1):
|
| 122 |
+
prompt = prompt.unsqueeze(0).repeat(2,1)
|
| 123 |
+
elif(prompt.ndim == 2):
|
| 124 |
+
if(prompt.shape[0] == 1):
|
| 125 |
+
prompt = prompt.repeat(2,1)
|
| 126 |
+
|
| 127 |
+
if(prompt.shape[-1] < int(30.76 * self.sample_rate)):
|
| 128 |
+
prompt = prompt[:,:int(10.24*self.sample_rate)] # limit max length to 10.24
|
| 129 |
+
else:
|
| 130 |
+
prompt = prompt[:,int(20.48*self.sample_rate):int(30.72*self.sample_rate)] # limit max length to 10.24
|
| 131 |
+
|
| 132 |
+
true_mel , _, _ = torch_tools.wav_to_fbank2(prompt, -1, fn_STFT=self.stft) # maximum 10.24s
|
| 133 |
+
true_mel = true_mel.unsqueeze(1).to(self.device)
|
| 134 |
+
true_latent = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(true_mel[[m]])) for m in range(true_mel.shape[0])],0)
|
| 135 |
+
true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach()
|
| 136 |
+
|
| 137 |
+
first_latent[:,:,0:true_latent.shape[2],:] = true_latent
|
| 138 |
+
first_latent_length = true_latent.shape[2]
|
| 139 |
+
first_latent_codes = self.sound2code(prompt)[:,:,0:first_latent_length*2] # B 4 T
|
| 140 |
+
first_latent_codes_length = first_latent_codes.shape[-1]
|
| 141 |
+
codes = torch.cat([first_latent_codes, codes], -1)
|
| 142 |
+
|
| 143 |
+
min_samples = 1024
|
| 144 |
+
hop_samples = min_samples // 4 * 3
|
| 145 |
+
ovlp_samples = min_samples - hop_samples
|
| 146 |
+
hop_frames = hop_samples // 2
|
| 147 |
+
ovlp_frames = ovlp_samples // 2
|
| 148 |
+
|
| 149 |
+
codes_len= codes.shape[-1]
|
| 150 |
+
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
| 151 |
+
|
| 152 |
+
if(codes_len < min_samples):
|
| 153 |
+
while(codes.shape[-1] < min_samples):
|
| 154 |
+
codes = torch.cat([codes, codes], -1)
|
| 155 |
+
codes = codes[:,:,0:min_samples]
|
| 156 |
+
codes_len = codes.shape[-1]
|
| 157 |
+
if((codes_len - ovlp_frames) % hop_samples > 0):
|
| 158 |
+
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
| 159 |
+
while(codes.shape[-1] < len_codes):
|
| 160 |
+
codes = torch.cat([codes, codes], -1)
|
| 161 |
+
codes = codes[:,:,0:len_codes]
|
| 162 |
+
latent_length = 512
|
| 163 |
+
latent_list = []
|
| 164 |
+
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
| 165 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 166 |
+
for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
| 167 |
+
codes_input=[]
|
| 168 |
+
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
| 169 |
+
if(sinx == 0):
|
| 170 |
+
incontext_length = first_latent_length
|
| 171 |
+
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 172 |
+
latent_list.append(latents)
|
| 173 |
+
else:
|
| 174 |
+
true_latent = latent_list[-1][:,:,-ovlp_frames:,:]
|
| 175 |
+
len_add_to_512 = 512 - true_latent.shape[-2]
|
| 176 |
+
incontext_length = true_latent.shape[-2]
|
| 177 |
+
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], true_latent.shape[1], len_add_to_512, true_latent.shape[-1]).to(self.device)], -2)
|
| 178 |
+
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 179 |
+
latent_list.append(latents)
|
| 180 |
+
|
| 181 |
+
latent_list = [l.float() for l in latent_list]
|
| 182 |
+
latent_list[0] = latent_list[0][:,:,first_latent_length:,:]
|
| 183 |
+
min_samples = int(duration * self.sample_rate)
|
| 184 |
+
hop_samples = min_samples // 4 * 3
|
| 185 |
+
ovlp_samples = min_samples - hop_samples
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
output = None
|
| 188 |
+
for i in range(len(latent_list)):
|
| 189 |
+
latent = latent_list[i]
|
| 190 |
+
bsz , ch, t, f = latent.shape
|
| 191 |
+
latent = latent.reshape(bsz*2, ch//2, t, f)
|
| 192 |
+
mel = self.vae.decode_first_stage(latent)
|
| 193 |
+
cur_output = self.vae.decode_to_waveform(mel)
|
| 194 |
+
cur_output = torch.from_numpy(cur_output)[:, 0:min_samples]
|
| 195 |
+
|
| 196 |
+
if output is None:
|
| 197 |
+
output = cur_output
|
| 198 |
+
else:
|
| 199 |
+
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| 200 |
+
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
| 201 |
+
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| 202 |
+
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| 203 |
+
output = output[:, 0:target_len]
|
| 204 |
+
return output
|
| 205 |
+
|
| 206 |
+
@torch.no_grad()
|
| 207 |
+
def preprocess_audio(self, input_audios, threshold=0.8):
|
| 208 |
+
assert len(input_audios.shape) == 3, input_audios.shape
|
| 209 |
+
nchan = input_audios.shape[1]
|
| 210 |
+
input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
| 211 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
| 212 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
| 213 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| 214 |
+
return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def sound2sound(self, sound, prompt=None, min_duration=40.96, steps=50, disable_progress=False):
|
| 218 |
+
codes = self.sound2code(sound)
|
| 219 |
+
wave = self.code2sound(codes, prompt, duration=min_duration, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
| 220 |
+
return wave
|
| 221 |
+
|
| 222 |
+
if __name__=="__main__":
|
| 223 |
+
ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/mucodec.pt")
|
| 224 |
+
mucodec = MuCodec(model_path=ckpt_path,layer_num=7,load_main_model=True)
|
| 225 |
+
|
| 226 |
+
filelist = []
|
| 227 |
+
|
| 228 |
+
root_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_wav")
|
| 229 |
+
for f in [os.path.join(root_dir, f) for f in os.listdir(root_dir) if '.flac' in f or '.wav' in f or '.mp3' in f]:
|
| 230 |
+
a, fs = torchaudio.load(f)
|
| 231 |
+
if(fs!=48000):
|
| 232 |
+
a = torchaudio.functional.resample(a, fs, 48000)
|
| 233 |
+
if(a.shape[0]==1):
|
| 234 |
+
a = torch.cat([a,a],0)
|
| 235 |
+
ori_len = a.shape[-1]
|
| 236 |
+
filelist.append([a, '', [0, a.shape[-1]/48000.], f,ori_len])
|
| 237 |
+
|
| 238 |
+
reconstructed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "reconstructed")
|
| 239 |
+
|
| 240 |
+
os.makedirs(reconstructed_dir, exist_ok=True)
|
| 241 |
+
|
| 242 |
+
for sample_idx, (orig_samples, lyric, st_et, fname,ori_len) in enumerate(filelist):
|
| 243 |
+
print(fname, lyric)
|
| 244 |
+
wave = mucodec.sound2sound(orig_samples,None)
|
| 245 |
+
wave = wave[:,0:ori_len]
|
| 246 |
+
torchaudio.save(os.path.join(reconstructed_dir, os.path.basename(fname)),wave.detach().cpu(), 48000)
|
| 247 |
+
|
MuCodec/libs/rvq/__pycache__/descript_quantize3.cpython-310.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
MuCodec/libs/rvq/descript_quantize3.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch.nn.utils import weight_norm
|
| 9 |
+
|
| 10 |
+
def WNConv1d(*args, **kwargs):
|
| 11 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 12 |
+
|
| 13 |
+
class VectorQuantize(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Implementation of VQ similar to Karpathy's repo:
|
| 16 |
+
https://github.com/karpathy/deep-vector-quantization
|
| 17 |
+
Additionally uses following tricks from Improved VQGAN
|
| 18 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
| 19 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 20 |
+
for improved codebook usage
|
| 21 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 22 |
+
improves training stability
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.codebook_size = codebook_size
|
| 28 |
+
self.codebook_dim = codebook_dim
|
| 29 |
+
|
| 30 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
| 31 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
| 32 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 33 |
+
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
| 34 |
+
self.stale_tolerance = stale_tolerance
|
| 35 |
+
|
| 36 |
+
def forward(self, z):
|
| 37 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
| 38 |
+
the corresponding codebook vectors
|
| 39 |
+
|
| 40 |
+
Parameters
|
| 41 |
+
----------
|
| 42 |
+
z : Tensor[B x D x T]
|
| 43 |
+
|
| 44 |
+
Returns
|
| 45 |
+
-------
|
| 46 |
+
Tensor[B x D x T]
|
| 47 |
+
Quantized continuous representation of input
|
| 48 |
+
Tensor[1]
|
| 49 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 50 |
+
entries
|
| 51 |
+
Tensor[1]
|
| 52 |
+
Codebook loss to update the codebook
|
| 53 |
+
Tensor[B x T]
|
| 54 |
+
Codebook indices (quantized discrete representation of input)
|
| 55 |
+
Tensor[B x D x T]
|
| 56 |
+
Projected latents (continuous representation of input before quantization)
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 60 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
| 61 |
+
z_q, indices = self.decode_latents(z_e)
|
| 62 |
+
|
| 63 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 64 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 65 |
+
|
| 66 |
+
z_q = (
|
| 67 |
+
z_e + (z_q - z_e).detach()
|
| 68 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
| 69 |
+
|
| 70 |
+
z_q = self.out_proj(z_q)
|
| 71 |
+
|
| 72 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 73 |
+
|
| 74 |
+
def embed_code(self, embed_id):
|
| 75 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 76 |
+
|
| 77 |
+
def decode_code(self, embed_id):
|
| 78 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 79 |
+
|
| 80 |
+
def decode_latents(self, latents):
|
| 81 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 82 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
| 83 |
+
|
| 84 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
| 85 |
+
encodings = F.normalize(encodings)
|
| 86 |
+
codebook = F.normalize(codebook)
|
| 87 |
+
|
| 88 |
+
# Compute euclidean distance with codebook
|
| 89 |
+
dist = (
|
| 90 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 91 |
+
- 2 * encodings @ codebook.t()
|
| 92 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 93 |
+
)
|
| 94 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 95 |
+
z_q = self.decode_code(indices)
|
| 96 |
+
|
| 97 |
+
if(self.training):
|
| 98 |
+
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
| 99 |
+
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
| 100 |
+
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
| 101 |
+
|
| 102 |
+
# random replace codes that haven't been used for a while
|
| 103 |
+
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
| 104 |
+
if replace_code.sum(-1) > 0:
|
| 105 |
+
print("Replace {} codes".format(replace_code.sum(-1)))
|
| 106 |
+
random_input_idx = torch.randperm(encodings.shape[0])
|
| 107 |
+
random_input = encodings[random_input_idx].view(encodings.shape)
|
| 108 |
+
if random_input.shape[0] < self.codebook_size:
|
| 109 |
+
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
| 110 |
+
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
| 111 |
+
|
| 112 |
+
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
| 113 |
+
self.stale_counter = self.stale_counter * (1 - replace_code)
|
| 114 |
+
|
| 115 |
+
return z_q, indices
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ResidualVectorQuantize(nn.Module):
|
| 119 |
+
"""
|
| 120 |
+
Introduced in SoundStream: An end2end neural audio codec
|
| 121 |
+
https://arxiv.org/abs/2107.03312
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
input_dim: int = 512,
|
| 127 |
+
n_codebooks: int = 9,
|
| 128 |
+
codebook_size: int = 1024,
|
| 129 |
+
codebook_dim: Union[int, list] = 8,
|
| 130 |
+
quantizer_dropout: float = 0.0,
|
| 131 |
+
stale_tolerance: int = 100,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
if isinstance(codebook_dim, int):
|
| 135 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 136 |
+
|
| 137 |
+
self.n_codebooks = n_codebooks
|
| 138 |
+
self.codebook_dim = codebook_dim
|
| 139 |
+
self.codebook_size = codebook_size
|
| 140 |
+
|
| 141 |
+
self.quantizers = nn.ModuleList(
|
| 142 |
+
[
|
| 143 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
|
| 144 |
+
for i in range(n_codebooks)
|
| 145 |
+
]
|
| 146 |
+
)
|
| 147 |
+
self.quantizer_dropout = quantizer_dropout
|
| 148 |
+
|
| 149 |
+
def forward(self, z, n_quantizers: int = None):
|
| 150 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 151 |
+
the corresponding codebook vectors
|
| 152 |
+
Parameters
|
| 153 |
+
----------
|
| 154 |
+
z : Tensor[B x D x T]
|
| 155 |
+
n_quantizers : int, optional
|
| 156 |
+
No. of quantizers to use
|
| 157 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 158 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 159 |
+
when in training mode, and a random number of quantizers is used.
|
| 160 |
+
Returns
|
| 161 |
+
-------
|
| 162 |
+
dict
|
| 163 |
+
A dictionary with the following keys:
|
| 164 |
+
|
| 165 |
+
"z" : Tensor[B x D x T]
|
| 166 |
+
Quantized continuous representation of input
|
| 167 |
+
"codes" : Tensor[B x N x T]
|
| 168 |
+
Codebook indices for each codebook
|
| 169 |
+
(quantized discrete representation of input)
|
| 170 |
+
"latents" : Tensor[B x N*D x T]
|
| 171 |
+
Projected latents (continuous representation of input before quantization)
|
| 172 |
+
"vq/commitment_loss" : Tensor[1]
|
| 173 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 174 |
+
entries
|
| 175 |
+
"vq/codebook_loss" : Tensor[1]
|
| 176 |
+
Codebook loss to update the codebook
|
| 177 |
+
"""
|
| 178 |
+
z_q = 0
|
| 179 |
+
residual = z
|
| 180 |
+
commitment_loss = 0
|
| 181 |
+
codebook_loss = 0
|
| 182 |
+
|
| 183 |
+
codebook_indices = []
|
| 184 |
+
latents = []
|
| 185 |
+
|
| 186 |
+
if n_quantizers is None:
|
| 187 |
+
n_quantizers = self.n_codebooks
|
| 188 |
+
if self.training:
|
| 189 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 190 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 191 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 192 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 193 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 194 |
+
else:
|
| 195 |
+
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
|
| 196 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 197 |
+
|
| 198 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 199 |
+
# if self.training is False and i >= n_quantizers:
|
| 200 |
+
# break
|
| 201 |
+
|
| 202 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
| 203 |
+
residual
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Create mask to apply quantizer dropout
|
| 207 |
+
mask = (
|
| 208 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 209 |
+
)
|
| 210 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 211 |
+
residual = residual - z_q_i
|
| 212 |
+
|
| 213 |
+
# Sum losses
|
| 214 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
| 215 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
| 216 |
+
|
| 217 |
+
codebook_indices.append(indices_i)
|
| 218 |
+
latents.append(z_e_i)
|
| 219 |
+
|
| 220 |
+
codes = torch.stack(codebook_indices, dim=1)
|
| 221 |
+
latents = torch.cat(latents, dim=1)
|
| 222 |
+
|
| 223 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
| 224 |
+
for n in range(encodings.shape[1]):
|
| 225 |
+
print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
|
| 226 |
+
(encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
|
| 227 |
+
))
|
| 228 |
+
|
| 229 |
+
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
|
| 230 |
+
|
| 231 |
+
def from_codes(self, codes: torch.Tensor):
|
| 232 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
| 233 |
+
Parameters
|
| 234 |
+
----------
|
| 235 |
+
codes : Tensor[B x N x T]
|
| 236 |
+
Quantized discrete representation of input
|
| 237 |
+
Returns
|
| 238 |
+
-------
|
| 239 |
+
Tensor[B x D x T]
|
| 240 |
+
Quantized continuous representation of input
|
| 241 |
+
"""
|
| 242 |
+
z_q = 0.0
|
| 243 |
+
z_p = []
|
| 244 |
+
n_codebooks = codes.shape[1]
|
| 245 |
+
for i in range(n_codebooks):
|
| 246 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 247 |
+
z_p.append(z_p_i)
|
| 248 |
+
|
| 249 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 250 |
+
z_q = z_q + z_q_i
|
| 251 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
| 252 |
+
|
| 253 |
+
def from_latents(self, latents: torch.Tensor):
|
| 254 |
+
"""Given the unquantized latents, reconstruct the
|
| 255 |
+
continuous representation after quantization.
|
| 256 |
+
|
| 257 |
+
Parameters
|
| 258 |
+
----------
|
| 259 |
+
latents : Tensor[B x N x T]
|
| 260 |
+
Continuous representation of input after projection
|
| 261 |
+
|
| 262 |
+
Returns
|
| 263 |
+
-------
|
| 264 |
+
Tensor[B x D x T]
|
| 265 |
+
Quantized representation of full-projected space
|
| 266 |
+
Tensor[B x D x T]
|
| 267 |
+
Quantized representation of latent space
|
| 268 |
+
"""
|
| 269 |
+
z_q = 0
|
| 270 |
+
z_p = []
|
| 271 |
+
codes = []
|
| 272 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 273 |
+
|
| 274 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
| 275 |
+
0
|
| 276 |
+
]
|
| 277 |
+
for i in range(n_codebooks):
|
| 278 |
+
j, k = dims[i], dims[i + 1]
|
| 279 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 280 |
+
z_p.append(z_p_i)
|
| 281 |
+
codes.append(codes_i)
|
| 282 |
+
|
| 283 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 284 |
+
z_q = z_q + z_q_i
|
| 285 |
+
|
| 286 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
|
| 291 |
+
x = torch.randn(16, 1024, 80)
|
| 292 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
|
| 293 |
+
print(quantized_prompt_embeds.shape)
|
| 294 |
+
print(codes.shape)
|
| 295 |
+
# w/o reconstruction
|
| 296 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0
|
| 297 |
+
# w/ reconstruction
|
| 298 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
|
MuCodec/model.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import random
|
| 3 |
+
import inspect
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import typing as tp
|
| 7 |
+
from abc import ABC
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torchaudio
|
| 13 |
+
|
| 14 |
+
from einops import repeat
|
| 15 |
+
from tools.torch_tools import wav_to_fbank
|
| 16 |
+
import os
|
| 17 |
+
import diffusers
|
| 18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 19 |
+
from diffusers import DDPMScheduler
|
| 20 |
+
from models.transformer_2d_flow import Transformer2DModel
|
| 21 |
+
from libs.rvq.descript_quantize3 import ResidualVectorQuantize
|
| 22 |
+
from torch.cuda.amp import autocast
|
| 23 |
+
from muq_dev.test import load_model
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SampleProcessor(torch.nn.Module):
|
| 29 |
+
def project_sample(self, x: torch.Tensor):
|
| 30 |
+
"""Project the original sample to the 'space' where the diffusion will happen."""
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
def return_sample(self, z: torch.Tensor):
|
| 34 |
+
"""Project back from diffusion space to the actual sample space."""
|
| 35 |
+
return z
|
| 36 |
+
|
| 37 |
+
class Feature2DProcessor(SampleProcessor):
|
| 38 |
+
def __init__(self, dim: int = 8, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1., \
|
| 39 |
+
num_samples: int = 100_000):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.num_samples = num_samples
|
| 42 |
+
self.dim = dim
|
| 43 |
+
self.power_std = power_std
|
| 44 |
+
self.register_buffer('counts', torch.zeros(1))
|
| 45 |
+
self.register_buffer('sum_x', torch.zeros(dim, 32))
|
| 46 |
+
self.register_buffer('sum_x2', torch.zeros(dim, 32))
|
| 47 |
+
self.register_buffer('sum_target_x2', torch.zeros(dim, 32))
|
| 48 |
+
self.counts: torch.Tensor
|
| 49 |
+
self.sum_x: torch.Tensor
|
| 50 |
+
self.sum_x2: torch.Tensor
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def mean(self):
|
| 54 |
+
mean = self.sum_x / self.counts
|
| 55 |
+
return mean
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def std(self):
|
| 59 |
+
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
| 60 |
+
return std
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def target_std(self):
|
| 64 |
+
return 1
|
| 65 |
+
|
| 66 |
+
def project_sample(self, x: torch.Tensor):
|
| 67 |
+
assert x.dim() == 4
|
| 68 |
+
if self.counts.item() < self.num_samples:
|
| 69 |
+
self.counts += len(x)
|
| 70 |
+
self.sum_x += x.mean(dim=(2,)).sum(dim=0)
|
| 71 |
+
self.sum_x2 += x.pow(2).mean(dim=(2,)).sum(dim=0)
|
| 72 |
+
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
| 73 |
+
x = (x - self.mean.view(1, -1, 1, 32).contiguous()) * rescale.view(1, -1, 1, 32).contiguous()
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
def return_sample(self, x: torch.Tensor):
|
| 77 |
+
assert x.dim() == 4
|
| 78 |
+
rescale = (self.std / self.target_std) ** self.power_std
|
| 79 |
+
x = x * rescale.view(1, -1, 1, 32).contiguous() + self.mean.view(1, -1, 1, 32).contiguous()
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class BASECFM(torch.nn.Module, ABC):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
estimator,
|
| 87 |
+
):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.sigma_min = 1e-4
|
| 90 |
+
|
| 91 |
+
self.estimator = estimator
|
| 92 |
+
|
| 93 |
+
@torch.inference_mode()
|
| 94 |
+
def forward(self, mu, n_timesteps, temperature=1.0):
|
| 95 |
+
"""Forward diffusion
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
mu (torch.Tensor): output of encoder
|
| 99 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 100 |
+
n_timesteps (int): number of diffusion steps
|
| 101 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
sample: generated mel-spectrogram
|
| 105 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 106 |
+
"""
|
| 107 |
+
z = torch.randn_like(mu) * temperature
|
| 108 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 109 |
+
return self.solve_euler(z, t_span=t_span)
|
| 110 |
+
|
| 111 |
+
def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, added_cond_kwargs, guidance_scale):
|
| 112 |
+
"""
|
| 113 |
+
Fixed euler solver for ODEs.
|
| 114 |
+
Args:
|
| 115 |
+
x (torch.Tensor): random noise
|
| 116 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 117 |
+
shape: (n_timesteps + 1,)
|
| 118 |
+
mu (torch.Tensor): output of encoder
|
| 119 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 120 |
+
"""
|
| 121 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 122 |
+
noise = x.clone()
|
| 123 |
+
|
| 124 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 125 |
+
# Or in future might add like a return_all_steps flag
|
| 126 |
+
sol = []
|
| 127 |
+
|
| 128 |
+
for step in tqdm(range(1, len(t_span))):
|
| 129 |
+
x[:,:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,:,0:incontext_length,:] + t * incontext_x[:,:,0:incontext_length,:]
|
| 130 |
+
if(guidance_scale > 1.0):
|
| 131 |
+
dphi_dt = self.estimator( \
|
| 132 |
+
torch.cat([ \
|
| 133 |
+
torch.cat([x, x], 0), \
|
| 134 |
+
torch.cat([incontext_x, incontext_x], 0), \
|
| 135 |
+
torch.cat([torch.zeros_like(mu), mu], 0), \
|
| 136 |
+
], 1), \
|
| 137 |
+
timestep = t.unsqueeze(-1).repeat(2), \
|
| 138 |
+
added_cond_kwargs={k:torch.cat([v,v],0) for k,v in added_cond_kwargs.items()}).sample
|
| 139 |
+
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
| 140 |
+
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
| 141 |
+
else:
|
| 142 |
+
dphi_dt = self.estimator(torch.cat([x, incontext_x, mu], 1), \
|
| 143 |
+
timestep = t.unsqueeze(-1),
|
| 144 |
+
added_cond_kwargs=added_cond_kwargs).sample
|
| 145 |
+
|
| 146 |
+
x = x + dt * dphi_dt
|
| 147 |
+
t = t + dt
|
| 148 |
+
sol.append(x)
|
| 149 |
+
if step < len(t_span) - 1:
|
| 150 |
+
dt = t_span[step + 1] - t
|
| 151 |
+
|
| 152 |
+
return sol[-1]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class PromptCondAudioDiffusion(nn.Module):
|
| 156 |
+
def __init__(
|
| 157 |
+
self,
|
| 158 |
+
num_channels,
|
| 159 |
+
unet_model_name=None,
|
| 160 |
+
unet_model_config_path=None,
|
| 161 |
+
snr_gamma=None,
|
| 162 |
+
uncondition=True,
|
| 163 |
+
out_paint=False,
|
| 164 |
+
):
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
| 168 |
+
|
| 169 |
+
self.unet_model_name = unet_model_name
|
| 170 |
+
self.unet_model_config_path = unet_model_config_path
|
| 171 |
+
self.snr_gamma = snr_gamma
|
| 172 |
+
self.uncondition = uncondition
|
| 173 |
+
self.num_channels = num_channels
|
| 174 |
+
|
| 175 |
+
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
| 176 |
+
self.normfeat = Feature2DProcessor(dim=num_channels)
|
| 177 |
+
|
| 178 |
+
self.sample_rate = 48000
|
| 179 |
+
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
| 180 |
+
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
| 181 |
+
muencoder_dir = "muq_dev/muq_fairseq"
|
| 182 |
+
muencoder_ckpt = "muq_dev/muq.pt"
|
| 183 |
+
|
| 184 |
+
self.muencoder = load_model(
|
| 185 |
+
model_dir=os.path.abspath(muencoder_dir),
|
| 186 |
+
checkpoint_dir=os.path.abspath(muencoder_ckpt),
|
| 187 |
+
)
|
| 188 |
+
self.rsq48tomuencoder = torchaudio.transforms.Resample(48000, 24000)
|
| 189 |
+
for v in self.muencoder.parameters():v.requires_grad = False
|
| 190 |
+
self.rvq_muencoder_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 191 |
+
self.cond_muencoder_emb = nn.Linear(1024, 16*32)
|
| 192 |
+
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
| 193 |
+
|
| 194 |
+
unet = Transformer2DModel.from_config(
|
| 195 |
+
unet_model_config_path,
|
| 196 |
+
)
|
| 197 |
+
self.set_from = "random"
|
| 198 |
+
self.cfm_wrapper = BASECFM(unet)
|
| 199 |
+
print("Transformer initialized from pretrain.")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def compute_snr(self, timesteps):
|
| 203 |
+
"""
|
| 204 |
+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
| 205 |
+
"""
|
| 206 |
+
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
| 207 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 208 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 209 |
+
|
| 210 |
+
# Expand the tensors.
|
| 211 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
| 212 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 213 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 214 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 215 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 216 |
+
|
| 217 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 218 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
| 219 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
| 220 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
| 221 |
+
|
| 222 |
+
# Compute SNR.
|
| 223 |
+
snr = (alpha / sigma) ** 2
|
| 224 |
+
return snr
|
| 225 |
+
|
| 226 |
+
def preprocess_audio(self, input_audios, threshold=0.9):
|
| 227 |
+
assert len(input_audios.shape) == 2, input_audios.shape
|
| 228 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
| 229 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
| 230 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| 231 |
+
return input_audios/norm_value.unsqueeze(-1)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def extract_muencoder_embeds(self, input_audio_0,input_audio_1,layer):
|
| 237 |
+
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
| 238 |
+
input_wav_mean = self.muencoder(self.rsq48tomuencoder(input_wav_mean), features_only = True)
|
| 239 |
+
layer_results = input_wav_mean['layer_results']
|
| 240 |
+
muencoder_emb = layer_results[layer]
|
| 241 |
+
muencoder_emb = muencoder_emb.permute(0,2,1).contiguous()
|
| 242 |
+
return muencoder_emb
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def init_device_dtype(self, device, dtype):
|
| 248 |
+
self.device = device
|
| 249 |
+
self.dtype = dtype
|
| 250 |
+
|
| 251 |
+
@torch.no_grad()
|
| 252 |
+
def fetch_codes(self, input_audios, additional_feats,layer):
|
| 253 |
+
input_audio_0 = input_audios[[0],:]
|
| 254 |
+
input_audio_1 = input_audios[[1],:]
|
| 255 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 256 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 257 |
+
|
| 258 |
+
self.muencoder.eval()
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
|
| 262 |
+
muencoder_emb = muencoder_emb.detach()
|
| 263 |
+
|
| 264 |
+
self.rvq_muencoder_emb.eval()
|
| 265 |
+
quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
spk_embeds = None
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
return [codes_muencoder_emb], [muencoder_emb], spk_embeds
|
| 272 |
+
@torch.no_grad()
|
| 273 |
+
def fetch_codes_batch(self, input_audios, additional_feats,layer):
|
| 274 |
+
input_audio_0 = input_audios[:,0,:]
|
| 275 |
+
input_audio_1 = input_audios[:,1,:]
|
| 276 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 277 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 278 |
+
|
| 279 |
+
self.muencoder.eval()
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
|
| 283 |
+
muencoder_emb = muencoder_emb.detach()
|
| 284 |
+
|
| 285 |
+
self.rvq_muencoder_emb.eval()
|
| 286 |
+
quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb) # b,d,t
|
| 287 |
+
|
| 288 |
+
spk_embeds = None
|
| 289 |
+
|
| 290 |
+
return [codes_muencoder_emb], [muencoder_emb], spk_embeds
|
| 291 |
+
@torch.no_grad()
|
| 292 |
+
def inference_codes(self, codes, spk_embeds, true_latents, latent_length,incontext_length, additional_feats,
|
| 293 |
+
guidance_scale=2, num_steps=20,
|
| 294 |
+
disable_progress=True, scenario='start_seg'):
|
| 295 |
+
classifier_free_guidance = guidance_scale > 1.0
|
| 296 |
+
device = self.device
|
| 297 |
+
dtype = self.dtype
|
| 298 |
+
codes_muencoder_emb = codes[0]
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
batch_size = codes_muencoder_emb.shape[0]
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
quantized_muencoder_emb,_,_=self.rvq_muencoder_emb.from_codes(codes_muencoder_emb)
|
| 305 |
+
|
| 306 |
+
quantized_muencoder_emb = self.cond_muencoder_emb(quantized_muencoder_emb.permute(0,2,1)) # b t 16*32
|
| 307 |
+
quantized_muencoder_emb = quantized_muencoder_emb.reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2, 16, 32).reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2*16, 32).permute(0,2,1,3).contiguous() # b 32 t f
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
num_frames = quantized_muencoder_emb.shape[-2]
|
| 311 |
+
|
| 312 |
+
num_channels_latents = self.num_channels
|
| 313 |
+
latents = self.prepare_latents(batch_size, num_frames, num_channels_latents, dtype, device)
|
| 314 |
+
|
| 315 |
+
bsz, _, height, width = latents.shape
|
| 316 |
+
resolution = torch.tensor([height, width]).repeat(bsz, 1)
|
| 317 |
+
aspect_ratio = torch.tensor([float(height / width)]).repeat(bsz, 1)
|
| 318 |
+
resolution = resolution.to(dtype=quantized_muencoder_emb.dtype, device=device)
|
| 319 |
+
aspect_ratio = aspect_ratio.to(dtype=quantized_muencoder_emb.dtype, device=device)
|
| 320 |
+
if classifier_free_guidance:
|
| 321 |
+
resolution = torch.cat([resolution, resolution], 0)
|
| 322 |
+
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], 0)
|
| 323 |
+
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
| 324 |
+
|
| 325 |
+
latent_masks = torch.zeros(latents.shape[0], latents.shape[2], dtype=torch.int64, device=latents.device)
|
| 326 |
+
latent_masks[:,0:latent_length] = 2
|
| 327 |
+
if(scenario=='other_seg'):
|
| 328 |
+
latent_masks[:,0:incontext_length] = 1
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
quantized_muencoder_emb = (latent_masks > 0.5).unsqueeze(1).unsqueeze(-1) * quantized_muencoder_emb \
|
| 333 |
+
+ (latent_masks < 0.5).unsqueeze(1).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,32,1,32)
|
| 334 |
+
true_latents = self.normfeat.project_sample(true_latents)
|
| 335 |
+
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(1).unsqueeze(-1).float()
|
| 336 |
+
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
| 337 |
+
|
| 338 |
+
additional_model_input = torch.cat([quantized_muencoder_emb],1)
|
| 339 |
+
|
| 340 |
+
temperature = 1.0
|
| 341 |
+
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_muencoder_emb.device)
|
| 342 |
+
latents = self.cfm_wrapper.solve_euler(latents * temperature, incontext_latents, incontext_length, t_span, additional_model_input, added_cond_kwargs, guidance_scale)
|
| 343 |
+
|
| 344 |
+
latents[:,:,0:incontext_length,:] = incontext_latents[:,:,0:incontext_length,:]
|
| 345 |
+
latents = self.normfeat.return_sample(latents)
|
| 346 |
+
return latents
|
| 347 |
+
|
| 348 |
+
@torch.no_grad()
|
| 349 |
+
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
| 350 |
+
disable_progress=True,layer=5,scenario='start_seg'):
|
| 351 |
+
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
| 352 |
+
|
| 353 |
+
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
| 354 |
+
guidance_scale=guidance_scale, num_steps=num_steps, \
|
| 355 |
+
disable_progress=disable_progress,scenario=scenario)
|
| 356 |
+
return latents
|
| 357 |
+
|
| 358 |
+
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
| 359 |
+
divisor = 4
|
| 360 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
| 361 |
+
if(num_frames%divisor>0):
|
| 362 |
+
num_frames = round(num_frames/float(divisor))*divisor
|
| 363 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
| 364 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
| 365 |
+
return latents
|
| 366 |
+
|
| 367 |
+
|
MuCodec/models/attention.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Any, Dict, Optional
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
from diffusers.utils import USE_PEFT_BACKEND
|
| 21 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 22 |
+
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
|
| 23 |
+
from diffusers.models.attention_processor import Attention
|
| 24 |
+
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
| 25 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
| 26 |
+
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _chunked_feed_forward(
|
| 30 |
+
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
|
| 31 |
+
):
|
| 32 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 33 |
+
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
| 39 |
+
if lora_scale is None:
|
| 40 |
+
ff_output = torch.cat(
|
| 41 |
+
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
| 42 |
+
dim=chunk_dim,
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
|
| 46 |
+
ff_output = torch.cat(
|
| 47 |
+
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
| 48 |
+
dim=chunk_dim,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return ff_output
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@maybe_allow_in_graph
|
| 55 |
+
class GatedSelfAttentionDense(nn.Module):
|
| 56 |
+
r"""
|
| 57 |
+
A gated self-attention dense layer that combines visual features and object features.
|
| 58 |
+
|
| 59 |
+
Parameters:
|
| 60 |
+
query_dim (`int`): The number of channels in the query.
|
| 61 |
+
context_dim (`int`): The number of channels in the context.
|
| 62 |
+
n_heads (`int`): The number of heads to use for attention.
|
| 63 |
+
d_head (`int`): The number of channels in each head.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
| 70 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
| 71 |
+
|
| 72 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
| 73 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
| 74 |
+
|
| 75 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
| 76 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
| 77 |
+
|
| 78 |
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
| 79 |
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
| 80 |
+
|
| 81 |
+
self.enabled = True
|
| 82 |
+
|
| 83 |
+
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
if not self.enabled:
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
n_visual = x.shape[1]
|
| 88 |
+
objs = self.linear(objs)
|
| 89 |
+
|
| 90 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
| 91 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
| 92 |
+
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@maybe_allow_in_graph
|
| 97 |
+
class BasicTransformerBlock(nn.Module):
|
| 98 |
+
r"""
|
| 99 |
+
A basic Transformer block.
|
| 100 |
+
|
| 101 |
+
Parameters:
|
| 102 |
+
dim (`int`): The number of channels in the input and output.
|
| 103 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 104 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 105 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 106 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 107 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 108 |
+
num_embeds_ada_norm (:
|
| 109 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
| 110 |
+
attention_bias (:
|
| 111 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
| 112 |
+
only_cross_attention (`bool`, *optional*):
|
| 113 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
| 114 |
+
double_self_attention (`bool`, *optional*):
|
| 115 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
| 116 |
+
upcast_attention (`bool`, *optional*):
|
| 117 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
| 118 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
| 119 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 120 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
| 121 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
| 122 |
+
final_dropout (`bool` *optional*, defaults to False):
|
| 123 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 124 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
| 125 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
| 126 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
| 127 |
+
The type of positional embeddings to apply to.
|
| 128 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
| 129 |
+
The maximum number of positional embeddings to apply.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
dim: int,
|
| 135 |
+
num_attention_heads: int,
|
| 136 |
+
attention_head_dim: int,
|
| 137 |
+
dropout=0.0,
|
| 138 |
+
cross_attention_dim: Optional[int] = None,
|
| 139 |
+
activation_fn: str = "geglu",
|
| 140 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 141 |
+
attention_bias: bool = False,
|
| 142 |
+
only_cross_attention: bool = False,
|
| 143 |
+
double_self_attention: bool = False,
|
| 144 |
+
upcast_attention: bool = False,
|
| 145 |
+
norm_elementwise_affine: bool = True,
|
| 146 |
+
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
|
| 147 |
+
norm_eps: float = 1e-5,
|
| 148 |
+
final_dropout: bool = False,
|
| 149 |
+
attention_type: str = "default",
|
| 150 |
+
positional_embeddings: Optional[str] = None,
|
| 151 |
+
num_positional_embeddings: Optional[int] = None,
|
| 152 |
+
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
| 153 |
+
ada_norm_bias: Optional[int] = None,
|
| 154 |
+
ff_inner_dim: Optional[int] = None,
|
| 155 |
+
ff_bias: bool = True,
|
| 156 |
+
attention_out_bias: bool = True,
|
| 157 |
+
):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.only_cross_attention = only_cross_attention
|
| 160 |
+
|
| 161 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
| 162 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
| 163 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
| 164 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
| 165 |
+
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
| 166 |
+
|
| 167 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
| 170 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
| 174 |
+
raise ValueError(
|
| 175 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if positional_embeddings == "sinusoidal":
|
| 179 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
| 180 |
+
else:
|
| 181 |
+
self.pos_embed = None
|
| 182 |
+
|
| 183 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 184 |
+
# 1. Self-Attn
|
| 185 |
+
if self.use_ada_layer_norm:
|
| 186 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 187 |
+
elif self.use_ada_layer_norm_zero:
|
| 188 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
| 189 |
+
elif self.use_ada_layer_norm_continuous:
|
| 190 |
+
self.norm1 = AdaLayerNormContinuous(
|
| 191 |
+
dim,
|
| 192 |
+
ada_norm_continous_conditioning_embedding_dim,
|
| 193 |
+
norm_elementwise_affine,
|
| 194 |
+
norm_eps,
|
| 195 |
+
ada_norm_bias,
|
| 196 |
+
"rms_norm",
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
| 200 |
+
|
| 201 |
+
self.attn1 = Attention(
|
| 202 |
+
query_dim=dim,
|
| 203 |
+
heads=num_attention_heads,
|
| 204 |
+
dim_head=attention_head_dim,
|
| 205 |
+
dropout=dropout,
|
| 206 |
+
bias=attention_bias,
|
| 207 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
| 208 |
+
upcast_attention=upcast_attention,
|
| 209 |
+
out_bias=attention_out_bias,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# 2. Cross-Attn
|
| 213 |
+
if cross_attention_dim is not None or double_self_attention:
|
| 214 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
| 215 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
| 216 |
+
# the second cross attention block.
|
| 217 |
+
if self.use_ada_layer_norm:
|
| 218 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 219 |
+
elif self.use_ada_layer_norm_continuous:
|
| 220 |
+
self.norm2 = AdaLayerNormContinuous(
|
| 221 |
+
dim,
|
| 222 |
+
ada_norm_continous_conditioning_embedding_dim,
|
| 223 |
+
norm_elementwise_affine,
|
| 224 |
+
norm_eps,
|
| 225 |
+
ada_norm_bias,
|
| 226 |
+
"rms_norm",
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
| 230 |
+
|
| 231 |
+
self.attn2 = Attention(
|
| 232 |
+
query_dim=dim,
|
| 233 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
| 234 |
+
heads=num_attention_heads,
|
| 235 |
+
dim_head=attention_head_dim,
|
| 236 |
+
dropout=dropout,
|
| 237 |
+
bias=attention_bias,
|
| 238 |
+
upcast_attention=upcast_attention,
|
| 239 |
+
out_bias=attention_out_bias,
|
| 240 |
+
) # is self-attn if encoder_hidden_states is none
|
| 241 |
+
else:
|
| 242 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
| 243 |
+
self.attn2 = None
|
| 244 |
+
|
| 245 |
+
# 3. Feed-forward
|
| 246 |
+
if self.use_ada_layer_norm_continuous:
|
| 247 |
+
self.norm3 = AdaLayerNormContinuous(
|
| 248 |
+
dim,
|
| 249 |
+
ada_norm_continous_conditioning_embedding_dim,
|
| 250 |
+
norm_elementwise_affine,
|
| 251 |
+
norm_eps,
|
| 252 |
+
ada_norm_bias,
|
| 253 |
+
"layer_norm",
|
| 254 |
+
)
|
| 255 |
+
elif not self.use_ada_layer_norm_single:
|
| 256 |
+
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
| 257 |
+
|
| 258 |
+
self.ff = FeedForward(
|
| 259 |
+
dim,
|
| 260 |
+
dropout=dropout,
|
| 261 |
+
activation_fn=activation_fn,
|
| 262 |
+
final_dropout=final_dropout,
|
| 263 |
+
inner_dim=ff_inner_dim,
|
| 264 |
+
bias=ff_bias,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# 4. Fuser
|
| 268 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
| 269 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
| 270 |
+
|
| 271 |
+
# 5. Scale-shift for PixArt-Alpha.
|
| 272 |
+
if self.use_ada_layer_norm_single:
|
| 273 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
| 274 |
+
|
| 275 |
+
# let chunk size default to None
|
| 276 |
+
self._chunk_size = None
|
| 277 |
+
self._chunk_dim = 0
|
| 278 |
+
|
| 279 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
| 280 |
+
# Sets chunk feed-forward
|
| 281 |
+
self._chunk_size = chunk_size
|
| 282 |
+
self._chunk_dim = dim
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
hidden_states: torch.FloatTensor,
|
| 287 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 288 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 289 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 290 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 291 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 292 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 293 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 294 |
+
) -> torch.FloatTensor:
|
| 295 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 296 |
+
# 0. Self-Attention
|
| 297 |
+
batch_size = hidden_states.shape[0]
|
| 298 |
+
|
| 299 |
+
if self.use_ada_layer_norm:
|
| 300 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
| 301 |
+
elif self.use_ada_layer_norm_zero:
|
| 302 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
| 303 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 304 |
+
)
|
| 305 |
+
elif self.use_layer_norm:
|
| 306 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 307 |
+
elif self.use_ada_layer_norm_continuous:
|
| 308 |
+
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 309 |
+
elif self.use_ada_layer_norm_single:
|
| 310 |
+
# print("Using PixArt-Alpha norm")
|
| 311 |
+
# print("time step: ", timestep.shape)
|
| 312 |
+
# print("self.scale_shift_table: ", self.scale_shift_table.shape)
|
| 313 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 314 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
| 315 |
+
).chunk(6, dim=1)
|
| 316 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 317 |
+
# print("scale_msa: ", scale_msa.shape)
|
| 318 |
+
# print("shift_msa: ", shift_msa.shape)
|
| 319 |
+
#scale_msa: torch.Size([5, 1, 1152])
|
| 320 |
+
#shift_msa: torch.Size([5, 1, 1152])
|
| 321 |
+
# exit()
|
| 322 |
+
# print("before: ", norm_hidden_states.shape)
|
| 323 |
+
#before: torch.Size([5, 3584, 1152])
|
| 324 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 325 |
+
# print("after: ", norm_hidden_states.shape)
|
| 326 |
+
#before: torch.Size([5, 3584, 1152])
|
| 327 |
+
# exit()
|
| 328 |
+
norm_hidden_states = norm_hidden_states.squeeze(1)
|
| 329 |
+
else:
|
| 330 |
+
raise ValueError("Incorrect norm used")
|
| 331 |
+
|
| 332 |
+
if self.pos_embed is not None:
|
| 333 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# 1. Retrieve lora scale.
|
| 337 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| 338 |
+
|
| 339 |
+
# 2. Prepare GLIGEN inputs
|
| 340 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 341 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
| 342 |
+
|
| 343 |
+
attn_output = self.attn1(
|
| 344 |
+
norm_hidden_states,
|
| 345 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
| 346 |
+
attention_mask=attention_mask,
|
| 347 |
+
**cross_attention_kwargs,
|
| 348 |
+
)
|
| 349 |
+
if self.use_ada_layer_norm_zero:
|
| 350 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 351 |
+
elif self.use_ada_layer_norm_single:
|
| 352 |
+
attn_output = gate_msa * attn_output
|
| 353 |
+
|
| 354 |
+
hidden_states = attn_output + hidden_states
|
| 355 |
+
if hidden_states.ndim == 4:
|
| 356 |
+
hidden_states = hidden_states.squeeze(1)
|
| 357 |
+
|
| 358 |
+
# 2.5 GLIGEN Control
|
| 359 |
+
if gligen_kwargs is not None:
|
| 360 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
| 361 |
+
|
| 362 |
+
# 3. Cross-Attention
|
| 363 |
+
if self.attn2 is not None:
|
| 364 |
+
if self.use_ada_layer_norm:
|
| 365 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
| 366 |
+
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
| 367 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 368 |
+
elif self.use_ada_layer_norm_single:
|
| 369 |
+
# For PixArt norm2 isn't applied here:
|
| 370 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
| 371 |
+
norm_hidden_states = hidden_states
|
| 372 |
+
elif self.use_ada_layer_norm_continuous:
|
| 373 |
+
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 374 |
+
else:
|
| 375 |
+
raise ValueError("Incorrect norm")
|
| 376 |
+
|
| 377 |
+
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
|
| 378 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
| 379 |
+
|
| 380 |
+
attn_output = self.attn2(
|
| 381 |
+
norm_hidden_states,
|
| 382 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 383 |
+
attention_mask=encoder_attention_mask,
|
| 384 |
+
**cross_attention_kwargs,
|
| 385 |
+
)
|
| 386 |
+
hidden_states = attn_output + hidden_states
|
| 387 |
+
|
| 388 |
+
# 4. Feed-forward
|
| 389 |
+
if self.use_ada_layer_norm_continuous:
|
| 390 |
+
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 391 |
+
elif not self.use_ada_layer_norm_single:
|
| 392 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 393 |
+
|
| 394 |
+
if self.use_ada_layer_norm_zero:
|
| 395 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 396 |
+
|
| 397 |
+
if self.use_ada_layer_norm_single:
|
| 398 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 399 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 400 |
+
|
| 401 |
+
if self._chunk_size is not None:
|
| 402 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 403 |
+
ff_output = _chunked_feed_forward(
|
| 404 |
+
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
| 408 |
+
|
| 409 |
+
if self.use_ada_layer_norm_zero:
|
| 410 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 411 |
+
elif self.use_ada_layer_norm_single:
|
| 412 |
+
ff_output = gate_mlp * ff_output
|
| 413 |
+
|
| 414 |
+
hidden_states = ff_output + hidden_states
|
| 415 |
+
if hidden_states.ndim == 4:
|
| 416 |
+
hidden_states = hidden_states.squeeze(1)
|
| 417 |
+
|
| 418 |
+
return hidden_states
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
@maybe_allow_in_graph
|
| 422 |
+
class TemporalBasicTransformerBlock(nn.Module):
|
| 423 |
+
r"""
|
| 424 |
+
A basic Transformer block for video like data.
|
| 425 |
+
|
| 426 |
+
Parameters:
|
| 427 |
+
dim (`int`): The number of channels in the input and output.
|
| 428 |
+
time_mix_inner_dim (`int`): The number of channels for temporal attention.
|
| 429 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 430 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 431 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
def __init__(
|
| 435 |
+
self,
|
| 436 |
+
dim: int,
|
| 437 |
+
time_mix_inner_dim: int,
|
| 438 |
+
num_attention_heads: int,
|
| 439 |
+
attention_head_dim: int,
|
| 440 |
+
cross_attention_dim: Optional[int] = None,
|
| 441 |
+
):
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.is_res = dim == time_mix_inner_dim
|
| 444 |
+
|
| 445 |
+
self.norm_in = nn.LayerNorm(dim)
|
| 446 |
+
|
| 447 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 448 |
+
# 1. Self-Attn
|
| 449 |
+
self.norm_in = nn.LayerNorm(dim)
|
| 450 |
+
self.ff_in = FeedForward(
|
| 451 |
+
dim,
|
| 452 |
+
dim_out=time_mix_inner_dim,
|
| 453 |
+
activation_fn="geglu",
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
|
| 457 |
+
self.attn1 = Attention(
|
| 458 |
+
query_dim=time_mix_inner_dim,
|
| 459 |
+
heads=num_attention_heads,
|
| 460 |
+
dim_head=attention_head_dim,
|
| 461 |
+
cross_attention_dim=None,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
# 2. Cross-Attn
|
| 465 |
+
if cross_attention_dim is not None:
|
| 466 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
| 467 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
| 468 |
+
# the second cross attention block.
|
| 469 |
+
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
|
| 470 |
+
self.attn2 = Attention(
|
| 471 |
+
query_dim=time_mix_inner_dim,
|
| 472 |
+
cross_attention_dim=cross_attention_dim,
|
| 473 |
+
heads=num_attention_heads,
|
| 474 |
+
dim_head=attention_head_dim,
|
| 475 |
+
) # is self-attn if encoder_hidden_states is none
|
| 476 |
+
else:
|
| 477 |
+
self.norm2 = None
|
| 478 |
+
self.attn2 = None
|
| 479 |
+
|
| 480 |
+
# 3. Feed-forward
|
| 481 |
+
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
|
| 482 |
+
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
|
| 483 |
+
|
| 484 |
+
# let chunk size default to None
|
| 485 |
+
self._chunk_size = None
|
| 486 |
+
self._chunk_dim = None
|
| 487 |
+
|
| 488 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
|
| 489 |
+
# Sets chunk feed-forward
|
| 490 |
+
self._chunk_size = chunk_size
|
| 491 |
+
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
|
| 492 |
+
self._chunk_dim = 1
|
| 493 |
+
|
| 494 |
+
def forward(
|
| 495 |
+
self,
|
| 496 |
+
hidden_states: torch.FloatTensor,
|
| 497 |
+
num_frames: int,
|
| 498 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 499 |
+
) -> torch.FloatTensor:
|
| 500 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 501 |
+
# 0. Self-Attention
|
| 502 |
+
batch_size = hidden_states.shape[0]
|
| 503 |
+
|
| 504 |
+
batch_frames, seq_length, channels = hidden_states.shape
|
| 505 |
+
batch_size = batch_frames // num_frames
|
| 506 |
+
|
| 507 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
| 508 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
| 509 |
+
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
|
| 510 |
+
|
| 511 |
+
residual = hidden_states
|
| 512 |
+
hidden_states = self.norm_in(hidden_states)
|
| 513 |
+
|
| 514 |
+
if self._chunk_size is not None:
|
| 515 |
+
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
| 516 |
+
else:
|
| 517 |
+
hidden_states = self.ff_in(hidden_states)
|
| 518 |
+
|
| 519 |
+
if self.is_res:
|
| 520 |
+
hidden_states = hidden_states + residual
|
| 521 |
+
|
| 522 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 523 |
+
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
| 524 |
+
hidden_states = attn_output + hidden_states
|
| 525 |
+
|
| 526 |
+
# 3. Cross-Attention
|
| 527 |
+
if self.attn2 is not None:
|
| 528 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 529 |
+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
| 530 |
+
hidden_states = attn_output + hidden_states
|
| 531 |
+
|
| 532 |
+
# 4. Feed-forward
|
| 533 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 534 |
+
|
| 535 |
+
if self._chunk_size is not None:
|
| 536 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
| 537 |
+
else:
|
| 538 |
+
ff_output = self.ff(norm_hidden_states)
|
| 539 |
+
|
| 540 |
+
if self.is_res:
|
| 541 |
+
hidden_states = ff_output + hidden_states
|
| 542 |
+
else:
|
| 543 |
+
hidden_states = ff_output
|
| 544 |
+
|
| 545 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
| 546 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
| 547 |
+
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
| 548 |
+
|
| 549 |
+
return hidden_states
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class SkipFFTransformerBlock(nn.Module):
|
| 553 |
+
def __init__(
|
| 554 |
+
self,
|
| 555 |
+
dim: int,
|
| 556 |
+
num_attention_heads: int,
|
| 557 |
+
attention_head_dim: int,
|
| 558 |
+
kv_input_dim: int,
|
| 559 |
+
kv_input_dim_proj_use_bias: bool,
|
| 560 |
+
dropout=0.0,
|
| 561 |
+
cross_attention_dim: Optional[int] = None,
|
| 562 |
+
attention_bias: bool = False,
|
| 563 |
+
attention_out_bias: bool = True,
|
| 564 |
+
):
|
| 565 |
+
super().__init__()
|
| 566 |
+
if kv_input_dim != dim:
|
| 567 |
+
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
|
| 568 |
+
else:
|
| 569 |
+
self.kv_mapper = None
|
| 570 |
+
|
| 571 |
+
self.norm1 = RMSNorm(dim, 1e-06)
|
| 572 |
+
|
| 573 |
+
self.attn1 = Attention(
|
| 574 |
+
query_dim=dim,
|
| 575 |
+
heads=num_attention_heads,
|
| 576 |
+
dim_head=attention_head_dim,
|
| 577 |
+
dropout=dropout,
|
| 578 |
+
bias=attention_bias,
|
| 579 |
+
cross_attention_dim=cross_attention_dim,
|
| 580 |
+
out_bias=attention_out_bias,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
self.norm2 = RMSNorm(dim, 1e-06)
|
| 584 |
+
|
| 585 |
+
self.attn2 = Attention(
|
| 586 |
+
query_dim=dim,
|
| 587 |
+
cross_attention_dim=cross_attention_dim,
|
| 588 |
+
heads=num_attention_heads,
|
| 589 |
+
dim_head=attention_head_dim,
|
| 590 |
+
dropout=dropout,
|
| 591 |
+
bias=attention_bias,
|
| 592 |
+
out_bias=attention_out_bias,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
|
| 596 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 597 |
+
|
| 598 |
+
if self.kv_mapper is not None:
|
| 599 |
+
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
|
| 600 |
+
|
| 601 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 602 |
+
|
| 603 |
+
attn_output = self.attn1(
|
| 604 |
+
norm_hidden_states,
|
| 605 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 606 |
+
**cross_attention_kwargs,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
hidden_states = attn_output + hidden_states
|
| 610 |
+
|
| 611 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 612 |
+
|
| 613 |
+
attn_output = self.attn2(
|
| 614 |
+
norm_hidden_states,
|
| 615 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 616 |
+
**cross_attention_kwargs,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
hidden_states = attn_output + hidden_states
|
| 620 |
+
|
| 621 |
+
return hidden_states
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
class FeedForward(nn.Module):
|
| 625 |
+
r"""
|
| 626 |
+
A feed-forward layer.
|
| 627 |
+
|
| 628 |
+
Parameters:
|
| 629 |
+
dim (`int`): The number of channels in the input.
|
| 630 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 631 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 632 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 633 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 634 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 635 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
| 636 |
+
"""
|
| 637 |
+
|
| 638 |
+
def __init__(
|
| 639 |
+
self,
|
| 640 |
+
dim: int,
|
| 641 |
+
dim_out: Optional[int] = None,
|
| 642 |
+
mult: int = 4,
|
| 643 |
+
dropout: float = 0.0,
|
| 644 |
+
activation_fn: str = "geglu",
|
| 645 |
+
final_dropout: bool = False,
|
| 646 |
+
inner_dim=None,
|
| 647 |
+
bias: bool = True,
|
| 648 |
+
):
|
| 649 |
+
super().__init__()
|
| 650 |
+
if inner_dim is None:
|
| 651 |
+
inner_dim = int(dim * mult)
|
| 652 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 653 |
+
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
| 654 |
+
|
| 655 |
+
if activation_fn == "gelu":
|
| 656 |
+
act_fn = GELU(dim, inner_dim, bias=bias)
|
| 657 |
+
if activation_fn == "gelu-approximate":
|
| 658 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
| 659 |
+
elif activation_fn == "geglu":
|
| 660 |
+
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
| 661 |
+
elif activation_fn == "geglu-approximate":
|
| 662 |
+
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
| 663 |
+
|
| 664 |
+
self.net = nn.ModuleList([])
|
| 665 |
+
# project in
|
| 666 |
+
self.net.append(act_fn)
|
| 667 |
+
# project dropout
|
| 668 |
+
self.net.append(nn.Dropout(dropout))
|
| 669 |
+
# project out
|
| 670 |
+
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
|
| 671 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 672 |
+
if final_dropout:
|
| 673 |
+
self.net.append(nn.Dropout(dropout))
|
| 674 |
+
|
| 675 |
+
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
| 676 |
+
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
|
| 677 |
+
for module in self.net:
|
| 678 |
+
if isinstance(module, compatible_cls):
|
| 679 |
+
hidden_states = module(hidden_states, scale)
|
| 680 |
+
else:
|
| 681 |
+
hidden_states = module(hidden_states)
|
| 682 |
+
return hidden_states
|
MuCodec/models/transformer_2d_flow.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
| 24 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
| 25 |
+
from models.attention import BasicTransformerBlock
|
| 26 |
+
from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
|
| 27 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
| 28 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 29 |
+
from diffusers.models.embeddings import TimestepEmbedding
|
| 30 |
+
|
| 31 |
+
class PixArtAlphaCombinedFlowEmbeddings(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
For PixArt-Alpha.
|
| 34 |
+
|
| 35 |
+
Reference:
|
| 36 |
+
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.flow_t_size = 512
|
| 43 |
+
self.outdim = size_emb_dim
|
| 44 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=self.flow_t_size, time_embed_dim=embedding_dim)
|
| 45 |
+
|
| 46 |
+
self.use_additional_conditions = use_additional_conditions
|
| 47 |
+
if use_additional_conditions:
|
| 48 |
+
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 49 |
+
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
| 50 |
+
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
| 51 |
+
|
| 52 |
+
# https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87
|
| 53 |
+
def timestep_embedding(self, timesteps, max_period=10000, scale=1000):
|
| 54 |
+
"""Create sinusoidal timestep embeddings.
|
| 55 |
+
|
| 56 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 57 |
+
:param dim: the dimension of the output.
|
| 58 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 59 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 60 |
+
"""
|
| 61 |
+
half = self.flow_t_size // 2
|
| 62 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type())
|
| 63 |
+
args = timesteps[:, None] * freqs[None] * scale
|
| 64 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 65 |
+
if self.flow_t_size % 2:
|
| 66 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 67 |
+
return embedding
|
| 68 |
+
|
| 69 |
+
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
| 70 |
+
timesteps_proj = self.timestep_embedding(timestep)
|
| 71 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
| 72 |
+
|
| 73 |
+
if self.use_additional_conditions:
|
| 74 |
+
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
| 75 |
+
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
| 76 |
+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
|
| 77 |
+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
|
| 78 |
+
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
|
| 79 |
+
else:
|
| 80 |
+
conditioning = timesteps_emb
|
| 81 |
+
|
| 82 |
+
return conditioning
|
| 83 |
+
|
| 84 |
+
class AdaLayerNormSingleFlow(nn.Module):
|
| 85 |
+
r"""
|
| 86 |
+
Norm layer adaptive layer norm single (adaLN-single).
|
| 87 |
+
|
| 88 |
+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
| 89 |
+
|
| 90 |
+
Parameters:
|
| 91 |
+
embedding_dim (`int`): The size of each embedding vector.
|
| 92 |
+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
self.emb = PixArtAlphaCombinedFlowEmbeddings(
|
| 99 |
+
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.silu = nn.SiLU()
|
| 103 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
| 104 |
+
|
| 105 |
+
def forward(
|
| 106 |
+
self,
|
| 107 |
+
timestep: torch.Tensor,
|
| 108 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 109 |
+
batch_size: Optional[int] = None,
|
| 110 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
| 111 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 112 |
+
# No modulation happening here.
|
| 113 |
+
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
| 114 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class Transformer2DModelOutput(BaseOutput):
|
| 119 |
+
"""
|
| 120 |
+
The output of [`Transformer2DModel`].
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
| 124 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
| 125 |
+
distributions for the unnoised latent pixels.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
sample: torch.FloatTensor
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
| 132 |
+
"""
|
| 133 |
+
A 2D Transformer model for image-like data.
|
| 134 |
+
|
| 135 |
+
Parameters:
|
| 136 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
| 137 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
| 138 |
+
in_channels (`int`, *optional*):
|
| 139 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
| 140 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
| 141 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 142 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
| 143 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
| 144 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
| 145 |
+
num_vector_embeds (`int`, *optional*):
|
| 146 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
| 147 |
+
Includes the class for the masked latent pixel.
|
| 148 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
| 149 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
| 150 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
| 151 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
| 152 |
+
added to the hidden states.
|
| 153 |
+
|
| 154 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
| 155 |
+
attention_bias (`bool`, *optional*):
|
| 156 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
_supports_gradient_checkpointing = True
|
| 160 |
+
|
| 161 |
+
@register_to_config
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
num_attention_heads: int = 16,
|
| 165 |
+
attention_head_dim: int = 88,
|
| 166 |
+
in_channels: Optional[int] = None,
|
| 167 |
+
out_channels: Optional[int] = None,
|
| 168 |
+
num_layers: int = 1,
|
| 169 |
+
dropout: float = 0.0,
|
| 170 |
+
norm_num_groups: int = 32,
|
| 171 |
+
cross_attention_dim: Optional[int] = None,
|
| 172 |
+
attention_bias: bool = False,
|
| 173 |
+
sample_size: Optional[int] = None,
|
| 174 |
+
num_vector_embeds: Optional[int] = None,
|
| 175 |
+
patch_size: Optional[int] = None,
|
| 176 |
+
activation_fn: str = "geglu",
|
| 177 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 178 |
+
use_linear_projection: bool = False,
|
| 179 |
+
only_cross_attention: bool = False,
|
| 180 |
+
double_self_attention: bool = False,
|
| 181 |
+
upcast_attention: bool = False,
|
| 182 |
+
norm_type: str = "layer_norm",
|
| 183 |
+
norm_elementwise_affine: bool = True,
|
| 184 |
+
norm_eps: float = 1e-5,
|
| 185 |
+
attention_type: str = "default",
|
| 186 |
+
caption_channels: int = None,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.use_linear_projection = use_linear_projection
|
| 190 |
+
self.num_attention_heads = num_attention_heads
|
| 191 |
+
self.attention_head_dim = attention_head_dim
|
| 192 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 193 |
+
|
| 194 |
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
| 195 |
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
| 196 |
+
|
| 197 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
| 198 |
+
# Define whether input is continuous or discrete depending on configuration
|
| 199 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
| 200 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
| 201 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
| 202 |
+
|
| 203 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
| 204 |
+
deprecation_message = (
|
| 205 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
| 206 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
| 207 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
| 208 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
| 209 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
| 210 |
+
)
|
| 211 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
| 212 |
+
norm_type = "ada_norm"
|
| 213 |
+
|
| 214 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
| 215 |
+
raise ValueError(
|
| 216 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
| 217 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
| 218 |
+
)
|
| 219 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
| 220 |
+
raise ValueError(
|
| 221 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
| 222 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
| 223 |
+
)
|
| 224 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
| 227 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# 2. Define input layers
|
| 231 |
+
if self.is_input_continuous:
|
| 232 |
+
self.in_channels = in_channels
|
| 233 |
+
|
| 234 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 235 |
+
if use_linear_projection:
|
| 236 |
+
self.proj_in = linear_cls(in_channels, inner_dim)
|
| 237 |
+
else:
|
| 238 |
+
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 239 |
+
elif self.is_input_vectorized:
|
| 240 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
| 241 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
| 242 |
+
|
| 243 |
+
self.height = sample_size
|
| 244 |
+
self.width = sample_size
|
| 245 |
+
self.num_vector_embeds = num_vector_embeds
|
| 246 |
+
self.num_latent_pixels = self.height * self.width
|
| 247 |
+
|
| 248 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
| 249 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
| 250 |
+
)
|
| 251 |
+
elif self.is_input_patches:
|
| 252 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
| 253 |
+
|
| 254 |
+
self.height = sample_size
|
| 255 |
+
self.width = sample_size
|
| 256 |
+
|
| 257 |
+
self.patch_size = patch_size
|
| 258 |
+
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
| 259 |
+
interpolation_scale = max(interpolation_scale, 1)
|
| 260 |
+
self.pos_embed = PatchEmbed(
|
| 261 |
+
height=sample_size,
|
| 262 |
+
width=sample_size,
|
| 263 |
+
patch_size=patch_size,
|
| 264 |
+
in_channels=in_channels,
|
| 265 |
+
embed_dim=inner_dim,
|
| 266 |
+
interpolation_scale=interpolation_scale,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# 3. Define transformers blocks
|
| 270 |
+
self.transformer_blocks = nn.ModuleList(
|
| 271 |
+
[
|
| 272 |
+
BasicTransformerBlock(
|
| 273 |
+
inner_dim,
|
| 274 |
+
num_attention_heads,
|
| 275 |
+
attention_head_dim,
|
| 276 |
+
dropout=dropout,
|
| 277 |
+
cross_attention_dim=cross_attention_dim,
|
| 278 |
+
activation_fn=activation_fn,
|
| 279 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 280 |
+
attention_bias=attention_bias,
|
| 281 |
+
only_cross_attention=only_cross_attention,
|
| 282 |
+
double_self_attention=double_self_attention,
|
| 283 |
+
upcast_attention=upcast_attention,
|
| 284 |
+
norm_type=norm_type,
|
| 285 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 286 |
+
norm_eps=norm_eps,
|
| 287 |
+
attention_type=attention_type,
|
| 288 |
+
)
|
| 289 |
+
for d in range(num_layers)
|
| 290 |
+
]
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# 4. Define output layers
|
| 294 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
| 295 |
+
if self.is_input_continuous:
|
| 296 |
+
# TODO: should use out_channels for continuous projections
|
| 297 |
+
if use_linear_projection:
|
| 298 |
+
self.proj_out = linear_cls(inner_dim, in_channels)
|
| 299 |
+
else:
|
| 300 |
+
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
| 301 |
+
elif self.is_input_vectorized:
|
| 302 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
| 303 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
| 304 |
+
elif self.is_input_patches and norm_type != "ada_norm_single":
|
| 305 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 306 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
| 307 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
| 308 |
+
elif self.is_input_patches and norm_type == "ada_norm_single":
|
| 309 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 310 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
| 311 |
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
| 312 |
+
|
| 313 |
+
# 5. PixArt-Alpha blocks.
|
| 314 |
+
self.adaln_single = None
|
| 315 |
+
self.use_additional_conditions = False
|
| 316 |
+
if norm_type == "ada_norm_single":
|
| 317 |
+
self.use_additional_conditions = self.config.sample_size == 128
|
| 318 |
+
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
| 319 |
+
# additional conditions until we find better name
|
| 320 |
+
self.adaln_single = AdaLayerNormSingleFlow(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
| 321 |
+
|
| 322 |
+
self.caption_projection = None
|
| 323 |
+
if caption_channels is not None:
|
| 324 |
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
| 325 |
+
|
| 326 |
+
self.gradient_checkpointing = False
|
| 327 |
+
|
| 328 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 329 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 330 |
+
module.gradient_checkpointing = value
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
hidden_states: torch.Tensor,
|
| 335 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 336 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 337 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
| 338 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 339 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 340 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 341 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 342 |
+
return_dict: bool = True,
|
| 343 |
+
):
|
| 344 |
+
"""
|
| 345 |
+
The [`Transformer2DModel`] forward method.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
| 349 |
+
Input `hidden_states`.
|
| 350 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
| 351 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
| 352 |
+
self-attention.
|
| 353 |
+
timestep ( `torch.LongTensor`, *optional*):
|
| 354 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
| 355 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
| 356 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
| 357 |
+
`AdaLayerZeroNorm`.
|
| 358 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
| 359 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 360 |
+
`self.processor` in
|
| 361 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 362 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
| 363 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 364 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 365 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 366 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
| 367 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
| 368 |
+
|
| 369 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
| 370 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
| 371 |
+
|
| 372 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
| 373 |
+
above. This bias will be added to the cross-attention scores.
|
| 374 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 375 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 376 |
+
tuple.
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 380 |
+
`tuple` where the first element is the sample tensor.
|
| 381 |
+
"""
|
| 382 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
| 383 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
| 384 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
| 385 |
+
# expects mask of shape:
|
| 386 |
+
# [batch, key_tokens]
|
| 387 |
+
# adds singleton query_tokens dimension:
|
| 388 |
+
# [batch, 1, key_tokens]
|
| 389 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 390 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 391 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 392 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
| 393 |
+
# assume that mask is expressed as:
|
| 394 |
+
# (1 = keep, 0 = discard)
|
| 395 |
+
# convert mask into a bias that can be added to attention scores:
|
| 396 |
+
# (keep = +0, discard = -10000.0)
|
| 397 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
| 398 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 399 |
+
|
| 400 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 401 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
| 402 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
| 403 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 404 |
+
|
| 405 |
+
# Retrieve lora scale.
|
| 406 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| 407 |
+
|
| 408 |
+
# 1. Input
|
| 409 |
+
if self.is_input_continuous:
|
| 410 |
+
batch, _, height, width = hidden_states.shape
|
| 411 |
+
residual = hidden_states
|
| 412 |
+
|
| 413 |
+
hidden_states = self.norm(hidden_states)
|
| 414 |
+
if not self.use_linear_projection:
|
| 415 |
+
hidden_states = (
|
| 416 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
| 417 |
+
if not USE_PEFT_BACKEND
|
| 418 |
+
else self.proj_in(hidden_states)
|
| 419 |
+
)
|
| 420 |
+
inner_dim = hidden_states.shape[1]
|
| 421 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
| 422 |
+
else:
|
| 423 |
+
inner_dim = hidden_states.shape[1]
|
| 424 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
| 425 |
+
hidden_states = (
|
| 426 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
| 427 |
+
if not USE_PEFT_BACKEND
|
| 428 |
+
else self.proj_in(hidden_states)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
elif self.is_input_vectorized:
|
| 432 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
| 433 |
+
elif self.is_input_patches:
|
| 434 |
+
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
| 435 |
+
hidden_states = self.pos_embed(hidden_states)
|
| 436 |
+
|
| 437 |
+
if self.adaln_single is not None:
|
| 438 |
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
| 439 |
+
raise ValueError(
|
| 440 |
+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
| 441 |
+
)
|
| 442 |
+
batch_size = hidden_states.shape[0]
|
| 443 |
+
timestep, embedded_timestep = self.adaln_single(
|
| 444 |
+
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# 2. Blocks
|
| 448 |
+
if self.caption_projection is not None:
|
| 449 |
+
batch_size = hidden_states.shape[0]
|
| 450 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
| 451 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
| 452 |
+
|
| 453 |
+
for block in self.transformer_blocks:
|
| 454 |
+
if self.training and self.gradient_checkpointing:
|
| 455 |
+
|
| 456 |
+
def create_custom_forward(module, return_dict=None):
|
| 457 |
+
def custom_forward(*inputs):
|
| 458 |
+
if return_dict is not None:
|
| 459 |
+
return module(*inputs, return_dict=return_dict)
|
| 460 |
+
else:
|
| 461 |
+
return module(*inputs)
|
| 462 |
+
|
| 463 |
+
return custom_forward
|
| 464 |
+
|
| 465 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 466 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 467 |
+
create_custom_forward(block),
|
| 468 |
+
hidden_states,
|
| 469 |
+
attention_mask,
|
| 470 |
+
encoder_hidden_states,
|
| 471 |
+
encoder_attention_mask,
|
| 472 |
+
timestep,
|
| 473 |
+
cross_attention_kwargs,
|
| 474 |
+
class_labels,
|
| 475 |
+
**ckpt_kwargs,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
hidden_states = block(
|
| 479 |
+
hidden_states,
|
| 480 |
+
attention_mask=attention_mask,
|
| 481 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 482 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 483 |
+
timestep=timestep,
|
| 484 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 485 |
+
class_labels=class_labels,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# 3. Output
|
| 489 |
+
if self.is_input_continuous:
|
| 490 |
+
if not self.use_linear_projection:
|
| 491 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 492 |
+
hidden_states = (
|
| 493 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
| 494 |
+
if not USE_PEFT_BACKEND
|
| 495 |
+
else self.proj_out(hidden_states)
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
hidden_states = (
|
| 499 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
| 500 |
+
if not USE_PEFT_BACKEND
|
| 501 |
+
else self.proj_out(hidden_states)
|
| 502 |
+
)
|
| 503 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 504 |
+
|
| 505 |
+
output = hidden_states + residual
|
| 506 |
+
elif self.is_input_vectorized:
|
| 507 |
+
hidden_states = self.norm_out(hidden_states)
|
| 508 |
+
logits = self.out(hidden_states)
|
| 509 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
| 510 |
+
logits = logits.permute(0, 2, 1)
|
| 511 |
+
|
| 512 |
+
# log(p(x_0))
|
| 513 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
| 514 |
+
|
| 515 |
+
if self.is_input_patches:
|
| 516 |
+
if self.config.norm_type != "ada_norm_single":
|
| 517 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
| 518 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 519 |
+
)
|
| 520 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
| 521 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
| 522 |
+
hidden_states = self.proj_out_2(hidden_states)
|
| 523 |
+
elif self.config.norm_type == "ada_norm_single":
|
| 524 |
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
| 525 |
+
hidden_states = self.norm_out(hidden_states)
|
| 526 |
+
# Modulation
|
| 527 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
| 528 |
+
hidden_states = self.proj_out(hidden_states)
|
| 529 |
+
hidden_states = hidden_states.squeeze(1)
|
| 530 |
+
|
| 531 |
+
# unpatchify
|
| 532 |
+
if self.adaln_single is None:
|
| 533 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
| 534 |
+
hidden_states = hidden_states.reshape(
|
| 535 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
| 536 |
+
)
|
| 537 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
| 538 |
+
output = hidden_states.reshape(
|
| 539 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
if not return_dict:
|
| 543 |
+
return (output,)
|
| 544 |
+
|
| 545 |
+
return Transformer2DModelOutput(sample=output)
|
MuCodec/mp3_to_code.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import tempfile
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_args():
|
| 13 |
+
parser = argparse.ArgumentParser(
|
| 14 |
+
description="Batch encode MP3 files to MuCodec codes (recursive)."
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument("input_dir", type=Path, help="Input folder (recursive scan)")
|
| 17 |
+
parser.add_argument("output_dir", type=Path, help="Output folder for saved codes")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--ckpt",
|
| 20 |
+
type=Path,
|
| 21 |
+
default=Path(__file__).resolve().parent / "ckpt" / "mucodec.pt",
|
| 22 |
+
help="Path to MuCodec checkpoint",
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--layer-num",
|
| 26 |
+
type=int,
|
| 27 |
+
default=7,
|
| 28 |
+
help="MuCodec layer num (default follows generate.py)",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--device",
|
| 32 |
+
default="cuda:0",
|
| 33 |
+
help="Torch device, e.g. cuda:0",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--ext",
|
| 37 |
+
nargs="+",
|
| 38 |
+
default=[".mp3"],
|
| 39 |
+
help="Audio extensions to include, e.g. .mp3 .wav .flac",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--format",
|
| 43 |
+
choices=["npz", "pt", "npy", "both", "all"],
|
| 44 |
+
default="npz",
|
| 45 |
+
help="Output format for code files",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--overwrite",
|
| 49 |
+
action="store_true",
|
| 50 |
+
help="Recompute files even if output already exists (disable resume)",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--strict",
|
| 54 |
+
action="store_true",
|
| 55 |
+
help="Stop immediately on first failed file",
|
| 56 |
+
)
|
| 57 |
+
return parser.parse_args()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def list_audio_files(root: Path, exts):
|
| 61 |
+
ext_set = {e.lower() if e.startswith(".") else f".{e.lower()}" for e in exts}
|
| 62 |
+
files = [
|
| 63 |
+
p
|
| 64 |
+
for p in root.rglob("*")
|
| 65 |
+
if p.is_file() and p.suffix.lower() in ext_set
|
| 66 |
+
]
|
| 67 |
+
files.sort()
|
| 68 |
+
return files
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def expected_output_paths(output_stem: Path, fmt: str):
|
| 72 |
+
if fmt == "npz":
|
| 73 |
+
return [output_stem.with_suffix(".npz")]
|
| 74 |
+
if fmt == "pt":
|
| 75 |
+
return [output_stem.with_suffix(".pt")]
|
| 76 |
+
if fmt == "npy":
|
| 77 |
+
return [output_stem.with_suffix(".npy")]
|
| 78 |
+
if fmt == "both":
|
| 79 |
+
return [output_stem.with_suffix(".pt"), output_stem.with_suffix(".npy")]
|
| 80 |
+
if fmt == "all":
|
| 81 |
+
return [
|
| 82 |
+
output_stem.with_suffix(".npz"),
|
| 83 |
+
output_stem.with_suffix(".pt"),
|
| 84 |
+
output_stem.with_suffix(".npy"),
|
| 85 |
+
]
|
| 86 |
+
raise ValueError(f"Unsupported format: {fmt}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def save_npz_atomic(codes_np: np.ndarray, output_path: Path):
|
| 90 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
tmp_path = None
|
| 92 |
+
try:
|
| 93 |
+
with tempfile.NamedTemporaryFile(
|
| 94 |
+
mode="wb",
|
| 95 |
+
suffix=".npz",
|
| 96 |
+
dir=output_path.parent,
|
| 97 |
+
delete=False,
|
| 98 |
+
) as tmp_file:
|
| 99 |
+
tmp_path = Path(tmp_file.name)
|
| 100 |
+
np.savez_compressed(tmp_file, codes=codes_np)
|
| 101 |
+
os.replace(tmp_path, output_path)
|
| 102 |
+
except Exception:
|
| 103 |
+
if tmp_path is not None and tmp_path.exists():
|
| 104 |
+
tmp_path.unlink()
|
| 105 |
+
raise
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_codes(codes: torch.Tensor, output_stem: Path, fmt: str):
|
| 109 |
+
codes_cpu = codes.detach().cpu()
|
| 110 |
+
codes_np = codes_cpu.numpy()
|
| 111 |
+
if fmt in ("npz", "all"):
|
| 112 |
+
save_npz_atomic(codes_np, output_stem.with_suffix(".npz"))
|
| 113 |
+
if fmt in ("pt", "both", "all"):
|
| 114 |
+
torch.save(codes_cpu, output_stem.with_suffix(".pt"))
|
| 115 |
+
if fmt in ("npy", "both", "all"):
|
| 116 |
+
np.save(output_stem.with_suffix(".npy"), codes_np)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def main():
|
| 120 |
+
args = parse_args()
|
| 121 |
+
|
| 122 |
+
from generate import MuCodec
|
| 123 |
+
|
| 124 |
+
if not args.input_dir.exists() or not args.input_dir.is_dir():
|
| 125 |
+
raise ValueError(f"input_dir does not exist or is not a directory: {args.input_dir}")
|
| 126 |
+
|
| 127 |
+
if not args.ckpt.exists():
|
| 128 |
+
raise FileNotFoundError(f"Checkpoint not found: {args.ckpt}")
|
| 129 |
+
|
| 130 |
+
if args.device.startswith("cuda") and not torch.cuda.is_available():
|
| 131 |
+
raise RuntimeError("CUDA device requested but torch.cuda.is_available() is False")
|
| 132 |
+
|
| 133 |
+
audio_files = list_audio_files(args.input_dir, args.ext)
|
| 134 |
+
if not audio_files:
|
| 135 |
+
print("No audio files found.")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
args.output_dir.mkdir(parents=True, exist_ok=True)
|
| 139 |
+
|
| 140 |
+
mucodec = MuCodec(
|
| 141 |
+
model_path=str(args.ckpt),
|
| 142 |
+
layer_num=args.layer_num,
|
| 143 |
+
load_main_model=True,
|
| 144 |
+
device=args.device,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
resume_enabled = not args.overwrite
|
| 148 |
+
ok = 0
|
| 149 |
+
skipped = 0
|
| 150 |
+
failed = []
|
| 151 |
+
|
| 152 |
+
for src in tqdm(audio_files, desc="Encoding", unit="file"):
|
| 153 |
+
rel = src.relative_to(args.input_dir)
|
| 154 |
+
output_stem = (args.output_dir / rel).with_suffix("")
|
| 155 |
+
output_paths = expected_output_paths(output_stem, args.format)
|
| 156 |
+
|
| 157 |
+
if resume_enabled and all(p.exists() for p in output_paths):
|
| 158 |
+
skipped += 1
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
output_stem.parent.mkdir(parents=True, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
codes = mucodec.file2code(str(src))
|
| 165 |
+
save_codes(codes, output_stem, args.format)
|
| 166 |
+
ok += 1
|
| 167 |
+
except Exception as e:
|
| 168 |
+
failed.append((src, str(e)))
|
| 169 |
+
print(f"[FAILED] {src}: {e}")
|
| 170 |
+
if args.strict:
|
| 171 |
+
print("--strict enabled, stopping on first failure.")
|
| 172 |
+
traceback.print_exc()
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
print(
|
| 176 |
+
"Done. "
|
| 177 |
+
f"success={ok}, skipped={skipped}, failed={len(failed)}, total={len(audio_files)}"
|
| 178 |
+
)
|
| 179 |
+
if failed:
|
| 180 |
+
print("Failed files:")
|
| 181 |
+
for path, err in failed:
|
| 182 |
+
print(f"- {path}: {err}")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
| 187 |
+
|
MuCodec/muq_dev/test.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import fairseq
|
| 4 |
+
import os.path as op
|
| 5 |
+
|
| 6 |
+
root = op.dirname(op.abspath(__file__))
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class UserDirModule:
|
| 11 |
+
user_dir: str
|
| 12 |
+
|
| 13 |
+
def load_model(model_dir, checkpoint_dir):
|
| 14 |
+
'''Load Fairseq SSL model'''
|
| 15 |
+
|
| 16 |
+
model_path = UserDirModule(model_dir)
|
| 17 |
+
fairseq.utils.import_user_module(model_path)
|
| 18 |
+
|
| 19 |
+
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir], strict=False)
|
| 20 |
+
model = model[0]
|
| 21 |
+
|
| 22 |
+
return model
|
MuCodec/readme.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MuCodec: Ultra Low-Bitrate Music Codec
|
| 2 |
+
|
| 3 |
+
This repository is the official code repository for MuCodec: Ultra Low-Bitrate Music Codec. You can find our paper on [arXiv] (https://arxiv.org/pdf/2409.13216). The demo page is available [here](https://xuyaoxun.github.io/MuCodec_demo/).
|
| 4 |
+
|
| 5 |
+
In this repository, we provide the Mucodec model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset. Specifically, we have released the model and inference code corresponding to the lowest bitrate of 0.35 kbps as mentioned in the paper, to demonstrate the effectiveness of our work.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
MuCodec supports 48kHz, dual-channel (stereo) audio reconstruction. If the original audio is in a different format, it will first be converted to 48kHz, dual-channel audio.
|
| 9 |
+
|
| 10 |
+
## Installation
|
| 11 |
+
|
| 12 |
+
You can install the necessary dependencies using the `requirements.txt` file with Python 3.8.12:
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
pip install -r requirements.txt
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Due to storage limitations, we have saved the model checkpoints on Hugging Face at https://huggingface.co/yaoxunxu/mucodec. You can easily download the models from Hugging Face and save them in the following directories:
|
| 19 |
+
|
| 20 |
+
- Save `audioldm_48k.pth` in the `tools` folder.
|
| 21 |
+
- Save `muq.pt` in the `muq_dev` folder.
|
| 22 |
+
- Save `mucodec.pt` in the `ckpt` folder.
|
| 23 |
+
|
| 24 |
+
Please note that all three checkpoints must be downloaded completely for the model to load correctly. The final file paths should be:
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
tools/audioldm_48k.pth
|
| 28 |
+
muq_dev/muq.pt
|
| 29 |
+
ckpt/mucodec.pt
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
The file `audioldm_48k.pth` is sourced from https://huggingface.co/haoheliu/audioldm_48k/blob/main/audioldm_48k.pth.
|
| 33 |
+
|
| 34 |
+
## Inference
|
| 35 |
+
|
| 36 |
+
To run inference, use the following command:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python3 generate.py
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
We have provided a sample song `test.wav`, randomly sampled from the Million Song Dataset, in the `test_wav` folder. The default input path is `test_wav/test.wav`, and the output path for the reconstructed audio is `reconstruct/test.wav`.
|
| 43 |
+
|
| 44 |
+
In the `generate.py` file, we have implemented several functions to facilitate the music compression and reconstruction process. You can easily obtain compressed tokens from audio using the `sound2code` function, and reconstruct the audio from tokens using the `code2sound` function.
|
| 45 |
+
|
| 46 |
+
## Note
|
| 47 |
+
|
| 48 |
+
Please note that the open-sourced model was trained solely on the Million Song Dataset. Considering the quality issues of this dataset, the open-sourced model may not achieve the same performance as demonstrated in the demo. Unfortunately, due to copyright restrictions, we are unable to release the checkpoints trained on additional datasets. However, you can use your own dataset to further train the model and achieve better results.
|
| 49 |
+
|
| 50 |
+
## License
|
| 51 |
+
|
| 52 |
+
The code in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
|
| 53 |
+
|
| 54 |
+
The model weights (muq.pt, mucodec.pt) in this repository are released under the CC-BY-NC 4.0 license, as detailed in the [LICENSE_weights](LICENSE_weights) file.
|
| 55 |
+
|
| 56 |
+
## Citation
|
| 57 |
+
|
| 58 |
+
If you find our work useful, please cite our paper:
|
| 59 |
+
|
| 60 |
+
```bibtex
|
| 61 |
+
@article{xu2024mucodec,
|
| 62 |
+
title={MuCodec: Ultra Low-Bitrate Music Codec},
|
| 63 |
+
author={Xu, Yaoxun and Chen, Hangting and Yu, Jianwei and Tan, Wei and Gu, Rongzhi and Lei, Shun and Lin, Zhiwei and Wu, Zhiyong},
|
| 64 |
+
journal={arXiv preprint arXiv:2409.13216},
|
| 65 |
+
year={2024}
|
| 66 |
+
}
|
| 67 |
+
```
|
MuCodec/requirements.txt
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.0.0
|
| 2 |
+
accelerate==0.30.1
|
| 3 |
+
aeiou==0.0.20
|
| 4 |
+
aiobotocore==2.13.1
|
| 5 |
+
aiofiles==23.2.1
|
| 6 |
+
aiohttp==3.9.3
|
| 7 |
+
aioitertools==0.11.0
|
| 8 |
+
aiosignal==1.3.1
|
| 9 |
+
alias-free-torch==0.0.6
|
| 10 |
+
altair==5.3.0
|
| 11 |
+
annotated-types==0.6.0
|
| 12 |
+
antlr4-python3-runtime==4.8
|
| 13 |
+
anyio==4.3.0
|
| 14 |
+
appdirs==1.4.4
|
| 15 |
+
argbind==0.3.9
|
| 16 |
+
asttokens==2.4.1
|
| 17 |
+
astunparse==1.6.3
|
| 18 |
+
async-timeout==4.0.3
|
| 19 |
+
attrs==23.1.0
|
| 20 |
+
audioread==3.0.1
|
| 21 |
+
auraloss==0.4.0
|
| 22 |
+
av==11.0.0
|
| 23 |
+
backcall==0.2.0
|
| 24 |
+
beartype==0.18.5
|
| 25 |
+
bitarray==2.9.2
|
| 26 |
+
bleach==6.1.0
|
| 27 |
+
blis==0.7.11
|
| 28 |
+
bokeh==3.1.1
|
| 29 |
+
botocore==1.34.131
|
| 30 |
+
braceexpand==0.1.7
|
| 31 |
+
cachetools==5.3.2
|
| 32 |
+
catalogue==2.0.10
|
| 33 |
+
certifi==2023.11.17
|
| 34 |
+
cffi==1.16.0
|
| 35 |
+
charset-normalizer==3.3.2
|
| 36 |
+
clean-fid==0.1.35
|
| 37 |
+
click==8.1.7
|
| 38 |
+
clip-anytorch==2.6.0
|
| 39 |
+
cloudpathlib==0.16.0
|
| 40 |
+
cloudpickle==3.0.0
|
| 41 |
+
cn2an==0.5.22
|
| 42 |
+
colorama==0.4.6
|
| 43 |
+
colorcet==3.1.0
|
| 44 |
+
colorlog==6.8.2
|
| 45 |
+
confection==0.1.4
|
| 46 |
+
configparser==7.0.0
|
| 47 |
+
contourpy==1.1.1
|
| 48 |
+
cycler==0.12.1
|
| 49 |
+
cymem==2.0.8
|
| 50 |
+
Cython==3.0.10
|
| 51 |
+
dataclasses==0.6
|
| 52 |
+
datasets
|
| 53 |
+
dctorch==0.1.2
|
| 54 |
+
decorator==5.1.1
|
| 55 |
+
decord==0.6.0
|
| 56 |
+
deepspeed==0.14.0
|
| 57 |
+
demucs==4.0.1
|
| 58 |
+
descript-audio-codec==1.0.0
|
| 59 |
+
descript-audiotools==0.7.2
|
| 60 |
+
diffusers==0.27.2
|
| 61 |
+
dill==0.3.8
|
| 62 |
+
Distance==0.1.3
|
| 63 |
+
docker-pycreds==0.4.0
|
| 64 |
+
docopt==0.6.2
|
| 65 |
+
docstring_parser==0.16
|
| 66 |
+
dora_search==0.1.12
|
| 67 |
+
einops==0.7.0
|
| 68 |
+
einops-exts==0.0.4
|
| 69 |
+
einx==0.3.0
|
| 70 |
+
ema-pytorch==0.2.3
|
| 71 |
+
encodec==0.1.1
|
| 72 |
+
exceptiongroup==1.2.0
|
| 73 |
+
executing==2.0.1
|
| 74 |
+
expecttest==0.1.6
|
| 75 |
+
fairseq==0.12.2
|
| 76 |
+
fastapi==0.110.3
|
| 77 |
+
fastcore==1.6.3
|
| 78 |
+
ffmpy==0.3.2
|
| 79 |
+
filelock==3.13.1
|
| 80 |
+
fire==0.6.0
|
| 81 |
+
flashy==0.0.2
|
| 82 |
+
flatten-dict==0.4.2
|
| 83 |
+
fonttools==4.49.0
|
| 84 |
+
frozendict==2.4.4
|
| 85 |
+
frozenlist==1.4.1
|
| 86 |
+
fsspec==2024.6.1
|
| 87 |
+
ftfy==6.1.3
|
| 88 |
+
future==1.0.0
|
| 89 |
+
g2p-en==2.1.0
|
| 90 |
+
gin-config==0.5.0
|
| 91 |
+
gitdb==4.0.11
|
| 92 |
+
GitPython==3.1.43
|
| 93 |
+
google-auth==2.23.4
|
| 94 |
+
google-auth-oauthlib==1.0.0
|
| 95 |
+
gradio==4.26.0
|
| 96 |
+
gradio_client==0.15.1
|
| 97 |
+
grpcio==1.59.3
|
| 98 |
+
h11==0.14.0
|
| 99 |
+
h5py==3.11.0
|
| 100 |
+
hjson==3.1.0
|
| 101 |
+
holoviews==1.17.1
|
| 102 |
+
httpcore==1.0.5
|
| 103 |
+
httpx==0.27.0
|
| 104 |
+
huggingface-hub==0.23.5
|
| 105 |
+
hydra-colorlog==1.2.0
|
| 106 |
+
hydra-core==1.0.7
|
| 107 |
+
hypothesis==6.90.0
|
| 108 |
+
idna==3.4
|
| 109 |
+
imageio==2.34.2
|
| 110 |
+
importlib-metadata==6.8.0
|
| 111 |
+
importlib-resources==5.12.0
|
| 112 |
+
inflect==7.0.0
|
| 113 |
+
ipython==8.12.3
|
| 114 |
+
jedi==0.19.1
|
| 115 |
+
jieba-fast==0.53
|
| 116 |
+
Jinja2==3.1.2
|
| 117 |
+
jmespath==1.0.1
|
| 118 |
+
joblib==1.3.2
|
| 119 |
+
json5==0.9.25
|
| 120 |
+
jsonlines==4.0.0
|
| 121 |
+
jsonmerge==1.9.2
|
| 122 |
+
jsonschema==4.22.0
|
| 123 |
+
jsonschema-specifications==2023.12.1
|
| 124 |
+
julius==0.2.7
|
| 125 |
+
k-diffusion==0.1.1
|
| 126 |
+
kaldiio==2.18.0
|
| 127 |
+
kiwisolver==1.4.5
|
| 128 |
+
kornia==0.7.3
|
| 129 |
+
kornia_rs==0.1.5
|
| 130 |
+
laion-clap==1.1.4
|
| 131 |
+
lameenc==1.7.0
|
| 132 |
+
langcodes==3.4.0
|
| 133 |
+
language_data==1.2.0
|
| 134 |
+
lazy_loader==0.3
|
| 135 |
+
librosa==0.9.2
|
| 136 |
+
lightning==2.2.1
|
| 137 |
+
lightning-utilities==0.10.1
|
| 138 |
+
linkify-it-py==2.0.3
|
| 139 |
+
lion-pytorch==0.2.2
|
| 140 |
+
llvmlite==0.41.1
|
| 141 |
+
local-attention==1.8.6
|
| 142 |
+
loguru==0.7.2
|
| 143 |
+
lxml==5.2.2
|
| 144 |
+
marisa-trie==1.1.1
|
| 145 |
+
Markdown==3.5.1
|
| 146 |
+
markdown-it-py==3.0.0
|
| 147 |
+
markdown2==2.5.0
|
| 148 |
+
MarkupSafe==2.1.3
|
| 149 |
+
matplotlib==3.7.5
|
| 150 |
+
matplotlib-inline==0.1.7
|
| 151 |
+
mdit-py-plugins==0.4.1
|
| 152 |
+
mdurl==0.1.2
|
| 153 |
+
mpmath==1.3.0
|
| 154 |
+
msgpack==1.0.8
|
| 155 |
+
multidict==6.0.5
|
| 156 |
+
multiprocess==0.70.16
|
| 157 |
+
murmurhash==1.0.10
|
| 158 |
+
mypy-extensions==1.0.0
|
| 159 |
+
networkx==3.1
|
| 160 |
+
ninja==1.11.1.1
|
| 161 |
+
nltk==3.8.1
|
| 162 |
+
nnAudio==0.3.3
|
| 163 |
+
num2words==0.5.13
|
| 164 |
+
numba==0.58.1
|
| 165 |
+
numpy==1.23.5
|
| 166 |
+
nvidia-cublas-cu11==11.11.3.6
|
| 167 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
| 168 |
+
nvidia-cuda-nvrtc-cu11==11.8.89
|
| 169 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
| 170 |
+
nvidia-cudnn-cu11==8.7.0.84
|
| 171 |
+
nvidia-cufft-cu11==10.9.0.58
|
| 172 |
+
nvidia-curand-cu11==10.3.0.86
|
| 173 |
+
nvidia-cusolver-cu11==11.4.1.48
|
| 174 |
+
nvidia-cusparse-cu11==11.7.5.86
|
| 175 |
+
nvidia-nccl-cu11==2.19.3
|
| 176 |
+
nvidia-nvtx-cu11==11.8.86
|
| 177 |
+
oauthlib==3.2.2
|
| 178 |
+
omegaconf
|
| 179 |
+
opencv-contrib-python==4.8.1.78
|
| 180 |
+
opencv-python==4.8.1.78
|
| 181 |
+
openunmix==1.2.1
|
| 182 |
+
orjson==3.10.3
|
| 183 |
+
packaging==23.2
|
| 184 |
+
pandas==2.0.2
|
| 185 |
+
panel==1.2.3
|
| 186 |
+
param==2.1.1
|
| 187 |
+
parso==0.8.4
|
| 188 |
+
pathtools==0.1.2
|
| 189 |
+
pedalboard==0.7.4
|
| 190 |
+
peft==0.10.0
|
| 191 |
+
pexpect==4.9.0
|
| 192 |
+
pickleshare==0.7.5
|
| 193 |
+
Pillow==10.1.0
|
| 194 |
+
pkgutil_resolve_name==1.3.10
|
| 195 |
+
platformdirs==4.2.0
|
| 196 |
+
plotly==5.23.0
|
| 197 |
+
pooch==1.8.1
|
| 198 |
+
portalocker==2.10.1
|
| 199 |
+
prefigure==0.0.9
|
| 200 |
+
preshed==3.0.9
|
| 201 |
+
proces==0.1.7
|
| 202 |
+
prodict==0.8.18
|
| 203 |
+
progressbar==2.5
|
| 204 |
+
prompt_toolkit==3.0.47
|
| 205 |
+
protobuf==3.19.6
|
| 206 |
+
psutil==5.9.6
|
| 207 |
+
ptyprocess==0.7.0
|
| 208 |
+
pure_eval==0.2.3
|
| 209 |
+
py-cpuinfo==9.0.0
|
| 210 |
+
pyarrow==17.0.0
|
| 211 |
+
pyarrow-hotfix==0.6
|
| 212 |
+
pyasn1==0.5.1
|
| 213 |
+
pyasn1-modules==0.3.0
|
| 214 |
+
pybind11==2.11.1
|
| 215 |
+
pycparser==2.21
|
| 216 |
+
pydantic==2.6.3
|
| 217 |
+
pydantic_core==2.16.3
|
| 218 |
+
pydub==0.25.1
|
| 219 |
+
Pygments==2.18.0
|
| 220 |
+
pyloudnorm==0.1.1
|
| 221 |
+
pynndescent==0.5.13
|
| 222 |
+
pynvml==11.5.0
|
| 223 |
+
pyparsing==3.1.2
|
| 224 |
+
pypinyin==0.51.0
|
| 225 |
+
pyre-extensions==0.0.29
|
| 226 |
+
pyreaper==0.0.10
|
| 227 |
+
pystoi==0.4.1
|
| 228 |
+
python-dateutil==2.8.2
|
| 229 |
+
python-multipart==0.0.9
|
| 230 |
+
pytorch-lightning==2.1.0
|
| 231 |
+
pytz==2023.3.post1
|
| 232 |
+
pyviz_comms==3.0.3
|
| 233 |
+
PyWavelets==1.4.1
|
| 234 |
+
PyYAML==6.0.1
|
| 235 |
+
randomname==0.2.1
|
| 236 |
+
referencing==0.35.1
|
| 237 |
+
regex==2023.10.3
|
| 238 |
+
requests==2.32.3
|
| 239 |
+
requests-oauthlib==1.3.1
|
| 240 |
+
resampy==0.4.3
|
| 241 |
+
retrying==1.3.4
|
| 242 |
+
rich==13.7.1
|
| 243 |
+
rpds-py==0.18.1
|
| 244 |
+
rsa==4.9
|
| 245 |
+
ruamel.yaml==0.18.5
|
| 246 |
+
ruamel.yaml.clib==0.2.8
|
| 247 |
+
ruff==0.4.4
|
| 248 |
+
s3fs==2024.6.1
|
| 249 |
+
s3transfer==0.7.0
|
| 250 |
+
sacrebleu==2.4.2
|
| 251 |
+
safetensors==0.4.3
|
| 252 |
+
scikit-image==0.21.0
|
| 253 |
+
scikit-learn==1.3.2
|
| 254 |
+
scipy==1.10.1
|
| 255 |
+
semantic-version==2.10.0
|
| 256 |
+
sentencepiece==0.1.99
|
| 257 |
+
sentry-sdk==2.10.0
|
| 258 |
+
setproctitle==1.3.3
|
| 259 |
+
shellingham==1.5.4
|
| 260 |
+
six==1.16.0
|
| 261 |
+
smart-open==6.4.0
|
| 262 |
+
smmap==5.0.1
|
| 263 |
+
sniffio==1.3.1
|
| 264 |
+
sortedcontainers==2.4.0
|
| 265 |
+
SoundFile==0.10.2
|
| 266 |
+
sox==1.4.1
|
| 267 |
+
soxr==0.3.7
|
| 268 |
+
spacy==3.7.4
|
| 269 |
+
spacy-legacy==3.0.12
|
| 270 |
+
spacy-loggers==1.0.5
|
| 271 |
+
srsly==2.4.8
|
| 272 |
+
stack-data==0.6.3
|
| 273 |
+
starlette==0.37.2
|
| 274 |
+
submitit==1.5.1
|
| 275 |
+
sympy==1.12
|
| 276 |
+
tabulate==0.9.0
|
| 277 |
+
tenacity==9.0.0
|
| 278 |
+
tensorboard==2.14.0
|
| 279 |
+
tensorboard-data-server==0.7.2
|
| 280 |
+
termcolor==2.3.0
|
| 281 |
+
thinc==8.2.3
|
| 282 |
+
threadpoolctl==3.3.0
|
| 283 |
+
tifffile==2023.7.10
|
| 284 |
+
timm==0.9.11
|
| 285 |
+
tokenizers==0.19.1
|
| 286 |
+
tomlkit==0.12.0
|
| 287 |
+
toolz==0.12.1
|
| 288 |
+
torch==2.2.0+cu118
|
| 289 |
+
torch-stoi==0.2.1
|
| 290 |
+
torchaudio==2.2.0+cu118
|
| 291 |
+
torchdata==0.7.1
|
| 292 |
+
torchdiffeq==0.2.4
|
| 293 |
+
torchlibrosa==0.1.0
|
| 294 |
+
torchmetrics==0.11.4
|
| 295 |
+
torchsde==0.2.6
|
| 296 |
+
torchtext==0.17.0
|
| 297 |
+
torchvision==0.17.0+cu118
|
| 298 |
+
tornado==6.4.1
|
| 299 |
+
tqdm==4.66.4
|
| 300 |
+
traitlets==5.14.3
|
| 301 |
+
trampoline==0.1.2
|
| 302 |
+
transformers==4.42.4
|
| 303 |
+
treetable==0.2.5
|
| 304 |
+
triton==2.2.0
|
| 305 |
+
typeguard==2.13.0
|
| 306 |
+
typer==0.9.4
|
| 307 |
+
types-dataclasses==0.6.6
|
| 308 |
+
typing-inspect==0.9.0
|
| 309 |
+
typing_extensions==4.8.0
|
| 310 |
+
tzdata==2023.3
|
| 311 |
+
uc-micro-py==1.0.3
|
| 312 |
+
umap-learn==0.5.6
|
| 313 |
+
Unidecode==1.3.8
|
| 314 |
+
urllib3==1.26.18
|
| 315 |
+
uvicorn==0.29.0
|
| 316 |
+
v-diffusion-pytorch==0.0.2
|
| 317 |
+
vector-quantize-pytorch==1.9.14
|
| 318 |
+
wandb==0.15.4
|
| 319 |
+
wasabi==1.1.2
|
| 320 |
+
wcwidth==0.2.12
|
| 321 |
+
weasel==0.3.4
|
| 322 |
+
webdataset==0.2.48
|
| 323 |
+
webencodings==0.5.1
|
| 324 |
+
websockets==11.0.3
|
| 325 |
+
Werkzeug==3.0.1
|
| 326 |
+
wget==3.2
|
| 327 |
+
wordsegment==1.3.1
|
| 328 |
+
wrapt==1.16.0
|
| 329 |
+
x-clip==0.14.4
|
| 330 |
+
x-transformers==1.26.6
|
| 331 |
+
xformers==0.0.24+cu118
|
| 332 |
+
xxhash==3.4.1
|
| 333 |
+
xyzservices==2024.6.0
|
| 334 |
+
yarl==1.9.4
|
| 335 |
+
zipp==3.17.0
|
MuCodec/tools/get_melvaehifigan48k.py
ADDED
|
@@ -0,0 +1,1551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import soundfile as sf
|
| 3 |
+
import os
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
| 7 |
+
import tools.torch_tools as torch_tools
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from scipy.signal import get_window
|
| 13 |
+
from librosa.util import pad_center, tiny
|
| 14 |
+
import librosa.util as librosa_util
|
| 15 |
+
|
| 16 |
+
class AttrDict(dict):
|
| 17 |
+
def __init__(self, *args, **kwargs):
|
| 18 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 19 |
+
self.__dict__ = self
|
| 20 |
+
|
| 21 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 22 |
+
classname = m.__class__.__name__
|
| 23 |
+
if classname.find("Conv") != -1:
|
| 24 |
+
m.weight.data.normal_(mean, std)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_padding(kernel_size, dilation=1):
|
| 28 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 29 |
+
|
| 30 |
+
LRELU_SLOPE = 0.1
|
| 31 |
+
|
| 32 |
+
class ResBlock(torch.nn.Module):
|
| 33 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 34 |
+
super(ResBlock, self).__init__()
|
| 35 |
+
self.h = h
|
| 36 |
+
self.convs1 = nn.ModuleList(
|
| 37 |
+
[
|
| 38 |
+
torch.nn.utils.weight_norm(
|
| 39 |
+
nn.Conv1d(
|
| 40 |
+
channels,
|
| 41 |
+
channels,
|
| 42 |
+
kernel_size,
|
| 43 |
+
1,
|
| 44 |
+
dilation=dilation[0],
|
| 45 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 46 |
+
)
|
| 47 |
+
),
|
| 48 |
+
torch.nn.utils.weight_norm(
|
| 49 |
+
nn.Conv1d(
|
| 50 |
+
channels,
|
| 51 |
+
channels,
|
| 52 |
+
kernel_size,
|
| 53 |
+
1,
|
| 54 |
+
dilation=dilation[1],
|
| 55 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 56 |
+
)
|
| 57 |
+
),
|
| 58 |
+
torch.nn.utils.weight_norm(
|
| 59 |
+
nn.Conv1d(
|
| 60 |
+
channels,
|
| 61 |
+
channels,
|
| 62 |
+
kernel_size,
|
| 63 |
+
1,
|
| 64 |
+
dilation=dilation[2],
|
| 65 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 66 |
+
)
|
| 67 |
+
),
|
| 68 |
+
]
|
| 69 |
+
)
|
| 70 |
+
self.convs1.apply(init_weights)
|
| 71 |
+
|
| 72 |
+
self.convs2 = nn.ModuleList(
|
| 73 |
+
[
|
| 74 |
+
torch.nn.utils.weight_norm(
|
| 75 |
+
nn.Conv1d(
|
| 76 |
+
channels,
|
| 77 |
+
channels,
|
| 78 |
+
kernel_size,
|
| 79 |
+
1,
|
| 80 |
+
dilation=1,
|
| 81 |
+
padding=get_padding(kernel_size, 1),
|
| 82 |
+
)
|
| 83 |
+
),
|
| 84 |
+
torch.nn.utils.weight_norm(
|
| 85 |
+
nn.Conv1d(
|
| 86 |
+
channels,
|
| 87 |
+
channels,
|
| 88 |
+
kernel_size,
|
| 89 |
+
1,
|
| 90 |
+
dilation=1,
|
| 91 |
+
padding=get_padding(kernel_size, 1),
|
| 92 |
+
)
|
| 93 |
+
),
|
| 94 |
+
torch.nn.utils.weight_norm(
|
| 95 |
+
nn.Conv1d(
|
| 96 |
+
channels,
|
| 97 |
+
channels,
|
| 98 |
+
kernel_size,
|
| 99 |
+
1,
|
| 100 |
+
dilation=1,
|
| 101 |
+
padding=get_padding(kernel_size, 1),
|
| 102 |
+
)
|
| 103 |
+
),
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
self.convs2.apply(init_weights)
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 110 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 111 |
+
xt = c1(xt)
|
| 112 |
+
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
|
| 113 |
+
xt = c2(xt)
|
| 114 |
+
x = xt + x
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
def remove_weight_norm(self):
|
| 118 |
+
for l in self.convs1:
|
| 119 |
+
torch.nn.utils.remove_weight_norm(l)
|
| 120 |
+
for l in self.convs2:
|
| 121 |
+
torch.nn.utils.remove_weight_norm(l)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Generator_old(torch.nn.Module):
|
| 125 |
+
def __init__(self, h):
|
| 126 |
+
super(Generator_old, self).__init__()
|
| 127 |
+
self.h = h
|
| 128 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 129 |
+
self.num_upsamples = len(h.upsample_rates)
|
| 130 |
+
self.conv_pre = torch.nn.utils.weight_norm(
|
| 131 |
+
nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
| 132 |
+
)
|
| 133 |
+
resblock = ResBlock
|
| 134 |
+
|
| 135 |
+
self.ups = nn.ModuleList()
|
| 136 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 137 |
+
self.ups.append(
|
| 138 |
+
torch.nn.utils.weight_norm(
|
| 139 |
+
nn.ConvTranspose1d(
|
| 140 |
+
h.upsample_initial_channel // (2**i),
|
| 141 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
| 142 |
+
k,
|
| 143 |
+
u,
|
| 144 |
+
padding=(k - u) // 2,
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.resblocks = nn.ModuleList()
|
| 150 |
+
for i in range(len(self.ups)):
|
| 151 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
| 152 |
+
for j, (k, d) in enumerate(
|
| 153 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
| 154 |
+
):
|
| 155 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
| 156 |
+
|
| 157 |
+
self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
|
| 158 |
+
self.ups.apply(init_weights)
|
| 159 |
+
self.conv_post.apply(init_weights)
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
x = self.conv_pre(x)
|
| 163 |
+
for i in range(self.num_upsamples):
|
| 164 |
+
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 165 |
+
x = self.ups[i](x)
|
| 166 |
+
xs = None
|
| 167 |
+
for j in range(self.num_kernels):
|
| 168 |
+
if xs is None:
|
| 169 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 170 |
+
else:
|
| 171 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 172 |
+
x = xs / self.num_kernels
|
| 173 |
+
x = torch.nn.functional.leaky_relu(x)
|
| 174 |
+
x = self.conv_post(x)
|
| 175 |
+
x = torch.tanh(x)
|
| 176 |
+
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
def remove_weight_norm(self):
|
| 180 |
+
# print("Removing weight norm...")
|
| 181 |
+
for l in self.ups:
|
| 182 |
+
torch.nn.utils.remove_weight_norm(l)
|
| 183 |
+
for l in self.resblocks:
|
| 184 |
+
l.remove_weight_norm()
|
| 185 |
+
torch.nn.utils.remove_weight_norm(self.conv_pre)
|
| 186 |
+
torch.nn.utils.remove_weight_norm(self.conv_post)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def nonlinearity(x):
|
| 191 |
+
# swish
|
| 192 |
+
return x * torch.sigmoid(x)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def Normalize(in_channels, num_groups=32):
|
| 196 |
+
return torch.nn.GroupNorm(
|
| 197 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
class Downsample(nn.Module):
|
| 201 |
+
def __init__(self, in_channels, with_conv):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.with_conv = with_conv
|
| 204 |
+
if self.with_conv:
|
| 205 |
+
# Do time downsampling here
|
| 206 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 207 |
+
self.conv = torch.nn.Conv2d(
|
| 208 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
if self.with_conv:
|
| 213 |
+
pad = (0, 1, 0, 1)
|
| 214 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 215 |
+
x = self.conv(x)
|
| 216 |
+
else:
|
| 217 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class DownsampleTimeStride4(nn.Module):
|
| 222 |
+
def __init__(self, in_channels, with_conv):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.with_conv = with_conv
|
| 225 |
+
if self.with_conv:
|
| 226 |
+
# Do time downsampling here
|
| 227 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 228 |
+
self.conv = torch.nn.Conv2d(
|
| 229 |
+
in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
if self.with_conv:
|
| 234 |
+
pad = (0, 1, 0, 1)
|
| 235 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 236 |
+
x = self.conv(x)
|
| 237 |
+
else:
|
| 238 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
|
| 239 |
+
return x
|
| 240 |
+
|
| 241 |
+
class Upsample(nn.Module):
|
| 242 |
+
def __init__(self, in_channels, with_conv):
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.with_conv = with_conv
|
| 245 |
+
if self.with_conv:
|
| 246 |
+
self.conv = torch.nn.Conv2d(
|
| 247 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
def forward(self, x):
|
| 251 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 252 |
+
if self.with_conv:
|
| 253 |
+
x = self.conv(x)
|
| 254 |
+
return x
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class UpsampleTimeStride4(nn.Module):
|
| 258 |
+
def __init__(self, in_channels, with_conv):
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.with_conv = with_conv
|
| 261 |
+
if self.with_conv:
|
| 262 |
+
self.conv = torch.nn.Conv2d(
|
| 263 |
+
in_channels, in_channels, kernel_size=5, stride=1, padding=2
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
|
| 268 |
+
if self.with_conv:
|
| 269 |
+
x = self.conv(x)
|
| 270 |
+
return x
|
| 271 |
+
|
| 272 |
+
class AttnBlock(nn.Module):
|
| 273 |
+
def __init__(self, in_channels):
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.in_channels = in_channels
|
| 276 |
+
|
| 277 |
+
self.norm = Normalize(in_channels)
|
| 278 |
+
self.q = torch.nn.Conv2d(
|
| 279 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 280 |
+
)
|
| 281 |
+
self.k = torch.nn.Conv2d(
|
| 282 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 283 |
+
)
|
| 284 |
+
self.v = torch.nn.Conv2d(
|
| 285 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 286 |
+
)
|
| 287 |
+
self.proj_out = torch.nn.Conv2d(
|
| 288 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def forward(self, x):
|
| 292 |
+
h_ = x
|
| 293 |
+
h_ = self.norm(h_)
|
| 294 |
+
q = self.q(h_)
|
| 295 |
+
k = self.k(h_)
|
| 296 |
+
v = self.v(h_)
|
| 297 |
+
|
| 298 |
+
# compute attention
|
| 299 |
+
b, c, h, w = q.shape
|
| 300 |
+
q = q.reshape(b, c, h * w).contiguous()
|
| 301 |
+
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
| 302 |
+
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
| 303 |
+
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 304 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 305 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 306 |
+
|
| 307 |
+
# attend to values
|
| 308 |
+
v = v.reshape(b, c, h * w).contiguous()
|
| 309 |
+
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
| 310 |
+
h_ = torch.bmm(
|
| 311 |
+
v, w_
|
| 312 |
+
).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 313 |
+
h_ = h_.reshape(b, c, h, w).contiguous()
|
| 314 |
+
|
| 315 |
+
h_ = self.proj_out(h_)
|
| 316 |
+
|
| 317 |
+
return x + h_
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
| 321 |
+
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
|
| 322 |
+
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
| 323 |
+
if attn_type == "vanilla":
|
| 324 |
+
return AttnBlock(in_channels)
|
| 325 |
+
elif attn_type == "none":
|
| 326 |
+
return nn.Identity(in_channels)
|
| 327 |
+
else:
|
| 328 |
+
raise ValueError(attn_type)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class ResnetBlock(nn.Module):
|
| 332 |
+
def __init__(
|
| 333 |
+
self,
|
| 334 |
+
*,
|
| 335 |
+
in_channels,
|
| 336 |
+
out_channels=None,
|
| 337 |
+
conv_shortcut=False,
|
| 338 |
+
dropout,
|
| 339 |
+
temb_channels=512,
|
| 340 |
+
):
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.in_channels = in_channels
|
| 343 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 344 |
+
self.out_channels = out_channels
|
| 345 |
+
self.use_conv_shortcut = conv_shortcut
|
| 346 |
+
|
| 347 |
+
self.norm1 = Normalize(in_channels)
|
| 348 |
+
self.conv1 = torch.nn.Conv2d(
|
| 349 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 350 |
+
)
|
| 351 |
+
if temb_channels > 0:
|
| 352 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 353 |
+
self.norm2 = Normalize(out_channels)
|
| 354 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 355 |
+
self.conv2 = torch.nn.Conv2d(
|
| 356 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 357 |
+
)
|
| 358 |
+
if self.in_channels != self.out_channels:
|
| 359 |
+
if self.use_conv_shortcut:
|
| 360 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
| 361 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 362 |
+
)
|
| 363 |
+
else:
|
| 364 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
| 365 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def forward(self, x, temb):
|
| 369 |
+
h = x
|
| 370 |
+
h = self.norm1(h)
|
| 371 |
+
h = nonlinearity(h)
|
| 372 |
+
h = self.conv1(h)
|
| 373 |
+
|
| 374 |
+
if temb is not None:
|
| 375 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
| 376 |
+
|
| 377 |
+
h = self.norm2(h)
|
| 378 |
+
h = nonlinearity(h)
|
| 379 |
+
h = self.dropout(h)
|
| 380 |
+
h = self.conv2(h)
|
| 381 |
+
|
| 382 |
+
if self.in_channels != self.out_channels:
|
| 383 |
+
if self.use_conv_shortcut:
|
| 384 |
+
x = self.conv_shortcut(x)
|
| 385 |
+
else:
|
| 386 |
+
x = self.nin_shortcut(x)
|
| 387 |
+
|
| 388 |
+
return x + h
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class Encoder(nn.Module):
|
| 392 |
+
def __init__(
|
| 393 |
+
self,
|
| 394 |
+
*,
|
| 395 |
+
ch,
|
| 396 |
+
out_ch,
|
| 397 |
+
ch_mult=(1, 2, 4, 8),
|
| 398 |
+
num_res_blocks,
|
| 399 |
+
attn_resolutions,
|
| 400 |
+
dropout=0.0,
|
| 401 |
+
resamp_with_conv=True,
|
| 402 |
+
in_channels,
|
| 403 |
+
resolution,
|
| 404 |
+
z_channels,
|
| 405 |
+
double_z=True,
|
| 406 |
+
use_linear_attn=False,
|
| 407 |
+
attn_type="vanilla",
|
| 408 |
+
downsample_time_stride4_levels=[],
|
| 409 |
+
**ignore_kwargs,
|
| 410 |
+
):
|
| 411 |
+
super().__init__()
|
| 412 |
+
if use_linear_attn:
|
| 413 |
+
attn_type = "linear"
|
| 414 |
+
self.ch = ch
|
| 415 |
+
self.temb_ch = 0
|
| 416 |
+
self.num_resolutions = len(ch_mult)
|
| 417 |
+
self.num_res_blocks = num_res_blocks
|
| 418 |
+
self.resolution = resolution
|
| 419 |
+
self.in_channels = in_channels
|
| 420 |
+
self.downsample_time_stride4_levels = downsample_time_stride4_levels
|
| 421 |
+
|
| 422 |
+
if len(self.downsample_time_stride4_levels) > 0:
|
| 423 |
+
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
|
| 424 |
+
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
|
| 425 |
+
% str(self.num_resolutions)
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# downsampling
|
| 429 |
+
self.conv_in = torch.nn.Conv2d(
|
| 430 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
curr_res = resolution
|
| 434 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 435 |
+
self.in_ch_mult = in_ch_mult
|
| 436 |
+
self.down = nn.ModuleList()
|
| 437 |
+
for i_level in range(self.num_resolutions):
|
| 438 |
+
block = nn.ModuleList()
|
| 439 |
+
attn = nn.ModuleList()
|
| 440 |
+
block_in = ch * in_ch_mult[i_level]
|
| 441 |
+
block_out = ch * ch_mult[i_level]
|
| 442 |
+
for i_block in range(self.num_res_blocks):
|
| 443 |
+
block.append(
|
| 444 |
+
ResnetBlock(
|
| 445 |
+
in_channels=block_in,
|
| 446 |
+
out_channels=block_out,
|
| 447 |
+
temb_channels=self.temb_ch,
|
| 448 |
+
dropout=dropout,
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
block_in = block_out
|
| 452 |
+
if curr_res in attn_resolutions:
|
| 453 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 454 |
+
down = nn.Module()
|
| 455 |
+
down.block = block
|
| 456 |
+
down.attn = attn
|
| 457 |
+
if i_level != self.num_resolutions - 1:
|
| 458 |
+
if i_level in self.downsample_time_stride4_levels:
|
| 459 |
+
down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
|
| 460 |
+
else:
|
| 461 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 462 |
+
curr_res = curr_res // 2
|
| 463 |
+
self.down.append(down)
|
| 464 |
+
|
| 465 |
+
# middle
|
| 466 |
+
self.mid = nn.Module()
|
| 467 |
+
self.mid.block_1 = ResnetBlock(
|
| 468 |
+
in_channels=block_in,
|
| 469 |
+
out_channels=block_in,
|
| 470 |
+
temb_channels=self.temb_ch,
|
| 471 |
+
dropout=dropout,
|
| 472 |
+
)
|
| 473 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 474 |
+
self.mid.block_2 = ResnetBlock(
|
| 475 |
+
in_channels=block_in,
|
| 476 |
+
out_channels=block_in,
|
| 477 |
+
temb_channels=self.temb_ch,
|
| 478 |
+
dropout=dropout,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# end
|
| 482 |
+
self.norm_out = Normalize(block_in)
|
| 483 |
+
self.conv_out = torch.nn.Conv2d(
|
| 484 |
+
block_in,
|
| 485 |
+
2 * z_channels if double_z else z_channels,
|
| 486 |
+
kernel_size=3,
|
| 487 |
+
stride=1,
|
| 488 |
+
padding=1,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
def forward(self, x):
|
| 492 |
+
# timestep embedding
|
| 493 |
+
temb = None
|
| 494 |
+
# downsampling
|
| 495 |
+
hs = [self.conv_in(x)]
|
| 496 |
+
for i_level in range(self.num_resolutions):
|
| 497 |
+
for i_block in range(self.num_res_blocks):
|
| 498 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 499 |
+
if len(self.down[i_level].attn) > 0:
|
| 500 |
+
h = self.down[i_level].attn[i_block](h)
|
| 501 |
+
hs.append(h)
|
| 502 |
+
if i_level != self.num_resolutions - 1:
|
| 503 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 504 |
+
|
| 505 |
+
# middle
|
| 506 |
+
h = hs[-1]
|
| 507 |
+
h = self.mid.block_1(h, temb)
|
| 508 |
+
h = self.mid.attn_1(h)
|
| 509 |
+
h = self.mid.block_2(h, temb)
|
| 510 |
+
|
| 511 |
+
# end
|
| 512 |
+
h = self.norm_out(h)
|
| 513 |
+
h = nonlinearity(h)
|
| 514 |
+
h = self.conv_out(h)
|
| 515 |
+
return h
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class Decoder(nn.Module):
|
| 519 |
+
def __init__(
|
| 520 |
+
self,
|
| 521 |
+
*,
|
| 522 |
+
ch,
|
| 523 |
+
out_ch,
|
| 524 |
+
ch_mult=(1, 2, 4, 8),
|
| 525 |
+
num_res_blocks,
|
| 526 |
+
attn_resolutions,
|
| 527 |
+
dropout=0.0,
|
| 528 |
+
resamp_with_conv=True,
|
| 529 |
+
in_channels,
|
| 530 |
+
resolution,
|
| 531 |
+
z_channels,
|
| 532 |
+
give_pre_end=False,
|
| 533 |
+
tanh_out=False,
|
| 534 |
+
use_linear_attn=False,
|
| 535 |
+
downsample_time_stride4_levels=[],
|
| 536 |
+
attn_type="vanilla",
|
| 537 |
+
**ignorekwargs,
|
| 538 |
+
):
|
| 539 |
+
super().__init__()
|
| 540 |
+
if use_linear_attn:
|
| 541 |
+
attn_type = "linear"
|
| 542 |
+
self.ch = ch
|
| 543 |
+
self.temb_ch = 0
|
| 544 |
+
self.num_resolutions = len(ch_mult)
|
| 545 |
+
self.num_res_blocks = num_res_blocks
|
| 546 |
+
self.resolution = resolution
|
| 547 |
+
self.in_channels = in_channels
|
| 548 |
+
self.give_pre_end = give_pre_end
|
| 549 |
+
self.tanh_out = tanh_out
|
| 550 |
+
self.downsample_time_stride4_levels = downsample_time_stride4_levels
|
| 551 |
+
|
| 552 |
+
if len(self.downsample_time_stride4_levels) > 0:
|
| 553 |
+
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
|
| 554 |
+
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
|
| 555 |
+
% str(self.num_resolutions)
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 559 |
+
(1,) + tuple(ch_mult)
|
| 560 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 561 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 562 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 563 |
+
# print(
|
| 564 |
+
# "Working with z of shape {} = {} dimensions.".format(
|
| 565 |
+
# self.z_shape, np.prod(self.z_shape)
|
| 566 |
+
# )
|
| 567 |
+
# )
|
| 568 |
+
|
| 569 |
+
# z to block_in
|
| 570 |
+
self.conv_in = torch.nn.Conv2d(
|
| 571 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
# middle
|
| 575 |
+
self.mid = nn.Module()
|
| 576 |
+
self.mid.block_1 = ResnetBlock(
|
| 577 |
+
in_channels=block_in,
|
| 578 |
+
out_channels=block_in,
|
| 579 |
+
temb_channels=self.temb_ch,
|
| 580 |
+
dropout=dropout,
|
| 581 |
+
)
|
| 582 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 583 |
+
self.mid.block_2 = ResnetBlock(
|
| 584 |
+
in_channels=block_in,
|
| 585 |
+
out_channels=block_in,
|
| 586 |
+
temb_channels=self.temb_ch,
|
| 587 |
+
dropout=dropout,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# upsampling
|
| 591 |
+
self.up = nn.ModuleList()
|
| 592 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 593 |
+
block = nn.ModuleList()
|
| 594 |
+
attn = nn.ModuleList()
|
| 595 |
+
block_out = ch * ch_mult[i_level]
|
| 596 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 597 |
+
block.append(
|
| 598 |
+
ResnetBlock(
|
| 599 |
+
in_channels=block_in,
|
| 600 |
+
out_channels=block_out,
|
| 601 |
+
temb_channels=self.temb_ch,
|
| 602 |
+
dropout=dropout,
|
| 603 |
+
)
|
| 604 |
+
)
|
| 605 |
+
block_in = block_out
|
| 606 |
+
if curr_res in attn_resolutions:
|
| 607 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 608 |
+
up = nn.Module()
|
| 609 |
+
up.block = block
|
| 610 |
+
up.attn = attn
|
| 611 |
+
if i_level != 0:
|
| 612 |
+
if i_level - 1 in self.downsample_time_stride4_levels:
|
| 613 |
+
up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
|
| 614 |
+
else:
|
| 615 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 616 |
+
curr_res = curr_res * 2
|
| 617 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 618 |
+
|
| 619 |
+
# end
|
| 620 |
+
self.norm_out = Normalize(block_in)
|
| 621 |
+
self.conv_out = torch.nn.Conv2d(
|
| 622 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
def forward(self, z):
|
| 626 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
| 627 |
+
self.last_z_shape = z.shape
|
| 628 |
+
|
| 629 |
+
# timestep embedding
|
| 630 |
+
temb = None
|
| 631 |
+
|
| 632 |
+
# z to block_in
|
| 633 |
+
h = self.conv_in(z)
|
| 634 |
+
|
| 635 |
+
# middle
|
| 636 |
+
h = self.mid.block_1(h, temb)
|
| 637 |
+
h = self.mid.attn_1(h)
|
| 638 |
+
h = self.mid.block_2(h, temb)
|
| 639 |
+
|
| 640 |
+
# upsampling
|
| 641 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 642 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 643 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 644 |
+
if len(self.up[i_level].attn) > 0:
|
| 645 |
+
h = self.up[i_level].attn[i_block](h)
|
| 646 |
+
if i_level != 0:
|
| 647 |
+
h = self.up[i_level].upsample(h)
|
| 648 |
+
|
| 649 |
+
# end
|
| 650 |
+
if self.give_pre_end:
|
| 651 |
+
return h
|
| 652 |
+
|
| 653 |
+
h = self.norm_out(h)
|
| 654 |
+
h = nonlinearity(h)
|
| 655 |
+
h = self.conv_out(h)
|
| 656 |
+
if self.tanh_out:
|
| 657 |
+
h = torch.tanh(h)
|
| 658 |
+
return h
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
class DiagonalGaussianDistribution(object):
|
| 662 |
+
def __init__(self, parameters, deterministic=False):
|
| 663 |
+
self.parameters = parameters
|
| 664 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
| 665 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 666 |
+
self.deterministic = deterministic
|
| 667 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 668 |
+
self.var = torch.exp(self.logvar)
|
| 669 |
+
if self.deterministic:
|
| 670 |
+
self.var = self.std = torch.zeros_like(self.mean).to(
|
| 671 |
+
device=self.parameters.device
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
def sample(self):
|
| 675 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
| 676 |
+
device=self.parameters.device
|
| 677 |
+
)
|
| 678 |
+
return x
|
| 679 |
+
|
| 680 |
+
def kl(self, other=None):
|
| 681 |
+
if self.deterministic:
|
| 682 |
+
return torch.Tensor([0.0])
|
| 683 |
+
else:
|
| 684 |
+
if other is None:
|
| 685 |
+
return 0.5 * torch.mean(
|
| 686 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
| 687 |
+
dim=[1, 2, 3],
|
| 688 |
+
)
|
| 689 |
+
else:
|
| 690 |
+
return 0.5 * torch.mean(
|
| 691 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 692 |
+
+ self.var / other.var
|
| 693 |
+
- 1.0
|
| 694 |
+
- self.logvar
|
| 695 |
+
+ other.logvar,
|
| 696 |
+
dim=[1, 2, 3],
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
def nll(self, sample, dims=[1, 2, 3]):
|
| 700 |
+
if self.deterministic:
|
| 701 |
+
return torch.Tensor([0.0])
|
| 702 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 703 |
+
return 0.5 * torch.sum(
|
| 704 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 705 |
+
dim=dims,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
def mode(self):
|
| 709 |
+
return self.mean
|
| 710 |
+
|
| 711 |
+
def get_vocoder_config_48k():
|
| 712 |
+
return {
|
| 713 |
+
"resblock": "1",
|
| 714 |
+
"num_gpus": 8,
|
| 715 |
+
"batch_size": 128,
|
| 716 |
+
"learning_rate": 0.0001,
|
| 717 |
+
"adam_b1": 0.8,
|
| 718 |
+
"adam_b2": 0.99,
|
| 719 |
+
"lr_decay": 0.999,
|
| 720 |
+
"seed": 1234,
|
| 721 |
+
|
| 722 |
+
"upsample_rates": [6,5,4,2,2],
|
| 723 |
+
"upsample_kernel_sizes": [12,10,8,4,4],
|
| 724 |
+
"upsample_initial_channel": 1536,
|
| 725 |
+
"resblock_kernel_sizes": [3,7,11,15],
|
| 726 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]],
|
| 727 |
+
|
| 728 |
+
"segment_size": 15360,
|
| 729 |
+
"num_mels": 256,
|
| 730 |
+
"n_fft": 2048,
|
| 731 |
+
"hop_size": 480,
|
| 732 |
+
"win_size": 2048,
|
| 733 |
+
|
| 734 |
+
"sampling_rate": 48000,
|
| 735 |
+
|
| 736 |
+
"fmin": 20,
|
| 737 |
+
"fmax": 24000,
|
| 738 |
+
"fmax_for_loss": None,
|
| 739 |
+
|
| 740 |
+
"num_workers": 8,
|
| 741 |
+
|
| 742 |
+
"dist_config": {
|
| 743 |
+
"dist_backend": "nccl",
|
| 744 |
+
"dist_url": "tcp://localhost:18273",
|
| 745 |
+
"world_size": 1
|
| 746 |
+
}
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
def get_vocoder(config, device, mel_bins):
|
| 750 |
+
name = "HiFi-GAN"
|
| 751 |
+
speaker = ""
|
| 752 |
+
if name == "MelGAN":
|
| 753 |
+
if speaker == "LJSpeech":
|
| 754 |
+
vocoder = torch.hub.load(
|
| 755 |
+
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
|
| 756 |
+
)
|
| 757 |
+
elif speaker == "universal":
|
| 758 |
+
vocoder = torch.hub.load(
|
| 759 |
+
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
|
| 760 |
+
)
|
| 761 |
+
vocoder.mel2wav.eval()
|
| 762 |
+
vocoder.mel2wav.to(device)
|
| 763 |
+
elif name == "HiFi-GAN":
|
| 764 |
+
if(mel_bins == 256):
|
| 765 |
+
config = get_vocoder_config_48k()
|
| 766 |
+
config = AttrDict(config)
|
| 767 |
+
vocoder = Generator_old(config)
|
| 768 |
+
# print("Load hifigan/g_01080000")
|
| 769 |
+
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
|
| 770 |
+
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
|
| 771 |
+
# ckpt = torch_version_orig_mod_remove(ckpt)
|
| 772 |
+
# vocoder.load_state_dict(ckpt["generator"])
|
| 773 |
+
vocoder.eval()
|
| 774 |
+
vocoder.remove_weight_norm()
|
| 775 |
+
vocoder.to(device)
|
| 776 |
+
else:
|
| 777 |
+
raise ValueError(mel_bins)
|
| 778 |
+
return vocoder
|
| 779 |
+
|
| 780 |
+
def vocoder_infer(mels, vocoder, lengths=None):
|
| 781 |
+
with torch.no_grad():
|
| 782 |
+
wavs = vocoder(mels).squeeze(1)
|
| 783 |
+
|
| 784 |
+
#wavs = (wavs.cpu().numpy() * 32768).astype("int16")
|
| 785 |
+
wavs = (wavs.cpu().numpy())
|
| 786 |
+
|
| 787 |
+
if lengths is not None:
|
| 788 |
+
wavs = wavs[:, :lengths]
|
| 789 |
+
|
| 790 |
+
# wavs = [wav for wav in wavs]
|
| 791 |
+
|
| 792 |
+
# for i in range(len(mels)):
|
| 793 |
+
# if lengths is not None:
|
| 794 |
+
# wavs[i] = wavs[i][: lengths[i]]
|
| 795 |
+
|
| 796 |
+
return wavs
|
| 797 |
+
|
| 798 |
+
@torch.no_grad()
|
| 799 |
+
def vocoder_chunk_infer(mels, vocoder, lengths=None):
|
| 800 |
+
chunk_size = 256*4
|
| 801 |
+
shift_size = 256*1
|
| 802 |
+
ov_size = chunk_size-shift_size
|
| 803 |
+
# import pdb;pdb.set_trace()
|
| 804 |
+
|
| 805 |
+
for cinx in range(0, mels.shape[2], shift_size):
|
| 806 |
+
if(cinx==0):
|
| 807 |
+
wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()
|
| 808 |
+
num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size
|
| 809 |
+
wavs = wavs[:,0:num_samples]
|
| 810 |
+
ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size)
|
| 811 |
+
ov_win = torch.from_numpy(np.linspace(0,1,ov_sample)[None,:])
|
| 812 |
+
ov_win = torch.cat([ov_win,1-ov_win],-1)
|
| 813 |
+
if(cinx+chunk_size>=mels.shape[2]):
|
| 814 |
+
break
|
| 815 |
+
else:
|
| 816 |
+
cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()[:,0:num_samples]
|
| 817 |
+
wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample]
|
| 818 |
+
# wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0
|
| 819 |
+
wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1)
|
| 820 |
+
if(cinx+chunk_size>=mels.shape[2]):
|
| 821 |
+
break
|
| 822 |
+
# print(wavs.shape)
|
| 823 |
+
|
| 824 |
+
wavs = (wavs.cpu().numpy())
|
| 825 |
+
|
| 826 |
+
if lengths is not None:
|
| 827 |
+
wavs = wavs[:, :lengths]
|
| 828 |
+
# print(wavs.shape)
|
| 829 |
+
return wavs
|
| 830 |
+
|
| 831 |
+
def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
|
| 832 |
+
if vocoder is not None:
|
| 833 |
+
|
| 834 |
+
wav_reconstruction = vocoder_infer(
|
| 835 |
+
mel_input.permute(0, 2, 1),
|
| 836 |
+
vocoder,
|
| 837 |
+
)
|
| 838 |
+
wav_prediction = vocoder_infer(
|
| 839 |
+
mel_prediction.permute(0, 2, 1),
|
| 840 |
+
vocoder,
|
| 841 |
+
)
|
| 842 |
+
else:
|
| 843 |
+
wav_reconstruction = wav_prediction = None
|
| 844 |
+
|
| 845 |
+
return wav_reconstruction, wav_prediction
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
class AutoencoderKL(nn.Module):
|
| 849 |
+
def __init__(
|
| 850 |
+
self,
|
| 851 |
+
ddconfig=None,
|
| 852 |
+
lossconfig=None,
|
| 853 |
+
batchsize=None,
|
| 854 |
+
embed_dim=None,
|
| 855 |
+
time_shuffle=1,
|
| 856 |
+
subband=1,
|
| 857 |
+
sampling_rate=16000,
|
| 858 |
+
ckpt_path=None,
|
| 859 |
+
reload_from_ckpt=None,
|
| 860 |
+
ignore_keys=[],
|
| 861 |
+
image_key="fbank",
|
| 862 |
+
colorize_nlabels=None,
|
| 863 |
+
monitor=None,
|
| 864 |
+
base_learning_rate=1e-5,
|
| 865 |
+
scale_factor=1
|
| 866 |
+
):
|
| 867 |
+
super().__init__()
|
| 868 |
+
self.automatic_optimization = False
|
| 869 |
+
assert (
|
| 870 |
+
"mel_bins" in ddconfig.keys()
|
| 871 |
+
), "mel_bins is not specified in the Autoencoder config"
|
| 872 |
+
num_mel = ddconfig["mel_bins"]
|
| 873 |
+
self.image_key = image_key
|
| 874 |
+
self.sampling_rate = sampling_rate
|
| 875 |
+
self.encoder = Encoder(**ddconfig)
|
| 876 |
+
self.decoder = Decoder(**ddconfig)
|
| 877 |
+
|
| 878 |
+
self.loss = None
|
| 879 |
+
self.subband = int(subband)
|
| 880 |
+
|
| 881 |
+
if self.subband > 1:
|
| 882 |
+
print("Use subband decomposition %s" % self.subband)
|
| 883 |
+
|
| 884 |
+
assert ddconfig["double_z"]
|
| 885 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
| 886 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 887 |
+
|
| 888 |
+
if self.image_key == "fbank":
|
| 889 |
+
self.vocoder = get_vocoder(None, "cpu", num_mel)
|
| 890 |
+
self.embed_dim = embed_dim
|
| 891 |
+
if colorize_nlabels is not None:
|
| 892 |
+
assert type(colorize_nlabels) == int
|
| 893 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 894 |
+
if monitor is not None:
|
| 895 |
+
self.monitor = monitor
|
| 896 |
+
if ckpt_path is not None:
|
| 897 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 898 |
+
self.learning_rate = float(base_learning_rate)
|
| 899 |
+
# print("Initial learning rate %s" % self.learning_rate)
|
| 900 |
+
|
| 901 |
+
self.time_shuffle = time_shuffle
|
| 902 |
+
self.reload_from_ckpt = reload_from_ckpt
|
| 903 |
+
self.reloaded = False
|
| 904 |
+
self.mean, self.std = None, None
|
| 905 |
+
|
| 906 |
+
self.feature_cache = None
|
| 907 |
+
self.flag_first_run = True
|
| 908 |
+
self.train_step = 0
|
| 909 |
+
|
| 910 |
+
self.logger_save_dir = None
|
| 911 |
+
self.logger_exp_name = None
|
| 912 |
+
self.scale_factor = scale_factor
|
| 913 |
+
|
| 914 |
+
print("Num parameters:")
|
| 915 |
+
print("Encoder : ", sum(p.numel() for p in self.encoder.parameters()))
|
| 916 |
+
print("Decoder : ", sum(p.numel() for p in self.decoder.parameters()))
|
| 917 |
+
print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters()))
|
| 918 |
+
|
| 919 |
+
def get_log_dir(self):
|
| 920 |
+
if self.logger_save_dir is None and self.logger_exp_name is None:
|
| 921 |
+
return os.path.join(self.logger.save_dir, self.logger._project)
|
| 922 |
+
else:
|
| 923 |
+
return os.path.join(self.logger_save_dir, self.logger_exp_name)
|
| 924 |
+
|
| 925 |
+
def set_log_dir(self, save_dir, exp_name):
|
| 926 |
+
self.logger_save_dir = save_dir
|
| 927 |
+
self.logger_exp_name = exp_name
|
| 928 |
+
|
| 929 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 930 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 931 |
+
keys = list(sd.keys())
|
| 932 |
+
for k in keys:
|
| 933 |
+
for ik in ignore_keys:
|
| 934 |
+
if k.startswith(ik):
|
| 935 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 936 |
+
del sd[k]
|
| 937 |
+
self.load_state_dict(sd, strict=False)
|
| 938 |
+
print(f"Restored from {path}")
|
| 939 |
+
|
| 940 |
+
def encode(self, x):
|
| 941 |
+
# x = self.time_shuffle_operation(x)
|
| 942 |
+
# x = self.freq_split_subband(x)
|
| 943 |
+
h = self.encoder(x)
|
| 944 |
+
moments = self.quant_conv(h)
|
| 945 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 946 |
+
return posterior
|
| 947 |
+
|
| 948 |
+
def decode(self, z):
|
| 949 |
+
z = self.post_quant_conv(z)
|
| 950 |
+
dec = self.decoder(z)
|
| 951 |
+
# bs, ch, shuffled_timesteps, fbins = dec.size()
|
| 952 |
+
# dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
|
| 953 |
+
# dec = self.freq_merge_subband(dec)
|
| 954 |
+
return dec
|
| 955 |
+
|
| 956 |
+
def decode_to_waveform(self, dec):
|
| 957 |
+
|
| 958 |
+
if self.image_key == "fbank":
|
| 959 |
+
dec = dec.squeeze(1).permute(0, 2, 1)
|
| 960 |
+
wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder)
|
| 961 |
+
elif self.image_key == "stft":
|
| 962 |
+
dec = dec.squeeze(1).permute(0, 2, 1)
|
| 963 |
+
wav_reconstruction = self.wave_decoder(dec)
|
| 964 |
+
return wav_reconstruction
|
| 965 |
+
|
| 966 |
+
def mel_spectrogram_to_waveform(
|
| 967 |
+
self, mel, savepath=".", bs=None, name="outwav", save=True
|
| 968 |
+
):
|
| 969 |
+
# Mel: [bs, 1, t-steps, fbins]
|
| 970 |
+
if len(mel.size()) == 4:
|
| 971 |
+
mel = mel.squeeze(1)
|
| 972 |
+
mel = mel.permute(0, 2, 1)
|
| 973 |
+
waveform = self.vocoder(mel)
|
| 974 |
+
waveform = waveform.cpu().detach().numpy()
|
| 975 |
+
#if save:
|
| 976 |
+
# self.save_waveform(waveform, savepath, name)
|
| 977 |
+
return waveform
|
| 978 |
+
|
| 979 |
+
@torch.no_grad()
|
| 980 |
+
def encode_first_stage(self, x):
|
| 981 |
+
return self.encode(x)
|
| 982 |
+
|
| 983 |
+
@torch.no_grad()
|
| 984 |
+
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
| 985 |
+
if predict_cids:
|
| 986 |
+
if z.dim() == 4:
|
| 987 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
| 988 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
| 989 |
+
z = rearrange(z, "b h w c -> b c h w").contiguous()
|
| 990 |
+
|
| 991 |
+
z = 1.0 / self.scale_factor * z
|
| 992 |
+
return self.decode(z)
|
| 993 |
+
|
| 994 |
+
def decode_first_stage_withgrad(self, z):
|
| 995 |
+
z = 1.0 / self.scale_factor * z
|
| 996 |
+
return self.decode(z)
|
| 997 |
+
|
| 998 |
+
def get_first_stage_encoding(self, encoder_posterior, use_mode=False):
|
| 999 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode:
|
| 1000 |
+
z = encoder_posterior.sample()
|
| 1001 |
+
elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode:
|
| 1002 |
+
z = encoder_posterior.mode()
|
| 1003 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
| 1004 |
+
z = encoder_posterior
|
| 1005 |
+
else:
|
| 1006 |
+
raise NotImplementedError(
|
| 1007 |
+
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
|
| 1008 |
+
)
|
| 1009 |
+
return self.scale_factor * z
|
| 1010 |
+
|
| 1011 |
+
def visualize_latent(self, input):
|
| 1012 |
+
import matplotlib.pyplot as plt
|
| 1013 |
+
|
| 1014 |
+
# for i in range(10):
|
| 1015 |
+
# zero_input = torch.zeros_like(input) - 11.59
|
| 1016 |
+
# zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59
|
| 1017 |
+
|
| 1018 |
+
# posterior = self.encode(zero_input)
|
| 1019 |
+
# latent = posterior.sample()
|
| 1020 |
+
# avg_latent = torch.mean(latent, dim=1)[0]
|
| 1021 |
+
# plt.imshow(avg_latent.cpu().detach().numpy().T)
|
| 1022 |
+
# plt.savefig("%s.png" % i)
|
| 1023 |
+
# plt.close()
|
| 1024 |
+
|
| 1025 |
+
np.save("input.npy", input.cpu().detach().numpy())
|
| 1026 |
+
# zero_input = torch.zeros_like(input) - 11.59
|
| 1027 |
+
time_input = input.clone()
|
| 1028 |
+
time_input[:, :, :, :32] *= 0
|
| 1029 |
+
time_input[:, :, :, :32] -= 11.59
|
| 1030 |
+
|
| 1031 |
+
np.save("time_input.npy", time_input.cpu().detach().numpy())
|
| 1032 |
+
|
| 1033 |
+
posterior = self.encode(time_input)
|
| 1034 |
+
latent = posterior.sample()
|
| 1035 |
+
np.save("time_latent.npy", latent.cpu().detach().numpy())
|
| 1036 |
+
avg_latent = torch.mean(latent, dim=1)
|
| 1037 |
+
for i in range(avg_latent.size(0)):
|
| 1038 |
+
plt.imshow(avg_latent[i].cpu().detach().numpy().T)
|
| 1039 |
+
plt.savefig("freq_%s.png" % i)
|
| 1040 |
+
plt.close()
|
| 1041 |
+
|
| 1042 |
+
freq_input = input.clone()
|
| 1043 |
+
freq_input[:, :, :512, :] *= 0
|
| 1044 |
+
freq_input[:, :, :512, :] -= 11.59
|
| 1045 |
+
|
| 1046 |
+
np.save("freq_input.npy", freq_input.cpu().detach().numpy())
|
| 1047 |
+
|
| 1048 |
+
posterior = self.encode(freq_input)
|
| 1049 |
+
latent = posterior.sample()
|
| 1050 |
+
np.save("freq_latent.npy", latent.cpu().detach().numpy())
|
| 1051 |
+
avg_latent = torch.mean(latent, dim=1)
|
| 1052 |
+
for i in range(avg_latent.size(0)):
|
| 1053 |
+
plt.imshow(avg_latent[i].cpu().detach().numpy().T)
|
| 1054 |
+
plt.savefig("time_%s.png" % i)
|
| 1055 |
+
plt.close()
|
| 1056 |
+
|
| 1057 |
+
def get_input(self, batch):
|
| 1058 |
+
fname, text, label_indices, waveform, stft, fbank = (
|
| 1059 |
+
batch["fname"],
|
| 1060 |
+
batch["text"],
|
| 1061 |
+
batch["label_vector"],
|
| 1062 |
+
batch["waveform"],
|
| 1063 |
+
batch["stft"],
|
| 1064 |
+
batch["log_mel_spec"],
|
| 1065 |
+
)
|
| 1066 |
+
# if(self.time_shuffle != 1):
|
| 1067 |
+
# if(fbank.size(1) % self.time_shuffle != 0):
|
| 1068 |
+
# pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
|
| 1069 |
+
# fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))
|
| 1070 |
+
|
| 1071 |
+
ret = {}
|
| 1072 |
+
|
| 1073 |
+
ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
|
| 1074 |
+
fbank.unsqueeze(1),
|
| 1075 |
+
stft.unsqueeze(1),
|
| 1076 |
+
fname,
|
| 1077 |
+
waveform.unsqueeze(1),
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
return ret
|
| 1081 |
+
|
| 1082 |
+
def save_wave(self, batch_wav, fname, save_dir):
|
| 1083 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 1084 |
+
|
| 1085 |
+
for wav, name in zip(batch_wav, fname):
|
| 1086 |
+
name = os.path.basename(name)
|
| 1087 |
+
|
| 1088 |
+
sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
|
| 1089 |
+
|
| 1090 |
+
def get_last_layer(self):
|
| 1091 |
+
return self.decoder.conv_out.weight
|
| 1092 |
+
|
| 1093 |
+
@torch.no_grad()
|
| 1094 |
+
def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
|
| 1095 |
+
log = dict()
|
| 1096 |
+
x = batch.to(self.device)
|
| 1097 |
+
if not only_inputs:
|
| 1098 |
+
xrec, posterior = self(x)
|
| 1099 |
+
log["samples"] = self.decode(posterior.sample())
|
| 1100 |
+
log["reconstructions"] = xrec
|
| 1101 |
+
|
| 1102 |
+
log["inputs"] = x
|
| 1103 |
+
wavs = self._log_img(log, train=train, index=0, waveform=waveform)
|
| 1104 |
+
return wavs
|
| 1105 |
+
|
| 1106 |
+
def _log_img(self, log, train=True, index=0, waveform=None):
|
| 1107 |
+
images_input = self.tensor2numpy(log["inputs"][index, 0]).T
|
| 1108 |
+
images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
|
| 1109 |
+
images_samples = self.tensor2numpy(log["samples"][index, 0]).T
|
| 1110 |
+
|
| 1111 |
+
if train:
|
| 1112 |
+
name = "train"
|
| 1113 |
+
else:
|
| 1114 |
+
name = "val"
|
| 1115 |
+
|
| 1116 |
+
if self.logger is not None:
|
| 1117 |
+
self.logger.log_image(
|
| 1118 |
+
"img_%s" % name,
|
| 1119 |
+
[images_input, images_reconstruct, images_samples],
|
| 1120 |
+
caption=["input", "reconstruct", "samples"],
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
inputs, reconstructions, samples = (
|
| 1124 |
+
log["inputs"],
|
| 1125 |
+
log["reconstructions"],
|
| 1126 |
+
log["samples"],
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
if self.image_key == "fbank":
|
| 1130 |
+
wav_original, wav_prediction = synth_one_sample(
|
| 1131 |
+
inputs[index],
|
| 1132 |
+
reconstructions[index],
|
| 1133 |
+
labels="validation",
|
| 1134 |
+
vocoder=self.vocoder,
|
| 1135 |
+
)
|
| 1136 |
+
wav_original, wav_samples = synth_one_sample(
|
| 1137 |
+
inputs[index], samples[index], labels="validation", vocoder=self.vocoder
|
| 1138 |
+
)
|
| 1139 |
+
wav_original, wav_samples, wav_prediction = (
|
| 1140 |
+
wav_original[0],
|
| 1141 |
+
wav_samples[0],
|
| 1142 |
+
wav_prediction[0],
|
| 1143 |
+
)
|
| 1144 |
+
elif self.image_key == "stft":
|
| 1145 |
+
wav_prediction = (
|
| 1146 |
+
self.decode_to_waveform(reconstructions)[index, 0]
|
| 1147 |
+
.cpu()
|
| 1148 |
+
.detach()
|
| 1149 |
+
.numpy()
|
| 1150 |
+
)
|
| 1151 |
+
wav_samples = (
|
| 1152 |
+
self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
|
| 1153 |
+
)
|
| 1154 |
+
wav_original = waveform[index, 0].cpu().detach().numpy()
|
| 1155 |
+
|
| 1156 |
+
if self.logger is not None:
|
| 1157 |
+
self.logger.experiment.log(
|
| 1158 |
+
{
|
| 1159 |
+
"original_%s"
|
| 1160 |
+
% name: wandb.Audio(
|
| 1161 |
+
wav_original, caption="original", sample_rate=self.sampling_rate
|
| 1162 |
+
),
|
| 1163 |
+
"reconstruct_%s"
|
| 1164 |
+
% name: wandb.Audio(
|
| 1165 |
+
wav_prediction,
|
| 1166 |
+
caption="reconstruct",
|
| 1167 |
+
sample_rate=self.sampling_rate,
|
| 1168 |
+
),
|
| 1169 |
+
"samples_%s"
|
| 1170 |
+
% name: wandb.Audio(
|
| 1171 |
+
wav_samples, caption="samples", sample_rate=self.sampling_rate
|
| 1172 |
+
),
|
| 1173 |
+
}
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
return wav_original, wav_prediction, wav_samples
|
| 1177 |
+
|
| 1178 |
+
def tensor2numpy(self, tensor):
|
| 1179 |
+
return tensor.cpu().detach().numpy()
|
| 1180 |
+
|
| 1181 |
+
def to_rgb(self, x):
|
| 1182 |
+
assert self.image_key == "segmentation"
|
| 1183 |
+
if not hasattr(self, "colorize"):
|
| 1184 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 1185 |
+
x = torch.nn.functional.conv2d(x, weight=self.colorize)
|
| 1186 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
| 1187 |
+
return x
|
| 1188 |
+
|
| 1189 |
+
|
| 1190 |
+
class IdentityFirstStage(torch.nn.Module):
|
| 1191 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
| 1192 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
| 1193 |
+
super().__init__()
|
| 1194 |
+
|
| 1195 |
+
def encode(self, x, *args, **kwargs):
|
| 1196 |
+
return x
|
| 1197 |
+
|
| 1198 |
+
def decode(self, x, *args, **kwargs):
|
| 1199 |
+
return x
|
| 1200 |
+
|
| 1201 |
+
def quantize(self, x, *args, **kwargs):
|
| 1202 |
+
if self.vq_interface:
|
| 1203 |
+
return x, None, [None, None, None]
|
| 1204 |
+
return x
|
| 1205 |
+
|
| 1206 |
+
def forward(self, x, *args, **kwargs):
|
| 1207 |
+
return x
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
def window_sumsquare(
|
| 1211 |
+
window,
|
| 1212 |
+
n_frames,
|
| 1213 |
+
hop_length,
|
| 1214 |
+
win_length,
|
| 1215 |
+
n_fft,
|
| 1216 |
+
dtype=np.float32,
|
| 1217 |
+
norm=None,
|
| 1218 |
+
):
|
| 1219 |
+
"""
|
| 1220 |
+
# from librosa 0.6
|
| 1221 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
| 1222 |
+
|
| 1223 |
+
This is used to estimate modulation effects induced by windowing
|
| 1224 |
+
observations in short-time fourier transforms.
|
| 1225 |
+
|
| 1226 |
+
Parameters
|
| 1227 |
+
----------
|
| 1228 |
+
window : string, tuple, number, callable, or list-like
|
| 1229 |
+
Window specification, as in `get_window`
|
| 1230 |
+
|
| 1231 |
+
n_frames : int > 0
|
| 1232 |
+
The number of analysis frames
|
| 1233 |
+
|
| 1234 |
+
hop_length : int > 0
|
| 1235 |
+
The number of samples to advance between frames
|
| 1236 |
+
|
| 1237 |
+
win_length : [optional]
|
| 1238 |
+
The length of the window function. By default, this matches `n_fft`.
|
| 1239 |
+
|
| 1240 |
+
n_fft : int > 0
|
| 1241 |
+
The length of each analysis frame.
|
| 1242 |
+
|
| 1243 |
+
dtype : np.dtype
|
| 1244 |
+
The data type of the output
|
| 1245 |
+
|
| 1246 |
+
Returns
|
| 1247 |
+
-------
|
| 1248 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
| 1249 |
+
The sum-squared envelope of the window function
|
| 1250 |
+
"""
|
| 1251 |
+
if win_length is None:
|
| 1252 |
+
win_length = n_fft
|
| 1253 |
+
|
| 1254 |
+
n = n_fft + hop_length * (n_frames - 1)
|
| 1255 |
+
x = np.zeros(n, dtype=dtype)
|
| 1256 |
+
|
| 1257 |
+
# Compute the squared window at the desired length
|
| 1258 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
| 1259 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
| 1260 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
| 1261 |
+
|
| 1262 |
+
# Fill the envelope
|
| 1263 |
+
for i in range(n_frames):
|
| 1264 |
+
sample = i * hop_length
|
| 1265 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
| 1266 |
+
return x
|
| 1267 |
+
|
| 1268 |
+
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
|
| 1269 |
+
"""
|
| 1270 |
+
PARAMS
|
| 1271 |
+
------
|
| 1272 |
+
C: compression factor
|
| 1273 |
+
"""
|
| 1274 |
+
return normalize_fun(torch.clamp(x, min=clip_val) * C)
|
| 1275 |
+
|
| 1276 |
+
|
| 1277 |
+
def dynamic_range_decompression(x, C=1):
|
| 1278 |
+
"""
|
| 1279 |
+
PARAMS
|
| 1280 |
+
------
|
| 1281 |
+
C: compression factor used to compress
|
| 1282 |
+
"""
|
| 1283 |
+
return torch.exp(x) / C
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
class STFT(torch.nn.Module):
|
| 1287 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
| 1288 |
+
|
| 1289 |
+
def __init__(self, filter_length, hop_length, win_length, window="hann"):
|
| 1290 |
+
super(STFT, self).__init__()
|
| 1291 |
+
self.filter_length = filter_length
|
| 1292 |
+
self.hop_length = hop_length
|
| 1293 |
+
self.win_length = win_length
|
| 1294 |
+
self.window = window
|
| 1295 |
+
self.forward_transform = None
|
| 1296 |
+
scale = self.filter_length / self.hop_length
|
| 1297 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
| 1298 |
+
|
| 1299 |
+
cutoff = int((self.filter_length / 2 + 1))
|
| 1300 |
+
fourier_basis = np.vstack(
|
| 1301 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
| 1305 |
+
inverse_basis = torch.FloatTensor(
|
| 1306 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
if window is not None:
|
| 1310 |
+
assert filter_length >= win_length
|
| 1311 |
+
# get window and zero center pad it to filter_length
|
| 1312 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
| 1313 |
+
fft_window = pad_center(fft_window, size=filter_length)
|
| 1314 |
+
fft_window = torch.from_numpy(fft_window).float()
|
| 1315 |
+
|
| 1316 |
+
# window the bases
|
| 1317 |
+
forward_basis *= fft_window
|
| 1318 |
+
inverse_basis *= fft_window
|
| 1319 |
+
|
| 1320 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
| 1321 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
| 1322 |
+
|
| 1323 |
+
def transform(self, input_data):
|
| 1324 |
+
|
| 1325 |
+
device = self.forward_basis.device
|
| 1326 |
+
input_data = input_data.to(device)
|
| 1327 |
+
|
| 1328 |
+
num_batches = input_data.size(0)
|
| 1329 |
+
num_samples = input_data.size(1)
|
| 1330 |
+
|
| 1331 |
+
self.num_samples = num_samples
|
| 1332 |
+
|
| 1333 |
+
# similar to librosa, reflect-pad the input
|
| 1334 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
| 1335 |
+
input_data = torch.nn.functional.pad(
|
| 1336 |
+
input_data.unsqueeze(1),
|
| 1337 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
| 1338 |
+
mode="reflect",
|
| 1339 |
+
)
|
| 1340 |
+
input_data = input_data.squeeze(1)
|
| 1341 |
+
|
| 1342 |
+
forward_transform = torch.nn.functional.conv1d(
|
| 1343 |
+
input_data,
|
| 1344 |
+
torch.autograd.Variable(self.forward_basis, requires_grad=False),
|
| 1345 |
+
stride=self.hop_length,
|
| 1346 |
+
padding=0,
|
| 1347 |
+
)#.cpu()
|
| 1348 |
+
|
| 1349 |
+
cutoff = int((self.filter_length / 2) + 1)
|
| 1350 |
+
real_part = forward_transform[:, :cutoff, :]
|
| 1351 |
+
imag_part = forward_transform[:, cutoff:, :]
|
| 1352 |
+
|
| 1353 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
| 1354 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
| 1355 |
+
|
| 1356 |
+
return magnitude, phase
|
| 1357 |
+
|
| 1358 |
+
def inverse(self, magnitude, phase):
|
| 1359 |
+
|
| 1360 |
+
device = self.forward_basis.device
|
| 1361 |
+
magnitude, phase = magnitude.to(device), phase.to(device)
|
| 1362 |
+
|
| 1363 |
+
recombine_magnitude_phase = torch.cat(
|
| 1364 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
inverse_transform = torch.nn.functional.conv_transpose1d(
|
| 1368 |
+
recombine_magnitude_phase,
|
| 1369 |
+
torch.autograd.Variable(self.inverse_basis, requires_grad=False),
|
| 1370 |
+
stride=self.hop_length,
|
| 1371 |
+
padding=0,
|
| 1372 |
+
)
|
| 1373 |
+
|
| 1374 |
+
if self.window is not None:
|
| 1375 |
+
window_sum = window_sumsquare(
|
| 1376 |
+
self.window,
|
| 1377 |
+
magnitude.size(-1),
|
| 1378 |
+
hop_length=self.hop_length,
|
| 1379 |
+
win_length=self.win_length,
|
| 1380 |
+
n_fft=self.filter_length,
|
| 1381 |
+
dtype=np.float32,
|
| 1382 |
+
)
|
| 1383 |
+
# remove modulation effects
|
| 1384 |
+
approx_nonzero_indices = torch.from_numpy(
|
| 1385 |
+
np.where(window_sum > tiny(window_sum))[0]
|
| 1386 |
+
)
|
| 1387 |
+
window_sum = torch.autograd.Variable(
|
| 1388 |
+
torch.from_numpy(window_sum), requires_grad=False
|
| 1389 |
+
)
|
| 1390 |
+
window_sum = window_sum
|
| 1391 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
| 1392 |
+
approx_nonzero_indices
|
| 1393 |
+
]
|
| 1394 |
+
|
| 1395 |
+
# scale by hop ratio
|
| 1396 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
| 1397 |
+
|
| 1398 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
| 1399 |
+
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
| 1400 |
+
|
| 1401 |
+
return inverse_transform
|
| 1402 |
+
|
| 1403 |
+
def forward(self, input_data):
|
| 1404 |
+
self.magnitude, self.phase = self.transform(input_data)
|
| 1405 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
| 1406 |
+
return reconstruction
|
| 1407 |
+
|
| 1408 |
+
|
| 1409 |
+
class TacotronSTFT(torch.nn.Module):
|
| 1410 |
+
def __init__(
|
| 1411 |
+
self,
|
| 1412 |
+
filter_length,
|
| 1413 |
+
hop_length,
|
| 1414 |
+
win_length,
|
| 1415 |
+
n_mel_channels,
|
| 1416 |
+
sampling_rate,
|
| 1417 |
+
mel_fmin,
|
| 1418 |
+
mel_fmax,
|
| 1419 |
+
):
|
| 1420 |
+
super(TacotronSTFT, self).__init__()
|
| 1421 |
+
self.n_mel_channels = n_mel_channels
|
| 1422 |
+
self.sampling_rate = sampling_rate
|
| 1423 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
| 1424 |
+
mel_basis = librosa_mel_fn(
|
| 1425 |
+
sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax
|
| 1426 |
+
)
|
| 1427 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 1428 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 1429 |
+
|
| 1430 |
+
def spectral_normalize(self, magnitudes, normalize_fun):
|
| 1431 |
+
output = dynamic_range_compression(magnitudes, normalize_fun)
|
| 1432 |
+
return output
|
| 1433 |
+
|
| 1434 |
+
def spectral_de_normalize(self, magnitudes):
|
| 1435 |
+
output = dynamic_range_decompression(magnitudes)
|
| 1436 |
+
return output
|
| 1437 |
+
|
| 1438 |
+
def mel_spectrogram(self, y, normalize_fun=torch.log):
|
| 1439 |
+
"""Computes mel-spectrograms from a batch of waves
|
| 1440 |
+
PARAMS
|
| 1441 |
+
------
|
| 1442 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
| 1443 |
+
|
| 1444 |
+
RETURNS
|
| 1445 |
+
-------
|
| 1446 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
| 1447 |
+
"""
|
| 1448 |
+
assert torch.min(y.data) >= -1, torch.min(y.data)
|
| 1449 |
+
assert torch.max(y.data) <= 1, torch.max(y.data)
|
| 1450 |
+
|
| 1451 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
| 1452 |
+
magnitudes = magnitudes.data
|
| 1453 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
| 1454 |
+
mel_output = self.spectral_normalize(mel_output, normalize_fun)
|
| 1455 |
+
energy = torch.norm(magnitudes, dim=1)
|
| 1456 |
+
|
| 1457 |
+
log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
|
| 1458 |
+
|
| 1459 |
+
return mel_output, log_magnitudes, energy
|
| 1460 |
+
|
| 1461 |
+
|
| 1462 |
+
def build_pretrained_models(ckpt):
|
| 1463 |
+
checkpoint = torch.load(ckpt, map_location="cpu")
|
| 1464 |
+
scale_factor = checkpoint["state_dict"]["scale_factor"].item()
|
| 1465 |
+
print("scale_factor: ", scale_factor)
|
| 1466 |
+
|
| 1467 |
+
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
|
| 1468 |
+
|
| 1469 |
+
config = {
|
| 1470 |
+
"preprocessing": {
|
| 1471 |
+
"audio": {
|
| 1472 |
+
"sampling_rate": 48000,
|
| 1473 |
+
"max_wav_value": 32768,
|
| 1474 |
+
"duration": 10.24
|
| 1475 |
+
},
|
| 1476 |
+
"stft": {
|
| 1477 |
+
"filter_length": 2048,
|
| 1478 |
+
"hop_length": 480,
|
| 1479 |
+
"win_length": 2048
|
| 1480 |
+
},
|
| 1481 |
+
"mel": {
|
| 1482 |
+
"n_mel_channels": 256,
|
| 1483 |
+
"mel_fmin": 20,
|
| 1484 |
+
"mel_fmax": 24000
|
| 1485 |
+
}
|
| 1486 |
+
},
|
| 1487 |
+
"model": {
|
| 1488 |
+
"params": {
|
| 1489 |
+
"first_stage_config": {
|
| 1490 |
+
"params": {
|
| 1491 |
+
"sampling_rate": 48000,
|
| 1492 |
+
"batchsize": 4,
|
| 1493 |
+
"monitor": "val/rec_loss",
|
| 1494 |
+
"image_key": "fbank",
|
| 1495 |
+
"subband": 1,
|
| 1496 |
+
"embed_dim": 16,
|
| 1497 |
+
"time_shuffle": 1,
|
| 1498 |
+
"lossconfig": {
|
| 1499 |
+
"target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator",
|
| 1500 |
+
"params": {
|
| 1501 |
+
"disc_start": 50001,
|
| 1502 |
+
"kl_weight": 1000,
|
| 1503 |
+
"disc_weight": 0.5,
|
| 1504 |
+
"disc_in_channels": 1
|
| 1505 |
+
}
|
| 1506 |
+
},
|
| 1507 |
+
"ddconfig": {
|
| 1508 |
+
"double_z": True,
|
| 1509 |
+
"mel_bins": 256,
|
| 1510 |
+
"z_channels": 16,
|
| 1511 |
+
"resolution": 256,
|
| 1512 |
+
"downsample_time": False,
|
| 1513 |
+
"in_channels": 1,
|
| 1514 |
+
"out_ch": 1,
|
| 1515 |
+
"ch": 128,
|
| 1516 |
+
"ch_mult": [
|
| 1517 |
+
1,
|
| 1518 |
+
2,
|
| 1519 |
+
4,
|
| 1520 |
+
8
|
| 1521 |
+
],
|
| 1522 |
+
"num_res_blocks": 2,
|
| 1523 |
+
"attn_resolutions": [],
|
| 1524 |
+
"dropout": 0
|
| 1525 |
+
}
|
| 1526 |
+
}
|
| 1527 |
+
},
|
| 1528 |
+
}
|
| 1529 |
+
}
|
| 1530 |
+
}
|
| 1531 |
+
vae_config = config["model"]["params"]["first_stage_config"]["params"]
|
| 1532 |
+
vae_config["scale_factor"] = scale_factor
|
| 1533 |
+
|
| 1534 |
+
vae = AutoencoderKL(**vae_config)
|
| 1535 |
+
vae.load_state_dict(vae_state_dict)
|
| 1536 |
+
|
| 1537 |
+
fn_STFT = TacotronSTFT(
|
| 1538 |
+
config["preprocessing"]["stft"]["filter_length"],
|
| 1539 |
+
config["preprocessing"]["stft"]["hop_length"],
|
| 1540 |
+
config["preprocessing"]["stft"]["win_length"],
|
| 1541 |
+
config["preprocessing"]["mel"]["n_mel_channels"],
|
| 1542 |
+
config["preprocessing"]["audio"]["sampling_rate"],
|
| 1543 |
+
config["preprocessing"]["mel"]["mel_fmin"],
|
| 1544 |
+
config["preprocessing"]["mel"]["mel_fmax"],
|
| 1545 |
+
)
|
| 1546 |
+
|
| 1547 |
+
vae.eval()
|
| 1548 |
+
fn_STFT.eval()
|
| 1549 |
+
return vae, fn_STFT
|
| 1550 |
+
|
| 1551 |
+
|
MuCodec/tools/torch_tools.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
import random
|
| 4 |
+
import itertools
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def normalize_wav(waveform):
|
| 10 |
+
waveform = waveform - torch.mean(waveform)
|
| 11 |
+
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
|
| 12 |
+
return waveform * 0.5
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def pad_wav(waveform, segment_length):
|
| 16 |
+
waveform_length = len(waveform)
|
| 17 |
+
|
| 18 |
+
if segment_length is None or waveform_length == segment_length:
|
| 19 |
+
return waveform
|
| 20 |
+
elif waveform_length > segment_length:
|
| 21 |
+
return waveform[:segment_length]
|
| 22 |
+
else:
|
| 23 |
+
pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
|
| 24 |
+
waveform = torch.cat([waveform, pad_wav])
|
| 25 |
+
return waveform
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _pad_spec(fbank, target_length=1024):
|
| 29 |
+
batch, n_frames, channels = fbank.shape
|
| 30 |
+
p = target_length - n_frames
|
| 31 |
+
if p > 0:
|
| 32 |
+
pad = torch.zeros(batch, p, channels).to(fbank.device)
|
| 33 |
+
fbank = torch.cat([fbank, pad], 1)
|
| 34 |
+
elif p < 0:
|
| 35 |
+
fbank = fbank[:, :target_length, :]
|
| 36 |
+
|
| 37 |
+
if channels % 2 != 0:
|
| 38 |
+
fbank = fbank[:, :, :-1]
|
| 39 |
+
|
| 40 |
+
return fbank
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def read_wav_file(filename, segment_length):
|
| 44 |
+
waveform, sr = torchaudio.load(filename) # Faster!!!
|
| 45 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0]
|
| 46 |
+
try:
|
| 47 |
+
waveform = normalize_wav(waveform)
|
| 48 |
+
except:
|
| 49 |
+
print ("Exception normalizing:", filename)
|
| 50 |
+
waveform = torch.ones(160000)
|
| 51 |
+
waveform = pad_wav(waveform, segment_length).unsqueeze(0)
|
| 52 |
+
waveform = waveform / torch.max(torch.abs(waveform))
|
| 53 |
+
waveform = 0.5 * waveform
|
| 54 |
+
return waveform
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_mel_from_wav(audio, _stft):
|
| 58 |
+
audio = torch.nan_to_num(torch.clip(audio, -1, 1))
|
| 59 |
+
audio = torch.autograd.Variable(audio, requires_grad=False)
|
| 60 |
+
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
|
| 61 |
+
return melspec, log_magnitudes_stft, energy
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def wav_to_fbank(paths, target_length=1024, fn_STFT=None):
|
| 65 |
+
assert fn_STFT is not None
|
| 66 |
+
|
| 67 |
+
waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160
|
| 68 |
+
|
| 69 |
+
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
| 70 |
+
fbank = fbank.transpose(1, 2)
|
| 71 |
+
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
| 72 |
+
|
| 73 |
+
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
| 74 |
+
log_magnitudes_stft, target_length
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return fbank, log_magnitudes_stft, waveform
|
| 78 |
+
|
| 79 |
+
def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None):
|
| 80 |
+
assert fn_STFT is not None
|
| 81 |
+
|
| 82 |
+
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
| 83 |
+
fbank = fbank.transpose(1, 2)
|
| 84 |
+
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
| 85 |
+
# print(fbank.shape, log_magnitudes_stft.shape)
|
| 86 |
+
|
| 87 |
+
if(target_length>0):
|
| 88 |
+
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
| 89 |
+
log_magnitudes_stft, target_length
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return fbank, log_magnitudes_stft, waveform
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def uncapitalize(s):
|
| 96 |
+
if s:
|
| 97 |
+
return s[:1].lower() + s[1:]
|
| 98 |
+
else:
|
| 99 |
+
return ""
|
| 100 |
+
|
__pycache__/audio_tokens.cpython-310.pyc
ADDED
|
Binary file (826 Bytes). View file
|
|
|
__pycache__/audio_tokens.cpython-312.pyc
ADDED
|
Binary file (951 Bytes). View file
|
|
|
__pycache__/condition_encoders.cpython-310.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
__pycache__/condition_encoders.cpython-312.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (21.5 kB). View file
|
|
|
__pycache__/decoders.cpython-310.pyc
ADDED
|
Binary file (3.86 kB). View file
|
|
|
__pycache__/decoders.cpython-312.pyc
ADDED
|
Binary file (5.77 kB). View file
|
|
|
__pycache__/inference_full.cpython-310.pyc
ADDED
|
Binary file (24.5 kB). View file
|
|
|
__pycache__/inference_full.cpython-312.pyc
ADDED
|
Binary file (35 kB). View file
|
|
|
__pycache__/modelling_qwen3.cpython-310.pyc
ADDED
|
Binary file (5.33 kB). View file
|
|
|
__pycache__/modelling_qwen3.cpython-312.pyc
ADDED
|
Binary file (9.53 kB). View file
|
|
|
__pycache__/runtime_utils.cpython-310.pyc
ADDED
|
Binary file (3.07 kB). View file
|
|
|
__pycache__/runtime_utils.cpython-312.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
audio_tokens.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
SOA_TOKEN = "[SOA]"
|
| 5 |
+
EOA_TOKEN = "[EOA]"
|
| 6 |
+
MASK_AUDIO_TOKEN = "[MASK_AUDIO]"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def audio_id_to_token(audio_id: int) -> str:
|
| 10 |
+
return f"<AUDIO_{int(audio_id)}>"
|
| 11 |
+
|
| 12 |
+
def add_audio_special_tokens(tokenizer, num_audio_token: int) -> int:
|
| 13 |
+
special_tokens = [audio_id_to_token(i) for i in range(num_audio_token)] + [
|
| 14 |
+
MASK_AUDIO_TOKEN,
|
| 15 |
+
SOA_TOKEN,
|
| 16 |
+
EOA_TOKEN,
|
| 17 |
+
]
|
| 18 |
+
return tokenizer.add_tokens(
|
| 19 |
+
special_tokens,
|
| 20 |
+
special_tokens=True,
|
| 21 |
+
)
|
batch_infer_checkpoints.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import traceback
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import datasets
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from inference_full import (
|
| 13 |
+
TokenLayout,
|
| 14 |
+
batch_generate_segmentwise,
|
| 15 |
+
build_mucodec_decoder,
|
| 16 |
+
generate_segmentwise,
|
| 17 |
+
load_hf_template_sample_from_music_dataset,
|
| 18 |
+
save_outputs,
|
| 19 |
+
)
|
| 20 |
+
from runtime_utils import (
|
| 21 |
+
load_magel_checkpoint,
|
| 22 |
+
load_music_dataset,
|
| 23 |
+
maybe_compile_model,
|
| 24 |
+
resolve_device,
|
| 25 |
+
seed_everything,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_args() -> argparse.Namespace:
|
| 30 |
+
parser = argparse.ArgumentParser(
|
| 31 |
+
description="Run audio inference on validation samples for multiple checkpoints."
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--checkpoint_list",
|
| 35 |
+
type=str,
|
| 36 |
+
default=None,
|
| 37 |
+
help="Text file with one checkpoint path per line.",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--checkpoint_dir",
|
| 41 |
+
type=str,
|
| 42 |
+
default=None,
|
| 43 |
+
help="Directory to scan for checkpoint-* subdirectories and optional final.",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--dataset_path",
|
| 47 |
+
type=str,
|
| 48 |
+
default="muse_mucodec_chord.ds",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--split",
|
| 52 |
+
type=str,
|
| 53 |
+
default="validation",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--tokenizer_path",
|
| 57 |
+
type=str,
|
| 58 |
+
default="checkpoints/Qwen3-0.6B",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--sample_indices",
|
| 62 |
+
type=int,
|
| 63 |
+
nargs="*",
|
| 64 |
+
default=None,
|
| 65 |
+
help="Specific sample indices to infer. Leave unset to run the full split.",
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--max_samples",
|
| 69 |
+
type=int,
|
| 70 |
+
default=0,
|
| 71 |
+
help="Run only the first N samples from the split. Ignored if --sample_indices is set.",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--infer_batch_size",
|
| 75 |
+
type=int,
|
| 76 |
+
default=1,
|
| 77 |
+
help="Number of samples to decode together per step for the same checkpoint.",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
| 80 |
+
parser.add_argument("--top_k", type=int, default=50)
|
| 81 |
+
parser.add_argument("--top_p", type=float, default=0.90)
|
| 82 |
+
parser.add_argument("--greedy", action="store_true", default=False)
|
| 83 |
+
parser.add_argument("--max_audio_tokens", type=int, default=0)
|
| 84 |
+
parser.add_argument("--fps", type=int, default=25)
|
| 85 |
+
parser.add_argument("--seed", type=int, default=1234)
|
| 86 |
+
parser.add_argument("--device", type=str, default="auto")
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--dtype",
|
| 89 |
+
type=str,
|
| 90 |
+
default="bfloat16",
|
| 91 |
+
choices=["float32", "float16", "bfloat16"],
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--attn_implementation",
|
| 95 |
+
type=str,
|
| 96 |
+
default="sdpa",
|
| 97 |
+
choices=["eager", "sdpa", "flash_attention_2"],
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument("--use_cache", action="store_true", default=True)
|
| 100 |
+
parser.add_argument("--no_cache", action="store_true", default=False)
|
| 101 |
+
parser.add_argument("--compile", action="store_true", default=False)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--compile_mode",
|
| 104 |
+
type=str,
|
| 105 |
+
default="reduce-overhead",
|
| 106 |
+
choices=["default", "reduce-overhead", "max-autotune"],
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument("--mucodec_device", type=str, default="auto")
|
| 109 |
+
parser.add_argument("--mucodec_layer_num", type=int, default=7)
|
| 110 |
+
parser.add_argument("--mucodec_duration", type=float, default=40.96)
|
| 111 |
+
parser.add_argument("--mucodec_guidance_scale", type=float, default=1.5)
|
| 112 |
+
parser.add_argument("--mucodec_num_steps", type=int, default=20)
|
| 113 |
+
parser.add_argument("--mucodec_sample_rate", type=int, default=48000)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--output_dir",
|
| 116 |
+
type=str,
|
| 117 |
+
default="/root/new_batch_predictions",
|
| 118 |
+
help="Root output dir. Each checkpoint gets its own subdirectory.",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--summary_json",
|
| 122 |
+
type=str,
|
| 123 |
+
default="/root/new_batch_predictions/summary.json",
|
| 124 |
+
)
|
| 125 |
+
args = parser.parse_args()
|
| 126 |
+
if not args.checkpoint_list and not args.checkpoint_dir:
|
| 127 |
+
parser.error("one of --checkpoint_list or --checkpoint_dir is required")
|
| 128 |
+
return args
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def parse_checkpoint_list(path: str) -> list[str]:
|
| 132 |
+
checkpoints: list[str] = []
|
| 133 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 134 |
+
for raw_line in f:
|
| 135 |
+
line = raw_line.strip()
|
| 136 |
+
if not line or line.startswith("#"):
|
| 137 |
+
continue
|
| 138 |
+
checkpoints.append(line)
|
| 139 |
+
if not checkpoints:
|
| 140 |
+
raise ValueError(f"No checkpoints found in list: {path}")
|
| 141 |
+
return checkpoints
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def scan_checkpoint_dir(path: str) -> list[str]:
|
| 145 |
+
root = Path(path)
|
| 146 |
+
if not root.is_dir():
|
| 147 |
+
raise NotADirectoryError(f"Checkpoint directory not found: {path}")
|
| 148 |
+
|
| 149 |
+
checkpoint_dirs = [
|
| 150 |
+
item
|
| 151 |
+
for item in root.iterdir()
|
| 152 |
+
if item.is_dir() and item.name.startswith("checkpoint-")
|
| 153 |
+
]
|
| 154 |
+
checkpoint_dirs = sorted(
|
| 155 |
+
checkpoint_dirs,
|
| 156 |
+
key=lambda p: int(p.name.split("-", 1)[1])
|
| 157 |
+
if p.name.split("-", 1)[1].isdigit()
|
| 158 |
+
else p.name,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
final_dir = root / "final"
|
| 162 |
+
if final_dir.is_dir():
|
| 163 |
+
checkpoint_dirs.append(final_dir)
|
| 164 |
+
|
| 165 |
+
checkpoints = [str(path_obj) for path_obj in checkpoint_dirs]
|
| 166 |
+
if not checkpoints:
|
| 167 |
+
raise ValueError(f"No checkpoint-* directories found under: {path}")
|
| 168 |
+
return checkpoints
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_dtype(name: str) -> torch.dtype:
|
| 172 |
+
return {
|
| 173 |
+
"float32": torch.float32,
|
| 174 |
+
"float16": torch.float16,
|
| 175 |
+
"bfloat16": torch.bfloat16,
|
| 176 |
+
}[name]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_split_size(dataset_path: str, split: str) -> int:
|
| 180 |
+
dataset_obj = datasets.load_from_disk(dataset_path)
|
| 181 |
+
if isinstance(dataset_obj, datasets.DatasetDict):
|
| 182 |
+
if split not in dataset_obj:
|
| 183 |
+
raise KeyError(f"Split not found: {split}")
|
| 184 |
+
return len(dataset_obj[split])
|
| 185 |
+
return len(dataset_obj)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def resolve_sample_indices(
|
| 189 |
+
dataset_path: str,
|
| 190 |
+
split: str,
|
| 191 |
+
sample_indices: list[int] | None,
|
| 192 |
+
max_samples: int,
|
| 193 |
+
) -> list[int]:
|
| 194 |
+
if sample_indices:
|
| 195 |
+
return list(sample_indices)
|
| 196 |
+
split_size = get_split_size(dataset_path, split)
|
| 197 |
+
if max_samples and max_samples > 0:
|
| 198 |
+
split_size = min(split_size, max_samples)
|
| 199 |
+
return list(range(split_size))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def sanitize_checkpoint_name(checkpoint_path: str) -> str:
|
| 203 |
+
path = Path(checkpoint_path.rstrip("/"))
|
| 204 |
+
if path.parent.name:
|
| 205 |
+
return f"{path.parent.name}__{path.name}"
|
| 206 |
+
return path.name
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def chunk_list(items: list[int], chunk_size: int) -> list[list[int]]:
|
| 210 |
+
return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def main() -> None:
|
| 214 |
+
args = parse_args()
|
| 215 |
+
seed_everything(args.seed)
|
| 216 |
+
|
| 217 |
+
if args.checkpoint_list:
|
| 218 |
+
checkpoints = parse_checkpoint_list(args.checkpoint_list)
|
| 219 |
+
else:
|
| 220 |
+
checkpoints = scan_checkpoint_dir(args.checkpoint_dir)
|
| 221 |
+
sample_indices = resolve_sample_indices(
|
| 222 |
+
dataset_path=args.dataset_path,
|
| 223 |
+
split=args.split,
|
| 224 |
+
sample_indices=args.sample_indices,
|
| 225 |
+
max_samples=args.max_samples,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
use_cache = args.use_cache and not args.no_cache
|
| 229 |
+
device = resolve_device(args.device)
|
| 230 |
+
dtype = get_dtype(args.dtype)
|
| 231 |
+
if device.type == "cpu" and dtype != torch.float32:
|
| 232 |
+
print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.")
|
| 233 |
+
dtype = torch.float32
|
| 234 |
+
|
| 235 |
+
output_root = Path(args.output_dir)
|
| 236 |
+
output_root.mkdir(parents=True, exist_ok=True)
|
| 237 |
+
|
| 238 |
+
print(f"[INFO] checkpoints={len(checkpoints)}")
|
| 239 |
+
print(f"[INFO] samples_per_checkpoint={len(sample_indices)}")
|
| 240 |
+
print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}")
|
| 241 |
+
|
| 242 |
+
mucodec_decoder = build_mucodec_decoder(args)
|
| 243 |
+
summary: list[dict] = []
|
| 244 |
+
|
| 245 |
+
for checkpoint_path in checkpoints:
|
| 246 |
+
ckpt_name = sanitize_checkpoint_name(checkpoint_path)
|
| 247 |
+
ckpt_output_dir = output_root / ckpt_name
|
| 248 |
+
json_dir = ckpt_output_dir / "json"
|
| 249 |
+
wav_dir = ckpt_output_dir / "wav"
|
| 250 |
+
|
| 251 |
+
print(f"\n[INFO] loading model from {checkpoint_path}")
|
| 252 |
+
model = load_magel_checkpoint(
|
| 253 |
+
checkpoint_path=checkpoint_path,
|
| 254 |
+
device=device,
|
| 255 |
+
dtype=dtype,
|
| 256 |
+
attn_implementation=args.attn_implementation,
|
| 257 |
+
)
|
| 258 |
+
model = maybe_compile_model(
|
| 259 |
+
model,
|
| 260 |
+
enabled=bool(args.compile),
|
| 261 |
+
mode=str(args.compile_mode),
|
| 262 |
+
)
|
| 263 |
+
num_audio_codebook = int(getattr(model.config, "magel_num_audio_token", 16384))
|
| 264 |
+
music_ds = load_music_dataset(
|
| 265 |
+
dataset_path=args.dataset_path,
|
| 266 |
+
split=args.split,
|
| 267 |
+
tokenizer_path=args.tokenizer_path,
|
| 268 |
+
num_audio_token=num_audio_codebook,
|
| 269 |
+
use_fast=True,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
checkpoint_record = {
|
| 273 |
+
"checkpoint_path": checkpoint_path,
|
| 274 |
+
"checkpoint_name": ckpt_name,
|
| 275 |
+
"status": "ok",
|
| 276 |
+
"num_samples_requested": len(sample_indices),
|
| 277 |
+
"results": [],
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
for batch_indices in chunk_list(sample_indices, max(1, int(args.infer_batch_size))):
|
| 282 |
+
samples = []
|
| 283 |
+
for sample_idx in batch_indices:
|
| 284 |
+
print(
|
| 285 |
+
f"[INFO] checkpoint={ckpt_name} sample_idx={sample_idx} split={args.split}"
|
| 286 |
+
)
|
| 287 |
+
samples.append(
|
| 288 |
+
load_hf_template_sample_from_music_dataset(
|
| 289 |
+
music_ds=music_ds,
|
| 290 |
+
sample_idx=sample_idx,
|
| 291 |
+
num_audio_codebook=num_audio_codebook,
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
layout = TokenLayout(
|
| 296 |
+
num_text_token=samples[0].num_text_token,
|
| 297 |
+
num_audio_codebook=num_audio_codebook,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if len(samples) == 1:
|
| 301 |
+
batch_outputs = [
|
| 302 |
+
generate_segmentwise(
|
| 303 |
+
model=model,
|
| 304 |
+
sample=samples[0],
|
| 305 |
+
layout=layout,
|
| 306 |
+
device=device,
|
| 307 |
+
use_cache=use_cache,
|
| 308 |
+
temperature=float(args.temperature),
|
| 309 |
+
top_k=int(args.top_k),
|
| 310 |
+
top_p=float(args.top_p),
|
| 311 |
+
greedy=bool(args.greedy),
|
| 312 |
+
max_audio_tokens=max(0, int(args.max_audio_tokens)),
|
| 313 |
+
)
|
| 314 |
+
]
|
| 315 |
+
else:
|
| 316 |
+
try:
|
| 317 |
+
batch_outputs = batch_generate_segmentwise(
|
| 318 |
+
model=model,
|
| 319 |
+
samples=samples,
|
| 320 |
+
layout=layout,
|
| 321 |
+
device=device,
|
| 322 |
+
use_cache=use_cache,
|
| 323 |
+
temperature=float(args.temperature),
|
| 324 |
+
top_k=int(args.top_k),
|
| 325 |
+
top_p=float(args.top_p),
|
| 326 |
+
greedy=bool(args.greedy),
|
| 327 |
+
max_audio_tokens=max(0, int(args.max_audio_tokens)),
|
| 328 |
+
)
|
| 329 |
+
except Exception as exc:
|
| 330 |
+
print(
|
| 331 |
+
"[WARN] batch_generate_segmentwise failed; "
|
| 332 |
+
f"falling back to single-sample decode. error={exc!r}"
|
| 333 |
+
)
|
| 334 |
+
traceback.print_exc()
|
| 335 |
+
batch_outputs = [
|
| 336 |
+
generate_segmentwise(
|
| 337 |
+
model=model,
|
| 338 |
+
sample=sample,
|
| 339 |
+
layout=layout,
|
| 340 |
+
device=device,
|
| 341 |
+
use_cache=use_cache,
|
| 342 |
+
temperature=float(args.temperature),
|
| 343 |
+
top_k=int(args.top_k),
|
| 344 |
+
top_p=float(args.top_p),
|
| 345 |
+
greedy=bool(args.greedy),
|
| 346 |
+
max_audio_tokens=max(0, int(args.max_audio_tokens)),
|
| 347 |
+
)
|
| 348 |
+
for sample in samples
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
for sample_idx, sample, batch_output in zip(batch_indices, samples, batch_outputs):
|
| 352 |
+
generated_ids, sampled_count, sampled_chord_ids, sampled_segment_ids = batch_output
|
| 353 |
+
prefix = f"{sample_idx:05d}_{sample.song_id}"
|
| 354 |
+
|
| 355 |
+
# save_outputs expects these attributes on args.
|
| 356 |
+
args.sample_idx = sample_idx
|
| 357 |
+
args.json_output_dir = str(json_dir)
|
| 358 |
+
args.wav_output_dir = str(wav_dir)
|
| 359 |
+
|
| 360 |
+
save_outputs(
|
| 361 |
+
output_dir=str(ckpt_output_dir),
|
| 362 |
+
output_prefix=prefix,
|
| 363 |
+
sample=sample,
|
| 364 |
+
layout=layout,
|
| 365 |
+
generated_ids=generated_ids,
|
| 366 |
+
sampled_chord_ids=sampled_chord_ids,
|
| 367 |
+
sampled_segment_ids=sampled_segment_ids,
|
| 368 |
+
args=args,
|
| 369 |
+
mucodec_decoder=mucodec_decoder,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
checkpoint_record["results"].append(
|
| 373 |
+
{
|
| 374 |
+
"sample_idx": sample_idx,
|
| 375 |
+
"song_id": sample.song_id,
|
| 376 |
+
"generated_audio_tokens": sampled_count,
|
| 377 |
+
"wav_path": str(wav_dir / f"{prefix}.wav"),
|
| 378 |
+
"json_path": str(json_dir / f"{prefix}.chord_segment.json"),
|
| 379 |
+
}
|
| 380 |
+
)
|
| 381 |
+
except Exception as exc:
|
| 382 |
+
checkpoint_record["status"] = "error"
|
| 383 |
+
checkpoint_record["error"] = str(exc)
|
| 384 |
+
print(f"[ERROR] checkpoint {checkpoint_path}: {exc!r}")
|
| 385 |
+
traceback.print_exc()
|
| 386 |
+
|
| 387 |
+
summary.append(checkpoint_record)
|
| 388 |
+
|
| 389 |
+
del model
|
| 390 |
+
if device.type == "cuda":
|
| 391 |
+
torch.cuda.empty_cache()
|
| 392 |
+
|
| 393 |
+
summary_path = Path(args.summary_json)
|
| 394 |
+
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
| 395 |
+
with open(summary_path, "w", encoding="utf-8") as f:
|
| 396 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 397 |
+
|
| 398 |
+
print(f"\nSaved summary to: {summary_path}")
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
if __name__ == "__main__":
|
| 402 |
+
main()
|
condition_encoders.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from vocab import NUM_CHORD_CLASSES, NUM_STRUCTURE_CLASSES
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EncoderBlock(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self, d_model: int, n_layers: int = 2, n_heads: int = 8, dropout: float = 0.0
|
| 12 |
+
):
|
| 13 |
+
super().__init__()
|
| 14 |
+
layer = nn.TransformerEncoderLayer(
|
| 15 |
+
d_model=d_model,
|
| 16 |
+
nhead=n_heads,
|
| 17 |
+
dim_feedforward=4 * d_model,
|
| 18 |
+
dropout=dropout,
|
| 19 |
+
activation="gelu",
|
| 20 |
+
batch_first=True,
|
| 21 |
+
norm_first=True,
|
| 22 |
+
)
|
| 23 |
+
self.enc = nn.TransformerEncoder(layer, num_layers=n_layers)
|
| 24 |
+
|
| 25 |
+
def forward(
|
| 26 |
+
self, x: torch.Tensor, pad_mask: Optional[torch.BoolTensor] = None
|
| 27 |
+
) -> torch.Tensor:
|
| 28 |
+
# pad_mask: [B,T], True means PAD (masked)
|
| 29 |
+
return self.enc(x, src_key_padding_mask=pad_mask)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ConditionEncoder(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Condition encoder for AdaLN-Zero:
|
| 35 |
+
Inputs: two aligned sequences
|
| 36 |
+
x_chord: [B,T,D_in]
|
| 37 |
+
x_seg: [B,T,D_in]
|
| 38 |
+
|
| 39 |
+
Output:
|
| 40 |
+
cond_expanded: [B,T,H] (feed into your AdaLN layers as cond_expanded)
|
| 41 |
+
|
| 42 |
+
What it encodes per token:
|
| 43 |
+
- token-level: chord/segment content at time t
|
| 44 |
+
- position: global position (always) + optional segment-relative position
|
| 45 |
+
- segment context: via x_seg + bidirectional transformer mixing
|
| 46 |
+
|
| 47 |
+
Notes:
|
| 48 |
+
- Non-causal (sees future): good for "guidance" conditions.
|
| 49 |
+
- Compute once per sample at generation start; slice per step.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
hidden_size: int,
|
| 55 |
+
chord_embed_dim: int = 512,
|
| 56 |
+
structure_embed_dim: int = 512,
|
| 57 |
+
n_layers: int = 2,
|
| 58 |
+
n_heads: int = 8,
|
| 59 |
+
dropout: float = 0.0,
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.hidden_size = hidden_size
|
| 64 |
+
self.chord_embedding = nn.Embedding(
|
| 65 |
+
NUM_CHORD_CLASSES, chord_embed_dim, padding_idx=0
|
| 66 |
+
)
|
| 67 |
+
self.structure_embedding = nn.Embedding(
|
| 68 |
+
NUM_STRUCTURE_CLASSES, structure_embed_dim, padding_idx=0
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.cond_dim = chord_embed_dim + structure_embed_dim
|
| 72 |
+
self.cond_proj = nn.Linear(self.cond_dim, hidden_size)
|
| 73 |
+
|
| 74 |
+
# Small bidirectional transformer
|
| 75 |
+
self.encoder = EncoderBlock(
|
| 76 |
+
d_model=hidden_size, n_layers=n_layers, n_heads=n_heads, dropout=dropout
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.proj_out = nn.Linear(hidden_size, hidden_size)
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def _sincos_pos(
|
| 83 |
+
positions: torch.Tensor, dim: int, dtype: torch.dtype
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""
|
| 86 |
+
positions: [B, T], absolute positions (0..T-1)
|
| 87 |
+
returns: [B, T, dim] sinusoidal positional encoding
|
| 88 |
+
"""
|
| 89 |
+
if dim <= 0:
|
| 90 |
+
raise ValueError("dim must be > 0 for positional encoding.")
|
| 91 |
+
|
| 92 |
+
half = dim // 2
|
| 93 |
+
if half == 0:
|
| 94 |
+
return torch.zeros(
|
| 95 |
+
positions.size(0),
|
| 96 |
+
positions.size(1),
|
| 97 |
+
dim,
|
| 98 |
+
device=positions.device,
|
| 99 |
+
dtype=dtype,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
pos = positions.to(dtype=torch.float32)
|
| 103 |
+
freqs = torch.exp(
|
| 104 |
+
-math.log(10000.0)
|
| 105 |
+
* torch.arange(half, device=positions.device, dtype=torch.float32)
|
| 106 |
+
/ half
|
| 107 |
+
)
|
| 108 |
+
angles = pos.unsqueeze(-1) * freqs # [B, T, half]
|
| 109 |
+
|
| 110 |
+
enc = torch.zeros(
|
| 111 |
+
positions.size(0),
|
| 112 |
+
positions.size(1),
|
| 113 |
+
dim,
|
| 114 |
+
device=positions.device,
|
| 115 |
+
dtype=torch.float32,
|
| 116 |
+
)
|
| 117 |
+
enc[..., 0 : 2 * half : 2] = torch.sin(angles)
|
| 118 |
+
enc[..., 1 : 2 * half : 2] = torch.cos(angles)
|
| 119 |
+
|
| 120 |
+
return enc.to(dtype=dtype)
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
chord_ids: torch.Tensor, # [B, T]
|
| 125 |
+
structure_ids: torch.Tensor, # [B, T]
|
| 126 |
+
) -> torch.Tensor:
|
| 127 |
+
|
| 128 |
+
chord_emb = self.chord_embedding(chord_ids) # [B, T, chord_dim]
|
| 129 |
+
structure_emb = self.structure_embedding(structure_ids) # [B, T, struct_dim]
|
| 130 |
+
|
| 131 |
+
cond = torch.cat([chord_emb, structure_emb], dim=-1)
|
| 132 |
+
cond = self.cond_proj(cond)
|
| 133 |
+
|
| 134 |
+
# Encoder attention mask is computed separately from condition content.
|
| 135 |
+
# True means this token can be attended by the condition encoder.
|
| 136 |
+
valid_tokens = chord_ids.ne(0) | structure_ids.ne(0)
|
| 137 |
+
pad_mask = ~valid_tokens
|
| 138 |
+
|
| 139 |
+
# Position ids are contiguous only on valid condition-id tokens.
|
| 140 |
+
pos = valid_tokens.to(torch.long).cumsum(dim=1) - 1
|
| 141 |
+
pos = torch.where(valid_tokens, pos, torch.zeros_like(pos))
|
| 142 |
+
|
| 143 |
+
pos_enc = self._sincos_pos(pos, self.hidden_size, cond.dtype)
|
| 144 |
+
valid_mask = valid_tokens.unsqueeze(-1)
|
| 145 |
+
cond = cond + pos_enc * valid_mask.to(dtype=cond.dtype)
|
| 146 |
+
encoded = self.encoder(cond, pad_mask=pad_mask)
|
| 147 |
+
|
| 148 |
+
# [B, T, hidden_size]
|
| 149 |
+
return self.proj_out(encoded)
|
dataset.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
"""Dataset/collate implementation for music training data."""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
|
| 12 |
+
from audio_tokens import (
|
| 13 |
+
EOA_TOKEN,
|
| 14 |
+
MASK_AUDIO_TOKEN,
|
| 15 |
+
SOA_TOKEN,
|
| 16 |
+
add_audio_special_tokens,
|
| 17 |
+
audio_id_to_token,
|
| 18 |
+
)
|
| 19 |
+
from vocab import (
|
| 20 |
+
CHORD_BOS_ID,
|
| 21 |
+
CHORD_EOS_ID,
|
| 22 |
+
STRUCTURE_BOS_ID,
|
| 23 |
+
STRUCTURE_EOS_ID,
|
| 24 |
+
build_frame_chord_ids,
|
| 25 |
+
build_frame_structure_ids,
|
| 26 |
+
normalize_structure_label,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
CN_LANGUAGE_LABELS = {"cn", "zh", "zh-cn", "chinese"}
|
| 31 |
+
SECTION_NAME_MAP = {
|
| 32 |
+
"intro": "Intro",
|
| 33 |
+
"verse": "Verse",
|
| 34 |
+
"chorus": "Chorus",
|
| 35 |
+
"prechorus": "Pre-Chorus",
|
| 36 |
+
"bridge": "Bridge",
|
| 37 |
+
"outro": "Outro",
|
| 38 |
+
"pad": "Pad",
|
| 39 |
+
}
|
| 40 |
+
SINGLETON_SECTION_NAMES = {"intro", "outro", "pad"}
|
| 41 |
+
ENDING_PUNCTUATION = {".", ";", "!", "?", "。", "?", "!", ";"}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _pad_batch_field(batch, key: str, padding_value):
|
| 45 |
+
return pad_sequence(
|
| 46 |
+
[row[key] for row in batch],
|
| 47 |
+
batch_first=True,
|
| 48 |
+
padding_value=padding_value,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def detect_language(text: str, language: str | None = None) -> str:
|
| 53 |
+
return (
|
| 54 |
+
text.replace(" ", ";")
|
| 55 |
+
if str(language).strip().lower() in CN_LANGUAGE_LABELS
|
| 56 |
+
else text
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def normalize_section_text(
|
| 61 |
+
text: str, structure: str, language: str | None = None
|
| 62 |
+
) -> str:
|
| 63 |
+
text = str(text or "")
|
| 64 |
+
text = (
|
| 65 |
+
text.replace(f"[{structure.upper()}]", "")
|
| 66 |
+
.replace(f"[{structure.lower()}]", "")
|
| 67 |
+
.replace(",", ";")
|
| 68 |
+
.replace(".", ";")
|
| 69 |
+
.replace(",", ";")
|
| 70 |
+
.replace("。", ";")
|
| 71 |
+
)
|
| 72 |
+
text = detect_language(text, language=language)
|
| 73 |
+
text = re.sub(r";(?=[A-Za-z])", "; ", text)
|
| 74 |
+
if text and text[-1] not in ENDING_PUNCTUATION:
|
| 75 |
+
text += ";"
|
| 76 |
+
return text
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DataCollate:
|
| 80 |
+
def __call__(self, batch):
|
| 81 |
+
input_ids = _pad_batch_field(batch, "token_ids", 0)
|
| 82 |
+
labels = input_ids
|
| 83 |
+
mask_padded = _pad_batch_field(batch, "mask", 0)
|
| 84 |
+
attention_mask_padded = _pad_batch_field(batch, "attention_mask", 0)
|
| 85 |
+
chord_ids_padded = _pad_batch_field(batch, "chord_ids", 0)
|
| 86 |
+
structure_ids_padded = _pad_batch_field(batch, "structure_ids", 0)
|
| 87 |
+
condition_mask_padded = _pad_batch_field(batch, "condition_mask", False)
|
| 88 |
+
|
| 89 |
+
return {
|
| 90 |
+
"input_ids": input_ids,
|
| 91 |
+
"labels": labels,
|
| 92 |
+
"masks": mask_padded,
|
| 93 |
+
"attention_mask": attention_mask_padded,
|
| 94 |
+
"chord_ids": chord_ids_padded,
|
| 95 |
+
"structure_ids": structure_ids_padded,
|
| 96 |
+
"condition_mask": condition_mask_padded,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class MusicDataset(torch.utils.data.Dataset):
|
| 101 |
+
"""Fly dataset with music-code tokens and section-conditioned text."""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
datasets,
|
| 106 |
+
split: str,
|
| 107 |
+
tokenizer_path: str,
|
| 108 |
+
num_audio_token=16384,
|
| 109 |
+
fps=25,
|
| 110 |
+
use_fast=True,
|
| 111 |
+
):
|
| 112 |
+
self._data = datasets[split]
|
| 113 |
+
self.tokenizer_path = tokenizer_path
|
| 114 |
+
self.use_fast = use_fast
|
| 115 |
+
self.num_audio_token = num_audio_token
|
| 116 |
+
self.fps = fps
|
| 117 |
+
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 119 |
+
self.tokenizer_path,
|
| 120 |
+
local_files_only=True,
|
| 121 |
+
use_fast=self.use_fast,
|
| 122 |
+
)
|
| 123 |
+
add_audio_special_tokens(self.tokenizer, self.num_audio_token)
|
| 124 |
+
self.tokenizer_vocab_size = len(self.tokenizer)
|
| 125 |
+
|
| 126 |
+
self.audio_prefix_length = int(
|
| 127 |
+
self.tokenizer.convert_tokens_to_ids(audio_id_to_token(0))
|
| 128 |
+
)
|
| 129 |
+
self.num_text_token = self.audio_prefix_length
|
| 130 |
+
self.MASK_AUDIO = int(self.tokenizer.convert_tokens_to_ids(MASK_AUDIO_TOKEN))
|
| 131 |
+
self.BOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(SOA_TOKEN))
|
| 132 |
+
self.EOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(EOA_TOKEN))
|
| 133 |
+
self._assistant_audio_placeholder = f"{SOA_TOKEN}{EOA_TOKEN}"
|
| 134 |
+
self._chat_template_kwargs = {"enable_thinking": False}
|
| 135 |
+
|
| 136 |
+
def __len__(self):
|
| 137 |
+
return len(self._data)
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def _positions(token_ids: torch.Tensor, target_id: int) -> torch.Tensor:
|
| 141 |
+
return torch.nonzero(token_ids == target_id, as_tuple=False).squeeze(-1)
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _sorted_sections(sample: dict) -> list[dict]:
|
| 145 |
+
return sorted(
|
| 146 |
+
(
|
| 147 |
+
{
|
| 148 |
+
"raw_index": raw_index,
|
| 149 |
+
"text": str(seg["text"]),
|
| 150 |
+
"desc": str(seg["desc"]).strip(),
|
| 151 |
+
"start": float(seg["start"]),
|
| 152 |
+
"end": float(seg["end"]),
|
| 153 |
+
"structure": normalize_structure_label(seg["section"]),
|
| 154 |
+
}
|
| 155 |
+
for raw_index, seg in enumerate(sample.get("sections", []))
|
| 156 |
+
),
|
| 157 |
+
key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]),
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def _sorted_chords(sample: dict) -> list[dict]:
|
| 162 |
+
return sorted(
|
| 163 |
+
(
|
| 164 |
+
{
|
| 165 |
+
"raw_index": raw_index,
|
| 166 |
+
"type": str(seg.get("type")),
|
| 167 |
+
"start": float(seg.get("start", 0.0)),
|
| 168 |
+
"end": float(seg.get("end", 0.0)),
|
| 169 |
+
}
|
| 170 |
+
for raw_index, seg in enumerate(sample.get("chords", []))
|
| 171 |
+
),
|
| 172 |
+
key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def __getitem__(self, idx):
|
| 176 |
+
sample = self._data[idx]
|
| 177 |
+
sections = self._prepare_sections(sample)
|
| 178 |
+
chords = self._prepare_chords(sample)
|
| 179 |
+
token_ids, attention_mask, frame_idx_map = self._tokenize_messages(
|
| 180 |
+
self._build_messages(sample, sections),
|
| 181 |
+
sample["mucodec_codes"],
|
| 182 |
+
sections,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
total_frames = len(sample["mucodec_codes"])
|
| 186 |
+
structure_ids = build_frame_structure_ids(sections, total_frames, fps=self.fps)
|
| 187 |
+
chord_ids = build_frame_chord_ids(chords, total_frames, fps=self.fps)
|
| 188 |
+
|
| 189 |
+
structure_ids = torch.from_numpy(structure_ids)
|
| 190 |
+
chord_ids = torch.from_numpy(chord_ids)
|
| 191 |
+
|
| 192 |
+
(
|
| 193 |
+
audio_codebook_mask,
|
| 194 |
+
bos_audio_mask,
|
| 195 |
+
eos_mask,
|
| 196 |
+
label_mask,
|
| 197 |
+
condition_mask,
|
| 198 |
+
) = self._build_token_masks(token_ids)
|
| 199 |
+
|
| 200 |
+
chord_ids_aligned, structure_ids_aligned = self._align_condition_ids(
|
| 201 |
+
token_ids=token_ids,
|
| 202 |
+
frame_idx_map=frame_idx_map,
|
| 203 |
+
total_frames=total_frames,
|
| 204 |
+
chord_ids=chord_ids,
|
| 205 |
+
structure_ids=structure_ids,
|
| 206 |
+
audio_codebook_mask=audio_codebook_mask,
|
| 207 |
+
bos_audio_mask=bos_audio_mask,
|
| 208 |
+
eos_mask=eos_mask,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return {
|
| 212 |
+
"token_ids": token_ids,
|
| 213 |
+
"mask": label_mask,
|
| 214 |
+
"attention_mask": attention_mask,
|
| 215 |
+
"chord_ids": chord_ids_aligned,
|
| 216 |
+
"structure_ids": structure_ids_aligned,
|
| 217 |
+
"condition_mask": condition_mask,
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
def _tokenize_messages(
|
| 221 |
+
self,
|
| 222 |
+
messages: list[dict[str, str]],
|
| 223 |
+
full_audio_codes,
|
| 224 |
+
sections: list[dict],
|
| 225 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 226 |
+
|
| 227 |
+
chat_inputs = self.tokenizer.apply_chat_template(
|
| 228 |
+
messages,
|
| 229 |
+
tokenize=True,
|
| 230 |
+
add_generation_prompt=False,
|
| 231 |
+
return_tensors="pt",
|
| 232 |
+
return_dict=True,
|
| 233 |
+
**self._chat_template_kwargs,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
token_ids = chat_inputs["input_ids"]
|
| 237 |
+
attention_mask = chat_inputs["attention_mask"]
|
| 238 |
+
|
| 239 |
+
token_ids = token_ids.squeeze(0)
|
| 240 |
+
attention_mask = attention_mask.squeeze(0)
|
| 241 |
+
|
| 242 |
+
token_ids = token_ids.to(torch.long)
|
| 243 |
+
attention_mask = attention_mask.to(torch.long)
|
| 244 |
+
|
| 245 |
+
return self._expand_audio_tokens(
|
| 246 |
+
token_ids=token_ids,
|
| 247 |
+
attention_mask=attention_mask,
|
| 248 |
+
full_audio_codes=full_audio_codes,
|
| 249 |
+
sections=sections,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def _frame_bounds(
|
| 253 |
+
self,
|
| 254 |
+
start: float,
|
| 255 |
+
end: float,
|
| 256 |
+
total_frames: int,
|
| 257 |
+
prev_end_idx: int = 0,
|
| 258 |
+
) -> tuple[int, int]:
|
| 259 |
+
start_idx = int(start * self.fps)
|
| 260 |
+
end_idx = int(math.ceil(end * self.fps))
|
| 261 |
+
start_idx = max(prev_end_idx, min(total_frames, start_idx))
|
| 262 |
+
end_idx = max(start_idx, min(total_frames, end_idx))
|
| 263 |
+
|
| 264 |
+
return start_idx, end_idx
|
| 265 |
+
|
| 266 |
+
def _prepare_sections(self, sample: dict) -> list[dict]:
|
| 267 |
+
sections = []
|
| 268 |
+
section_counts: dict[str, int] = {}
|
| 269 |
+
sample_language = sample.get("language")
|
| 270 |
+
total_frames = len(sample["mucodec_codes"])
|
| 271 |
+
prev_end_idx = 0
|
| 272 |
+
|
| 273 |
+
for seg in self._sorted_sections(sample):
|
| 274 |
+
structure = seg["structure"]
|
| 275 |
+
section_counts[structure] = section_counts.get(structure, 0) + 1
|
| 276 |
+
raw_start_idx = max(0, min(total_frames, int(seg["start"] * self.fps)))
|
| 277 |
+
raw_end_idx = max(
|
| 278 |
+
raw_start_idx,
|
| 279 |
+
min(total_frames, int(math.ceil(seg["end"] * self.fps))),
|
| 280 |
+
)
|
| 281 |
+
start_idx = prev_end_idx
|
| 282 |
+
end_idx = max(start_idx, raw_end_idx)
|
| 283 |
+
|
| 284 |
+
sections.append(
|
| 285 |
+
{
|
| 286 |
+
"text": normalize_section_text(
|
| 287 |
+
seg["text"], structure, language=sample_language
|
| 288 |
+
),
|
| 289 |
+
"desc": seg["desc"],
|
| 290 |
+
"start": start_idx / float(self.fps),
|
| 291 |
+
"end": end_idx / float(self.fps),
|
| 292 |
+
"start_frame": start_idx,
|
| 293 |
+
"end_frame": end_idx,
|
| 294 |
+
"structure": structure,
|
| 295 |
+
"tag": f"{structure}{section_counts[structure]}",
|
| 296 |
+
"index": section_counts[structure],
|
| 297 |
+
}
|
| 298 |
+
)
|
| 299 |
+
prev_end_idx = end_idx
|
| 300 |
+
|
| 301 |
+
if sections:
|
| 302 |
+
sections[-1]["end_frame"] = total_frames
|
| 303 |
+
sections[-1]["end"] = total_frames / float(self.fps)
|
| 304 |
+
|
| 305 |
+
return sections
|
| 306 |
+
|
| 307 |
+
def _prepare_chords(self, sample: dict) -> list[dict]:
|
| 308 |
+
chords = []
|
| 309 |
+
total_frames = len(sample["mucodec_codes"])
|
| 310 |
+
prev_end_idx = 0
|
| 311 |
+
|
| 312 |
+
for seg in self._sorted_chords(sample):
|
| 313 |
+
start_idx, end_idx = self._frame_bounds(
|
| 314 |
+
seg["start"],
|
| 315 |
+
seg["end"],
|
| 316 |
+
total_frames,
|
| 317 |
+
prev_end_idx=prev_end_idx,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
chords.append(
|
| 321 |
+
{
|
| 322 |
+
"type": seg["type"],
|
| 323 |
+
"start": start_idx / float(self.fps),
|
| 324 |
+
"end": end_idx / float(self.fps),
|
| 325 |
+
"start_frame": start_idx,
|
| 326 |
+
"end_frame": end_idx,
|
| 327 |
+
}
|
| 328 |
+
)
|
| 329 |
+
prev_end_idx = end_idx
|
| 330 |
+
|
| 331 |
+
return chords
|
| 332 |
+
|
| 333 |
+
def _format_section_label(self, section: dict) -> str:
|
| 334 |
+
structure = section["structure"]
|
| 335 |
+
index = section["index"]
|
| 336 |
+
label = SECTION_NAME_MAP[structure]
|
| 337 |
+
if structure in SINGLETON_SECTION_NAMES and index == 1:
|
| 338 |
+
return label
|
| 339 |
+
return f"{label} {index}"
|
| 340 |
+
|
| 341 |
+
def _build_section_user_content(
|
| 342 |
+
self, sample: dict, section: dict, is_first_turn: bool
|
| 343 |
+
) -> str:
|
| 344 |
+
parts = []
|
| 345 |
+
if is_first_turn:
|
| 346 |
+
style = sample["style"].strip()
|
| 347 |
+
if style:
|
| 348 |
+
parts.append(
|
| 349 |
+
f"Please generate a song in the following style:{style}\n"
|
| 350 |
+
"Next, I will tell you the requirements and lyrics for the song "
|
| 351 |
+
"fragment to be generated, section by section."
|
| 352 |
+
)
|
| 353 |
+
else:
|
| 354 |
+
parts.append(
|
| 355 |
+
"Please generate the song section by section. "
|
| 356 |
+
"Next, I will tell you the requirements and lyrics for each fragment."
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
section_parts = [f"[{self._format_section_label(section)}]"]
|
| 360 |
+
desc = section["desc"]
|
| 361 |
+
if desc:
|
| 362 |
+
section_parts.append(f"[desc:{desc}]")
|
| 363 |
+
|
| 364 |
+
lyrics = section["text"]
|
| 365 |
+
if lyrics:
|
| 366 |
+
section_parts.append(f"[lyrics:{lyrics}]")
|
| 367 |
+
|
| 368 |
+
parts.append("".join(section_parts))
|
| 369 |
+
|
| 370 |
+
return "\n".join(part for part in parts if part)
|
| 371 |
+
|
| 372 |
+
def _build_messages(
|
| 373 |
+
self,
|
| 374 |
+
sample: dict,
|
| 375 |
+
sections: list[dict],
|
| 376 |
+
) -> list[dict[str, str]]:
|
| 377 |
+
messages: list[dict[str, str]] = [None] * (2 * len(sections))
|
| 378 |
+
|
| 379 |
+
for i, section in enumerate(sections):
|
| 380 |
+
msg_idx = 2 * i
|
| 381 |
+
messages[msg_idx] = {
|
| 382 |
+
"role": "user",
|
| 383 |
+
"content": self._build_section_user_content(
|
| 384 |
+
sample, section, is_first_turn=(i == 0)
|
| 385 |
+
),
|
| 386 |
+
}
|
| 387 |
+
messages[msg_idx + 1] = {
|
| 388 |
+
"role": "assistant",
|
| 389 |
+
"content": self._assistant_audio_placeholder,
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
return messages
|
| 393 |
+
|
| 394 |
+
def _expand_audio_tokens(
|
| 395 |
+
self,
|
| 396 |
+
token_ids: torch.Tensor,
|
| 397 |
+
attention_mask: torch.Tensor,
|
| 398 |
+
full_audio_codes,
|
| 399 |
+
sections: list[dict],
|
| 400 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 401 |
+
|
| 402 |
+
if not sections:
|
| 403 |
+
return (
|
| 404 |
+
token_ids,
|
| 405 |
+
attention_mask,
|
| 406 |
+
torch.full(token_ids.shape, -1, dtype=torch.long),
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
bos_positions = self._positions(token_ids, self.BOS_AUDIO)
|
| 410 |
+
eos_positions = self._positions(token_ids, self.EOS_AUDIO)
|
| 411 |
+
|
| 412 |
+
audio_code_tensor = torch.as_tensor(full_audio_codes, dtype=torch.long)
|
| 413 |
+
extra_audio_tokens = sum(
|
| 414 |
+
int(section["end_frame"]) - int(section["start_frame"])
|
| 415 |
+
for section in sections
|
| 416 |
+
)
|
| 417 |
+
final_len = token_ids.numel() + extra_audio_tokens
|
| 418 |
+
|
| 419 |
+
expanded_token_ids = torch.empty(final_len, dtype=torch.long)
|
| 420 |
+
expanded_attention_mask = torch.empty(final_len, dtype=torch.long)
|
| 421 |
+
frame_idx_map = torch.full((final_len,), -1, dtype=torch.long)
|
| 422 |
+
|
| 423 |
+
read_pos = 0
|
| 424 |
+
write_pos = 0
|
| 425 |
+
|
| 426 |
+
for bos_pos, eos_pos, section in zip(
|
| 427 |
+
bos_positions.tolist(), eos_positions.tolist(), sections
|
| 428 |
+
):
|
| 429 |
+
start_idx = int(section["start_frame"])
|
| 430 |
+
end_idx = int(section["end_frame"])
|
| 431 |
+
audio_len = end_idx - start_idx
|
| 432 |
+
|
| 433 |
+
prefix_len = bos_pos + 1 - read_pos
|
| 434 |
+
next_write = write_pos + prefix_len
|
| 435 |
+
expanded_token_ids[write_pos:next_write] = token_ids[read_pos : bos_pos + 1]
|
| 436 |
+
expanded_attention_mask[write_pos:next_write] = attention_mask[
|
| 437 |
+
read_pos : bos_pos + 1
|
| 438 |
+
]
|
| 439 |
+
frame_idx_map[next_write - 1] = start_idx if audio_len > 0 else -1
|
| 440 |
+
write_pos = next_write
|
| 441 |
+
|
| 442 |
+
if audio_len > 0:
|
| 443 |
+
next_write = write_pos + audio_len
|
| 444 |
+
expanded_token_ids[write_pos:next_write] = audio_code_tensor[
|
| 445 |
+
start_idx:end_idx
|
| 446 |
+
]
|
| 447 |
+
expanded_token_ids[write_pos:next_write].add_(self.audio_prefix_length)
|
| 448 |
+
expanded_attention_mask[write_pos:next_write] = 1
|
| 449 |
+
frame_idx_map[write_pos:next_write] = torch.arange(
|
| 450 |
+
start_idx, end_idx, dtype=torch.long
|
| 451 |
+
)
|
| 452 |
+
write_pos = next_write
|
| 453 |
+
|
| 454 |
+
expanded_token_ids[write_pos] = token_ids[eos_pos]
|
| 455 |
+
expanded_attention_mask[write_pos] = attention_mask[eos_pos]
|
| 456 |
+
frame_idx_map[write_pos] = end_idx - 1 if audio_len > 0 else -1
|
| 457 |
+
write_pos += 1
|
| 458 |
+
read_pos = eos_pos + 1
|
| 459 |
+
|
| 460 |
+
tail_len = token_ids.numel() - read_pos
|
| 461 |
+
if tail_len > 0:
|
| 462 |
+
expanded_token_ids[write_pos : write_pos + tail_len] = token_ids[read_pos:]
|
| 463 |
+
expanded_attention_mask[write_pos : write_pos + tail_len] = attention_mask[
|
| 464 |
+
read_pos:
|
| 465 |
+
]
|
| 466 |
+
|
| 467 |
+
return expanded_token_ids, expanded_attention_mask, frame_idx_map
|
| 468 |
+
|
| 469 |
+
def _build_token_masks(
|
| 470 |
+
self, token_ids: torch.Tensor
|
| 471 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 472 |
+
|
| 473 |
+
audio_codebook_mask = (token_ids >= self.audio_prefix_length) & (
|
| 474 |
+
token_ids < self.MASK_AUDIO
|
| 475 |
+
)
|
| 476 |
+
bos_audio_mask = token_ids == self.BOS_AUDIO
|
| 477 |
+
eos_mask = token_ids == self.EOS_AUDIO
|
| 478 |
+
label_mask = (audio_codebook_mask | eos_mask).long()
|
| 479 |
+
condition_mask = audio_codebook_mask | bos_audio_mask | eos_mask
|
| 480 |
+
|
| 481 |
+
return audio_codebook_mask, bos_audio_mask, eos_mask, label_mask, condition_mask
|
| 482 |
+
|
| 483 |
+
def _align_condition_ids(
|
| 484 |
+
self,
|
| 485 |
+
token_ids: torch.Tensor,
|
| 486 |
+
frame_idx_map: torch.Tensor,
|
| 487 |
+
total_frames: int,
|
| 488 |
+
chord_ids: torch.Tensor,
|
| 489 |
+
structure_ids: torch.Tensor,
|
| 490 |
+
audio_codebook_mask: torch.Tensor,
|
| 491 |
+
bos_audio_mask: torch.Tensor,
|
| 492 |
+
eos_mask: torch.Tensor,
|
| 493 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 494 |
+
|
| 495 |
+
seq_len = token_ids.numel()
|
| 496 |
+
chord_ids_aligned = torch.zeros(seq_len, dtype=torch.long)
|
| 497 |
+
structure_ids_aligned = torch.zeros(seq_len, dtype=torch.long)
|
| 498 |
+
|
| 499 |
+
bos_positions = torch.nonzero(bos_audio_mask, as_tuple=False).squeeze(-1)
|
| 500 |
+
chord_ids_aligned[bos_positions] = CHORD_BOS_ID
|
| 501 |
+
structure_ids_aligned[bos_positions] = STRUCTURE_BOS_ID
|
| 502 |
+
|
| 503 |
+
audio_positions = torch.nonzero(audio_codebook_mask, as_tuple=False).squeeze(-1)
|
| 504 |
+
cur_frame_idx = frame_idx_map[audio_positions]
|
| 505 |
+
cur_frame_idx = cur_frame_idx.clamp(0, max(total_frames - 1, 0))
|
| 506 |
+
chord_ids_aligned[audio_positions] = chord_ids[cur_frame_idx]
|
| 507 |
+
structure_ids_aligned[audio_positions] = structure_ids[cur_frame_idx]
|
| 508 |
+
|
| 509 |
+
eos_positions = torch.nonzero(eos_mask, as_tuple=False).squeeze(-1)
|
| 510 |
+
chord_ids_aligned[eos_positions] = CHORD_EOS_ID
|
| 511 |
+
structure_ids_aligned[eos_positions] = STRUCTURE_EOS_ID
|
| 512 |
+
|
| 513 |
+
return chord_ids_aligned, structure_ids_aligned
|
decoders.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
|
| 5 |
+
from transformers.cache_utils import Cache
|
| 6 |
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AdaLN(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
DiT-style AdaLN:
|
| 12 |
+
cond_token -> (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
| 13 |
+
|
| 14 |
+
If zero_init=True, then at step0:
|
| 15 |
+
shift/scale/gate are all exactly 0 -> base behavior preserved (mathematically).
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
hidden_size: int,
|
| 21 |
+
cond_dim: int,
|
| 22 |
+
zero_init: bool = True,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.hidden_size = hidden_size
|
| 26 |
+
|
| 27 |
+
self.act = nn.SiLU()
|
| 28 |
+
self.linear = nn.Linear(cond_dim, 6 * hidden_size, bias=True)
|
| 29 |
+
|
| 30 |
+
if zero_init:
|
| 31 |
+
nn.init.zeros_(self.linear.weight)
|
| 32 |
+
nn.init.zeros_(self.linear.bias)
|
| 33 |
+
|
| 34 |
+
def forward(self, cond_token: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
| 35 |
+
"""
|
| 36 |
+
cond_token: [B, T, cond_dim]
|
| 37 |
+
returns 6 tensors, each [B, T, H]
|
| 38 |
+
"""
|
| 39 |
+
params = self.linear(self.act(cond_token)) # [B, T, 6H]
|
| 40 |
+
(
|
| 41 |
+
shift_msa,
|
| 42 |
+
scale_msa,
|
| 43 |
+
gate_msa,
|
| 44 |
+
shift_mlp,
|
| 45 |
+
scale_mlp,
|
| 46 |
+
gate_mlp,
|
| 47 |
+
) = params.chunk(6, dim=-1)
|
| 48 |
+
|
| 49 |
+
return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_adaln(
|
| 53 |
+
x_norm: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
# x_norm * (1 + scale) + shift
|
| 56 |
+
return x_norm * (1.0 + scale) + shift
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Qwen3DecoderLayerAdaLN(Qwen3DecoderLayer):
|
| 60 |
+
"""
|
| 61 |
+
Qwen3 decoder layer with AdaLN injection:
|
| 62 |
+
- Modulate normalized input with (shift, scale) on masked positions.
|
| 63 |
+
- IMPORTANT: gate must preserve base behavior at gate=0:
|
| 64 |
+
out = out_base * (1 + gate) (on masked positions)
|
| 65 |
+
so that when gate==0, out==out_base.
|
| 66 |
+
|
| 67 |
+
Only applied on audio positions (condition_mask==True).
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
config,
|
| 73 |
+
layer_idx: int,
|
| 74 |
+
cond_dim: int,
|
| 75 |
+
zero_init: bool = True,
|
| 76 |
+
):
|
| 77 |
+
super().__init__(config, layer_idx)
|
| 78 |
+
|
| 79 |
+
self.dit_adaln = AdaLN(
|
| 80 |
+
hidden_size=config.hidden_size,
|
| 81 |
+
cond_dim=cond_dim,
|
| 82 |
+
zero_init=zero_init,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def forward(
|
| 86 |
+
self,
|
| 87 |
+
hidden_states: torch.Tensor,
|
| 88 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 89 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 90 |
+
past_key_values: Optional[Cache] = None,
|
| 91 |
+
use_cache: bool = False,
|
| 92 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 93 |
+
cond_expanded: Optional[torch.Tensor] = None, # [B, T, cond_dim]
|
| 94 |
+
condition_mask: Optional[torch.BoolTensor] = None, # [B, T]
|
| 95 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
# Keep the condition path fully tensor-based; avoid .item() checks that
|
| 99 |
+
# can force GPU-CPU synchronization in autoregressive decoding.
|
| 100 |
+
do_cond = (cond_expanded is not None) and (condition_mask is not None)
|
| 101 |
+
|
| 102 |
+
if do_cond:
|
| 103 |
+
(
|
| 104 |
+
shift_msa,
|
| 105 |
+
scale_msa,
|
| 106 |
+
gate_msa,
|
| 107 |
+
shift_mlp,
|
| 108 |
+
scale_mlp,
|
| 109 |
+
gate_mlp,
|
| 110 |
+
) = self.dit_adaln(cond_expanded)
|
| 111 |
+
mask_expanded = condition_mask.unsqueeze(-1) # [B, T, 1]
|
| 112 |
+
|
| 113 |
+
# ---- Self-Attention branch ----
|
| 114 |
+
residual = hidden_states
|
| 115 |
+
x_norm = self.input_layernorm(hidden_states) # RMSNorm in Qwen3
|
| 116 |
+
|
| 117 |
+
if do_cond:
|
| 118 |
+
x_mod = apply_adaln(x_norm, shift_msa, scale_msa)
|
| 119 |
+
x_in = torch.where(mask_expanded, x_mod, x_norm)
|
| 120 |
+
else:
|
| 121 |
+
x_in = x_norm
|
| 122 |
+
|
| 123 |
+
attn_out, _ = self.self_attn(
|
| 124 |
+
hidden_states=x_in,
|
| 125 |
+
attention_mask=attention_mask,
|
| 126 |
+
position_ids=position_ids,
|
| 127 |
+
past_key_values=past_key_values,
|
| 128 |
+
use_cache=use_cache,
|
| 129 |
+
cache_position=cache_position,
|
| 130 |
+
position_embeddings=position_embeddings,
|
| 131 |
+
**kwargs,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if do_cond:
|
| 135 |
+
# Preserve base when gate==0: attn_out_audio = (1 + gate) * attn_out_base
|
| 136 |
+
attn_out = torch.where(mask_expanded, (1.0 + gate_msa) * attn_out, attn_out)
|
| 137 |
+
|
| 138 |
+
hidden_states = residual + attn_out
|
| 139 |
+
|
| 140 |
+
# ---- MLP branch ----
|
| 141 |
+
residual = hidden_states
|
| 142 |
+
x_norm = self.post_attention_layernorm(hidden_states)
|
| 143 |
+
|
| 144 |
+
if do_cond:
|
| 145 |
+
x_mod = apply_adaln(x_norm, shift_mlp, scale_mlp)
|
| 146 |
+
x_in = torch.where(mask_expanded, x_mod, x_norm)
|
| 147 |
+
else:
|
| 148 |
+
x_in = x_norm
|
| 149 |
+
|
| 150 |
+
mlp_out = self.mlp(x_in)
|
| 151 |
+
|
| 152 |
+
if do_cond:
|
| 153 |
+
# Preserve base when gate==0
|
| 154 |
+
mlp_out = torch.where(mask_expanded, (1.0 + gate_mlp) * mlp_out, mlp_out)
|
| 155 |
+
|
| 156 |
+
hidden_states = residual + mlp_out
|
| 157 |
+
|
| 158 |
+
return hidden_states
|
inference_full.py
ADDED
|
@@ -0,0 +1,1084 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
HF-driven inference for MAGEL with segment-level autoregressive generation.
|
| 5 |
+
|
| 6 |
+
Uses from HF sample:
|
| 7 |
+
- text instruction/template tokens (token_ids scaffold)
|
| 8 |
+
- control tokens: chord_ids/structure_ids
|
| 9 |
+
|
| 10 |
+
Does NOT use:
|
| 11 |
+
- ground-truth audio token values as input (audio codebook positions are masked)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import contextlib
|
| 16 |
+
import importlib
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Optional
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from runtime_utils import (
|
| 29 |
+
load_magel_checkpoint,
|
| 30 |
+
load_music_dataset,
|
| 31 |
+
maybe_compile_model,
|
| 32 |
+
maybe_mark_compile_step_begin,
|
| 33 |
+
resolve_device,
|
| 34 |
+
seed_everything,
|
| 35 |
+
)
|
| 36 |
+
from vocab import (
|
| 37 |
+
CHORD_BOS_ID,
|
| 38 |
+
CHORD_EOS_ID,
|
| 39 |
+
STRUCTURE_EOS_ID,
|
| 40 |
+
chord_id_to_label,
|
| 41 |
+
structure_id_to_label,
|
| 42 |
+
)
|
| 43 |
+
from modelling_qwen3 import MAGEL
|
| 44 |
+
|
| 45 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 46 |
+
MUCODEC_ROOT = REPO_ROOT / "MuCodec"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class TokenLayout:
|
| 51 |
+
num_text_token: int
|
| 52 |
+
num_audio_codebook: int = 16384
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def audio_start(self) -> int:
|
| 56 |
+
return self.num_text_token
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def audio_end(self) -> int:
|
| 60 |
+
return self.num_text_token + self.num_audio_codebook
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def mask_audio(self) -> int:
|
| 64 |
+
return self.audio_end
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def bos_audio(self) -> int:
|
| 68 |
+
return self.audio_end + 1
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def eos_audio(self) -> int:
|
| 72 |
+
return self.audio_end + 2
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class SegmentSpan:
|
| 77 |
+
seg_idx: int
|
| 78 |
+
bos_pos: int
|
| 79 |
+
eos_pos: int
|
| 80 |
+
audio_positions: list[int]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class HFTemplateSample:
|
| 85 |
+
song_id: str
|
| 86 |
+
num_text_token: int
|
| 87 |
+
template_ids: torch.Tensor # [T], original token_ids
|
| 88 |
+
input_ids: torch.Tensor # [T], audio codebook replaced with MASK_AUDIO
|
| 89 |
+
chord_ids: torch.Tensor # [T]
|
| 90 |
+
structure_ids: torch.Tensor # [T]
|
| 91 |
+
condition_mask: torch.Tensor # [T]
|
| 92 |
+
is_audio_codebook: torch.Tensor # [T]
|
| 93 |
+
is_eos: torch.Tensor # [T]
|
| 94 |
+
segments: list[SegmentSpan]
|
| 95 |
+
raw_item: dict[str, Any]
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def seq_len(self) -> int:
|
| 99 |
+
return int(self.input_ids.numel())
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def parse_args() -> argparse.Namespace:
|
| 103 |
+
parser = argparse.ArgumentParser(
|
| 104 |
+
description="Segment-wise AR generation from HF controls/scaffold."
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--model_path",
|
| 108 |
+
type=str,
|
| 109 |
+
default="./output_qwen3_0p6b_train/final",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--dataset_path",
|
| 113 |
+
type=str,
|
| 114 |
+
default="muse_mucodec_chord.ds",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument("--split", type=str, default="validation")
|
| 117 |
+
parser.add_argument("--sample_idx", type=int, default=0)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--tokenizer_path", type=str, default="checkpoints/Qwen3-0.6B"
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--num_audio_codebook",
|
| 123 |
+
type=int,
|
| 124 |
+
default=None,
|
| 125 |
+
help="Audio codebook size. Defaults to checkpoint metadata when available.",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
| 129 |
+
parser.add_argument("--top_k", type=int, default=50)
|
| 130 |
+
parser.add_argument("--top_p", type=float, default=0.90)
|
| 131 |
+
parser.add_argument("--greedy", action="store_true", default=False)
|
| 132 |
+
parser.add_argument("--max_audio_tokens", type=int, default=0)
|
| 133 |
+
parser.add_argument("--fps", type=int, default=25)
|
| 134 |
+
|
| 135 |
+
parser.add_argument("--seed", type=int, default=1234)
|
| 136 |
+
parser.add_argument("--device", type=str, default="auto")
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--dtype",
|
| 139 |
+
type=str,
|
| 140 |
+
default="bfloat16",
|
| 141 |
+
choices=["float32", "float16", "bfloat16"],
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument("--use_cache", action="store_true", default=True)
|
| 144 |
+
parser.add_argument("--no_cache", action="store_true", default=False)
|
| 145 |
+
parser.add_argument("--compile", action="store_true", default=False)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--compile_mode",
|
| 148 |
+
type=str,
|
| 149 |
+
default="reduce-overhead",
|
| 150 |
+
choices=["default", "reduce-overhead", "max-autotune"],
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--attn_implementation",
|
| 154 |
+
type=str,
|
| 155 |
+
default="sdpa",
|
| 156 |
+
choices=["eager", "sdpa", "flash_attention_2"],
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument("--output_dir", type=str, default="predictions")
|
| 159 |
+
parser.add_argument("--output_prefix", type=str, default="")
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--json_output_dir",
|
| 162 |
+
type=str,
|
| 163 |
+
default="predictions/json",
|
| 164 |
+
help="Directory for chord/segment json. Default: <output_dir>/json",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--mucodec_device",
|
| 168 |
+
type=str,
|
| 169 |
+
default="auto",
|
| 170 |
+
help="Device string for MuCodec, for example cuda:0.",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--mucodec_layer_num",
|
| 174 |
+
type=int,
|
| 175 |
+
default=7,
|
| 176 |
+
help="MuCodec layer_num passed to the official decoder.",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--mucodec_duration",
|
| 180 |
+
type=float,
|
| 181 |
+
default=40.96,
|
| 182 |
+
help="Chunk duration argument passed to MuCodec code2sound.",
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--mucodec_guidance_scale",
|
| 186 |
+
type=float,
|
| 187 |
+
default=1.5,
|
| 188 |
+
help="Guidance scale argument passed to MuCodec code2sound.",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--mucodec_num_steps",
|
| 192 |
+
type=int,
|
| 193 |
+
default=20,
|
| 194 |
+
help="Sampling steps argument passed to MuCodec code2sound.",
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--mucodec_sample_rate",
|
| 198 |
+
type=int,
|
| 199 |
+
default=48000,
|
| 200 |
+
help="Sample rate used when saving decoded wav.",
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--wav_output_dir",
|
| 204 |
+
type=str,
|
| 205 |
+
default="predictions/wav",
|
| 206 |
+
help="Directory for decoded wav. Default: <output_dir>/wav",
|
| 207 |
+
)
|
| 208 |
+
return parser.parse_args()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def resolve_runtime_device_str(device_arg: str) -> str:
|
| 212 |
+
if device_arg != "auto":
|
| 213 |
+
return device_arg
|
| 214 |
+
if torch.cuda.is_available():
|
| 215 |
+
return "cuda:0"
|
| 216 |
+
if torch.backends.mps.is_available():
|
| 217 |
+
return "mps"
|
| 218 |
+
return "cpu"
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@contextlib.contextmanager
|
| 222 |
+
def pushd(path: str):
|
| 223 |
+
prev = os.getcwd()
|
| 224 |
+
os.chdir(path)
|
| 225 |
+
try:
|
| 226 |
+
yield
|
| 227 |
+
finally:
|
| 228 |
+
os.chdir(prev)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def ensure_sys_path(path: str) -> None:
|
| 232 |
+
if path and path not in sys.path:
|
| 233 |
+
sys.path.insert(0, path)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def get_mucodec_root() -> str:
|
| 237 |
+
if not MUCODEC_ROOT.is_dir():
|
| 238 |
+
raise FileNotFoundError(f"MuCodec directory not found: {MUCODEC_ROOT}")
|
| 239 |
+
if not (MUCODEC_ROOT / "generate.py").is_file():
|
| 240 |
+
raise FileNotFoundError(
|
| 241 |
+
f"MuCodec entrypoint not found: {MUCODEC_ROOT / 'generate.py'}"
|
| 242 |
+
)
|
| 243 |
+
return str(MUCODEC_ROOT)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def import_mucodec_class():
|
| 247 |
+
repo_path = get_mucodec_root()
|
| 248 |
+
ensure_sys_path(repo_path)
|
| 249 |
+
try:
|
| 250 |
+
module = importlib.import_module("generate")
|
| 251 |
+
return getattr(module, "MuCodec"), repo_path
|
| 252 |
+
except Exception as exc: # pragma: no cover - env dependent
|
| 253 |
+
raise ImportError(f"Could not import MuCodec from {repo_path}/generate.py: {exc}")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def build_mucodec_decoder(args: argparse.Namespace) -> Any:
|
| 257 |
+
MuCodec, resolved_repo = import_mucodec_class()
|
| 258 |
+
|
| 259 |
+
ckpt_path = os.path.join(resolved_repo, "ckpt", "mucodec.pt")
|
| 260 |
+
if not os.path.exists(ckpt_path):
|
| 261 |
+
raise FileNotFoundError(f"MuCodec checkpoint not found: {ckpt_path}")
|
| 262 |
+
|
| 263 |
+
required_local_files = [
|
| 264 |
+
os.path.join(resolved_repo, "tools", "audioldm_48k.pth"),
|
| 265 |
+
os.path.join(resolved_repo, "muq_dev", "muq.pt"),
|
| 266 |
+
]
|
| 267 |
+
for path in required_local_files:
|
| 268 |
+
if not os.path.exists(path):
|
| 269 |
+
raise FileNotFoundError(
|
| 270 |
+
f"Required MuCodec dependency not found for current folder structure: {path}"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
mucodec_device = resolve_runtime_device_str(args.mucodec_device)
|
| 274 |
+
if resolved_repo:
|
| 275 |
+
print(f"[INFO] resolved MuCodec repo: {resolved_repo}")
|
| 276 |
+
print(f"[INFO] loading MuCodec from {ckpt_path} on {mucodec_device}")
|
| 277 |
+
with pushd(resolved_repo):
|
| 278 |
+
decoder = MuCodec(
|
| 279 |
+
model_path=ckpt_path,
|
| 280 |
+
layer_num=int(args.mucodec_layer_num),
|
| 281 |
+
load_main_model=True,
|
| 282 |
+
device=mucodec_device,
|
| 283 |
+
)
|
| 284 |
+
setattr(decoder, "_magel_mucodec_repo", resolved_repo)
|
| 285 |
+
return decoder
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def decode_mucodec_codes(
|
| 289 |
+
mucodec_decoder: Any,
|
| 290 |
+
shifted_codes: np.ndarray,
|
| 291 |
+
args: argparse.Namespace,
|
| 292 |
+
) -> torch.Tensor:
|
| 293 |
+
if shifted_codes.ndim != 1:
|
| 294 |
+
raise ValueError(
|
| 295 |
+
f"Expected 1D MuCodec token stream, got shape {shifted_codes.shape}"
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
codes = torch.from_numpy(shifted_codes.astype(np.int64, copy=False))
|
| 299 |
+
codes = codes.unsqueeze(0).unsqueeze(0)
|
| 300 |
+
repo_path = getattr(mucodec_decoder, "_magel_mucodec_repo", "")
|
| 301 |
+
decode_ctx = pushd(repo_path) if repo_path else contextlib.nullcontext()
|
| 302 |
+
with decode_ctx:
|
| 303 |
+
wave = mucodec_decoder.code2sound(
|
| 304 |
+
codes,
|
| 305 |
+
prompt=None,
|
| 306 |
+
duration=float(args.mucodec_duration),
|
| 307 |
+
guidance_scale=float(args.mucodec_guidance_scale),
|
| 308 |
+
num_steps=int(args.mucodec_num_steps),
|
| 309 |
+
disable_progress=True,
|
| 310 |
+
)
|
| 311 |
+
if not torch.is_tensor(wave):
|
| 312 |
+
wave = torch.as_tensor(wave)
|
| 313 |
+
if wave.ndim == 1:
|
| 314 |
+
wave = wave.unsqueeze(0)
|
| 315 |
+
return wave.detach().cpu().to(torch.float32)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def build_segment_spans(
|
| 319 |
+
template_ids: torch.Tensor,
|
| 320 |
+
is_audio_codebook: torch.Tensor,
|
| 321 |
+
layout: TokenLayout,
|
| 322 |
+
) -> list[SegmentSpan]:
|
| 323 |
+
bos_positions = torch.where(template_ids.eq(layout.bos_audio))[0].tolist()
|
| 324 |
+
eos_positions = torch.where(template_ids.eq(layout.eos_audio))[0].tolist()
|
| 325 |
+
if not bos_positions or not eos_positions:
|
| 326 |
+
return []
|
| 327 |
+
|
| 328 |
+
spans: list[SegmentSpan] = []
|
| 329 |
+
eos_ptr = 0
|
| 330 |
+
for b in bos_positions:
|
| 331 |
+
while eos_ptr < len(eos_positions) and eos_positions[eos_ptr] <= b:
|
| 332 |
+
eos_ptr += 1
|
| 333 |
+
if eos_ptr >= len(eos_positions):
|
| 334 |
+
break
|
| 335 |
+
e = eos_positions[eos_ptr]
|
| 336 |
+
eos_ptr += 1
|
| 337 |
+
idx = torch.arange(template_ids.numel(), device=template_ids.device)
|
| 338 |
+
mask = is_audio_codebook & (idx > b) & (idx < e)
|
| 339 |
+
audio_positions = torch.where(mask)[0].tolist()
|
| 340 |
+
spans.append(
|
| 341 |
+
SegmentSpan(
|
| 342 |
+
seg_idx=len(spans),
|
| 343 |
+
bos_pos=int(b),
|
| 344 |
+
eos_pos=int(e),
|
| 345 |
+
audio_positions=[int(p) for p in audio_positions],
|
| 346 |
+
)
|
| 347 |
+
)
|
| 348 |
+
return spans
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def load_hf_template_sample(
|
| 352 |
+
dataset_path: str,
|
| 353 |
+
split: str,
|
| 354 |
+
tokenizer_path: str,
|
| 355 |
+
sample_idx: int,
|
| 356 |
+
num_audio_codebook: int,
|
| 357 |
+
) -> HFTemplateSample:
|
| 358 |
+
music_ds = load_music_dataset(
|
| 359 |
+
dataset_path=dataset_path,
|
| 360 |
+
split=split,
|
| 361 |
+
tokenizer_path=tokenizer_path,
|
| 362 |
+
num_audio_token=num_audio_codebook,
|
| 363 |
+
use_fast=True,
|
| 364 |
+
)
|
| 365 |
+
return load_hf_template_sample_from_music_dataset(
|
| 366 |
+
music_ds=music_ds,
|
| 367 |
+
sample_idx=sample_idx,
|
| 368 |
+
num_audio_codebook=num_audio_codebook,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def load_hf_template_sample_from_music_dataset(
|
| 373 |
+
music_ds,
|
| 374 |
+
sample_idx: int,
|
| 375 |
+
num_audio_codebook: int,
|
| 376 |
+
) -> HFTemplateSample:
|
| 377 |
+
layout = TokenLayout(
|
| 378 |
+
num_text_token=music_ds.num_text_token,
|
| 379 |
+
num_audio_codebook=num_audio_codebook,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
raw_item = music_ds._data[sample_idx]
|
| 383 |
+
row = music_ds[sample_idx]
|
| 384 |
+
|
| 385 |
+
template_ids = row["token_ids"].to(torch.long)
|
| 386 |
+
chord_ids = row["chord_ids"].to(torch.long)
|
| 387 |
+
structure_ids = row["structure_ids"].to(torch.long)
|
| 388 |
+
condition_mask = row["condition_mask"].to(torch.bool)
|
| 389 |
+
|
| 390 |
+
seq_len = int(template_ids.numel())
|
| 391 |
+
for name, t in [
|
| 392 |
+
("chord_ids", chord_ids),
|
| 393 |
+
("structure_ids", structure_ids),
|
| 394 |
+
("condition_mask", condition_mask),
|
| 395 |
+
]:
|
| 396 |
+
if int(t.numel()) != seq_len:
|
| 397 |
+
raise ValueError(f"{name} length mismatch: {int(t.numel())} != {seq_len}")
|
| 398 |
+
|
| 399 |
+
is_audio_codebook = (template_ids >= layout.audio_start) & (
|
| 400 |
+
template_ids < layout.audio_end
|
| 401 |
+
)
|
| 402 |
+
is_eos = template_ids.eq(layout.eos_audio)
|
| 403 |
+
|
| 404 |
+
# Remove GT audio token values from input scaffold.
|
| 405 |
+
input_ids = template_ids.clone()
|
| 406 |
+
input_ids[is_audio_codebook] = layout.mask_audio
|
| 407 |
+
|
| 408 |
+
spans = build_segment_spans(template_ids, is_audio_codebook, layout)
|
| 409 |
+
|
| 410 |
+
return HFTemplateSample(
|
| 411 |
+
song_id=str(raw_item.get("song_id", f"sample_{sample_idx}")),
|
| 412 |
+
num_text_token=music_ds.num_text_token,
|
| 413 |
+
template_ids=template_ids,
|
| 414 |
+
input_ids=input_ids,
|
| 415 |
+
chord_ids=chord_ids,
|
| 416 |
+
structure_ids=structure_ids,
|
| 417 |
+
condition_mask=condition_mask,
|
| 418 |
+
is_audio_codebook=is_audio_codebook,
|
| 419 |
+
is_eos=is_eos,
|
| 420 |
+
segments=spans,
|
| 421 |
+
raw_item=raw_item,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def apply_top_k_top_p(logits: torch.Tensor, top_k: int, top_p: float) -> torch.Tensor:
|
| 426 |
+
if top_k is not None and top_k > 0:
|
| 427 |
+
k = min(top_k, logits.shape[-1])
|
| 428 |
+
values, _ = torch.topk(logits, k, dim=-1)
|
| 429 |
+
kth = values[:, -1].unsqueeze(-1)
|
| 430 |
+
logits = logits.masked_fill(logits < kth, float("-inf"))
|
| 431 |
+
|
| 432 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 433 |
+
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
|
| 434 |
+
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
| 435 |
+
cum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 436 |
+
remove_mask = cum_probs > top_p
|
| 437 |
+
remove_mask[:, 0] = False
|
| 438 |
+
sorted_logits = sorted_logits.masked_fill(remove_mask, float("-inf"))
|
| 439 |
+
filtered = torch.full_like(logits, float("-inf"))
|
| 440 |
+
filtered.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)
|
| 441 |
+
logits = filtered
|
| 442 |
+
return logits
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def sample_from_logits(
|
| 446 |
+
logits: torch.Tensor,
|
| 447 |
+
temperature: float,
|
| 448 |
+
top_k: int,
|
| 449 |
+
top_p: float,
|
| 450 |
+
greedy: bool,
|
| 451 |
+
) -> int:
|
| 452 |
+
if greedy or temperature <= 0:
|
| 453 |
+
return int(torch.argmax(logits, dim=-1).item())
|
| 454 |
+
logits = logits / max(temperature, 1e-6)
|
| 455 |
+
logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p)
|
| 456 |
+
if not torch.isfinite(logits).any():
|
| 457 |
+
raise RuntimeError("All logits are -inf after filtering.")
|
| 458 |
+
probs = torch.softmax(logits, dim=-1)
|
| 459 |
+
return int(torch.multinomial(probs, num_samples=1).item())
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def sample_audio_token_from_logits(
|
| 463 |
+
logits: torch.Tensor,
|
| 464 |
+
layout: TokenLayout,
|
| 465 |
+
temperature: float,
|
| 466 |
+
top_k: int,
|
| 467 |
+
top_p: float,
|
| 468 |
+
greedy: bool,
|
| 469 |
+
) -> int:
|
| 470 |
+
audio_logits = logits[:, layout.audio_start : layout.audio_end]
|
| 471 |
+
sampled_audio_idx = sample_from_logits(
|
| 472 |
+
audio_logits,
|
| 473 |
+
temperature=temperature,
|
| 474 |
+
top_k=top_k,
|
| 475 |
+
top_p=top_p,
|
| 476 |
+
greedy=greedy,
|
| 477 |
+
)
|
| 478 |
+
return int(layout.audio_start + sampled_audio_idx)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def chord_id_to_type(chord_id: int) -> str:
|
| 482 |
+
decoded = chord_id_to_label(chord_id)
|
| 483 |
+
return decoded if decoded != "N" or chord_id in {1, CHORD_BOS_ID, CHORD_EOS_ID} else f"unknown_{chord_id}"
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def segment_id_to_type(segment_id: int) -> str:
|
| 487 |
+
decoded = structure_id_to_label(segment_id)
|
| 488 |
+
return decoded if 0 <= segment_id <= STRUCTURE_EOS_ID else f"unknown_{segment_id}"
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def to_intervals(type_ids: list[int], fps: int, mapper) -> list[dict[str, Any]]:
|
| 492 |
+
if not type_ids:
|
| 493 |
+
return []
|
| 494 |
+
out: list[dict[str, Any]] = []
|
| 495 |
+
start = 0
|
| 496 |
+
cur = type_ids[0]
|
| 497 |
+
for i in range(1, len(type_ids) + 1):
|
| 498 |
+
if i == len(type_ids) or type_ids[i] != cur:
|
| 499 |
+
out.append(
|
| 500 |
+
{
|
| 501 |
+
"start": round(start / float(fps), 6),
|
| 502 |
+
"end": round(i / float(fps), 6),
|
| 503 |
+
"type": mapper(int(cur)),
|
| 504 |
+
}
|
| 505 |
+
)
|
| 506 |
+
if i < len(type_ids):
|
| 507 |
+
start = i
|
| 508 |
+
cur = type_ids[i]
|
| 509 |
+
return out
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def merge_same_type_with_small_gap(
|
| 513 |
+
intervals: list[dict[str, Any]], fps: int, max_gap_frames: int = 1
|
| 514 |
+
) -> list[dict[str, Any]]:
|
| 515 |
+
if not intervals:
|
| 516 |
+
return []
|
| 517 |
+
max_gap_s = float(max_gap_frames) / float(fps)
|
| 518 |
+
merged = [dict(intervals[0])]
|
| 519 |
+
for cur in intervals[1:]:
|
| 520 |
+
prev = merged[-1]
|
| 521 |
+
gap_s = float(cur["start"]) - float(prev["end"])
|
| 522 |
+
if prev.get("type") == cur.get("type") and gap_s <= (max_gap_s + 1e-9):
|
| 523 |
+
prev["end"] = cur["end"]
|
| 524 |
+
else:
|
| 525 |
+
merged.append(dict(cur))
|
| 526 |
+
return merged
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
@torch.inference_mode()
|
| 530 |
+
def generate_segmentwise(
|
| 531 |
+
model: MAGEL,
|
| 532 |
+
sample: HFTemplateSample,
|
| 533 |
+
layout: TokenLayout,
|
| 534 |
+
device: torch.device,
|
| 535 |
+
use_cache: bool,
|
| 536 |
+
temperature: float,
|
| 537 |
+
top_k: int,
|
| 538 |
+
top_p: float,
|
| 539 |
+
greedy: bool,
|
| 540 |
+
max_audio_tokens: int,
|
| 541 |
+
) -> tuple[torch.Tensor, int, list[int], list[int]]:
|
| 542 |
+
import time
|
| 543 |
+
|
| 544 |
+
seq_template = sample.input_ids.to(device)
|
| 545 |
+
chord_template = sample.chord_ids.to(device)
|
| 546 |
+
structure_template = sample.structure_ids.to(device)
|
| 547 |
+
condition_mask_template = sample.condition_mask.to(device)
|
| 548 |
+
is_audio_code = sample.is_audio_codebook.to(device)
|
| 549 |
+
is_eos = sample.is_eos.to(device)
|
| 550 |
+
|
| 551 |
+
slot_positions = torch.where(is_audio_code | is_eos)[0]
|
| 552 |
+
if slot_positions.numel() == 0:
|
| 553 |
+
# No generation slot: return scaffold as-is.
|
| 554 |
+
return seq_template.detach().cpu(), 0, [], []
|
| 555 |
+
|
| 556 |
+
start_pos = int(slot_positions[0].item())
|
| 557 |
+
if sample.segments:
|
| 558 |
+
end_pos = int(sample.segments[-1].eos_pos)
|
| 559 |
+
else:
|
| 560 |
+
end_pos = int(slot_positions[-1].item())
|
| 561 |
+
|
| 562 |
+
sampled_chord_ids: list[int] = []
|
| 563 |
+
sampled_segment_ids: list[int] = []
|
| 564 |
+
|
| 565 |
+
generated_ids = seq_template.clone()
|
| 566 |
+
sampled_count = 0
|
| 567 |
+
past_key_values: Optional[tuple] = None
|
| 568 |
+
|
| 569 |
+
# Precompute full-sequence condition once so cached decoding keeps
|
| 570 |
+
# the same global condition-encoder context as training.
|
| 571 |
+
cond_template: torch.Tensor = model.condition_encoder(
|
| 572 |
+
chord_template.unsqueeze(0),
|
| 573 |
+
structure_template.unsqueeze(0),
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Prefill with fixed prefix.
|
| 577 |
+
full_attention_mask = torch.ones(
|
| 578 |
+
(1, sample.seq_len), dtype=torch.long, device=device
|
| 579 |
+
)
|
| 580 |
+
prefix_ids = generated_ids[:start_pos].unsqueeze(0)
|
| 581 |
+
prefix_attn = full_attention_mask[:, :start_pos]
|
| 582 |
+
model_kwargs = dict(
|
| 583 |
+
input_ids=prefix_ids,
|
| 584 |
+
attention_mask=prefix_attn,
|
| 585 |
+
condition_mask=condition_mask_template[:start_pos].unsqueeze(0),
|
| 586 |
+
cond_precomputed=cond_template[:, :start_pos, :],
|
| 587 |
+
use_cache=use_cache,
|
| 588 |
+
)
|
| 589 |
+
maybe_mark_compile_step_begin(model)
|
| 590 |
+
prefill_t0 = time.perf_counter()
|
| 591 |
+
out = model(**model_kwargs)
|
| 592 |
+
prefill_time_s = time.perf_counter() - prefill_t0
|
| 593 |
+
logits_next = out.logits[:, -1, :]
|
| 594 |
+
if use_cache:
|
| 595 |
+
past_key_values = out.past_key_values
|
| 596 |
+
step_ids = torch.empty((1, 1), dtype=torch.long, device=device)
|
| 597 |
+
|
| 598 |
+
decode_time_s = 0.0
|
| 599 |
+
for i in range(start_pos, end_pos + 1):
|
| 600 |
+
if bool(is_audio_code[i].item()):
|
| 601 |
+
if max_audio_tokens > 0 and sampled_count >= max_audio_tokens:
|
| 602 |
+
break
|
| 603 |
+
next_id = sample_audio_token_from_logits(
|
| 604 |
+
logits_next,
|
| 605 |
+
layout=layout,
|
| 606 |
+
temperature=temperature,
|
| 607 |
+
top_k=top_k,
|
| 608 |
+
top_p=top_p,
|
| 609 |
+
greedy=greedy,
|
| 610 |
+
)
|
| 611 |
+
sampled_count += 1
|
| 612 |
+
# Controls are input-aligned to the token sequence.
|
| 613 |
+
cond_pos = i
|
| 614 |
+
sampled_chord_ids.append(int(chord_template[cond_pos].item()))
|
| 615 |
+
sampled_segment_ids.append(int(structure_template[cond_pos].item()))
|
| 616 |
+
elif bool(is_eos[i].item()):
|
| 617 |
+
next_id = layout.eos_audio
|
| 618 |
+
else:
|
| 619 |
+
next_id = int(seq_template[i].item())
|
| 620 |
+
|
| 621 |
+
generated_ids[i] = int(next_id)
|
| 622 |
+
|
| 623 |
+
if i >= end_pos:
|
| 624 |
+
break
|
| 625 |
+
|
| 626 |
+
if use_cache:
|
| 627 |
+
step_ids[0, 0] = int(next_id)
|
| 628 |
+
step_attn = full_attention_mask[:, : i + 2]
|
| 629 |
+
model_kwargs = dict(
|
| 630 |
+
input_ids=step_ids,
|
| 631 |
+
attention_mask=step_attn,
|
| 632 |
+
condition_mask=condition_mask_template[i : i + 1].unsqueeze(0),
|
| 633 |
+
cond_precomputed=cond_template[:, i : i + 1, :],
|
| 634 |
+
past_key_values=past_key_values,
|
| 635 |
+
use_cache=True,
|
| 636 |
+
)
|
| 637 |
+
maybe_mark_compile_step_begin(model)
|
| 638 |
+
step_t0 = time.perf_counter()
|
| 639 |
+
out = model(**model_kwargs)
|
| 640 |
+
decode_time_s += time.perf_counter() - step_t0
|
| 641 |
+
logits_next = out.logits[:, -1, :]
|
| 642 |
+
past_key_values = out.past_key_values
|
| 643 |
+
else:
|
| 644 |
+
cur_len = i + 1
|
| 645 |
+
model_kwargs = dict(
|
| 646 |
+
input_ids=generated_ids[:cur_len].unsqueeze(0),
|
| 647 |
+
attention_mask=full_attention_mask[:, :cur_len],
|
| 648 |
+
condition_mask=condition_mask_template[:cur_len].unsqueeze(0),
|
| 649 |
+
cond_precomputed=cond_template[:, :cur_len, :],
|
| 650 |
+
use_cache=False,
|
| 651 |
+
)
|
| 652 |
+
maybe_mark_compile_step_begin(model)
|
| 653 |
+
step_t0 = time.perf_counter()
|
| 654 |
+
out = model(**model_kwargs)
|
| 655 |
+
decode_time_s += time.perf_counter() - step_t0
|
| 656 |
+
logits_next = out.logits[:, -1, :]
|
| 657 |
+
|
| 658 |
+
total_gen_time_s = prefill_time_s + decode_time_s
|
| 659 |
+
tokens_per_second = (
|
| 660 |
+
float(sampled_count) / decode_time_s if decode_time_s > 0 and sampled_count > 0 else 0.0
|
| 661 |
+
)
|
| 662 |
+
print(
|
| 663 |
+
"[PROFILE] generation "
|
| 664 |
+
f"prefill_s={prefill_time_s:.3f} "
|
| 665 |
+
f"decode_s={decode_time_s:.3f} "
|
| 666 |
+
f"total_s={total_gen_time_s:.3f} "
|
| 667 |
+
f"sampled_audio_tokens={sampled_count} "
|
| 668 |
+
f"decode_tok_per_s={tokens_per_second:.3f}"
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
return (
|
| 672 |
+
generated_ids.detach().cpu(),
|
| 673 |
+
sampled_count,
|
| 674 |
+
sampled_chord_ids,
|
| 675 |
+
sampled_segment_ids,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
@torch.inference_mode()
|
| 680 |
+
def batch_generate_segmentwise(
|
| 681 |
+
model: MAGEL,
|
| 682 |
+
samples: list[HFTemplateSample],
|
| 683 |
+
layout: TokenLayout,
|
| 684 |
+
device: torch.device,
|
| 685 |
+
use_cache: bool,
|
| 686 |
+
temperature: float,
|
| 687 |
+
top_k: int,
|
| 688 |
+
top_p: float,
|
| 689 |
+
greedy: bool,
|
| 690 |
+
max_audio_tokens: int,
|
| 691 |
+
) -> list[tuple[torch.Tensor, int, list[int], list[int]]]:
|
| 692 |
+
import time
|
| 693 |
+
|
| 694 |
+
if not samples:
|
| 695 |
+
return []
|
| 696 |
+
if not use_cache:
|
| 697 |
+
return [
|
| 698 |
+
generate_segmentwise(
|
| 699 |
+
model=model,
|
| 700 |
+
sample=sample,
|
| 701 |
+
layout=layout,
|
| 702 |
+
device=device,
|
| 703 |
+
use_cache=use_cache,
|
| 704 |
+
temperature=temperature,
|
| 705 |
+
top_k=top_k,
|
| 706 |
+
top_p=top_p,
|
| 707 |
+
greedy=greedy,
|
| 708 |
+
max_audio_tokens=max_audio_tokens,
|
| 709 |
+
)
|
| 710 |
+
for sample in samples
|
| 711 |
+
]
|
| 712 |
+
|
| 713 |
+
batch_size = len(samples)
|
| 714 |
+
seq_lens = [sample.seq_len for sample in samples]
|
| 715 |
+
max_seq_len = max(seq_lens)
|
| 716 |
+
|
| 717 |
+
seq_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
|
| 718 |
+
generated_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
|
| 719 |
+
chord_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
|
| 720 |
+
structure_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
|
| 721 |
+
condition_mask_templates = torch.zeros(
|
| 722 |
+
(batch_size, max_seq_len), dtype=torch.bool, device=device
|
| 723 |
+
)
|
| 724 |
+
is_audio_code_templates = torch.zeros(
|
| 725 |
+
(batch_size, max_seq_len), dtype=torch.bool, device=device
|
| 726 |
+
)
|
| 727 |
+
is_eos_templates = torch.zeros((batch_size, max_seq_len), dtype=torch.bool, device=device)
|
| 728 |
+
|
| 729 |
+
start_positions: list[int] = []
|
| 730 |
+
end_positions: list[int] = []
|
| 731 |
+
sampled_counts = [0 for _ in samples]
|
| 732 |
+
sampled_chord_ids: list[list[int]] = [[] for _ in samples]
|
| 733 |
+
sampled_segment_ids: list[list[int]] = [[] for _ in samples]
|
| 734 |
+
valid_sample_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
|
| 735 |
+
|
| 736 |
+
for row_idx, sample in enumerate(samples):
|
| 737 |
+
seq_templates[row_idx, : sample.seq_len] = sample.input_ids.to(device)
|
| 738 |
+
generated_ids[row_idx, : sample.seq_len] = sample.input_ids.to(device)
|
| 739 |
+
chord_templates[row_idx, : sample.seq_len] = sample.chord_ids.to(device)
|
| 740 |
+
structure_templates[row_idx, : sample.seq_len] = sample.structure_ids.to(device)
|
| 741 |
+
condition_mask_templates[row_idx, : sample.seq_len] = sample.condition_mask.to(device)
|
| 742 |
+
is_audio_code_templates[row_idx, : sample.seq_len] = sample.is_audio_codebook.to(device)
|
| 743 |
+
is_eos_templates[row_idx, : sample.seq_len] = sample.is_eos.to(device)
|
| 744 |
+
|
| 745 |
+
slot_positions = torch.where(
|
| 746 |
+
is_audio_code_templates[row_idx, : sample.seq_len]
|
| 747 |
+
| is_eos_templates[row_idx, : sample.seq_len]
|
| 748 |
+
)[0]
|
| 749 |
+
if slot_positions.numel() == 0:
|
| 750 |
+
valid_sample_mask[row_idx] = False
|
| 751 |
+
start_positions.append(sample.seq_len)
|
| 752 |
+
end_positions.append(sample.seq_len - 1)
|
| 753 |
+
continue
|
| 754 |
+
start_pos = int(slot_positions[0].item())
|
| 755 |
+
if sample.segments:
|
| 756 |
+
end_pos = int(sample.segments[-1].eos_pos)
|
| 757 |
+
else:
|
| 758 |
+
end_pos = int(slot_positions[-1].item())
|
| 759 |
+
start_positions.append(start_pos)
|
| 760 |
+
end_positions.append(end_pos)
|
| 761 |
+
|
| 762 |
+
if not bool(valid_sample_mask.any().item()):
|
| 763 |
+
return [
|
| 764 |
+
(sample.input_ids.detach().cpu(), 0, [], [])
|
| 765 |
+
for sample in samples
|
| 766 |
+
]
|
| 767 |
+
|
| 768 |
+
start_positions_t = torch.tensor(start_positions, dtype=torch.long, device=device)
|
| 769 |
+
end_positions_t = torch.tensor(end_positions, dtype=torch.long, device=device)
|
| 770 |
+
prefix_lens = start_positions_t.clone()
|
| 771 |
+
max_prefix_len = int(prefix_lens.max().item())
|
| 772 |
+
max_decode_steps = int((end_positions_t - start_positions_t + 1).clamp_min(0).max().item())
|
| 773 |
+
|
| 774 |
+
cond_template = model.condition_encoder(chord_templates, structure_templates)
|
| 775 |
+
|
| 776 |
+
prefix_attention_mask = (
|
| 777 |
+
torch.arange(max_prefix_len, device=device).unsqueeze(0) < prefix_lens.unsqueeze(1)
|
| 778 |
+
).to(torch.long)
|
| 779 |
+
prefill_t0 = time.perf_counter()
|
| 780 |
+
maybe_mark_compile_step_begin(model)
|
| 781 |
+
out = model(
|
| 782 |
+
input_ids=generated_ids[:, :max_prefix_len],
|
| 783 |
+
attention_mask=prefix_attention_mask,
|
| 784 |
+
condition_mask=condition_mask_templates[:, :max_prefix_len],
|
| 785 |
+
cond_precomputed=cond_template[:, :max_prefix_len, :],
|
| 786 |
+
use_cache=True,
|
| 787 |
+
)
|
| 788 |
+
prefill_time_s = time.perf_counter() - prefill_t0
|
| 789 |
+
|
| 790 |
+
gather_idx = (prefix_lens - 1).clamp_min(0)
|
| 791 |
+
batch_indices = torch.arange(batch_size, device=device)
|
| 792 |
+
logits_next = out.logits[batch_indices, gather_idx, :]
|
| 793 |
+
past_key_values = out.past_key_values
|
| 794 |
+
|
| 795 |
+
step_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
|
| 796 |
+
decode_valid_mask = torch.zeros(
|
| 797 |
+
(batch_size, max_decode_steps), dtype=torch.bool, device=device
|
| 798 |
+
)
|
| 799 |
+
decode_time_s = 0.0
|
| 800 |
+
|
| 801 |
+
for step_idx in range(max_decode_steps):
|
| 802 |
+
cur_positions = start_positions_t + step_idx
|
| 803 |
+
active_mask = valid_sample_mask & cur_positions.le(end_positions_t)
|
| 804 |
+
if not bool(active_mask.any().item()):
|
| 805 |
+
break
|
| 806 |
+
|
| 807 |
+
next_ids = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 808 |
+
for row_idx in range(batch_size):
|
| 809 |
+
if not bool(active_mask[row_idx].item()):
|
| 810 |
+
continue
|
| 811 |
+
cur_pos = int(cur_positions[row_idx].item())
|
| 812 |
+
if bool(is_audio_code_templates[row_idx, cur_pos].item()):
|
| 813 |
+
if max_audio_tokens > 0 and sampled_counts[row_idx] >= max_audio_tokens:
|
| 814 |
+
valid_sample_mask[row_idx] = False
|
| 815 |
+
continue
|
| 816 |
+
next_id = sample_audio_token_from_logits(
|
| 817 |
+
logits_next[row_idx : row_idx + 1],
|
| 818 |
+
layout=layout,
|
| 819 |
+
temperature=temperature,
|
| 820 |
+
top_k=top_k,
|
| 821 |
+
top_p=top_p,
|
| 822 |
+
greedy=greedy,
|
| 823 |
+
)
|
| 824 |
+
sampled_counts[row_idx] += 1
|
| 825 |
+
sampled_chord_ids[row_idx].append(
|
| 826 |
+
int(chord_templates[row_idx, cur_pos].item())
|
| 827 |
+
)
|
| 828 |
+
sampled_segment_ids[row_idx].append(
|
| 829 |
+
int(structure_templates[row_idx, cur_pos].item())
|
| 830 |
+
)
|
| 831 |
+
elif bool(is_eos_templates[row_idx, cur_pos].item()):
|
| 832 |
+
next_id = layout.eos_audio
|
| 833 |
+
else:
|
| 834 |
+
next_id = int(seq_templates[row_idx, cur_pos].item())
|
| 835 |
+
|
| 836 |
+
generated_ids[row_idx, cur_pos] = int(next_id)
|
| 837 |
+
next_ids[row_idx] = int(next_id)
|
| 838 |
+
decode_valid_mask[row_idx, step_idx] = True
|
| 839 |
+
|
| 840 |
+
if step_idx >= max_decode_steps - 1:
|
| 841 |
+
break
|
| 842 |
+
|
| 843 |
+
step_ids[:, 0] = next_ids
|
| 844 |
+
step_attention_mask = torch.cat(
|
| 845 |
+
[
|
| 846 |
+
prefix_attention_mask,
|
| 847 |
+
decode_valid_mask[:, : step_idx + 1].to(torch.long),
|
| 848 |
+
],
|
| 849 |
+
dim=1,
|
| 850 |
+
)
|
| 851 |
+
step_condition_mask = torch.zeros((batch_size, 1), dtype=torch.bool, device=device)
|
| 852 |
+
step_cond = torch.zeros(
|
| 853 |
+
(batch_size, 1, cond_template.shape[-1]),
|
| 854 |
+
dtype=cond_template.dtype,
|
| 855 |
+
device=device,
|
| 856 |
+
)
|
| 857 |
+
for row_idx in range(batch_size):
|
| 858 |
+
if not bool(decode_valid_mask[row_idx, step_idx].item()):
|
| 859 |
+
continue
|
| 860 |
+
cur_pos = int(cur_positions[row_idx].item())
|
| 861 |
+
step_condition_mask[row_idx, 0] = condition_mask_templates[row_idx, cur_pos]
|
| 862 |
+
step_cond[row_idx, 0, :] = cond_template[row_idx, cur_pos, :]
|
| 863 |
+
|
| 864 |
+
step_t0 = time.perf_counter()
|
| 865 |
+
maybe_mark_compile_step_begin(model)
|
| 866 |
+
out = model(
|
| 867 |
+
input_ids=step_ids,
|
| 868 |
+
attention_mask=step_attention_mask,
|
| 869 |
+
condition_mask=step_condition_mask,
|
| 870 |
+
cond_precomputed=step_cond,
|
| 871 |
+
past_key_values=past_key_values,
|
| 872 |
+
use_cache=True,
|
| 873 |
+
)
|
| 874 |
+
decode_time_s += time.perf_counter() - step_t0
|
| 875 |
+
logits_next = out.logits[:, -1, :]
|
| 876 |
+
past_key_values = out.past_key_values
|
| 877 |
+
|
| 878 |
+
total_sampled_tokens = sum(sampled_counts)
|
| 879 |
+
total_gen_time_s = prefill_time_s + decode_time_s
|
| 880 |
+
tokens_per_second = (
|
| 881 |
+
float(total_sampled_tokens) / decode_time_s
|
| 882 |
+
if decode_time_s > 0 and total_sampled_tokens > 0
|
| 883 |
+
else 0.0
|
| 884 |
+
)
|
| 885 |
+
print(
|
| 886 |
+
"[PROFILE] batch_generation "
|
| 887 |
+
f"batch_size={batch_size} "
|
| 888 |
+
f"prefill_s={prefill_time_s:.3f} "
|
| 889 |
+
f"decode_s={decode_time_s:.3f} "
|
| 890 |
+
f"total_s={total_gen_time_s:.3f} "
|
| 891 |
+
f"sampled_audio_tokens={total_sampled_tokens} "
|
| 892 |
+
f"decode_tok_per_s={tokens_per_second:.3f}"
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
outputs: list[tuple[torch.Tensor, int, list[int], list[int]]] = []
|
| 896 |
+
for row_idx, sample in enumerate(samples):
|
| 897 |
+
if not bool((torch.where(sample.is_audio_codebook | sample.is_eos)[0]).numel()):
|
| 898 |
+
outputs.append((sample.input_ids.detach().cpu(), 0, [], []))
|
| 899 |
+
continue
|
| 900 |
+
outputs.append(
|
| 901 |
+
(
|
| 902 |
+
generated_ids[row_idx, : sample.seq_len].detach().cpu(),
|
| 903 |
+
sampled_counts[row_idx],
|
| 904 |
+
sampled_chord_ids[row_idx],
|
| 905 |
+
sampled_segment_ids[row_idx],
|
| 906 |
+
)
|
| 907 |
+
)
|
| 908 |
+
return outputs
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def save_outputs(
|
| 912 |
+
output_dir: str,
|
| 913 |
+
output_prefix: str,
|
| 914 |
+
sample: HFTemplateSample,
|
| 915 |
+
layout: TokenLayout,
|
| 916 |
+
generated_ids: torch.Tensor,
|
| 917 |
+
sampled_chord_ids: list[int],
|
| 918 |
+
sampled_segment_ids: list[int],
|
| 919 |
+
args: argparse.Namespace,
|
| 920 |
+
mucodec_decoder: Any = None,
|
| 921 |
+
) -> None:
|
| 922 |
+
import time
|
| 923 |
+
|
| 924 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 925 |
+
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 926 |
+
prefix = output_prefix or f"{sample.song_id}_{args.sample_idx}_{stamp}"
|
| 927 |
+
|
| 928 |
+
json_dir = args.json_output_dir or os.path.join(output_dir, "json")
|
| 929 |
+
wav_dir = args.wav_output_dir or os.path.join(output_dir, "wav")
|
| 930 |
+
Path(json_dir).mkdir(parents=True, exist_ok=True)
|
| 931 |
+
Path(wav_dir).mkdir(parents=True, exist_ok=True)
|
| 932 |
+
|
| 933 |
+
json_path = os.path.join(json_dir, f"{prefix}.chord_segment.json")
|
| 934 |
+
wav_path = os.path.join(wav_dir, f"{prefix}.wav")
|
| 935 |
+
|
| 936 |
+
gen_full = generated_ids.cpu().numpy().astype(np.int64)
|
| 937 |
+
|
| 938 |
+
gen_audio_raw = gen_full[
|
| 939 |
+
(gen_full >= layout.audio_start) & (gen_full < layout.audio_end)
|
| 940 |
+
]
|
| 941 |
+
gen_audio_shift = gen_audio_raw - layout.audio_start
|
| 942 |
+
|
| 943 |
+
save_t0 = time.perf_counter()
|
| 944 |
+
if gen_audio_shift.size == 0:
|
| 945 |
+
print("[WARN] No generated MuCodec tokens; skipping wav decode.")
|
| 946 |
+
else:
|
| 947 |
+
import torchaudio
|
| 948 |
+
|
| 949 |
+
wave = decode_mucodec_codes(mucodec_decoder, gen_audio_shift, args)
|
| 950 |
+
torchaudio.save(wav_path, wave, int(args.mucodec_sample_rate))
|
| 951 |
+
print(f"[OK] {wav_path}")
|
| 952 |
+
|
| 953 |
+
chord_intervals = to_intervals(
|
| 954 |
+
sampled_chord_ids, fps=int(args.fps), mapper=chord_id_to_type
|
| 955 |
+
)
|
| 956 |
+
segment_intervals = to_intervals(
|
| 957 |
+
sampled_segment_ids, fps=int(args.fps), mapper=segment_id_to_type
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
# PAD is used for EOS-related conditioning; drop it in exported json.
|
| 961 |
+
chord_intervals = [x for x in chord_intervals if x.get("type") != "pad"]
|
| 962 |
+
segment_intervals = [x for x in segment_intervals if x.get("type") != "pad"]
|
| 963 |
+
chord_intervals = merge_same_type_with_small_gap(
|
| 964 |
+
chord_intervals, fps=int(args.fps), max_gap_frames=1
|
| 965 |
+
)
|
| 966 |
+
segment_intervals = merge_same_type_with_small_gap(
|
| 967 |
+
segment_intervals, fps=int(args.fps), max_gap_frames=1
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
chord_segment = {
|
| 971 |
+
"song_id": sample.song_id,
|
| 972 |
+
"sample_idx": int(args.sample_idx),
|
| 973 |
+
"fps": int(args.fps),
|
| 974 |
+
"generated_audio_count": int(gen_audio_raw.shape[0]),
|
| 975 |
+
"chord": chord_intervals,
|
| 976 |
+
"segment": segment_intervals,
|
| 977 |
+
}
|
| 978 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
| 979 |
+
json.dump(chord_segment, f, ensure_ascii=False, indent=2)
|
| 980 |
+
|
| 981 |
+
print(f"[OK] {json_path}")
|
| 982 |
+
save_time_s = time.perf_counter() - save_t0
|
| 983 |
+
print(
|
| 984 |
+
"[PROFILE] save "
|
| 985 |
+
f"save_s={save_time_s:.3f} "
|
| 986 |
+
f"generated_audio_count={int(gen_audio_raw.shape[0])}"
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
def main() -> None:
|
| 991 |
+
import time
|
| 992 |
+
|
| 993 |
+
args = parse_args()
|
| 994 |
+
seed_everything(args.seed)
|
| 995 |
+
|
| 996 |
+
use_cache = args.use_cache and not args.no_cache
|
| 997 |
+
|
| 998 |
+
device = resolve_device(args.device)
|
| 999 |
+
dtype = {
|
| 1000 |
+
"float32": torch.float32,
|
| 1001 |
+
"float16": torch.float16,
|
| 1002 |
+
"bfloat16": torch.bfloat16,
|
| 1003 |
+
}[args.dtype]
|
| 1004 |
+
if device.type == "cpu" and dtype != torch.float32:
|
| 1005 |
+
print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.")
|
| 1006 |
+
dtype = torch.float32
|
| 1007 |
+
|
| 1008 |
+
print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}")
|
| 1009 |
+
print(f"[INFO] loading model from {args.model_path}")
|
| 1010 |
+
model = load_magel_checkpoint(
|
| 1011 |
+
checkpoint_path=args.model_path,
|
| 1012 |
+
device=device,
|
| 1013 |
+
dtype=dtype,
|
| 1014 |
+
attn_implementation=args.attn_implementation,
|
| 1015 |
+
)
|
| 1016 |
+
model = maybe_compile_model(
|
| 1017 |
+
model,
|
| 1018 |
+
enabled=bool(args.compile),
|
| 1019 |
+
mode=str(args.compile_mode),
|
| 1020 |
+
)
|
| 1021 |
+
num_audio_codebook = (
|
| 1022 |
+
int(args.num_audio_codebook)
|
| 1023 |
+
if args.num_audio_codebook is not None
|
| 1024 |
+
else int(getattr(model.config, "magel_num_audio_token", 16384))
|
| 1025 |
+
)
|
| 1026 |
+
print(f"[INFO] num_audio_codebook={num_audio_codebook}")
|
| 1027 |
+
|
| 1028 |
+
print(f"[INFO] loading HF sample idx={args.sample_idx} from {args.dataset_path}")
|
| 1029 |
+
sample = load_hf_template_sample(
|
| 1030 |
+
dataset_path=args.dataset_path,
|
| 1031 |
+
split=args.split,
|
| 1032 |
+
tokenizer_path=args.tokenizer_path,
|
| 1033 |
+
sample_idx=args.sample_idx,
|
| 1034 |
+
num_audio_codebook=num_audio_codebook,
|
| 1035 |
+
)
|
| 1036 |
+
layout = TokenLayout(
|
| 1037 |
+
num_text_token=sample.num_text_token,
|
| 1038 |
+
num_audio_codebook=num_audio_codebook,
|
| 1039 |
+
)
|
| 1040 |
+
print(
|
| 1041 |
+
f"[INFO] song_id={sample.song_id}, seq_len={sample.seq_len}, segments={len(sample.segments)}"
|
| 1042 |
+
)
|
| 1043 |
+
mucodec_decoder = build_mucodec_decoder(args)
|
| 1044 |
+
print("[INFO] running segment-level autoregressive generation...")
|
| 1045 |
+
t1 = time.time()
|
| 1046 |
+
(
|
| 1047 |
+
generated_ids,
|
| 1048 |
+
sampled_count,
|
| 1049 |
+
sampled_chord_ids,
|
| 1050 |
+
sampled_segment_ids,
|
| 1051 |
+
) = generate_segmentwise(
|
| 1052 |
+
model=model,
|
| 1053 |
+
sample=sample,
|
| 1054 |
+
layout=layout,
|
| 1055 |
+
device=device,
|
| 1056 |
+
use_cache=use_cache,
|
| 1057 |
+
temperature=float(args.temperature),
|
| 1058 |
+
top_k=int(args.top_k),
|
| 1059 |
+
top_p=float(args.top_p),
|
| 1060 |
+
greedy=bool(args.greedy),
|
| 1061 |
+
max_audio_tokens=max(0, int(args.max_audio_tokens)),
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
print(f"[INFO] sampled audio tokens: {sampled_count}")
|
| 1065 |
+
print(f"[INFO] output sequence length: {generated_ids.numel()}")
|
| 1066 |
+
t2 = time.time()
|
| 1067 |
+
|
| 1068 |
+
print("total time:", t2 - t1)
|
| 1069 |
+
|
| 1070 |
+
save_outputs(
|
| 1071 |
+
output_dir=args.output_dir,
|
| 1072 |
+
output_prefix=args.output_prefix,
|
| 1073 |
+
sample=sample,
|
| 1074 |
+
layout=layout,
|
| 1075 |
+
generated_ids=generated_ids,
|
| 1076 |
+
sampled_chord_ids=sampled_chord_ids,
|
| 1077 |
+
sampled_segment_ids=sampled_segment_ids,
|
| 1078 |
+
args=args,
|
| 1079 |
+
mucodec_decoder=mucodec_decoder,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
if __name__ == "__main__":
|
| 1084 |
+
main()
|
modelling_qwen3.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Any, Optional
|
| 4 |
+
|
| 5 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 6 |
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3Config, Qwen3ForCausalLM
|
| 7 |
+
from transformers.cache_utils import Cache
|
| 8 |
+
|
| 9 |
+
from decoders import Qwen3DecoderLayerAdaLN
|
| 10 |
+
from condition_encoders import ConditionEncoder
|
| 11 |
+
from vocab import (
|
| 12 |
+
CHORD_BOS_ID,
|
| 13 |
+
CHORD_EOS_ID,
|
| 14 |
+
CHORD_N_ID,
|
| 15 |
+
SEGMENT_FALLBACK_ID,
|
| 16 |
+
STRUCTURE_BOS_ID,
|
| 17 |
+
STRUCTURE_EOS_ID,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MAGEL(Qwen3ForCausalLM):
|
| 22 |
+
"""
|
| 23 |
+
- masks-based CE loss
|
| 24 |
+
- decoder layers replaced with Qwen3DecoderLayerAdaLN
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
config: Qwen3Config,
|
| 30 |
+
**kwargs: Any,
|
| 31 |
+
):
|
| 32 |
+
super().__init__(config)
|
| 33 |
+
|
| 34 |
+
adaln_dim = int(config.hidden_size)
|
| 35 |
+
chord_dropout_trigger_prob = float(config.magel_chord_dropout_trigger_prob)
|
| 36 |
+
structure_dropout_trigger_prob = float(config.magel_structure_dropout_trigger_prob)
|
| 37 |
+
|
| 38 |
+
self.vocab_size = config.vocab_size
|
| 39 |
+
self.adaln_dim = adaln_dim
|
| 40 |
+
|
| 41 |
+
self.condition_encoder = ConditionEncoder(hidden_size=adaln_dim)
|
| 42 |
+
self.chord_dropout_trigger_prob = chord_dropout_trigger_prob
|
| 43 |
+
self.structure_dropout_trigger_prob = structure_dropout_trigger_prob
|
| 44 |
+
|
| 45 |
+
for layer_idx in range(len(self.model.layers)):
|
| 46 |
+
self.model.layers[layer_idx] = Qwen3DecoderLayerAdaLN(
|
| 47 |
+
config,
|
| 48 |
+
layer_idx=layer_idx,
|
| 49 |
+
cond_dim=adaln_dim,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Persist MAGEL-specific ctor args so checkpoints can be reloaded without
|
| 53 |
+
# out-of-band flags.
|
| 54 |
+
self.config.magel_chord_dropout_trigger_prob = chord_dropout_trigger_prob
|
| 55 |
+
self.config.magel_structure_dropout_trigger_prob = structure_dropout_trigger_prob
|
| 56 |
+
|
| 57 |
+
self.post_init()
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def _drop_audio_condition_spans(
|
| 61 |
+
ids: torch.LongTensor,
|
| 62 |
+
condition_mask: torch.BoolTensor,
|
| 63 |
+
trigger_prob: float,
|
| 64 |
+
replacement_id: int,
|
| 65 |
+
bos_id: int,
|
| 66 |
+
eos_id: int,
|
| 67 |
+
) -> torch.LongTensor:
|
| 68 |
+
if trigger_prob <= 0.0:
|
| 69 |
+
return ids
|
| 70 |
+
|
| 71 |
+
# Only drop aligned audio-condition positions; keep BOS/EOS untouched.
|
| 72 |
+
eligible_mask = condition_mask & (ids != bos_id) & (ids != eos_id)
|
| 73 |
+
|
| 74 |
+
if not eligible_mask.any():
|
| 75 |
+
return ids
|
| 76 |
+
|
| 77 |
+
dropped = ids.clone()
|
| 78 |
+
trigger_mask = torch.rand(ids.size(0), device=ids.device) < trigger_prob
|
| 79 |
+
span_len = 25
|
| 80 |
+
|
| 81 |
+
for batch_idx in torch.nonzero(trigger_mask, as_tuple=False).flatten():
|
| 82 |
+
candidate_positions = torch.nonzero(
|
| 83 |
+
eligible_mask[batch_idx], as_tuple=False
|
| 84 |
+
).flatten()
|
| 85 |
+
num_candidates = int(candidate_positions.numel())
|
| 86 |
+
if num_candidates == 0:
|
| 87 |
+
continue
|
| 88 |
+
drop_ratio = torch.rand((), device=ids.device).item()
|
| 89 |
+
num_to_drop = int(round(drop_ratio * num_candidates))
|
| 90 |
+
if num_to_drop <= 0:
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
remaining = num_to_drop
|
| 94 |
+
available_positions = candidate_positions.clone()
|
| 95 |
+
while remaining > 0:
|
| 96 |
+
num_available = int(available_positions.numel())
|
| 97 |
+
if num_available == 0:
|
| 98 |
+
break
|
| 99 |
+
|
| 100 |
+
cur_span_len = min(span_len, remaining)
|
| 101 |
+
if num_available <= cur_span_len:
|
| 102 |
+
start_idx = 0
|
| 103 |
+
selected_positions = available_positions[:cur_span_len]
|
| 104 |
+
else:
|
| 105 |
+
max_start = num_available - cur_span_len + 1
|
| 106 |
+
start_idx = int(
|
| 107 |
+
torch.randint(0, max_start, (1,), device=ids.device).item()
|
| 108 |
+
)
|
| 109 |
+
selected_positions = available_positions[
|
| 110 |
+
start_idx : start_idx + cur_span_len
|
| 111 |
+
]
|
| 112 |
+
dropped[batch_idx, selected_positions] = replacement_id
|
| 113 |
+
|
| 114 |
+
keep_mask = torch.ones(
|
| 115 |
+
num_available,
|
| 116 |
+
dtype=torch.bool,
|
| 117 |
+
device=ids.device,
|
| 118 |
+
)
|
| 119 |
+
keep_mask[start_idx : start_idx + int(selected_positions.numel())] = False
|
| 120 |
+
available_positions = available_positions[keep_mask]
|
| 121 |
+
remaining -= int(selected_positions.numel())
|
| 122 |
+
|
| 123 |
+
return dropped
|
| 124 |
+
|
| 125 |
+
def _build_condition(
|
| 126 |
+
self,
|
| 127 |
+
chord_ids: Optional[torch.LongTensor],
|
| 128 |
+
structure_ids: Optional[torch.LongTensor],
|
| 129 |
+
condition_mask: Optional[torch.BoolTensor],
|
| 130 |
+
cond_precomputed: Optional[torch.FloatTensor],
|
| 131 |
+
) -> Optional[torch.FloatTensor]:
|
| 132 |
+
if cond_precomputed is not None:
|
| 133 |
+
return cond_precomputed
|
| 134 |
+
if chord_ids is None or structure_ids is None:
|
| 135 |
+
return None
|
| 136 |
+
if self.training:
|
| 137 |
+
if condition_mask is None:
|
| 138 |
+
raise ValueError("condition_mask is required during training.")
|
| 139 |
+
chord_ids = self._drop_audio_condition_spans(
|
| 140 |
+
ids=chord_ids,
|
| 141 |
+
condition_mask=condition_mask,
|
| 142 |
+
trigger_prob=self.chord_dropout_trigger_prob,
|
| 143 |
+
replacement_id=CHORD_N_ID,
|
| 144 |
+
bos_id=CHORD_BOS_ID,
|
| 145 |
+
eos_id=CHORD_EOS_ID,
|
| 146 |
+
)
|
| 147 |
+
structure_ids = self._drop_audio_condition_spans(
|
| 148 |
+
ids=structure_ids,
|
| 149 |
+
condition_mask=condition_mask,
|
| 150 |
+
trigger_prob=self.structure_dropout_trigger_prob,
|
| 151 |
+
replacement_id=SEGMENT_FALLBACK_ID,
|
| 152 |
+
bos_id=STRUCTURE_BOS_ID,
|
| 153 |
+
eos_id=STRUCTURE_EOS_ID,
|
| 154 |
+
)
|
| 155 |
+
return self.condition_encoder(chord_ids, structure_ids)
|
| 156 |
+
|
| 157 |
+
def ce_loss(
|
| 158 |
+
self,
|
| 159 |
+
logits: torch.FloatTensor,
|
| 160 |
+
labels: Optional[torch.LongTensor],
|
| 161 |
+
masks: Optional[torch.LongTensor],
|
| 162 |
+
) -> Optional[torch.Tensor]:
|
| 163 |
+
if labels is None or masks is None:
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 167 |
+
shift_labels = labels[:, 1:].clone()
|
| 168 |
+
valid_token_mask = masks[:, 1:].bool().contiguous()
|
| 169 |
+
|
| 170 |
+
if not valid_token_mask.any():
|
| 171 |
+
return shift_logits.new_zeros(())
|
| 172 |
+
|
| 173 |
+
shift_labels.masked_fill_(~valid_token_mask, -100)
|
| 174 |
+
loss_sum = F.cross_entropy(
|
| 175 |
+
shift_logits.view(-1, self.config.vocab_size),
|
| 176 |
+
shift_labels.view(-1).to(shift_logits.device),
|
| 177 |
+
ignore_index=-100,
|
| 178 |
+
reduction="sum",
|
| 179 |
+
)
|
| 180 |
+
valid_count = valid_token_mask.sum().to(
|
| 181 |
+
device=loss_sum.device,
|
| 182 |
+
dtype=loss_sum.dtype,
|
| 183 |
+
)
|
| 184 |
+
return loss_sum / valid_count.clamp_min(1)
|
| 185 |
+
|
| 186 |
+
def forward(
|
| 187 |
+
self,
|
| 188 |
+
input_ids: torch.LongTensor,
|
| 189 |
+
masks: Optional[torch.LongTensor] = None,
|
| 190 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 191 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 192 |
+
past_key_values: Optional[Cache] = None,
|
| 193 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 194 |
+
labels: Optional[torch.LongTensor] = None,
|
| 195 |
+
use_cache: Optional[bool] = None,
|
| 196 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 197 |
+
chord_ids: Optional[torch.LongTensor] = None,
|
| 198 |
+
structure_ids: Optional[torch.LongTensor] = None,
|
| 199 |
+
condition_mask: Optional[torch.BoolTensor] = None,
|
| 200 |
+
cond_precomputed: Optional[torch.FloatTensor] = None,
|
| 201 |
+
) -> CausalLMOutputWithPast:
|
| 202 |
+
|
| 203 |
+
if use_cache is None:
|
| 204 |
+
use_cache = self.config.use_cache
|
| 205 |
+
|
| 206 |
+
if inputs_embeds is None:
|
| 207 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
| 208 |
+
|
| 209 |
+
cond = self._build_condition(
|
| 210 |
+
chord_ids=chord_ids,
|
| 211 |
+
structure_ids=structure_ids,
|
| 212 |
+
condition_mask=condition_mask,
|
| 213 |
+
cond_precomputed=cond_precomputed,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
base_out = self.model(
|
| 217 |
+
inputs_embeds=inputs_embeds,
|
| 218 |
+
attention_mask=attention_mask,
|
| 219 |
+
position_ids=position_ids,
|
| 220 |
+
past_key_values=past_key_values,
|
| 221 |
+
use_cache=use_cache,
|
| 222 |
+
cond_expanded=cond,
|
| 223 |
+
condition_mask=condition_mask,
|
| 224 |
+
cache_position=cache_position,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
hidden_states = base_out.last_hidden_state
|
| 228 |
+
logits = self.lm_head(hidden_states)
|
| 229 |
+
loss = self.ce_loss(logits=logits, labels=labels, masks=masks)
|
| 230 |
+
|
| 231 |
+
return CausalLMOutputWithPast(
|
| 232 |
+
loss=loss,
|
| 233 |
+
logits=logits,
|
| 234 |
+
past_key_values=base_out.past_key_values,
|
| 235 |
+
hidden_states=base_out.hidden_states,
|
| 236 |
+
attentions=base_out.attentions,
|
| 237 |
+
)
|
muse_mucodec_chord.ds/dataset_dict.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"splits": ["train", "validation"]}
|
runtime_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import datasets
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from datasets import DatasetDict
|
| 7 |
+
from transformers import AutoConfig
|
| 8 |
+
|
| 9 |
+
from dataset import MusicDataset
|
| 10 |
+
from modelling_qwen3 import MAGEL
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def seed_everything(seed: int) -> None:
|
| 14 |
+
random.seed(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
torch.manual_seed(seed)
|
| 17 |
+
if torch.cuda.is_available():
|
| 18 |
+
torch.cuda.manual_seed_all(seed)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def resolve_device(device_arg: str) -> torch.device:
|
| 22 |
+
if device_arg != "auto":
|
| 23 |
+
return torch.device(device_arg)
|
| 24 |
+
if torch.cuda.is_available():
|
| 25 |
+
return torch.device("cuda")
|
| 26 |
+
if torch.backends.mps.is_available():
|
| 27 |
+
return torch.device("mps")
|
| 28 |
+
return torch.device("cpu")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def move_batch_to_device(
|
| 32 |
+
batch: dict[str, torch.Tensor], device: torch.device
|
| 33 |
+
) -> dict[str, torch.Tensor]:
|
| 34 |
+
return {
|
| 35 |
+
key: value.to(device) if torch.is_tensor(value) else value
|
| 36 |
+
for key, value in batch.items()
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def load_music_dataset(
|
| 40 |
+
dataset_path: str,
|
| 41 |
+
split: str,
|
| 42 |
+
tokenizer_path: str,
|
| 43 |
+
num_audio_token: int = 16384,
|
| 44 |
+
fps: int = 25,
|
| 45 |
+
use_fast: bool = True,
|
| 46 |
+
) -> MusicDataset:
|
| 47 |
+
hf = datasets.load_from_disk(dataset_path)
|
| 48 |
+
if isinstance(hf, DatasetDict):
|
| 49 |
+
if split not in hf:
|
| 50 |
+
raise KeyError(f"Split not found: {split}")
|
| 51 |
+
container = hf
|
| 52 |
+
else:
|
| 53 |
+
container = {split: hf}
|
| 54 |
+
return MusicDataset(
|
| 55 |
+
datasets=container,
|
| 56 |
+
split=split,
|
| 57 |
+
tokenizer_path=tokenizer_path,
|
| 58 |
+
num_audio_token=num_audio_token,
|
| 59 |
+
fps=fps,
|
| 60 |
+
use_fast=use_fast,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_magel_checkpoint(
|
| 65 |
+
checkpoint_path: str,
|
| 66 |
+
device: torch.device,
|
| 67 |
+
dtype: torch.dtype = torch.float32,
|
| 68 |
+
attn_implementation: str = "sdpa",
|
| 69 |
+
) -> MAGEL:
|
| 70 |
+
config = AutoConfig.from_pretrained(
|
| 71 |
+
checkpoint_path,
|
| 72 |
+
local_files_only=True,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
model = MAGEL.from_pretrained(
|
| 76 |
+
checkpoint_path,
|
| 77 |
+
config=config,
|
| 78 |
+
torch_dtype=dtype,
|
| 79 |
+
attn_implementation=attn_implementation,
|
| 80 |
+
local_files_only=True,
|
| 81 |
+
)
|
| 82 |
+
model.to(device=device)
|
| 83 |
+
model.eval()
|
| 84 |
+
return model
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def maybe_compile_model(
|
| 88 |
+
model,
|
| 89 |
+
enabled: bool = False,
|
| 90 |
+
mode: str = "reduce-overhead",
|
| 91 |
+
):
|
| 92 |
+
if not enabled:
|
| 93 |
+
setattr(model, "_magel_is_compiled", False)
|
| 94 |
+
return model
|
| 95 |
+
if not hasattr(torch, "compile"):
|
| 96 |
+
raise RuntimeError("torch.compile is not available in this PyTorch build.")
|
| 97 |
+
compiled_model = torch.compile(model, mode=mode)
|
| 98 |
+
setattr(compiled_model, "_magel_is_compiled", True)
|
| 99 |
+
return compiled_model
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def maybe_mark_compile_step_begin(model) -> None:
|
| 103 |
+
if not getattr(model, "_magel_is_compiled", False):
|
| 104 |
+
return
|
| 105 |
+
compiler_ns = getattr(torch, "compiler", None)
|
| 106 |
+
if compiler_ns is None:
|
| 107 |
+
return
|
| 108 |
+
mark_step_begin = getattr(compiler_ns, "cudagraph_mark_step_begin", None)
|
| 109 |
+
if mark_step_begin is None:
|
| 110 |
+
return
|
| 111 |
+
mark_step_begin()
|
train.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Train MAGEL directly from a vanilla Qwen3 checkpoint.
|
| 5 |
+
|
| 6 |
+
Compared with train.py/train_newparaonly.py, this script:
|
| 7 |
+
1) Loads an original Qwen3 base checkpoint.
|
| 8 |
+
2) Resolves MAGEL hparams explicitly at construction time.
|
| 9 |
+
3) Initializes MAGEL extra modules from scratch and trains end-to-end.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoConfig,
|
| 18 |
+
Trainer,
|
| 19 |
+
TrainingArguments,
|
| 20 |
+
)
|
| 21 |
+
import datasets
|
| 22 |
+
from dataset import DataCollate, MusicDataset
|
| 23 |
+
from modelling_qwen3 import MAGEL
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def resolve_model_source(model_path: str, resume_from_checkpoint: str | None) -> str:
|
| 27 |
+
if not resume_from_checkpoint:
|
| 28 |
+
return model_path
|
| 29 |
+
|
| 30 |
+
if os.path.abspath(model_path) != os.path.abspath(resume_from_checkpoint):
|
| 31 |
+
print(
|
| 32 |
+
"Ignoring --model_path during resume and loading config/model from: "
|
| 33 |
+
f"{resume_from_checkpoint}"
|
| 34 |
+
)
|
| 35 |
+
return resume_from_checkpoint
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_model(
|
| 39 |
+
model_path: str,
|
| 40 |
+
model_dtype: torch.dtype,
|
| 41 |
+
target_vocab_size: int,
|
| 42 |
+
attn_implementation: str,
|
| 43 |
+
) -> MAGEL:
|
| 44 |
+
print(f"Loading Qwen3 model from: {model_path}")
|
| 45 |
+
|
| 46 |
+
config = AutoConfig.from_pretrained(
|
| 47 |
+
model_path,
|
| 48 |
+
local_files_only=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
model = MAGEL.from_pretrained(
|
| 52 |
+
model_path,
|
| 53 |
+
torch_dtype=model_dtype,
|
| 54 |
+
config=config,
|
| 55 |
+
attn_implementation=attn_implementation,
|
| 56 |
+
ignore_mismatched_sizes=True,
|
| 57 |
+
local_files_only=True,
|
| 58 |
+
)
|
| 59 |
+
model.resize_token_embeddings(target_vocab_size)
|
| 60 |
+
|
| 61 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 62 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 63 |
+
magel_extra_params = sum(
|
| 64 |
+
p.numel()
|
| 65 |
+
for name, p in model.named_parameters()
|
| 66 |
+
if ("condition_encoder" in name or "dit_adaln" in name)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
print(f"Total parameters: {total_params:,}")
|
| 70 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 71 |
+
print(f"MAGEL extra parameters: {magel_extra_params:,}")
|
| 72 |
+
print(
|
| 73 |
+
"MAGEL config: "
|
| 74 |
+
f"adaln_dim={model.adaln_dim}, "
|
| 75 |
+
f"chord_dropout_trigger_prob={model.chord_dropout_trigger_prob}, "
|
| 76 |
+
f"structure_dropout_trigger_prob={model.structure_dropout_trigger_prob}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return model
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_dataset(
|
| 83 |
+
dataset_path: str,
|
| 84 |
+
tokenizer_path: str,
|
| 85 |
+
num_audio_token: int = 16384,
|
| 86 |
+
) -> MusicDataset:
|
| 87 |
+
print(f"Loading dataset from: {dataset_path}")
|
| 88 |
+
print(f"Loading tokenizer from: {tokenizer_path}")
|
| 89 |
+
|
| 90 |
+
hf_ds = datasets.load_from_disk(dataset_path)
|
| 91 |
+
|
| 92 |
+
train_dataset = MusicDataset(
|
| 93 |
+
hf_ds,
|
| 94 |
+
split="train",
|
| 95 |
+
tokenizer_path=tokenizer_path,
|
| 96 |
+
num_audio_token=num_audio_token,
|
| 97 |
+
use_fast=True,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
print(f"Dataset size: {len(train_dataset)}")
|
| 101 |
+
|
| 102 |
+
return train_dataset
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
parser = argparse.ArgumentParser(
|
| 107 |
+
description="Train MAGEL directly from a vanilla Qwen3 base checkpoint."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--dataset_path",
|
| 112 |
+
type=str,
|
| 113 |
+
default="muse_mucodec_chord.ds",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--model_path",
|
| 118 |
+
type=str,
|
| 119 |
+
default="checkpoints/Qwen3-0.6B",
|
| 120 |
+
help="Local Qwen3 base checkpoint path.",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--tokenizer_path",
|
| 124 |
+
type=str,
|
| 125 |
+
default="checkpoints/Qwen3-0.6B",
|
| 126 |
+
help="Local tokenizer checkpoint path.",
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--model_dtype",
|
| 130 |
+
type=str,
|
| 131 |
+
default="bfloat16",
|
| 132 |
+
choices=["float32", "float16", "bfloat16"],
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--attn_implementation",
|
| 136 |
+
type=str,
|
| 137 |
+
default="sdpa",
|
| 138 |
+
choices=["eager", "sdpa", "flash_attention_2"],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
parser.add_argument("--output_dir", type=str, default="./output_qwen3_0p6b_train")
|
| 142 |
+
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
|
| 143 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
|
| 144 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
| 145 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
| 146 |
+
parser.add_argument("--num_train_epochs", type=float, default=20)
|
| 147 |
+
parser.add_argument("--warmup_steps", type=int, default=1000)
|
| 148 |
+
parser.add_argument("--max_grad_norm", type=float, default=5.0)
|
| 149 |
+
parser.add_argument("--logging_steps", type=int, default=10)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--resume_from_checkpoint",
|
| 152 |
+
type=str,
|
| 153 |
+
default=None,
|
| 154 |
+
help="Resume training from a Trainer checkpoint directory such as output_dir/checkpoint-500.",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=12)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--gradient_checkpointing",
|
| 160 |
+
dest="gradient_checkpointing",
|
| 161 |
+
action="store_true",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--deepspeed",
|
| 166 |
+
type=str,
|
| 167 |
+
default=None,
|
| 168 |
+
help="Path to DeepSpeed config. Leave unset to disable DeepSpeed.",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
parser.add_argument("--report_to", type=str, default="wandb")
|
| 172 |
+
parser.add_argument("--wandb_project", type=str, default="vaultum-qwen3-0p6b")
|
| 173 |
+
parser.add_argument("--wandb_run_name", type=str, default=None)
|
| 174 |
+
|
| 175 |
+
args = parser.parse_args()
|
| 176 |
+
|
| 177 |
+
model_dtype = {
|
| 178 |
+
"float32": torch.float32,
|
| 179 |
+
"float16": torch.float16,
|
| 180 |
+
"bfloat16": torch.bfloat16,
|
| 181 |
+
}[args.model_dtype]
|
| 182 |
+
|
| 183 |
+
model_source = resolve_model_source(
|
| 184 |
+
model_path=args.model_path,
|
| 185 |
+
resume_from_checkpoint=args.resume_from_checkpoint,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
base_config = AutoConfig.from_pretrained(
|
| 189 |
+
model_source,
|
| 190 |
+
local_files_only=True,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
num_audio_token = int(base_config.magel_num_audio_token)
|
| 194 |
+
print(f"Using num_audio_token={num_audio_token}")
|
| 195 |
+
|
| 196 |
+
train_dataset = create_dataset(
|
| 197 |
+
dataset_path=args.dataset_path,
|
| 198 |
+
tokenizer_path=args.tokenizer_path,
|
| 199 |
+
num_audio_token=num_audio_token,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
target_vocab_size = train_dataset.tokenizer_vocab_size
|
| 203 |
+
|
| 204 |
+
model = create_model(
|
| 205 |
+
model_path=model_source,
|
| 206 |
+
model_dtype=model_dtype,
|
| 207 |
+
attn_implementation=args.attn_implementation,
|
| 208 |
+
target_vocab_size=target_vocab_size,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
training_args = TrainingArguments(
|
| 212 |
+
output_dir=args.output_dir,
|
| 213 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 214 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 215 |
+
learning_rate=args.learning_rate,
|
| 216 |
+
weight_decay=args.weight_decay,
|
| 217 |
+
num_train_epochs=args.num_train_epochs,
|
| 218 |
+
warmup_steps=args.warmup_steps,
|
| 219 |
+
max_grad_norm=args.max_grad_norm,
|
| 220 |
+
logging_steps=args.logging_steps,
|
| 221 |
+
save_strategy="epoch",
|
| 222 |
+
dataloader_num_workers=args.dataloader_num_workers,
|
| 223 |
+
bf16=(args.model_dtype == "bfloat16"),
|
| 224 |
+
fp16=(args.model_dtype == "float16"),
|
| 225 |
+
gradient_checkpointing=args.gradient_checkpointing,
|
| 226 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 227 |
+
deepspeed=args.deepspeed,
|
| 228 |
+
remove_unused_columns=False,
|
| 229 |
+
dataloader_drop_last=True,
|
| 230 |
+
report_to=args.report_to,
|
| 231 |
+
logging_dir=None,
|
| 232 |
+
run_name=args.wandb_run_name,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if args.wandb_project and "wandb" in args.report_to:
|
| 236 |
+
os.environ["WANDB_PROJECT"] = args.wandb_project
|
| 237 |
+
|
| 238 |
+
trainer = Trainer(
|
| 239 |
+
model=model,
|
| 240 |
+
args=training_args,
|
| 241 |
+
train_dataset=train_dataset,
|
| 242 |
+
data_collator=DataCollate(),
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if args.resume_from_checkpoint:
|
| 246 |
+
print(f"Resuming training from checkpoint: {args.resume_from_checkpoint}")
|
| 247 |
+
else:
|
| 248 |
+
print("Starting training from current model initialization.")
|
| 249 |
+
|
| 250 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
| 251 |
+
final_dir = os.path.join(args.output_dir, "final")
|
| 252 |
+
trainer.save_model(final_dir)
|
| 253 |
+
train_dataset.tokenizer.save_pretrained(final_dir)
|
| 254 |
+
|
| 255 |
+
print(f"Training complete. Final model saved to: {final_dir}")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
main()
|
vocab/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
"""Condition vocab package."""
|
| 4 |
+
|
| 5 |
+
from .chord import (
|
| 6 |
+
CHORD_BOS_ID,
|
| 7 |
+
CHORD_EOS_ID,
|
| 8 |
+
CHORD_LABELS,
|
| 9 |
+
CHORD_LABEL_TO_ID,
|
| 10 |
+
CHORD_N_ID,
|
| 11 |
+
NUM_CHORD_CLASSES,
|
| 12 |
+
build_frame_chord_ids,
|
| 13 |
+
chord_id_to_label,
|
| 14 |
+
chord_to_id,
|
| 15 |
+
normalize_chord_text,
|
| 16 |
+
)
|
| 17 |
+
from .sections import (
|
| 18 |
+
SEGMENT_FALLBACK_ID,
|
| 19 |
+
SEGMENT_LABELS,
|
| 20 |
+
SEGMENT_LABEL_TO_ID,
|
| 21 |
+
NUM_STRUCTURE_CLASSES,
|
| 22 |
+
STRUCTURE_BOS_ID,
|
| 23 |
+
STRUCTURE_EOS_ID,
|
| 24 |
+
build_frame_structure_ids,
|
| 25 |
+
normalize_structure_label,
|
| 26 |
+
structure_id_to_label,
|
| 27 |
+
structure_to_id,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"CHORD_BOS_ID",
|
| 32 |
+
"CHORD_EOS_ID",
|
| 33 |
+
"CHORD_LABELS",
|
| 34 |
+
"CHORD_LABEL_TO_ID",
|
| 35 |
+
"CHORD_N_ID",
|
| 36 |
+
"NUM_CHORD_CLASSES",
|
| 37 |
+
"normalize_chord_text",
|
| 38 |
+
"chord_to_id",
|
| 39 |
+
"chord_id_to_label",
|
| 40 |
+
"build_frame_chord_ids",
|
| 41 |
+
"SEGMENT_LABELS",
|
| 42 |
+
"SEGMENT_LABEL_TO_ID",
|
| 43 |
+
"SEGMENT_FALLBACK_ID",
|
| 44 |
+
"STRUCTURE_BOS_ID",
|
| 45 |
+
"STRUCTURE_EOS_ID",
|
| 46 |
+
"NUM_STRUCTURE_CLASSES",
|
| 47 |
+
"normalize_structure_label",
|
| 48 |
+
"structure_to_id",
|
| 49 |
+
"structure_id_to_label",
|
| 50 |
+
"build_frame_structure_ids",
|
| 51 |
+
]
|
vocab/chord.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
"""Chord vocab helpers.
|
| 4 |
+
|
| 5 |
+
Dataset inspection summary for `muse_mucodec_chord.ds`:
|
| 6 |
+
- `chords[].type` uses only 25 labels
|
| 7 |
+
- labels are either `N` or `Root:maj|min`
|
| 8 |
+
- roots are represented with sharps rather than flats
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
CHORD_BOS_ID = 26
|
| 17 |
+
CHORD_EOS_ID = 27
|
| 18 |
+
CHORD_N_ID = 1
|
| 19 |
+
NUM_CHORD_CLASSES = 28
|
| 20 |
+
|
| 21 |
+
_PITCH_TO_PC = {
|
| 22 |
+
"C": 0,
|
| 23 |
+
"B#": 0,
|
| 24 |
+
"C#": 1,
|
| 25 |
+
"Db": 1,
|
| 26 |
+
"D": 2,
|
| 27 |
+
"D#": 3,
|
| 28 |
+
"Eb": 3,
|
| 29 |
+
"E": 4,
|
| 30 |
+
"Fb": 4,
|
| 31 |
+
"F": 5,
|
| 32 |
+
"E#": 5,
|
| 33 |
+
"F#": 6,
|
| 34 |
+
"Gb": 6,
|
| 35 |
+
"G": 7,
|
| 36 |
+
"G#": 8,
|
| 37 |
+
"Ab": 8,
|
| 38 |
+
"A": 9,
|
| 39 |
+
"A#": 10,
|
| 40 |
+
"Bb": 10,
|
| 41 |
+
"B": 11,
|
| 42 |
+
"Cb": 11,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
_PC_TO_ROOT = [
|
| 46 |
+
"C",
|
| 47 |
+
"C#",
|
| 48 |
+
"D",
|
| 49 |
+
"D#",
|
| 50 |
+
"E",
|
| 51 |
+
"F",
|
| 52 |
+
"F#",
|
| 53 |
+
"G",
|
| 54 |
+
"G#",
|
| 55 |
+
"A",
|
| 56 |
+
"A#",
|
| 57 |
+
"B",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
CHORD_ROOTS = tuple(_PC_TO_ROOT)
|
| 61 |
+
_QUALITIES = [
|
| 62 |
+
"maj",
|
| 63 |
+
"min",
|
| 64 |
+
]
|
| 65 |
+
CHORD_LABELS = (
|
| 66 |
+
("pad", "N")
|
| 67 |
+
+ tuple(f"{root}:{quality}" for root in CHORD_ROOTS for quality in _QUALITIES)
|
| 68 |
+
+ ("bos", "eos")
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
CHORD_LABEL_TO_ID = {label: index for index, label in enumerate(CHORD_LABELS)}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def normalize_chord_text(label: str) -> str:
|
| 75 |
+
if label is None or not isinstance(label, str):
|
| 76 |
+
return "N"
|
| 77 |
+
|
| 78 |
+
label = label.strip()
|
| 79 |
+
if not label or label.lower() in {"n", "none"}:
|
| 80 |
+
return "N"
|
| 81 |
+
|
| 82 |
+
base = label.replace("♯", "#").replace("♭", "b").strip()
|
| 83 |
+
match = re.fullmatch(r"([A-Ga-g])([#b]?):(maj|min)", base, flags=re.IGNORECASE)
|
| 84 |
+
if not match:
|
| 85 |
+
return "N"
|
| 86 |
+
|
| 87 |
+
root = (match.group(1).upper() + match.group(2)).strip()
|
| 88 |
+
pc = _PITCH_TO_PC.get(root)
|
| 89 |
+
if pc is None:
|
| 90 |
+
return "N"
|
| 91 |
+
|
| 92 |
+
quality = match.group(3).lower()
|
| 93 |
+
return f"{CHORD_ROOTS[pc]}:{quality}"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def chord_to_id(label: str) -> int:
|
| 97 |
+
return CHORD_LABEL_TO_ID.get(normalize_chord_text(label), CHORD_N_ID)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def chord_id_to_label(chord_id: int) -> str:
|
| 101 |
+
if 0 <= chord_id < len(CHORD_LABELS):
|
| 102 |
+
return CHORD_LABELS[chord_id]
|
| 103 |
+
return "N"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def build_frame_chord_ids(
|
| 107 |
+
chord_segments: list[dict], total_frames: int, fps: int = 25
|
| 108 |
+
) -> np.ndarray:
|
| 109 |
+
chord_arr = (
|
| 110 |
+
np.full((total_frames,), CHORD_N_ID, dtype=np.int64)
|
| 111 |
+
if total_frames > 0
|
| 112 |
+
else np.zeros((0,), dtype=np.int64)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if not isinstance(chord_segments, (list, tuple)):
|
| 116 |
+
return chord_arr
|
| 117 |
+
|
| 118 |
+
for seg in chord_segments:
|
| 119 |
+
if not isinstance(seg, dict):
|
| 120 |
+
continue
|
| 121 |
+
if "start_frame" in seg and "end_frame" in seg:
|
| 122 |
+
start = int(seg.get("start_frame", 0))
|
| 123 |
+
end = int(seg.get("end_frame", 0))
|
| 124 |
+
else:
|
| 125 |
+
start = int(float(seg.get("start", 0.0)) * fps)
|
| 126 |
+
end = int(math.ceil(float(seg.get("end", 0.0)) * fps))
|
| 127 |
+
start = max(0, min(total_frames, start))
|
| 128 |
+
end = max(start, min(total_frames, end))
|
| 129 |
+
if end <= start:
|
| 130 |
+
continue
|
| 131 |
+
label = seg.get("type") or seg.get("chord") or seg.get("label") or "N"
|
| 132 |
+
chord_arr[start:end] = chord_to_id(str(label))
|
| 133 |
+
|
| 134 |
+
return chord_arr
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
import datasets
|
| 139 |
+
|
| 140 |
+
ds = datasets.load_from_disk("muse_mucodec_chord.ds")
|
| 141 |
+
sample = ds["train"][0]
|
| 142 |
+
total_frames = 1500 # e.g., for a 6-second clip at 25 fps
|
| 143 |
+
chord_arr = build_frame_chord_ids(sample["chords"], total_frames)
|
| 144 |
+
print(chord_arr)
|
vocab/sections.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
"""Section vocab helpers.
|
| 4 |
+
|
| 5 |
+
Dataset inspection summary for `muse_mucodec_chord.ds`:
|
| 6 |
+
- `sections[]` has fields `section/text/start/end/desc`
|
| 7 |
+
- `sections[].section` uses only 6 labels:
|
| 8 |
+
`Intro`, `Verse`, `Prechorus`, `Chorus`, `Bridge`, `Outro`
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
SEGMENT_LABELS = (
|
| 17 |
+
"pad",
|
| 18 |
+
"intro",
|
| 19 |
+
"verse",
|
| 20 |
+
"chorus",
|
| 21 |
+
"prechorus",
|
| 22 |
+
"bridge",
|
| 23 |
+
"outro",
|
| 24 |
+
)
|
| 25 |
+
SEGMENT_LABEL_TO_ID = {label: index for index, label in enumerate(SEGMENT_LABELS)}
|
| 26 |
+
SEGMENT_FALLBACK_ID = SEGMENT_LABEL_TO_ID["pad"]
|
| 27 |
+
|
| 28 |
+
STRUCTURE_BOS_ID = len(SEGMENT_LABELS)
|
| 29 |
+
STRUCTURE_EOS_ID = STRUCTURE_BOS_ID + 1
|
| 30 |
+
NUM_STRUCTURE_CLASSES = STRUCTURE_EOS_ID + 1
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def normalize_structure_label(label: str) -> str:
|
| 34 |
+
if label is None or not isinstance(label, str):
|
| 35 |
+
return "pad"
|
| 36 |
+
|
| 37 |
+
normalized = re.sub(r"[\s_-]+", "", label.strip().lower())
|
| 38 |
+
normalized = re.sub(r"\d+", "", normalized)
|
| 39 |
+
|
| 40 |
+
if not normalized:
|
| 41 |
+
return "pad"
|
| 42 |
+
|
| 43 |
+
return (
|
| 44 |
+
normalized
|
| 45 |
+
if normalized != "pad" and normalized in SEGMENT_LABEL_TO_ID
|
| 46 |
+
else "pad"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def structure_to_id(structure: str) -> int:
|
| 51 |
+
return SEGMENT_LABEL_TO_ID.get(
|
| 52 |
+
normalize_structure_label(structure), SEGMENT_FALLBACK_ID
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def structure_id_to_label(segment_id: int) -> str:
|
| 57 |
+
if segment_id == STRUCTURE_BOS_ID:
|
| 58 |
+
return "bos"
|
| 59 |
+
if segment_id == STRUCTURE_EOS_ID:
|
| 60 |
+
return "eos"
|
| 61 |
+
if 0 <= segment_id < len(SEGMENT_LABELS):
|
| 62 |
+
return SEGMENT_LABELS[segment_id]
|
| 63 |
+
return "pad"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def build_frame_structure_ids(
|
| 67 |
+
sections: list[dict], total_frames: int, fps: int = 25
|
| 68 |
+
) -> np.ndarray:
|
| 69 |
+
|
| 70 |
+
labels = np.full((total_frames,), SEGMENT_FALLBACK_ID, dtype=np.int64)
|
| 71 |
+
if not isinstance(sections, (list, tuple)):
|
| 72 |
+
return labels
|
| 73 |
+
|
| 74 |
+
for seg in sections:
|
| 75 |
+
if not isinstance(seg, dict):
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
if "start_frame" in seg and "end_frame" in seg:
|
| 79 |
+
start = int(seg.get("start_frame", 0))
|
| 80 |
+
end = int(seg.get("end_frame", 0))
|
| 81 |
+
else:
|
| 82 |
+
start = int(float(seg.get("start", 0.0)) * fps)
|
| 83 |
+
end = int(math.ceil(float(seg.get("end", 0.0)) * fps))
|
| 84 |
+
|
| 85 |
+
start = max(0, min(total_frames, start))
|
| 86 |
+
end = max(start, min(total_frames, end))
|
| 87 |
+
if end <= start:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
label = seg.get("section", seg.get("structure", "")) or ""
|
| 91 |
+
labels[start:end] = structure_to_id(str(label))
|
| 92 |
+
|
| 93 |
+
return labels
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
# Example usage
|
| 98 |
+
segments = [
|
| 99 |
+
{"section": "intro", "start": 0.0, "end": 10.0},
|
| 100 |
+
{"section": "Verse", "start": 10.0, "end": 30.0},
|
| 101 |
+
{"section": "Chorus", "start": 30.0, "end": 50.0},
|
| 102 |
+
]
|
| 103 |
+
total_frames = 50 # e.g., for a 50-second audio at 25 fps
|
| 104 |
+
labels = build_frame_structure_ids(segments, total_frames)
|
| 105 |
+
print(labels)
|
wandb/debug-cli.root.log
ADDED
|
File without changes
|