freyza commited on
Commit
d98a90b
·
1 Parent(s): 600b02f

Upload process_ckpt.py

Browse files
Files changed (1) hide show
  1. infer/lib/train/process_ckpt.py +261 -0
infer/lib/train/process_ckpt.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+
8
+ from i18n.i18n import I18nAuto
9
+
10
+ i18n = I18nAuto()
11
+
12
+
13
+ def savee(ckpt, sr, if_f0, name, epoch, version, hps):
14
+ try:
15
+ opt = OrderedDict()
16
+ opt["weight"] = {}
17
+ for key in ckpt.keys():
18
+ if "enc_q" in key:
19
+ continue
20
+ opt["weight"][key] = ckpt[key].half()
21
+ opt["config"] = [
22
+ hps.data.filter_length // 2 + 1,
23
+ 32,
24
+ hps.model.inter_channels,
25
+ hps.model.hidden_channels,
26
+ hps.model.filter_channels,
27
+ hps.model.n_heads,
28
+ hps.model.n_layers,
29
+ hps.model.kernel_size,
30
+ hps.model.p_dropout,
31
+ hps.model.resblock,
32
+ hps.model.resblock_kernel_sizes,
33
+ hps.model.resblock_dilation_sizes,
34
+ hps.model.upsample_rates,
35
+ hps.model.upsample_initial_channel,
36
+ hps.model.upsample_kernel_sizes,
37
+ hps.model.spk_embed_dim,
38
+ hps.model.gin_channels,
39
+ hps.data.sampling_rate,
40
+ ]
41
+ opt["info"] = "%sepoch" % epoch
42
+ opt["sr"] = sr
43
+ opt["f0"] = if_f0
44
+ opt["version"] = version
45
+ torch.save(opt, "assets/weights/%s.pth" % name)
46
+ return "Success."
47
+ except:
48
+ return traceback.format_exc()
49
+
50
+
51
+ def show_info(path):
52
+ try:
53
+ a = torch.load(path, map_location="cpu")
54
+ return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
55
+ a.get("info", "None"),
56
+ a.get("sr", "None"),
57
+ a.get("f0", "None"),
58
+ a.get("version", "None"),
59
+ )
60
+ except:
61
+ return traceback.format_exc()
62
+
63
+
64
+ def extract_small_model(path, name, sr, if_f0, info, version):
65
+ try:
66
+ ckpt = torch.load(path, map_location="cpu")
67
+ if "model" in ckpt:
68
+ ckpt = ckpt["model"]
69
+ opt = OrderedDict()
70
+ opt["weight"] = {}
71
+ for key in ckpt.keys():
72
+ if "enc_q" in key:
73
+ continue
74
+ opt["weight"][key] = ckpt[key].half()
75
+ if sr == "40k":
76
+ opt["config"] = [
77
+ 1025,
78
+ 32,
79
+ 192,
80
+ 192,
81
+ 768,
82
+ 2,
83
+ 6,
84
+ 3,
85
+ 0,
86
+ "1",
87
+ [3, 7, 11],
88
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
89
+ [10, 10, 2, 2],
90
+ 512,
91
+ [16, 16, 4, 4],
92
+ 109,
93
+ 256,
94
+ 40000,
95
+ ]
96
+ elif sr == "48k":
97
+ if version == "v1":
98
+ opt["config"] = [
99
+ 1025,
100
+ 32,
101
+ 192,
102
+ 192,
103
+ 768,
104
+ 2,
105
+ 6,
106
+ 3,
107
+ 0,
108
+ "1",
109
+ [3, 7, 11],
110
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
111
+ [10, 6, 2, 2, 2],
112
+ 512,
113
+ [16, 16, 4, 4, 4],
114
+ 109,
115
+ 256,
116
+ 48000,
117
+ ]
118
+ else:
119
+ opt["config"] = [
120
+ 1025,
121
+ 32,
122
+ 192,
123
+ 192,
124
+ 768,
125
+ 2,
126
+ 6,
127
+ 3,
128
+ 0,
129
+ "1",
130
+ [3, 7, 11],
131
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
132
+ [12, 10, 2, 2],
133
+ 512,
134
+ [24, 20, 4, 4],
135
+ 109,
136
+ 256,
137
+ 48000,
138
+ ]
139
+ elif sr == "32k":
140
+ if version == "v1":
141
+ opt["config"] = [
142
+ 513,
143
+ 32,
144
+ 192,
145
+ 192,
146
+ 768,
147
+ 2,
148
+ 6,
149
+ 3,
150
+ 0,
151
+ "1",
152
+ [3, 7, 11],
153
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
154
+ [10, 4, 2, 2, 2],
155
+ 512,
156
+ [16, 16, 4, 4, 4],
157
+ 109,
158
+ 256,
159
+ 32000,
160
+ ]
161
+ else:
162
+ opt["config"] = [
163
+ 513,
164
+ 32,
165
+ 192,
166
+ 192,
167
+ 768,
168
+ 2,
169
+ 6,
170
+ 3,
171
+ 0,
172
+ "1",
173
+ [3, 7, 11],
174
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
175
+ [10, 8, 2, 2],
176
+ 512,
177
+ [20, 16, 4, 4],
178
+ 109,
179
+ 256,
180
+ 32000,
181
+ ]
182
+ if info == "":
183
+ info = "Extracted model."
184
+ opt["info"] = info
185
+ opt["version"] = version
186
+ opt["sr"] = sr
187
+ opt["f0"] = int(if_f0)
188
+ torch.save(opt, "assets/weights/%s.pth" % name)
189
+ return "Success."
190
+ except:
191
+ return traceback.format_exc()
192
+
193
+
194
+ def change_info(path, info, name):
195
+ try:
196
+ ckpt = torch.load(path, map_location="cpu")
197
+ ckpt["info"] = info
198
+ if name == "":
199
+ name = os.path.basename(path)
200
+ torch.save(ckpt, "assets/weights/%s" % name)
201
+ return "Success."
202
+ except:
203
+ return traceback.format_exc()
204
+
205
+
206
+ def merge(path1, path2, alpha1, sr, f0, info, name, version):
207
+ try:
208
+
209
+ def extract(ckpt):
210
+ a = ckpt["model"]
211
+ opt = OrderedDict()
212
+ opt["weight"] = {}
213
+ for key in a.keys():
214
+ if "enc_q" in key:
215
+ continue
216
+ opt["weight"][key] = a[key]
217
+ return opt
218
+
219
+ ckpt1 = torch.load(path1, map_location="cpu")
220
+ ckpt2 = torch.load(path2, map_location="cpu")
221
+ cfg = ckpt1["config"]
222
+ if "model" in ckpt1:
223
+ ckpt1 = extract(ckpt1)
224
+ else:
225
+ ckpt1 = ckpt1["weight"]
226
+ if "model" in ckpt2:
227
+ ckpt2 = extract(ckpt2)
228
+ else:
229
+ ckpt2 = ckpt2["weight"]
230
+ if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
231
+ return "Fail to merge the models. The model architectures are not the same."
232
+ opt = OrderedDict()
233
+ opt["weight"] = {}
234
+ for key in ckpt1.keys():
235
+ # try:
236
+ if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
237
+ min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
238
+ opt["weight"][key] = (
239
+ alpha1 * (ckpt1[key][:min_shape0].float())
240
+ + (1 - alpha1) * (ckpt2[key][:min_shape0].float())
241
+ ).half()
242
+ else:
243
+ opt["weight"][key] = (
244
+ alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
245
+ ).half()
246
+ # except:
247
+ # pdb.set_trace()
248
+ opt["config"] = cfg
249
+ """
250
+ if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
251
+ elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
252
+ elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
253
+ """
254
+ opt["sr"] = sr
255
+ opt["f0"] = 1 if f0 == i18n("是") else 0
256
+ opt["version"] = version
257
+ opt["info"] = info
258
+ torch.save(opt, "assets/weights/%s.pth" % name)
259
+ return "Success."
260
+ except:
261
+ return traceback.format_exc()