File size: 10,957 Bytes
f2e6b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""Music Transformer with relative attention for chord generation.

Architecture: Transformer decoder (autoregressive) with relative position
encoding (Shaw et al. 2018, efficient skewing from Huang et al. 2018).

Default config (~25M params):
    d_model=512, n_heads=8, d_ff=2048, n_layers=8
"""

from __future__ import annotations

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class RelativeMultiHeadAttention(nn.Module):
    """Multi-head self-attention with relative position bias."""

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        max_seq_len: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.scale = math.sqrt(self.d_k)

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

        # Learnable relative position embeddings: positions in [-max_len+1, max_len-1]
        self.max_seq_len = max_seq_len
        self.rel_emb = nn.Embedding(2 * max_seq_len - 1, self.d_k)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        """
        Args:
            x: (B, L, D)
            mask: (L, L) bool — True = masked (don't attend)
        Returns:
            (B, L, D)
        """
        B, L, _ = x.shape
        H, dk = self.n_heads, self.d_k

        Q = self.w_q(x).view(B, L, H, dk).transpose(1, 2)  # (B, H, L, dk)
        K = self.w_k(x).view(B, L, H, dk).transpose(1, 2)
        V = self.w_v(x).view(B, L, H, dk).transpose(1, 2)

        # Content attention: Q K^T
        content = torch.matmul(Q, K.transpose(-2, -1))  # (B, H, L, L)

        # Relative position attention: Q R^T via efficient gather
        rel = self._relative_attention(Q, L)  # (B, H, L, L)

        attn = (content + rel) / self.scale

        if mask is not None:
            attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))

        attn = self.dropout(F.softmax(attn, dim=-1))
        out = torch.matmul(attn, V)  # (B, H, L, dk)
        out = out.transpose(1, 2).contiguous().view(B, L, -1)
        return self.w_o(out)

    def _relative_attention(self, Q: torch.Tensor, L: int) -> torch.Tensor:
        """Compute Q @ R^T using relative position embeddings.

        Uses the index-gather approach: for each (i, j) pair, the relative
        position is j - i, shifted to a non-negative index.
        """
        device = Q.device
        # Relative position indices: rel[i,j] = j - i + max_seq_len - 1
        positions = torch.arange(L, device=device)
        rel_idx = positions.unsqueeze(0) - positions.unsqueeze(1) + self.max_seq_len - 1
        rel_idx = rel_idx.clamp(0, 2 * self.max_seq_len - 2)

        R = self.rel_emb(rel_idx)  # (L, L, dk)

        # Q: (B, H, L, dk)  R: (L, L, dk) → need (B, H, L, L)
        # Reshape Q to (B*H, L, dk), bmm with R^T reshaped
        BH = Q.shape[0] * Q.shape[1]
        Q_flat = Q.reshape(BH, L, self.d_k)  # (BH, L, dk)

        # For each query position i, we want dot(Q[i], R[i, :, :]) → (BH, L, L)
        # R: (L, L, dk) → transpose last two → (L, dk, L)
        # Then Q_flat[:, i, :] @ R[i, :, :].T for each i
        # Efficient: einsum
        rel_score = torch.einsum("bld,lsd->bls", Q_flat, R)  # (BH, L, L)
        return rel_score.view(Q.shape[0], Q.shape[1], L, L)


class TransformerBlock(nn.Module):
    """Pre-norm Transformer decoder block."""

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        max_seq_len: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = RelativeMultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        x = x + self.drop(self.attn(self.norm1(x), mask))
        x = x + self.ffn(self.norm2(x))
        return x


class MusicTransformer(nn.Module):
    """Autoregressive Music Transformer for chord generation."""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        d_ff: int = 2048,
        n_layers: int = 8,
        max_seq_len: int = 512,
        dropout: float = 0.1,
        pad_id: int = 0,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.pad_id = pad_id

        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.drop = nn.Dropout(dropout)

        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, max_seq_len, dropout)
            for _ in range(n_layers)
        ])

        self.norm = nn.LayerNorm(d_model)
        self.out_proj = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying (embedding ↔ output projection)
        self.out_proj.weight = self.token_emb.weight

        self._init_weights()

    def _init_weights(self) -> None:
        for name, p in self.named_parameters():
            if p.dim() > 1 and "token_emb" not in name:
                nn.init.xavier_uniform_(p)
        # Embedding std=1/sqrt(d_model) so that after *sqrt(d_model) scaling
        # inputs have unit variance, and weight-tied output logits stay small
        nn.init.normal_(self.token_emb.weight, mean=0.0, std=self.d_model ** -0.5)

    @staticmethod
    def _causal_mask(L: int, device: torch.device) -> torch.Tensor:
        """Upper-triangular causal mask (True = masked)."""
        return torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input_ids: (B, L) token IDs
        Returns:
            logits: (B, L, vocab_size)
        """
        B, L = input_ids.shape
        x = self.token_emb(input_ids) * math.sqrt(self.d_model)
        x = self.drop(x)

        mask = self._causal_mask(L, input_ids.device)
        for layer in self.layers:
            x = layer(x, mask)

        return self.out_proj(self.norm(x))

    def count_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    @torch.no_grad()
    def generate(
        self,
        prompt_ids: torch.Tensor,
        max_new_tokens: int = 64,
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 0.9,
        eos_id: int = 2,
        repetition_penalty: float = 1.0,
        no_repeat_ngram_size: int = 0,
        ignore_repeat_token_ids: set[int] | None = None,
    ) -> torch.Tensor:
        """Autoregressive generation from a prompt.

        Args:
            prompt_ids: (1, L) token IDs including [BOS] and context.
            max_new_tokens: maximum tokens to generate.
            temperature: sampling temperature (lower = more deterministic).
            top_k: keep only top-k logits (0 = disabled).
            top_p: nucleus sampling threshold.
            eos_id: stop token.
            repetition_penalty: divide logits of previously-seen tokens by
                this factor (HF convention). > 1.0 discourages repeats.
                1.0 disables. Typical: 1.2–1.5.
            no_repeat_ngram_size: ban candidate tokens that would complete
                an n-gram already present in the current sequence (n =
                this value). 0 disables. Typical: 3 for chord sequences.
            ignore_repeat_token_ids: token ids exempt from the two repetition
                controls above — e.g. [BAR] or other separators that
                *should* recur. If None, no exemptions.

        Returns:
            (1, L') full sequence including prompt and generated tokens.
        """
        self.eval()
        ids = prompt_ids.clone()
        exempt = ignore_repeat_token_ids or set()

        for _ in range(max_new_tokens):
            ctx = ids[:, -self.max_seq_len :]
            logits = self(ctx)[:, -1, :] / max(temperature, 1e-8)

            # Repetition penalty (HuggingFace-style): scale already-seen token
            # logits so they are less attractive. Positive logits get divided,
            # negative logits get multiplied (stays "less attractive" either sign).
            if repetition_penalty != 1.0:
                seen = set(ids[0].tolist()) - exempt
                if seen:
                    idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
                    vals = logits[0, idx]
                    vals = torch.where(
                        vals > 0,
                        vals / repetition_penalty,
                        vals * repetition_penalty,
                    )
                    logits[0, idx] = vals

            # No-repeat n-gram: block any candidate token that would complete
            # an n-gram already present earlier in the sequence.
            if no_repeat_ngram_size > 0 and ids.shape[1] >= no_repeat_ngram_size:
                n = no_repeat_ngram_size
                seq = ids[0].tolist()
                prefix = tuple(seq[-(n - 1):]) if n > 1 else ()
                banned: set[int] = set()
                for i in range(len(seq) - n + 1):
                    if tuple(seq[i : i + n - 1]) == prefix:
                        banned.add(seq[i + n - 1])
                banned -= exempt
                if banned:
                    bidx = torch.tensor(list(banned), device=logits.device, dtype=torch.long)
                    logits[0, bidx] = float("-inf")

            # Top-k
            if top_k > 0:
                topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < topk_vals[:, -1:]] = float("-inf")

            # Top-p (nucleus)
            if 0 < top_p < 1.0:
                sorted_logits, sorted_idx = torch.sort(logits, descending=True)
                cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
                sorted_logits[remove] = float("-inf")
                logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)

            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            ids = torch.cat([ids, next_id], dim=-1)

            if (next_id == eos_id).all():
                break

        return ids