techfreakworm commited on
Commit
ca78147
·
unverified ·
1 Parent(s): 46728df

feat(audio): wav validation, write helper, mono/16k normalization

Browse files
Files changed (2) hide show
  1. server/audio.py +69 -0
  2. tests/test_audio.py +76 -0
server/audio.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio I/O and validation utilities."""
2
+ from __future__ import annotations
3
+
4
+ import io
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ import soundfile as sf
9
+
10
+
11
+ MIN_DURATION_S = 0.5
12
+ MAX_DURATION_S = 60.0
13
+ MIN_SAMPLE_RATE = 16000
14
+
15
+
16
+ class AudioValidationError(ValueError):
17
+ """Raised when a reference clip fails validation."""
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class ClipInfo:
22
+ duration_s: float
23
+ sample_rate: int
24
+ channels: int
25
+
26
+
27
+ def validate_reference_clip(wav_bytes: bytes) -> ClipInfo:
28
+ try:
29
+ with sf.SoundFile(io.BytesIO(wav_bytes)) as f:
30
+ sample_rate = f.samplerate
31
+ channels = f.channels
32
+ frames = f.frames
33
+ except Exception as exc:
34
+ raise AudioValidationError(f"invalid audio format: {exc}") from exc
35
+
36
+ duration_s = frames / float(sample_rate) if sample_rate else 0.0
37
+
38
+ if sample_rate < MIN_SAMPLE_RATE:
39
+ raise AudioValidationError(
40
+ f"sample rate {sample_rate} below minimum {MIN_SAMPLE_RATE}"
41
+ )
42
+ if duration_s < MIN_DURATION_S:
43
+ raise AudioValidationError(f"clip too short ({duration_s:.2f}s)")
44
+ if duration_s > MAX_DURATION_S:
45
+ raise AudioValidationError(f"clip too long ({duration_s:.2f}s)")
46
+
47
+ return ClipInfo(duration_s=duration_s, sample_rate=sample_rate, channels=channels)
48
+
49
+
50
+ def write_wav_bytes(samples: np.ndarray, sample_rate: int) -> bytes:
51
+ buf = io.BytesIO()
52
+ sf.write(buf, samples, sample_rate, format="WAV", subtype="PCM_16")
53
+ return buf.getvalue()
54
+
55
+
56
+ def normalize_to_mono_16k(
57
+ samples: np.ndarray, original_sr: int, target_sr: int = 16000
58
+ ) -> tuple[np.ndarray, int]:
59
+ """Downmix to mono and naive linear resample to target_sr."""
60
+ if samples.ndim == 2:
61
+ samples = samples.mean(axis=1)
62
+ if original_sr == target_sr:
63
+ return samples.astype(np.float32), target_sr
64
+ duration = samples.shape[0] / float(original_sr)
65
+ target_len = int(round(duration * target_sr))
66
+ x_old = np.linspace(0.0, 1.0, samples.shape[0], endpoint=False)
67
+ x_new = np.linspace(0.0, 1.0, target_len, endpoint=False)
68
+ out = np.interp(x_new, x_old, samples).astype(np.float32)
69
+ return out, target_sr
tests/test_audio.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import wave
3
+
4
+ import numpy as np
5
+ import pytest
6
+
7
+ from server.audio import (
8
+ AudioValidationError,
9
+ normalize_to_mono_16k,
10
+ validate_reference_clip,
11
+ write_wav_bytes,
12
+ )
13
+
14
+
15
+ def _make_wav_bytes(samples: np.ndarray, sample_rate: int, channels: int = 1) -> bytes:
16
+ buf = io.BytesIO()
17
+ with wave.open(buf, "wb") as w:
18
+ w.setnchannels(channels)
19
+ w.setsampwidth(2)
20
+ w.setframerate(sample_rate)
21
+ pcm = (samples * 32767).clip(-32768, 32767).astype(np.int16)
22
+ if channels > 1:
23
+ pcm = np.repeat(pcm[:, None], channels, axis=1).flatten()
24
+ w.writeframes(pcm.tobytes())
25
+ return buf.getvalue()
26
+
27
+
28
+ def test_write_wav_bytes_roundtrip():
29
+ samples = np.sin(np.linspace(0, 6.28, 24000)).astype(np.float32)
30
+ wav_bytes = write_wav_bytes(samples, sample_rate=24000)
31
+ assert wav_bytes[:4] == b"RIFF"
32
+ with wave.open(io.BytesIO(wav_bytes)) as w:
33
+ assert w.getnchannels() == 1
34
+ assert w.getframerate() == 24000
35
+
36
+
37
+ def test_validate_accepts_valid_clip():
38
+ samples = np.zeros(48000, dtype=np.float32) # 2s at 24kHz
39
+ wav = _make_wav_bytes(samples, 24000)
40
+ info = validate_reference_clip(wav)
41
+ assert info.duration_s == pytest.approx(2.0, rel=1e-3)
42
+ assert info.sample_rate == 24000
43
+
44
+
45
+ def test_validate_rejects_too_short():
46
+ samples = np.zeros(2400, dtype=np.float32) # 0.1s
47
+ wav = _make_wav_bytes(samples, 24000)
48
+ with pytest.raises(AudioValidationError, match="too short"):
49
+ validate_reference_clip(wav)
50
+
51
+
52
+ def test_validate_rejects_too_long():
53
+ samples = np.zeros(24000 * 70, dtype=np.float32) # 70s
54
+ wav = _make_wav_bytes(samples, 24000)
55
+ with pytest.raises(AudioValidationError, match="too long"):
56
+ validate_reference_clip(wav)
57
+
58
+
59
+ def test_validate_rejects_low_sample_rate():
60
+ samples = np.zeros(8000, dtype=np.float32) # 1s at 8kHz
61
+ wav = _make_wav_bytes(samples, 8000)
62
+ with pytest.raises(AudioValidationError, match="sample rate"):
63
+ validate_reference_clip(wav)
64
+
65
+
66
+ def test_validate_rejects_non_wav_bytes():
67
+ with pytest.raises(AudioValidationError, match="format"):
68
+ validate_reference_clip(b"not a wav")
69
+
70
+
71
+ def test_normalize_downmixes_stereo_and_resamples():
72
+ samples = np.zeros((48000, 2), dtype=np.float32) # 1s stereo at 48kHz
73
+ out, sr = normalize_to_mono_16k(samples, original_sr=48000)
74
+ assert out.ndim == 1
75
+ assert sr == 16000
76
+ assert out.shape[0] == 16000