Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import librosa | |
| # ββ Constants (must match your notebook exactly) ββββββββββββββββββββββββββββββ | |
| SR = 22050 | |
| DURATION = 30 | |
| N_SAMPLES = SR * DURATION # 661500 | |
| N_MELS = 128 | |
| HOP_LENGTH = 512 | |
| N_FFT = 2048 | |
| N_CLASSES = 10 | |
| GENRES = sorted([ | |
| "blues", "classical", "country", "disco", "hiphop", | |
| "jazz", "metal", "pop", "reggae", "rock" | |
| ]) | |
| DEVICE = torch.device("cpu") # HF Spaces free tier = CPU only | |
| # ββ Model Architecture (must be identical to your notebook Cell 64) βββββββββββ | |
| class GenreCNN(nn.Module): | |
| def __init__(self, n_classes=N_CLASSES): | |
| super().__init__() | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), | |
| nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25) | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), | |
| nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25) | |
| ) | |
| self.conv3 = nn.Sequential( | |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), | |
| nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25) | |
| ) | |
| self.conv4 = nn.Sequential( | |
| nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), | |
| nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25) | |
| ) | |
| self.pool = nn.AdaptiveAvgPool2d((4, 4)) # 256*4*4 = 4096 | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.5), | |
| nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5), | |
| nn.Linear(128, n_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x = self.conv3(x) | |
| x = self.conv4(x) | |
| x = self.pool(x) | |
| return self.classifier(x) | |
| # ββ Load model once at startup ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = GenreCNN().to(DEVICE) | |
| model.load_state_dict( | |
| torch.load("best_model.pth", map_location=DEVICE) | |
| ) | |
| model.eval() | |
| print("β Model loaded successfully") | |
| # ββ Audio preprocessing helpers βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_mel_spectrogram(y): | |
| """Convert raw audio array β normalised (128, 512) mel spectrogram tensor.""" | |
| mel = librosa.feature.melspectrogram( | |
| y=y, sr=SR, n_mels=N_MELS, | |
| hop_length=HOP_LENGTH, n_fft=N_FFT | |
| ) | |
| mel_db = librosa.power_to_db(mel, ref=np.max) | |
| # normalise to 0-1 | |
| mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6) | |
| # pad or crop to exactly 512 time frames | |
| if mel_db.shape[1] >= 512: | |
| mel_db = mel_db[:, :512] | |
| else: | |
| mel_db = np.pad(mel_db, ((0, 0), (0, 512 - mel_db.shape[1]))) | |
| return mel_db.astype(np.float32) # (128, 512) | |
| def extract_3_crops(y): | |
| """Return [start, center, end] crops of exactly N_SAMPLES each.""" | |
| total = len(y) | |
| if total <= N_SAMPLES: | |
| y = np.pad(y, (0, N_SAMPLES - total)) | |
| return [y, y, y] | |
| start = y[:N_SAMPLES] | |
| mid_s = (total - N_SAMPLES) // 2 | |
| center = y[mid_s : mid_s + N_SAMPLES] | |
| end = y[total - N_SAMPLES:] | |
| return [start, center, end] | |
| # ββ Main prediction function ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_genre(audio_path): | |
| """ | |
| Takes an audio file path from Gradio, | |
| returns a dict of {genre: probability} for the label display. | |
| """ | |
| if audio_path is None: | |
| return {} | |
| # 1. Load audio (librosa handles mp3, wav, flac, ogg, etc.) | |
| try: | |
| y, _ = librosa.load(audio_path, sr=SR, mono=True) | |
| except Exception as e: | |
| return {f"Error loading audio: {str(e)}": 1.0} | |
| # 2. Extract 3 crops for test-time augmentation | |
| crops = extract_3_crops(y) | |
| # 3. Convert each crop to a mel spectrogram tensor | |
| # Shape per crop: (1, 1, 128, 512) β batch=1, channel=1, H, W | |
| mel_tensors = [ | |
| torch.tensor(get_mel_spectrogram(crop)).unsqueeze(0).unsqueeze(0) | |
| for crop in crops | |
| ] | |
| # 4. Run all 3 crops through the model, average probabilities (TTA) | |
| probs = torch.zeros(1, N_CLASSES) | |
| with torch.no_grad(): | |
| for mel in mel_tensors: | |
| logits = model(mel.to(DEVICE)) | |
| probs += torch.softmax(logits, dim=1).cpu() | |
| probs /= 3.0 # average over 3 crops | |
| # 5. Build output dict for Gradio's Label component | |
| prob_np = probs.squeeze().numpy() | |
| return {GENRES[i]: float(prob_np[i]) for i in range(N_CLASSES)} | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=predict_genre, | |
| inputs=gr.Audio( | |
| type="filepath", | |
| label="Upload an audio file (wav, mp3, flac, ogg β any genre)" | |
| ), | |
| outputs=gr.Label( | |
| num_top_classes=10, | |
| label="Genre probabilities" | |
| ), | |
| title="π΅ Music Genre Classifier", | |
| description=( | |
| "Upload any audio clip and the model will predict its genre.\n\n" | |
| "**Genres:** Blues Β· Classical Β· Country Β· Disco Β· Hip-hop Β· " | |
| "Jazz Β· Metal Β· Pop Β· Reggae Β· Rock\n\n" | |
| "**How it works:** The audio is converted to a mel spectrogram " | |
| "(a 2D image of frequency vs time), then passed through a CNN trained " | |
| "on 27,000 synthetic audio mashups. Three time-crop predictions are " | |
| "averaged for more reliable results." | |
| ), | |
| examples=[], # add example file paths here if you upload samples | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |