| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
| class ResBlock1D(nn.Module): |
| """ |
| Residual Block for extracting rhythmic features from audio spectrograms. |
| Maintains temporal resolution while increasing receptive field. |
| """ |
|
|
| def __init__(self, channels, kernel_size=3, dilation=1): |
| super().__init__() |
| padding = (kernel_size - 1) * dilation // 2 |
| self.conv1 = nn.Conv1d( |
| channels, channels, kernel_size, padding=padding, dilation=dilation |
| ) |
| self.bn1 = nn.BatchNorm1d(channels) |
| self.conv2 = nn.Conv1d( |
| channels, channels, kernel_size, padding=padding, dilation=dilation |
| ) |
| self.bn2 = nn.BatchNorm1d(channels) |
|
|
| def forward(self, x): |
| res = x |
| x = F.gelu(self.bn1(self.conv1(x))) |
| x = self.bn2(self.conv2(x)) |
| return F.gelu(x + res) |
|
|
|
|
| class GameChartEvaluator(nn.Module, PyTorchModelHubMixin): |
| def __init__(self, input_dim=80, d_model=128, n_layers=4): |
| super().__init__() |
|
|
| |
| |
| |
| self.input_proj = nn.Conv1d( |
| input_dim * 2, d_model, kernel_size=3, stride=1, padding=1 |
| ) |
|
|
| |
| |
| |
| self.encoder = nn.Sequential( |
| ResBlock1D(d_model, kernel_size=3, dilation=1), |
| ResBlock1D(d_model, kernel_size=3, dilation=2), |
| ResBlock1D(d_model, kernel_size=3, dilation=4), |
| ResBlock1D(d_model, kernel_size=3, dilation=8), |
| |
| ) |
|
|
| |
| |
| self.quality_proj = nn.Linear(d_model, 1) |
|
|
| |
| self.raw_severity = nn.Parameter(torch.tensor(0.0)) |
|
|
| def forward(self, music_mels, chart_mels): |
| """ |
| music_mels: (Batch, 80, Time) |
| chart_mels: (Batch, 80, Time) |
| """ |
| |
| |
| x = torch.cat([music_mels, chart_mels], dim=1) |
|
|
| |
| x = F.gelu(self.input_proj(x)) |
| x = self.encoder(x) |
|
|
| |
| |
| x = x.permute(0, 2, 1) |
| local_scores = torch.sigmoid(self.quality_proj(x)) |
|
|
| |
| avg_score = local_scores.mean(dim=1) |
|
|
| k = max(1, int(local_scores.size(1) * 0.1)) |
| min_vals, _ = torch.topk(local_scores, k, dim=1, largest=False) |
| worst_score = min_vals.mean(dim=1) |
|
|
| alpha = torch.sigmoid(self.raw_severity) |
| final_score = (alpha * worst_score) + ((1 - alpha) * avg_score) |
|
|
| return final_score.squeeze(1) |
|
|
| def predict_trace(self, music_mels, chart_mels): |
| """ |
| Explainability Method: Returns the second-by-second quality curve. |
| |
| Returns: |
| local_scores: (Batch, Time) - The quality score at every timestep. |
| """ |
| with torch.no_grad(): |
| |
| |
| x = torch.cat([music_mels, chart_mels], dim=1) |
|
|
| |
| x = F.gelu(self.input_proj(x)) |
| x = self.encoder(x) |
|
|
| |
| |
| x = x.permute(0, 2, 1) |
| local_scores = torch.sigmoid(self.quality_proj(x)) |
| return local_scores.squeeze(2) |
|
|
|
|
| if __name__ == "__main__": |
| |
| from torchinfo import summary |
|
|
| model = GameChartEvaluator() |
| print( |
| f"Model initialized. Learnable Severity: {torch.sigmoid(model.raw_severity).item():.2f}" |
| ) |
|
|
| |
| m = torch.randn(2, 80, 1000) |
| c = torch.randn(2, 80, 1000) |
|
|
| output = model(m, c) |
| print(f"Output shape: {output.shape}") |
| print(f"Scores: {output}") |
|
|
| |
| trace = model.predict_trace(m, c) |
| print( |
| f"Trace shape: {trace.shape}" |
| ) |
|
|
| summary(model, input_data=[m, c]) |
|
|