File size: 6,122 Bytes
d662341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e34037c
d662341
 
e34037c
 
d662341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import librosa

# ── Constants (must match your notebook exactly) ──────────────────────────────
SR          = 22050
DURATION    = 30
N_SAMPLES   = SR * DURATION      # 661500
N_MELS      = 128
HOP_LENGTH  = 512
N_FFT       = 2048
N_CLASSES   = 10

GENRES = sorted([
    "blues", "classical", "country", "disco", "hiphop",
    "jazz", "metal", "pop", "reggae", "rock"
])

DEVICE = torch.device("cpu")     # HF Spaces free tier = CPU only


# ── Model Architecture (must be identical to your notebook Cell 64) ───────────
class GenreCNN(nn.Module):
    def __init__(self, n_classes=N_CLASSES):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32),
            nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64),
            nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128),
            nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256),
            nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout2d(0.25)
        )
        self.pool       = nn.AdaptiveAvgPool2d((4, 4))   # 256*4*4 = 4096
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, 128),  nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128, n_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pool(x)
        return self.classifier(x)


# ── Load model once at startup ────────────────────────────────────────────────
model = GenreCNN().to(DEVICE)
model.load_state_dict(
    torch.load("best_model.pth", map_location=DEVICE)
)
model.eval()
print("βœ“ Model loaded successfully")


# ── Audio preprocessing helpers ───────────────────────────────────────────────
def get_mel_spectrogram(y):
    """Convert raw audio array β†’ normalised (128, 512) mel spectrogram tensor."""
    mel    = librosa.feature.melspectrogram(
                y=y, sr=SR, n_mels=N_MELS,
                hop_length=HOP_LENGTH, n_fft=N_FFT
             )
    mel_db = librosa.power_to_db(mel, ref=np.max)

    # normalise to 0-1
    mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)

    # pad or crop to exactly 512 time frames
    if mel_db.shape[1] >= 512:
        mel_db = mel_db[:, :512]
    else:
        mel_db = np.pad(mel_db, ((0, 0), (0, 512 - mel_db.shape[1])))

    return mel_db.astype(np.float32)   # (128, 512)


def extract_3_crops(y):
    """Return [start, center, end] crops of exactly N_SAMPLES each."""
    total = len(y)

    if total <= N_SAMPLES:
        y = np.pad(y, (0, N_SAMPLES - total))
        return [y, y, y]

    start  = y[:N_SAMPLES]
    mid_s  = (total - N_SAMPLES) // 2
    center = y[mid_s : mid_s + N_SAMPLES]
    end    = y[total - N_SAMPLES:]
    return [start, center, end]


# ── Main prediction function ──────────────────────────────────────────────────
def predict_genre(audio_path):
    """
    Takes an audio file path from Gradio,
    returns a dict of {genre: probability} for the label display.
    """
    if audio_path is None:
        return {}

    # 1. Load audio (librosa handles mp3, wav, flac, ogg, etc.)
    try:
        y, _ = librosa.load(audio_path, sr=SR, mono=True)
    except Exception as e:
        return {f"Error loading audio: {str(e)}": 1.0}

    # 2. Extract 3 crops for test-time augmentation
    crops = extract_3_crops(y)

    # 3. Convert each crop to a mel spectrogram tensor
    #    Shape per crop: (1, 1, 128, 512)  ← batch=1, channel=1, H, W
    mel_tensors = [
        torch.tensor(get_mel_spectrogram(crop)).unsqueeze(0).unsqueeze(0)
        for crop in crops
    ]

    # 4. Run all 3 crops through the model, average probabilities (TTA)
    probs = torch.zeros(1, N_CLASSES)

    with torch.no_grad():
        for mel in mel_tensors:
            logits  = model(mel.to(DEVICE))
            probs  += torch.softmax(logits, dim=1).cpu()

    probs /= 3.0   # average over 3 crops

    # 5. Build output dict for Gradio's Label component
    prob_np = probs.squeeze().numpy()
    return {GENRES[i]: float(prob_np[i]) for i in range(N_CLASSES)}


# ── Gradio UI ─────────────────────────────────────────────────────────────────
demo = gr.Interface(
    fn=predict_genre,

    inputs=gr.Audio(
        type="filepath",
        label="Upload an audio file (wav, mp3, flac, ogg β€” any genre)"
    ),

    outputs=gr.Label(
        num_top_classes=10,
        label="Genre probabilities"
    ),

    title="🎡 Music Genre Classifier",
    description=(
        "Upload any audio clip and the model will predict its genre.\n\n"
        "**Genres:** Blues Β· Classical Β· Country Β· Disco Β· Hip-hop Β· "
        "Jazz Β· Metal Β· Pop Β· Reggae Β· Rock\n\n"
        "**How it works:** The audio is converted to a mel spectrogram "
        "(a 2D image of frequency vs time), then passed through a CNN trained "
        "on 27,000 synthetic audio mashups. Three time-crop predictions are "
        "averaged for more reliable results."
    ),

    examples=[],   # add example file paths here if you upload samples

)

if __name__ == "__main__":
    demo.launch()