Spaces:
Sleeping
Sleeping
File size: 6,122 Bytes
d662341 e34037c d662341 e34037c d662341 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import librosa
# ββ Constants (must match your notebook exactly) ββββββββββββββββββββββββββββββ
SR = 22050
DURATION = 30
N_SAMPLES = SR * DURATION # 661500
N_MELS = 128
HOP_LENGTH = 512
N_FFT = 2048
N_CLASSES = 10
GENRES = sorted([
"blues", "classical", "country", "disco", "hiphop",
"jazz", "metal", "pop", "reggae", "rock"
])
DEVICE = torch.device("cpu") # HF Spaces free tier = CPU only
# ββ Model Architecture (must be identical to your notebook Cell 64) βββββββββββ
class GenreCNN(nn.Module):
def __init__(self, n_classes=N_CLASSES):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32),
nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64),
nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
)
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128),
nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
)
self.conv4 = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256),
nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
)
self.pool = nn.AdaptiveAvgPool2d((4, 4)) # 256*4*4 = 4096
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(128, n_classes)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.pool(x)
return self.classifier(x)
# ββ Load model once at startup ββββββββββββββββββββββββββββββββββββββββββββββββ
model = GenreCNN().to(DEVICE)
model.load_state_dict(
torch.load("best_model.pth", map_location=DEVICE)
)
model.eval()
print("β Model loaded successfully")
# ββ Audio preprocessing helpers βββββββββββββββββββββββββββββββββββββββββββββββ
def get_mel_spectrogram(y):
"""Convert raw audio array β normalised (128, 512) mel spectrogram tensor."""
mel = librosa.feature.melspectrogram(
y=y, sr=SR, n_mels=N_MELS,
hop_length=HOP_LENGTH, n_fft=N_FFT
)
mel_db = librosa.power_to_db(mel, ref=np.max)
# normalise to 0-1
mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
# pad or crop to exactly 512 time frames
if mel_db.shape[1] >= 512:
mel_db = mel_db[:, :512]
else:
mel_db = np.pad(mel_db, ((0, 0), (0, 512 - mel_db.shape[1])))
return mel_db.astype(np.float32) # (128, 512)
def extract_3_crops(y):
"""Return [start, center, end] crops of exactly N_SAMPLES each."""
total = len(y)
if total <= N_SAMPLES:
y = np.pad(y, (0, N_SAMPLES - total))
return [y, y, y]
start = y[:N_SAMPLES]
mid_s = (total - N_SAMPLES) // 2
center = y[mid_s : mid_s + N_SAMPLES]
end = y[total - N_SAMPLES:]
return [start, center, end]
# ββ Main prediction function ββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict_genre(audio_path):
"""
Takes an audio file path from Gradio,
returns a dict of {genre: probability} for the label display.
"""
if audio_path is None:
return {}
# 1. Load audio (librosa handles mp3, wav, flac, ogg, etc.)
try:
y, _ = librosa.load(audio_path, sr=SR, mono=True)
except Exception as e:
return {f"Error loading audio: {str(e)}": 1.0}
# 2. Extract 3 crops for test-time augmentation
crops = extract_3_crops(y)
# 3. Convert each crop to a mel spectrogram tensor
# Shape per crop: (1, 1, 128, 512) β batch=1, channel=1, H, W
mel_tensors = [
torch.tensor(get_mel_spectrogram(crop)).unsqueeze(0).unsqueeze(0)
for crop in crops
]
# 4. Run all 3 crops through the model, average probabilities (TTA)
probs = torch.zeros(1, N_CLASSES)
with torch.no_grad():
for mel in mel_tensors:
logits = model(mel.to(DEVICE))
probs += torch.softmax(logits, dim=1).cpu()
probs /= 3.0 # average over 3 crops
# 5. Build output dict for Gradio's Label component
prob_np = probs.squeeze().numpy()
return {GENRES[i]: float(prob_np[i]) for i in range(N_CLASSES)}
# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=predict_genre,
inputs=gr.Audio(
type="filepath",
label="Upload an audio file (wav, mp3, flac, ogg β any genre)"
),
outputs=gr.Label(
num_top_classes=10,
label="Genre probabilities"
),
title="π΅ Music Genre Classifier",
description=(
"Upload any audio clip and the model will predict its genre.\n\n"
"**Genres:** Blues Β· Classical Β· Country Β· Disco Β· Hip-hop Β· "
"Jazz Β· Metal Β· Pop Β· Reggae Β· Rock\n\n"
"**How it works:** The audio is converted to a mel spectrogram "
"(a 2D image of frequency vs time), then passed through a CNN trained "
"on 27,000 synthetic audio mashups. Three time-crop predictions are "
"averaged for more reliable results."
),
examples=[], # add example file paths here if you upload samples
)
if __name__ == "__main__":
demo.launch() |