ibcplateformes Claude Opus 4.6 commited on
Commit
259efa9
·
1 Parent(s): 266f7ad

Skip HiFi-GAN training on CPU, use pre-trained model + FAISS index

Browse files

RVC training on CPU takes hours — impractical for a web app.
New approach on CPU:
- Preprocess + extract features (~5 min)
- Build FAISS index from voice embeddings (seconds)
- Use pre-trained RVC generator with user's index for inference
- Full training still available when GPU is detected

Also rewrote build_index to use faiss directly instead of Applio script.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +1 -1
  2. pipeline/training.py +82 -32
app.py CHANGED
@@ -296,7 +296,7 @@ with gr.Blocks(
296
  maximum=30,
297
  value=10,
298
  step=5,
299
- label="Nombre d'époques (CPU: 10 20-30 min, 20 ≈ 45-60 min)",
300
  )
301
  train_btn = gr.Button(
302
  "Lancer l'entraînement",
 
296
  maximum=30,
297
  value=10,
298
  step=5,
299
+ label="Nombre d'époques (utilisé uniquement avec GPU)",
300
  )
301
  train_btn = gr.Button(
302
  "Lancer l'entraînement",
pipeline/training.py CHANGED
@@ -312,26 +312,59 @@ def train_model(
312
 
313
 
314
  def build_index(model_name: str):
315
- """Build FAISS index for the trained model. Runs on CPU (subprocess OK)."""
316
- _setup_applio_env()
317
 
318
- exp_dir = os.path.join(LOGS_DIR, model_name)
319
- index_script = os.path.join(APPLIO_DIR, "rvc", "train", "process", "extract_index.py")
 
 
 
320
 
321
- command = [sys.executable, index_script, exp_dir, "Auto"]
 
322
 
323
- logger.info(f"Building index for {model_name}...")
324
- result = subprocess.run(command, capture_output=True, text=True, cwd=APPLIO_DIR)
 
325
 
326
- if result.returncode != 0:
327
- logger.warning(f"Index building failed: {result.stderr[-300:]}")
 
 
 
 
 
 
 
 
 
 
328
  return None
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  index_path = os.path.join(exp_dir, f"{model_name}.index")
331
- if os.path.exists(index_path):
332
- logger.info(f"Index built: {index_path}")
333
- return index_path
334
- return None
335
 
336
 
337
  def find_trained_model(model_name: str):
@@ -347,52 +380,69 @@ def find_trained_model(model_name: str):
347
  if f.endswith(".pth") and f.startswith(model_name):
348
  return os.path.join(exp_dir, f)
349
 
350
- if os.path.exists(LOGS_DIR):
351
- for f in sorted(os.listdir(LOGS_DIR), reverse=True):
352
- if f.endswith(".pth") and f.startswith(model_name):
353
- return os.path.join(LOGS_DIR, f)
354
 
 
 
 
 
 
 
355
  return None
356
 
357
 
358
  def full_training_pipeline(
359
  audio_path: str,
360
  model_name: str,
361
- epochs: int = 20,
362
  sample_rate: int = 40000,
363
- batch_size: int = 8,
364
  progress_callback=None,
365
  ):
366
  """
367
- Run the complete training pipeline.
 
368
  Returns (pth_path, index_path) on success.
369
  """
 
370
  from pipeline.storage import upload_model, LOCAL_MODELS_DIR
371
 
 
 
372
  if progress_callback:
373
- progress_callback(0.05, "Preprocessing audio...")
374
 
375
  n_slices = preprocess(model_name, audio_path, sample_rate)
376
 
377
  if progress_callback:
378
- progress_callback(0.15, f"Preprocessing done ({n_slices} segments). Extracting features...")
379
 
380
  extract_features(model_name, sample_rate)
381
 
382
  if progress_callback:
383
- progress_callback(0.35, "Features extracted. Training model...")
384
-
385
- train_model(model_name, sample_rate, epochs, batch_size)
386
-
387
- if progress_callback:
388
- progress_callback(0.85, "Training done. Building index...")
389
 
 
390
  index_path = build_index(model_name)
391
 
392
- pth_path = find_trained_model(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
393
  if not pth_path:
394
- raise RuntimeError("Training completed but model file not found.")
395
 
 
396
  local_model_dir = os.path.join(LOCAL_MODELS_DIR, model_name)
397
  os.makedirs(local_model_dir, exist_ok=True)
398
 
@@ -405,7 +455,7 @@ def full_training_pipeline(
405
  shutil.copy2(index_path, local_index)
406
 
407
  if progress_callback:
408
- progress_callback(0.90, "Uploading model...")
409
 
410
  try:
411
  upload_model(model_name, local_pth, local_index)
@@ -413,6 +463,6 @@ def full_training_pipeline(
413
  logger.warning(f"Failed to upload to HF (non-critical): {e}")
414
 
415
  if progress_callback:
416
- progress_callback(1.0, "Training complete!")
417
 
418
  return local_pth, local_index
 
312
 
313
 
314
  def build_index(model_name: str):
315
+ """Build FAISS index from extracted embeddings."""
316
+ import numpy as np
317
 
318
+ try:
319
+ import faiss
320
+ except ImportError:
321
+ logger.warning("faiss not available, skipping index building.")
322
+ return None
323
 
324
+ exp_dir = os.path.join(LOGS_DIR, model_name)
325
+ extracted_dir = os.path.join(exp_dir, "extracted")
326
 
327
+ if not os.path.exists(extracted_dir):
328
+ logger.warning("No extracted features found for index building.")
329
+ return None
330
 
331
+ # Load all embeddings
332
+ embeddings = []
333
+ for npy_file in sorted(glob.glob(os.path.join(extracted_dir, "*.npy"))):
334
+ try:
335
+ emb = np.load(npy_file)
336
+ if emb.ndim == 2:
337
+ embeddings.append(emb)
338
+ except Exception as e:
339
+ logger.warning(f"Failed to load {npy_file}: {e}")
340
+
341
+ if not embeddings:
342
+ logger.warning("No valid embeddings found for index.")
343
  return None
344
 
345
+ all_emb = np.concatenate(embeddings, axis=0).astype(np.float32)
346
+ logger.info(f"Building FAISS index from {all_emb.shape[0]} vectors ({all_emb.shape[1]}D)...")
347
+
348
+ # Build IVF index for fast retrieval
349
+ dim = all_emb.shape[1]
350
+ n_vectors = all_emb.shape[0]
351
+
352
+ if n_vectors < 40:
353
+ # Too few vectors for IVF, use flat index
354
+ index = faiss.IndexFlatL2(dim)
355
+ else:
356
+ n_clusters = min(int(np.sqrt(n_vectors)), n_vectors // 4)
357
+ n_clusters = max(n_clusters, 1)
358
+ quantizer = faiss.IndexFlatL2(dim)
359
+ index = faiss.IndexIVFFlat(quantizer, dim, n_clusters)
360
+ index.train(all_emb)
361
+
362
+ index.add(all_emb)
363
+
364
  index_path = os.path.join(exp_dir, f"{model_name}.index")
365
+ faiss.write_index(index, index_path)
366
+ logger.info(f"FAISS index built: {index_path} ({n_vectors} vectors)")
367
+ return index_path
 
368
 
369
 
370
  def find_trained_model(model_name: str):
 
380
  if f.endswith(".pth") and f.startswith(model_name):
381
  return os.path.join(exp_dir, f)
382
 
383
+ return None
384
+
 
 
385
 
386
+ def find_pretrained_model(sample_rate: int = 40000):
387
+ """Find the pre-trained RVC generator model."""
388
+ sr_prefix = str(sample_rate)[:2]
389
+ pg = os.path.join(APPLIO_DIR, "rvc", "models", "pretraineds", "hifi-gan", f"f0G{sr_prefix}k.pth")
390
+ if os.path.exists(pg):
391
+ return pg
392
  return None
393
 
394
 
395
  def full_training_pipeline(
396
  audio_path: str,
397
  model_name: str,
398
+ epochs: int = 10,
399
  sample_rate: int = 40000,
400
+ batch_size: int = 4,
401
  progress_callback=None,
402
  ):
403
  """
404
+ Run the voice model creation pipeline.
405
+ On CPU: skips heavy HiFi-GAN training, uses pre-trained model + FAISS index.
406
  Returns (pth_path, index_path) on success.
407
  """
408
+ import torch
409
  from pipeline.storage import upload_model, LOCAL_MODELS_DIR
410
 
411
+ has_gpu = torch.cuda.is_available()
412
+
413
  if progress_callback:
414
+ progress_callback(0.05, "Découpage de l'audio...")
415
 
416
  n_slices = preprocess(model_name, audio_path, sample_rate)
417
 
418
  if progress_callback:
419
+ progress_callback(0.20, f"{n_slices} segments créés. Extraction des caractéristiques vocales...")
420
 
421
  extract_features(model_name, sample_rate)
422
 
423
  if progress_callback:
424
+ progress_callback(0.60, "Caractéristiques extraites. Construction de l'index vocal...")
 
 
 
 
 
425
 
426
+ # Build FAISS index (fast, CPU-friendly)
427
  index_path = build_index(model_name)
428
 
429
+ if has_gpu:
430
+ # With GPU: do full training for best quality
431
+ if progress_callback:
432
+ progress_callback(0.65, "GPU détecté. Entraînement du modèle...")
433
+ train_model(model_name, sample_rate, epochs, batch_size)
434
+ pth_path = find_trained_model(model_name)
435
+ else:
436
+ # CPU only: use pre-trained model (skip hours-long training)
437
+ if progress_callback:
438
+ progress_callback(0.75, "Mode CPU : utilisation du modèle pré-entraîné...")
439
+ logger.info("CPU mode: skipping HiFi-GAN training, using pre-trained model.")
440
+ pth_path = find_pretrained_model(sample_rate)
441
+
442
  if not pth_path:
443
+ raise RuntimeError("Aucun modèle trouvé. Vérifiez que les modèles pré-entraînés sont téléchargés.")
444
 
445
+ # Save to local models directory
446
  local_model_dir = os.path.join(LOCAL_MODELS_DIR, model_name)
447
  os.makedirs(local_model_dir, exist_ok=True)
448
 
 
455
  shutil.copy2(index_path, local_index)
456
 
457
  if progress_callback:
458
+ progress_callback(0.90, "Sauvegarde du modèle...")
459
 
460
  try:
461
  upload_model(model_name, local_pth, local_index)
 
463
  logger.warning(f"Failed to upload to HF (non-critical): {e}")
464
 
465
  if progress_callback:
466
+ progress_callback(1.0, "Modèle vocal créé !")
467
 
468
  return local_pth, local_index