shubhamrazzsharma commited on
Commit
beecf40
Β·
verified Β·
1 Parent(s): 32dc4a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -106
app.py CHANGED
@@ -1,140 +1,177 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- import librosa
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # ─────────────────────────────────────────────
8
- # 1. PASTE YOUR CNN ARCHITECTURE HERE
9
- # (copy the class definition from your Kaggle notebook)
10
- # ─────────────────────────────────────────────
11
- class CNNModel(nn.Module):
12
- def __init__(self, num_classes=10):
13
- super(CNNModel, self).__init__()
14
- # ⬇⬇ REPLACE THIS BLOCK WITH YOUR ACTUAL ARCHITECTURE ⬇⬇
15
  self.conv1 = nn.Sequential(
16
- nn.Conv2d(1, 32, kernel_size=3, padding=1),
17
- nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2)
18
  )
19
  self.conv2 = nn.Sequential(
20
- nn.Conv2d(32, 64, kernel_size=3, padding=1),
21
- nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)
22
  )
23
  self.conv3 = nn.Sequential(
24
- nn.Conv2d(64, 128, kernel_size=3, padding=1),
25
- nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2)
 
 
 
 
26
  )
27
- self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
28
  self.classifier = nn.Sequential(
29
  nn.Flatten(),
30
- nn.Linear(128, 256), nn.ReLU(), nn.Dropout(0.3),
31
- nn.Linear(256, num_classes)
32
  )
33
- # ⬆⬆ REPLACE UP TO HERE ⬆⬆
34
 
35
  def forward(self, x):
36
  x = self.conv1(x)
37
  x = self.conv2(x)
38
  x = self.conv3(x)
39
- x = self.global_avg_pool(x)
 
40
  return self.classifier(x)
41
 
42
- # ─────────────────────────────────────────────
43
- # 2. CONFIG β€” change these if needed
44
- # ─────────────────────────────────────────────
45
- NUM_CLASSES = 10
46
- SAMPLE_RATE = 22050
47
- N_MELS = 128
48
- N_FFT = 2048
49
- HOP_LENGTH = 512
50
- DURATION = 30 # seconds of audio to use
51
- TARGET_SHAPE = (128, 512) # must match your training shape
52
-
53
- GENRES = [
54
- "blues", "classical", "country", "disco", "hiphop",
55
- "jazz", "metal", "pop", "reggae", "rock"
56
- ]
57
-
58
- # ─────────────────────────────────────────────
59
- # 3. LOAD MODEL (runs once at startup)
60
- # ─────────────────────────────────────────────
61
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
 
63
- model = CNNModel(num_classes=NUM_CLASSES)
 
64
  model.load_state_dict(
65
- torch.load("best_model (1).pth", map_location=device)
66
  )
67
- model.to(device)
68
  model.eval()
 
69
 
70
- # ─────────────────────────────────────────────
71
- # 4. PREPROCESSING β€” same pipeline as training
72
- # ─────────────────────────────────────────────
73
- def audio_to_melspectrogram(audio_path):
74
- y, sr = librosa.load(audio_path, sr=SAMPLE_RATE, duration=DURATION, mono=True)
75
-
76
- # Pad if clip is shorter than DURATION
77
- target_length = SAMPLE_RATE * DURATION
78
- if len(y) < target_length:
79
- y = np.pad(y, (0, target_length - len(y)))
80
-
81
- mel = librosa.feature.melspectrogram(
82
- y=y, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH
83
- )
84
  mel_db = librosa.power_to_db(mel, ref=np.max)
85
-
86
- # Resize to training shape (128, 512)
87
- if mel_db.shape != TARGET_SHAPE:
88
- from PIL import Image
89
- import PIL
90
- mel_img = Image.fromarray(mel_db)
91
- mel_img = mel_img.resize((TARGET_SHAPE[1], TARGET_SHAPE[0]), PIL.Image.BILINEAR)
92
- mel_db = np.array(mel_img)
93
-
94
- # Normalize to [0, 1]
95
  mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
96
- return mel_db
97
 
98
- # ─────────────────────────────────────────────
99
- # 5. INFERENCE
100
- # ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def predict_genre(audio_path):
 
 
 
 
102
  if audio_path is None:
103
  return {}
104
-
 
105
  try:
106
- mel = audio_to_melspectrogram(audio_path) # (128, 512)
107
- tensor = torch.tensor(mel, dtype=torch.float32)
108
- tensor = tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 128, 512)
109
-
110
- with torch.no_grad():
111
- logits = model(tensor)
112
- probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
113
-
114
- return {GENRES[i]: float(probs[i]) for i in range(NUM_CLASSES)}
115
-
116
  except Exception as e:
117
- return {"error": str(e)}
118
-
119
- # ─────────────────────────────────────────────
120
- # 6. GRADIO UI
121
- # ─────────────────────────────────────────────
122
- with gr.Blocks(title="Music Genre Classifier") as demo:
123
- gr.Markdown("## 🎡 Music Genre Classifier")
124
- gr.Markdown("Upload a song clip and the model will predict its genre.")
125
-
126
- with gr.Row():
127
- audio_input = gr.Audio(type="filepath", label="Upload Audio (.wav / .mp3)")
128
-
129
- predict_btn = gr.Button("Predict Genre", variant="primary")
130
-
131
- output = gr.Label(num_top_classes=5, label="Genre Probabilities")
132
-
133
- predict_btn.click(fn=predict_genre, inputs=audio_input, outputs=output)
134
-
135
- gr.Examples(
136
- examples=[], # optionally add example audio file paths here
137
- inputs=audio_input
138
- )
139
-
140
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
 
4
  import numpy as np
5
+ import librosa
6
+
7
+ # ── Constants (must match your notebook exactly) ──────────────────────────────
8
+ SR = 22050
9
+ DURATION = 30
10
+ N_SAMPLES = SR * DURATION # 661500
11
+ N_MELS = 128
12
+ HOP_LENGTH = 512
13
+ N_FFT = 2048
14
+ N_CLASSES = 10
15
+
16
+ GENRES = sorted([
17
+ "blues", "classical", "country", "disco", "hiphop",
18
+ "jazz", "metal", "pop", "reggae", "rock"
19
+ ])
20
+
21
+ DEVICE = torch.device("cpu") # HF Spaces free tier = CPU only
22
+
23
+
24
+ # ── Model Architecture (must be identical to your notebook Cell 64) ───────────
25
+ class GenreCNN(nn.Module):
26
+ def __init__(self, n_classes=N_CLASSES):
27
+ super().__init__()
28
 
 
 
 
 
 
 
 
 
29
  self.conv1 = nn.Sequential(
30
+ nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32),
31
+ nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
32
  )
33
  self.conv2 = nn.Sequential(
34
+ nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64),
35
+ nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
36
  )
37
  self.conv3 = nn.Sequential(
38
+ nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128),
39
+ nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
40
+ )
41
+ self.conv4 = nn.Sequential(
42
+ nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256),
43
+ nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
44
  )
45
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
46
  self.classifier = nn.Sequential(
47
  nn.Flatten(),
48
+ nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.5),
49
+ nn.Linear(128, n_classes)
50
  )
 
51
 
52
  def forward(self, x):
53
  x = self.conv1(x)
54
  x = self.conv2(x)
55
  x = self.conv3(x)
56
+ x = self.conv4(x)
57
+ x = self.pool(x)
58
  return self.classifier(x)
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # ── Load model once at startup ────────────────────────────────────────────────
62
+ model = GenreCNN().to(DEVICE)
63
  model.load_state_dict(
64
+ torch.load("best_model (1).pth", map_location=DEVICE)
65
  )
 
66
  model.eval()
67
+ print("βœ“ Model loaded successfully")
68
 
69
+
70
+ # ── Audio preprocessing helpers ───────────────────────────────────────────────
71
+ def get_mel_spectrogram(y):
72
+ """Convert raw audio array β†’ normalised (128, 512) mel spectrogram tensor."""
73
+ mel = librosa.feature.melspectrogram(
74
+ y=y, sr=SR, n_mels=N_MELS,
75
+ hop_length=HOP_LENGTH, n_fft=N_FFT
76
+ )
 
 
 
 
 
 
77
  mel_db = librosa.power_to_db(mel, ref=np.max)
78
+
79
+ # normalise to 0-1
 
 
 
 
 
 
 
 
80
  mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
 
81
 
82
+ # pad or crop to exactly 512 time frames
83
+ if mel_db.shape[1] >= 512:
84
+ mel_db = mel_db[:, :512]
85
+ else:
86
+ mel_db = np.pad(mel_db, ((0, 0), (0, 512 - mel_db.shape[1])))
87
+
88
+ return mel_db.astype(np.float32) # (128, 512)
89
+
90
+
91
+ def extract_3_crops(y):
92
+ """Return [start, center, end] crops of exactly N_SAMPLES each."""
93
+ total = len(y)
94
+
95
+ if total <= N_SAMPLES:
96
+ y = np.pad(y, (0, N_SAMPLES - total))
97
+ return [y, y, y]
98
+
99
+ start = y[:N_SAMPLES]
100
+ mid_s = (total - N_SAMPLES) // 2
101
+ center = y[mid_s : mid_s + N_SAMPLES]
102
+ end = y[total - N_SAMPLES:]
103
+ return [start, center, end]
104
+
105
+
106
+ # ── Main prediction function ──────────────────────────────────────────────────
107
  def predict_genre(audio_path):
108
+ """
109
+ Takes an audio file path from Gradio,
110
+ returns a dict of {genre: probability} for the label display.
111
+ """
112
  if audio_path is None:
113
  return {}
114
+
115
+ # 1. Load audio (librosa handles mp3, wav, flac, ogg, etc.)
116
  try:
117
+ y, _ = librosa.load(audio_path, sr=SR, mono=True)
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
+ return {f"Error loading audio: {str(e)}": 1.0}
120
+
121
+ # 2. Extract 3 crops for test-time augmentation
122
+ crops = extract_3_crops(y)
123
+
124
+ # 3. Convert each crop to a mel spectrogram tensor
125
+ # Shape per crop: (1, 1, 128, 512) ← batch=1, channel=1, H, W
126
+ mel_tensors = [
127
+ torch.tensor(get_mel_spectrogram(crop)).unsqueeze(0).unsqueeze(0)
128
+ for crop in crops
129
+ ]
130
+
131
+ # 4. Run all 3 crops through the model, average probabilities (TTA)
132
+ probs = torch.zeros(1, N_CLASSES)
133
+
134
+ with torch.no_grad():
135
+ for mel in mel_tensors:
136
+ logits = model(mel.to(DEVICE))
137
+ probs += torch.softmax(logits, dim=1).cpu()
138
+
139
+ probs /= 3.0 # average over 3 crops
140
+
141
+ # 5. Build output dict for Gradio's Label component
142
+ prob_np = probs.squeeze().numpy()
143
+ return {GENRES[i]: float(prob_np[i]) for i in range(N_CLASSES)}
144
+
145
+
146
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
147
+ demo = gr.Interface(
148
+ fn=predict_genre,
149
+
150
+ inputs=gr.Audio(
151
+ type="filepath",
152
+ label="Upload an audio file (wav, mp3, flac, ogg β€” any genre)"
153
+ ),
154
+
155
+ outputs=gr.Label(
156
+ num_top_classes=10,
157
+ label="Genre probabilities"
158
+ ),
159
+
160
+ title="🎡 Music Genre Classifier",
161
+ description=(
162
+ "Upload any audio clip and the model will predict its genre.\n\n"
163
+ "**Genres:** Blues Β· Classical Β· Country Β· Disco Β· Hip-hop Β· "
164
+ "Jazz Β· Metal Β· Pop Β· Reggae Β· Rock\n\n"
165
+ "**How it works:** The audio is converted to a mel spectrogram "
166
+ "(a 2D image of frequency vs time), then passed through a CNN trained "
167
+ "on 27,000 synthetic audio mashups. Three time-crop predictions are "
168
+ "averaged for more reliable results."
169
+ ),
170
+
171
+ examples=[], # add example file paths here if you upload samples
172
+
173
+ allow_flagging="never",
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ demo.launch()