der02 commited on
Commit
965496a
·
verified ·
1 Parent(s): 8eae9a8

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +104 -0
handler.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Inference API handler for the Sinama audio classifier.
3
+
4
+ Receives a raw audio file (WAV, MP3, etc.), extracts Mel Spectrogram
5
+ features, runs inference through the CNN, and returns predicted class
6
+ probabilities.
7
+ """
8
+
9
+ import json
10
+ import os
11
+ import tempfile
12
+
13
+ import librosa
14
+ import numpy as np
15
+ import tensorflow as tf
16
+
17
+
18
+ class EndpointHandler:
19
+ """HF Inference Endpoints handler."""
20
+
21
+ def __init__(self, path: str = ""):
22
+ # path is the model directory on the endpoint
23
+ model_path = os.path.join(path, "best_model.keras")
24
+ self.model = tf.keras.models.load_model(model_path)
25
+
26
+ with open(os.path.join(path, "label_map.json"), "r") as f:
27
+ raw = json.load(f)
28
+ self.label_map = {int(k): v for k, v in raw.items()}
29
+
30
+ with open(os.path.join(path, "config.json"), "r") as f:
31
+ self.cfg = json.load(f)
32
+
33
+ def preprocess(self, audio_bytes: bytes) -> np.ndarray:
34
+ """Convert raw audio bytes into a Mel Spectrogram array."""
35
+ sr = self.cfg["sample_rate"]
36
+ duration = self.cfg["duration"]
37
+ n_mels = self.cfg["n_mels"]
38
+ n_fft = self.cfg["n_fft"]
39
+ hop = self.cfg["hop_length"]
40
+ target_len = int(sr * duration)
41
+
42
+ # Write bytes to a temp file so librosa can read it
43
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
44
+ tmp.write(audio_bytes)
45
+ tmp_path = tmp.name
46
+
47
+ try:
48
+ waveform, _ = librosa.load(tmp_path, sr=sr, mono=True)
49
+ finally:
50
+ os.unlink(tmp_path)
51
+
52
+ # Pad / trim
53
+ if len(waveform) < target_len:
54
+ waveform = np.pad(waveform, (0, target_len - len(waveform)))
55
+ else:
56
+ waveform = waveform[:target_len]
57
+
58
+ # Mel spectrogram
59
+ mel = librosa.feature.melspectrogram(
60
+ y=waveform, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop
61
+ )
62
+ mel_db = librosa.power_to_db(mel, ref=np.max)
63
+
64
+ # Normalise
65
+ mean, std = mel_db.mean(), mel_db.std()
66
+ mel_db = (mel_db - mean) / (std + 1e-9)
67
+
68
+ # Add batch + channel dims → (1, freq, time, 1)
69
+ return mel_db[np.newaxis, ..., np.newaxis]
70
+
71
+ def __call__(self, data):
72
+ """
73
+ Handle an inference request.
74
+
75
+ Parameters
76
+ ----------
77
+ data : dict
78
+ Either {"inputs": <base64 or bytes>} for audio data,
79
+ or the raw request body bytes.
80
+
81
+ Returns
82
+ -------
83
+ list[dict] – [{"label": "word", "score": 0.95}, ...]
84
+ """
85
+ # Extract audio bytes from the request
86
+ if isinstance(data, dict):
87
+ audio = data.get("inputs", data.get("body", b""))
88
+ else:
89
+ audio = data
90
+
91
+ if isinstance(audio, str):
92
+ import base64
93
+ audio = base64.b64decode(audio)
94
+
95
+ features = self.preprocess(audio)
96
+ preds = self.model.predict(features, verbose=0)[0]
97
+
98
+ # Return top-5 predictions sorted by confidence
99
+ top_indices = np.argsort(preds)[::-1][:5]
100
+ results = [
101
+ {"label": self.label_map[int(i)], "score": round(float(preds[i]), 4)}
102
+ for i in top_indices
103
+ ]
104
+ return results