cond_gen / condition_encoders.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
import math
import torch
import torch.nn as nn
from typing import Optional
from vocab import NUM_CHORD_CLASSES, NUM_STRUCTURE_CLASSES
class EncoderBlock(nn.Module):
def __init__(
self, d_model: int, n_layers: int = 2, n_heads: int = 8, dropout: float = 0.0
):
super().__init__()
layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
dim_feedforward=4 * d_model,
dropout=dropout,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.enc = nn.TransformerEncoder(layer, num_layers=n_layers)
def forward(
self, x: torch.Tensor, pad_mask: Optional[torch.BoolTensor] = None
) -> torch.Tensor:
# pad_mask: [B,T], True means PAD (masked)
return self.enc(x, src_key_padding_mask=pad_mask)
class ConditionEncoder(nn.Module):
"""
Condition encoder for AdaLN-Zero:
Inputs: two aligned sequences
x_chord: [B,T,D_in]
x_seg: [B,T,D_in]
Output:
cond_expanded: [B,T,H] (feed into your AdaLN layers as cond_expanded)
What it encodes per token:
- token-level: chord/segment content at time t
- position: global position (always) + optional segment-relative position
- segment context: via x_seg + bidirectional transformer mixing
Notes:
- Non-causal (sees future): good for "guidance" conditions.
- Compute once per sample at generation start; slice per step.
"""
def __init__(
self,
hidden_size: int,
chord_embed_dim: int = 512,
structure_embed_dim: int = 512,
n_layers: int = 2,
n_heads: int = 8,
dropout: float = 0.0,
):
super().__init__()
self.hidden_size = hidden_size
self.chord_embedding = nn.Embedding(
NUM_CHORD_CLASSES, chord_embed_dim, padding_idx=0
)
self.structure_embedding = nn.Embedding(
NUM_STRUCTURE_CLASSES, structure_embed_dim, padding_idx=0
)
self.cond_dim = chord_embed_dim + structure_embed_dim
self.cond_proj = nn.Linear(self.cond_dim, hidden_size)
# Small bidirectional transformer
self.encoder = EncoderBlock(
d_model=hidden_size, n_layers=n_layers, n_heads=n_heads, dropout=dropout
)
self.proj_out = nn.Linear(hidden_size, hidden_size)
@staticmethod
def _sincos_pos(
positions: torch.Tensor, dim: int, dtype: torch.dtype
) -> torch.Tensor:
"""
positions: [B, T], absolute positions (0..T-1)
returns: [B, T, dim] sinusoidal positional encoding
"""
if dim <= 0:
raise ValueError("dim must be > 0 for positional encoding.")
half = dim // 2
if half == 0:
return torch.zeros(
positions.size(0),
positions.size(1),
dim,
device=positions.device,
dtype=dtype,
)
pos = positions.to(dtype=torch.float32)
freqs = torch.exp(
-math.log(10000.0)
* torch.arange(half, device=positions.device, dtype=torch.float32)
/ half
)
angles = pos.unsqueeze(-1) * freqs # [B, T, half]
enc = torch.zeros(
positions.size(0),
positions.size(1),
dim,
device=positions.device,
dtype=torch.float32,
)
enc[..., 0 : 2 * half : 2] = torch.sin(angles)
enc[..., 1 : 2 * half : 2] = torch.cos(angles)
return enc.to(dtype=dtype)
def forward(
self,
chord_ids: torch.Tensor, # [B, T]
structure_ids: torch.Tensor, # [B, T]
) -> torch.Tensor:
chord_emb = self.chord_embedding(chord_ids) # [B, T, chord_dim]
structure_emb = self.structure_embedding(structure_ids) # [B, T, struct_dim]
cond = torch.cat([chord_emb, structure_emb], dim=-1)
cond = self.cond_proj(cond)
# Encoder attention mask is computed separately from condition content.
# True means this token can be attended by the condition encoder.
valid_tokens = chord_ids.ne(0) | structure_ids.ne(0)
pad_mask = ~valid_tokens
# Position ids are contiguous only on valid condition-id tokens.
pos = valid_tokens.to(torch.long).cumsum(dim=1) - 1
pos = torch.where(valid_tokens, pos, torch.zeros_like(pos))
pos_enc = self._sincos_pos(pos, self.hidden_size, cond.dtype)
valid_mask = valid_tokens.unsqueeze(-1)
cond = cond + pos_enc * valid_mask.to(dtype=cond.dtype)
encoded = self.encoder(cond, pad_mask=pad_mask)
# [B, T, hidden_size]
return self.proj_out(encoded)