Pj12 commited on
Commit
8bdb1e0
·
verified ·
1 Parent(s): c6114a2

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +496 -0
infer.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import torch
5
+ import librosa
6
+ import logging
7
+ import traceback
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import noisereduce as nr
11
+ from pedalboard import (
12
+ Pedalboard,
13
+ Chorus,
14
+ Distortion,
15
+ Reverb,
16
+ PitchShift,
17
+ Limiter,
18
+ Gain,
19
+ Bitcrush,
20
+ Clipping,
21
+ Compressor,
22
+ Delay,
23
+ )
24
+
25
+ from scipy.io import wavfile
26
+ from audio_upscaler import upscale
27
+
28
+ now_dir = os.getcwd()
29
+ sys.path.append(now_dir)
30
+
31
+ from rvc.infer.pipeline import Pipeline as VC
32
+ from rvc.lib.utils import load_audio_infer, load_embedding
33
+ from rvc.lib.tools.split_audio import process_audio, merge_audio
34
+ from rvc.lib.algorithm.synthesizers import Synthesizer
35
+ from rvc.configs.config import Config
36
+
37
+ logging.getLogger("httpx").setLevel(logging.WARNING)
38
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
39
+ logging.getLogger("faiss").setLevel(logging.WARNING)
40
+ logging.getLogger("faiss.loader").setLevel(logging.WARNING)
41
+
42
+
43
+ class VoiceConverter:
44
+ """
45
+ A class for performing voice conversion using the Retrieval-Based Voice Conversion (RVC) method.
46
+ """
47
+
48
+ def __init__(self):
49
+ """
50
+ Initializes the VoiceConverter with default configuration, and sets up models and parameters.
51
+ """
52
+ self.config = Config() # Load RVC configuration
53
+ self.hubert_model = (
54
+ None # Initialize the Hubert model (for embedding extraction)
55
+ )
56
+ self.last_embedder_model = None # Last used embedder model
57
+ self.tgt_sr = None # Target sampling rate for the output audio
58
+ self.net_g = None # Generator network for voice conversion
59
+ self.vc = None # Voice conversion pipeline instance
60
+ self.cpt = None # Checkpoint for loading model weights
61
+ self.version = None # Model version
62
+ self.n_spk = None # Number of speakers in the model
63
+ self.use_f0 = None # Whether the model uses F0
64
+
65
+ def load_hubert(self, embedder_model: str, embedder_model_custom: str = None):
66
+ """
67
+ Loads the HuBERT model for speaker embedding extraction.
68
+
69
+ Args:
70
+ embedder_model (str): Path to the pre-trained HuBERT model.
71
+ embedder_model_custom (str): Path to the custom HuBERT model.
72
+ """
73
+ self.hubert_model = load_embedding(embedder_model, embedder_model_custom)
74
+ self.hubert_model.to(self.config.device)
75
+ self.hubert_model = (
76
+ self.hubert_model.half()
77
+ if self.config.is_half
78
+ else self.hubert_model.float()
79
+ )
80
+ self.hubert_model.eval()
81
+
82
+ @staticmethod
83
+ def remove_audio_noise(data, sr, reduction_strength=0.7):
84
+ """
85
+ Removes noise from an audio file using the NoiseReduce library.
86
+
87
+ Args:
88
+ data (numpy.ndarray): The audio data as a NumPy array.
89
+ sr (int): The sample rate of the audio data.
90
+ reduction_strength (float): Strength of the noise reduction. Default is 0.7.
91
+ """
92
+ try:
93
+ reduced_noise = nr.reduce_noise(
94
+ y=data, sr=sr, prop_decrease=reduction_strength
95
+ )
96
+ return reduced_noise
97
+ except Exception as error:
98
+ print(f"An error occurred removing audio noise: {error}")
99
+ return None
100
+
101
+ @staticmethod
102
+ def convert_audio_format(input_path, output_path, output_format):
103
+ """
104
+ Converts an audio file to a specified output format.
105
+
106
+ Args:
107
+ input_path (str): Path to the input audio file.
108
+ output_path (str): Path to the output audio file.
109
+ output_format (str): Desired audio format (e.g., "WAV", "MP3").
110
+ """
111
+ try:
112
+ if output_format != "WAV":
113
+ print(f"Converting audio to {output_format} format...")
114
+ audio, sample_rate = librosa.load(input_path, sr=None)
115
+ common_sample_rates = [
116
+ 8000,
117
+ 11025,
118
+ 12000,
119
+ 16000,
120
+ 22050,
121
+ 24000,
122
+ 32000,
123
+ 44100,
124
+ 48000,
125
+ ]
126
+ target_sr = min(common_sample_rates, key=lambda x: abs(x - sample_rate))
127
+ audio = librosa.resample(
128
+ audio, orig_sr=sample_rate, target_sr=target_sr
129
+ )
130
+ sf.write(output_path, audio, target_sr, format=output_format.lower())
131
+ return output_path
132
+ except Exception as error:
133
+ print(f"An error occurred converting the audio format: {error}")
134
+
135
+ @staticmethod
136
+ def post_process_audio(
137
+ audio_input,
138
+ sample_rate,
139
+ **kwargs,
140
+ ):
141
+ board = Pedalboard()
142
+ if kwargs.get("reverb", False):
143
+ reverb = Reverb(
144
+ room_size=kwargs.get("reverb_room_size", 0.5),
145
+ damping=kwargs.get("reverb_damping", 0.5),
146
+ wet_level=kwargs.get("reverb_wet_level", 0.33),
147
+ dry_level=kwargs.get("reverb_dry_level", 0.4),
148
+ width=kwargs.get("reverb_width", 1.0),
149
+ freeze_mode=kwargs.get("reverb_freeze_mode", 0),
150
+ )
151
+ board.append(reverb)
152
+ if kwargs.get("pitch_shift", False):
153
+ pitch_shift = PitchShift(semitones=kwargs.get("pitch_shift_semitones", 0))
154
+ board.append(pitch_shift)
155
+ if kwargs.get("limiter", False):
156
+ limiter = Limiter(
157
+ threshold_db=kwargs.get("limiter_threshold", -6),
158
+ release_ms=kwargs.get("limiter_release", 0.05),
159
+ )
160
+ board.append(limiter)
161
+ if kwargs.get("gain", False):
162
+ gain = Gain(gain_db=kwargs.get("gain_db", 0))
163
+ board.append(gain)
164
+ if kwargs.get("distortion", False):
165
+ distortion = Distortion(drive_db=kwargs.get("distortion_gain", 25))
166
+ board.append(distortion)
167
+ if kwargs.get("chorus", False):
168
+ chorus = Chorus(
169
+ rate_hz=kwargs.get("chorus_rate", 1.0),
170
+ depth=kwargs.get("chorus_depth", 0.25),
171
+ centre_delay_ms=kwargs.get("chorus_delay", 7),
172
+ feedback=kwargs.get("chorus_feedback", 0.0),
173
+ mix=kwargs.get("chorus_mix", 0.5),
174
+ )
175
+ board.append(chorus)
176
+ if kwargs.get("bitcrush", False):
177
+ bitcrush = Bitcrush(bit_depth=kwargs.get("bitcrush_bit_depth", 8))
178
+ board.append(bitcrush)
179
+ if kwargs.get("clipping", False):
180
+ clipping = Clipping(threshold_db=kwargs.get("clipping_threshold", 0))
181
+ board.append(clipping)
182
+ if kwargs.get("compressor", False):
183
+ compressor = Compressor(
184
+ threshold_db=kwargs.get("compressor_threshold", 0),
185
+ ratio=kwargs.get("compressor_ratio", 1),
186
+ attack_ms=kwargs.get("compressor_attack", 1.0),
187
+ release_ms=kwargs.get("compressor_release", 100),
188
+ )
189
+ board.append(compressor)
190
+ if kwargs.get("delay", False):
191
+ delay = Delay(
192
+ delay_seconds=kwargs.get("delay_seconds", 0.5),
193
+ feedback=kwargs.get("delay_feedback", 0.0),
194
+ mix=kwargs.get("delay_mix", 0.5),
195
+ )
196
+ board.append(delay)
197
+ return board(audio_input, sample_rate)
198
+
199
+ def convert_audio(
200
+ self,
201
+ audio_input_path: str,
202
+ audio_output_path: str,
203
+ model_path: str,
204
+ index_path: str,
205
+ pitch: int = 0,
206
+ f0_file: str = None,
207
+ f0_method: str = "rmvpe",
208
+ index_rate: float = 0.75,
209
+ volume_envelope: float = 1,
210
+ protect: float = 0.5,
211
+ hop_length: int = 128,
212
+ split_audio: bool = False,
213
+ f0_autotune: bool = False,
214
+ filter_radius: int = 3,
215
+ embedder_model: str = "contentvec",
216
+ embedder_model_custom: str = None,
217
+ clean_audio: bool = False,
218
+ clean_strength: float = 0.5,
219
+ export_format: str = "WAV",
220
+ upscale_audio: bool = False,
221
+ post_process: bool = False,
222
+ resample_sr: int = 0,
223
+ sid: int = 0,
224
+ **kwargs,
225
+ ):
226
+ """
227
+ Performs voice conversion on the input audio.
228
+
229
+ Args:
230
+ pitch (int): Key for F0 up-sampling.
231
+ filter_radius (int): Radius for filtering.
232
+ index_rate (float): Rate for index matching.
233
+ volume_envelope (int): RMS mix rate.
234
+ protect (float): Protection rate for certain audio segments.
235
+ hop_length (int): Hop length for audio processing.
236
+ f0_method (str): Method for F0 extraction.
237
+ audio_input_path (str): Path to the input audio file.
238
+ audio_output_path (str): Path to the output audio file.
239
+ model_path (str): Path to the voice conversion model.
240
+ index_path (str): Path to the index file.
241
+ split_audio (bool): Whether to split the audio for processing.
242
+ f0_autotune (bool): Whether to use F0 autotune.
243
+ clean_audio (bool): Whether to clean the audio.
244
+ clean_strength (float): Strength of the audio cleaning.
245
+ export_format (str): Format for exporting the audio.
246
+ upscale_audio (bool): Whether to upscale the audio.
247
+ f0_file (str): Path to the F0 file.
248
+ embedder_model (str): Path to the embedder model.
249
+ embedder_model_custom (str): Path to the custom embedder model.
250
+ resample_sr (int, optional): Resample sampling rate. Default is 0.
251
+ sid (int, optional): Speaker ID. Default is 0.
252
+ **kwargs: Additional keyword arguments.
253
+ """
254
+ self.get_vc(model_path, sid)
255
+ try:
256
+ start_time = time.time()
257
+ print(f"Converting audio '{audio_input_path}'...")
258
+
259
+ if upscale_audio == True:
260
+ upscale(audio_input_path, audio_input_path)
261
+ audio = load_audio_infer(
262
+ audio_input_path,
263
+ 16000,
264
+ **kwargs,
265
+ )
266
+ audio_max = np.abs(audio).max() / 0.95
267
+
268
+ if audio_max > 1:
269
+ audio /= audio_max
270
+
271
+ if not self.hubert_model or embedder_model != self.last_embedder_model:
272
+ self.load_hubert(embedder_model, embedder_model_custom)
273
+ self.last_embedder_model = embedder_model
274
+
275
+ file_index = (
276
+ index_path.strip()
277
+ .strip('"')
278
+ .strip("\n")
279
+ .strip('"')
280
+ .strip()
281
+ .replace("trained", "added")
282
+ )
283
+
284
+ if self.tgt_sr != resample_sr >= 16000:
285
+ self.tgt_sr = resample_sr
286
+
287
+ if split_audio:
288
+ chunks, intervals = process_audio(audio, 16000)
289
+ print(f"Audio split into {len(chunks)} chunks for processing.")
290
+ else:
291
+ chunks = []
292
+ chunks.append(audio)
293
+
294
+ converted_chunks = []
295
+ for c in chunks:
296
+ audio_opt = self.vc.pipeline(
297
+ model=self.hubert_model,
298
+ net_g=self.net_g,
299
+ sid=sid,
300
+ audio=c,
301
+ pitch=pitch,
302
+ f0_method=f0_method,
303
+ file_index=file_index,
304
+ index_rate=index_rate,
305
+ pitch_guidance=self.use_f0,
306
+ filter_radius=filter_radius,
307
+ volume_envelope=volume_envelope,
308
+ version=self.version,
309
+ protect=protect,
310
+ hop_length=hop_length,
311
+ f0_autotune=f0_autotune,
312
+ f0_file=f0_file,
313
+ )
314
+ converted_chunks.append(audio_opt)
315
+ if split_audio:
316
+ print(f"Converted audio chunk {len(converted_chunks)}")
317
+
318
+ if split_audio:
319
+ audio_opt = merge_audio(converted_chunks, intervals, 16000, self.tgt_sr)
320
+ else:
321
+ audio_opt = converted_chunks[0]
322
+
323
+ if clean_audio:
324
+ cleaned_audio = self.remove_audio_noise(
325
+ audio_opt, self.tgt_sr, clean_strength
326
+ )
327
+ if cleaned_audio is not None:
328
+ audio_opt = cleaned_audio
329
+
330
+ if post_process:
331
+ audio_opt = self.post_process_audio(
332
+ audio_input=audio_opt,
333
+ sample_rate=self.tgt_sr,
334
+ **kwargs,
335
+ )
336
+
337
+ sf.write(audio_output_path, audio_opt, self.tgt_sr, format="WAV")
338
+ output_path_format = audio_output_path.replace(
339
+ ".wav", f".{export_format.lower()}"
340
+ )
341
+ audio_output_path = self.convert_audio_format(
342
+ audio_output_path, output_path_format, export_format
343
+ )
344
+
345
+ elapsed_time = time.time() - start_time
346
+ print(
347
+ f"Conversion completed at '{audio_output_path}' in {elapsed_time:.2f} seconds."
348
+ )
349
+ except Exception as error:
350
+ print(f"An error occurred during audio conversion: {error}")
351
+ print(traceback.format_exc())
352
+
353
+ def convert_audio_batch(
354
+ self,
355
+ audio_input_paths: str,
356
+ audio_output_path: str,
357
+ **kwargs,
358
+ ):
359
+ """
360
+ Performs voice conversion on a batch of input audio files.
361
+
362
+ Args:
363
+ audio_input_paths (str): List of paths to the input audio files.
364
+ audio_output_path (str): Path to the output audio file.
365
+ resample_sr (int, optional): Resample sampling rate. Default is 0.
366
+ sid (int, optional): Speaker ID. Default is 0.
367
+ **kwargs: Additional keyword arguments.
368
+ """
369
+ pid = os.getpid()
370
+ try:
371
+ with open(
372
+ os.path.join(now_dir, "assets", "infer_pid.txt"), "w"
373
+ ) as pid_file:
374
+ pid_file.write(str(pid))
375
+ start_time = time.time()
376
+ print(f"Converting audio batch '{audio_input_paths}'...")
377
+ audio_files = [
378
+ f
379
+ for f in os.listdir(audio_input_paths)
380
+ if f.endswith(
381
+ (
382
+ "wav",
383
+ "mp3",
384
+ "flac",
385
+ "ogg",
386
+ "opus",
387
+ "m4a",
388
+ "mp4",
389
+ "aac",
390
+ "alac",
391
+ "wma",
392
+ "aiff",
393
+ "webm",
394
+ "ac3",
395
+ )
396
+ )
397
+ ]
398
+ print(f"Detected {len(audio_files)} audio files for inference.")
399
+ for a in audio_files:
400
+ new_input = os.path.join(audio_input_paths, a)
401
+ new_output = os.path.splitext(a)[0] + "_output.wav"
402
+ new_output = os.path.join(audio_output_path, new_output)
403
+ if os.path.exists(new_output):
404
+ continue
405
+ self.convert_audio(
406
+ audio_input_path=new_input,
407
+ audio_output_path=new_output,
408
+ **kwargs,
409
+ )
410
+ print(f"Conversion completed at '{audio_input_paths}'.")
411
+ elapsed_time = time.time() - start_time
412
+ print(f"Batch conversion completed in {elapsed_time:.2f} seconds.")
413
+ except Exception as error:
414
+ print(f"An error occurred during audio batch conversion: {error}")
415
+ print(traceback.format_exc())
416
+ finally:
417
+ os.remove(os.path.join(now_dir, "assets", "infer_pid.txt"))
418
+
419
+ def get_vc(self, weight_root, sid):
420
+ """
421
+ Loads the voice conversion model and sets up the pipeline.
422
+
423
+ Args:
424
+ weight_root (str): Path to the model weights.
425
+ sid (int): Speaker ID.
426
+ """
427
+ if sid == "" or sid == []:
428
+ self.cleanup_model()
429
+ if torch.cuda.is_available():
430
+ torch.cuda.empty_cache()
431
+
432
+ self.load_model(weight_root)
433
+
434
+ if self.cpt is not None:
435
+ self.setup_network()
436
+ self.setup_vc_instance()
437
+
438
+ def cleanup_model(self):
439
+ """
440
+ Cleans up the model and releases resources.
441
+ """
442
+ if self.hubert_model is not None:
443
+ del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr
444
+ self.hubert_model = self.net_g = self.n_spk = self.vc = self.tgt_sr = None
445
+ if torch.cuda.is_available():
446
+ torch.cuda.empty_cache()
447
+
448
+ del self.net_g, self.cpt
449
+ if torch.cuda.is_available():
450
+ torch.cuda.empty_cache()
451
+ self.cpt = None
452
+
453
+ def load_model(self, weight_root):
454
+ """
455
+ Loads the model weights from the specified path.
456
+
457
+ Args:
458
+ weight_root (str): Path to the model weights.
459
+ """
460
+ self.cpt = (
461
+ torch.load(weight_root, map_location="cpu")
462
+ if os.path.isfile(weight_root)
463
+ else None
464
+ )
465
+
466
+ def setup_network(self):
467
+ """
468
+ Sets up the network configuration based on the loaded checkpoint.
469
+ """
470
+ if self.cpt is not None:
471
+ self.tgt_sr = self.cpt["config"][-1]
472
+ self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0]
473
+ self.use_f0 = self.cpt.get("f0", 1)
474
+
475
+ self.version = self.cpt.get("version", "v1")
476
+ self.text_enc_hidden_dim = 768 if self.version == "v2" and os.path.exists("/content/Applio/Large_hubert.txt") == False else 256 if self.version == "v1" and os.path.exists("/content/Applio/Large_hubert.txt") == False else 1024
477
+ self.net_g = Synthesizer(
478
+ *self.cpt["config"],
479
+ use_f0=self.use_f0,
480
+ text_enc_hidden_dim=self.text_enc_hidden_dim,
481
+ is_half=self.config.is_half,
482
+ )
483
+ del self.net_g.enc_q
484
+ self.net_g.load_state_dict(self.cpt["weight"], strict=False)
485
+ self.net_g.eval().to(self.config.device)
486
+ self.net_g = (
487
+ self.net_g.half() if self.config.is_half else self.net_g.float()
488
+ )
489
+
490
+ def setup_vc_instance(self):
491
+ """
492
+ Sets up the voice conversion pipeline instance based on the target sampling rate and configuration.
493
+ """
494
+ if self.cpt is not None:
495
+ self.vc = VC(self.tgt_sr, self.config)
496
+ self.n_spk = self.cpt["config"][-3]