| """Music Transformer with relative attention for chord generation. |
| |
| Architecture: Transformer decoder (autoregressive) with relative position |
| encoding (Shaw et al. 2018, efficient skewing from Huang et al. 2018). |
| |
| Default config (~25M params): |
| d_model=512, n_heads=8, d_ff=2048, n_layers=8 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class RelativeMultiHeadAttention(nn.Module): |
| """Multi-head self-attention with relative position bias.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int, |
| max_seq_len: int, |
| dropout: float = 0.1, |
| ) -> None: |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.d_k = d_model // n_heads |
| self.scale = math.sqrt(self.d_k) |
|
|
| self.w_q = nn.Linear(d_model, d_model) |
| self.w_k = nn.Linear(d_model, d_model) |
| self.w_v = nn.Linear(d_model, d_model) |
| self.w_o = nn.Linear(d_model, d_model) |
|
|
| |
| self.max_seq_len = max_seq_len |
| self.rel_emb = nn.Embedding(2 * max_seq_len - 1, self.d_k) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, L, D) |
| mask: (L, L) bool — True = masked (don't attend) |
| Returns: |
| (B, L, D) |
| """ |
| B, L, _ = x.shape |
| H, dk = self.n_heads, self.d_k |
|
|
| Q = self.w_q(x).view(B, L, H, dk).transpose(1, 2) |
| K = self.w_k(x).view(B, L, H, dk).transpose(1, 2) |
| V = self.w_v(x).view(B, L, H, dk).transpose(1, 2) |
|
|
| |
| content = torch.matmul(Q, K.transpose(-2, -1)) |
|
|
| |
| rel = self._relative_attention(Q, L) |
|
|
| attn = (content + rel) / self.scale |
|
|
| if mask is not None: |
| attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf")) |
|
|
| attn = self.dropout(F.softmax(attn, dim=-1)) |
| out = torch.matmul(attn, V) |
| out = out.transpose(1, 2).contiguous().view(B, L, -1) |
| return self.w_o(out) |
|
|
| def _relative_attention(self, Q: torch.Tensor, L: int) -> torch.Tensor: |
| """Compute Q @ R^T using relative position embeddings. |
| |
| Uses the index-gather approach: for each (i, j) pair, the relative |
| position is j - i, shifted to a non-negative index. |
| """ |
| device = Q.device |
| |
| positions = torch.arange(L, device=device) |
| rel_idx = positions.unsqueeze(0) - positions.unsqueeze(1) + self.max_seq_len - 1 |
| rel_idx = rel_idx.clamp(0, 2 * self.max_seq_len - 2) |
|
|
| R = self.rel_emb(rel_idx) |
|
|
| |
| |
| BH = Q.shape[0] * Q.shape[1] |
| Q_flat = Q.reshape(BH, L, self.d_k) |
|
|
| |
| |
| |
| |
| rel_score = torch.einsum("bld,lsd->bls", Q_flat, R) |
| return rel_score.view(Q.shape[0], Q.shape[1], L, L) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Pre-norm Transformer decoder block.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int, |
| d_ff: int, |
| max_seq_len: int, |
| dropout: float = 0.1, |
| ) -> None: |
| super().__init__() |
| self.norm1 = nn.LayerNorm(d_model) |
| self.attn = RelativeMultiHeadAttention(d_model, n_heads, max_seq_len, dropout) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, d_ff), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model), |
| nn.Dropout(dropout), |
| ) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: |
| x = x + self.drop(self.attn(self.norm1(x), mask)) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class MusicTransformer(nn.Module): |
| """Autoregressive Music Transformer for chord generation.""" |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| d_model: int = 512, |
| n_heads: int = 8, |
| d_ff: int = 2048, |
| n_layers: int = 8, |
| max_seq_len: int = 512, |
| dropout: float = 0.1, |
| pad_id: int = 0, |
| ) -> None: |
| super().__init__() |
| self.d_model = d_model |
| self.max_seq_len = max_seq_len |
| self.pad_id = pad_id |
|
|
| self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id) |
| self.drop = nn.Dropout(dropout) |
|
|
| self.layers = nn.ModuleList([ |
| TransformerBlock(d_model, n_heads, d_ff, max_seq_len, dropout) |
| for _ in range(n_layers) |
| ]) |
|
|
| self.norm = nn.LayerNorm(d_model) |
| self.out_proj = nn.Linear(d_model, vocab_size, bias=False) |
|
|
| |
| self.out_proj.weight = self.token_emb.weight |
|
|
| self._init_weights() |
|
|
| def _init_weights(self) -> None: |
| for name, p in self.named_parameters(): |
| if p.dim() > 1 and "token_emb" not in name: |
| nn.init.xavier_uniform_(p) |
| |
| |
| nn.init.normal_(self.token_emb.weight, mean=0.0, std=self.d_model ** -0.5) |
|
|
| @staticmethod |
| def _causal_mask(L: int, device: torch.device) -> torch.Tensor: |
| """Upper-triangular causal mask (True = masked).""" |
| return torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1) |
|
|
| def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| input_ids: (B, L) token IDs |
| Returns: |
| logits: (B, L, vocab_size) |
| """ |
| B, L = input_ids.shape |
| x = self.token_emb(input_ids) * math.sqrt(self.d_model) |
| x = self.drop(x) |
|
|
| mask = self._causal_mask(L, input_ids.device) |
| for layer in self.layers: |
| x = layer(x, mask) |
|
|
| return self.out_proj(self.norm(x)) |
|
|
| def count_parameters(self) -> int: |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| prompt_ids: torch.Tensor, |
| max_new_tokens: int = 64, |
| temperature: float = 1.0, |
| top_k: int = 0, |
| top_p: float = 0.9, |
| eos_id: int = 2, |
| repetition_penalty: float = 1.0, |
| no_repeat_ngram_size: int = 0, |
| ignore_repeat_token_ids: set[int] | None = None, |
| ) -> torch.Tensor: |
| """Autoregressive generation from a prompt. |
| |
| Args: |
| prompt_ids: (1, L) token IDs including [BOS] and context. |
| max_new_tokens: maximum tokens to generate. |
| temperature: sampling temperature (lower = more deterministic). |
| top_k: keep only top-k logits (0 = disabled). |
| top_p: nucleus sampling threshold. |
| eos_id: stop token. |
| repetition_penalty: divide logits of previously-seen tokens by |
| this factor (HF convention). > 1.0 discourages repeats. |
| 1.0 disables. Typical: 1.2–1.5. |
| no_repeat_ngram_size: ban candidate tokens that would complete |
| an n-gram already present in the current sequence (n = |
| this value). 0 disables. Typical: 3 for chord sequences. |
| ignore_repeat_token_ids: token ids exempt from the two repetition |
| controls above — e.g. [BAR] or other separators that |
| *should* recur. If None, no exemptions. |
| |
| Returns: |
| (1, L') full sequence including prompt and generated tokens. |
| """ |
| self.eval() |
| ids = prompt_ids.clone() |
| exempt = ignore_repeat_token_ids or set() |
|
|
| for _ in range(max_new_tokens): |
| ctx = ids[:, -self.max_seq_len :] |
| logits = self(ctx)[:, -1, :] / max(temperature, 1e-8) |
|
|
| |
| |
| |
| if repetition_penalty != 1.0: |
| seen = set(ids[0].tolist()) - exempt |
| if seen: |
| idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long) |
| vals = logits[0, idx] |
| vals = torch.where( |
| vals > 0, |
| vals / repetition_penalty, |
| vals * repetition_penalty, |
| ) |
| logits[0, idx] = vals |
|
|
| |
| |
| if no_repeat_ngram_size > 0 and ids.shape[1] >= no_repeat_ngram_size: |
| n = no_repeat_ngram_size |
| seq = ids[0].tolist() |
| prefix = tuple(seq[-(n - 1):]) if n > 1 else () |
| banned: set[int] = set() |
| for i in range(len(seq) - n + 1): |
| if tuple(seq[i : i + n - 1]) == prefix: |
| banned.add(seq[i + n - 1]) |
| banned -= exempt |
| if banned: |
| bidx = torch.tensor(list(banned), device=logits.device, dtype=torch.long) |
| logits[0, bidx] = float("-inf") |
|
|
| |
| if top_k > 0: |
| topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < topk_vals[:, -1:]] = float("-inf") |
|
|
| |
| if 0 < top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p |
| sorted_logits[remove] = float("-inf") |
| logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| ids = torch.cat([ids, next_id], dim=-1) |
|
|
| if (next_id == eos_id).all(): |
| break |
|
|
| return ids |
|
|