shubhamrazzsharma's picture
Update app.py
f4cc057 verified
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import librosa
# ── Constants (must match your notebook exactly) ──────────────────────────────
SR = 22050
DURATION = 30
N_SAMPLES = SR * DURATION # 661500
N_MELS = 128
HOP_LENGTH = 512
N_FFT = 2048
N_CLASSES = 10
GENRES = sorted([
"blues", "classical", "country", "disco", "hiphop",
"jazz", "metal", "pop", "reggae", "rock"
])
DEVICE = torch.device("cpu") # HF Spaces free tier = CPU only
# ── Model Architecture (must be identical to your notebook Cell 64) ───────────
class GenreCNN(nn.Module):
def __init__(self, n_classes=N_CLASSES):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32),
nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64),
nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
)
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128),
nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
)
self.conv4 = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256),
nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
)
self.pool = nn.AdaptiveAvgPool2d((4, 4)) # 256*4*4 = 4096
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(128, n_classes)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.pool(x)
return self.classifier(x)
# ── Load model once at startup ────────────────────────────────────────────────
model = GenreCNN().to(DEVICE)
model.load_state_dict(
torch.load("best_model.pth", map_location=DEVICE)
)
model.eval()
print("βœ“ Model loaded successfully")
# ── Audio preprocessing helpers ───────────────────────────────────────────────
def get_mel_spectrogram(y):
"""Convert raw audio array β†’ normalised (128, 512) mel spectrogram tensor."""
mel = librosa.feature.melspectrogram(
y=y, sr=SR, n_mels=N_MELS,
hop_length=HOP_LENGTH, n_fft=N_FFT
)
mel_db = librosa.power_to_db(mel, ref=np.max)
# normalise to 0-1
mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
# pad or crop to exactly 512 time frames
if mel_db.shape[1] >= 512:
mel_db = mel_db[:, :512]
else:
mel_db = np.pad(mel_db, ((0, 0), (0, 512 - mel_db.shape[1])))
return mel_db.astype(np.float32) # (128, 512)
def extract_3_crops(y):
"""Return [start, center, end] crops of exactly N_SAMPLES each."""
total = len(y)
if total <= N_SAMPLES:
y = np.pad(y, (0, N_SAMPLES - total))
return [y, y, y]
start = y[:N_SAMPLES]
mid_s = (total - N_SAMPLES) // 2
center = y[mid_s : mid_s + N_SAMPLES]
end = y[total - N_SAMPLES:]
return [start, center, end]
# ── Main prediction function ──────────────────────────────────────────────────
def predict_genre(audio_path):
"""
Takes an audio file path from Gradio,
returns a dict of {genre: probability} for the label display.
"""
if audio_path is None:
return {}
# 1. Load audio (librosa handles mp3, wav, flac, ogg, etc.)
try:
y, _ = librosa.load(audio_path, sr=SR, mono=True)
except Exception as e:
return {f"Error loading audio: {str(e)}": 1.0}
# 2. Extract 3 crops for test-time augmentation
crops = extract_3_crops(y)
# 3. Convert each crop to a mel spectrogram tensor
# Shape per crop: (1, 1, 128, 512) ← batch=1, channel=1, H, W
mel_tensors = [
torch.tensor(get_mel_spectrogram(crop)).unsqueeze(0).unsqueeze(0)
for crop in crops
]
# 4. Run all 3 crops through the model, average probabilities (TTA)
probs = torch.zeros(1, N_CLASSES)
with torch.no_grad():
for mel in mel_tensors:
logits = model(mel.to(DEVICE))
probs += torch.softmax(logits, dim=1).cpu()
probs /= 3.0 # average over 3 crops
# 5. Build output dict for Gradio's Label component
prob_np = probs.squeeze().numpy()
return {GENRES[i]: float(prob_np[i]) for i in range(N_CLASSES)}
# ── Gradio UI ─────────────────────────────────────────────────────────────────
demo = gr.Interface(
fn=predict_genre,
inputs=gr.Audio(
type="filepath",
label="Upload an audio file (wav, mp3, flac, ogg β€” any genre)"
),
outputs=gr.Label(
num_top_classes=10,
label="Genre probabilities"
),
title="🎡 Music Genre Classifier",
description=(
"Upload any audio clip and the model will predict its genre.\n\n"
"**Genres:** Blues Β· Classical Β· Country Β· Disco Β· Hip-hop Β· "
"Jazz Β· Metal Β· Pop Β· Reggae Β· Rock\n\n"
"**How it works:** The audio is converted to a mel spectrogram "
"(a 2D image of frequency vs time), then passed through a CNN trained "
"on 27,000 synthetic audio mashups. Three time-crop predictions are "
"averaged for more reliable results."
),
examples=[], # add example file paths here if you upload samples
)
if __name__ == "__main__":
demo.launch()