import math import torch import torch.nn as nn from typing import Optional from vocab import NUM_CHORD_CLASSES, NUM_STRUCTURE_CLASSES class EncoderBlock(nn.Module): def __init__( self, d_model: int, n_layers: int = 2, n_heads: int = 8, dropout: float = 0.0 ): super().__init__() layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, activation="gelu", batch_first=True, norm_first=True, ) self.enc = nn.TransformerEncoder(layer, num_layers=n_layers) def forward( self, x: torch.Tensor, pad_mask: Optional[torch.BoolTensor] = None ) -> torch.Tensor: # pad_mask: [B,T], True means PAD (masked) return self.enc(x, src_key_padding_mask=pad_mask) class ConditionEncoder(nn.Module): """ Condition encoder for AdaLN-Zero: Inputs: two aligned sequences x_chord: [B,T,D_in] x_seg: [B,T,D_in] Output: cond_expanded: [B,T,H] (feed into your AdaLN layers as cond_expanded) What it encodes per token: - token-level: chord/segment content at time t - position: global position (always) + optional segment-relative position - segment context: via x_seg + bidirectional transformer mixing Notes: - Non-causal (sees future): good for "guidance" conditions. - Compute once per sample at generation start; slice per step. """ def __init__( self, hidden_size: int, chord_embed_dim: int = 512, structure_embed_dim: int = 512, n_layers: int = 2, n_heads: int = 8, dropout: float = 0.0, ): super().__init__() self.hidden_size = hidden_size self.chord_embedding = nn.Embedding( NUM_CHORD_CLASSES, chord_embed_dim, padding_idx=0 ) self.structure_embedding = nn.Embedding( NUM_STRUCTURE_CLASSES, structure_embed_dim, padding_idx=0 ) self.cond_dim = chord_embed_dim + structure_embed_dim self.cond_proj = nn.Linear(self.cond_dim, hidden_size) # Small bidirectional transformer self.encoder = EncoderBlock( d_model=hidden_size, n_layers=n_layers, n_heads=n_heads, dropout=dropout ) self.proj_out = nn.Linear(hidden_size, hidden_size) @staticmethod def _sincos_pos( positions: torch.Tensor, dim: int, dtype: torch.dtype ) -> torch.Tensor: """ positions: [B, T], absolute positions (0..T-1) returns: [B, T, dim] sinusoidal positional encoding """ if dim <= 0: raise ValueError("dim must be > 0 for positional encoding.") half = dim // 2 if half == 0: return torch.zeros( positions.size(0), positions.size(1), dim, device=positions.device, dtype=dtype, ) pos = positions.to(dtype=torch.float32) freqs = torch.exp( -math.log(10000.0) * torch.arange(half, device=positions.device, dtype=torch.float32) / half ) angles = pos.unsqueeze(-1) * freqs # [B, T, half] enc = torch.zeros( positions.size(0), positions.size(1), dim, device=positions.device, dtype=torch.float32, ) enc[..., 0 : 2 * half : 2] = torch.sin(angles) enc[..., 1 : 2 * half : 2] = torch.cos(angles) return enc.to(dtype=dtype) def forward( self, chord_ids: torch.Tensor, # [B, T] structure_ids: torch.Tensor, # [B, T] ) -> torch.Tensor: chord_emb = self.chord_embedding(chord_ids) # [B, T, chord_dim] structure_emb = self.structure_embedding(structure_ids) # [B, T, struct_dim] cond = torch.cat([chord_emb, structure_emb], dim=-1) cond = self.cond_proj(cond) # Encoder attention mask is computed separately from condition content. # True means this token can be attended by the condition encoder. valid_tokens = chord_ids.ne(0) | structure_ids.ne(0) pad_mask = ~valid_tokens # Position ids are contiguous only on valid condition-id tokens. pos = valid_tokens.to(torch.long).cumsum(dim=1) - 1 pos = torch.where(valid_tokens, pos, torch.zeros_like(pos)) pos_enc = self._sincos_pos(pos, self.hidden_size, cond.dtype) valid_mask = valid_tokens.unsqueeze(-1) cond = cond + pos_enc * valid_mask.to(dtype=cond.dtype) encoded = self.encoder(cond, pad_mask=pad_mask) # [B, T, hidden_size] return self.proj_out(encoded)