PearlLeeStudio commited on
Commit
f2e6b6d
·
verified ·
1 Parent(s): 8005576

Initial release: TheArtist chord-generation paper companion

Browse files
Files changed (7) hide show
  1. README.md +98 -0
  2. best.pt +3 -0
  3. config.json +25 -0
  4. eval_results.csv +8 -0
  5. model.py +294 -0
  6. tokenizer.json +356 -0
  7. tokenizer.py +379 -0
README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ library_name: pytorch
4
+ tags:
5
+ - music
6
+ - music-generation
7
+ - chord-generation
8
+ - symbolic-music
9
+ - music-transformer
10
+ - jazz
11
+ - pop
12
+ language:
13
+ - en
14
+ pipeline_tag: text-generation
15
+ ---
16
+
17
+ # TheArtist Music Transformer — F2 (Pop 5K Mix)
18
+
19
+ **Jazz-adapted chord model with a 5,000-sequence pop rehearsal buffer. Calibration point that the paper finds is dominated by F3 on every axis.**
20
+
21
+ One of six checkpoints released alongside the paper *Empirical Study of Pop and Jazz Mix Ratios for Genre-Adaptive Chord Generation* (Lee, 2026). See the collection overview at `PearlLeeStudio/TheArtist-MusicTransformer-pop-baseline`.
22
+
23
+ ## Model summary
24
+
25
+ | Field | Value |
26
+ |---|---|
27
+ | Architecture | Music Transformer with relative positional attention |
28
+ | Parameters | 25,661,440 |
29
+ | Vocabulary size | 351 tokens |
30
+ | Max sequence length | 256 |
31
+ | d_model / heads / FFN / layers | 512 / 8 / 2048 / 8 |
32
+ | Fine-tune resumed from | Phase 0 pop baseline |
33
+ | Best epoch | 4 |
34
+
35
+ ## Training data
36
+
37
+ All 1,513 jazz training sequences plus 5,000 pop rehearsal sequences (seed 42). Pop:jazz ≈ 3.3:1.
38
+
39
+ ## Evaluation (held-out per-genre test sets)
40
+
41
+ | Metric | Pop test | Jazz test |
42
+ |---|---:|---:|
43
+ | Top-1 accuracy | 84.07% | 79.90% |
44
+ | Top-5 accuracy | 97.04% | 92.14% |
45
+ | Perplexity | 1.75 | 2.33 |
46
+ | Δ vs. Phase 0 baseline | −0.17 | +7.04 |
47
+
48
+ F2 is dominated by F3 on every axis. It is released for reproducibility of the saturation curve described in the paper (see paper §6.1, §7.3) but is not the recommended choice for any operating point. Prefer F3 for the balanced setting, F1 for pop-leaning, or F4 for jazz-leaning.
49
+
50
+ ## Intended use
51
+
52
+ Reference checkpoint for replication and saturation-curve analysis. Not recommended as a default for chord-composition workflows.
53
+
54
+ ## Usage
55
+
56
+ ```python
57
+ import torch
58
+ from huggingface_hub import hf_hub_download
59
+ from model import MusicTransformer
60
+ from tokenizer import ChordTokenizer
61
+
62
+ ckpt_path = hf_hub_download(
63
+ repo_id="PearlLeeStudio/TheArtist-MusicTransformer-ft-pop67",
64
+ filename="best.pt",
65
+ )
66
+ tokenizer = ChordTokenizer()
67
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
68
+
69
+ model = MusicTransformer(
70
+ vocab_size=tokenizer.vocab_size,
71
+ d_model=512, n_heads=8, d_ff=2048, n_layers=8,
72
+ max_seq_len=256, dropout=0.0, pad_id=tokenizer.pad_id,
73
+ )
74
+ model.load_state_dict(ckpt["model_state_dict"])
75
+ model.eval()
76
+ ```
77
+
78
+ ## Training-data licenses
79
+
80
+ | Dataset | License |
81
+ |---|---|
82
+ | Chordonomicon | Public (user-generated) |
83
+ | McGill Billboard | CC0 |
84
+ | Jazz Harmony Treebank | Public |
85
+ | JazzStandards (iReal Pro) | Community redistribution |
86
+ | Weimar Jazz Database | ODbL |
87
+ | JAAH | Research-use public |
88
+
89
+ ## Citation
90
+
91
+ ```bibtex
92
+ @misc{lee2026chordmix,
93
+ title = {Empirical Study of Pop and Jazz Mix Ratios for Genre-Adaptive Chord Generation},
94
+ author = {Lee, Jinju},
95
+ year = {2026},
96
+ eprint = {arXiv:XXXX.XXXXX}
97
+ }
98
+ ```
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9aebcae5294c331aab43c509af698062f9bc7e50fb06e92280ee93126491d7b
3
+ size 308077642
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "run_name": "ft_jazz_pop67",
3
+ "resume_from": "checkpoints/phase0_pop_baseline/best.pt",
4
+ "pop_mix_count": 5000,
5
+ "epochs": 10,
6
+ "batch_size": 64,
7
+ "gradient_accumulation_steps": 2,
8
+ "lr": 2e-05,
9
+ "weight_decay": 0.01,
10
+ "warmup_epochs": 2,
11
+ "max_grad_norm": 1.0,
12
+ "d_model": 512,
13
+ "n_heads": 8,
14
+ "d_ff": 2048,
15
+ "n_layers": 8,
16
+ "max_seq_len": 256,
17
+ "dropout": 0.1,
18
+ "use_amp": true,
19
+ "checkpoint_every": 1,
20
+ "patience": 5,
21
+ "num_workers": 4,
22
+ "persistent_workers": true,
23
+ "prefetch_factor": 4,
24
+ "log_every_steps": 200
25
+ }
eval_results.csv ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ epoch,lr,train_loss,val_loss,val_ppl,val_top1,val_top5,pop_loss,pop_ppl,pop_top1,pop_top5,jazz_loss,jazz_ppl,jazz_top1,jazz_top5
2
+ 3,2.56e-04,0.7450,0.5703,1.77,83.99,96.81,0.5490,1.73,84.21,97.09,1.3893,4.01,72.86,86.51
3
+ 4,2.07e-04,0.6459,0.5660,1.76,84.03,96.05,0.5599,1.75,84.06,96.17,0.8482,2.34,79.91,91.46
4
+ 5,1.50e-04,0.6020,0.5750,1.78,83.83,96.01,0.5694,1.77,83.87,96.12,0.8305,2.29,80.28,91.86
5
+ 6,9.26e-05,0.5770,0.5843,1.79,83.67,95.97,0.5790,1.78,83.69,96.08,0.8298,2.29,80.33,91.73
6
+ 7,4.39e-05,0.5587,0.5926,1.81,83.54,95.94,0.5879,1.80,83.54,96.05,0.8339,2.30,80.19,91.77
7
+ 8,1.14e-05,0.5471,0.5983,1.82,83.40,96.78,0.5937,1.81,83.41,96.88,0.8365,2.31,80.09,92.58
8
+ 9,0.00e+00,0.5410,0.6005,1.82,83.38,96.78,0.5958,1.81,83.39,96.87,0.8374,2.31,80.12,92.62
model.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Music Transformer with relative attention for chord generation.
2
+
3
+ Architecture: Transformer decoder (autoregressive) with relative position
4
+ encoding (Shaw et al. 2018, efficient skewing from Huang et al. 2018).
5
+
6
+ Default config (~25M params):
7
+ d_model=512, n_heads=8, d_ff=2048, n_layers=8
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class RelativeMultiHeadAttention(nn.Module):
20
+ """Multi-head self-attention with relative position bias."""
21
+
22
+ def __init__(
23
+ self,
24
+ d_model: int,
25
+ n_heads: int,
26
+ max_seq_len: int,
27
+ dropout: float = 0.1,
28
+ ) -> None:
29
+ super().__init__()
30
+ assert d_model % n_heads == 0
31
+ self.n_heads = n_heads
32
+ self.d_k = d_model // n_heads
33
+ self.scale = math.sqrt(self.d_k)
34
+
35
+ self.w_q = nn.Linear(d_model, d_model)
36
+ self.w_k = nn.Linear(d_model, d_model)
37
+ self.w_v = nn.Linear(d_model, d_model)
38
+ self.w_o = nn.Linear(d_model, d_model)
39
+
40
+ # Learnable relative position embeddings: positions in [-max_len+1, max_len-1]
41
+ self.max_seq_len = max_seq_len
42
+ self.rel_emb = nn.Embedding(2 * max_seq_len - 1, self.d_k)
43
+ self.dropout = nn.Dropout(dropout)
44
+
45
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
46
+ """
47
+ Args:
48
+ x: (B, L, D)
49
+ mask: (L, L) bool — True = masked (don't attend)
50
+ Returns:
51
+ (B, L, D)
52
+ """
53
+ B, L, _ = x.shape
54
+ H, dk = self.n_heads, self.d_k
55
+
56
+ Q = self.w_q(x).view(B, L, H, dk).transpose(1, 2) # (B, H, L, dk)
57
+ K = self.w_k(x).view(B, L, H, dk).transpose(1, 2)
58
+ V = self.w_v(x).view(B, L, H, dk).transpose(1, 2)
59
+
60
+ # Content attention: Q K^T
61
+ content = torch.matmul(Q, K.transpose(-2, -1)) # (B, H, L, L)
62
+
63
+ # Relative position attention: Q R^T via efficient gather
64
+ rel = self._relative_attention(Q, L) # (B, H, L, L)
65
+
66
+ attn = (content + rel) / self.scale
67
+
68
+ if mask is not None:
69
+ attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
70
+
71
+ attn = self.dropout(F.softmax(attn, dim=-1))
72
+ out = torch.matmul(attn, V) # (B, H, L, dk)
73
+ out = out.transpose(1, 2).contiguous().view(B, L, -1)
74
+ return self.w_o(out)
75
+
76
+ def _relative_attention(self, Q: torch.Tensor, L: int) -> torch.Tensor:
77
+ """Compute Q @ R^T using relative position embeddings.
78
+
79
+ Uses the index-gather approach: for each (i, j) pair, the relative
80
+ position is j - i, shifted to a non-negative index.
81
+ """
82
+ device = Q.device
83
+ # Relative position indices: rel[i,j] = j - i + max_seq_len - 1
84
+ positions = torch.arange(L, device=device)
85
+ rel_idx = positions.unsqueeze(0) - positions.unsqueeze(1) + self.max_seq_len - 1
86
+ rel_idx = rel_idx.clamp(0, 2 * self.max_seq_len - 2)
87
+
88
+ R = self.rel_emb(rel_idx) # (L, L, dk)
89
+
90
+ # Q: (B, H, L, dk) R: (L, L, dk) → need (B, H, L, L)
91
+ # Reshape Q to (B*H, L, dk), bmm with R^T reshaped
92
+ BH = Q.shape[0] * Q.shape[1]
93
+ Q_flat = Q.reshape(BH, L, self.d_k) # (BH, L, dk)
94
+
95
+ # For each query position i, we want dot(Q[i], R[i, :, :]) → (BH, L, L)
96
+ # R: (L, L, dk) → transpose last two → (L, dk, L)
97
+ # Then Q_flat[:, i, :] @ R[i, :, :].T for each i
98
+ # Efficient: einsum
99
+ rel_score = torch.einsum("bld,lsd->bls", Q_flat, R) # (BH, L, L)
100
+ return rel_score.view(Q.shape[0], Q.shape[1], L, L)
101
+
102
+
103
+ class TransformerBlock(nn.Module):
104
+ """Pre-norm Transformer decoder block."""
105
+
106
+ def __init__(
107
+ self,
108
+ d_model: int,
109
+ n_heads: int,
110
+ d_ff: int,
111
+ max_seq_len: int,
112
+ dropout: float = 0.1,
113
+ ) -> None:
114
+ super().__init__()
115
+ self.norm1 = nn.LayerNorm(d_model)
116
+ self.attn = RelativeMultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
117
+ self.norm2 = nn.LayerNorm(d_model)
118
+ self.ffn = nn.Sequential(
119
+ nn.Linear(d_model, d_ff),
120
+ nn.GELU(),
121
+ nn.Dropout(dropout),
122
+ nn.Linear(d_ff, d_model),
123
+ nn.Dropout(dropout),
124
+ )
125
+ self.drop = nn.Dropout(dropout)
126
+
127
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
128
+ x = x + self.drop(self.attn(self.norm1(x), mask))
129
+ x = x + self.ffn(self.norm2(x))
130
+ return x
131
+
132
+
133
+ class MusicTransformer(nn.Module):
134
+ """Autoregressive Music Transformer for chord generation."""
135
+
136
+ def __init__(
137
+ self,
138
+ vocab_size: int,
139
+ d_model: int = 512,
140
+ n_heads: int = 8,
141
+ d_ff: int = 2048,
142
+ n_layers: int = 8,
143
+ max_seq_len: int = 512,
144
+ dropout: float = 0.1,
145
+ pad_id: int = 0,
146
+ ) -> None:
147
+ super().__init__()
148
+ self.d_model = d_model
149
+ self.max_seq_len = max_seq_len
150
+ self.pad_id = pad_id
151
+
152
+ self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
153
+ self.drop = nn.Dropout(dropout)
154
+
155
+ self.layers = nn.ModuleList([
156
+ TransformerBlock(d_model, n_heads, d_ff, max_seq_len, dropout)
157
+ for _ in range(n_layers)
158
+ ])
159
+
160
+ self.norm = nn.LayerNorm(d_model)
161
+ self.out_proj = nn.Linear(d_model, vocab_size, bias=False)
162
+
163
+ # Weight tying (embedding ↔ output projection)
164
+ self.out_proj.weight = self.token_emb.weight
165
+
166
+ self._init_weights()
167
+
168
+ def _init_weights(self) -> None:
169
+ for name, p in self.named_parameters():
170
+ if p.dim() > 1 and "token_emb" not in name:
171
+ nn.init.xavier_uniform_(p)
172
+ # Embedding std=1/sqrt(d_model) so that after *sqrt(d_model) scaling
173
+ # inputs have unit variance, and weight-tied output logits stay small
174
+ nn.init.normal_(self.token_emb.weight, mean=0.0, std=self.d_model ** -0.5)
175
+
176
+ @staticmethod
177
+ def _causal_mask(L: int, device: torch.device) -> torch.Tensor:
178
+ """Upper-triangular causal mask (True = masked)."""
179
+ return torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
180
+
181
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
182
+ """
183
+ Args:
184
+ input_ids: (B, L) token IDs
185
+ Returns:
186
+ logits: (B, L, vocab_size)
187
+ """
188
+ B, L = input_ids.shape
189
+ x = self.token_emb(input_ids) * math.sqrt(self.d_model)
190
+ x = self.drop(x)
191
+
192
+ mask = self._causal_mask(L, input_ids.device)
193
+ for layer in self.layers:
194
+ x = layer(x, mask)
195
+
196
+ return self.out_proj(self.norm(x))
197
+
198
+ def count_parameters(self) -> int:
199
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
200
+
201
+ @torch.no_grad()
202
+ def generate(
203
+ self,
204
+ prompt_ids: torch.Tensor,
205
+ max_new_tokens: int = 64,
206
+ temperature: float = 1.0,
207
+ top_k: int = 0,
208
+ top_p: float = 0.9,
209
+ eos_id: int = 2,
210
+ repetition_penalty: float = 1.0,
211
+ no_repeat_ngram_size: int = 0,
212
+ ignore_repeat_token_ids: set[int] | None = None,
213
+ ) -> torch.Tensor:
214
+ """Autoregressive generation from a prompt.
215
+
216
+ Args:
217
+ prompt_ids: (1, L) token IDs including [BOS] and context.
218
+ max_new_tokens: maximum tokens to generate.
219
+ temperature: sampling temperature (lower = more deterministic).
220
+ top_k: keep only top-k logits (0 = disabled).
221
+ top_p: nucleus sampling threshold.
222
+ eos_id: stop token.
223
+ repetition_penalty: divide logits of previously-seen tokens by
224
+ this factor (HF convention). > 1.0 discourages repeats.
225
+ 1.0 disables. Typical: 1.2–1.5.
226
+ no_repeat_ngram_size: ban candidate tokens that would complete
227
+ an n-gram already present in the current sequence (n =
228
+ this value). 0 disables. Typical: 3 for chord sequences.
229
+ ignore_repeat_token_ids: token ids exempt from the two repetition
230
+ controls above — e.g. [BAR] or other separators that
231
+ *should* recur. If None, no exemptions.
232
+
233
+ Returns:
234
+ (1, L') full sequence including prompt and generated tokens.
235
+ """
236
+ self.eval()
237
+ ids = prompt_ids.clone()
238
+ exempt = ignore_repeat_token_ids or set()
239
+
240
+ for _ in range(max_new_tokens):
241
+ ctx = ids[:, -self.max_seq_len :]
242
+ logits = self(ctx)[:, -1, :] / max(temperature, 1e-8)
243
+
244
+ # Repetition penalty (HuggingFace-style): scale already-seen token
245
+ # logits so they are less attractive. Positive logits get divided,
246
+ # negative logits get multiplied (stays "less attractive" either sign).
247
+ if repetition_penalty != 1.0:
248
+ seen = set(ids[0].tolist()) - exempt
249
+ if seen:
250
+ idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
251
+ vals = logits[0, idx]
252
+ vals = torch.where(
253
+ vals > 0,
254
+ vals / repetition_penalty,
255
+ vals * repetition_penalty,
256
+ )
257
+ logits[0, idx] = vals
258
+
259
+ # No-repeat n-gram: block any candidate token that would complete
260
+ # an n-gram already present earlier in the sequence.
261
+ if no_repeat_ngram_size > 0 and ids.shape[1] >= no_repeat_ngram_size:
262
+ n = no_repeat_ngram_size
263
+ seq = ids[0].tolist()
264
+ prefix = tuple(seq[-(n - 1):]) if n > 1 else ()
265
+ banned: set[int] = set()
266
+ for i in range(len(seq) - n + 1):
267
+ if tuple(seq[i : i + n - 1]) == prefix:
268
+ banned.add(seq[i + n - 1])
269
+ banned -= exempt
270
+ if banned:
271
+ bidx = torch.tensor(list(banned), device=logits.device, dtype=torch.long)
272
+ logits[0, bidx] = float("-inf")
273
+
274
+ # Top-k
275
+ if top_k > 0:
276
+ topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
277
+ logits[logits < topk_vals[:, -1:]] = float("-inf")
278
+
279
+ # Top-p (nucleus)
280
+ if 0 < top_p < 1.0:
281
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
282
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
283
+ remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
284
+ sorted_logits[remove] = float("-inf")
285
+ logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
286
+
287
+ probs = F.softmax(logits, dim=-1)
288
+ next_id = torch.multinomial(probs, num_samples=1)
289
+ ids = torch.cat([ids, next_id], dim=-1)
290
+
291
+ if (next_id == eos_id).all():
292
+ break
293
+
294
+ return ids
tokenizer.json ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "token2id": {
3
+ "[PAD]": 0,
4
+ "[BOS]": 1,
5
+ "[EOS]": 2,
6
+ "[BAR]": 3,
7
+ "[KEY:Cmaj]": 4,
8
+ "[KEY:Dbmaj]": 5,
9
+ "[KEY:Dmaj]": 6,
10
+ "[KEY:Ebmaj]": 7,
11
+ "[KEY:Emaj]": 8,
12
+ "[KEY:Fmaj]": 9,
13
+ "[KEY:F#maj]": 10,
14
+ "[KEY:Gmaj]": 11,
15
+ "[KEY:Abmaj]": 12,
16
+ "[KEY:Amaj]": 13,
17
+ "[KEY:Bbmaj]": 14,
18
+ "[KEY:Bmaj]": 15,
19
+ "[KEY:Cmin]": 16,
20
+ "[KEY:Dbmin]": 17,
21
+ "[KEY:Dmin]": 18,
22
+ "[KEY:Ebmin]": 19,
23
+ "[KEY:Emin]": 20,
24
+ "[KEY:Fmin]": 21,
25
+ "[KEY:F#min]": 22,
26
+ "[KEY:Gmin]": 23,
27
+ "[KEY:Abmin]": 24,
28
+ "[KEY:Amin]": 25,
29
+ "[KEY:Bbmin]": 26,
30
+ "[KEY:Bmin]": 27,
31
+ "[TIME:4/4]": 28,
32
+ "[TIME:3/4]": 29,
33
+ "[TIME:6/8]": 30,
34
+ "[TIME:2/4]": 31,
35
+ "[TIME:5/4]": 32,
36
+ "[GENRE:jazz]": 33,
37
+ "[GENRE:pop]": 34,
38
+ "[GENRE:rock]": 35,
39
+ "[GENRE:blues]": 36,
40
+ "[GENRE:bossa]": 37,
41
+ "[GENRE:none]": 38,
42
+ "Cmaj": 39,
43
+ "Cm": 40,
44
+ "C7": 41,
45
+ "Cmaj7": 42,
46
+ "Cm7": 43,
47
+ "Cm7b5": 44,
48
+ "Cdim7": 45,
49
+ "Cdim": 46,
50
+ "Caug": 47,
51
+ "Csus4": 48,
52
+ "Csus2": 49,
53
+ "C6": 50,
54
+ "Cm6": 51,
55
+ "C9": 52,
56
+ "Cm9": 53,
57
+ "Cmaj9": 54,
58
+ "C11": 55,
59
+ "Cm11": 56,
60
+ "C13": 57,
61
+ "Cm13": 58,
62
+ "Cadd9": 59,
63
+ "CmMaj7": 60,
64
+ "C7b9": 61,
65
+ "C7#9": 62,
66
+ "C7#11": 63,
67
+ "C7b13": 64,
68
+ "Dbmaj": 65,
69
+ "Dbm": 66,
70
+ "Db7": 67,
71
+ "Dbmaj7": 68,
72
+ "Dbm7": 69,
73
+ "Dbm7b5": 70,
74
+ "Dbdim7": 71,
75
+ "Dbdim": 72,
76
+ "Dbaug": 73,
77
+ "Dbsus4": 74,
78
+ "Dbsus2": 75,
79
+ "Db6": 76,
80
+ "Dbm6": 77,
81
+ "Db9": 78,
82
+ "Dbm9": 79,
83
+ "Dbmaj9": 80,
84
+ "Db11": 81,
85
+ "Dbm11": 82,
86
+ "Db13": 83,
87
+ "Dbm13": 84,
88
+ "Dbadd9": 85,
89
+ "DbmMaj7": 86,
90
+ "Db7b9": 87,
91
+ "Db7#9": 88,
92
+ "Db7#11": 89,
93
+ "Db7b13": 90,
94
+ "Dmaj": 91,
95
+ "Dm": 92,
96
+ "D7": 93,
97
+ "Dmaj7": 94,
98
+ "Dm7": 95,
99
+ "Dm7b5": 96,
100
+ "Ddim7": 97,
101
+ "Ddim": 98,
102
+ "Daug": 99,
103
+ "Dsus4": 100,
104
+ "Dsus2": 101,
105
+ "D6": 102,
106
+ "Dm6": 103,
107
+ "D9": 104,
108
+ "Dm9": 105,
109
+ "Dmaj9": 106,
110
+ "D11": 107,
111
+ "Dm11": 108,
112
+ "D13": 109,
113
+ "Dm13": 110,
114
+ "Dadd9": 111,
115
+ "DmMaj7": 112,
116
+ "D7b9": 113,
117
+ "D7#9": 114,
118
+ "D7#11": 115,
119
+ "D7b13": 116,
120
+ "Ebmaj": 117,
121
+ "Ebm": 118,
122
+ "Eb7": 119,
123
+ "Ebmaj7": 120,
124
+ "Ebm7": 121,
125
+ "Ebm7b5": 122,
126
+ "Ebdim7": 123,
127
+ "Ebdim": 124,
128
+ "Ebaug": 125,
129
+ "Ebsus4": 126,
130
+ "Ebsus2": 127,
131
+ "Eb6": 128,
132
+ "Ebm6": 129,
133
+ "Eb9": 130,
134
+ "Ebm9": 131,
135
+ "Ebmaj9": 132,
136
+ "Eb11": 133,
137
+ "Ebm11": 134,
138
+ "Eb13": 135,
139
+ "Ebm13": 136,
140
+ "Ebadd9": 137,
141
+ "EbmMaj7": 138,
142
+ "Eb7b9": 139,
143
+ "Eb7#9": 140,
144
+ "Eb7#11": 141,
145
+ "Eb7b13": 142,
146
+ "Emaj": 143,
147
+ "Em": 144,
148
+ "E7": 145,
149
+ "Emaj7": 146,
150
+ "Em7": 147,
151
+ "Em7b5": 148,
152
+ "Edim7": 149,
153
+ "Edim": 150,
154
+ "Eaug": 151,
155
+ "Esus4": 152,
156
+ "Esus2": 153,
157
+ "E6": 154,
158
+ "Em6": 155,
159
+ "E9": 156,
160
+ "Em9": 157,
161
+ "Emaj9": 158,
162
+ "E11": 159,
163
+ "Em11": 160,
164
+ "E13": 161,
165
+ "Em13": 162,
166
+ "Eadd9": 163,
167
+ "EmMaj7": 164,
168
+ "E7b9": 165,
169
+ "E7#9": 166,
170
+ "E7#11": 167,
171
+ "E7b13": 168,
172
+ "Fmaj": 169,
173
+ "Fm": 170,
174
+ "F7": 171,
175
+ "Fmaj7": 172,
176
+ "Fm7": 173,
177
+ "Fm7b5": 174,
178
+ "Fdim7": 175,
179
+ "Fdim": 176,
180
+ "Faug": 177,
181
+ "Fsus4": 178,
182
+ "Fsus2": 179,
183
+ "F6": 180,
184
+ "Fm6": 181,
185
+ "F9": 182,
186
+ "Fm9": 183,
187
+ "Fmaj9": 184,
188
+ "F11": 185,
189
+ "Fm11": 186,
190
+ "F13": 187,
191
+ "Fm13": 188,
192
+ "Fadd9": 189,
193
+ "FmMaj7": 190,
194
+ "F7b9": 191,
195
+ "F7#9": 192,
196
+ "F7#11": 193,
197
+ "F7b13": 194,
198
+ "F#maj": 195,
199
+ "F#m": 196,
200
+ "F#7": 197,
201
+ "F#maj7": 198,
202
+ "F#m7": 199,
203
+ "F#m7b5": 200,
204
+ "F#dim7": 201,
205
+ "F#dim": 202,
206
+ "F#aug": 203,
207
+ "F#sus4": 204,
208
+ "F#sus2": 205,
209
+ "F#6": 206,
210
+ "F#m6": 207,
211
+ "F#9": 208,
212
+ "F#m9": 209,
213
+ "F#maj9": 210,
214
+ "F#11": 211,
215
+ "F#m11": 212,
216
+ "F#13": 213,
217
+ "F#m13": 214,
218
+ "F#add9": 215,
219
+ "F#mMaj7": 216,
220
+ "F#7b9": 217,
221
+ "F#7#9": 218,
222
+ "F#7#11": 219,
223
+ "F#7b13": 220,
224
+ "Gmaj": 221,
225
+ "Gm": 222,
226
+ "G7": 223,
227
+ "Gmaj7": 224,
228
+ "Gm7": 225,
229
+ "Gm7b5": 226,
230
+ "Gdim7": 227,
231
+ "Gdim": 228,
232
+ "Gaug": 229,
233
+ "Gsus4": 230,
234
+ "Gsus2": 231,
235
+ "G6": 232,
236
+ "Gm6": 233,
237
+ "G9": 234,
238
+ "Gm9": 235,
239
+ "Gmaj9": 236,
240
+ "G11": 237,
241
+ "Gm11": 238,
242
+ "G13": 239,
243
+ "Gm13": 240,
244
+ "Gadd9": 241,
245
+ "GmMaj7": 242,
246
+ "G7b9": 243,
247
+ "G7#9": 244,
248
+ "G7#11": 245,
249
+ "G7b13": 246,
250
+ "Abmaj": 247,
251
+ "Abm": 248,
252
+ "Ab7": 249,
253
+ "Abmaj7": 250,
254
+ "Abm7": 251,
255
+ "Abm7b5": 252,
256
+ "Abdim7": 253,
257
+ "Abdim": 254,
258
+ "Abaug": 255,
259
+ "Absus4": 256,
260
+ "Absus2": 257,
261
+ "Ab6": 258,
262
+ "Abm6": 259,
263
+ "Ab9": 260,
264
+ "Abm9": 261,
265
+ "Abmaj9": 262,
266
+ "Ab11": 263,
267
+ "Abm11": 264,
268
+ "Ab13": 265,
269
+ "Abm13": 266,
270
+ "Abadd9": 267,
271
+ "AbmMaj7": 268,
272
+ "Ab7b9": 269,
273
+ "Ab7#9": 270,
274
+ "Ab7#11": 271,
275
+ "Ab7b13": 272,
276
+ "Amaj": 273,
277
+ "Am": 274,
278
+ "A7": 275,
279
+ "Amaj7": 276,
280
+ "Am7": 277,
281
+ "Am7b5": 278,
282
+ "Adim7": 279,
283
+ "Adim": 280,
284
+ "Aaug": 281,
285
+ "Asus4": 282,
286
+ "Asus2": 283,
287
+ "A6": 284,
288
+ "Am6": 285,
289
+ "A9": 286,
290
+ "Am9": 287,
291
+ "Amaj9": 288,
292
+ "A11": 289,
293
+ "Am11": 290,
294
+ "A13": 291,
295
+ "Am13": 292,
296
+ "Aadd9": 293,
297
+ "AmMaj7": 294,
298
+ "A7b9": 295,
299
+ "A7#9": 296,
300
+ "A7#11": 297,
301
+ "A7b13": 298,
302
+ "Bbmaj": 299,
303
+ "Bbm": 300,
304
+ "Bb7": 301,
305
+ "Bbmaj7": 302,
306
+ "Bbm7": 303,
307
+ "Bbm7b5": 304,
308
+ "Bbdim7": 305,
309
+ "Bbdim": 306,
310
+ "Bbaug": 307,
311
+ "Bbsus4": 308,
312
+ "Bbsus2": 309,
313
+ "Bb6": 310,
314
+ "Bbm6": 311,
315
+ "Bb9": 312,
316
+ "Bbm9": 313,
317
+ "Bbmaj9": 314,
318
+ "Bb11": 315,
319
+ "Bbm11": 316,
320
+ "Bb13": 317,
321
+ "Bbm13": 318,
322
+ "Bbadd9": 319,
323
+ "BbmMaj7": 320,
324
+ "Bb7b9": 321,
325
+ "Bb7#9": 322,
326
+ "Bb7#11": 323,
327
+ "Bb7b13": 324,
328
+ "Bmaj": 325,
329
+ "Bm": 326,
330
+ "B7": 327,
331
+ "Bmaj7": 328,
332
+ "Bm7": 329,
333
+ "Bm7b5": 330,
334
+ "Bdim7": 331,
335
+ "Bdim": 332,
336
+ "Baug": 333,
337
+ "Bsus4": 334,
338
+ "Bsus2": 335,
339
+ "B6": 336,
340
+ "Bm6": 337,
341
+ "B9": 338,
342
+ "Bm9": 339,
343
+ "Bmaj9": 340,
344
+ "B11": 341,
345
+ "Bm11": 342,
346
+ "B13": 343,
347
+ "Bm13": 344,
348
+ "Badd9": 345,
349
+ "BmMaj7": 346,
350
+ "B7b9": 347,
351
+ "B7#9": 348,
352
+ "B7#11": 349,
353
+ "B7b13": 350
354
+ },
355
+ "vocab_size": 351
356
+ }
tokenizer.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chord sequence tokenizer for Music Transformer training.
2
+
3
+ Vocabulary (~350 tokens):
4
+ [PAD]=0, [BOS]=1, [EOS]=2, [BAR]=3
5
+ [KEY:Cmaj] ... [KEY:Bmin] (24 keys)
6
+ [TIME:4/4] ... [TIME:5/4] (5 time sigs)
7
+ [GENRE:jazz] ... [GENRE:none] (6 genres)
8
+ Cmaj, Cm, C7, ... B7b13 (12 roots x 26 qualities = 312 chords)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ from pathlib import Path
15
+
16
+ # Canonical root names (jazz convention: prefer flats)
17
+ ROOTS = ["C", "Db", "D", "Eb", "E", "F", "F#", "G", "Ab", "A", "Bb", "B"]
18
+
19
+ # Root name aliases for normalization
20
+ ROOT_ALIASES: dict[str, str] = {
21
+ "C#": "Db", "D#": "Eb", "E#": "F", "Fb": "E",
22
+ "G#": "Ab", "A#": "Bb", "B#": "C", "Cb": "B",
23
+ "Gb": "F#",
24
+ # Lowercase
25
+ "c": "C", "d": "D", "e": "E", "f": "F", "g": "G", "a": "A", "b": "B",
26
+ "c#": "Db", "db": "Db", "d#": "Eb", "eb": "Eb",
27
+ "f#": "F#", "gb": "F#", "g#": "Ab", "ab": "Ab",
28
+ "a#": "Bb", "bb": "Bb", "cb": "B", "fb": "E",
29
+ }
30
+
31
+ # Chord qualities in our vocabulary
32
+ QUALITIES = [
33
+ "maj", "m", "7", "maj7", "m7", "m7b5", "dim7", "dim", "aug",
34
+ "sus4", "sus2", "6", "m6", "9", "m9", "maj9", "11", "m11",
35
+ "13", "m13", "add9", "mMaj7", "7b9", "7#9", "7#11", "7b13",
36
+ ]
37
+
38
+ # Quality alias mapping → canonical quality
39
+ _QUALITY_ALIASES: dict[str, str] = {
40
+ # Major
41
+ "major": "maj", "M": "maj",
42
+ # Minor
43
+ "min": "m", "minor": "m", "-": "m", "mi": "m",
44
+ # Dominant 7
45
+ "dom7": "7", "dom": "7",
46
+ # Major 7
47
+ "^7": "maj7", "M7": "maj7", "Maj7": "maj7", "major7": "maj7",
48
+ "j7": "maj7", "^": "maj7", "delta": "maj7",
49
+ # Minor 7
50
+ "min7": "m7", "-7": "m7", "mi7": "m7",
51
+ # Half-diminished
52
+ "hdim7": "m7b5", "hdim": "m7b5", "h7": "m7b5",
53
+ "%7": "m7b5", "%": "m7b5",
54
+ # Diminished
55
+ "o": "dim", "o7": "dim7",
56
+ # Augmented
57
+ "+": "aug",
58
+ # Suspended
59
+ "sus": "sus4",
60
+ # 6th
61
+ "min6": "m6", "-6": "m6",
62
+ # 9th
63
+ "min9": "m9", "-9": "m9", "M9": "maj9", "^9": "maj9", "Maj9": "maj9",
64
+ # 11th
65
+ "min11": "m11", "-11": "m11",
66
+ # 13th
67
+ "min13": "m13", "-13": "m13",
68
+ # Minor-major 7
69
+ "minmaj7": "mMaj7", "-^7": "mMaj7", "mM7": "mMaj7",
70
+ # Altered dominants
71
+ "7alt": "7b9",
72
+ }
73
+
74
+ # Keys and metadata
75
+ MAJOR_KEYS = [f"{r}maj" for r in ROOTS]
76
+ MINOR_KEYS = [f"{r}min" for r in ROOTS]
77
+ ALL_KEYS = MAJOR_KEYS + MINOR_KEYS
78
+ TIME_SIGS = ["4/4", "3/4", "6/8", "2/4", "5/4"]
79
+ GENRES = ["jazz", "pop", "rock", "blues", "bossa"]
80
+
81
+
82
+ class ChordTokenizer:
83
+ """Deterministic tokenizer for chord sequences."""
84
+
85
+ PAD = 0
86
+ BOS = 1
87
+ EOS = 2
88
+ BAR = 3
89
+
90
+ def __init__(self) -> None:
91
+ self.token2id: dict[str, int] = {}
92
+ self.id2token: dict[int, str] = {}
93
+ self._build_vocab()
94
+
95
+ # ------------------------------------------------------------------
96
+ # Vocab construction
97
+ # ------------------------------------------------------------------
98
+
99
+ def _build_vocab(self) -> None:
100
+ tokens: list[str] = ["[PAD]", "[BOS]", "[EOS]", "[BAR]"]
101
+ for key in ALL_KEYS:
102
+ tokens.append(f"[KEY:{key}]")
103
+ for ts in TIME_SIGS:
104
+ tokens.append(f"[TIME:{ts}]")
105
+ for genre in GENRES:
106
+ tokens.append(f"[GENRE:{genre}]")
107
+ tokens.append("[GENRE:none]")
108
+ for root in ROOTS:
109
+ for quality in QUALITIES:
110
+ tokens.append(f"{root}{quality}")
111
+ for i, tok in enumerate(tokens):
112
+ self.token2id[tok] = i
113
+ self.id2token[i] = tok
114
+
115
+ @property
116
+ def vocab_size(self) -> int:
117
+ return len(self.token2id)
118
+
119
+ @property
120
+ def pad_id(self) -> int:
121
+ return self.PAD
122
+
123
+ @property
124
+ def bos_id(self) -> int:
125
+ return self.BOS
126
+
127
+ @property
128
+ def eos_id(self) -> int:
129
+ return self.EOS
130
+
131
+ @property
132
+ def bar_id(self) -> int:
133
+ return self.BAR
134
+
135
+ # ------------------------------------------------------------------
136
+ # Encoding helpers
137
+ # ------------------------------------------------------------------
138
+
139
+ def encode_chord(self, chord_str: str) -> int | None:
140
+ token = self.normalize_chord(chord_str)
141
+ return self.token2id.get(token) if token else None
142
+
143
+ def encode_key(self, key_str: str) -> int | None:
144
+ return self.token2id.get(f"[KEY:{key_str}]")
145
+
146
+ def encode_time_sig(self, ts: str) -> int | None:
147
+ return self.token2id.get(f"[TIME:{ts}]")
148
+
149
+ def encode_genre(self, genre: str) -> int | None:
150
+ return self.token2id.get(f"[GENRE:{genre}]")
151
+
152
+ def encode_sequence(self, song: dict) -> list[int]:
153
+ """Encode a unified song dict to a token-ID sequence.
154
+
155
+ Expected *song* format::
156
+
157
+ {
158
+ "key": "Cmaj",
159
+ "time_signature": "4/4",
160
+ "genre": "jazz",
161
+ "bars": [["Cmaj7", "Am7"], ["Dm7", "G7"], ...]
162
+ }
163
+ """
164
+ ids: list[int] = [self.BOS]
165
+
166
+ kid = self.encode_key(song.get("key", "Cmaj"))
167
+ if kid is not None:
168
+ ids.append(kid)
169
+
170
+ tid = self.encode_time_sig(song.get("time_signature", "4/4"))
171
+ if tid is not None:
172
+ ids.append(tid)
173
+
174
+ gid = self.encode_genre(song.get("genre", "none"))
175
+ if gid is not None:
176
+ ids.append(gid)
177
+
178
+ for bar in song.get("bars", []):
179
+ ids.append(self.BAR)
180
+ for chord in bar:
181
+ cid = self.encode_chord(chord)
182
+ if cid is not None:
183
+ ids.append(cid)
184
+
185
+ ids.append(self.EOS)
186
+ return ids
187
+
188
+ def decode(self, ids: list[int]) -> list[str]:
189
+ return [self.id2token.get(i, "[UNK]") for i in ids]
190
+
191
+ # ------------------------------------------------------------------
192
+ # Chord normalization
193
+ # ------------------------------------------------------------------
194
+
195
+ @staticmethod
196
+ def normalize_root(root: str) -> str | None:
197
+ """Normalize a root note name to canonical form."""
198
+ if root in ROOTS:
199
+ return root
200
+ if root in ROOT_ALIASES:
201
+ return ROOT_ALIASES[root]
202
+ # Try capitalize first letter
203
+ cap = root[0].upper() + root[1:] if len(root) > 1 else root.upper()
204
+ if cap in ROOTS:
205
+ return cap
206
+ if cap in ROOT_ALIASES:
207
+ return ROOT_ALIASES[cap]
208
+ return None
209
+
210
+ @staticmethod
211
+ def normalize_chord(chord_str: str) -> str | None:
212
+ """Normalize any chord notation to ``{Root}{quality}`` in our vocab."""
213
+ if not chord_str or chord_str in (
214
+ "N", "NC", "N.C.", "X", "x",
215
+ "pause", "silence", "&pause", "end",
216
+ ):
217
+ return None
218
+
219
+ # Strip slash-chord bass
220
+ if "/" in chord_str:
221
+ chord_str = chord_str.split("/")[0]
222
+
223
+ # Billboard colon format Root:Quality
224
+ if ":" in chord_str:
225
+ root_part, qual_part = chord_str.split(":", 1)
226
+ # qual_part may also have /bass — already stripped above
227
+ else:
228
+ root_part = chord_str[0]
229
+ qual_part = chord_str[1:]
230
+ if qual_part and qual_part[0] in ("b", "#"):
231
+ root_part += qual_part[0]
232
+ qual_part = qual_part[1:]
233
+
234
+ norm_root = ChordTokenizer.normalize_root(root_part)
235
+ if norm_root is None:
236
+ return None
237
+
238
+ quality = ChordTokenizer._normalize_quality(qual_part)
239
+ if quality is None or quality not in QUALITIES:
240
+ return None
241
+
242
+ return f"{norm_root}{quality}"
243
+
244
+ @staticmethod
245
+ def _normalize_quality(q: str) -> str | None:
246
+ """Map various quality notations to our canonical set."""
247
+ if not q:
248
+ return "maj"
249
+
250
+ # Direct hit
251
+ if q in QUALITIES:
252
+ return q
253
+
254
+ # Alias table
255
+ if q in _QUALITY_ALIASES:
256
+ return _QUALITY_ALIASES[q]
257
+
258
+ # Case-insensitive alias search
259
+ for alias, canon in _QUALITY_ALIASES.items():
260
+ if q.lower() == alias.lower():
261
+ return canon
262
+
263
+ # ---- Heuristic fallbacks for unusual notations ----
264
+
265
+ # WJazzD altered dominants: "79b" → 7b9, "79#" → 7#9, etc.
266
+ if q.startswith("7"):
267
+ tail = q[1:]
268
+ if "b9" in tail or "9b" in tail:
269
+ return "7b9"
270
+ if "#9" in tail or "9#" in tail:
271
+ return "7#9"
272
+ if "#11" in tail or "11#" in tail:
273
+ return "7#11"
274
+ if "b13" in tail or "13b" in tail:
275
+ return "7b13"
276
+
277
+ # Compound minor qualities
278
+ if q.startswith("m") or q.startswith("-"):
279
+ inner = q.lstrip("m").lstrip("-")
280
+ if "7" in inner and ("b5" in inner or "b5" in q):
281
+ return "m7b5"
282
+ if "7" in inner:
283
+ return "m7"
284
+ if "9" in inner:
285
+ return "m9"
286
+ if "11" in inner:
287
+ return "m11"
288
+ if "13" in inner:
289
+ return "m13"
290
+ if "6" in inner:
291
+ return "m6"
292
+ return "m"
293
+
294
+ # Bare numbers
295
+ if q in ("7",):
296
+ return "7"
297
+ if q in ("9",):
298
+ return "9"
299
+ if q in ("6",):
300
+ return "6"
301
+ if q in ("11",):
302
+ return "11"
303
+ if q in ("13",):
304
+ return "13"
305
+
306
+ # If nothing matched, approximate as major
307
+ return "maj"
308
+
309
+ # ------------------------------------------------------------------
310
+ # Transposition
311
+ # ------------------------------------------------------------------
312
+
313
+ def transpose_chord_token(self, token: str, semitones: int) -> str | None:
314
+ """Transpose a chord token string by *semitones*."""
315
+ if token.startswith("["):
316
+ return None
317
+ root = token[0]
318
+ rest = token[1:]
319
+ if rest and rest[0] in ("b", "#"):
320
+ root += rest[0]
321
+ rest = rest[1:]
322
+ norm_root = self.normalize_root(root)
323
+ if norm_root is None:
324
+ return None
325
+ new_root = ROOTS[(ROOTS.index(norm_root) + semitones) % 12]
326
+ return f"{new_root}{rest}"
327
+
328
+ def transpose_key_token(self, token: str, semitones: int) -> str:
329
+ """Transpose a key token like ``[KEY:Cmaj]``."""
330
+ inner = token[5:-1] # strip [KEY: and ]
331
+ if inner.endswith("maj"):
332
+ root, mode = inner[:-3], "maj"
333
+ elif inner.endswith("min"):
334
+ root, mode = inner[:-3], "min"
335
+ else:
336
+ return token
337
+ norm = self.normalize_root(root)
338
+ if norm is None:
339
+ return token
340
+ new_root = ROOTS[(ROOTS.index(norm) + semitones) % 12]
341
+ return f"[KEY:{new_root}{mode}]"
342
+
343
+ def transpose_sequence(self, ids: list[int], semitones: int) -> list[int]:
344
+ """Transpose every chord & key token in *ids* by *semitones*."""
345
+ if semitones % 12 == 0:
346
+ return list(ids)
347
+ out: list[int] = []
348
+ for tid in ids:
349
+ tok = self.id2token.get(tid)
350
+ if tok is None:
351
+ out.append(tid)
352
+ elif tok.startswith("[KEY:"):
353
+ new = self.transpose_key_token(tok, semitones)
354
+ out.append(self.token2id.get(new, tid))
355
+ elif tok.startswith("[") or tid <= self.BAR:
356
+ out.append(tid)
357
+ else:
358
+ new = self.transpose_chord_token(tok, semitones)
359
+ out.append(self.token2id[new] if new and new in self.token2id else tid)
360
+ return out
361
+
362
+ # ------------------------------------------------------------------
363
+ # Persistence
364
+ # ------------------------------------------------------------------
365
+
366
+ def save(self, path: str | Path) -> None:
367
+ Path(path).write_text(json.dumps({
368
+ "token2id": self.token2id,
369
+ "vocab_size": self.vocab_size,
370
+ }, indent=2, ensure_ascii=False))
371
+
372
+ @classmethod
373
+ def load(cls, path: str | Path) -> ChordTokenizer:
374
+ tok = cls()
375
+ data = json.loads(Path(path).read_text())
376
+ assert data["vocab_size"] == tok.vocab_size, (
377
+ f"Vocab mismatch: file={data['vocab_size']}, current={tok.vocab_size}"
378
+ )
379
+ return tok