| import math |
| import torch |
| import torch.nn as nn |
| from typing import Optional |
|
|
| from vocab import NUM_CHORD_CLASSES, NUM_STRUCTURE_CLASSES |
|
|
|
|
| class EncoderBlock(nn.Module): |
| def __init__( |
| self, d_model: int, n_layers: int = 2, n_heads: int = 8, dropout: float = 0.0 |
| ): |
| super().__init__() |
| layer = nn.TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=n_heads, |
| dim_feedforward=4 * d_model, |
| dropout=dropout, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.enc = nn.TransformerEncoder(layer, num_layers=n_layers) |
|
|
| def forward( |
| self, x: torch.Tensor, pad_mask: Optional[torch.BoolTensor] = None |
| ) -> torch.Tensor: |
| |
| return self.enc(x, src_key_padding_mask=pad_mask) |
|
|
|
|
| class ConditionEncoder(nn.Module): |
| """ |
| Condition encoder for AdaLN-Zero: |
| Inputs: two aligned sequences |
| x_chord: [B,T,D_in] |
| x_seg: [B,T,D_in] |
| |
| Output: |
| cond_expanded: [B,T,H] (feed into your AdaLN layers as cond_expanded) |
| |
| What it encodes per token: |
| - token-level: chord/segment content at time t |
| - position: global position (always) + optional segment-relative position |
| - segment context: via x_seg + bidirectional transformer mixing |
| |
| Notes: |
| - Non-causal (sees future): good for "guidance" conditions. |
| - Compute once per sample at generation start; slice per step. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| chord_embed_dim: int = 512, |
| structure_embed_dim: int = 512, |
| n_layers: int = 2, |
| n_heads: int = 8, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
|
|
| self.hidden_size = hidden_size |
| self.chord_embedding = nn.Embedding( |
| NUM_CHORD_CLASSES, chord_embed_dim, padding_idx=0 |
| ) |
| self.structure_embedding = nn.Embedding( |
| NUM_STRUCTURE_CLASSES, structure_embed_dim, padding_idx=0 |
| ) |
|
|
| self.cond_dim = chord_embed_dim + structure_embed_dim |
| self.cond_proj = nn.Linear(self.cond_dim, hidden_size) |
|
|
| |
| self.encoder = EncoderBlock( |
| d_model=hidden_size, n_layers=n_layers, n_heads=n_heads, dropout=dropout |
| ) |
|
|
| self.proj_out = nn.Linear(hidden_size, hidden_size) |
|
|
| @staticmethod |
| def _sincos_pos( |
| positions: torch.Tensor, dim: int, dtype: torch.dtype |
| ) -> torch.Tensor: |
| """ |
| positions: [B, T], absolute positions (0..T-1) |
| returns: [B, T, dim] sinusoidal positional encoding |
| """ |
| if dim <= 0: |
| raise ValueError("dim must be > 0 for positional encoding.") |
|
|
| half = dim // 2 |
| if half == 0: |
| return torch.zeros( |
| positions.size(0), |
| positions.size(1), |
| dim, |
| device=positions.device, |
| dtype=dtype, |
| ) |
|
|
| pos = positions.to(dtype=torch.float32) |
| freqs = torch.exp( |
| -math.log(10000.0) |
| * torch.arange(half, device=positions.device, dtype=torch.float32) |
| / half |
| ) |
| angles = pos.unsqueeze(-1) * freqs |
|
|
| enc = torch.zeros( |
| positions.size(0), |
| positions.size(1), |
| dim, |
| device=positions.device, |
| dtype=torch.float32, |
| ) |
| enc[..., 0 : 2 * half : 2] = torch.sin(angles) |
| enc[..., 1 : 2 * half : 2] = torch.cos(angles) |
|
|
| return enc.to(dtype=dtype) |
|
|
| def forward( |
| self, |
| chord_ids: torch.Tensor, |
| structure_ids: torch.Tensor, |
| ) -> torch.Tensor: |
|
|
| chord_emb = self.chord_embedding(chord_ids) |
| structure_emb = self.structure_embedding(structure_ids) |
|
|
| cond = torch.cat([chord_emb, structure_emb], dim=-1) |
| cond = self.cond_proj(cond) |
|
|
| |
| |
| valid_tokens = chord_ids.ne(0) | structure_ids.ne(0) |
| pad_mask = ~valid_tokens |
|
|
| |
| pos = valid_tokens.to(torch.long).cumsum(dim=1) - 1 |
| pos = torch.where(valid_tokens, pos, torch.zeros_like(pos)) |
|
|
| pos_enc = self._sincos_pos(pos, self.hidden_size, cond.dtype) |
| valid_mask = valid_tokens.unsqueeze(-1) |
| cond = cond + pos_enc * valid_mask.to(dtype=cond.dtype) |
| encoded = self.encoder(cond, pad_mask=pad_mask) |
|
|
| |
| return self.proj_out(encoded) |
|
|