""" Custom Inference API handler for the Sinama audio classifier. Receives a raw audio file (WAV, MP3, etc.), extracts Mel Spectrogram features, runs inference through the CNN, and returns predicted class probabilities. """ import json import os import tempfile import librosa import numpy as np import tensorflow as tf class EndpointHandler: """HF Inference Endpoints handler.""" def __init__(self, path: str = ""): # path is the model directory on the endpoint model_path = os.path.join(path, "best_model.keras") self.model = tf.keras.models.load_model(model_path) with open(os.path.join(path, "label_map.json"), "r") as f: raw = json.load(f) self.label_map = {int(k): v for k, v in raw.items()} with open(os.path.join(path, "config.json"), "r") as f: self.cfg = json.load(f) def preprocess(self, audio_bytes: bytes) -> np.ndarray: """Convert raw audio bytes into a Mel Spectrogram array.""" sr = self.cfg["sample_rate"] duration = self.cfg["duration"] n_mels = self.cfg["n_mels"] n_fft = self.cfg["n_fft"] hop = self.cfg["hop_length"] target_len = int(sr * duration) # Write bytes to a temp file so librosa can read it with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name try: waveform, _ = librosa.load(tmp_path, sr=sr, mono=True) finally: os.unlink(tmp_path) # Pad / trim if len(waveform) < target_len: waveform = np.pad(waveform, (0, target_len - len(waveform))) else: waveform = waveform[:target_len] # Mel spectrogram mel = librosa.feature.melspectrogram( y=waveform, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop ) mel_db = librosa.power_to_db(mel, ref=np.max) # Normalise mean, std = mel_db.mean(), mel_db.std() mel_db = (mel_db - mean) / (std + 1e-9) # Add batch + channel dims → (1, freq, time, 1) return mel_db[np.newaxis, ..., np.newaxis] def __call__(self, data): """ Handle an inference request. Parameters ---------- data : dict Either {"inputs": } for audio data, or the raw request body bytes. Returns ------- list[dict] – [{"label": "word", "score": 0.95}, ...] """ # Extract audio bytes from the request if isinstance(data, dict): audio = data.get("inputs", data.get("body", b"")) else: audio = data if isinstance(audio, str): import base64 audio = base64.b64decode(audio) features = self.preprocess(audio) preds = self.model.predict(features, verbose=0)[0] # Return top-5 predictions sorted by confidence top_indices = np.argsort(preds)[::-1][:5] results = [ {"label": self.label_map[int(i)], "score": round(float(preds[i]), 4)} for i in top_indices ] return results