Upload Conv-TasNet best checkpoint
Browse files- README.md +95 -0
- best.ckpt +3 -0
- configs/data.yaml +10 -0
- configs/train.yaml +35 -0
- requirements.txt +5 -0
- src/model.py +114 -0
- src/separate.py +141 -0
- training_metadata.json +9 -0
README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- audio
|
| 5 |
+
- speech
|
| 6 |
+
- source-separation
|
| 7 |
+
- conv-tasnet
|
| 8 |
+
- asteroid
|
| 9 |
+
- pytorch
|
| 10 |
+
library_name: pytorch
|
| 11 |
+
pipeline_tag: audio-to-audio
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Cocktail Party AI - Conv-TasNet 3-Source Separator
|
| 15 |
+
|
| 16 |
+
This repository contains the best checkpoint from a Conv-TasNet model trained for speech source separation.
|
| 17 |
+
The model takes a mixed speech waveform and estimates 3 separated source waveforms.
|
| 18 |
+
|
| 19 |
+
## Checkpoint
|
| 20 |
+
|
| 21 |
+
- File: `best.ckpt`
|
| 22 |
+
- Architecture: Asteroid `ConvTasNet`
|
| 23 |
+
- Number of sources: 3
|
| 24 |
+
- Sample rate: 16 kHz
|
| 25 |
+
- Training checkpoint epoch: 68
|
| 26 |
+
- Best validation loss: -2.909952
|
| 27 |
+
- Approximate validation SI-SNR: 2.91 dB
|
| 28 |
+
|
| 29 |
+
## Files
|
| 30 |
+
|
| 31 |
+
```text
|
| 32 |
+
best.ckpt
|
| 33 |
+
configs/data.yaml
|
| 34 |
+
configs/train.yaml
|
| 35 |
+
requirements.txt
|
| 36 |
+
src/model.py
|
| 37 |
+
src/separate.py
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Usage
|
| 41 |
+
|
| 42 |
+
Install dependencies:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
pip install -r requirements.txt
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Load the checkpoint with the project code:
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
import yaml
|
| 52 |
+
import torch
|
| 53 |
+
|
| 54 |
+
from src.model import build_model, load_checkpoint
|
| 55 |
+
|
| 56 |
+
with open("configs/train.yaml") as f:
|
| 57 |
+
train_cfg = yaml.safe_load(f)
|
| 58 |
+
with open("configs/data.yaml") as f:
|
| 59 |
+
data_cfg = yaml.safe_load(f)
|
| 60 |
+
|
| 61 |
+
mod = train_cfg["model"]
|
| 62 |
+
ds = data_cfg["dataset"]
|
| 63 |
+
|
| 64 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 65 |
+
|
| 66 |
+
model = build_model(
|
| 67 |
+
n_src=ds["n_src"],
|
| 68 |
+
sample_rate=ds["sample_rate"],
|
| 69 |
+
n_filters=mod["n_filters"],
|
| 70 |
+
filter_length=mod["filter_length"],
|
| 71 |
+
stride=mod["stride"],
|
| 72 |
+
n_blocks=mod["n_blocks"],
|
| 73 |
+
n_repeats=mod["n_repeats"],
|
| 74 |
+
bn_chan=mod["bn_chan"],
|
| 75 |
+
hid_chan=mod["hid_chan"],
|
| 76 |
+
skip_chan=mod["skip_chan"],
|
| 77 |
+
norm_type=mod["norm_type"],
|
| 78 |
+
mask_act=mod["mask_act"],
|
| 79 |
+
use_gradient_checkpointing=False,
|
| 80 |
+
).to(device)
|
| 81 |
+
|
| 82 |
+
load_checkpoint(model, "best.ckpt", device)
|
| 83 |
+
model.eval()
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
To separate a WAV file using this project:
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
python src/separate.py --mix path/to/mixture.wav --ckpt best.ckpt
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## Notes
|
| 93 |
+
|
| 94 |
+
This is a research/training checkpoint, not a fully packaged `transformers` pipeline.
|
| 95 |
+
It depends on PyTorch, Torchaudio, and Asteroid.
|
best.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ca8736f8da084c480bd765a99cb0fb4f2f1b2a9c94437cf345896c75bab8ee4
|
| 3 |
+
size 23479878
|
configs/data.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
root_dir: data/mixtures # chemin vers le dossier des mélanges
|
| 3 |
+
n_src: 3 # nombre de sources (source_1 à source_n)
|
| 4 |
+
sample_rate: 16000 # fréquence d'échantillonnage en Hz
|
| 5 |
+
segment_duration: 5.0 # durée d'un segment en secondes (0 = fichier entier)
|
| 6 |
+
|
| 7 |
+
splits:
|
| 8 |
+
train_ratio: 0.8 # 80% train
|
| 9 |
+
val_ratio: 0.1 # 10% validation
|
| 10 |
+
test_ratio: 0.1 # calculé automatiquement
|
configs/train.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
epochs: 200
|
| 3 |
+
early_stopping: 20
|
| 4 |
+
batch_size: 48
|
| 5 |
+
accumulation_steps: 2 # effective batch = 48 × 2 = 96
|
| 6 |
+
learning_rate: 0.00005
|
| 7 |
+
grad_clip: 30.0 # ← réduit : 1.0 provoquait des instabilités
|
| 8 |
+
save_every: 10
|
| 9 |
+
num_workers: 4
|
| 10 |
+
seed: 42
|
| 11 |
+
|
| 12 |
+
scheduler:
|
| 13 |
+
name: warmup_cosine
|
| 14 |
+
warmup_epochs: 5 # ← augmenté : plus de warmup avec peu de données
|
| 15 |
+
min_lr: 0.0000001
|
| 16 |
+
optimizer:
|
| 17 |
+
name: adam
|
| 18 |
+
weight_decay: 0.001
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
n_filters: 256
|
| 22 |
+
filter_length: 16
|
| 23 |
+
stride: 8
|
| 24 |
+
n_blocks: 6
|
| 25 |
+
n_repeats: 3
|
| 26 |
+
bn_chan: 128
|
| 27 |
+
hid_chan: 256
|
| 28 |
+
skip_chan: 128
|
| 29 |
+
norm_type: gLN
|
| 30 |
+
mask_act: relu
|
| 31 |
+
gradient_checkpointing: true
|
| 32 |
+
|
| 33 |
+
paths:
|
| 34 |
+
checkpoint_dir: checkpoints_3src_2
|
| 35 |
+
log_dir: logs_3src_2
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchaudio>=2.0.0
|
| 3 |
+
asteroid>=0.6.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
matplotlib>=3.10.8
|
src/model.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.checkpoint import checkpoint
|
| 4 |
+
|
| 5 |
+
from asteroid.models import ConvTasNet
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_model(n_src=5, sample_rate=8000,
|
| 9 |
+
n_filters=512, filter_length=16,
|
| 10 |
+
stride=8, n_blocks=8, n_repeats=3,
|
| 11 |
+
bn_chan=128, hid_chan=512, skip_chan=128,
|
| 12 |
+
norm_type="gLN", mask_act="relu",
|
| 13 |
+
use_gradient_checkpointing=False):
|
| 14 |
+
|
| 15 |
+
model = ConvTasNet(
|
| 16 |
+
n_src=n_src, sample_rate=sample_rate,
|
| 17 |
+
n_filters=n_filters, filter_length=filter_length,
|
| 18 |
+
stride=stride, n_blocks=n_blocks, n_repeats=n_repeats,
|
| 19 |
+
bn_chan=bn_chan, hid_chan=hid_chan, skip_chan=skip_chan,
|
| 20 |
+
norm_type=norm_type, mask_act=mask_act,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
if use_gradient_checkpointing:
|
| 24 |
+
_apply_gradient_checkpointing(model)
|
| 25 |
+
print("[Model] Gradient checkpointing : ACTIVÉ (-50% VRAM, +30% temps)")
|
| 26 |
+
else:
|
| 27 |
+
print("[Model] Gradient checkpointing : désactivé")
|
| 28 |
+
|
| 29 |
+
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 30 |
+
print(f"[Model] Conv-TasNet | Paramètres entraînables : {n_params:,}")
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _apply_gradient_checkpointing(model):
|
| 35 |
+
if not hasattr(model, "masker") or not hasattr(model.masker, "TCN"):
|
| 36 |
+
print("[Warning] masker.TCN introuvable — gradient checkpointing non appliqué.")
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
original_blocks = list(model.masker.TCN.named_children())
|
| 40 |
+
if not original_blocks:
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
for name, block in original_blocks:
|
| 44 |
+
_wrap_block(model.masker.TCN, name, block)
|
| 45 |
+
|
| 46 |
+
print(f"[Model] {len(original_blocks)} blocs TCN checkpointés.")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _wrap_block(parent, name, block):
|
| 50 |
+
class CheckpointedBlock(nn.Module):
|
| 51 |
+
def __init__(self, inner):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.inner = inner
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
if not x.requires_grad:
|
| 57 |
+
x = x.requires_grad_(True)
|
| 58 |
+
return checkpoint(self.inner, x, use_reentrant=False)
|
| 59 |
+
|
| 60 |
+
setattr(parent, name, CheckpointedBlock(block))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_checkpoint(model, path, device="cpu"):
|
| 64 |
+
"""
|
| 65 |
+
Load checkpoint safely.
|
| 66 |
+
Automatically handles the .inner. key mismatch caused by
|
| 67 |
+
gradient checkpointing wrapper (CheckpointedBlock).
|
| 68 |
+
"""
|
| 69 |
+
ckpt = torch.load(path, map_location=device)
|
| 70 |
+
state = ckpt.get("model_state_dict", ckpt)
|
| 71 |
+
|
| 72 |
+
model_keys = set(model.state_dict().keys())
|
| 73 |
+
|
| 74 |
+
# Try loading as-is first
|
| 75 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 76 |
+
missing_set = set(missing)
|
| 77 |
+
unexpected_set = set(unexpected)
|
| 78 |
+
|
| 79 |
+
# Case 1 : checkpoint has .inner. but model doesn't → strip .inner.
|
| 80 |
+
if any(".inner." in k for k in unexpected_set) and \
|
| 81 |
+
any(".inner." not in k for k in missing_set):
|
| 82 |
+
state = {k.replace(".inner.", "."): v for k, v in state.items()}
|
| 83 |
+
model.load_state_dict(state, strict=True)
|
| 84 |
+
print("[Model] '.inner.' stripped from checkpoint keys (GC ON → OFF)")
|
| 85 |
+
|
| 86 |
+
# Case 2 : model has .inner. but checkpoint doesn't → add .inner.
|
| 87 |
+
elif any(".inner." in k for k in missing_set) and \
|
| 88 |
+
any(".inner." not in k for k in unexpected_set):
|
| 89 |
+
new_state = {}
|
| 90 |
+
for k, v in state.items():
|
| 91 |
+
if "masker.TCN." in k and ".inner." not in k:
|
| 92 |
+
parts = k.split(".")
|
| 93 |
+
parts.insert(3, "inner")
|
| 94 |
+
k = ".".join(parts)
|
| 95 |
+
new_state[k] = v
|
| 96 |
+
model.load_state_dict(new_state, strict=True)
|
| 97 |
+
print("[Model] '.inner.' added to checkpoint keys (GC OFF → ON)")
|
| 98 |
+
|
| 99 |
+
# Case 3 : loaded fine on first try
|
| 100 |
+
elif len(missing) == 0 and len(unexpected) == 0:
|
| 101 |
+
print("[Model] Checkpoint chargé sans modification")
|
| 102 |
+
|
| 103 |
+
else:
|
| 104 |
+
raise RuntimeError(
|
| 105 |
+
f"Cannot load checkpoint — unresolvable key mismatch:\n"
|
| 106 |
+
f" Missing : {list(missing)[:3]}...\n"
|
| 107 |
+
f" Unexpected: {list(unexpected)[:3]}..."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
epoch = ckpt.get("epoch", "?")
|
| 111 |
+
val_loss = ckpt.get("best_val_loss", "?")
|
| 112 |
+
print(f"[Model] Checkpoint chargé depuis {path} "
|
| 113 |
+
f"(epoch {epoch}, val loss {val_loss})")
|
| 114 |
+
return model
|
src/separate.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
separate.py — Séparation de sources avec le modèle entraîné
|
| 3 |
+
Usage :
|
| 4 |
+
python main.py separate --mix data/mixture/mix_0/mixture.wav
|
| 5 |
+
python main.py separate --mix data/mixture/mix_0/mixture.wav --ckpt checkpoints/best.ckpt
|
| 6 |
+
python main.py separate --mix mon_audio.wav --out_dir outputs/separated_audio
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
import torch
|
| 12 |
+
import torchaudio
|
| 13 |
+
|
| 14 |
+
from src.model import build_model, load_checkpoint
|
| 15 |
+
import yaml
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_config(path):
|
| 19 |
+
with open(path, "r") as f:
|
| 20 |
+
return yaml.safe_load(f)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def parse_args():
|
| 24 |
+
p = argparse.ArgumentParser(description="Séparation de sources Conv-TasNet")
|
| 25 |
+
p.add_argument("--mix", type=str, required=True,
|
| 26 |
+
help="Chemin vers le fichier mixture.wav à séparer")
|
| 27 |
+
p.add_argument("--ckpt", type=str, default="checkpoints/best.ckpt",
|
| 28 |
+
help="Checkpoint du modèle entraîné")
|
| 29 |
+
p.add_argument("--out_dir", type=str, default="outputs/separated_audio",
|
| 30 |
+
help="Dossier de sortie pour les sources séparées")
|
| 31 |
+
p.add_argument("--train_cfg", type=str, default="configs/train.yaml")
|
| 32 |
+
p.add_argument("--data_cfg", type=str, default="configs/data.yaml")
|
| 33 |
+
return p.parse_args()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def separate(mix_path, model, sample_rate, device, out_dir):
|
| 37 |
+
"""Charge un mixture.wav, sépare les sources, sauvegarde les .wav."""
|
| 38 |
+
|
| 39 |
+
# ── Charger le fichier audio ─────────────
|
| 40 |
+
mixture, sr = torchaudio.load(mix_path)
|
| 41 |
+
|
| 42 |
+
if sr != sample_rate:
|
| 43 |
+
print(f" Resample {sr} Hz → {sample_rate} Hz")
|
| 44 |
+
mixture = torchaudio.functional.resample(mixture, sr, sample_rate)
|
| 45 |
+
|
| 46 |
+
# Mono (1, T)
|
| 47 |
+
if mixture.shape[0] > 1:
|
| 48 |
+
mixture = mixture.mean(dim=0, keepdim=True)
|
| 49 |
+
|
| 50 |
+
print(f" Durée : {mixture.shape[-1] / sample_rate:.2f}s "
|
| 51 |
+
f"({mixture.shape[-1]} samples)")
|
| 52 |
+
|
| 53 |
+
# ── Inférence ────────────────────────────
|
| 54 |
+
mixture = mixture.to(device) # (1, T)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
# Le modèle attend (B, T) → unsqueeze batch dim
|
| 57 |
+
est_sources = model(mixture.unsqueeze(0)) # (1, n_src, T)
|
| 58 |
+
est_sources = est_sources.squeeze(0) # (n_src, T)
|
| 59 |
+
|
| 60 |
+
# ── Sauvegarder les sources séparées ─────
|
| 61 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 62 |
+
mix_name = os.path.splitext(os.path.basename(mix_path))[0]
|
| 63 |
+
|
| 64 |
+
for i, src in enumerate(est_sources):
|
| 65 |
+
src_cpu = src.unsqueeze(0).cpu() # (1, T)
|
| 66 |
+
|
| 67 |
+
# Normaliser pour éviter la saturation
|
| 68 |
+
max_val = src_cpu.abs().max()
|
| 69 |
+
if max_val > 0:
|
| 70 |
+
src_cpu = src_cpu / max_val * 0.9
|
| 71 |
+
|
| 72 |
+
out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav")
|
| 73 |
+
torchaudio.save(out_path, src_cpu, sample_rate)
|
| 74 |
+
print(f" ✓ Source {i+1} sauvegardée : {out_path}")
|
| 75 |
+
|
| 76 |
+
return est_sources
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
args = parse_args()
|
| 81 |
+
tcfg = load_config(args.train_cfg)
|
| 82 |
+
dcfg = load_config(args.data_cfg)
|
| 83 |
+
|
| 84 |
+
mod = tcfg["model"]
|
| 85 |
+
ds = dcfg["dataset"]
|
| 86 |
+
|
| 87 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 88 |
+
print(f"\n[Config] Device : {device}")
|
| 89 |
+
print(f"[Config] Checkpoint : {args.ckpt}")
|
| 90 |
+
print(f"[Config] Fichier mix : {args.mix}\n")
|
| 91 |
+
|
| 92 |
+
# ── Charger le modèle ────────────────────
|
| 93 |
+
model = build_model(
|
| 94 |
+
n_src = ds["n_src"],
|
| 95 |
+
sample_rate = ds["sample_rate"],
|
| 96 |
+
n_filters = mod["n_filters"],
|
| 97 |
+
filter_length = mod["filter_length"],
|
| 98 |
+
stride = mod["stride"],
|
| 99 |
+
n_blocks = mod["n_blocks"],
|
| 100 |
+
n_repeats = mod["n_repeats"],
|
| 101 |
+
bn_chan = mod["bn_chan"],
|
| 102 |
+
hid_chan = mod["hid_chan"],
|
| 103 |
+
skip_chan = mod["skip_chan"],
|
| 104 |
+
norm_type = mod["norm_type"],
|
| 105 |
+
mask_act = mod["mask_act"],
|
| 106 |
+
use_gradient_checkpointing = False, # pas besoin en inférence
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# ── Charger les poids entraînés ──────────
|
| 110 |
+
if not os.path.exists(args.ckpt):
|
| 111 |
+
raise FileNotFoundError(
|
| 112 |
+
f"Checkpoint introuvable : {args.ckpt}\n"
|
| 113 |
+
f"Lancez d'abord : python main.py train"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
load_checkpoint(model, args.ckpt, device)
|
| 117 |
+
model.to(device)
|
| 118 |
+
model.eval()
|
| 119 |
+
|
| 120 |
+
ckpt = torch.load(args.ckpt, map_location="cpu")
|
| 121 |
+
epoch = ckpt.get("epoch", "?")
|
| 122 |
+
val = ckpt.get("best_val_loss", None)
|
| 123 |
+
if val is not None:
|
| 124 |
+
print(f"[Model] Checkpoint chargé (epoch {epoch}, val loss {val:.4f})\n")
|
| 125 |
+
else:
|
| 126 |
+
print(f"[Model] Checkpoint chargé (epoch {epoch})\n")
|
| 127 |
+
|
| 128 |
+
# ── Séparation ───────────────────────────
|
| 129 |
+
separate(
|
| 130 |
+
mix_path = args.mix,
|
| 131 |
+
model = model,
|
| 132 |
+
sample_rate = ds["sample_rate"],
|
| 133 |
+
device = device,
|
| 134 |
+
out_dir = args.out_dir,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
print(f"\n[Done] Sources séparées dans : {args.out_dir}/")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
main()
|
training_metadata.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"checkpoint_file": "best.ckpt",
|
| 3 |
+
"epoch": 68,
|
| 4 |
+
"best_val_loss": -2.9099522034327188,
|
| 5 |
+
"validation_si_snr_db": 2.9099522034327188,
|
| 6 |
+
"contains_optimizer_state": true,
|
| 7 |
+
"contains_scaler_state": true,
|
| 8 |
+
"model_state_tensors": 261
|
| 9 |
+
}
|