Antigravity AI commited on
Commit
6a43f46
·
1 Parent(s): 5321307

Optimize RVC training: resumable training with checkpoints every 10 epochs and advanced UI options

Browse files
Files changed (2) hide show
  1. app.py +24 -4
  2. pipeline/rvc_training.py +6 -6
app.py CHANGED
@@ -248,6 +248,26 @@ with gr.Blocks(title="Voice Clone RVC", theme=gr.themes.Soft()) as app:
248
  step=1,
249
  label="Epochs (Iteraciones de entrenamiento)",
250
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  rvc_train_btn = gr.Button(
252
  "Iniciar Entrenamiento RVC",
253
  variant="primary",
@@ -265,14 +285,14 @@ with gr.Blocks(title="Voice Clone RVC", theme=gr.themes.Soft()) as app:
265
  )
266
 
267
  gr.Markdown(
268
- "**Advertencia sobre ZeroGPU:**\n"
269
- "- El entrenamiento de modelos requiere mucho tiempo. ZeroGPU puede cancelar el proceso si supera el tiempo máximo asignado (usualmente 10 minutos).\n"
270
- "- Si falla por timeout, considera bajar los epochs o usar menos audio."
271
  )
272
 
273
  rvc_train_btn.click(
274
  fn=train_rvc_model,
275
- inputs=[rvc_audio, rvc_model_name, rvc_epochs],
276
  outputs=[rvc_status, rvc_download],
277
  )
278
 
 
248
  step=1,
249
  label="Epochs (Iteraciones de entrenamiento)",
250
  )
251
+ with gr.Accordion("Opciones Avanzadas", open=False):
252
+ rvc_f0_method = gr.Dropdown(
253
+ choices=["rmvpe", "crepe", "fcpe"],
254
+ value="rmvpe",
255
+ label="Método de Extracción de Pitch (f0)"
256
+ )
257
+ rvc_batch_size = gr.Slider(
258
+ minimum=1,
259
+ maximum=24,
260
+ value=4,
261
+ step=1,
262
+ label="Batch Size (Tamaño de lote)"
263
+ )
264
+ rvc_save_every = gr.Slider(
265
+ minimum=1,
266
+ maximum=50,
267
+ value=10,
268
+ step=1,
269
+ label="Guardar Checkpoint cada (Epochs)"
270
+ )
271
  rvc_train_btn = gr.Button(
272
  "Iniciar Entrenamiento RVC",
273
  variant="primary",
 
285
  )
286
 
287
  gr.Markdown(
288
+ "**🚀 Entrenamiento Resumible:**\n"
289
+ "- Si ZeroGPU corta el entrenamiento por tiempo (10 min), puedes volver a dar clic en el botón y el proceso continuará desde el último punto guardado.\n"
290
+ "- Los checkpoints se guardan cada **10 epochs** por defecto."
291
  )
292
 
293
  rvc_train_btn.click(
294
  fn=train_rvc_model,
295
+ inputs=[rvc_audio, rvc_model_name, rvc_epochs, rvc_batch_size, rvc_f0_method, rvc_save_every],
296
  outputs=[rvc_status, rvc_download],
297
  )
298
 
pipeline/rvc_training.py CHANGED
@@ -45,7 +45,7 @@ except Exception as e:
45
  RVC_LOGIC_AVAILABLE = False
46
 
47
  @spaces.GPU(duration=1000)
48
- def train_rvc_model(audio_path, model_name, epochs=100, progress=None):
49
  if not RVC_LOGIC_AVAILABLE:
50
  return f"Error: rvc_logic module failed to load. Reason: {RVC_IMPORT_ERROR}", None
51
 
@@ -75,20 +75,20 @@ def train_rvc_model(audio_path, model_name, epochs=100, progress=None):
75
  )
76
 
77
  # 3. Extract Features
78
- p(0.4, "Step 3/4: Extracting features (F0 & Content)...")
79
  extract.extract_features(
80
  model_name=model_name,
81
- f0_method=F0Method.RMVPE,
82
  embedder_model=EmbedderModel.CONTENTVEC
83
  )
84
 
85
  # 4. Train
86
- p(0.6, f"Step 4/4: Training for {epochs} epochs...")
87
  result_paths = train.run_training(
88
  model_name=model_name,
89
  num_epochs=epochs,
90
- batch_size=4,
91
- save_interval=epochs
92
  )
93
 
94
  if not result_paths or len(result_paths) < 2:
 
45
  RVC_LOGIC_AVAILABLE = False
46
 
47
  @spaces.GPU(duration=1000)
48
+ def train_rvc_model(audio_path, model_name, epochs=100, batch_size=4, f0_method="rmvpe", save_every=10, progress=None):
49
  if not RVC_LOGIC_AVAILABLE:
50
  return f"Error: rvc_logic module failed to load. Reason: {RVC_IMPORT_ERROR}", None
51
 
 
75
  )
76
 
77
  # 3. Extract Features
78
+ p(0.4, f"Step 3/4: Extracting features (Method: {f0_method})...")
79
  extract.extract_features(
80
  model_name=model_name,
81
+ f0_method=f0_method,
82
  embedder_model=EmbedderModel.CONTENTVEC
83
  )
84
 
85
  # 4. Train
86
+ p(0.6, f"Step 4/4: Training for {epochs} epochs (Batch: {batch_size}, Checkpoints: {save_every})...")
87
  result_paths = train.run_training(
88
  model_name=model_name,
89
  num_epochs=epochs,
90
+ batch_size=batch_size,
91
+ save_interval=save_every
92
  )
93
 
94
  if not result_paths or len(result_paths) < 2: