Pj12 commited on
Commit
cbfe346
·
verified ·
1 Parent(s): cfefce2

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +504 -0
utils.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, traceback
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
21
+
22
+ ##################
23
+ def go(model, bkey):
24
+ saved_state_dict = checkpoint_dict[bkey]
25
+ if hasattr(model, "module"):
26
+ state_dict = model.module.state_dict()
27
+ else:
28
+ state_dict = model.state_dict()
29
+ new_state_dict = {}
30
+ for k, v in state_dict.items(): # 模型需要的shape
31
+ try:
32
+ new_state_dict[k] = saved_state_dict[k]
33
+ if saved_state_dict[k].shape != state_dict[k].shape:
34
+ print(
35
+ "shape-%s-mismatch|need-%s|get-%s"
36
+ % (k, state_dict[k].shape, saved_state_dict[k].shape)
37
+ ) #
38
+ raise KeyError
39
+ except:
40
+ # logger.info(traceback.format_exc())
41
+ logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
42
+ new_state_dict[k] = v # 模型自带的随机值
43
+ if hasattr(model, "module"):
44
+ model.module.load_state_dict(new_state_dict, strict=False)
45
+ else:
46
+ model.load_state_dict(new_state_dict, strict=False)
47
+
48
+ go(combd, "combd")
49
+ go(sbd, "sbd")
50
+ #############
51
+ logger.info("Loaded model weights")
52
+
53
+ iteration = checkpoint_dict["iteration"]
54
+ learning_rate = checkpoint_dict["learning_rate"]
55
+ if (
56
+ optimizer is not None and load_opt == 1
57
+ ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
58
+ # try:
59
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
60
+ # except:
61
+ # traceback.print_exc()
62
+ logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
63
+ return model, optimizer, learning_rate, iteration
64
+
65
+
66
+ # def load_checkpoint(checkpoint_path, model, optimizer=None):
67
+ # assert os.path.isfile(checkpoint_path)
68
+ # checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
69
+ # iteration = checkpoint_dict['iteration']
70
+ # learning_rate = checkpoint_dict['learning_rate']
71
+ # if optimizer is not None:
72
+ # optimizer.load_state_dict(checkpoint_dict['optimizer'])
73
+ # # print(1111)
74
+ # saved_state_dict = checkpoint_dict['model']
75
+ # # print(1111)
76
+ #
77
+ # if hasattr(model, 'module'):
78
+ # state_dict = model.module.state_dict()
79
+ # else:
80
+ # state_dict = model.state_dict()
81
+ # new_state_dict= {}
82
+ # for k, v in state_dict.items():
83
+ # try:
84
+ # new_state_dict[k] = saved_state_dict[k]
85
+ # except:
86
+ # logger.info("%s is not in the checkpoint" % k)
87
+ # new_state_dict[k] = v
88
+ # if hasattr(model, 'module'):
89
+ # model.module.load_state_dict(new_state_dict)
90
+ # else:
91
+ # model.load_state_dict(new_state_dict)
92
+ # logger.info("Loaded checkpoint '{}' (epoch {})" .format(
93
+ # checkpoint_path, iteration))
94
+ # return model, optimizer, learning_rate, iteration
95
+ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
96
+ assert os.path.isfile(checkpoint_path)
97
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
98
+
99
+ saved_state_dict = checkpoint_dict["model"]
100
+ if hasattr(model, "module"):
101
+ state_dict = model.module.state_dict()
102
+ else:
103
+ state_dict = model.state_dict()
104
+ new_state_dict = {}
105
+ for k, v in state_dict.items(): # 模型需要的shape
106
+ try:
107
+ new_state_dict[k] = saved_state_dict[k]
108
+ if saved_state_dict[k].shape != state_dict[k].shape:
109
+ print(
110
+ "shape-%s-mismatch|need-%s|get-%s"
111
+ % (k, state_dict[k].shape, saved_state_dict[k].shape)
112
+ ) #
113
+ raise KeyError
114
+ except:
115
+ # logger.info(traceback.format_exc())
116
+ logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
117
+ new_state_dict[k] = v # 模型自带的随机值
118
+ if hasattr(model, "module"):
119
+ model.module.load_state_dict(new_state_dict, strict=False)
120
+ else:
121
+ model.load_state_dict(new_state_dict, strict=False)
122
+ logger.info("Loaded model weights")
123
+
124
+ iteration = checkpoint_dict["iteration"]
125
+ learning_rate = checkpoint_dict["learning_rate"]
126
+ if (
127
+ optimizer is not None and load_opt == 1
128
+ ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
129
+ # try:
130
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
131
+ # except:
132
+ # traceback.print_exc()
133
+ logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
134
+ return model, optimizer, learning_rate, iteration
135
+
136
+
137
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
138
+ logger.info(
139
+ "Saving model and optimizer state at epoch {} to {}".format(
140
+ iteration, checkpoint_path
141
+ )
142
+ )
143
+ if hasattr(model, "module"):
144
+ state_dict = model.module.state_dict()
145
+ else:
146
+ state_dict = model.state_dict()
147
+ torch.save(
148
+ {
149
+ "model": state_dict,
150
+ "iteration": iteration,
151
+ "optimizer": optimizer.state_dict(),
152
+ "learning_rate": learning_rate,
153
+ },
154
+ checkpoint_path,
155
+ )
156
+
157
+
158
+ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
159
+ logger.info(
160
+ "Saving model and optimizer state at epoch {} to {}".format(
161
+ iteration, checkpoint_path
162
+ )
163
+ )
164
+ if hasattr(combd, "module"):
165
+ state_dict_combd = combd.module.state_dict()
166
+ else:
167
+ state_dict_combd = combd.state_dict()
168
+ if hasattr(sbd, "module"):
169
+ state_dict_sbd = sbd.module.state_dict()
170
+ else:
171
+ state_dict_sbd = sbd.state_dict()
172
+ torch.save(
173
+ {
174
+ "combd": state_dict_combd,
175
+ "sbd": state_dict_sbd,
176
+ "iteration": iteration,
177
+ "optimizer": optimizer.state_dict(),
178
+ "learning_rate": learning_rate,
179
+ },
180
+ checkpoint_path,
181
+ )
182
+
183
+
184
+ def summarize(
185
+ writer,
186
+ global_step,
187
+ scalars={},
188
+ histograms={},
189
+ images={},
190
+ audios={},
191
+ audio_sampling_rate=22050,
192
+ ):
193
+ for k, v in scalars.items():
194
+ writer.add_scalar(k, v, global_step)
195
+ for k, v in histograms.items():
196
+ writer.add_histogram(k, v, global_step)
197
+ for k, v in images.items():
198
+ writer.add_image(k, v, global_step, dataformats="HWC")
199
+ for k, v in audios.items():
200
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
201
+
202
+
203
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
204
+ f_list = glob.glob(os.path.join(dir_path, regex))
205
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
206
+ x = f_list[-1]
207
+ print(x)
208
+ return x
209
+
210
+
211
+ def plot_spectrogram_to_numpy(spectrogram):
212
+ global MATPLOTLIB_FLAG
213
+ if not MATPLOTLIB_FLAG:
214
+ import matplotlib
215
+
216
+ matplotlib.use("Agg")
217
+ MATPLOTLIB_FLAG = True
218
+ mpl_logger = logging.getLogger("matplotlib")
219
+ mpl_logger.setLevel(logging.WARNING)
220
+ import matplotlib.pylab as plt
221
+ import numpy as np
222
+
223
+ fig, ax = plt.subplots(figsize=(10, 2))
224
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
225
+ plt.colorbar(im, ax=ax)
226
+ plt.xlabel("Frames")
227
+ plt.ylabel("Channels")
228
+ plt.tight_layout()
229
+
230
+ fig.canvas.draw()
231
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
232
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
233
+ plt.close()
234
+ return data
235
+
236
+
237
+ def plot_alignment_to_numpy(alignment, info=None):
238
+ global MATPLOTLIB_FLAG
239
+ if not MATPLOTLIB_FLAG:
240
+ import matplotlib
241
+
242
+ matplotlib.use("Agg")
243
+ MATPLOTLIB_FLAG = True
244
+ mpl_logger = logging.getLogger("matplotlib")
245
+ mpl_logger.setLevel(logging.WARNING)
246
+ import matplotlib.pylab as plt
247
+ import numpy as np
248
+
249
+ fig, ax = plt.subplots(figsize=(6, 4))
250
+ im = ax.imshow(
251
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
252
+ )
253
+ fig.colorbar(im, ax=ax)
254
+ xlabel = "Decoder timestep"
255
+ if info is not None:
256
+ xlabel += "\n\n" + info
257
+ plt.xlabel(xlabel)
258
+ plt.ylabel("Encoder timestep")
259
+ plt.tight_layout()
260
+
261
+ fig.canvas.draw()
262
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
263
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
264
+ plt.close()
265
+ return data
266
+
267
+
268
+ def load_wav_to_torch(full_path):
269
+ sampling_rate, data = read(full_path)
270
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
271
+
272
+
273
+ def load_filepaths_and_text(filename, split="|"):
274
+ with open(filename, encoding="utf-8") as f:
275
+ filepaths_and_text = [line.strip().split(split) for line in f]
276
+ return filepaths_and_text
277
+
278
+
279
+ def get_hparams(init=True):
280
+ """
281
+ todo:
282
+ 结尾七人组:
283
+ 保存频率、总epoch done
284
+ bs done
285
+ pretrainG、pretrainD done
286
+ 卡号:os.en["CUDA_VISIBLE_DEVICES"] done
287
+ if_latest done
288
+ 模型:if_f0 done
289
+ 采样率:自动选择config done
290
+ 是否缓存数据集进GPU:if_cache_data_in_gpu done
291
+
292
+ -m:
293
+ 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done
294
+ -c不要了
295
+ """
296
+ parser = argparse.ArgumentParser()
297
+ # parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration')
298
+ parser.add_argument(
299
+ "-se",
300
+ "--save_every_epoch",
301
+ type=int,
302
+ required=True,
303
+ help="checkpoint save frequency (epoch)",
304
+ )
305
+ parser.add_argument(
306
+ "-te", "--total_epoch", type=int, required=True, help="total_epoch"
307
+ )
308
+ parser.add_argument(
309
+ "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
310
+ )
311
+ parser.add_argument(
312
+ "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
313
+ )
314
+ parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
315
+ parser.add_argument(
316
+ "-bs", "--batch_size", type=int, required=True, help="batch size"
317
+ )
318
+ parser.add_argument(
319
+ "-e", "--experiment_dir", type=str, required=True, help="experiment dir"
320
+ ) # -m
321
+ parser.add_argument(
322
+ "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
323
+ )
324
+ parser.add_argument(
325
+ "-sw",
326
+ "--save_every_weights",
327
+ type=str,
328
+ default="0",
329
+ help="save the extracted model in weights directory when saving checkpoints",
330
+ )
331
+ parser.add_argument(
332
+ "-v", "--version", type=str, required=True, help="model version"
333
+ )
334
+ parser.add_argument(
335
+ "-f0",
336
+ "--if_f0",
337
+ type=int,
338
+ required=True,
339
+ help="use f0 as one of the inputs of the model, 1 or 0",
340
+ )
341
+ parser.add_argument(
342
+ "-l",
343
+ "--if_latest",
344
+ type=int,
345
+ required=True,
346
+ help="if only save the latest G/D pth file, 1 or 0",
347
+ )
348
+ parser.add_argument(
349
+ "-c",
350
+ "--if_cache_data_in_gpu",
351
+ type=int,
352
+ required=True,
353
+ help="if caching the dataset in GPU memory, 1 or 0",
354
+ )
355
+ parser.add_argument(
356
+ "-li", "--log_interval", type=int, required=True, help="log interval"
357
+ )
358
+
359
+ parser.add_argument(
360
+ "-overtrain", "--overtrain", type=int, required=True, help="Detecting Overtrain"
361
+ )
362
+
363
+ args = parser.parse_args()
364
+ name = args.experiment_dir
365
+ experiment_dir = os.path.join("./logs", args.experiment_dir)
366
+
367
+ if not os.path.exists(experiment_dir):
368
+ os.makedirs(experiment_dir)
369
+
370
+ if args.version == "v1" or args.sample_rate == "40k":
371
+ config_path = "configs/%s.json" % args.sample_rate
372
+ else:
373
+ config_path = "configs/%s_v2.json" % args.sample_rate
374
+ config_save_path = os.path.join(experiment_dir, "config.json")
375
+ if init:
376
+ with open(config_path, "r") as f:
377
+ data = f.read()
378
+ with open(config_save_path, "w") as f:
379
+ f.write(data)
380
+ else:
381
+ with open(config_save_path, "r") as f:
382
+ data = f.read()
383
+ config = json.loads(data)
384
+
385
+ hparams = HParams(**config)
386
+ hparams.model_dir = hparams.experiment_dir = experiment_dir
387
+ hparams.save_every_epoch = args.save_every_epoch
388
+ hparams.name = name
389
+ hparams.total_epoch = args.total_epoch
390
+ hparams.pretrainG = args.pretrainG
391
+ hparams.pretrainD = args.pretrainD
392
+ hparams.version = args.version
393
+ hparams.gpus = args.gpus
394
+ hparams.train.batch_size = args.batch_size
395
+ hparams.sample_rate = args.sample_rate
396
+ hparams.if_f0 = args.if_f0
397
+ hparams.if_latest = args.if_latest
398
+ hparams.save_every_weights = args.save_every_weights
399
+ hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
400
+ hparams.data.training_files = "%s/filelist.txt" % experiment_dir
401
+
402
+ hparams.train.log_interval = args.log_interval
403
+ hparams.overtrain = args.overtrain
404
+
405
+ # Update log_interval in the 'train' section of the config dictionary
406
+ config["train"]["log_interval"] = args.log_interval
407
+
408
+ # Save the updated config back to the config_save_path
409
+ with open(config_save_path, "w") as f:
410
+ json.dump(config, f, indent=4)
411
+
412
+ return hparams
413
+
414
+
415
+ def get_hparams_from_dir(model_dir):
416
+ config_save_path = os.path.join(model_dir, "config.json")
417
+ with open(config_save_path, "r") as f:
418
+ data = f.read()
419
+ config = json.loads(data)
420
+
421
+ hparams = HParams(**config)
422
+ hparams.model_dir = model_dir
423
+ return hparams
424
+
425
+
426
+ def get_hparams_from_file(config_path):
427
+ with open(config_path, "r") as f:
428
+ data = f.read()
429
+ config = json.loads(data)
430
+
431
+ hparams = HParams(**config)
432
+ return hparams
433
+
434
+
435
+ def check_git_hash(model_dir):
436
+ source_dir = os.path.dirname(os.path.realpath(__file__))
437
+ if not os.path.exists(os.path.join(source_dir, ".git")):
438
+ logger.warn(
439
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
440
+ source_dir
441
+ )
442
+ )
443
+ return
444
+
445
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
446
+
447
+ path = os.path.join(model_dir, "githash")
448
+ if os.path.exists(path):
449
+ saved_hash = open(path).read()
450
+ if saved_hash != cur_hash:
451
+ logger.warn(
452
+ "git hash values are different. {}(saved) != {}(current)".format(
453
+ saved_hash[:8], cur_hash[:8]
454
+ )
455
+ )
456
+ else:
457
+ open(path, "w").write(cur_hash)
458
+
459
+
460
+ def get_logger(model_dir, filename="train.log"):
461
+ global logger
462
+ logger = logging.getLogger(os.path.basename(model_dir))
463
+ logger.setLevel(logging.DEBUG)
464
+
465
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
466
+ if not os.path.exists(model_dir):
467
+ os.makedirs(model_dir)
468
+ h = logging.FileHandler(os.path.join(model_dir, filename))
469
+ h.setLevel(logging.DEBUG)
470
+ h.setFormatter(formatter)
471
+ logger.addHandler(h)
472
+ return logger
473
+
474
+
475
+ class HParams:
476
+ def __init__(self, **kwargs):
477
+ for k, v in kwargs.items():
478
+ if type(v) == dict:
479
+ v = HParams(**v)
480
+ self[k] = v
481
+
482
+ def keys(self):
483
+ return self.__dict__.keys()
484
+
485
+ def items(self):
486
+ return self.__dict__.items()
487
+
488
+ def values(self):
489
+ return self.__dict__.values()
490
+
491
+ def __len__(self):
492
+ return len(self.__dict__)
493
+
494
+ def __getitem__(self, key):
495
+ return getattr(self, key)
496
+
497
+ def __setitem__(self, key, value):
498
+ return setattr(self, key, value)
499
+
500
+ def __contains__(self, key):
501
+ return key in self.__dict__
502
+
503
+ def __repr__(self):
504
+ return self.__dict__.__repr__()