feat(audio): wav validation, write helper, mono/16k normalization
Browse files- server/audio.py +69 -0
- tests/test_audio.py +76 -0
server/audio.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio I/O and validation utilities."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import io
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
MIN_DURATION_S = 0.5
|
| 12 |
+
MAX_DURATION_S = 60.0
|
| 13 |
+
MIN_SAMPLE_RATE = 16000
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AudioValidationError(ValueError):
|
| 17 |
+
"""Raised when a reference clip fails validation."""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass(frozen=True)
|
| 21 |
+
class ClipInfo:
|
| 22 |
+
duration_s: float
|
| 23 |
+
sample_rate: int
|
| 24 |
+
channels: int
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def validate_reference_clip(wav_bytes: bytes) -> ClipInfo:
|
| 28 |
+
try:
|
| 29 |
+
with sf.SoundFile(io.BytesIO(wav_bytes)) as f:
|
| 30 |
+
sample_rate = f.samplerate
|
| 31 |
+
channels = f.channels
|
| 32 |
+
frames = f.frames
|
| 33 |
+
except Exception as exc:
|
| 34 |
+
raise AudioValidationError(f"invalid audio format: {exc}") from exc
|
| 35 |
+
|
| 36 |
+
duration_s = frames / float(sample_rate) if sample_rate else 0.0
|
| 37 |
+
|
| 38 |
+
if sample_rate < MIN_SAMPLE_RATE:
|
| 39 |
+
raise AudioValidationError(
|
| 40 |
+
f"sample rate {sample_rate} below minimum {MIN_SAMPLE_RATE}"
|
| 41 |
+
)
|
| 42 |
+
if duration_s < MIN_DURATION_S:
|
| 43 |
+
raise AudioValidationError(f"clip too short ({duration_s:.2f}s)")
|
| 44 |
+
if duration_s > MAX_DURATION_S:
|
| 45 |
+
raise AudioValidationError(f"clip too long ({duration_s:.2f}s)")
|
| 46 |
+
|
| 47 |
+
return ClipInfo(duration_s=duration_s, sample_rate=sample_rate, channels=channels)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def write_wav_bytes(samples: np.ndarray, sample_rate: int) -> bytes:
|
| 51 |
+
buf = io.BytesIO()
|
| 52 |
+
sf.write(buf, samples, sample_rate, format="WAV", subtype="PCM_16")
|
| 53 |
+
return buf.getvalue()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def normalize_to_mono_16k(
|
| 57 |
+
samples: np.ndarray, original_sr: int, target_sr: int = 16000
|
| 58 |
+
) -> tuple[np.ndarray, int]:
|
| 59 |
+
"""Downmix to mono and naive linear resample to target_sr."""
|
| 60 |
+
if samples.ndim == 2:
|
| 61 |
+
samples = samples.mean(axis=1)
|
| 62 |
+
if original_sr == target_sr:
|
| 63 |
+
return samples.astype(np.float32), target_sr
|
| 64 |
+
duration = samples.shape[0] / float(original_sr)
|
| 65 |
+
target_len = int(round(duration * target_sr))
|
| 66 |
+
x_old = np.linspace(0.0, 1.0, samples.shape[0], endpoint=False)
|
| 67 |
+
x_new = np.linspace(0.0, 1.0, target_len, endpoint=False)
|
| 68 |
+
out = np.interp(x_new, x_old, samples).astype(np.float32)
|
| 69 |
+
return out, target_sr
|
tests/test_audio.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import wave
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from server.audio import (
|
| 8 |
+
AudioValidationError,
|
| 9 |
+
normalize_to_mono_16k,
|
| 10 |
+
validate_reference_clip,
|
| 11 |
+
write_wav_bytes,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _make_wav_bytes(samples: np.ndarray, sample_rate: int, channels: int = 1) -> bytes:
|
| 16 |
+
buf = io.BytesIO()
|
| 17 |
+
with wave.open(buf, "wb") as w:
|
| 18 |
+
w.setnchannels(channels)
|
| 19 |
+
w.setsampwidth(2)
|
| 20 |
+
w.setframerate(sample_rate)
|
| 21 |
+
pcm = (samples * 32767).clip(-32768, 32767).astype(np.int16)
|
| 22 |
+
if channels > 1:
|
| 23 |
+
pcm = np.repeat(pcm[:, None], channels, axis=1).flatten()
|
| 24 |
+
w.writeframes(pcm.tobytes())
|
| 25 |
+
return buf.getvalue()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_write_wav_bytes_roundtrip():
|
| 29 |
+
samples = np.sin(np.linspace(0, 6.28, 24000)).astype(np.float32)
|
| 30 |
+
wav_bytes = write_wav_bytes(samples, sample_rate=24000)
|
| 31 |
+
assert wav_bytes[:4] == b"RIFF"
|
| 32 |
+
with wave.open(io.BytesIO(wav_bytes)) as w:
|
| 33 |
+
assert w.getnchannels() == 1
|
| 34 |
+
assert w.getframerate() == 24000
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_validate_accepts_valid_clip():
|
| 38 |
+
samples = np.zeros(48000, dtype=np.float32) # 2s at 24kHz
|
| 39 |
+
wav = _make_wav_bytes(samples, 24000)
|
| 40 |
+
info = validate_reference_clip(wav)
|
| 41 |
+
assert info.duration_s == pytest.approx(2.0, rel=1e-3)
|
| 42 |
+
assert info.sample_rate == 24000
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_validate_rejects_too_short():
|
| 46 |
+
samples = np.zeros(2400, dtype=np.float32) # 0.1s
|
| 47 |
+
wav = _make_wav_bytes(samples, 24000)
|
| 48 |
+
with pytest.raises(AudioValidationError, match="too short"):
|
| 49 |
+
validate_reference_clip(wav)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_validate_rejects_too_long():
|
| 53 |
+
samples = np.zeros(24000 * 70, dtype=np.float32) # 70s
|
| 54 |
+
wav = _make_wav_bytes(samples, 24000)
|
| 55 |
+
with pytest.raises(AudioValidationError, match="too long"):
|
| 56 |
+
validate_reference_clip(wav)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_validate_rejects_low_sample_rate():
|
| 60 |
+
samples = np.zeros(8000, dtype=np.float32) # 1s at 8kHz
|
| 61 |
+
wav = _make_wav_bytes(samples, 8000)
|
| 62 |
+
with pytest.raises(AudioValidationError, match="sample rate"):
|
| 63 |
+
validate_reference_clip(wav)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_validate_rejects_non_wav_bytes():
|
| 67 |
+
with pytest.raises(AudioValidationError, match="format"):
|
| 68 |
+
validate_reference_clip(b"not a wav")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_normalize_downmixes_stereo_and_resamples():
|
| 72 |
+
samples = np.zeros((48000, 2), dtype=np.float32) # 1s stereo at 48kHz
|
| 73 |
+
out, sr = normalize_to_mono_16k(samples, original_sr=48000)
|
| 74 |
+
assert out.ndim == 1
|
| 75 |
+
assert sr == 16000
|
| 76 |
+
assert out.shape[0] == 16000
|