| """Chart prediction model architecture. |
| |
| FiLM-conditioned masked transformer for Guitar Hero chart generation. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): |
| x_glu, x_linear = x[..., ::2], x[..., 1::2] |
| x_glu = x_glu.clamp(max=limit) |
| x_linear = x_linear.clamp(min=-limit, max=limit) |
| return x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.scale = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| t = x.float() |
| t = t * torch.rsqrt(t.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
| return (t * self.scale).to(x.dtype) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): |
| super().__init__() |
| self.linear1 = nn.Linear(d_model, d_ff, bias=False) |
| self.linear_out = nn.Linear(d_ff // 2, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.linear_out(self.dropout(swiglu(self.linear1(x)))) |
|
|
|
|
| |
| |
| |
|
|
| def apply_rotary_emb( |
| x: torch.Tensor, dim: int, base: float = 10000.0, |
| ) -> torch.Tensor: |
| """Apply RoPE to a tensor of shape [B, heads, T, head_dim].""" |
| seq_len = x.size(2) |
| device, dtype = x.device, x.dtype |
| theta = base ** (-torch.arange(0, dim, 2, device=device, dtype=dtype) / dim) |
| positions = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1) |
| angles = positions * theta.unsqueeze(0) |
| sin, cos = angles.sin(), angles.cos() |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| x1 = x[..., : dim // 2] |
| x2 = x[..., dim // 2 : dim] |
| return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class BidirectionalAttention(nn.Module): |
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1, |
| rope_base: float = 10000.0): |
| super().__init__() |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.d_k = d_model // n_heads |
| self.rope_base = rope_base |
|
|
| self.w_q = nn.Linear(d_model, d_model, bias=False) |
| self.w_k = nn.Linear(d_model, d_model, bias=False) |
| self.w_v = nn.Linear(d_model, d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| B, T, _ = x.shape |
| Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) |
| K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) |
| V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) |
|
|
| Q = apply_rotary_emb(Q, dim=self.d_k, base=self.rope_base) |
| K = apply_rotary_emb(K, dim=self.d_k, base=self.rope_base) |
|
|
| sdpa_mask = None |
| if attn_mask is not None: |
| sdpa_mask = attn_mask[:, None, None, :].bool() |
|
|
| out = F.scaled_dot_product_attention( |
| Q, K, V, attn_mask=sdpa_mask, |
| dropout_p=self.dropout.p if self.training else 0.0, |
| is_causal=False, |
| ) |
| out = out.transpose(1, 2).contiguous().view(B, T, self.d_model) |
| return self.out_proj(out) |
|
|
|
|
| |
| |
| |
|
|
| class FiLMEncoderBlock(nn.Module): |
| """Encoder block with FiLM difficulty conditioning. |
| |
| After the feedforward, the output is modulated: |
| h = (1 + gamma) * h + beta |
| where gamma, beta are derived from the difficulty embedding. |
| """ |
|
|
| def __init__(self, d_model: int, d_ff: int, n_heads: int, |
| dropout: float = 0.1, rope_base: float = 10000.0): |
| super().__init__() |
| self.norm1 = RMSNorm(d_model) |
| self.attn = BidirectionalAttention(d_model, n_heads, dropout, rope_base) |
| self.norm2 = RMSNorm(d_model) |
| self.ff = FeedForward(d_model, d_ff, dropout) |
| self.dropout = nn.Dropout(dropout) |
|
|
| self.film_proj = nn.Linear(d_model, d_model * 2) |
| nn.init.zeros_(self.film_proj.weight) |
| nn.init.zeros_(self.film_proj.bias) |
|
|
| def forward(self, x: torch.Tensor, diff_emb: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| x = x + self.dropout(self.attn(self.norm1(x), attn_mask)) |
| h = self.ff(self.norm2(x)) |
|
|
| film = self.film_proj(diff_emb).unsqueeze(1) |
| gamma, beta = film.chunk(2, dim=-1) |
| h = (1 + gamma) * h + beta |
|
|
| x = x + self.dropout(h) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| SILENCE_TOKEN = 32 |
| MASK_TOKEN = 33 |
| VOCAB_SIZE = 34 |
| NUM_SUSTAIN_BUCKETS = 6 |
|
|
|
|
| |
| |
| |
|
|
| class ChartMaskPredictor(nn.Module): |
| """Masked prediction chart model (v3). |
| |
| Token vocabulary: 0-31 fret combos, 32 silence, 33 MASK. |
| """ |
|
|
| def __init__(self, config: "ChartMaskPredictorConfig"): |
| super().__init__() |
| self.config = config |
| d = config.d_model |
|
|
| self.audio_projection = nn.Linear(config.audio_dim, d, bias=False) |
| self.chart_embedding = nn.Embedding(VOCAB_SIZE, d) |
| self.input_dropout = nn.Dropout(config.dropout) |
| self.difficulty_embedding = nn.Embedding(4, d) |
|
|
| self.layers = nn.ModuleList([ |
| FiLMEncoderBlock( |
| d_model=d, d_ff=config.d_ff, n_heads=config.n_heads, |
| dropout=config.dropout, rope_base=config.rope_base, |
| ) |
| for _ in range(config.n_layers) |
| ]) |
|
|
| self.final_norm = RMSNorm(d) |
| self.token_head = nn.Linear(d, VOCAB_SIZE - 1) |
| self.sustain_head = nn.Linear(d, 1) |
| self.duration_head = nn.Linear(d, NUM_SUSTAIN_BUCKETS) |
|
|
| def forward(self, audio_features: torch.Tensor, chart_tokens: torch.Tensor, |
| difficulty: torch.Tensor, |
| padding_mask: Optional[torch.Tensor] = None) -> dict[str, torch.Tensor]: |
| audio = self.audio_projection(audio_features) |
| chart = self.chart_embedding(chart_tokens) |
| x = audio + chart |
| x = self.input_dropout(x) |
|
|
| diff_emb = self.difficulty_embedding(difficulty) |
|
|
| for layer in self.layers: |
| x = layer(x, diff_emb, attn_mask=padding_mask) |
|
|
| x = self.final_norm(x) |
|
|
| return { |
| "token_logits": self.token_head(x), |
| "sustain_logits": self.sustain_head(x), |
| "duration_logits": self.duration_head(x), |
| } |
|
|
|
|
| @dataclass |
| class ChartMaskPredictorConfig: |
| audio_dim: int = 771 |
| d_model: int = 512 |
| n_heads: int = 8 |
| n_layers: int = 6 |
| d_ff: int = 2048 |
| dropout: float = 0.15 |
| rope_base: float = 10000.0 |
|
|