PearlLeeStudio's picture
Initial release: TheArtist chord-generation paper companion
d206246 verified
"""Music Transformer with relative attention for chord generation.
Architecture: Transformer decoder (autoregressive) with relative position
encoding (Shaw et al. 2018, efficient skewing from Huang et al. 2018).
Default config (~25M params):
d_model=512, n_heads=8, d_ff=2048, n_layers=8
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class RelativeMultiHeadAttention(nn.Module):
"""Multi-head self-attention with relative position bias."""
def __init__(
self,
d_model: int,
n_heads: int,
max_seq_len: int,
dropout: float = 0.1,
) -> None:
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.scale = math.sqrt(self.d_k)
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
# Learnable relative position embeddings: positions in [-max_len+1, max_len-1]
self.max_seq_len = max_seq_len
self.rel_emb = nn.Embedding(2 * max_seq_len - 1, self.d_k)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
"""
Args:
x: (B, L, D)
mask: (L, L) bool — True = masked (don't attend)
Returns:
(B, L, D)
"""
B, L, _ = x.shape
H, dk = self.n_heads, self.d_k
Q = self.w_q(x).view(B, L, H, dk).transpose(1, 2) # (B, H, L, dk)
K = self.w_k(x).view(B, L, H, dk).transpose(1, 2)
V = self.w_v(x).view(B, L, H, dk).transpose(1, 2)
# Content attention: Q K^T
content = torch.matmul(Q, K.transpose(-2, -1)) # (B, H, L, L)
# Relative position attention: Q R^T via efficient gather
rel = self._relative_attention(Q, L) # (B, H, L, L)
attn = (content + rel) / self.scale
if mask is not None:
attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
attn = self.dropout(F.softmax(attn, dim=-1))
out = torch.matmul(attn, V) # (B, H, L, dk)
out = out.transpose(1, 2).contiguous().view(B, L, -1)
return self.w_o(out)
def _relative_attention(self, Q: torch.Tensor, L: int) -> torch.Tensor:
"""Compute Q @ R^T using relative position embeddings.
Uses the index-gather approach: for each (i, j) pair, the relative
position is j - i, shifted to a non-negative index.
"""
device = Q.device
# Relative position indices: rel[i,j] = j - i + max_seq_len - 1
positions = torch.arange(L, device=device)
rel_idx = positions.unsqueeze(0) - positions.unsqueeze(1) + self.max_seq_len - 1
rel_idx = rel_idx.clamp(0, 2 * self.max_seq_len - 2)
R = self.rel_emb(rel_idx) # (L, L, dk)
# Q: (B, H, L, dk) R: (L, L, dk) → need (B, H, L, L)
# Reshape Q to (B*H, L, dk), bmm with R^T reshaped
BH = Q.shape[0] * Q.shape[1]
Q_flat = Q.reshape(BH, L, self.d_k) # (BH, L, dk)
# For each query position i, we want dot(Q[i], R[i, :, :]) → (BH, L, L)
# R: (L, L, dk) → transpose last two → (L, dk, L)
# Then Q_flat[:, i, :] @ R[i, :, :].T for each i
# Efficient: einsum
rel_score = torch.einsum("bld,lsd->bls", Q_flat, R) # (BH, L, L)
return rel_score.view(Q.shape[0], Q.shape[1], L, L)
class TransformerBlock(nn.Module):
"""Pre-norm Transformer decoder block."""
def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int,
max_seq_len: int,
dropout: float = 0.1,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = RelativeMultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
x = x + self.drop(self.attn(self.norm1(x), mask))
x = x + self.ffn(self.norm2(x))
return x
class MusicTransformer(nn.Module):
"""Autoregressive Music Transformer for chord generation."""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
n_heads: int = 8,
d_ff: int = 2048,
n_layers: int = 8,
max_seq_len: int = 512,
dropout: float = 0.1,
pad_id: int = 0,
) -> None:
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.pad_id = pad_id
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
self.drop = nn.Dropout(dropout)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, max_seq_len, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.out_proj = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying (embedding ↔ output projection)
self.out_proj.weight = self.token_emb.weight
self._init_weights()
def _init_weights(self) -> None:
for name, p in self.named_parameters():
if p.dim() > 1 and "token_emb" not in name:
nn.init.xavier_uniform_(p)
# Embedding std=1/sqrt(d_model) so that after *sqrt(d_model) scaling
# inputs have unit variance, and weight-tied output logits stay small
nn.init.normal_(self.token_emb.weight, mean=0.0, std=self.d_model ** -0.5)
@staticmethod
def _causal_mask(L: int, device: torch.device) -> torch.Tensor:
"""Upper-triangular causal mask (True = masked)."""
return torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
input_ids: (B, L) token IDs
Returns:
logits: (B, L, vocab_size)
"""
B, L = input_ids.shape
x = self.token_emb(input_ids) * math.sqrt(self.d_model)
x = self.drop(x)
mask = self._causal_mask(L, input_ids.device)
for layer in self.layers:
x = layer(x, mask)
return self.out_proj(self.norm(x))
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@torch.no_grad()
def generate(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int = 64,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 0.9,
eos_id: int = 2,
repetition_penalty: float = 1.0,
no_repeat_ngram_size: int = 0,
ignore_repeat_token_ids: set[int] | None = None,
) -> torch.Tensor:
"""Autoregressive generation from a prompt.
Args:
prompt_ids: (1, L) token IDs including [BOS] and context.
max_new_tokens: maximum tokens to generate.
temperature: sampling temperature (lower = more deterministic).
top_k: keep only top-k logits (0 = disabled).
top_p: nucleus sampling threshold.
eos_id: stop token.
repetition_penalty: divide logits of previously-seen tokens by
this factor (HF convention). > 1.0 discourages repeats.
1.0 disables. Typical: 1.2–1.5.
no_repeat_ngram_size: ban candidate tokens that would complete
an n-gram already present in the current sequence (n =
this value). 0 disables. Typical: 3 for chord sequences.
ignore_repeat_token_ids: token ids exempt from the two repetition
controls above — e.g. [BAR] or other separators that
*should* recur. If None, no exemptions.
Returns:
(1, L') full sequence including prompt and generated tokens.
"""
self.eval()
ids = prompt_ids.clone()
exempt = ignore_repeat_token_ids or set()
for _ in range(max_new_tokens):
ctx = ids[:, -self.max_seq_len :]
logits = self(ctx)[:, -1, :] / max(temperature, 1e-8)
# Repetition penalty (HuggingFace-style): scale already-seen token
# logits so they are less attractive. Positive logits get divided,
# negative logits get multiplied (stays "less attractive" either sign).
if repetition_penalty != 1.0:
seen = set(ids[0].tolist()) - exempt
if seen:
idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
vals = logits[0, idx]
vals = torch.where(
vals > 0,
vals / repetition_penalty,
vals * repetition_penalty,
)
logits[0, idx] = vals
# No-repeat n-gram: block any candidate token that would complete
# an n-gram already present earlier in the sequence.
if no_repeat_ngram_size > 0 and ids.shape[1] >= no_repeat_ngram_size:
n = no_repeat_ngram_size
seq = ids[0].tolist()
prefix = tuple(seq[-(n - 1):]) if n > 1 else ()
banned: set[int] = set()
for i in range(len(seq) - n + 1):
if tuple(seq[i : i + n - 1]) == prefix:
banned.add(seq[i + n - 1])
banned -= exempt
if banned:
bidx = torch.tensor(list(banned), device=logits.device, dtype=torch.long)
logits[0, bidx] = float("-inf")
# Top-k
if top_k > 0:
topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < topk_vals[:, -1:]] = float("-inf")
# Top-p (nucleus)
if 0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
sorted_logits[remove] = float("-inf")
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
ids = torch.cat([ids, next_id], dim=-1)
if (next_id == eos_id).all():
break
return ids