midmid3 / midmid /nn.py
markury's picture
Initial commit
d171350
"""Chart prediction model architecture.
FiLM-conditioned masked transformer for Guitar Hero chart generation.
"""
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Utility layers
# ---------------------------------------------------------------------------
def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0):
x_glu, x_linear = x[..., ::2], x[..., 1::2]
x_glu = x_glu.clamp(max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
return x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = x.float()
t = t * torch.rsqrt(t.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (t * self.scale).to(x.dtype)
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff, bias=False)
self.linear_out = nn.Linear(d_ff // 2, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_out(self.dropout(swiglu(self.linear1(x))))
# ---------------------------------------------------------------------------
# Rotary position embeddings
# ---------------------------------------------------------------------------
def apply_rotary_emb(
x: torch.Tensor, dim: int, base: float = 10000.0,
) -> torch.Tensor:
"""Apply RoPE to a tensor of shape [B, heads, T, head_dim]."""
seq_len = x.size(2)
device, dtype = x.device, x.dtype
theta = base ** (-torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)
positions = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1)
angles = positions * theta.unsqueeze(0)
sin, cos = angles.sin(), angles.cos()
sin = sin.unsqueeze(0).unsqueeze(0)
cos = cos.unsqueeze(0).unsqueeze(0)
x1 = x[..., : dim // 2]
x2 = x[..., dim // 2 : dim]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
# ---------------------------------------------------------------------------
# Bidirectional multi-head self-attention
# ---------------------------------------------------------------------------
class BidirectionalAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1,
rope_base: float = 10000.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.rope_base = rope_base
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, _ = x.shape
Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
Q = apply_rotary_emb(Q, dim=self.d_k, base=self.rope_base)
K = apply_rotary_emb(K, dim=self.d_k, base=self.rope_base)
sdpa_mask = None
if attn_mask is not None:
sdpa_mask = attn_mask[:, None, None, :].bool()
out = F.scaled_dot_product_attention(
Q, K, V, attn_mask=sdpa_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=False,
)
out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
return self.out_proj(out)
# ---------------------------------------------------------------------------
# FiLM-conditioned encoder block
# ---------------------------------------------------------------------------
class FiLMEncoderBlock(nn.Module):
"""Encoder block with FiLM difficulty conditioning.
After the feedforward, the output is modulated:
h = (1 + gamma) * h + beta
where gamma, beta are derived from the difficulty embedding.
"""
def __init__(self, d_model: int, d_ff: int, n_heads: int,
dropout: float = 0.1, rope_base: float = 10000.0):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.attn = BidirectionalAttention(d_model, n_heads, dropout, rope_base)
self.norm2 = RMSNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
self.dropout = nn.Dropout(dropout)
self.film_proj = nn.Linear(d_model, d_model * 2)
nn.init.zeros_(self.film_proj.weight)
nn.init.zeros_(self.film_proj.bias)
def forward(self, x: torch.Tensor, diff_emb: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.dropout(self.attn(self.norm1(x), attn_mask))
h = self.ff(self.norm2(x))
film = self.film_proj(diff_emb).unsqueeze(1)
gamma, beta = film.chunk(2, dim=-1)
h = (1 + gamma) * h + beta
x = x + self.dropout(h)
return x
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SILENCE_TOKEN = 32
MASK_TOKEN = 33
VOCAB_SIZE = 34
NUM_SUSTAIN_BUCKETS = 6
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
class ChartMaskPredictor(nn.Module):
"""Masked prediction chart model (v3).
Token vocabulary: 0-31 fret combos, 32 silence, 33 MASK.
"""
def __init__(self, config: "ChartMaskPredictorConfig"):
super().__init__()
self.config = config
d = config.d_model
self.audio_projection = nn.Linear(config.audio_dim, d, bias=False)
self.chart_embedding = nn.Embedding(VOCAB_SIZE, d)
self.input_dropout = nn.Dropout(config.dropout)
self.difficulty_embedding = nn.Embedding(4, d)
self.layers = nn.ModuleList([
FiLMEncoderBlock(
d_model=d, d_ff=config.d_ff, n_heads=config.n_heads,
dropout=config.dropout, rope_base=config.rope_base,
)
for _ in range(config.n_layers)
])
self.final_norm = RMSNorm(d)
self.token_head = nn.Linear(d, VOCAB_SIZE - 1) # 33 classes (no MASK)
self.sustain_head = nn.Linear(d, 1)
self.duration_head = nn.Linear(d, NUM_SUSTAIN_BUCKETS)
def forward(self, audio_features: torch.Tensor, chart_tokens: torch.Tensor,
difficulty: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None) -> dict[str, torch.Tensor]:
audio = self.audio_projection(audio_features)
chart = self.chart_embedding(chart_tokens)
x = audio + chart
x = self.input_dropout(x)
diff_emb = self.difficulty_embedding(difficulty)
for layer in self.layers:
x = layer(x, diff_emb, attn_mask=padding_mask)
x = self.final_norm(x)
return {
"token_logits": self.token_head(x),
"sustain_logits": self.sustain_head(x),
"duration_logits": self.duration_head(x),
}
@dataclass
class ChartMaskPredictorConfig:
audio_dim: int = 771
d_model: int = 512
n_heads: int = 8
n_layers: int = 6
d_ff: int = 2048
dropout: float = 0.15
rope_base: float = 10000.0