shubhamrazzsharma commited on
Commit
3425304
Β·
verified Β·
1 Parent(s): 978a35c

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -177
app.py DELETED
@@ -1,177 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- import librosa
6
-
7
- # ── Constants (must match your notebook exactly) ──────────────────────────────
8
- SR = 22050
9
- DURATION = 30
10
- N_SAMPLES = SR * DURATION # 661500
11
- N_MELS = 128
12
- HOP_LENGTH = 512
13
- N_FFT = 2048
14
- N_CLASSES = 10
15
-
16
- GENRES = sorted([
17
- "blues", "classical", "country", "disco", "hiphop",
18
- "jazz", "metal", "pop", "reggae", "rock"
19
- ])
20
-
21
- DEVICE = torch.device("cpu") # HF Spaces free tier = CPU only
22
-
23
-
24
- # ── Model Architecture (must be identical to your notebook Cell 64) ───────────
25
- class GenreCNN(nn.Module):
26
- def __init__(self, n_classes=N_CLASSES):
27
- super().__init__()
28
-
29
- self.conv1 = nn.Sequential(
30
- nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32),
31
- nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
32
- )
33
- self.conv2 = nn.Sequential(
34
- nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64),
35
- nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
36
- )
37
- self.conv3 = nn.Sequential(
38
- nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128),
39
- nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
40
- )
41
- self.conv4 = nn.Sequential(
42
- nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256),
43
- nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
44
- )
45
- self.pool = nn.AdaptiveAvgPool2d((1, 1))
46
- self.classifier = nn.Sequential(
47
- nn.Flatten(),
48
- nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.5),
49
- nn.Linear(128, n_classes)
50
- )
51
-
52
- def forward(self, x):
53
- x = self.conv1(x)
54
- x = self.conv2(x)
55
- x = self.conv3(x)
56
- x = self.conv4(x)
57
- x = self.pool(x)
58
- return self.classifier(x)
59
-
60
-
61
- # ── Load model once at startup ────────────────────────────────────────────────
62
- model = GenreCNN().to(DEVICE)
63
- model.load_state_dict(
64
- torch.load("best_model.pth", map_location=DEVICE)
65
- )
66
- model.eval()
67
- print("βœ“ Model loaded successfully")
68
-
69
-
70
- # ── Audio preprocessing helpers ───────────────────────────────────────────────
71
- def get_mel_spectrogram(y):
72
- """Convert raw audio array β†’ normalised (128, 512) mel spectrogram tensor."""
73
- mel = librosa.feature.melspectrogram(
74
- y=y, sr=SR, n_mels=N_MELS,
75
- hop_length=HOP_LENGTH, n_fft=N_FFT
76
- )
77
- mel_db = librosa.power_to_db(mel, ref=np.max)
78
-
79
- # normalise to 0-1
80
- mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
81
-
82
- # pad or crop to exactly 512 time frames
83
- if mel_db.shape[1] >= 512:
84
- mel_db = mel_db[:, :512]
85
- else:
86
- mel_db = np.pad(mel_db, ((0, 0), (0, 512 - mel_db.shape[1])))
87
-
88
- return mel_db.astype(np.float32) # (128, 512)
89
-
90
-
91
- def extract_3_crops(y):
92
- """Return [start, center, end] crops of exactly N_SAMPLES each."""
93
- total = len(y)
94
-
95
- if total <= N_SAMPLES:
96
- y = np.pad(y, (0, N_SAMPLES - total))
97
- return [y, y, y]
98
-
99
- start = y[:N_SAMPLES]
100
- mid_s = (total - N_SAMPLES) // 2
101
- center = y[mid_s : mid_s + N_SAMPLES]
102
- end = y[total - N_SAMPLES:]
103
- return [start, center, end]
104
-
105
-
106
- # ── Main prediction function ──────────────────────────────────────────────────
107
- def predict_genre(audio_path):
108
- """
109
- Takes an audio file path from Gradio,
110
- returns a dict of {genre: probability} for the label display.
111
- """
112
- if audio_path is None:
113
- return {}
114
-
115
- # 1. Load audio (librosa handles mp3, wav, flac, ogg, etc.)
116
- try:
117
- y, _ = librosa.load(audio_path, sr=SR, mono=True)
118
- except Exception as e:
119
- return {f"Error loading audio: {str(e)}": 1.0}
120
-
121
- # 2. Extract 3 crops for test-time augmentation
122
- crops = extract_3_crops(y)
123
-
124
- # 3. Convert each crop to a mel spectrogram tensor
125
- # Shape per crop: (1, 1, 128, 512) ← batch=1, channel=1, H, W
126
- mel_tensors = [
127
- torch.tensor(get_mel_spectrogram(crop)).unsqueeze(0).unsqueeze(0)
128
- for crop in crops
129
- ]
130
-
131
- # 4. Run all 3 crops through the model, average probabilities (TTA)
132
- probs = torch.zeros(1, N_CLASSES)
133
-
134
- with torch.no_grad():
135
- for mel in mel_tensors:
136
- logits = model(mel.to(DEVICE))
137
- probs += torch.softmax(logits, dim=1).cpu()
138
-
139
- probs /= 3.0 # average over 3 crops
140
-
141
- # 5. Build output dict for Gradio's Label component
142
- prob_np = probs.squeeze().numpy()
143
- return {GENRES[i]: float(prob_np[i]) for i in range(N_CLASSES)}
144
-
145
-
146
- # ── Gradio UI ──────────────���──────────────────────────────────────────────────
147
- demo = gr.Interface(
148
- fn=predict_genre,
149
-
150
- inputs=gr.Audio(
151
- type="filepath",
152
- label="Upload an audio file (wav, mp3, flac, ogg β€” any genre)"
153
- ),
154
-
155
- outputs=gr.Label(
156
- num_top_classes=10,
157
- label="Genre probabilities"
158
- ),
159
-
160
- title="🎡 Music Genre Classifier",
161
- description=(
162
- "Upload any audio clip and the model will predict its genre.\n\n"
163
- "**Genres:** Blues Β· Classical Β· Country Β· Disco Β· Hip-hop Β· "
164
- "Jazz Β· Metal Β· Pop Β· Reggae Β· Rock\n\n"
165
- "**How it works:** The audio is converted to a mel spectrogram "
166
- "(a 2D image of frequency vs time), then passed through a CNN trained "
167
- "on 27,000 synthetic audio mashups. Three time-crop predictions are "
168
- "averaged for more reliable results."
169
- ),
170
-
171
- examples=[], # add example file paths here if you upload samples
172
-
173
- allow_flagging="never",
174
- )
175
-
176
- if __name__ == "__main__":
177
- demo.launch()