| import tensorflow as tf |
| import tensorflow_hub as hub |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import gradio as gr |
| import soundfile as sf |
| from scipy.signal import resample |
| import tempfile |
| import os |
|
|
| |
| yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1") |
|
|
| |
| def load_class_map(): |
| class_map_path = tf.keras.utils.get_file( |
| 'yamnet_class_map.csv', |
| 'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv' |
| ) |
| with open(class_map_path, 'r') as f: |
| return [line.strip().split(',')[2] for line in f.readlines()[1:]] |
|
|
| class_names = load_class_map() |
|
|
| |
| def classify_audio(audio_input): |
| try: |
| |
| if isinstance(audio_input, str): |
| file_path = audio_input |
|
|
| |
| elif hasattr(audio_input, "read"): |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(audio_input.read()) |
| file_path = tmp.name |
| else: |
| raise ValueError("Unsupported input format") |
|
|
| |
| audio_data, sample_rate = sf.read(file_path) |
|
|
| |
| if 'tmp' in locals(): |
| os.unlink(tmp.name) |
|
|
| |
| if len(audio_data.shape) > 1: |
| audio_data = np.mean(audio_data, axis=1) |
|
|
| |
| audio_data = audio_data / np.max(np.abs(audio_data)) |
|
|
| |
| target_rate = 16000 |
| if sample_rate != target_rate: |
| duration = audio_data.shape[0] / sample_rate |
| new_length = int(duration * target_rate) |
| audio_data = resample(audio_data, new_length) |
| sample_rate = target_rate |
|
|
| |
| waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32) |
|
|
| |
| scores, embeddings, spectrogram = yamnet_model(waveform) |
| mean_scores = tf.reduce_mean(scores, axis=0).numpy() |
| top_5 = np.argsort(mean_scores)[::-1][:5] |
|
|
| |
| top_prediction = class_names[top_5[0]] |
| top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5} |
|
|
| |
| fig, ax = plt.subplots() |
| ax.plot(audio_data) |
| ax.set_title("Waveform") |
| ax.set_xlabel("Time (samples)") |
| ax.set_ylabel("Amplitude") |
| plt.tight_layout() |
|
|
| return top_prediction, top_scores, fig |
|
|
| except Exception as e: |
| return f"Error processing audio: {str(e)}", {}, None |
|
|
| |
| interface = gr.Interface( |
| fn=classify_audio, |
| inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"), |
| outputs=[ |
| gr.Textbox(label="Top Prediction"), |
| gr.Label(label="Top 5 Classes with Scores"), |
| gr.Plot(label="Waveform") |
| ], |
| title="Audtheia YAMNet Audio Classifier", |
| description="Upload an environmental or animal sound to classify using the YAMNet model. Returns label predictions and waveform." |
| ) |
|
|
| if __name__ == "__main__": |
| interface.launch() |
|
|