"""Chart prediction model architecture. FiLM-conditioned masked transformer for Guitar Hero chart generation. """ from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Utility layers # --------------------------------------------------------------------------- def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): x_glu, x_linear = x[..., ::2], x[..., 1::2] x_glu = x_glu.clamp(max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) return x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: t = x.float() t = t * torch.rsqrt(t.pow(2).mean(dim=-1, keepdim=True) + self.eps) return (t * self.scale).to(x.dtype) class FeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff, bias=False) self.linear_out = nn.Linear(d_ff // 2, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear_out(self.dropout(swiglu(self.linear1(x)))) # --------------------------------------------------------------------------- # Rotary position embeddings # --------------------------------------------------------------------------- def apply_rotary_emb( x: torch.Tensor, dim: int, base: float = 10000.0, ) -> torch.Tensor: """Apply RoPE to a tensor of shape [B, heads, T, head_dim].""" seq_len = x.size(2) device, dtype = x.device, x.dtype theta = base ** (-torch.arange(0, dim, 2, device=device, dtype=dtype) / dim) positions = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1) angles = positions * theta.unsqueeze(0) sin, cos = angles.sin(), angles.cos() sin = sin.unsqueeze(0).unsqueeze(0) cos = cos.unsqueeze(0).unsqueeze(0) x1 = x[..., : dim // 2] x2 = x[..., dim // 2 : dim] return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) # --------------------------------------------------------------------------- # Bidirectional multi-head self-attention # --------------------------------------------------------------------------- class BidirectionalAttention(nn.Module): def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1, rope_base: float = 10000.0): super().__init__() self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.rope_base = rope_base self.w_q = nn.Linear(d_model, d_model, bias=False) self.w_k = nn.Linear(d_model, d_model, bias=False) self.w_v = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, T, _ = x.shape Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2) Q = apply_rotary_emb(Q, dim=self.d_k, base=self.rope_base) K = apply_rotary_emb(K, dim=self.d_k, base=self.rope_base) sdpa_mask = None if attn_mask is not None: sdpa_mask = attn_mask[:, None, None, :].bool() out = F.scaled_dot_product_attention( Q, K, V, attn_mask=sdpa_mask, dropout_p=self.dropout.p if self.training else 0.0, is_causal=False, ) out = out.transpose(1, 2).contiguous().view(B, T, self.d_model) return self.out_proj(out) # --------------------------------------------------------------------------- # FiLM-conditioned encoder block # --------------------------------------------------------------------------- class FiLMEncoderBlock(nn.Module): """Encoder block with FiLM difficulty conditioning. After the feedforward, the output is modulated: h = (1 + gamma) * h + beta where gamma, beta are derived from the difficulty embedding. """ def __init__(self, d_model: int, d_ff: int, n_heads: int, dropout: float = 0.1, rope_base: float = 10000.0): super().__init__() self.norm1 = RMSNorm(d_model) self.attn = BidirectionalAttention(d_model, n_heads, dropout, rope_base) self.norm2 = RMSNorm(d_model) self.ff = FeedForward(d_model, d_ff, dropout) self.dropout = nn.Dropout(dropout) self.film_proj = nn.Linear(d_model, d_model * 2) nn.init.zeros_(self.film_proj.weight) nn.init.zeros_(self.film_proj.bias) def forward(self, x: torch.Tensor, diff_emb: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = x + self.dropout(self.attn(self.norm1(x), attn_mask)) h = self.ff(self.norm2(x)) film = self.film_proj(diff_emb).unsqueeze(1) gamma, beta = film.chunk(2, dim=-1) h = (1 + gamma) * h + beta x = x + self.dropout(h) return x # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- SILENCE_TOKEN = 32 MASK_TOKEN = 33 VOCAB_SIZE = 34 NUM_SUSTAIN_BUCKETS = 6 # --------------------------------------------------------------------------- # Main model # --------------------------------------------------------------------------- class ChartMaskPredictor(nn.Module): """Masked prediction chart model (v3). Token vocabulary: 0-31 fret combos, 32 silence, 33 MASK. """ def __init__(self, config: "ChartMaskPredictorConfig"): super().__init__() self.config = config d = config.d_model self.audio_projection = nn.Linear(config.audio_dim, d, bias=False) self.chart_embedding = nn.Embedding(VOCAB_SIZE, d) self.input_dropout = nn.Dropout(config.dropout) self.difficulty_embedding = nn.Embedding(4, d) self.layers = nn.ModuleList([ FiLMEncoderBlock( d_model=d, d_ff=config.d_ff, n_heads=config.n_heads, dropout=config.dropout, rope_base=config.rope_base, ) for _ in range(config.n_layers) ]) self.final_norm = RMSNorm(d) self.token_head = nn.Linear(d, VOCAB_SIZE - 1) # 33 classes (no MASK) self.sustain_head = nn.Linear(d, 1) self.duration_head = nn.Linear(d, NUM_SUSTAIN_BUCKETS) def forward(self, audio_features: torch.Tensor, chart_tokens: torch.Tensor, difficulty: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> dict[str, torch.Tensor]: audio = self.audio_projection(audio_features) chart = self.chart_embedding(chart_tokens) x = audio + chart x = self.input_dropout(x) diff_emb = self.difficulty_embedding(difficulty) for layer in self.layers: x = layer(x, diff_emb, attn_mask=padding_mask) x = self.final_norm(x) return { "token_logits": self.token_head(x), "sustain_logits": self.sustain_head(x), "duration_logits": self.duration_head(x), } @dataclass class ChartMaskPredictorConfig: audio_dim: int = 771 d_model: int = 512 n_heads: int = 8 n_layers: int = 6 d_ff: int = 2048 dropout: float = 0.15 rope_base: float = 10000.0