import gradio as gr import torch import torch.nn as nn import numpy as np import librosa # ── Constants (must match your notebook exactly) ────────────────────────────── SR = 22050 DURATION = 30 N_SAMPLES = SR * DURATION # 661500 N_MELS = 128 HOP_LENGTH = 512 N_FFT = 2048 N_CLASSES = 10 GENRES = sorted([ "blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock" ]) DEVICE = torch.device("cpu") # HF Spaces free tier = CPU only # ── Model Architecture (must be identical to your notebook Cell 64) ─────────── class GenreCNN(nn.Module): def __init__(self, n_classes=N_CLASSES): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25) ) self.conv2 = nn.Sequential( nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25) ) self.conv3 = nn.Sequential( nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25) ) self.conv4 = nn.Sequential( nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25) ) self.pool = nn.AdaptiveAvgPool2d((4, 4)) # 256*4*4 = 4096 self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, n_classes) ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.pool(x) return self.classifier(x) # ── Load model once at startup ──────────────────────────────────────────────── model = GenreCNN().to(DEVICE) model.load_state_dict( torch.load("best_model.pth", map_location=DEVICE) ) model.eval() print("✓ Model loaded successfully") # ── Audio preprocessing helpers ─────────────────────────────────────────────── def get_mel_spectrogram(y): """Convert raw audio array → normalised (128, 512) mel spectrogram tensor.""" mel = librosa.feature.melspectrogram( y=y, sr=SR, n_mels=N_MELS, hop_length=HOP_LENGTH, n_fft=N_FFT ) mel_db = librosa.power_to_db(mel, ref=np.max) # normalise to 0-1 mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6) # pad or crop to exactly 512 time frames if mel_db.shape[1] >= 512: mel_db = mel_db[:, :512] else: mel_db = np.pad(mel_db, ((0, 0), (0, 512 - mel_db.shape[1]))) return mel_db.astype(np.float32) # (128, 512) def extract_3_crops(y): """Return [start, center, end] crops of exactly N_SAMPLES each.""" total = len(y) if total <= N_SAMPLES: y = np.pad(y, (0, N_SAMPLES - total)) return [y, y, y] start = y[:N_SAMPLES] mid_s = (total - N_SAMPLES) // 2 center = y[mid_s : mid_s + N_SAMPLES] end = y[total - N_SAMPLES:] return [start, center, end] # ── Main prediction function ────────────────────────────────────────────────── def predict_genre(audio_path): """ Takes an audio file path from Gradio, returns a dict of {genre: probability} for the label display. """ if audio_path is None: return {} # 1. Load audio (librosa handles mp3, wav, flac, ogg, etc.) try: y, _ = librosa.load(audio_path, sr=SR, mono=True) except Exception as e: return {f"Error loading audio: {str(e)}": 1.0} # 2. Extract 3 crops for test-time augmentation crops = extract_3_crops(y) # 3. Convert each crop to a mel spectrogram tensor # Shape per crop: (1, 1, 128, 512) ← batch=1, channel=1, H, W mel_tensors = [ torch.tensor(get_mel_spectrogram(crop)).unsqueeze(0).unsqueeze(0) for crop in crops ] # 4. Run all 3 crops through the model, average probabilities (TTA) probs = torch.zeros(1, N_CLASSES) with torch.no_grad(): for mel in mel_tensors: logits = model(mel.to(DEVICE)) probs += torch.softmax(logits, dim=1).cpu() probs /= 3.0 # average over 3 crops # 5. Build output dict for Gradio's Label component prob_np = probs.squeeze().numpy() return {GENRES[i]: float(prob_np[i]) for i in range(N_CLASSES)} # ── Gradio UI ───────────────────────────────────────────────────────────────── demo = gr.Interface( fn=predict_genre, inputs=gr.Audio( type="filepath", label="Upload an audio file (wav, mp3, flac, ogg — any genre)" ), outputs=gr.Label( num_top_classes=10, label="Genre probabilities" ), title="🎵 Music Genre Classifier", description=( "Upload any audio clip and the model will predict its genre.\n\n" "**Genres:** Blues · Classical · Country · Disco · Hip-hop · " "Jazz · Metal · Pop · Reggae · Rock\n\n" "**How it works:** The audio is converted to a mel spectrogram " "(a 2D image of frequency vs time), then passed through a CNN trained " "on 27,000 synthetic audio mashups. Three time-crop predictions are " "averaged for more reliable results." ), examples=[], # add example file paths here if you upload samples ) if __name__ == "__main__": demo.launch()