Rahma89 commited on
Commit
27441d2
·
verified ·
1 Parent(s): 11742d1

Upload Conv-TasNet best checkpoint

Browse files
Files changed (8) hide show
  1. README.md +95 -0
  2. best.ckpt +3 -0
  3. configs/data.yaml +10 -0
  4. configs/train.yaml +35 -0
  5. requirements.txt +5 -0
  6. src/model.py +114 -0
  7. src/separate.py +141 -0
  8. training_metadata.json +9 -0
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - audio
5
+ - speech
6
+ - source-separation
7
+ - conv-tasnet
8
+ - asteroid
9
+ - pytorch
10
+ library_name: pytorch
11
+ pipeline_tag: audio-to-audio
12
+ ---
13
+
14
+ # Cocktail Party AI - Conv-TasNet 3-Source Separator
15
+
16
+ This repository contains the best checkpoint from a Conv-TasNet model trained for speech source separation.
17
+ The model takes a mixed speech waveform and estimates 3 separated source waveforms.
18
+
19
+ ## Checkpoint
20
+
21
+ - File: `best.ckpt`
22
+ - Architecture: Asteroid `ConvTasNet`
23
+ - Number of sources: 3
24
+ - Sample rate: 16 kHz
25
+ - Training checkpoint epoch: 68
26
+ - Best validation loss: -2.909952
27
+ - Approximate validation SI-SNR: 2.91 dB
28
+
29
+ ## Files
30
+
31
+ ```text
32
+ best.ckpt
33
+ configs/data.yaml
34
+ configs/train.yaml
35
+ requirements.txt
36
+ src/model.py
37
+ src/separate.py
38
+ ```
39
+
40
+ ## Usage
41
+
42
+ Install dependencies:
43
+
44
+ ```bash
45
+ pip install -r requirements.txt
46
+ ```
47
+
48
+ Load the checkpoint with the project code:
49
+
50
+ ```python
51
+ import yaml
52
+ import torch
53
+
54
+ from src.model import build_model, load_checkpoint
55
+
56
+ with open("configs/train.yaml") as f:
57
+ train_cfg = yaml.safe_load(f)
58
+ with open("configs/data.yaml") as f:
59
+ data_cfg = yaml.safe_load(f)
60
+
61
+ mod = train_cfg["model"]
62
+ ds = data_cfg["dataset"]
63
+
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+
66
+ model = build_model(
67
+ n_src=ds["n_src"],
68
+ sample_rate=ds["sample_rate"],
69
+ n_filters=mod["n_filters"],
70
+ filter_length=mod["filter_length"],
71
+ stride=mod["stride"],
72
+ n_blocks=mod["n_blocks"],
73
+ n_repeats=mod["n_repeats"],
74
+ bn_chan=mod["bn_chan"],
75
+ hid_chan=mod["hid_chan"],
76
+ skip_chan=mod["skip_chan"],
77
+ norm_type=mod["norm_type"],
78
+ mask_act=mod["mask_act"],
79
+ use_gradient_checkpointing=False,
80
+ ).to(device)
81
+
82
+ load_checkpoint(model, "best.ckpt", device)
83
+ model.eval()
84
+ ```
85
+
86
+ To separate a WAV file using this project:
87
+
88
+ ```bash
89
+ python src/separate.py --mix path/to/mixture.wav --ckpt best.ckpt
90
+ ```
91
+
92
+ ## Notes
93
+
94
+ This is a research/training checkpoint, not a fully packaged `transformers` pipeline.
95
+ It depends on PyTorch, Torchaudio, and Asteroid.
best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ca8736f8da084c480bd765a99cb0fb4f2f1b2a9c94437cf345896c75bab8ee4
3
+ size 23479878
configs/data.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ root_dir: data/mixtures # chemin vers le dossier des mélanges
3
+ n_src: 3 # nombre de sources (source_1 à source_n)
4
+ sample_rate: 16000 # fréquence d'échantillonnage en Hz
5
+ segment_duration: 5.0 # durée d'un segment en secondes (0 = fichier entier)
6
+
7
+ splits:
8
+ train_ratio: 0.8 # 80% train
9
+ val_ratio: 0.1 # 10% validation
10
+ test_ratio: 0.1 # calculé automatiquement
configs/train.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ epochs: 200
3
+ early_stopping: 20
4
+ batch_size: 48
5
+ accumulation_steps: 2 # effective batch = 48 × 2 = 96
6
+ learning_rate: 0.00005
7
+ grad_clip: 30.0 # ← réduit : 1.0 provoquait des instabilités
8
+ save_every: 10
9
+ num_workers: 4
10
+ seed: 42
11
+
12
+ scheduler:
13
+ name: warmup_cosine
14
+ warmup_epochs: 5 # ← augmenté : plus de warmup avec peu de données
15
+ min_lr: 0.0000001
16
+ optimizer:
17
+ name: adam
18
+ weight_decay: 0.001
19
+
20
+ model:
21
+ n_filters: 256
22
+ filter_length: 16
23
+ stride: 8
24
+ n_blocks: 6
25
+ n_repeats: 3
26
+ bn_chan: 128
27
+ hid_chan: 256
28
+ skip_chan: 128
29
+ norm_type: gLN
30
+ mask_act: relu
31
+ gradient_checkpointing: true
32
+
33
+ paths:
34
+ checkpoint_dir: checkpoints_3src_2
35
+ log_dir: logs_3src_2
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio>=2.0.0
3
+ asteroid>=0.6.0
4
+ numpy>=1.24.0
5
+ matplotlib>=3.10.8
src/model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from asteroid.models import ConvTasNet
6
+
7
+
8
+ def build_model(n_src=5, sample_rate=8000,
9
+ n_filters=512, filter_length=16,
10
+ stride=8, n_blocks=8, n_repeats=3,
11
+ bn_chan=128, hid_chan=512, skip_chan=128,
12
+ norm_type="gLN", mask_act="relu",
13
+ use_gradient_checkpointing=False):
14
+
15
+ model = ConvTasNet(
16
+ n_src=n_src, sample_rate=sample_rate,
17
+ n_filters=n_filters, filter_length=filter_length,
18
+ stride=stride, n_blocks=n_blocks, n_repeats=n_repeats,
19
+ bn_chan=bn_chan, hid_chan=hid_chan, skip_chan=skip_chan,
20
+ norm_type=norm_type, mask_act=mask_act,
21
+ )
22
+
23
+ if use_gradient_checkpointing:
24
+ _apply_gradient_checkpointing(model)
25
+ print("[Model] Gradient checkpointing : ACTIVÉ (-50% VRAM, +30% temps)")
26
+ else:
27
+ print("[Model] Gradient checkpointing : désactivé")
28
+
29
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
30
+ print(f"[Model] Conv-TasNet | Paramètres entraînables : {n_params:,}")
31
+ return model
32
+
33
+
34
+ def _apply_gradient_checkpointing(model):
35
+ if not hasattr(model, "masker") or not hasattr(model.masker, "TCN"):
36
+ print("[Warning] masker.TCN introuvable — gradient checkpointing non appliqué.")
37
+ return
38
+
39
+ original_blocks = list(model.masker.TCN.named_children())
40
+ if not original_blocks:
41
+ return
42
+
43
+ for name, block in original_blocks:
44
+ _wrap_block(model.masker.TCN, name, block)
45
+
46
+ print(f"[Model] {len(original_blocks)} blocs TCN checkpointés.")
47
+
48
+
49
+ def _wrap_block(parent, name, block):
50
+ class CheckpointedBlock(nn.Module):
51
+ def __init__(self, inner):
52
+ super().__init__()
53
+ self.inner = inner
54
+
55
+ def forward(self, x):
56
+ if not x.requires_grad:
57
+ x = x.requires_grad_(True)
58
+ return checkpoint(self.inner, x, use_reentrant=False)
59
+
60
+ setattr(parent, name, CheckpointedBlock(block))
61
+
62
+
63
+ def load_checkpoint(model, path, device="cpu"):
64
+ """
65
+ Load checkpoint safely.
66
+ Automatically handles the .inner. key mismatch caused by
67
+ gradient checkpointing wrapper (CheckpointedBlock).
68
+ """
69
+ ckpt = torch.load(path, map_location=device)
70
+ state = ckpt.get("model_state_dict", ckpt)
71
+
72
+ model_keys = set(model.state_dict().keys())
73
+
74
+ # Try loading as-is first
75
+ missing, unexpected = model.load_state_dict(state, strict=False)
76
+ missing_set = set(missing)
77
+ unexpected_set = set(unexpected)
78
+
79
+ # Case 1 : checkpoint has .inner. but model doesn't → strip .inner.
80
+ if any(".inner." in k for k in unexpected_set) and \
81
+ any(".inner." not in k for k in missing_set):
82
+ state = {k.replace(".inner.", "."): v for k, v in state.items()}
83
+ model.load_state_dict(state, strict=True)
84
+ print("[Model] '.inner.' stripped from checkpoint keys (GC ON → OFF)")
85
+
86
+ # Case 2 : model has .inner. but checkpoint doesn't → add .inner.
87
+ elif any(".inner." in k for k in missing_set) and \
88
+ any(".inner." not in k for k in unexpected_set):
89
+ new_state = {}
90
+ for k, v in state.items():
91
+ if "masker.TCN." in k and ".inner." not in k:
92
+ parts = k.split(".")
93
+ parts.insert(3, "inner")
94
+ k = ".".join(parts)
95
+ new_state[k] = v
96
+ model.load_state_dict(new_state, strict=True)
97
+ print("[Model] '.inner.' added to checkpoint keys (GC OFF → ON)")
98
+
99
+ # Case 3 : loaded fine on first try
100
+ elif len(missing) == 0 and len(unexpected) == 0:
101
+ print("[Model] Checkpoint chargé sans modification")
102
+
103
+ else:
104
+ raise RuntimeError(
105
+ f"Cannot load checkpoint — unresolvable key mismatch:\n"
106
+ f" Missing : {list(missing)[:3]}...\n"
107
+ f" Unexpected: {list(unexpected)[:3]}..."
108
+ )
109
+
110
+ epoch = ckpt.get("epoch", "?")
111
+ val_loss = ckpt.get("best_val_loss", "?")
112
+ print(f"[Model] Checkpoint chargé depuis {path} "
113
+ f"(epoch {epoch}, val loss {val_loss})")
114
+ return model
src/separate.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ separate.py — Séparation de sources avec le modèle entraîné
3
+ Usage :
4
+ python main.py separate --mix data/mixture/mix_0/mixture.wav
5
+ python main.py separate --mix data/mixture/mix_0/mixture.wav --ckpt checkpoints/best.ckpt
6
+ python main.py separate --mix mon_audio.wav --out_dir outputs/separated_audio
7
+ """
8
+
9
+ import os
10
+ import argparse
11
+ import torch
12
+ import torchaudio
13
+
14
+ from src.model import build_model, load_checkpoint
15
+ import yaml
16
+
17
+
18
+ def load_config(path):
19
+ with open(path, "r") as f:
20
+ return yaml.safe_load(f)
21
+
22
+
23
+ def parse_args():
24
+ p = argparse.ArgumentParser(description="Séparation de sources Conv-TasNet")
25
+ p.add_argument("--mix", type=str, required=True,
26
+ help="Chemin vers le fichier mixture.wav à séparer")
27
+ p.add_argument("--ckpt", type=str, default="checkpoints/best.ckpt",
28
+ help="Checkpoint du modèle entraîné")
29
+ p.add_argument("--out_dir", type=str, default="outputs/separated_audio",
30
+ help="Dossier de sortie pour les sources séparées")
31
+ p.add_argument("--train_cfg", type=str, default="configs/train.yaml")
32
+ p.add_argument("--data_cfg", type=str, default="configs/data.yaml")
33
+ return p.parse_args()
34
+
35
+
36
+ def separate(mix_path, model, sample_rate, device, out_dir):
37
+ """Charge un mixture.wav, sépare les sources, sauvegarde les .wav."""
38
+
39
+ # ── Charger le fichier audio ─────────────
40
+ mixture, sr = torchaudio.load(mix_path)
41
+
42
+ if sr != sample_rate:
43
+ print(f" Resample {sr} Hz → {sample_rate} Hz")
44
+ mixture = torchaudio.functional.resample(mixture, sr, sample_rate)
45
+
46
+ # Mono (1, T)
47
+ if mixture.shape[0] > 1:
48
+ mixture = mixture.mean(dim=0, keepdim=True)
49
+
50
+ print(f" Durée : {mixture.shape[-1] / sample_rate:.2f}s "
51
+ f"({mixture.shape[-1]} samples)")
52
+
53
+ # ── Inférence ────────────────────────────
54
+ mixture = mixture.to(device) # (1, T)
55
+ with torch.no_grad():
56
+ # Le modèle attend (B, T) → unsqueeze batch dim
57
+ est_sources = model(mixture.unsqueeze(0)) # (1, n_src, T)
58
+ est_sources = est_sources.squeeze(0) # (n_src, T)
59
+
60
+ # ── Sauvegarder les sources séparées ─────
61
+ os.makedirs(out_dir, exist_ok=True)
62
+ mix_name = os.path.splitext(os.path.basename(mix_path))[0]
63
+
64
+ for i, src in enumerate(est_sources):
65
+ src_cpu = src.unsqueeze(0).cpu() # (1, T)
66
+
67
+ # Normaliser pour éviter la saturation
68
+ max_val = src_cpu.abs().max()
69
+ if max_val > 0:
70
+ src_cpu = src_cpu / max_val * 0.9
71
+
72
+ out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav")
73
+ torchaudio.save(out_path, src_cpu, sample_rate)
74
+ print(f" ✓ Source {i+1} sauvegardée : {out_path}")
75
+
76
+ return est_sources
77
+
78
+
79
+ def main():
80
+ args = parse_args()
81
+ tcfg = load_config(args.train_cfg)
82
+ dcfg = load_config(args.data_cfg)
83
+
84
+ mod = tcfg["model"]
85
+ ds = dcfg["dataset"]
86
+
87
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ print(f"\n[Config] Device : {device}")
89
+ print(f"[Config] Checkpoint : {args.ckpt}")
90
+ print(f"[Config] Fichier mix : {args.mix}\n")
91
+
92
+ # ── Charger le modèle ────────────────────
93
+ model = build_model(
94
+ n_src = ds["n_src"],
95
+ sample_rate = ds["sample_rate"],
96
+ n_filters = mod["n_filters"],
97
+ filter_length = mod["filter_length"],
98
+ stride = mod["stride"],
99
+ n_blocks = mod["n_blocks"],
100
+ n_repeats = mod["n_repeats"],
101
+ bn_chan = mod["bn_chan"],
102
+ hid_chan = mod["hid_chan"],
103
+ skip_chan = mod["skip_chan"],
104
+ norm_type = mod["norm_type"],
105
+ mask_act = mod["mask_act"],
106
+ use_gradient_checkpointing = False, # pas besoin en inférence
107
+ )
108
+
109
+ # ── Charger les poids entraînés ──────────
110
+ if not os.path.exists(args.ckpt):
111
+ raise FileNotFoundError(
112
+ f"Checkpoint introuvable : {args.ckpt}\n"
113
+ f"Lancez d'abord : python main.py train"
114
+ )
115
+
116
+ load_checkpoint(model, args.ckpt, device)
117
+ model.to(device)
118
+ model.eval()
119
+
120
+ ckpt = torch.load(args.ckpt, map_location="cpu")
121
+ epoch = ckpt.get("epoch", "?")
122
+ val = ckpt.get("best_val_loss", None)
123
+ if val is not None:
124
+ print(f"[Model] Checkpoint chargé (epoch {epoch}, val loss {val:.4f})\n")
125
+ else:
126
+ print(f"[Model] Checkpoint chargé (epoch {epoch})\n")
127
+
128
+ # ── Séparation ───────────────────────────
129
+ separate(
130
+ mix_path = args.mix,
131
+ model = model,
132
+ sample_rate = ds["sample_rate"],
133
+ device = device,
134
+ out_dir = args.out_dir,
135
+ )
136
+
137
+ print(f"\n[Done] Sources séparées dans : {args.out_dir}/")
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()
training_metadata.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "checkpoint_file": "best.ckpt",
3
+ "epoch": 68,
4
+ "best_val_loss": -2.9099522034327188,
5
+ "validation_si_snr_db": 2.9099522034327188,
6
+ "contains_optimizer_state": true,
7
+ "contains_scaler_state": true,
8
+ "model_state_tensors": 261
9
+ }