Initial release: TheArtist chord-generation paper companion
Browse files- README.md +98 -0
- best.pt +3 -0
- config.json +25 -0
- eval_results.csv +8 -0
- model.py +294 -0
- tokenizer.json +356 -0
- tokenizer.py +379 -0
README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: other
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
tags:
|
| 5 |
+
- music
|
| 6 |
+
- music-generation
|
| 7 |
+
- chord-generation
|
| 8 |
+
- symbolic-music
|
| 9 |
+
- music-transformer
|
| 10 |
+
- jazz
|
| 11 |
+
- pop
|
| 12 |
+
language:
|
| 13 |
+
- en
|
| 14 |
+
pipeline_tag: text-generation
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# TheArtist Music Transformer — F2 (Pop 5K Mix)
|
| 18 |
+
|
| 19 |
+
**Jazz-adapted chord model with a 5,000-sequence pop rehearsal buffer. Calibration point that the paper finds is dominated by F3 on every axis.**
|
| 20 |
+
|
| 21 |
+
One of six checkpoints released alongside the paper *Empirical Study of Pop and Jazz Mix Ratios for Genre-Adaptive Chord Generation* (Lee, 2026). See the collection overview at `PearlLeeStudio/TheArtist-MusicTransformer-pop-baseline`.
|
| 22 |
+
|
| 23 |
+
## Model summary
|
| 24 |
+
|
| 25 |
+
| Field | Value |
|
| 26 |
+
|---|---|
|
| 27 |
+
| Architecture | Music Transformer with relative positional attention |
|
| 28 |
+
| Parameters | 25,661,440 |
|
| 29 |
+
| Vocabulary size | 351 tokens |
|
| 30 |
+
| Max sequence length | 256 |
|
| 31 |
+
| d_model / heads / FFN / layers | 512 / 8 / 2048 / 8 |
|
| 32 |
+
| Fine-tune resumed from | Phase 0 pop baseline |
|
| 33 |
+
| Best epoch | 4 |
|
| 34 |
+
|
| 35 |
+
## Training data
|
| 36 |
+
|
| 37 |
+
All 1,513 jazz training sequences plus 5,000 pop rehearsal sequences (seed 42). Pop:jazz ≈ 3.3:1.
|
| 38 |
+
|
| 39 |
+
## Evaluation (held-out per-genre test sets)
|
| 40 |
+
|
| 41 |
+
| Metric | Pop test | Jazz test |
|
| 42 |
+
|---|---:|---:|
|
| 43 |
+
| Top-1 accuracy | 84.07% | 79.90% |
|
| 44 |
+
| Top-5 accuracy | 97.04% | 92.14% |
|
| 45 |
+
| Perplexity | 1.75 | 2.33 |
|
| 46 |
+
| Δ vs. Phase 0 baseline | −0.17 | +7.04 |
|
| 47 |
+
|
| 48 |
+
F2 is dominated by F3 on every axis. It is released for reproducibility of the saturation curve described in the paper (see paper §6.1, §7.3) but is not the recommended choice for any operating point. Prefer F3 for the balanced setting, F1 for pop-leaning, or F4 for jazz-leaning.
|
| 49 |
+
|
| 50 |
+
## Intended use
|
| 51 |
+
|
| 52 |
+
Reference checkpoint for replication and saturation-curve analysis. Not recommended as a default for chord-composition workflows.
|
| 53 |
+
|
| 54 |
+
## Usage
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
import torch
|
| 58 |
+
from huggingface_hub import hf_hub_download
|
| 59 |
+
from model import MusicTransformer
|
| 60 |
+
from tokenizer import ChordTokenizer
|
| 61 |
+
|
| 62 |
+
ckpt_path = hf_hub_download(
|
| 63 |
+
repo_id="PearlLeeStudio/TheArtist-MusicTransformer-ft-pop67",
|
| 64 |
+
filename="best.pt",
|
| 65 |
+
)
|
| 66 |
+
tokenizer = ChordTokenizer()
|
| 67 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 68 |
+
|
| 69 |
+
model = MusicTransformer(
|
| 70 |
+
vocab_size=tokenizer.vocab_size,
|
| 71 |
+
d_model=512, n_heads=8, d_ff=2048, n_layers=8,
|
| 72 |
+
max_seq_len=256, dropout=0.0, pad_id=tokenizer.pad_id,
|
| 73 |
+
)
|
| 74 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 75 |
+
model.eval()
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Training-data licenses
|
| 79 |
+
|
| 80 |
+
| Dataset | License |
|
| 81 |
+
|---|---|
|
| 82 |
+
| Chordonomicon | Public (user-generated) |
|
| 83 |
+
| McGill Billboard | CC0 |
|
| 84 |
+
| Jazz Harmony Treebank | Public |
|
| 85 |
+
| JazzStandards (iReal Pro) | Community redistribution |
|
| 86 |
+
| Weimar Jazz Database | ODbL |
|
| 87 |
+
| JAAH | Research-use public |
|
| 88 |
+
|
| 89 |
+
## Citation
|
| 90 |
+
|
| 91 |
+
```bibtex
|
| 92 |
+
@misc{lee2026chordmix,
|
| 93 |
+
title = {Empirical Study of Pop and Jazz Mix Ratios for Genre-Adaptive Chord Generation},
|
| 94 |
+
author = {Lee, Jinju},
|
| 95 |
+
year = {2026},
|
| 96 |
+
eprint = {arXiv:XXXX.XXXXX}
|
| 97 |
+
}
|
| 98 |
+
```
|
best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9aebcae5294c331aab43c509af698062f9bc7e50fb06e92280ee93126491d7b
|
| 3 |
+
size 308077642
|
config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"run_name": "ft_jazz_pop67",
|
| 3 |
+
"resume_from": "checkpoints/phase0_pop_baseline/best.pt",
|
| 4 |
+
"pop_mix_count": 5000,
|
| 5 |
+
"epochs": 10,
|
| 6 |
+
"batch_size": 64,
|
| 7 |
+
"gradient_accumulation_steps": 2,
|
| 8 |
+
"lr": 2e-05,
|
| 9 |
+
"weight_decay": 0.01,
|
| 10 |
+
"warmup_epochs": 2,
|
| 11 |
+
"max_grad_norm": 1.0,
|
| 12 |
+
"d_model": 512,
|
| 13 |
+
"n_heads": 8,
|
| 14 |
+
"d_ff": 2048,
|
| 15 |
+
"n_layers": 8,
|
| 16 |
+
"max_seq_len": 256,
|
| 17 |
+
"dropout": 0.1,
|
| 18 |
+
"use_amp": true,
|
| 19 |
+
"checkpoint_every": 1,
|
| 20 |
+
"patience": 5,
|
| 21 |
+
"num_workers": 4,
|
| 22 |
+
"persistent_workers": true,
|
| 23 |
+
"prefetch_factor": 4,
|
| 24 |
+
"log_every_steps": 200
|
| 25 |
+
}
|
eval_results.csv
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
epoch,lr,train_loss,val_loss,val_ppl,val_top1,val_top5,pop_loss,pop_ppl,pop_top1,pop_top5,jazz_loss,jazz_ppl,jazz_top1,jazz_top5
|
| 2 |
+
3,2.56e-04,0.7450,0.5703,1.77,83.99,96.81,0.5490,1.73,84.21,97.09,1.3893,4.01,72.86,86.51
|
| 3 |
+
4,2.07e-04,0.6459,0.5660,1.76,84.03,96.05,0.5599,1.75,84.06,96.17,0.8482,2.34,79.91,91.46
|
| 4 |
+
5,1.50e-04,0.6020,0.5750,1.78,83.83,96.01,0.5694,1.77,83.87,96.12,0.8305,2.29,80.28,91.86
|
| 5 |
+
6,9.26e-05,0.5770,0.5843,1.79,83.67,95.97,0.5790,1.78,83.69,96.08,0.8298,2.29,80.33,91.73
|
| 6 |
+
7,4.39e-05,0.5587,0.5926,1.81,83.54,95.94,0.5879,1.80,83.54,96.05,0.8339,2.30,80.19,91.77
|
| 7 |
+
8,1.14e-05,0.5471,0.5983,1.82,83.40,96.78,0.5937,1.81,83.41,96.88,0.8365,2.31,80.09,92.58
|
| 8 |
+
9,0.00e+00,0.5410,0.6005,1.82,83.38,96.78,0.5958,1.81,83.39,96.87,0.8374,2.31,80.12,92.62
|
model.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Music Transformer with relative attention for chord generation.
|
| 2 |
+
|
| 3 |
+
Architecture: Transformer decoder (autoregressive) with relative position
|
| 4 |
+
encoding (Shaw et al. 2018, efficient skewing from Huang et al. 2018).
|
| 5 |
+
|
| 6 |
+
Default config (~25M params):
|
| 7 |
+
d_model=512, n_heads=8, d_ff=2048, n_layers=8
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RelativeMultiHeadAttention(nn.Module):
|
| 20 |
+
"""Multi-head self-attention with relative position bias."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
d_model: int,
|
| 25 |
+
n_heads: int,
|
| 26 |
+
max_seq_len: int,
|
| 27 |
+
dropout: float = 0.1,
|
| 28 |
+
) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
assert d_model % n_heads == 0
|
| 31 |
+
self.n_heads = n_heads
|
| 32 |
+
self.d_k = d_model // n_heads
|
| 33 |
+
self.scale = math.sqrt(self.d_k)
|
| 34 |
+
|
| 35 |
+
self.w_q = nn.Linear(d_model, d_model)
|
| 36 |
+
self.w_k = nn.Linear(d_model, d_model)
|
| 37 |
+
self.w_v = nn.Linear(d_model, d_model)
|
| 38 |
+
self.w_o = nn.Linear(d_model, d_model)
|
| 39 |
+
|
| 40 |
+
# Learnable relative position embeddings: positions in [-max_len+1, max_len-1]
|
| 41 |
+
self.max_seq_len = max_seq_len
|
| 42 |
+
self.rel_emb = nn.Embedding(2 * max_seq_len - 1, self.d_k)
|
| 43 |
+
self.dropout = nn.Dropout(dropout)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
x: (B, L, D)
|
| 49 |
+
mask: (L, L) bool — True = masked (don't attend)
|
| 50 |
+
Returns:
|
| 51 |
+
(B, L, D)
|
| 52 |
+
"""
|
| 53 |
+
B, L, _ = x.shape
|
| 54 |
+
H, dk = self.n_heads, self.d_k
|
| 55 |
+
|
| 56 |
+
Q = self.w_q(x).view(B, L, H, dk).transpose(1, 2) # (B, H, L, dk)
|
| 57 |
+
K = self.w_k(x).view(B, L, H, dk).transpose(1, 2)
|
| 58 |
+
V = self.w_v(x).view(B, L, H, dk).transpose(1, 2)
|
| 59 |
+
|
| 60 |
+
# Content attention: Q K^T
|
| 61 |
+
content = torch.matmul(Q, K.transpose(-2, -1)) # (B, H, L, L)
|
| 62 |
+
|
| 63 |
+
# Relative position attention: Q R^T via efficient gather
|
| 64 |
+
rel = self._relative_attention(Q, L) # (B, H, L, L)
|
| 65 |
+
|
| 66 |
+
attn = (content + rel) / self.scale
|
| 67 |
+
|
| 68 |
+
if mask is not None:
|
| 69 |
+
attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
|
| 70 |
+
|
| 71 |
+
attn = self.dropout(F.softmax(attn, dim=-1))
|
| 72 |
+
out = torch.matmul(attn, V) # (B, H, L, dk)
|
| 73 |
+
out = out.transpose(1, 2).contiguous().view(B, L, -1)
|
| 74 |
+
return self.w_o(out)
|
| 75 |
+
|
| 76 |
+
def _relative_attention(self, Q: torch.Tensor, L: int) -> torch.Tensor:
|
| 77 |
+
"""Compute Q @ R^T using relative position embeddings.
|
| 78 |
+
|
| 79 |
+
Uses the index-gather approach: for each (i, j) pair, the relative
|
| 80 |
+
position is j - i, shifted to a non-negative index.
|
| 81 |
+
"""
|
| 82 |
+
device = Q.device
|
| 83 |
+
# Relative position indices: rel[i,j] = j - i + max_seq_len - 1
|
| 84 |
+
positions = torch.arange(L, device=device)
|
| 85 |
+
rel_idx = positions.unsqueeze(0) - positions.unsqueeze(1) + self.max_seq_len - 1
|
| 86 |
+
rel_idx = rel_idx.clamp(0, 2 * self.max_seq_len - 2)
|
| 87 |
+
|
| 88 |
+
R = self.rel_emb(rel_idx) # (L, L, dk)
|
| 89 |
+
|
| 90 |
+
# Q: (B, H, L, dk) R: (L, L, dk) → need (B, H, L, L)
|
| 91 |
+
# Reshape Q to (B*H, L, dk), bmm with R^T reshaped
|
| 92 |
+
BH = Q.shape[0] * Q.shape[1]
|
| 93 |
+
Q_flat = Q.reshape(BH, L, self.d_k) # (BH, L, dk)
|
| 94 |
+
|
| 95 |
+
# For each query position i, we want dot(Q[i], R[i, :, :]) → (BH, L, L)
|
| 96 |
+
# R: (L, L, dk) → transpose last two → (L, dk, L)
|
| 97 |
+
# Then Q_flat[:, i, :] @ R[i, :, :].T for each i
|
| 98 |
+
# Efficient: einsum
|
| 99 |
+
rel_score = torch.einsum("bld,lsd->bls", Q_flat, R) # (BH, L, L)
|
| 100 |
+
return rel_score.view(Q.shape[0], Q.shape[1], L, L)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class TransformerBlock(nn.Module):
|
| 104 |
+
"""Pre-norm Transformer decoder block."""
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
d_model: int,
|
| 109 |
+
n_heads: int,
|
| 110 |
+
d_ff: int,
|
| 111 |
+
max_seq_len: int,
|
| 112 |
+
dropout: float = 0.1,
|
| 113 |
+
) -> None:
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 116 |
+
self.attn = RelativeMultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
|
| 117 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 118 |
+
self.ffn = nn.Sequential(
|
| 119 |
+
nn.Linear(d_model, d_ff),
|
| 120 |
+
nn.GELU(),
|
| 121 |
+
nn.Dropout(dropout),
|
| 122 |
+
nn.Linear(d_ff, d_model),
|
| 123 |
+
nn.Dropout(dropout),
|
| 124 |
+
)
|
| 125 |
+
self.drop = nn.Dropout(dropout)
|
| 126 |
+
|
| 127 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 128 |
+
x = x + self.drop(self.attn(self.norm1(x), mask))
|
| 129 |
+
x = x + self.ffn(self.norm2(x))
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class MusicTransformer(nn.Module):
|
| 134 |
+
"""Autoregressive Music Transformer for chord generation."""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
vocab_size: int,
|
| 139 |
+
d_model: int = 512,
|
| 140 |
+
n_heads: int = 8,
|
| 141 |
+
d_ff: int = 2048,
|
| 142 |
+
n_layers: int = 8,
|
| 143 |
+
max_seq_len: int = 512,
|
| 144 |
+
dropout: float = 0.1,
|
| 145 |
+
pad_id: int = 0,
|
| 146 |
+
) -> None:
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.d_model = d_model
|
| 149 |
+
self.max_seq_len = max_seq_len
|
| 150 |
+
self.pad_id = pad_id
|
| 151 |
+
|
| 152 |
+
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
|
| 153 |
+
self.drop = nn.Dropout(dropout)
|
| 154 |
+
|
| 155 |
+
self.layers = nn.ModuleList([
|
| 156 |
+
TransformerBlock(d_model, n_heads, d_ff, max_seq_len, dropout)
|
| 157 |
+
for _ in range(n_layers)
|
| 158 |
+
])
|
| 159 |
+
|
| 160 |
+
self.norm = nn.LayerNorm(d_model)
|
| 161 |
+
self.out_proj = nn.Linear(d_model, vocab_size, bias=False)
|
| 162 |
+
|
| 163 |
+
# Weight tying (embedding ↔ output projection)
|
| 164 |
+
self.out_proj.weight = self.token_emb.weight
|
| 165 |
+
|
| 166 |
+
self._init_weights()
|
| 167 |
+
|
| 168 |
+
def _init_weights(self) -> None:
|
| 169 |
+
for name, p in self.named_parameters():
|
| 170 |
+
if p.dim() > 1 and "token_emb" not in name:
|
| 171 |
+
nn.init.xavier_uniform_(p)
|
| 172 |
+
# Embedding std=1/sqrt(d_model) so that after *sqrt(d_model) scaling
|
| 173 |
+
# inputs have unit variance, and weight-tied output logits stay small
|
| 174 |
+
nn.init.normal_(self.token_emb.weight, mean=0.0, std=self.d_model ** -0.5)
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def _causal_mask(L: int, device: torch.device) -> torch.Tensor:
|
| 178 |
+
"""Upper-triangular causal mask (True = masked)."""
|
| 179 |
+
return torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
| 180 |
+
|
| 181 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
input_ids: (B, L) token IDs
|
| 185 |
+
Returns:
|
| 186 |
+
logits: (B, L, vocab_size)
|
| 187 |
+
"""
|
| 188 |
+
B, L = input_ids.shape
|
| 189 |
+
x = self.token_emb(input_ids) * math.sqrt(self.d_model)
|
| 190 |
+
x = self.drop(x)
|
| 191 |
+
|
| 192 |
+
mask = self._causal_mask(L, input_ids.device)
|
| 193 |
+
for layer in self.layers:
|
| 194 |
+
x = layer(x, mask)
|
| 195 |
+
|
| 196 |
+
return self.out_proj(self.norm(x))
|
| 197 |
+
|
| 198 |
+
def count_parameters(self) -> int:
|
| 199 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def generate(
|
| 203 |
+
self,
|
| 204 |
+
prompt_ids: torch.Tensor,
|
| 205 |
+
max_new_tokens: int = 64,
|
| 206 |
+
temperature: float = 1.0,
|
| 207 |
+
top_k: int = 0,
|
| 208 |
+
top_p: float = 0.9,
|
| 209 |
+
eos_id: int = 2,
|
| 210 |
+
repetition_penalty: float = 1.0,
|
| 211 |
+
no_repeat_ngram_size: int = 0,
|
| 212 |
+
ignore_repeat_token_ids: set[int] | None = None,
|
| 213 |
+
) -> torch.Tensor:
|
| 214 |
+
"""Autoregressive generation from a prompt.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
prompt_ids: (1, L) token IDs including [BOS] and context.
|
| 218 |
+
max_new_tokens: maximum tokens to generate.
|
| 219 |
+
temperature: sampling temperature (lower = more deterministic).
|
| 220 |
+
top_k: keep only top-k logits (0 = disabled).
|
| 221 |
+
top_p: nucleus sampling threshold.
|
| 222 |
+
eos_id: stop token.
|
| 223 |
+
repetition_penalty: divide logits of previously-seen tokens by
|
| 224 |
+
this factor (HF convention). > 1.0 discourages repeats.
|
| 225 |
+
1.0 disables. Typical: 1.2–1.5.
|
| 226 |
+
no_repeat_ngram_size: ban candidate tokens that would complete
|
| 227 |
+
an n-gram already present in the current sequence (n =
|
| 228 |
+
this value). 0 disables. Typical: 3 for chord sequences.
|
| 229 |
+
ignore_repeat_token_ids: token ids exempt from the two repetition
|
| 230 |
+
controls above — e.g. [BAR] or other separators that
|
| 231 |
+
*should* recur. If None, no exemptions.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
(1, L') full sequence including prompt and generated tokens.
|
| 235 |
+
"""
|
| 236 |
+
self.eval()
|
| 237 |
+
ids = prompt_ids.clone()
|
| 238 |
+
exempt = ignore_repeat_token_ids or set()
|
| 239 |
+
|
| 240 |
+
for _ in range(max_new_tokens):
|
| 241 |
+
ctx = ids[:, -self.max_seq_len :]
|
| 242 |
+
logits = self(ctx)[:, -1, :] / max(temperature, 1e-8)
|
| 243 |
+
|
| 244 |
+
# Repetition penalty (HuggingFace-style): scale already-seen token
|
| 245 |
+
# logits so they are less attractive. Positive logits get divided,
|
| 246 |
+
# negative logits get multiplied (stays "less attractive" either sign).
|
| 247 |
+
if repetition_penalty != 1.0:
|
| 248 |
+
seen = set(ids[0].tolist()) - exempt
|
| 249 |
+
if seen:
|
| 250 |
+
idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
|
| 251 |
+
vals = logits[0, idx]
|
| 252 |
+
vals = torch.where(
|
| 253 |
+
vals > 0,
|
| 254 |
+
vals / repetition_penalty,
|
| 255 |
+
vals * repetition_penalty,
|
| 256 |
+
)
|
| 257 |
+
logits[0, idx] = vals
|
| 258 |
+
|
| 259 |
+
# No-repeat n-gram: block any candidate token that would complete
|
| 260 |
+
# an n-gram already present earlier in the sequence.
|
| 261 |
+
if no_repeat_ngram_size > 0 and ids.shape[1] >= no_repeat_ngram_size:
|
| 262 |
+
n = no_repeat_ngram_size
|
| 263 |
+
seq = ids[0].tolist()
|
| 264 |
+
prefix = tuple(seq[-(n - 1):]) if n > 1 else ()
|
| 265 |
+
banned: set[int] = set()
|
| 266 |
+
for i in range(len(seq) - n + 1):
|
| 267 |
+
if tuple(seq[i : i + n - 1]) == prefix:
|
| 268 |
+
banned.add(seq[i + n - 1])
|
| 269 |
+
banned -= exempt
|
| 270 |
+
if banned:
|
| 271 |
+
bidx = torch.tensor(list(banned), device=logits.device, dtype=torch.long)
|
| 272 |
+
logits[0, bidx] = float("-inf")
|
| 273 |
+
|
| 274 |
+
# Top-k
|
| 275 |
+
if top_k > 0:
|
| 276 |
+
topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 277 |
+
logits[logits < topk_vals[:, -1:]] = float("-inf")
|
| 278 |
+
|
| 279 |
+
# Top-p (nucleus)
|
| 280 |
+
if 0 < top_p < 1.0:
|
| 281 |
+
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 282 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 283 |
+
remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
|
| 284 |
+
sorted_logits[remove] = float("-inf")
|
| 285 |
+
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
| 286 |
+
|
| 287 |
+
probs = F.softmax(logits, dim=-1)
|
| 288 |
+
next_id = torch.multinomial(probs, num_samples=1)
|
| 289 |
+
ids = torch.cat([ids, next_id], dim=-1)
|
| 290 |
+
|
| 291 |
+
if (next_id == eos_id).all():
|
| 292 |
+
break
|
| 293 |
+
|
| 294 |
+
return ids
|
tokenizer.json
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"token2id": {
|
| 3 |
+
"[PAD]": 0,
|
| 4 |
+
"[BOS]": 1,
|
| 5 |
+
"[EOS]": 2,
|
| 6 |
+
"[BAR]": 3,
|
| 7 |
+
"[KEY:Cmaj]": 4,
|
| 8 |
+
"[KEY:Dbmaj]": 5,
|
| 9 |
+
"[KEY:Dmaj]": 6,
|
| 10 |
+
"[KEY:Ebmaj]": 7,
|
| 11 |
+
"[KEY:Emaj]": 8,
|
| 12 |
+
"[KEY:Fmaj]": 9,
|
| 13 |
+
"[KEY:F#maj]": 10,
|
| 14 |
+
"[KEY:Gmaj]": 11,
|
| 15 |
+
"[KEY:Abmaj]": 12,
|
| 16 |
+
"[KEY:Amaj]": 13,
|
| 17 |
+
"[KEY:Bbmaj]": 14,
|
| 18 |
+
"[KEY:Bmaj]": 15,
|
| 19 |
+
"[KEY:Cmin]": 16,
|
| 20 |
+
"[KEY:Dbmin]": 17,
|
| 21 |
+
"[KEY:Dmin]": 18,
|
| 22 |
+
"[KEY:Ebmin]": 19,
|
| 23 |
+
"[KEY:Emin]": 20,
|
| 24 |
+
"[KEY:Fmin]": 21,
|
| 25 |
+
"[KEY:F#min]": 22,
|
| 26 |
+
"[KEY:Gmin]": 23,
|
| 27 |
+
"[KEY:Abmin]": 24,
|
| 28 |
+
"[KEY:Amin]": 25,
|
| 29 |
+
"[KEY:Bbmin]": 26,
|
| 30 |
+
"[KEY:Bmin]": 27,
|
| 31 |
+
"[TIME:4/4]": 28,
|
| 32 |
+
"[TIME:3/4]": 29,
|
| 33 |
+
"[TIME:6/8]": 30,
|
| 34 |
+
"[TIME:2/4]": 31,
|
| 35 |
+
"[TIME:5/4]": 32,
|
| 36 |
+
"[GENRE:jazz]": 33,
|
| 37 |
+
"[GENRE:pop]": 34,
|
| 38 |
+
"[GENRE:rock]": 35,
|
| 39 |
+
"[GENRE:blues]": 36,
|
| 40 |
+
"[GENRE:bossa]": 37,
|
| 41 |
+
"[GENRE:none]": 38,
|
| 42 |
+
"Cmaj": 39,
|
| 43 |
+
"Cm": 40,
|
| 44 |
+
"C7": 41,
|
| 45 |
+
"Cmaj7": 42,
|
| 46 |
+
"Cm7": 43,
|
| 47 |
+
"Cm7b5": 44,
|
| 48 |
+
"Cdim7": 45,
|
| 49 |
+
"Cdim": 46,
|
| 50 |
+
"Caug": 47,
|
| 51 |
+
"Csus4": 48,
|
| 52 |
+
"Csus2": 49,
|
| 53 |
+
"C6": 50,
|
| 54 |
+
"Cm6": 51,
|
| 55 |
+
"C9": 52,
|
| 56 |
+
"Cm9": 53,
|
| 57 |
+
"Cmaj9": 54,
|
| 58 |
+
"C11": 55,
|
| 59 |
+
"Cm11": 56,
|
| 60 |
+
"C13": 57,
|
| 61 |
+
"Cm13": 58,
|
| 62 |
+
"Cadd9": 59,
|
| 63 |
+
"CmMaj7": 60,
|
| 64 |
+
"C7b9": 61,
|
| 65 |
+
"C7#9": 62,
|
| 66 |
+
"C7#11": 63,
|
| 67 |
+
"C7b13": 64,
|
| 68 |
+
"Dbmaj": 65,
|
| 69 |
+
"Dbm": 66,
|
| 70 |
+
"Db7": 67,
|
| 71 |
+
"Dbmaj7": 68,
|
| 72 |
+
"Dbm7": 69,
|
| 73 |
+
"Dbm7b5": 70,
|
| 74 |
+
"Dbdim7": 71,
|
| 75 |
+
"Dbdim": 72,
|
| 76 |
+
"Dbaug": 73,
|
| 77 |
+
"Dbsus4": 74,
|
| 78 |
+
"Dbsus2": 75,
|
| 79 |
+
"Db6": 76,
|
| 80 |
+
"Dbm6": 77,
|
| 81 |
+
"Db9": 78,
|
| 82 |
+
"Dbm9": 79,
|
| 83 |
+
"Dbmaj9": 80,
|
| 84 |
+
"Db11": 81,
|
| 85 |
+
"Dbm11": 82,
|
| 86 |
+
"Db13": 83,
|
| 87 |
+
"Dbm13": 84,
|
| 88 |
+
"Dbadd9": 85,
|
| 89 |
+
"DbmMaj7": 86,
|
| 90 |
+
"Db7b9": 87,
|
| 91 |
+
"Db7#9": 88,
|
| 92 |
+
"Db7#11": 89,
|
| 93 |
+
"Db7b13": 90,
|
| 94 |
+
"Dmaj": 91,
|
| 95 |
+
"Dm": 92,
|
| 96 |
+
"D7": 93,
|
| 97 |
+
"Dmaj7": 94,
|
| 98 |
+
"Dm7": 95,
|
| 99 |
+
"Dm7b5": 96,
|
| 100 |
+
"Ddim7": 97,
|
| 101 |
+
"Ddim": 98,
|
| 102 |
+
"Daug": 99,
|
| 103 |
+
"Dsus4": 100,
|
| 104 |
+
"Dsus2": 101,
|
| 105 |
+
"D6": 102,
|
| 106 |
+
"Dm6": 103,
|
| 107 |
+
"D9": 104,
|
| 108 |
+
"Dm9": 105,
|
| 109 |
+
"Dmaj9": 106,
|
| 110 |
+
"D11": 107,
|
| 111 |
+
"Dm11": 108,
|
| 112 |
+
"D13": 109,
|
| 113 |
+
"Dm13": 110,
|
| 114 |
+
"Dadd9": 111,
|
| 115 |
+
"DmMaj7": 112,
|
| 116 |
+
"D7b9": 113,
|
| 117 |
+
"D7#9": 114,
|
| 118 |
+
"D7#11": 115,
|
| 119 |
+
"D7b13": 116,
|
| 120 |
+
"Ebmaj": 117,
|
| 121 |
+
"Ebm": 118,
|
| 122 |
+
"Eb7": 119,
|
| 123 |
+
"Ebmaj7": 120,
|
| 124 |
+
"Ebm7": 121,
|
| 125 |
+
"Ebm7b5": 122,
|
| 126 |
+
"Ebdim7": 123,
|
| 127 |
+
"Ebdim": 124,
|
| 128 |
+
"Ebaug": 125,
|
| 129 |
+
"Ebsus4": 126,
|
| 130 |
+
"Ebsus2": 127,
|
| 131 |
+
"Eb6": 128,
|
| 132 |
+
"Ebm6": 129,
|
| 133 |
+
"Eb9": 130,
|
| 134 |
+
"Ebm9": 131,
|
| 135 |
+
"Ebmaj9": 132,
|
| 136 |
+
"Eb11": 133,
|
| 137 |
+
"Ebm11": 134,
|
| 138 |
+
"Eb13": 135,
|
| 139 |
+
"Ebm13": 136,
|
| 140 |
+
"Ebadd9": 137,
|
| 141 |
+
"EbmMaj7": 138,
|
| 142 |
+
"Eb7b9": 139,
|
| 143 |
+
"Eb7#9": 140,
|
| 144 |
+
"Eb7#11": 141,
|
| 145 |
+
"Eb7b13": 142,
|
| 146 |
+
"Emaj": 143,
|
| 147 |
+
"Em": 144,
|
| 148 |
+
"E7": 145,
|
| 149 |
+
"Emaj7": 146,
|
| 150 |
+
"Em7": 147,
|
| 151 |
+
"Em7b5": 148,
|
| 152 |
+
"Edim7": 149,
|
| 153 |
+
"Edim": 150,
|
| 154 |
+
"Eaug": 151,
|
| 155 |
+
"Esus4": 152,
|
| 156 |
+
"Esus2": 153,
|
| 157 |
+
"E6": 154,
|
| 158 |
+
"Em6": 155,
|
| 159 |
+
"E9": 156,
|
| 160 |
+
"Em9": 157,
|
| 161 |
+
"Emaj9": 158,
|
| 162 |
+
"E11": 159,
|
| 163 |
+
"Em11": 160,
|
| 164 |
+
"E13": 161,
|
| 165 |
+
"Em13": 162,
|
| 166 |
+
"Eadd9": 163,
|
| 167 |
+
"EmMaj7": 164,
|
| 168 |
+
"E7b9": 165,
|
| 169 |
+
"E7#9": 166,
|
| 170 |
+
"E7#11": 167,
|
| 171 |
+
"E7b13": 168,
|
| 172 |
+
"Fmaj": 169,
|
| 173 |
+
"Fm": 170,
|
| 174 |
+
"F7": 171,
|
| 175 |
+
"Fmaj7": 172,
|
| 176 |
+
"Fm7": 173,
|
| 177 |
+
"Fm7b5": 174,
|
| 178 |
+
"Fdim7": 175,
|
| 179 |
+
"Fdim": 176,
|
| 180 |
+
"Faug": 177,
|
| 181 |
+
"Fsus4": 178,
|
| 182 |
+
"Fsus2": 179,
|
| 183 |
+
"F6": 180,
|
| 184 |
+
"Fm6": 181,
|
| 185 |
+
"F9": 182,
|
| 186 |
+
"Fm9": 183,
|
| 187 |
+
"Fmaj9": 184,
|
| 188 |
+
"F11": 185,
|
| 189 |
+
"Fm11": 186,
|
| 190 |
+
"F13": 187,
|
| 191 |
+
"Fm13": 188,
|
| 192 |
+
"Fadd9": 189,
|
| 193 |
+
"FmMaj7": 190,
|
| 194 |
+
"F7b9": 191,
|
| 195 |
+
"F7#9": 192,
|
| 196 |
+
"F7#11": 193,
|
| 197 |
+
"F7b13": 194,
|
| 198 |
+
"F#maj": 195,
|
| 199 |
+
"F#m": 196,
|
| 200 |
+
"F#7": 197,
|
| 201 |
+
"F#maj7": 198,
|
| 202 |
+
"F#m7": 199,
|
| 203 |
+
"F#m7b5": 200,
|
| 204 |
+
"F#dim7": 201,
|
| 205 |
+
"F#dim": 202,
|
| 206 |
+
"F#aug": 203,
|
| 207 |
+
"F#sus4": 204,
|
| 208 |
+
"F#sus2": 205,
|
| 209 |
+
"F#6": 206,
|
| 210 |
+
"F#m6": 207,
|
| 211 |
+
"F#9": 208,
|
| 212 |
+
"F#m9": 209,
|
| 213 |
+
"F#maj9": 210,
|
| 214 |
+
"F#11": 211,
|
| 215 |
+
"F#m11": 212,
|
| 216 |
+
"F#13": 213,
|
| 217 |
+
"F#m13": 214,
|
| 218 |
+
"F#add9": 215,
|
| 219 |
+
"F#mMaj7": 216,
|
| 220 |
+
"F#7b9": 217,
|
| 221 |
+
"F#7#9": 218,
|
| 222 |
+
"F#7#11": 219,
|
| 223 |
+
"F#7b13": 220,
|
| 224 |
+
"Gmaj": 221,
|
| 225 |
+
"Gm": 222,
|
| 226 |
+
"G7": 223,
|
| 227 |
+
"Gmaj7": 224,
|
| 228 |
+
"Gm7": 225,
|
| 229 |
+
"Gm7b5": 226,
|
| 230 |
+
"Gdim7": 227,
|
| 231 |
+
"Gdim": 228,
|
| 232 |
+
"Gaug": 229,
|
| 233 |
+
"Gsus4": 230,
|
| 234 |
+
"Gsus2": 231,
|
| 235 |
+
"G6": 232,
|
| 236 |
+
"Gm6": 233,
|
| 237 |
+
"G9": 234,
|
| 238 |
+
"Gm9": 235,
|
| 239 |
+
"Gmaj9": 236,
|
| 240 |
+
"G11": 237,
|
| 241 |
+
"Gm11": 238,
|
| 242 |
+
"G13": 239,
|
| 243 |
+
"Gm13": 240,
|
| 244 |
+
"Gadd9": 241,
|
| 245 |
+
"GmMaj7": 242,
|
| 246 |
+
"G7b9": 243,
|
| 247 |
+
"G7#9": 244,
|
| 248 |
+
"G7#11": 245,
|
| 249 |
+
"G7b13": 246,
|
| 250 |
+
"Abmaj": 247,
|
| 251 |
+
"Abm": 248,
|
| 252 |
+
"Ab7": 249,
|
| 253 |
+
"Abmaj7": 250,
|
| 254 |
+
"Abm7": 251,
|
| 255 |
+
"Abm7b5": 252,
|
| 256 |
+
"Abdim7": 253,
|
| 257 |
+
"Abdim": 254,
|
| 258 |
+
"Abaug": 255,
|
| 259 |
+
"Absus4": 256,
|
| 260 |
+
"Absus2": 257,
|
| 261 |
+
"Ab6": 258,
|
| 262 |
+
"Abm6": 259,
|
| 263 |
+
"Ab9": 260,
|
| 264 |
+
"Abm9": 261,
|
| 265 |
+
"Abmaj9": 262,
|
| 266 |
+
"Ab11": 263,
|
| 267 |
+
"Abm11": 264,
|
| 268 |
+
"Ab13": 265,
|
| 269 |
+
"Abm13": 266,
|
| 270 |
+
"Abadd9": 267,
|
| 271 |
+
"AbmMaj7": 268,
|
| 272 |
+
"Ab7b9": 269,
|
| 273 |
+
"Ab7#9": 270,
|
| 274 |
+
"Ab7#11": 271,
|
| 275 |
+
"Ab7b13": 272,
|
| 276 |
+
"Amaj": 273,
|
| 277 |
+
"Am": 274,
|
| 278 |
+
"A7": 275,
|
| 279 |
+
"Amaj7": 276,
|
| 280 |
+
"Am7": 277,
|
| 281 |
+
"Am7b5": 278,
|
| 282 |
+
"Adim7": 279,
|
| 283 |
+
"Adim": 280,
|
| 284 |
+
"Aaug": 281,
|
| 285 |
+
"Asus4": 282,
|
| 286 |
+
"Asus2": 283,
|
| 287 |
+
"A6": 284,
|
| 288 |
+
"Am6": 285,
|
| 289 |
+
"A9": 286,
|
| 290 |
+
"Am9": 287,
|
| 291 |
+
"Amaj9": 288,
|
| 292 |
+
"A11": 289,
|
| 293 |
+
"Am11": 290,
|
| 294 |
+
"A13": 291,
|
| 295 |
+
"Am13": 292,
|
| 296 |
+
"Aadd9": 293,
|
| 297 |
+
"AmMaj7": 294,
|
| 298 |
+
"A7b9": 295,
|
| 299 |
+
"A7#9": 296,
|
| 300 |
+
"A7#11": 297,
|
| 301 |
+
"A7b13": 298,
|
| 302 |
+
"Bbmaj": 299,
|
| 303 |
+
"Bbm": 300,
|
| 304 |
+
"Bb7": 301,
|
| 305 |
+
"Bbmaj7": 302,
|
| 306 |
+
"Bbm7": 303,
|
| 307 |
+
"Bbm7b5": 304,
|
| 308 |
+
"Bbdim7": 305,
|
| 309 |
+
"Bbdim": 306,
|
| 310 |
+
"Bbaug": 307,
|
| 311 |
+
"Bbsus4": 308,
|
| 312 |
+
"Bbsus2": 309,
|
| 313 |
+
"Bb6": 310,
|
| 314 |
+
"Bbm6": 311,
|
| 315 |
+
"Bb9": 312,
|
| 316 |
+
"Bbm9": 313,
|
| 317 |
+
"Bbmaj9": 314,
|
| 318 |
+
"Bb11": 315,
|
| 319 |
+
"Bbm11": 316,
|
| 320 |
+
"Bb13": 317,
|
| 321 |
+
"Bbm13": 318,
|
| 322 |
+
"Bbadd9": 319,
|
| 323 |
+
"BbmMaj7": 320,
|
| 324 |
+
"Bb7b9": 321,
|
| 325 |
+
"Bb7#9": 322,
|
| 326 |
+
"Bb7#11": 323,
|
| 327 |
+
"Bb7b13": 324,
|
| 328 |
+
"Bmaj": 325,
|
| 329 |
+
"Bm": 326,
|
| 330 |
+
"B7": 327,
|
| 331 |
+
"Bmaj7": 328,
|
| 332 |
+
"Bm7": 329,
|
| 333 |
+
"Bm7b5": 330,
|
| 334 |
+
"Bdim7": 331,
|
| 335 |
+
"Bdim": 332,
|
| 336 |
+
"Baug": 333,
|
| 337 |
+
"Bsus4": 334,
|
| 338 |
+
"Bsus2": 335,
|
| 339 |
+
"B6": 336,
|
| 340 |
+
"Bm6": 337,
|
| 341 |
+
"B9": 338,
|
| 342 |
+
"Bm9": 339,
|
| 343 |
+
"Bmaj9": 340,
|
| 344 |
+
"B11": 341,
|
| 345 |
+
"Bm11": 342,
|
| 346 |
+
"B13": 343,
|
| 347 |
+
"Bm13": 344,
|
| 348 |
+
"Badd9": 345,
|
| 349 |
+
"BmMaj7": 346,
|
| 350 |
+
"B7b9": 347,
|
| 351 |
+
"B7#9": 348,
|
| 352 |
+
"B7#11": 349,
|
| 353 |
+
"B7b13": 350
|
| 354 |
+
},
|
| 355 |
+
"vocab_size": 351
|
| 356 |
+
}
|
tokenizer.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chord sequence tokenizer for Music Transformer training.
|
| 2 |
+
|
| 3 |
+
Vocabulary (~350 tokens):
|
| 4 |
+
[PAD]=0, [BOS]=1, [EOS]=2, [BAR]=3
|
| 5 |
+
[KEY:Cmaj] ... [KEY:Bmin] (24 keys)
|
| 6 |
+
[TIME:4/4] ... [TIME:5/4] (5 time sigs)
|
| 7 |
+
[GENRE:jazz] ... [GENRE:none] (6 genres)
|
| 8 |
+
Cmaj, Cm, C7, ... B7b13 (12 roots x 26 qualities = 312 chords)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Canonical root names (jazz convention: prefer flats)
|
| 17 |
+
ROOTS = ["C", "Db", "D", "Eb", "E", "F", "F#", "G", "Ab", "A", "Bb", "B"]
|
| 18 |
+
|
| 19 |
+
# Root name aliases for normalization
|
| 20 |
+
ROOT_ALIASES: dict[str, str] = {
|
| 21 |
+
"C#": "Db", "D#": "Eb", "E#": "F", "Fb": "E",
|
| 22 |
+
"G#": "Ab", "A#": "Bb", "B#": "C", "Cb": "B",
|
| 23 |
+
"Gb": "F#",
|
| 24 |
+
# Lowercase
|
| 25 |
+
"c": "C", "d": "D", "e": "E", "f": "F", "g": "G", "a": "A", "b": "B",
|
| 26 |
+
"c#": "Db", "db": "Db", "d#": "Eb", "eb": "Eb",
|
| 27 |
+
"f#": "F#", "gb": "F#", "g#": "Ab", "ab": "Ab",
|
| 28 |
+
"a#": "Bb", "bb": "Bb", "cb": "B", "fb": "E",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# Chord qualities in our vocabulary
|
| 32 |
+
QUALITIES = [
|
| 33 |
+
"maj", "m", "7", "maj7", "m7", "m7b5", "dim7", "dim", "aug",
|
| 34 |
+
"sus4", "sus2", "6", "m6", "9", "m9", "maj9", "11", "m11",
|
| 35 |
+
"13", "m13", "add9", "mMaj7", "7b9", "7#9", "7#11", "7b13",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# Quality alias mapping → canonical quality
|
| 39 |
+
_QUALITY_ALIASES: dict[str, str] = {
|
| 40 |
+
# Major
|
| 41 |
+
"major": "maj", "M": "maj",
|
| 42 |
+
# Minor
|
| 43 |
+
"min": "m", "minor": "m", "-": "m", "mi": "m",
|
| 44 |
+
# Dominant 7
|
| 45 |
+
"dom7": "7", "dom": "7",
|
| 46 |
+
# Major 7
|
| 47 |
+
"^7": "maj7", "M7": "maj7", "Maj7": "maj7", "major7": "maj7",
|
| 48 |
+
"j7": "maj7", "^": "maj7", "delta": "maj7",
|
| 49 |
+
# Minor 7
|
| 50 |
+
"min7": "m7", "-7": "m7", "mi7": "m7",
|
| 51 |
+
# Half-diminished
|
| 52 |
+
"hdim7": "m7b5", "hdim": "m7b5", "h7": "m7b5",
|
| 53 |
+
"%7": "m7b5", "%": "m7b5",
|
| 54 |
+
# Diminished
|
| 55 |
+
"o": "dim", "o7": "dim7",
|
| 56 |
+
# Augmented
|
| 57 |
+
"+": "aug",
|
| 58 |
+
# Suspended
|
| 59 |
+
"sus": "sus4",
|
| 60 |
+
# 6th
|
| 61 |
+
"min6": "m6", "-6": "m6",
|
| 62 |
+
# 9th
|
| 63 |
+
"min9": "m9", "-9": "m9", "M9": "maj9", "^9": "maj9", "Maj9": "maj9",
|
| 64 |
+
# 11th
|
| 65 |
+
"min11": "m11", "-11": "m11",
|
| 66 |
+
# 13th
|
| 67 |
+
"min13": "m13", "-13": "m13",
|
| 68 |
+
# Minor-major 7
|
| 69 |
+
"minmaj7": "mMaj7", "-^7": "mMaj7", "mM7": "mMaj7",
|
| 70 |
+
# Altered dominants
|
| 71 |
+
"7alt": "7b9",
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# Keys and metadata
|
| 75 |
+
MAJOR_KEYS = [f"{r}maj" for r in ROOTS]
|
| 76 |
+
MINOR_KEYS = [f"{r}min" for r in ROOTS]
|
| 77 |
+
ALL_KEYS = MAJOR_KEYS + MINOR_KEYS
|
| 78 |
+
TIME_SIGS = ["4/4", "3/4", "6/8", "2/4", "5/4"]
|
| 79 |
+
GENRES = ["jazz", "pop", "rock", "blues", "bossa"]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ChordTokenizer:
|
| 83 |
+
"""Deterministic tokenizer for chord sequences."""
|
| 84 |
+
|
| 85 |
+
PAD = 0
|
| 86 |
+
BOS = 1
|
| 87 |
+
EOS = 2
|
| 88 |
+
BAR = 3
|
| 89 |
+
|
| 90 |
+
def __init__(self) -> None:
|
| 91 |
+
self.token2id: dict[str, int] = {}
|
| 92 |
+
self.id2token: dict[int, str] = {}
|
| 93 |
+
self._build_vocab()
|
| 94 |
+
|
| 95 |
+
# ------------------------------------------------------------------
|
| 96 |
+
# Vocab construction
|
| 97 |
+
# ------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
def _build_vocab(self) -> None:
|
| 100 |
+
tokens: list[str] = ["[PAD]", "[BOS]", "[EOS]", "[BAR]"]
|
| 101 |
+
for key in ALL_KEYS:
|
| 102 |
+
tokens.append(f"[KEY:{key}]")
|
| 103 |
+
for ts in TIME_SIGS:
|
| 104 |
+
tokens.append(f"[TIME:{ts}]")
|
| 105 |
+
for genre in GENRES:
|
| 106 |
+
tokens.append(f"[GENRE:{genre}]")
|
| 107 |
+
tokens.append("[GENRE:none]")
|
| 108 |
+
for root in ROOTS:
|
| 109 |
+
for quality in QUALITIES:
|
| 110 |
+
tokens.append(f"{root}{quality}")
|
| 111 |
+
for i, tok in enumerate(tokens):
|
| 112 |
+
self.token2id[tok] = i
|
| 113 |
+
self.id2token[i] = tok
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def vocab_size(self) -> int:
|
| 117 |
+
return len(self.token2id)
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def pad_id(self) -> int:
|
| 121 |
+
return self.PAD
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def bos_id(self) -> int:
|
| 125 |
+
return self.BOS
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def eos_id(self) -> int:
|
| 129 |
+
return self.EOS
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def bar_id(self) -> int:
|
| 133 |
+
return self.BAR
|
| 134 |
+
|
| 135 |
+
# ------------------------------------------------------------------
|
| 136 |
+
# Encoding helpers
|
| 137 |
+
# ------------------------------------------------------------------
|
| 138 |
+
|
| 139 |
+
def encode_chord(self, chord_str: str) -> int | None:
|
| 140 |
+
token = self.normalize_chord(chord_str)
|
| 141 |
+
return self.token2id.get(token) if token else None
|
| 142 |
+
|
| 143 |
+
def encode_key(self, key_str: str) -> int | None:
|
| 144 |
+
return self.token2id.get(f"[KEY:{key_str}]")
|
| 145 |
+
|
| 146 |
+
def encode_time_sig(self, ts: str) -> int | None:
|
| 147 |
+
return self.token2id.get(f"[TIME:{ts}]")
|
| 148 |
+
|
| 149 |
+
def encode_genre(self, genre: str) -> int | None:
|
| 150 |
+
return self.token2id.get(f"[GENRE:{genre}]")
|
| 151 |
+
|
| 152 |
+
def encode_sequence(self, song: dict) -> list[int]:
|
| 153 |
+
"""Encode a unified song dict to a token-ID sequence.
|
| 154 |
+
|
| 155 |
+
Expected *song* format::
|
| 156 |
+
|
| 157 |
+
{
|
| 158 |
+
"key": "Cmaj",
|
| 159 |
+
"time_signature": "4/4",
|
| 160 |
+
"genre": "jazz",
|
| 161 |
+
"bars": [["Cmaj7", "Am7"], ["Dm7", "G7"], ...]
|
| 162 |
+
}
|
| 163 |
+
"""
|
| 164 |
+
ids: list[int] = [self.BOS]
|
| 165 |
+
|
| 166 |
+
kid = self.encode_key(song.get("key", "Cmaj"))
|
| 167 |
+
if kid is not None:
|
| 168 |
+
ids.append(kid)
|
| 169 |
+
|
| 170 |
+
tid = self.encode_time_sig(song.get("time_signature", "4/4"))
|
| 171 |
+
if tid is not None:
|
| 172 |
+
ids.append(tid)
|
| 173 |
+
|
| 174 |
+
gid = self.encode_genre(song.get("genre", "none"))
|
| 175 |
+
if gid is not None:
|
| 176 |
+
ids.append(gid)
|
| 177 |
+
|
| 178 |
+
for bar in song.get("bars", []):
|
| 179 |
+
ids.append(self.BAR)
|
| 180 |
+
for chord in bar:
|
| 181 |
+
cid = self.encode_chord(chord)
|
| 182 |
+
if cid is not None:
|
| 183 |
+
ids.append(cid)
|
| 184 |
+
|
| 185 |
+
ids.append(self.EOS)
|
| 186 |
+
return ids
|
| 187 |
+
|
| 188 |
+
def decode(self, ids: list[int]) -> list[str]:
|
| 189 |
+
return [self.id2token.get(i, "[UNK]") for i in ids]
|
| 190 |
+
|
| 191 |
+
# ------------------------------------------------------------------
|
| 192 |
+
# Chord normalization
|
| 193 |
+
# ------------------------------------------------------------------
|
| 194 |
+
|
| 195 |
+
@staticmethod
|
| 196 |
+
def normalize_root(root: str) -> str | None:
|
| 197 |
+
"""Normalize a root note name to canonical form."""
|
| 198 |
+
if root in ROOTS:
|
| 199 |
+
return root
|
| 200 |
+
if root in ROOT_ALIASES:
|
| 201 |
+
return ROOT_ALIASES[root]
|
| 202 |
+
# Try capitalize first letter
|
| 203 |
+
cap = root[0].upper() + root[1:] if len(root) > 1 else root.upper()
|
| 204 |
+
if cap in ROOTS:
|
| 205 |
+
return cap
|
| 206 |
+
if cap in ROOT_ALIASES:
|
| 207 |
+
return ROOT_ALIASES[cap]
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
def normalize_chord(chord_str: str) -> str | None:
|
| 212 |
+
"""Normalize any chord notation to ``{Root}{quality}`` in our vocab."""
|
| 213 |
+
if not chord_str or chord_str in (
|
| 214 |
+
"N", "NC", "N.C.", "X", "x",
|
| 215 |
+
"pause", "silence", "&pause", "end",
|
| 216 |
+
):
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
# Strip slash-chord bass
|
| 220 |
+
if "/" in chord_str:
|
| 221 |
+
chord_str = chord_str.split("/")[0]
|
| 222 |
+
|
| 223 |
+
# Billboard colon format Root:Quality
|
| 224 |
+
if ":" in chord_str:
|
| 225 |
+
root_part, qual_part = chord_str.split(":", 1)
|
| 226 |
+
# qual_part may also have /bass — already stripped above
|
| 227 |
+
else:
|
| 228 |
+
root_part = chord_str[0]
|
| 229 |
+
qual_part = chord_str[1:]
|
| 230 |
+
if qual_part and qual_part[0] in ("b", "#"):
|
| 231 |
+
root_part += qual_part[0]
|
| 232 |
+
qual_part = qual_part[1:]
|
| 233 |
+
|
| 234 |
+
norm_root = ChordTokenizer.normalize_root(root_part)
|
| 235 |
+
if norm_root is None:
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
quality = ChordTokenizer._normalize_quality(qual_part)
|
| 239 |
+
if quality is None or quality not in QUALITIES:
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
return f"{norm_root}{quality}"
|
| 243 |
+
|
| 244 |
+
@staticmethod
|
| 245 |
+
def _normalize_quality(q: str) -> str | None:
|
| 246 |
+
"""Map various quality notations to our canonical set."""
|
| 247 |
+
if not q:
|
| 248 |
+
return "maj"
|
| 249 |
+
|
| 250 |
+
# Direct hit
|
| 251 |
+
if q in QUALITIES:
|
| 252 |
+
return q
|
| 253 |
+
|
| 254 |
+
# Alias table
|
| 255 |
+
if q in _QUALITY_ALIASES:
|
| 256 |
+
return _QUALITY_ALIASES[q]
|
| 257 |
+
|
| 258 |
+
# Case-insensitive alias search
|
| 259 |
+
for alias, canon in _QUALITY_ALIASES.items():
|
| 260 |
+
if q.lower() == alias.lower():
|
| 261 |
+
return canon
|
| 262 |
+
|
| 263 |
+
# ---- Heuristic fallbacks for unusual notations ----
|
| 264 |
+
|
| 265 |
+
# WJazzD altered dominants: "79b" → 7b9, "79#" → 7#9, etc.
|
| 266 |
+
if q.startswith("7"):
|
| 267 |
+
tail = q[1:]
|
| 268 |
+
if "b9" in tail or "9b" in tail:
|
| 269 |
+
return "7b9"
|
| 270 |
+
if "#9" in tail or "9#" in tail:
|
| 271 |
+
return "7#9"
|
| 272 |
+
if "#11" in tail or "11#" in tail:
|
| 273 |
+
return "7#11"
|
| 274 |
+
if "b13" in tail or "13b" in tail:
|
| 275 |
+
return "7b13"
|
| 276 |
+
|
| 277 |
+
# Compound minor qualities
|
| 278 |
+
if q.startswith("m") or q.startswith("-"):
|
| 279 |
+
inner = q.lstrip("m").lstrip("-")
|
| 280 |
+
if "7" in inner and ("b5" in inner or "b5" in q):
|
| 281 |
+
return "m7b5"
|
| 282 |
+
if "7" in inner:
|
| 283 |
+
return "m7"
|
| 284 |
+
if "9" in inner:
|
| 285 |
+
return "m9"
|
| 286 |
+
if "11" in inner:
|
| 287 |
+
return "m11"
|
| 288 |
+
if "13" in inner:
|
| 289 |
+
return "m13"
|
| 290 |
+
if "6" in inner:
|
| 291 |
+
return "m6"
|
| 292 |
+
return "m"
|
| 293 |
+
|
| 294 |
+
# Bare numbers
|
| 295 |
+
if q in ("7",):
|
| 296 |
+
return "7"
|
| 297 |
+
if q in ("9",):
|
| 298 |
+
return "9"
|
| 299 |
+
if q in ("6",):
|
| 300 |
+
return "6"
|
| 301 |
+
if q in ("11",):
|
| 302 |
+
return "11"
|
| 303 |
+
if q in ("13",):
|
| 304 |
+
return "13"
|
| 305 |
+
|
| 306 |
+
# If nothing matched, approximate as major
|
| 307 |
+
return "maj"
|
| 308 |
+
|
| 309 |
+
# ------------------------------------------------------------------
|
| 310 |
+
# Transposition
|
| 311 |
+
# ------------------------------------------------------------------
|
| 312 |
+
|
| 313 |
+
def transpose_chord_token(self, token: str, semitones: int) -> str | None:
|
| 314 |
+
"""Transpose a chord token string by *semitones*."""
|
| 315 |
+
if token.startswith("["):
|
| 316 |
+
return None
|
| 317 |
+
root = token[0]
|
| 318 |
+
rest = token[1:]
|
| 319 |
+
if rest and rest[0] in ("b", "#"):
|
| 320 |
+
root += rest[0]
|
| 321 |
+
rest = rest[1:]
|
| 322 |
+
norm_root = self.normalize_root(root)
|
| 323 |
+
if norm_root is None:
|
| 324 |
+
return None
|
| 325 |
+
new_root = ROOTS[(ROOTS.index(norm_root) + semitones) % 12]
|
| 326 |
+
return f"{new_root}{rest}"
|
| 327 |
+
|
| 328 |
+
def transpose_key_token(self, token: str, semitones: int) -> str:
|
| 329 |
+
"""Transpose a key token like ``[KEY:Cmaj]``."""
|
| 330 |
+
inner = token[5:-1] # strip [KEY: and ]
|
| 331 |
+
if inner.endswith("maj"):
|
| 332 |
+
root, mode = inner[:-3], "maj"
|
| 333 |
+
elif inner.endswith("min"):
|
| 334 |
+
root, mode = inner[:-3], "min"
|
| 335 |
+
else:
|
| 336 |
+
return token
|
| 337 |
+
norm = self.normalize_root(root)
|
| 338 |
+
if norm is None:
|
| 339 |
+
return token
|
| 340 |
+
new_root = ROOTS[(ROOTS.index(norm) + semitones) % 12]
|
| 341 |
+
return f"[KEY:{new_root}{mode}]"
|
| 342 |
+
|
| 343 |
+
def transpose_sequence(self, ids: list[int], semitones: int) -> list[int]:
|
| 344 |
+
"""Transpose every chord & key token in *ids* by *semitones*."""
|
| 345 |
+
if semitones % 12 == 0:
|
| 346 |
+
return list(ids)
|
| 347 |
+
out: list[int] = []
|
| 348 |
+
for tid in ids:
|
| 349 |
+
tok = self.id2token.get(tid)
|
| 350 |
+
if tok is None:
|
| 351 |
+
out.append(tid)
|
| 352 |
+
elif tok.startswith("[KEY:"):
|
| 353 |
+
new = self.transpose_key_token(tok, semitones)
|
| 354 |
+
out.append(self.token2id.get(new, tid))
|
| 355 |
+
elif tok.startswith("[") or tid <= self.BAR:
|
| 356 |
+
out.append(tid)
|
| 357 |
+
else:
|
| 358 |
+
new = self.transpose_chord_token(tok, semitones)
|
| 359 |
+
out.append(self.token2id[new] if new and new in self.token2id else tid)
|
| 360 |
+
return out
|
| 361 |
+
|
| 362 |
+
# ------------------------------------------------------------------
|
| 363 |
+
# Persistence
|
| 364 |
+
# ------------------------------------------------------------------
|
| 365 |
+
|
| 366 |
+
def save(self, path: str | Path) -> None:
|
| 367 |
+
Path(path).write_text(json.dumps({
|
| 368 |
+
"token2id": self.token2id,
|
| 369 |
+
"vocab_size": self.vocab_size,
|
| 370 |
+
}, indent=2, ensure_ascii=False))
|
| 371 |
+
|
| 372 |
+
@classmethod
|
| 373 |
+
def load(cls, path: str | Path) -> ChordTokenizer:
|
| 374 |
+
tok = cls()
|
| 375 |
+
data = json.loads(Path(path).read_text())
|
| 376 |
+
assert data["vocab_size"] == tok.vocab_size, (
|
| 377 |
+
f"Vocab mismatch: file={data['vocab_size']}, current={tok.vocab_size}"
|
| 378 |
+
)
|
| 379 |
+
return tok
|