PearlLeeStudio's picture
Initial release: TheArtist chord-generation paper companion
d206246 verified
"""Chord sequence tokenizer for Music Transformer training.
Vocabulary (~350 tokens):
[PAD]=0, [BOS]=1, [EOS]=2, [BAR]=3
[KEY:Cmaj] ... [KEY:Bmin] (24 keys)
[TIME:4/4] ... [TIME:5/4] (5 time sigs)
[GENRE:jazz] ... [GENRE:none] (6 genres)
Cmaj, Cm, C7, ... B7b13 (12 roots x 26 qualities = 312 chords)
"""
from __future__ import annotations
import json
from pathlib import Path
# Canonical root names (jazz convention: prefer flats)
ROOTS = ["C", "Db", "D", "Eb", "E", "F", "F#", "G", "Ab", "A", "Bb", "B"]
# Root name aliases for normalization
ROOT_ALIASES: dict[str, str] = {
"C#": "Db", "D#": "Eb", "E#": "F", "Fb": "E",
"G#": "Ab", "A#": "Bb", "B#": "C", "Cb": "B",
"Gb": "F#",
# Lowercase
"c": "C", "d": "D", "e": "E", "f": "F", "g": "G", "a": "A", "b": "B",
"c#": "Db", "db": "Db", "d#": "Eb", "eb": "Eb",
"f#": "F#", "gb": "F#", "g#": "Ab", "ab": "Ab",
"a#": "Bb", "bb": "Bb", "cb": "B", "fb": "E",
}
# Chord qualities in our vocabulary
QUALITIES = [
"maj", "m", "7", "maj7", "m7", "m7b5", "dim7", "dim", "aug",
"sus4", "sus2", "6", "m6", "9", "m9", "maj9", "11", "m11",
"13", "m13", "add9", "mMaj7", "7b9", "7#9", "7#11", "7b13",
]
# Quality alias mapping → canonical quality
_QUALITY_ALIASES: dict[str, str] = {
# Major
"major": "maj", "M": "maj",
# Minor
"min": "m", "minor": "m", "-": "m", "mi": "m",
# Dominant 7
"dom7": "7", "dom": "7",
# Major 7
"^7": "maj7", "M7": "maj7", "Maj7": "maj7", "major7": "maj7",
"j7": "maj7", "^": "maj7", "delta": "maj7",
# Minor 7
"min7": "m7", "-7": "m7", "mi7": "m7",
# Half-diminished
"hdim7": "m7b5", "hdim": "m7b5", "h7": "m7b5",
"%7": "m7b5", "%": "m7b5",
# Diminished
"o": "dim", "o7": "dim7",
# Augmented
"+": "aug",
# Suspended
"sus": "sus4",
# 6th
"min6": "m6", "-6": "m6",
# 9th
"min9": "m9", "-9": "m9", "M9": "maj9", "^9": "maj9", "Maj9": "maj9",
# 11th
"min11": "m11", "-11": "m11",
# 13th
"min13": "m13", "-13": "m13",
# Minor-major 7
"minmaj7": "mMaj7", "-^7": "mMaj7", "mM7": "mMaj7",
# Altered dominants
"7alt": "7b9",
}
# Keys and metadata
MAJOR_KEYS = [f"{r}maj" for r in ROOTS]
MINOR_KEYS = [f"{r}min" for r in ROOTS]
ALL_KEYS = MAJOR_KEYS + MINOR_KEYS
TIME_SIGS = ["4/4", "3/4", "6/8", "2/4", "5/4"]
GENRES = ["jazz", "pop", "rock", "blues", "bossa"]
class ChordTokenizer:
"""Deterministic tokenizer for chord sequences."""
PAD = 0
BOS = 1
EOS = 2
BAR = 3
def __init__(self) -> None:
self.token2id: dict[str, int] = {}
self.id2token: dict[int, str] = {}
self._build_vocab()
# ------------------------------------------------------------------
# Vocab construction
# ------------------------------------------------------------------
def _build_vocab(self) -> None:
tokens: list[str] = ["[PAD]", "[BOS]", "[EOS]", "[BAR]"]
for key in ALL_KEYS:
tokens.append(f"[KEY:{key}]")
for ts in TIME_SIGS:
tokens.append(f"[TIME:{ts}]")
for genre in GENRES:
tokens.append(f"[GENRE:{genre}]")
tokens.append("[GENRE:none]")
for root in ROOTS:
for quality in QUALITIES:
tokens.append(f"{root}{quality}")
for i, tok in enumerate(tokens):
self.token2id[tok] = i
self.id2token[i] = tok
@property
def vocab_size(self) -> int:
return len(self.token2id)
@property
def pad_id(self) -> int:
return self.PAD
@property
def bos_id(self) -> int:
return self.BOS
@property
def eos_id(self) -> int:
return self.EOS
@property
def bar_id(self) -> int:
return self.BAR
# ------------------------------------------------------------------
# Encoding helpers
# ------------------------------------------------------------------
def encode_chord(self, chord_str: str) -> int | None:
token = self.normalize_chord(chord_str)
return self.token2id.get(token) if token else None
def encode_key(self, key_str: str) -> int | None:
return self.token2id.get(f"[KEY:{key_str}]")
def encode_time_sig(self, ts: str) -> int | None:
return self.token2id.get(f"[TIME:{ts}]")
def encode_genre(self, genre: str) -> int | None:
return self.token2id.get(f"[GENRE:{genre}]")
def encode_sequence(self, song: dict) -> list[int]:
"""Encode a unified song dict to a token-ID sequence.
Expected *song* format::
{
"key": "Cmaj",
"time_signature": "4/4",
"genre": "jazz",
"bars": [["Cmaj7", "Am7"], ["Dm7", "G7"], ...]
}
"""
ids: list[int] = [self.BOS]
kid = self.encode_key(song.get("key", "Cmaj"))
if kid is not None:
ids.append(kid)
tid = self.encode_time_sig(song.get("time_signature", "4/4"))
if tid is not None:
ids.append(tid)
gid = self.encode_genre(song.get("genre", "none"))
if gid is not None:
ids.append(gid)
for bar in song.get("bars", []):
ids.append(self.BAR)
for chord in bar:
cid = self.encode_chord(chord)
if cid is not None:
ids.append(cid)
ids.append(self.EOS)
return ids
def decode(self, ids: list[int]) -> list[str]:
return [self.id2token.get(i, "[UNK]") for i in ids]
# ------------------------------------------------------------------
# Chord normalization
# ------------------------------------------------------------------
@staticmethod
def normalize_root(root: str) -> str | None:
"""Normalize a root note name to canonical form."""
if root in ROOTS:
return root
if root in ROOT_ALIASES:
return ROOT_ALIASES[root]
# Try capitalize first letter
cap = root[0].upper() + root[1:] if len(root) > 1 else root.upper()
if cap in ROOTS:
return cap
if cap in ROOT_ALIASES:
return ROOT_ALIASES[cap]
return None
@staticmethod
def normalize_chord(chord_str: str) -> str | None:
"""Normalize any chord notation to ``{Root}{quality}`` in our vocab."""
if not chord_str or chord_str in (
"N", "NC", "N.C.", "X", "x",
"pause", "silence", "&pause", "end",
):
return None
# Strip slash-chord bass
if "/" in chord_str:
chord_str = chord_str.split("/")[0]
# Billboard colon format Root:Quality
if ":" in chord_str:
root_part, qual_part = chord_str.split(":", 1)
# qual_part may also have /bass — already stripped above
else:
root_part = chord_str[0]
qual_part = chord_str[1:]
if qual_part and qual_part[0] in ("b", "#"):
root_part += qual_part[0]
qual_part = qual_part[1:]
norm_root = ChordTokenizer.normalize_root(root_part)
if norm_root is None:
return None
quality = ChordTokenizer._normalize_quality(qual_part)
if quality is None or quality not in QUALITIES:
return None
return f"{norm_root}{quality}"
@staticmethod
def _normalize_quality(q: str) -> str | None:
"""Map various quality notations to our canonical set."""
if not q:
return "maj"
# Direct hit
if q in QUALITIES:
return q
# Alias table
if q in _QUALITY_ALIASES:
return _QUALITY_ALIASES[q]
# Case-insensitive alias search
for alias, canon in _QUALITY_ALIASES.items():
if q.lower() == alias.lower():
return canon
# ---- Heuristic fallbacks for unusual notations ----
# WJazzD altered dominants: "79b" → 7b9, "79#" → 7#9, etc.
if q.startswith("7"):
tail = q[1:]
if "b9" in tail or "9b" in tail:
return "7b9"
if "#9" in tail or "9#" in tail:
return "7#9"
if "#11" in tail or "11#" in tail:
return "7#11"
if "b13" in tail or "13b" in tail:
return "7b13"
# Compound minor qualities
if q.startswith("m") or q.startswith("-"):
inner = q.lstrip("m").lstrip("-")
if "7" in inner and ("b5" in inner or "b5" in q):
return "m7b5"
if "7" in inner:
return "m7"
if "9" in inner:
return "m9"
if "11" in inner:
return "m11"
if "13" in inner:
return "m13"
if "6" in inner:
return "m6"
return "m"
# Bare numbers
if q in ("7",):
return "7"
if q in ("9",):
return "9"
if q in ("6",):
return "6"
if q in ("11",):
return "11"
if q in ("13",):
return "13"
# If nothing matched, approximate as major
return "maj"
# ------------------------------------------------------------------
# Transposition
# ------------------------------------------------------------------
def transpose_chord_token(self, token: str, semitones: int) -> str | None:
"""Transpose a chord token string by *semitones*."""
if token.startswith("["):
return None
root = token[0]
rest = token[1:]
if rest and rest[0] in ("b", "#"):
root += rest[0]
rest = rest[1:]
norm_root = self.normalize_root(root)
if norm_root is None:
return None
new_root = ROOTS[(ROOTS.index(norm_root) + semitones) % 12]
return f"{new_root}{rest}"
def transpose_key_token(self, token: str, semitones: int) -> str:
"""Transpose a key token like ``[KEY:Cmaj]``."""
inner = token[5:-1] # strip [KEY: and ]
if inner.endswith("maj"):
root, mode = inner[:-3], "maj"
elif inner.endswith("min"):
root, mode = inner[:-3], "min"
else:
return token
norm = self.normalize_root(root)
if norm is None:
return token
new_root = ROOTS[(ROOTS.index(norm) + semitones) % 12]
return f"[KEY:{new_root}{mode}]"
def transpose_sequence(self, ids: list[int], semitones: int) -> list[int]:
"""Transpose every chord & key token in *ids* by *semitones*."""
if semitones % 12 == 0:
return list(ids)
out: list[int] = []
for tid in ids:
tok = self.id2token.get(tid)
if tok is None:
out.append(tid)
elif tok.startswith("[KEY:"):
new = self.transpose_key_token(tok, semitones)
out.append(self.token2id.get(new, tid))
elif tok.startswith("[") or tid <= self.BAR:
out.append(tid)
else:
new = self.transpose_chord_token(tok, semitones)
out.append(self.token2id[new] if new and new in self.token2id else tid)
return out
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
def save(self, path: str | Path) -> None:
Path(path).write_text(json.dumps({
"token2id": self.token2id,
"vocab_size": self.vocab_size,
}, indent=2, ensure_ascii=False))
@classmethod
def load(cls, path: str | Path) -> ChordTokenizer:
tok = cls()
data = json.loads(Path(path).read_text())
assert data["vocab_size"] == tok.vocab_size, (
f"Vocab mismatch: file={data['vocab_size']}, current={tok.vocab_size}"
)
return tok