Pj12 commited on
Commit
aeefca3
·
verified ·
1 Parent(s): 688b95c

Upload 2 files

Browse files
process_ckpt.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, traceback, os, pdb, sys
2
+
3
+ now_dir = os.getcwd()
4
+ sys.path.append(now_dir)
5
+ from collections import OrderedDict
6
+ from i18n import I18nAuto
7
+
8
+ i18n = I18nAuto()
9
+
10
+
11
+ def savee(ckpt, sr, if_f0, name, epoch, version, hps, experiment_name):
12
+ try:
13
+ opt = OrderedDict()
14
+ opt["weight"] = {}
15
+ for key in ckpt.keys():
16
+ if "enc_q" in key:
17
+ continue
18
+ opt["weight"][key] = ckpt[key].half()
19
+ opt["config"] = [
20
+ hps.data.filter_length // 2 + 1,
21
+ 32,
22
+ hps.model.inter_channels,
23
+ hps.model.hidden_channels,
24
+ hps.model.filter_channels,
25
+ hps.model.n_heads,
26
+ hps.model.n_layers,
27
+ hps.model.kernel_size,
28
+ hps.model.p_dropout,
29
+ hps.model.resblock,
30
+ hps.model.resblock_kernel_sizes,
31
+ hps.model.resblock_dilation_sizes,
32
+ hps.model.upsample_rates,
33
+ hps.model.upsample_initial_channel,
34
+ hps.model.upsample_kernel_sizes,
35
+ hps.model.spk_embed_dim,
36
+ hps.model.gin_channels,
37
+ hps.data.sampling_rate,
38
+ ]
39
+ opt["info"] = "%sepoch" % epoch
40
+ opt["sr"] = sr
41
+ opt["f0"] = if_f0
42
+ opt["version"] = version
43
+ torch.save(opt, f"logs/{experiment_name}/weights/{name}.pth")
44
+ return "Success."
45
+ except:
46
+ return traceback.format_exc()
47
+
48
+
49
+ def show_info(path):
50
+ try:
51
+ a = torch.load(path, map_location="cpu")
52
+ return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
53
+ a.get("info", "None"),
54
+ a.get("sr", "None"),
55
+ a.get("f0", "None"),
56
+ a.get("version", "None"),
57
+ )
58
+ except:
59
+ return traceback.format_exc()
60
+
61
+
62
+ def extract_small_model(path, name, sr, if_f0, info, version):
63
+ try:
64
+ ckpt = torch.load(path, map_location="cpu")
65
+ if "model" in ckpt:
66
+ ckpt = ckpt["model"]
67
+ opt = OrderedDict()
68
+ opt["weight"] = {}
69
+ for key in ckpt.keys():
70
+ if "enc_q" in key:
71
+ continue
72
+ opt["weight"][key] = ckpt[key].half()
73
+ if sr == "40k":
74
+ opt["config"] = [
75
+ 1025,
76
+ 32,
77
+ 192,
78
+ 192,
79
+ 768,
80
+ 2,
81
+ 6,
82
+ 3,
83
+ 0,
84
+ "1",
85
+ [3, 7, 11],
86
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
87
+ [10, 10, 2, 2],
88
+ 512,
89
+ [16, 16, 4, 4],
90
+ 109,
91
+ 256,
92
+ 40000,
93
+ ]
94
+ elif sr == "48k":
95
+ if version == "v1":
96
+ opt["config"] = [
97
+ 1025,
98
+ 32,
99
+ 192,
100
+ 192,
101
+ 768,
102
+ 2,
103
+ 6,
104
+ 3,
105
+ 0,
106
+ "1",
107
+ [3, 7, 11],
108
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
109
+ [10, 6, 2, 2, 2],
110
+ 512,
111
+ [16, 16, 4, 4, 4],
112
+ 109,
113
+ 256,
114
+ 48000,
115
+ ]
116
+ else:
117
+ opt["config"] = [
118
+ 1025,
119
+ 32,
120
+ 192,
121
+ 192,
122
+ 768,
123
+ 2,
124
+ 6,
125
+ 3,
126
+ 0,
127
+ "1",
128
+ [3, 7, 11],
129
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
130
+ [12, 10, 2, 2],
131
+ 512,
132
+ [24, 20, 4, 4],
133
+ 109,
134
+ 256,
135
+ 48000,
136
+ ]
137
+ elif sr == "32k":
138
+ if version == "v1":
139
+ opt["config"] = [
140
+ 513,
141
+ 32,
142
+ 192,
143
+ 192,
144
+ 768,
145
+ 2,
146
+ 6,
147
+ 3,
148
+ 0,
149
+ "1",
150
+ [3, 7, 11],
151
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
152
+ [10, 4, 2, 2, 2],
153
+ 512,
154
+ [16, 16, 4, 4, 4],
155
+ 109,
156
+ 256,
157
+ 32000,
158
+ ]
159
+ else:
160
+ opt["config"] = [
161
+ 513,
162
+ 32,
163
+ 192,
164
+ 192,
165
+ 768,
166
+ 2,
167
+ 6,
168
+ 3,
169
+ 0,
170
+ "1",
171
+ [3, 7, 11],
172
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
173
+ [10, 8, 2, 2],
174
+ 512,
175
+ [20, 16, 4, 4],
176
+ 109,
177
+ 256,
178
+ 32000,
179
+ ]
180
+ if info == "":
181
+ info = "Extracted model."
182
+ opt["info"] = info
183
+ opt["version"] = version
184
+ opt["sr"] = sr
185
+ opt["f0"] = int(if_f0)
186
+ torch.save(opt, "weights/%s.pth" % name)
187
+ return "Success."
188
+ except:
189
+ return traceback.format_exc()
190
+
191
+
192
+ def change_info(path, info, name):
193
+ try:
194
+ ckpt = torch.load(path, map_location="cpu")
195
+ ckpt["info"] = info
196
+ if name == "":
197
+ name = os.path.basename(path)
198
+ torch.save(ckpt, "weights/%s" % name)
199
+ return "Success."
200
+ except:
201
+ return traceback.format_exc()
202
+
203
+
204
+ def merge(path1, path2, alpha1, sr, f0, info, name, version):
205
+ try:
206
+
207
+ def extract(ckpt):
208
+ a = ckpt["model"]
209
+ opt = OrderedDict()
210
+ opt["weight"] = {}
211
+ for key in a.keys():
212
+ if "enc_q" in key:
213
+ continue
214
+ opt["weight"][key] = a[key]
215
+ return opt
216
+
217
+ ckpt1 = torch.load(path1, map_location="cpu")
218
+ ckpt2 = torch.load(path2, map_location="cpu")
219
+ cfg = ckpt1["config"]
220
+ if "model" in ckpt1:
221
+ ckpt1 = extract(ckpt1)
222
+ else:
223
+ ckpt1 = ckpt1["weight"]
224
+ if "model" in ckpt2:
225
+ ckpt2 = extract(ckpt2)
226
+ else:
227
+ ckpt2 = ckpt2["weight"]
228
+ if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
229
+ return "Fail to merge the models. The model architectures are not the same."
230
+ opt = OrderedDict()
231
+ opt["weight"] = {}
232
+ for key in ckpt1.keys():
233
+ # try:
234
+ if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
235
+ min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
236
+ opt["weight"][key] = (
237
+ alpha1 * (ckpt1[key][:min_shape0].float())
238
+ + (1 - alpha1) * (ckpt2[key][:min_shape0].float())
239
+ ).half()
240
+ else:
241
+ opt["weight"][key] = (
242
+ alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
243
+ ).half()
244
+ # except:
245
+ # pdb.set_trace()
246
+ opt["config"] = cfg
247
+ """
248
+ if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
249
+ elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
250
+ elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
251
+ """
252
+ opt["sr"] = sr
253
+ opt["f0"] = 1 if f0 else 0
254
+ opt["version"] = version
255
+ opt["info"] = info
256
+ torch.save(opt, "weights/%s.pth" % name)
257
+ return "Success."
258
+ except:
259
+ return traceback.format_exc()
train_nsf_sim_cache_sid_load_pretrain.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ now_dir = os.getcwd()
4
+ sys.path.append(os.path.join(now_dir))
5
+ sys.path.append(os.path.join(now_dir, "train"))
6
+ import utils
7
+ import datetime
8
+
9
+ hps = utils.get_hparams()
10
+ experiment_name = hps.name
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
12
+ n_gpus = len(hps.gpus.split("-"))
13
+ from random import shuffle, randint
14
+ import traceback, json, argparse, itertools, math, torch, pdb
15
+
16
+ torch.backends.cudnn.deterministic = False
17
+ torch.backends.cudnn.benchmark = False
18
+ from torch import nn, optim
19
+ from torch.nn import functional as F
20
+ from torch.utils.data import DataLoader
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ import torch.multiprocessing as mp
23
+ import torch.distributed as dist
24
+ from torch.nn.parallel import DistributedDataParallel as DDP
25
+ from torch.cuda.amp import autocast, GradScaler
26
+ from lib.infer_pack import commons
27
+ from time import sleep
28
+ from time import time as ttime
29
+ from data_utils import (
30
+ TextAudioLoaderMultiNSFsid,
31
+ TextAudioLoader,
32
+ TextAudioCollateMultiNSFsid,
33
+ TextAudioCollate,
34
+ DistributedBucketSampler,
35
+ )
36
+
37
+ import csv
38
+
39
+ if hps.version == "v1":
40
+ from lib.infer_pack.models import (
41
+ SynthesizerTrnMs256NSFsid as RVC_Model_f0,
42
+ SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
43
+ MultiPeriodDiscriminator,
44
+ )
45
+ else:
46
+ from lib.infer_pack.models import (
47
+ SynthesizerTrnMs768NSFsid as RVC_Model_f0,
48
+ SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
49
+ MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
50
+ )
51
+ from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
52
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
53
+ from process_ckpt import savee
54
+
55
+ global global_step
56
+ global_step = 0
57
+
58
+
59
+ class EpochRecorder:
60
+ def __init__(self):
61
+ self.last_time = ttime()
62
+
63
+ def record(self):
64
+ now_time = ttime()
65
+ elapsed_time = now_time - self.last_time
66
+ self.last_time = now_time
67
+ elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time))
68
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
69
+ return f"[{current_time}] | ({elapsed_time_str})"
70
+
71
+
72
+ def main():
73
+ n_gpus = torch.cuda.device_count()
74
+ if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True:
75
+ n_gpus = 1
76
+ os.environ["MASTER_ADDR"] = "localhost"
77
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
78
+ children = []
79
+ for i in range(n_gpus):
80
+ subproc = mp.Process(
81
+ target=run,
82
+ args=(
83
+ i,
84
+ n_gpus,
85
+ hps,
86
+ ),
87
+ )
88
+ children.append(subproc)
89
+ subproc.start()
90
+
91
+ for i in range(n_gpus):
92
+ children[i].join()
93
+
94
+ def reset_stop_flag():
95
+ with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
96
+ csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
97
+ csv_writer.writerow(["False"])
98
+
99
+ def create_model(hps, model_f0, model_nof0):
100
+ filter_length_adjusted = hps.data.filter_length // 2 + 1
101
+ segment_size_adjusted = hps.train.segment_size // hps.data.hop_length
102
+ is_half = hps.train.fp16_run
103
+ sr = hps.sample_rate
104
+
105
+ model = model_f0 if hps.if_f0 == 1 else model_nof0
106
+
107
+ return model(
108
+ filter_length_adjusted,
109
+ segment_size_adjusted,
110
+ **hps.model,
111
+ is_half=is_half,
112
+ sr=sr
113
+ )
114
+
115
+ def move_model_to_cuda_if_available(model, rank):
116
+ if torch.cuda.is_available():
117
+ return model.cuda(rank)
118
+ else:
119
+ return model
120
+
121
+ def create_optimizer(model, hps):
122
+ return torch.optim.AdamW(
123
+ model.parameters(),
124
+ hps.train.learning_rate,
125
+ betas=hps.train.betas,
126
+ eps=hps.train.eps,
127
+ )
128
+
129
+ def create_ddp_model(model, rank):
130
+ if torch.cuda.is_available():
131
+ return DDP(model, device_ids=[rank])
132
+ else:
133
+ return DDP(model)
134
+
135
+ def create_dataset(hps, if_f0=True):
136
+ return TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data) if if_f0 else TextAudioLoader(hps.data.training_files, hps.data)
137
+
138
+ def create_sampler(dataset, batch_size, n_gpus, rank):
139
+ return DistributedBucketSampler(
140
+ dataset,
141
+ batch_size * n_gpus,
142
+ # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
143
+ [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
144
+ num_replicas=n_gpus,
145
+ rank=rank,
146
+ shuffle=True,
147
+ )
148
+
149
+ def set_collate_fn(if_f0=True):
150
+ return TextAudioCollateMultiNSFsid() if if_f0 else TextAudioCollate()
151
+
152
+ def run(rank, n_gpus, hps):
153
+ global global_step
154
+ if rank == 0:
155
+ logger = utils.get_logger(hps.model_dir)
156
+ logger.info(hps)
157
+ # utils.check_git_hash(hps.model_dir)
158
+ writer = SummaryWriter(log_dir=hps.model_dir)
159
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
160
+
161
+ dist.init_process_group(
162
+ backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
163
+ )
164
+ torch.manual_seed(hps.train.seed)
165
+ if torch.cuda.is_available():
166
+ torch.cuda.set_device(rank)
167
+
168
+
169
+ train_dataset = TextAudioLoaderMultiNSFsid(
170
+ hps.data.training_files, hps.data
171
+ ) if hps.if_f0 == 1 else TextAudioLoader(hps.data.training_files, hps.data)
172
+
173
+ train_sampler = DistributedBucketSampler(
174
+ train_dataset,
175
+ hps.train.batch_size * n_gpus,
176
+ # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
177
+ [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
178
+ num_replicas=n_gpus,
179
+ rank=rank,
180
+ shuffle=True,
181
+ )
182
+ # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
183
+ # num_workers=8 -> num_workers=4
184
+
185
+ collate_fn = TextAudioCollateMultiNSFsid() if hps.if_f0 == 1 else TextAudioCollate()
186
+ train_loader = DataLoader(
187
+ train_dataset,
188
+ num_workers=4,
189
+ shuffle=False,
190
+ pin_memory=True,
191
+ collate_fn=collate_fn,
192
+ batch_sampler=train_sampler,
193
+ persistent_workers=True,
194
+ prefetch_factor=8,
195
+ )
196
+
197
+ net_g = create_model(hps, RVC_Model_f0, RVC_Model_nof0)
198
+
199
+ net_g = move_model_to_cuda_if_available(net_g, rank)
200
+ net_d = move_model_to_cuda_if_available(MultiPeriodDiscriminator(hps.model.use_spectral_norm), rank)
201
+
202
+ optim_g = create_optimizer(net_g, hps)
203
+ optim_d = create_optimizer(net_d, hps)
204
+ # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
205
+ # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
206
+ net_g = create_ddp_model(net_g, rank)
207
+ net_d = create_ddp_model(net_d, rank)
208
+
209
+ try: # 如果能加载自动resume
210
+ _, _, _, epoch_str = utils.load_checkpoint(
211
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
212
+ ) # D多半加载没事
213
+ if rank == 0:
214
+ logger.info("loaded D")
215
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
216
+ _, _, _, epoch_str = utils.load_checkpoint(
217
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
218
+ )
219
+ global_step = (epoch_str - 1) * len(train_loader)
220
+ # epoch_str = 1
221
+ # global_step = 0
222
+ except: # 如果首次不能加载,加载pretrain
223
+ # traceback.print_exc()
224
+ epoch_str = 1
225
+ global_step = 0
226
+ if hps.pretrainG != "":
227
+ if rank == 0:
228
+ logger.info(f"loaded pretrained {hps.pretrainG}")
229
+ print(
230
+ net_g.module.load_state_dict(
231
+ torch.load(hps.pretrainG, map_location="cpu")["model"]
232
+ )
233
+ ) ##测试不加载优化器
234
+ if hps.pretrainD != "":
235
+ if rank == 0:
236
+ logger.info("loaded pretrained %s" % (hps.pretrainD))
237
+ print(
238
+ net_d.module.load_state_dict(
239
+ torch.load(hps.pretrainD, map_location="cpu")["model"]
240
+ )
241
+ )
242
+
243
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
244
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
245
+ )
246
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
247
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
248
+ )
249
+
250
+ scaler = GradScaler(enabled=hps.train.fp16_run)
251
+
252
+ cache = []
253
+ for epoch in range(epoch_str, hps.train.epochs + 1):
254
+ if rank == 0:
255
+ train_and_evaluate(
256
+ rank,
257
+ epoch,
258
+ hps,
259
+ [net_g, net_d],
260
+ [optim_g, optim_d],
261
+ [scheduler_g, scheduler_d],
262
+ scaler,
263
+ [train_loader, None],
264
+ logger,
265
+ [writer, writer_eval],
266
+ cache,
267
+ )
268
+ else:
269
+ train_and_evaluate(
270
+ rank,
271
+ epoch,
272
+ hps,
273
+ [net_g, net_d],
274
+ [optim_g, optim_d],
275
+ [scheduler_g, scheduler_d],
276
+ scaler,
277
+ [train_loader, None],
278
+ None,
279
+ None,
280
+ cache,
281
+ )
282
+ scheduler_g.step()
283
+ scheduler_d.step()
284
+
285
+
286
+ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache):
287
+ net_g, net_d = nets
288
+ optim_g, optim_d = optims
289
+ train_loader, eval_loader = loaders
290
+ writer, writer_eval = (writers if writers is not None else (None, None))
291
+
292
+ train_loader.batch_sampler.set_epoch(epoch)
293
+ global global_step
294
+
295
+ nets = [net_g, net_d]
296
+ for net in nets:
297
+ net.train()
298
+
299
+ def save_checkpoint(name):
300
+ ckpt = net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict()
301
+ result = savee(ckpt, hps.sample_rate, hps.if_f0, name, epoch, hps.version, hps, experiment_name)
302
+ logger.info("Saving final ckpt: {}".format(result))
303
+ sleep(1)
304
+
305
+ if hps.if_cache_data_in_gpu:
306
+ # Use Cache
307
+ data_iterator = cache
308
+ if len(cache) == 0:
309
+ gpu_available = torch.cuda.is_available()
310
+
311
+ for batch_idx, info in enumerate(train_loader):
312
+ # Unpack
313
+ info = list(info)
314
+ if hps.if_f0:
315
+ tensors = info
316
+ else:
317
+ # We consider that pitch and pitchf are not included in this case
318
+ tensors = info[:2] + info[4:]
319
+
320
+ # Load on CUDA
321
+ if gpu_available:
322
+ tensors = [tensor.cuda(rank, non_blocking=True) for tensor in tensors]
323
+
324
+ # Cache on list
325
+ cache.extend([(batch_idx, tuple(tensor for tensor in tensors if tensor is not None))])
326
+ else:
327
+ shuffle(cache)
328
+ else:
329
+ data_iterator = enumerate(train_loader)
330
+
331
+ def to_gpu_if_available(tensor):
332
+ return tensor.cuda(rank, non_blocking=True) if torch.cuda.is_available() else tensor
333
+
334
+ # Run steps
335
+ gpu_available = torch.cuda.is_available()
336
+ epoch_recorder = EpochRecorder()
337
+ fp16_run = hps.train.fp16_run
338
+ c_mel = hps.train.c_mel
339
+
340
+ for batch_idx, info in data_iterator:
341
+ # Data
342
+ ## Unpack
343
+ if hps.if_f0 == 1:
344
+ phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info
345
+ else:
346
+ phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
347
+ ## Load on CUDA
348
+ if (not hps.if_cache_data_in_gpu) and gpu_available:
349
+ phone = to_gpu_if_available(phone)
350
+ phone_lengths = to_gpu_if_available(phone_lengths)
351
+ sid = to_gpu_if_available(sid)
352
+ spec = to_gpu_if_available(spec)
353
+ spec_lengths = to_gpu_if_available(spec_lengths)
354
+ wave = to_gpu_if_available(wave)
355
+
356
+ if hps.if_f0 == 1:
357
+ pitch = to_gpu_if_available(pitch)
358
+ pitchf = to_gpu_if_available(pitchf)
359
+
360
+ # Calculate
361
+ with autocast(enabled=fp16_run):
362
+ if hps.if_f0 == 1:
363
+ y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = \
364
+ net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
365
+ else:
366
+ y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = \
367
+ net_g(phone, phone_lengths, spec, spec_lengths, sid)
368
+ mel = spec_to_mel_torch(spec, hps.data.filter_length, hps.data.n_mel_channels,
369
+ hps.data.sampling_rate, hps.data.mel_fmin, hps.data.mel_fmax)
370
+
371
+ y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
372
+ y_hat_mel = mel_spectrogram_torch(
373
+ y_hat.float().squeeze(1),
374
+ hps.data.filter_length,
375
+ hps.data.n_mel_channels,
376
+ hps.data.sampling_rate,
377
+ hps.data.hop_length,
378
+ hps.data.win_length,
379
+ hps.data.mel_fmin,
380
+ hps.data.mel_fmax,
381
+ )
382
+
383
+ if fp16_run: y_hat_mel = y_hat_mel.half()
384
+
385
+ wave = commons.slice_segments(wave, ids_slice * hps.data.hop_length,
386
+ hps.train.segment_size) # slice
387
+
388
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
389
+
390
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
391
+ net_d_params = net_d.parameters()
392
+ net_g_params = net_g.parameters()
393
+ lr_scalar = optim_g.param_groups[0]["lr"]
394
+
395
+ optim_d.zero_grad()
396
+ scaler.scale(loss_disc).backward()
397
+ scaler.unscale_(optim_d)
398
+ grad_norm_d = commons.clip_grad_value_(net_d_params, None)
399
+ scaler.step(optim_d)
400
+
401
+ with autocast(enabled=fp16_run):
402
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
403
+
404
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * c_mel
405
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
406
+ loss_fm = feature_loss(fmap_r, fmap_g)
407
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
408
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
409
+
410
+ optim_g.zero_grad()
411
+ scaler.scale(loss_gen_all).backward()
412
+ scaler.unscale_(optim_g)
413
+ grad_norm_g = commons.clip_grad_value_(net_g_params, None)
414
+ scaler.step(optim_g)
415
+ scaler.update()
416
+
417
+ if rank == 0 and global_step % hps.train.log_interval == 0:
418
+ lr = lr_scalar # use stored lr scalar here
419
+ logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
420
+
421
+ # Amor For Tensorboard display
422
+ loss_mel, loss_kl = min(loss_mel, 75), min(loss_kl, 9)
423
+
424
+ scalar_dict = {
425
+ "loss/g/total": loss_gen_all,
426
+ "loss/d/total": loss_disc,
427
+ "learning_rate": lr,
428
+ "grad_norm_d": grad_norm_d,
429
+ "grad_norm_g": grad_norm_g,
430
+ "loss/g/fm": loss_fm,
431
+ "loss/g/mel": loss_mel,
432
+ "loss/g/kl": loss_kl,
433
+ **{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)},
434
+ **{"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)},
435
+ **{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)},
436
+ }
437
+
438
+ image_dict = {
439
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
440
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
441
+ "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
442
+ }
443
+
444
+ utils.summarize(
445
+ writer=writer,
446
+ global_step=global_step,
447
+ images=image_dict,
448
+ scalars=scalar_dict,
449
+ )
450
+ global_step += 1
451
+
452
+ if epoch % hps.save_every_epoch == 0:
453
+ if rank == 0:
454
+ save_format = str(2333333) if hps.if_latest else str(global_step)
455
+ model_dir = hps.model_dir
456
+ learning_rate = hps.train.learning_rate
457
+ name_epoch = f"{hps.name}_e{epoch}"
458
+ models = {'G': net_g, 'D': net_d}
459
+ optims = {'G': optim_g, 'D': optim_d}
460
+
461
+ for model_name, model in models.items():
462
+ path = os.path.join(model_dir, f"{model_name}_{save_format}.pth")
463
+ utils.save_checkpoint(model, optims[model_name], learning_rate, epoch, path)
464
+
465
+ if hps.save_every_weights == "1":
466
+ ckpt = net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict()
467
+ logger.info(
468
+ "saving ckpt %s_%s"
469
+ % (
470
+ name_epoch,
471
+ savee(
472
+ ckpt,
473
+ hps.sample_rate,
474
+ hps.if_f0,
475
+ f"{name_epoch}_s{global_step}",
476
+ epoch,
477
+ hps.version,
478
+ hps,
479
+ experiment_name,
480
+ ),
481
+ )
482
+ )
483
+
484
+ stopbtn = False
485
+ try:
486
+ with open("csvdb/stop.csv", 'r') as csv_file:
487
+ stopbtn_str = next(csv.reader(csv_file), [None])[0]
488
+ if stopbtn_str is not None: stopbtn = stopbtn_str.lower() == 'true'
489
+ except (ValueError, TypeError, FileNotFoundError, IndexError) as e:
490
+ print(f"Handling exception: {e}")
491
+ stopbtn = False
492
+
493
+ if stopbtn:
494
+ logger.info("Stop Button was pressed. The program is closed.")
495
+ ckpt = net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict()
496
+ logger.info(f"Saving final ckpt:{savee(ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps, experiment_name)}")
497
+ sleep(1)
498
+ reset_stop_flag()
499
+ os._exit(2333333)
500
+
501
+ if rank == 0:
502
+ logger.info(f"====> Epoch: {epoch} {epoch_recorder.record()}")
503
+
504
+ if epoch >= hps.total_epoch:
505
+ logger.info("Training is done. The program is closed.")
506
+ save_checkpoint(hps.name)
507
+ os._exit(2333333)
508
+
509
+
510
+ if __name__ == "__main__":
511
+ torch.multiprocessing.set_start_method("spawn")
512
+ main()