shubhamrazzsharma commited on
Commit
d662341
Β·
verified Β·
1 Parent(s): 3425304

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import librosa
6
+
7
+ # ── Constants (must match your notebook exactly) ──────────────────────────────
8
+ SR = 22050
9
+ DURATION = 30
10
+ N_SAMPLES = SR * DURATION # 661500
11
+ N_MELS = 128
12
+ HOP_LENGTH = 512
13
+ N_FFT = 2048
14
+ N_CLASSES = 10
15
+
16
+ GENRES = sorted([
17
+ "blues", "classical", "country", "disco", "hiphop",
18
+ "jazz", "metal", "pop", "reggae", "rock"
19
+ ])
20
+
21
+ DEVICE = torch.device("cpu") # HF Spaces free tier = CPU only
22
+
23
+
24
+ # ── Model Architecture (must be identical to your notebook Cell 64) ───────────
25
+ class GenreCNN(nn.Module):
26
+ def __init__(self, n_classes=N_CLASSES):
27
+ super().__init__()
28
+
29
+ self.conv1 = nn.Sequential(
30
+ nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32),
31
+ nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
32
+ )
33
+ self.conv2 = nn.Sequential(
34
+ nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64),
35
+ nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
36
+ )
37
+ self.conv3 = nn.Sequential(
38
+ nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128),
39
+ nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
40
+ )
41
+ self.conv4 = nn.Sequential(
42
+ nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256),
43
+ nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
44
+ )
45
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
46
+ self.classifier = nn.Sequential(
47
+ nn.Flatten(),
48
+ nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.5),
49
+ nn.Linear(128, n_classes)
50
+ )
51
+
52
+ def forward(self, x):
53
+ x = self.conv1(x)
54
+ x = self.conv2(x)
55
+ x = self.conv3(x)
56
+ x = self.conv4(x)
57
+ x = self.pool(x)
58
+ return self.classifier(x)
59
+
60
+
61
+ # ── Load model once at startup ────────────────────────────────────────────────
62
+ model = GenreCNN().to(DEVICE)
63
+ model.load_state_dict(
64
+ torch.load("best_model.pth", map_location=DEVICE)
65
+ )
66
+ model.eval()
67
+ print("βœ“ Model loaded successfully")
68
+
69
+
70
+ # ── Audio preprocessing helpers ───────────────────────────────────────────────
71
+ def get_mel_spectrogram(y):
72
+ """Convert raw audio array β†’ normalised (128, 512) mel spectrogram tensor."""
73
+ mel = librosa.feature.melspectrogram(
74
+ y=y, sr=SR, n_mels=N_MELS,
75
+ hop_length=HOP_LENGTH, n_fft=N_FFT
76
+ )
77
+ mel_db = librosa.power_to_db(mel, ref=np.max)
78
+
79
+ # normalise to 0-1
80
+ mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
81
+
82
+ # pad or crop to exactly 512 time frames
83
+ if mel_db.shape[1] >= 512:
84
+ mel_db = mel_db[:, :512]
85
+ else:
86
+ mel_db = np.pad(mel_db, ((0, 0), (0, 512 - mel_db.shape[1])))
87
+
88
+ return mel_db.astype(np.float32) # (128, 512)
89
+
90
+
91
+ def extract_3_crops(y):
92
+ """Return [start, center, end] crops of exactly N_SAMPLES each."""
93
+ total = len(y)
94
+
95
+ if total <= N_SAMPLES:
96
+ y = np.pad(y, (0, N_SAMPLES - total))
97
+ return [y, y, y]
98
+
99
+ start = y[:N_SAMPLES]
100
+ mid_s = (total - N_SAMPLES) // 2
101
+ center = y[mid_s : mid_s + N_SAMPLES]
102
+ end = y[total - N_SAMPLES:]
103
+ return [start, center, end]
104
+
105
+
106
+ # ── Main prediction function ──────────────────────────────────────────────────
107
+ def predict_genre(audio_path):
108
+ """
109
+ Takes an audio file path from Gradio,
110
+ returns a dict of {genre: probability} for the label display.
111
+ """
112
+ if audio_path is None:
113
+ return {}
114
+
115
+ # 1. Load audio (librosa handles mp3, wav, flac, ogg, etc.)
116
+ try:
117
+ y, _ = librosa.load(audio_path, sr=SR, mono=True)
118
+ except Exception as e:
119
+ return {f"Error loading audio: {str(e)}": 1.0}
120
+
121
+ # 2. Extract 3 crops for test-time augmentation
122
+ crops = extract_3_crops(y)
123
+
124
+ # 3. Convert each crop to a mel spectrogram tensor
125
+ # Shape per crop: (1, 1, 128, 512) ← batch=1, channel=1, H, W
126
+ mel_tensors = [
127
+ torch.tensor(get_mel_spectrogram(crop)).unsqueeze(0).unsqueeze(0)
128
+ for crop in crops
129
+ ]
130
+
131
+ # 4. Run all 3 crops through the model, average probabilities (TTA)
132
+ probs = torch.zeros(1, N_CLASSES)
133
+
134
+ with torch.no_grad():
135
+ for mel in mel_tensors:
136
+ logits = model(mel.to(DEVICE))
137
+ probs += torch.softmax(logits, dim=1).cpu()
138
+
139
+ probs /= 3.0 # average over 3 crops
140
+
141
+ # 5. Build output dict for Gradio's Label component
142
+ prob_np = probs.squeeze().numpy()
143
+ return {GENRES[i]: float(prob_np[i]) for i in range(N_CLASSES)}
144
+
145
+
146
+ # ── Gradio UI ──────────────���──────────────────────────────────────────────────
147
+ demo = gr.Interface(
148
+ fn=predict_genre,
149
+
150
+ inputs=gr.Audio(
151
+ type="filepath",
152
+ label="Upload an audio file (wav, mp3, flac, ogg β€” any genre)"
153
+ ),
154
+
155
+ outputs=gr.Label(
156
+ num_top_classes=10,
157
+ label="Genre probabilities"
158
+ ),
159
+
160
+ title="🎡 Music Genre Classifier",
161
+ description=(
162
+ "Upload any audio clip and the model will predict its genre.\n\n"
163
+ "**Genres:** Blues Β· Classical Β· Country Β· Disco Β· Hip-hop Β· "
164
+ "Jazz Β· Metal Β· Pop Β· Reggae Β· Rock\n\n"
165
+ "**How it works:** The audio is converted to a mel spectrogram "
166
+ "(a 2D image of frequency vs time), then passed through a CNN trained "
167
+ "on 27,000 synthetic audio mashups. Three time-crop predictions are "
168
+ "averaged for more reliable results."
169
+ ),
170
+
171
+ examples=[], # add example file paths here if you upload samples
172
+
173
+ allow_flagging="never",
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ demo.launch()